├── .gitignore ├── Code ├── dataset_figure.py ├── extract_images.py ├── extract_masks.py ├── extract_sets.py ├── learning_classes.py ├── train_script.py └── utils.py ├── Figures ├── Segmentations_labels.png ├── indices.png ├── pansharpening.png ├── sample_ex.png └── train_samples.gif └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | Code/__pycache__/* 2 | models/* 3 | .DS_Store 4 | -------------------------------------------------------------------------------- /Code/dataset_figure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from ast import literal_eval 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | 7 | from learning_classes import dataset 8 | 9 | path_to_drive = '/Volumes/SupMem/' 10 | path_to_data = path_to_drive + 'DSTL_data/' 11 | path_to_mask = path_to_drive + 'DSTL_data/masks/' 12 | path_to_img = path_to_drive + 'DSTL_data/processed_img/' 13 | 14 | # index of class in masks 15 | class_position = {'building':0, 'misc':1, 'road':2, 'track':3, \ 16 | 'tree':4, 'crop':5, 'water':6, 'vehicle':7} 17 | 18 | # %% load sample_df 19 | df = pd.read_csv(path_to_data+'train_samples.csv', index_col=0, converters={'classes' : literal_eval}) 20 | 21 | # %% plot 22 | N_img = 6 23 | fig, axs = plt.subplots(8,N_img,figsize=(N_img*3,8*3+2), gridspec_kw={'wspace':0.05, 'hspace':0.2}) 24 | fig.patch.set_alpha(0) 25 | 26 | for i, (class_name, class_pos) in enumerate(class_position.items()): 27 | # creat dataset 28 | df_tmp = df[pd.DataFrame(df.classes.tolist()).isin([class_name]).any(1)] 29 | data_set = dataset(df_tmp, path_to_img, path_to_mask, class_position[class_name], augment=True, crop_size=(144,144)) 30 | # draw group rectangles 31 | pos = axs[i,0].get_position() 32 | fig.patches.extend([plt.Rectangle((pos.x0-0.06,pos.y0-0.05*pos.height), pos.width*1.125*N_img , pos.height*1.1, 33 | facecolor='whitesmoke', ec='black', alpha=1, zorder=-1, 34 | transform=fig.transFigure, figure=fig)]) 35 | fig.text(pos.x0-0.03, pos.y0+0.5*pos.height, class_name.title(), rotation=90, rotation_mode='anchor', \ 36 | fontweight='bold', fontsize=14, ha='center', va='center') 37 | # plot image + mask 38 | for j in range(N_img): 39 | img, mask = data_set.__getitem__(np.random.randint(0, df_tmp.shape[0])) 40 | axs[i,j].imshow(np.moveaxis(np.array(img[[4,2,1], :, :]), 0, 2)) 41 | m = np.ma.masked_where(mask == 0, mask) 42 | axs[i,j].imshow(m, cmap = matplotlib.colors.ListedColormap(['white', 'red']), vmin=0, vmax=1, alpha=0.3) 43 | axs[i,j].set_axis_off() 44 | #fig.tight_layout() 45 | fig.savefig('../Figures/sample_ex.png', dpi=150, bbox_inches='tight') 46 | plt.show() 47 | -------------------------------------------------------------------------------- /Code/extract_images.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import skimage 7 | 8 | import warnings 9 | warnings.filterwarnings("ignore", message="Possible precision loss when converting from float64 to uint16") 10 | 11 | from utils import load_image, save_geotiff, contrast_stretch, NDVI, EVI, pansharpen 12 | 13 | """ BANDS 14 | |-- A 15 | |--- 0 : SWIR-1 1195 - 1225 n 16 | |--- 1 : SWIR-2 1550 - 1590 n 17 | |--- 2 : SWIR-3 1640 - 1680 n 18 | |--- 3 : SWIR-4 1710 - 1750 n 19 | |--- 4 : SWIR-5 2145 - 2185 nm 20 | |--- 5 : SWIR-6 2185 - 2225 nm 21 | |--- 6 : SWIR-7 2235 - 2285 nm 22 | |--- 7 : SWIR-8 2295 - 2365 nm 23 | |-- M 24 | |--- 0 : Coastal Blue 400 - 450 nm 25 | |--- 1 : Blue 450 - 510 nm 26 | |--- 2 : Green 510 - 580 nm 27 | |--- 3 : Yellow 585 - 625 nm 28 | |--- 4 : Red 630 - 690 nm 29 | |--- 5 : Red-edges 705 - 745 nm 30 | |--- 6 : NIR-1 770 - 895 nm 31 | |--- 7 : NIR-2 860 - 1040 nm 32 | |-- P 33 | |--- 0 : 450 - 800 nm 34 | """ 35 | 36 | # %% Get to the external drive and check for its presence 37 | path_to_drive = '/Volumes/SupMem/DSTL_data/' 38 | if not os.path.isdir(path_to_drive): print("Cannot find the external drive!") 39 | 40 | # load the img_id 41 | wkt_df = pd.read_csv(path_to_drive+'train_wkt_v4.csv') 42 | 43 | # get all the image id 44 | img_id_list = list(wkt_df.ImageId.unique()) 45 | 46 | #%% 47 | print(f'>>>> Convert Polygon to mask \n'+'-'*80) 48 | for i, img_id in enumerate(img_id_list): 49 | print(f'\t|---- {i+1:02} : Processing image {img_id}') 50 | _, m, p = load_image(path_to_drive+'sixteen_band/', img_id) 51 | img_fused = pansharpen(m, p, order=3, W=1.5, stretch_perc=(1,99)) 52 | ndvi = NDVI(img_fused[:,:,4], img_fused[:,:,6]) 53 | ndwi = NDVI(img_fused[:,:,6], img_fused[:,:,2]) 54 | evi = EVI(img_fused[:,:,4], img_fused[:,:,6], img_fused[:,:,1]) 55 | img = np.concatenate([img_fused, np.expand_dims(ndvi,2), np.expand_dims(ndwi,2), np.expand_dims(evi,2)], axis=2) 56 | img = np.moveaxis(img, 2, 0) 57 | #skimage.external.tifffile.imsave(path_to_drive+'processed_img/'+img_id+'.tiff', skimage.img_as_uint(img)) 58 | save_geotiff(path_to_drive+'processed_img/'+img_id+'.tif', skimage.img_as_uint(img), dtype='uint16') 59 | 60 | # %% ----------------------------------------------------------------------------------- 61 | # processing example 62 | img_id = '6120_2_2' 63 | _, m, p = load_image(path_to_drive+'sixteen_band/', img_id) 64 | m_tmp = contrast_stretch(m) 65 | p_tmp = np.squeeze(contrast_stretch(np.expand_dims(p,2))) 66 | m_up = skimage.transform.resize(m_tmp, p_tmp.shape, order=3) 67 | img_fused = pansharpen(m, p, order=3, W=1.5, stretch_perc=(1,99)) 68 | 69 | #%% 70 | fig, axs = plt.subplots(1,3,figsize=(12,6)) 71 | fig.patch.set_alpha(0) 72 | axs[0].imshow(m_tmp[300:400,300:400,[6,4,2]]) 73 | axs[0].set_title('Multispectral Image', fontsize=12) 74 | axs[1].imshow(m_up[1200:1600,1200:1600,[6,4,2]]) 75 | axs[1].set_title('Bicubic upsampling', fontsize=12) 76 | axs[2].imshow(img_fused[1200:1600,1200:1600,[6,4,2]]) 77 | axs[2].set_title('Image fusion', fontsize=12) 78 | for ax in axs: ax.set_axis_off() 79 | fig.tight_layout() 80 | fig.savefig('../Figures/pansharpening.png', dpi=150, bbox_inches='tight') 81 | plt.show() 82 | 83 | #%% Index example 84 | img_id = '6100_2_2' 85 | _, m, p = load_image(path_to_drive+'sixteen_band/', img_id) 86 | img_fused = pansharpen(m, p, order=3, W=1.5, stretch_perc=(1,99)) 87 | 88 | #%% 89 | ndvi = NDVI(img_fused[:,:,4], img_fused[:,:,6]) 90 | ndwi = NDVI(img_fused[:,:,6], img_fused[:,:,2]) 91 | evi = EVI(img_fused[:,:,4], img_fused[:,:,6], img_fused[:,:,1]) 92 | fig, axs = plt.subplots(1,4,figsize=(16,6)) 93 | fig.patch.set_alpha(0) 94 | axs[0].set_title('NDVI', fontsize=12) 95 | axs[0].imshow(ndvi[:,:], cmap='PiYG', vmin=-1, vmax=1) 96 | axs[1].set_title('NDWI', fontsize=12) 97 | axs[1].imshow(ndwi[:,:], cmap='coolwarm_r', vmin=-1, vmax=1) 98 | axs[2].set_title('EVI', fontsize=12) 99 | axs[2].imshow(evi[:,:], cmap='PiYG', vmin=-1, vmax=1) 100 | axs[3].set_title('True Color Composition', fontsize=12) 101 | axs[3].imshow(img_fused[:,:,[4,2,1]]) 102 | for ax in axs: ax.set_axis_off() 103 | fig.tight_layout() 104 | fig.savefig('../Figures/indices.png', dpi=150, bbox_inches='tight') 105 | plt.show() 106 | -------------------------------------------------------------------------------- /Code/extract_masks.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | import shapely 7 | import descartes 8 | import skimage 9 | 10 | from utils import get_polygon_dict, get_polygons_masks, plot_polygons, plot_masks 11 | 12 | Class_dict = {'building':[1], 'misc':[2], 'road':[3], 'track':[4], \ 13 | 'tree':[5], 'crop':[6], 'water':[7,8], 'vehicle':[9,10]} 14 | 15 | order_dict = {'building':1, 'misc':2, 'road':3, 'track':4, \ 16 | 'tree':5, 'crop':6, 'water':7, 'vehicle':8} 17 | 18 | color_dict = {'building':'0.7', 'misc':'0.4', 'road':'#b35806', 'track':'#dfc27d', \ 19 | 'tree':'#1b7837', 'crop':'#a6dba0', 'water':'#74add1', 'vehicle':'#f46d43'} 20 | 21 | zorder_dict = {'crop':1, 'water':2, 'road':3, 'track':4,\ 22 | 'building':5, 'misc':6, 'vehicle':7, 'tree':8} 23 | 24 | # %% Get to the external drive and check for its presence 25 | path_to_drive = '/Volumes/SupMem/DSTL_data/' 26 | if not os.path.isdir(path_to_drive): print("Cannot find the external drive!") 27 | 28 | # load the grid for conversion of polygon into image space 29 | grid = pd.read_csv(path_to_drive+'grid_sizes.csv').rename(columns={'Unnamed: 0':'ImageId'}) 30 | 31 | # load the polygons 32 | wkt_df = pd.read_csv(path_to_drive+'train_wkt_v4.csv') 33 | 34 | # get all the image id 35 | img_id_list = list(wkt_df.ImageId.unique()) 36 | 37 | # %% 38 | masks = {} 39 | print(f'>>>> Convert Polygon to mask \n'+'-'*80) 40 | for i, img_id in enumerate(img_id_list): 41 | print(f'\t|---- {i+1:02} : Getting segmentation of image {img_id}') 42 | img_size = skimage.io.imread(path_to_drive+'sixteen_band/'+img_id+'_P.tif', plugin="tifffile").shape 43 | pdict = get_polygon_dict(img_id, Class_dict, img_size, wkt_df, grid) 44 | masks[img_id] = get_polygons_masks(pdict, order_dict, img_size, filename=path_to_drive+'masks/'+img_id+'_mask.tif') 45 | 46 | # %% plot masks 47 | fig, axs = plt.subplots(5,5,figsize=(20,20)) 48 | fig.patch.set_alpha(0) 49 | for (id, mask), ax in zip(masks.items(), axs.reshape(-1)): 50 | plot_masks(ax, mask, order_dict, color_dict, zorder_dict, legend=False) 51 | ax.set_title(id, fontsize=14) 52 | ax.tick_params(axis='both', which='both',bottom=False, top=False,\ 53 | labelbottom=False, right=False, left=False, labelleft=False) 54 | 55 | handles = [matplotlib.patches.Patch(facecolor=pcol) for pcol in color_dict.values()] 56 | labels = list(color_dict.keys()) 57 | lgd = fig.legend(handles, labels, ncol=8, loc='lower center', fontsize=14, \ 58 | bbox_to_anchor=(0.5, -0.03), bbox_transform=fig.transFigure) 59 | fig.tight_layout() 60 | fig.savefig('../Figures/Segmentations_labels.png', dpi=200, bbox_extra_artists=(lgd,), bbox_inches='tight') 61 | plt.show() 62 | -------------------------------------------------------------------------------- /Code/extract_sets.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib 3 | import matplotlib.animation 4 | import numpy as np 5 | import pandas as pd 6 | import skimage 7 | import tifffile 8 | import pickle 9 | import os 10 | import rasterio 11 | 12 | from utils import get_crops_grid, load_image_part, get_represented_classes, get_samples 13 | 14 | path_to_drive = '/Volumes/SupMem/' 15 | path_to_data = path_to_drive + 'DSTL_data/' 16 | path_to_mask = path_to_drive + 'DSTL_data/masks/' 17 | path_to_img = path_to_drive + 'DSTL_data/processed_img/' 18 | if not os.path.isdir(path_to_drive): print("Cannot find the external drive!") 19 | 20 | order_dict = {1:'building', 2:'misc', 3:'road', 4:'track', \ 21 | 5:'tree', 6:'crop', 7:'water', 8:'vehicle'} 22 | 23 | # %% load the img_id 24 | img_id_list = list(pd.read_csv(path_to_drive+'DSTL_data/train_wkt_v4.csv', usecols=['ImageId']).ImageId.unique()) 25 | 26 | # define which image is for test and which is for train 27 | id_test = ['6100_2_2', '6060_2_3', '6110_4_0', '6160_2_1'] 28 | id_train = [id for id in img_id_list if id not in id_test] 29 | 30 | #%% Generate samples 31 | crop_size_train = (160,160) 32 | overlap_train = (80,80) 33 | class_offset_train = (40,40) 34 | class_area_train = (80,80) 35 | 36 | crop_size_test = (1024,1024) 37 | overlap_test = None 38 | class_offset_test = (0,0) 39 | class_area_test = crop_size_test 40 | 41 | df_train = get_samples(id_train, \ 42 | path_to_img, path_to_mask, \ 43 | crop_size_train, overlap_train, \ 44 | order_dict, \ 45 | class_offset_train, class_area_train, \ 46 | verbose=True) 47 | 48 | df_test = get_samples(id_test, \ 49 | path_to_img, path_to_mask, \ 50 | crop_size_test, overlap_test, \ 51 | order_dict, \ 52 | class_offset_test, class_area_test, \ 53 | verbose=True) 54 | 55 | # the test set cut similarly to the train for performance estimations 56 | df_test_small = get_samples(id_test, \ 57 | path_to_img, path_to_mask, \ 58 | crop_size_train, overlap_train, \ 59 | order_dict, \ 60 | class_offset_train, class_area_train, \ 61 | verbose=True) 62 | 63 | df_train.to_csv(path_to_data+'train_samples.csv') 64 | df_test.to_csv(path_to_data+'test_samples.csv') 65 | df_test_small.to_csv(path_to_data+'test_samples_small.csv') 66 | 67 | # ----------------------------------------------------------------------------------------- 68 | #%% 69 | subimg = load_image_part((0,0), (5*80,8*80), path_to_img+id_train[2]+'.tif') 70 | crops = get_crops_grid(subimg.shape[1], subimg.shape[2], crop_size_train, overlap_train) 71 | #%% 72 | fig, ax = plt.subplots(1,1,figsize=(10,7)) 73 | #fig.patch.set_alpha(0) 74 | ax.imshow(np.moveaxis(subimg[[4,2,1], :, :], 0, 2)) 75 | ax.set_axis_off() 76 | ax.set_title('train sample generation example', fontsize=12) 77 | P1 = ax.add_patch(matplotlib.patches.Rectangle((0, 0), 0, 0, fc=(0,0,0,0), ec='Orangered', lw=2)) 78 | P2 = ax.add_patch(matplotlib.patches.Rectangle((0, 0), 0, 0, fc=(0.1,0.1,0.1,0.3), ec='dodgerblue', lw=1)) 79 | fig.legend([P1, P2], ['train sample', 'loss-relevant area'], ncol=2, loc='lower center', fontsize=12) 80 | fig.tight_layout() 81 | 82 | def init(): 83 | return [] 84 | 85 | def animate(c): 86 | P1.set_xy((c[1], c[0])) 87 | P1.set_height(crop_size_train[0]) 88 | P1.set_width(crop_size_train[1]) 89 | #P2.set_xy((c[1]+overlap_train[0]/2, c[0]+overlap_train[1]/2)) 90 | #P2.set_height(overlap_train[0]) 91 | #P2.set_width(overlap_train[1]) 92 | P2 = ax.add_patch(matplotlib.patches.Rectangle((c[1]+overlap_train[0]/2, c[0]+overlap_train[1]/2), overlap_train[0], overlap_train[1], fc=(0.1,0.1,0.1,0.5), ec='dodgerblue', lw=1)) 93 | return [P1,P2] 94 | 95 | anim = matplotlib.animation.FuncAnimation(fig, animate, init_func=init, frames=crops, interval=500, blit=True) 96 | anim.save('../Figures/train_samples.gif', writer='imagemagick', dpi=150) 97 | -------------------------------------------------------------------------------- /Code/learning_classes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage 3 | import skimage.transform 4 | import rasterio 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from torch.utils import data 10 | 11 | from utils import load_image_part 12 | 13 | class dataset(data.Dataset): 14 | """ 15 | Define a pytorch dataset for the DSTL data. 16 | """ 17 | def __init__(self, sample_df, image_path, target_path, class_type, augment=True, crop_size=(144,144)): 18 | """ 19 | Initialize a dataset for the DSTL data from the crop informations. 20 | ------------ 21 | INPUT 22 | |---- sample_df (pandas.DataFrame) Dataframe with the samples informations 23 | | each row is a sample with image_id, crop corrdinates 24 | | and crops dimension. 25 | |---- image_path (str) the path to the folder of images 26 | |---- mask_path (str) the path to the folder of plot_masks 27 | |---- class_type (int) the channel dimension to use (aka the class 28 | | coordinate in the mask) 29 | |---- augment (bool) whether to perform data augmentation 30 | |---- crop_size (tuple) the random crop size to perform 31 | Output 32 | |---- NONE 33 | """ 34 | self.sample_df = sample_df 35 | self.image_path = image_path 36 | self.target_path = target_path 37 | self.class_type = class_type 38 | self.augment = augment 39 | self.crop_size = crop_size 40 | 41 | def transform(self, image, mask): 42 | """ 43 | Transform the passed image and mask and perform a data augmentation if 44 | self.augment is True (random crop + random horizontal and vertical flip 45 | + random 90° rotations) 46 | ------------ 47 | INPUT 48 | |---- image (3D numpy.array) the image to transform 49 | |---- mask (2D numpy.array) the corresponding mask to transform 50 | Output 51 | |---- image (3D torch.Tensor) the transformed image as B x H x W 52 | |---- mask (2D torch.Tensor) the transformed associated mask for the selected class 53 | """ 54 | # Random crop 55 | if self.augment: 56 | r, c = self.get_crop_param(image.shape[1:3], self.crop_size) 57 | else: 58 | r, c = int((image.shape[1]-self.crop_size[0])/2), int((image.shape[2]-self.crop_size[1])/2) 59 | image = image[:,r:r+self.crop_size[0],c:c+self.crop_size[1]] 60 | mask = mask[r:r+self.crop_size[0],c:c+self.crop_size[1]] 61 | 62 | if self.augment: 63 | # Random Vertical flip 64 | if np.random.random() > 0.5: 65 | image = image[:, ::-1, :] 66 | mask = mask[::-1, :] 67 | # Random Horizontal flip 68 | if np.random.random() > 0.5: 69 | image = image[:, :, ::-1] 70 | mask = mask[:, ::-1] 71 | # Random Rotate 72 | angle = 90*np.random.randint(0,4) # number between 0 and 3 73 | image = np.moveaxis(skimage.transform.rotate(np.moveaxis(image, 0, 2), angle, preserve_range=True), 2, 0) 74 | mask = skimage.transform.rotate(mask, angle, preserve_range=True) 75 | 76 | # Transform to Tensor 77 | image = torch.from_numpy(image).float() 78 | mask = torch.from_numpy(mask).float() 79 | 80 | return image, mask 81 | 82 | def __len__(self): 83 | """ 84 | Return length of the dataset (the number of sample) 85 | ------------ 86 | INPUT 87 | |---- NONE 88 | Output 89 | |---- (int) the number of samples 90 | """ 91 | return self.sample_df.shape[0] 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Return one sample of the dataset associated with the index (row of the dataframe) 96 | ------------ 97 | INPUT 98 | |---- index (int) the index of the sample to load 99 | Output 100 | |---- image (3D torch.Tensor) the image as B x H x W 101 | |---- mask (2D torch.Tensor) the associated mask for the selected class 102 | """ 103 | sample = self.sample_df.iloc[index, :] 104 | xy = (sample.row, sample.col) 105 | hw = (sample.h, sample.w) 106 | id = sample.img_id 107 | image = load_image_part(xy, hw, self.image_path+id+'.tif') 108 | mask = load_image_part(xy, hw, self.target_path+id+'_mask.tif', as_float=False)[self.class_type, :, :] 109 | return self.transform(image, mask) 110 | 111 | def get_crop_param(self, img_dim, crop_size): 112 | """ 113 | Return random crop parameters given the input size and the crop size 114 | ------------ 115 | INPUT 116 | |---- img_dim (tuple) the input image dimension as (row, col) 117 | |---- crop_size (tuple) the crop dimension as (row, col) 118 | Output 119 | |---- r_crop (int) the row coordinate of the crop 120 | |---- c_crop (int) the column coordinate of the crop 121 | """ 122 | r_crop = np.random.randint(0,img_dim[0]-crop_size[0]) 123 | c_crop = np.random.randint(0,img_dim[1]-crop_size[1]) 124 | return r_crop, c_crop 125 | 126 | class BinaryDiceLoss(nn.Module): 127 | """ 128 | Define a Binary DiceLoss. 129 | """ 130 | def __init__(self, smooth=1, p=2, reduction='mean'): 131 | """ 132 | Constructor of the BinaryDiceLoss. 133 | ------------ 134 | INPUT 135 | |---- smooth (int) smooth number for the diceloss 136 | |---- p (int) the power to use in the denominator 137 | |---- reduction (str) how the loss should be reduced (should be one 138 | | of : 'mean', 'sum', or 'none') 139 | Output 140 | |---- None 141 | """ 142 | nn.Module.__init__(self) 143 | self.smooth = smooth 144 | self.reduction = reduction 145 | self.p = p 146 | 147 | def forward(self, input, target): 148 | """ 149 | Constructor of the BinaryDiceLoss. 150 | ------------ 151 | INPUT 152 | |---- input (torch.FloatTensor) the binary input with dimension B x 2 x H x W. 153 | | The positive class is defined by the as a 154 | | one on the tensor. 155 | |---- target (torch.FloatTensor) the binary target with dimension B x H x W. 156 | | The positive class is defined by the as a 157 | | one on the tensor. 158 | Output 159 | |---- loss (torch.FloatTensor) the Dice loss with dimension depending 160 | | on the reduction chosen. 161 | """ 162 | # check input 163 | assert input.shape[0] == target.shape[0], 'Input and Target must have the same batch size.' 164 | assert input.dim() == 4, f'Input dimension {input.shape} does not match. Should be 4D : Batch x 2 x Height x Width' 165 | assert target.dim() == 3, f'Target dimension {input.shape} does not match. Should be 3D : Batch x Height x Width' 166 | # convert input and target to float 167 | input, target = input.float(), target.float() 168 | # softmax on input and keep only the class of 1 169 | input = F.softmax(input, dim=1)[:,1,:,:] 170 | # linearize input and target as vector 171 | input = input.contiguous().view(input.shape[0], -1) 172 | target = target.contiguous().view(target.shape[0], -1) 173 | # compute numerator and denominator 174 | numerator = 2*(input*target).sum(dim=1) + self.smooth 175 | denominator = torch.sum(input.pow(self.p) + target.pow(self.p), dim=1) + self.smooth 176 | # compute the dice loss 177 | loss = 1 - numerator / denominator 178 | # return according to the reduction 179 | if self.reduction == 'mean': 180 | return loss.mean() 181 | elif self.reduction == 'sum': 182 | return loss.sum() 183 | elif self.reduction == 'none': 184 | return loss 185 | else: 186 | raise Exception(f'Unexpected reduction {self.reduction}') 187 | 188 | class ResBlock(nn.Module): 189 | """ 190 | Define a Residual block for the U-net. (2 3x3 convolution + BatchNorm layer 191 | with SELU as activation function). 192 | """ 193 | def __init__(self, in_channel, out_channel): 194 | """ 195 | Constructor of the resblock. 196 | ------------ 197 | INPUT 198 | |---- in_channel (int) number of input channel 199 | |---- out_channel (int) number of output channel 200 | Output 201 | |---- None 202 | """ 203 | nn.Module.__init__(self) 204 | self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1) 205 | self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1) 206 | self.BN = nn.BatchNorm2d(out_channel) 207 | 208 | def forward(self, x): 209 | """ 210 | Forward method of the block. 211 | ------------ 212 | INPUT 213 | |---- x (torch.Tensor) the input of dimension (batch, in_channel, img_H, img_W) 214 | Output 215 | |---- x (torch.Tensor) the output of dimension (batch, out_channel, img_H, img_W) 216 | """ 217 | x = F.selu(self.BN(self.conv1(x))) 218 | x = F.selu(self.BN(self.conv2(x))) 219 | return x 220 | 221 | class U_net(nn.Module): 222 | """ 223 | Definition of the U-net model. Convolution with 5 ResBlock and MaxPool layers 224 | Followd by a deconvolution with 5 ResBlock and ConvTranspose layer. 225 | """ 226 | def __init__(self, in_channel): 227 | """ 228 | Constructor of the Unet. 229 | ------------ 230 | INPUT 231 | |---- in_channel (int) number of input channel of the model 232 | Output 233 | |---- None 234 | """ 235 | nn.Module.__init__(self) 236 | # Down blocks 237 | self.RBD1 = ResBlock(in_channel,32) 238 | self.RBD2 = ResBlock(32,64) 239 | self.RBD3 = ResBlock(64,128) 240 | self.RBD4 = ResBlock(128,256) 241 | self.RBD5 = ResBlock(256,512) 242 | 243 | # Up blocks 244 | self.convT1 = nn.ConvTranspose2d(512, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 245 | self.RBU1 = ResBlock(128+256,256) 246 | self.convT2 = nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1, stride=2, output_padding=1) 247 | self.RBU2 = ResBlock(256+128,128) 248 | self.convT3 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 249 | self.RBU3 = ResBlock(128+64,64) 250 | self.convT4 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 251 | self.RBU4 = ResBlock(64+32,32) 252 | self.convFinal = nn.Conv2d(32,2, kernel_size=3, padding=1) 253 | 254 | def forward(self, x): 255 | """ 256 | Forward method of the Unet. 257 | ------------ 258 | INPUT 259 | |---- x (torch.Tensor) the input of dimension (batch, in_channel, img_H, img_W) 260 | Output 261 | |---- x (torch.Tensor) the output of dimension (batch, 2, img_H, img_W) 262 | """ 263 | # dimension Batch x Channel x Width x Height 264 | # down 265 | r1 = self.RBD1.forward(x) 266 | x = F.max_pool2d(r1, kernel_size=2, stride=2) 267 | r2 = self.RBD2.forward(x) 268 | x = F.max_pool2d(r2, kernel_size=2, stride=2) 269 | r3 = self.RBD3.forward(x) 270 | x = F.max_pool2d(r3, kernel_size=2, stride=2) 271 | r4 = self.RBD4.forward(x) 272 | x = F.max_pool2d(r4, kernel_size=2, stride=2) 273 | x = self.RBD5.forward(x) 274 | # up 275 | #x = F.interpolate(x, scale_factor=2, mode='bilinear') 276 | x = F.selu(self.convT1(x)) 277 | x = self.RBU1.forward(torch.cat((r4, x), dim=1)) 278 | #x = F.interpolate(x, scale_factor=2, mode='bilinear') 279 | x = F.selu(self.convT2(x)) 280 | x = self.RBU2.forward(torch.cat((r3, x), dim=1)) 281 | #x = F.interpolate(x, scale_factor=2, mode='bilinear') 282 | x = F.selu(self.convT3(x)) 283 | x = self.RBU3.forward(torch.cat((r2, x), dim=1)) 284 | #x = F.interpolate(x, scale_factor=2, mode='bilinear') 285 | x = F.selu(self.convT4(x)) 286 | x = self.RBU4.forward(torch.cat((r1, x), dim=1)) 287 | x = F.selu(self.convFinal(x)) 288 | return x 289 | 290 | 291 | # Shallower 292 | class U_net2(nn.Module): 293 | """ """ 294 | def __init__(self, in_channel): 295 | """ """ 296 | nn.Module.__init__(self) 297 | # Down blocks 298 | self.RBD1 = ResBlock(in_channel,32) 299 | self.RBD2 = ResBlock(32,64) 300 | self.RBD3 = ResBlock(64,128) 301 | self.RBD4 = ResBlock(128,256) 302 | 303 | # Up blocks 304 | self.convT2 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 305 | self.RBU2 = ResBlock(128+128,128) 306 | self.convT3 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 307 | self.RBU3 = ResBlock(128+64,64) 308 | self.convT4 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 309 | self.RBU4 = ResBlock(64+32,32) 310 | self.convFinal = nn.Conv2d(32,2, kernel_size=3, padding=1) 311 | 312 | def forward(self, x): 313 | """ """ 314 | # dimension Batch x Channel x Width x Height 315 | # down 316 | r1 = self.RBD1.forward(x) 317 | x = F.max_pool2d(r1, kernel_size=2, stride=2) 318 | r2 = self.RBD2.forward(x) 319 | x = F.max_pool2d(r2, kernel_size=2, stride=2) 320 | r3 = self.RBD3.forward(x) 321 | x = F.max_pool2d(r3, kernel_size=2, stride=2) 322 | x = self.RBD4.forward(x) 323 | # up 324 | #x = F.interpolate(x, scale_factor=2, mode='bilinear') 325 | x = F.selu(self.convT2(x)) 326 | x = self.RBU2.forward(torch.cat((r3, x), dim=1)) 327 | #x = F.interpolate(x, scale_factor=2, mode='bilinear') 328 | x = F.selu(self.convT3(x)) 329 | x = self.RBU3.forward(torch.cat((r2, x), dim=1)) 330 | #x = F.interpolate(x, scale_factor=2, mode='bilinear') 331 | x = F.selu(self.convT4(x)) 332 | x = self.RBU4.forward(torch.cat((r1, x), dim=1)) 333 | x = F.selu(self.convFinal(x)) 334 | return x 335 | 336 | # Shallower 337 | class U_net3(nn.Module): 338 | """ """ 339 | def __init__(self, in_channel): 340 | """ """ 341 | nn.Module.__init__(self) 342 | # Down blocks 343 | self.RBD1 = ResBlock(in_channel,32) 344 | self.RBD2 = ResBlock(32,64) 345 | 346 | # Up blocks 347 | self.convT4 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 348 | self.RBU4 = ResBlock(64+32,32) 349 | self.convFinal = nn.Conv2d(32,2, kernel_size=3, padding=1) 350 | 351 | def forward(self, x): 352 | """ """ 353 | # dimension Batch x Channel x Width x Height 354 | # down 355 | r1 = self.RBD1.forward(x) 356 | x = F.max_pool2d(r1, kernel_size=2, stride=2) 357 | x = self.RBD2.forward(x) 358 | # up 359 | x = F.selu(self.convT4(x)) 360 | x = self.RBU4.forward(torch.cat((r1, x), dim=1)) 361 | x = F.selu(self.convFinal(x)) 362 | return x 363 | -------------------------------------------------------------------------------- /Code/train_script.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #from torch.autograd import Variable 4 | import torch.cuda as cuda 5 | 6 | from sklearn.metrics import jaccard_score, f1_score 7 | from sklearn.model_selection import train_test_split 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from ast import literal_eval 12 | 13 | import os 14 | import sys 15 | import pickle 16 | import matplotlib.pyplot as plt 17 | import matplotlib 18 | 19 | from learning_classes import dataset, BinaryDiceLoss, U_net, U_net2, U_net3 20 | from utils import print_param_summary, append_scores, load_sample_df, get_dataset_scores, print_progessbar 21 | 22 | #%%------------------------------------------------------------------------------------------------- 23 | # General declaration 24 | 25 | path_to_drive = '/Volumes/SupMem/' 26 | path_to_data = path_to_drive + 'DSTL_data/' 27 | path_to_mask = path_to_drive + 'DSTL_data/masks/' 28 | path_to_img = path_to_drive + 'DSTL_data/processed_img/' 29 | path_to_output = '../Outputs/' 30 | if not os.path.isdir(path_to_drive): print("Cannot find the external drive!") 31 | 32 | # index of class in masks 33 | class_position = {'building':0, 'misc':1, 'road':2, 'track':3, \ 34 | 'tree':4, 'crop':5, 'water':6, 'vehicle':7} 35 | 36 | #%%------------------------------------------------------------------------------------------------- 37 | # Get the Dataset and Dataloader 38 | 39 | # recover the class to train from passed argument or define it 40 | class_type = 'building' 41 | # if sys.argv[1]: 42 | # if sys.argv[1] in list(class_position.keys()): 43 | # class_type = sys.argv[1] 44 | # else: 45 | # raise ValueError(f'Wrong Input class type. Should be one of {list(class_position.keys())}') 46 | 47 | # dataset parameters 48 | augment_data = True 49 | crop_size = (144,144) 50 | train_frac = 0.85 51 | non_class_fraction = 0.15 52 | 53 | # the full data 54 | df = load_sample_df(path_to_data+'train_samples.csv', class_type=class_type, others_frac=non_class_fraction, seed=1) 55 | # train validation split 56 | train_df, val_df = train_test_split(df, train_size=train_frac, random_state=1) 57 | # load the test set 58 | test_df = load_sample_df(path_to_data+'test_samples_small.csv', class_type=class_type, others_frac=non_class_fraction, seed=1) 59 | 60 | # the train and validation datasets 61 | train_set = dataset(train_df, path_to_img, path_to_mask, class_position[class_type], augment=augment_data, crop_size=crop_size) 62 | val_set = dataset(val_df, path_to_img, path_to_mask, class_position[class_type], augment=augment_data, crop_size=crop_size) 63 | test_set = dataset(test_df, path_to_img, path_to_mask, class_position[class_type], augment=augment_data, crop_size=crop_size) 64 | 65 | # parameters for the dataloader , adapt batch size to ensure at least 4 batches 66 | train_dataloader_params = {'batch_size': min(64, int(train_set.__len__()/4)), 'shuffle': True, 'num_workers': 6} 67 | val_dataloader_params = {'batch_size': min(64, int(val_set.__len__()/4)), 'shuffle': True, 'num_workers': 6} 68 | 69 | # The data loader 70 | train_dataloader = torch.utils.data.DataLoader(train_set, **train_dataloader_params) 71 | val_dataloader = torch.utils.data.DataLoader(val_set, **val_dataloader_params) 72 | 73 | #%%------------------------------------------------------------------------------------------------- 74 | # Training Settings 75 | # get GPU if available 76 | if cuda.is_available(): 77 | device = torch.device('cuda') 78 | else: 79 | device = torch.device('cpu') 80 | 81 | # initialize output dict 82 | log_train = {} 83 | 84 | # learning parameters 85 | n_epoch = 100 86 | nb_epochs_finished = 0 87 | lr = 0.0001 88 | 89 | # the loss 90 | loss_function = BinaryDiceLoss()#nn.CrossEntropyLoss() 91 | 92 | # initialize de model 93 | model = U_net3(in_channel=11) 94 | model = model.to(device) 95 | 96 | # the optimizer 97 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 98 | 99 | # model name 100 | model_name = 'Unet_' + class_type 101 | # save params to log 102 | log_train['params'] = {'device':device, 'train_dataloader_params':train_dataloader_params, \ 103 | 'val_dataloader_params':val_dataloader_params, \ 104 | 'n_epoch':n_epoch, 'lr':lr, 'loss_function':str(loss_function), \ 105 | 'optimizer':str(optimizer), 'class':class_type, 'data_augmentation':augment_data, \ 106 | 'train size':train_df.shape[0], 'validation size':val_df.shape[0], \ 107 | 'non_class_fraction':non_class_fraction, \ 108 | 'N parameters Unet':sum(p.numel() for p in model.parameters())} 109 | # save split indices 110 | log_train['split_indices'] = {'train':train_set.sample_df.index.tolist(), 'val':val_set.sample_df.index.tolist()} 111 | # print parameter summary 112 | print_param_summary(**log_train['params']) 113 | 114 | #%%------------------------------------------------------------------------------------------------- 115 | # Training procedure 116 | print('-'*100+'\n'+'Training'.center(100)+'\n'+'-'*100) 117 | # load model from checkpoint if any 118 | checkpoint_name = path_to_output+model_name+'_checkpoint.pth' 119 | try: 120 | checkpoint = torch.load(checkpoint_name) 121 | nb_epochs_finished = checkpoint['nb_epochs_finished'] 122 | model.load_state_dict(checkpoint['model_state']) 123 | optimizer.load_state_dict(checkpoint['optimizer_state']) 124 | log_train = checkpoint['log_train'] 125 | print(f'\n>>> Checkpoint loaded with {nb_epochs_finished} epochs finished.\n') 126 | except FileNotFoundError: 127 | print('\n>>> Starting from scratch.\n') 128 | except: 129 | print('Error when loading the checkpoint.') 130 | exit(1) 131 | print('-'*100) 132 | 133 | # train 134 | best_F1 = 0.0 135 | for epoch in range(nb_epochs_finished, n_epoch): 136 | print(f'|- Epoch {epoch+1:02d}/{n_epoch:02d}'.ljust(16)) 137 | sum_loss = 0.0 138 | jaccard_train = [] 139 | jaccard_val = [] 140 | f1_train = [] 141 | f1_val = [] 142 | 143 | for b, (train_data, train_label) in enumerate(train_dataloader): 144 | # Enable autograd 145 | train_data.requires_grad = True 146 | train_label.requires_grad = True 147 | # Transfer to GPU if available 148 | train_data, train_label = train_data.to(device), train_label.to(device) 149 | # forward pass 150 | output = model(train_data) 151 | # compute the loss 152 | loss = loss_function(output, train_label) # Output B x 2 x H x W ; Target B x H x W 153 | sum_loss += loss.item() 154 | #reset gardient 155 | optimizer.zero_grad() 156 | #backward pass 157 | loss.backward() 158 | # Gradient step 159 | optimizer.step() 160 | # compute train scores 161 | with torch.set_grad_enabled(False): 162 | l, o = train_label.long().flatten().cpu(), output.argmax(dim=1).flatten().cpu() 163 | jaccard_train.append(jaccard_score(l, o)) 164 | f1_train.append(f1_score(l, o)) 165 | # print the progress bar 166 | print_progessbar(b, train_dataloader.__len__(), Name='|--- Train', Size=20) 167 | # print the train loss 168 | print(f' | Loss {sum_loss:.3f}') 169 | 170 | # compute validation scores 171 | with torch.set_grad_enabled(False): 172 | for b, (val_data, val_label) in enumerate(val_dataloader): 173 | val_data, val_label = val_data.to(device), val_label.to(device) 174 | val_pred = model(val_data).argmax(dim=1) 175 | l, o = val_label.long().flatten().cpu(), val_pred.flatten().cpu() 176 | jaccard_val.append(jaccard_score(l, o)) 177 | f1_val.append(f1_score(l, o)) 178 | # print the progress bar 179 | print_progessbar(b, val_dataloader.__len__(), Name='|--- Validation', Size=20) 180 | 181 | # append values to log 182 | append_scores(log_train, epoch=epoch+1, loss=sum_loss, \ 183 | jaccard_train=jaccard_train, jaccard_val=jaccard_val, \ 184 | f1_train=f1_train, f1_val=f1_val) 185 | # print validations' scores 186 | print(f' | Jaccard {log_train["jaccard_val"]["mean"][-1]:.2%}'.ljust(15) + \ 187 | f' | F1-score {log_train["f1_val"]["mean"][-1]:.2%}'.ljust(15)) 188 | print('-'*100) 189 | 190 | # save the current model state as checkpoint 191 | checkpoint = {'nb_epochs_finished': epoch+1, \ 192 | 'model_state': model.state_dict(), \ 193 | 'optimizer_state': optimizer.state_dict(), \ 194 | 'log_train':log_train} 195 | torch.save(checkpoint, checkpoint_name) 196 | 197 | # save the model state if better F1-score 198 | if best_F1 < log_train["f1_val"]["mean"][-1]: 199 | best_model = {'nb_epochs': epoch+1, \ 200 | 'model_state': model.state_dict(), \ 201 | 'F1_score': log_train["f1_val"]["mean"][-1], \ 202 | 'Jaccard':log_train["f1_val"]["mean"][-1]} 203 | torch.save(best_model, path_to_output+model_name+'_best.pickle') 204 | best_F1 = log_train["f1_val"]["mean"][-1] 205 | 206 | # Save the log of training 207 | with open(path_to_output+model_name+'_log_train.pickle', 'wb') as handle: 208 | pickle.dump(log_train, handle, protocol=pickle.HIGHEST_PROTOCOL) 209 | print('\n>>> LOG saved on disk at'+path_to_output+model_name+'_log_train.pickle') 210 | 211 | # Save trained model state dict 212 | torch.save(model.state_dict(), path_to_output+model_name+'_trained.pt') 213 | print('\n>>> Trained model saved on disk at'+path_to_output+model_name+'_trained.pt') 214 | 215 | #%%------------------------------------------------------------------------------------------------- 216 | # Get the scores for the train, valdation and test dataset 217 | 218 | # Loading the best model 219 | best_model_param = torch.load(path_to_output+model_name+'_best.pickle') 220 | model.load_state_dict(best_model_param['model_state']) 221 | print(f'>>> Best Model loaded from epoch {best_model_param['nb_epochs']}\n') 222 | 223 | # Compute the scores 224 | datasets = {'train':train_set, 'validation':val_set, 'test':test_set} 225 | all_scores = {} 226 | 227 | print('-'*100+'\n'+'Computing Scores'.center(100)+'\n'+'-'*100) 228 | with torch.set_grad_enabled(False): 229 | for name, data_set in datasets.item(): 230 | print('|-- ' + name) 231 | all_scores[name] = get_dataset_scores(data_set, model, augmented_pred=True, verbose=True) 232 | 233 | # save the scores 234 | with open(path_to_output+model_name+'_scores.pickle', 'wb') as handle: 235 | pickle.dump(all_scores, handle, protocol=pickle.HIGHEST_PROTOCOL) 236 | print('\n>>> Scores saved on disk at'+path_to_output+model_name+'_scores.pickle') 237 | -------------------------------------------------------------------------------- /Code/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib 3 | import numpy as np 4 | import pandas as pd 5 | from ast import literal_eval 6 | import os 7 | import shapely 8 | import shapely.wkt 9 | import descartes 10 | import skimage 11 | import rasterio 12 | import torch 13 | import torch.cuda as cuda 14 | import pickle 15 | 16 | import warnings 17 | warnings.filterwarnings("ignore", message="Dataset has no geotransform set. The identity matrix may be returned") 18 | 19 | # ------------------------ Extract Polygons ------------------------------------ 20 | 21 | def get_polygon_list(img_id, class_id, wkt_df): 22 | """ 23 | Load the polygons from the wkt dataframe for the specifed img and class. 24 | ------------ 25 | INPUT 26 | |---- img_id (str) the image id 27 | |---- class_id (int) the class identifier 28 | |---- wkt_df (pandas.DataFrame) dataframe containing the wkt 29 | OUTPUT 30 | |---- polygon_list (shapely.MultiPolygon) the list of polygon 31 | """ 32 | all_polygon = wkt_df[wkt_df.ImageId == img_id] 33 | polygon = all_polygon[all_polygon.ClassType == class_id].MultipolygonWKT 34 | polygon_list = shapely.wkt.loads(polygon.values[0]) 35 | return polygon_list 36 | 37 | def get_scale_factor(img_id, grid_df, img_size): 38 | """ 39 | Load the polygons from the wkt dataframe for the specifed img and class. 40 | ------------ 41 | INPUT 42 | |---- img_id (str) the image id 43 | |---- grid_df (pandas.DataFrame) dataframe containing the scalling info 44 | |---- img_size (tuple) dimension of the image 45 | OUTPUT 46 | |---- scale (tuple) the scalling factor for the given image 47 | """ 48 | img_h, img_w = img_size 49 | xmax = grid_df.loc[grid_df['ImageId']==img_id, 'Xmax'].values[0] 50 | ymin = grid_df.loc[grid_df['ImageId']==img_id, 'Ymin'].values[0] 51 | scale = ((img_w-2)/xmax, (img_h-2)/ymin) 52 | return scale 53 | 54 | def scale_polygon_list(polygon_list, scale): 55 | """ 56 | Scale the polygon list. 57 | ------------ 58 | INPUT 59 | |---- polygon_list (shapely.MultiPolygon) the list of polygon 60 | |---- scale (tuple) the scalling factor for the given image 61 | OUTPUT 62 | |---- polygon_list_scaled () scaled polygon_list 63 | """ 64 | polygon_list_scaled = shapely.affinity.scale(polygon_list, \ 65 | xfact=scale[0], \ 66 | yfact=scale[1], \ 67 | origin = [0., 0., 0.]) 68 | return polygon_list_scaled 69 | 70 | def plot_polygons(ax, polygon_dict, color_dict, zorder_dict, legend=True, **legend_kwargs): 71 | """ 72 | Plot the polygon on a matplotlib Axes. 73 | ------------ 74 | INPUT 75 | |---- ax (matplotlib.Axes) the axes on which to plot 76 | |---- polygon_dict (dict) dictionnary of shapely.MultiPolygon for each classe by name 77 | |---- color_dict (dict) the color associated with each classes 78 | |---- legend (bool) whether to add a legend to the plot 79 | |---- zorder_dict (dict) the order of stacking the classes 80 | |---- legend_kwargs (kwargs) keywords arguments for the legend 81 | OUTPUT 82 | |---- None 83 | """ 84 | for class_name in list(zorder_dict.keys()): 85 | ax.add_patch(descartes.PolygonPatch(polygon_dict[class_name], \ 86 | color=color_dict[class_name], \ 87 | zorder=zorder_dict[class_name], \ 88 | linewidth=0, label=class_name)) 89 | if legend : ax.legend(**legend_kwargs) 90 | 91 | def plot_masks(ax, masks, order_dict, color_dict, zorder_dict, mask_alpha=1, legend=False, **legend_kwargs): 92 | """ 93 | Plot the masks on a matplotlib Axes. 94 | ------------ 95 | INPUT 96 | |---- ax (matplotlib.Axes) the axes on which to plot 97 | |---- masks (3D numpy.array) the masks to plot (H x W x C) 98 | |---- order_dict (dict) dictionnary of class order in the resulting mask 99 | |---- color_dict (dict) the color associated with each classes 100 | |---- zorder_dict (dict) the order of stacking the classes 101 | |---- legend_kwargs (kwargs) keywords arguments for the legend 102 | OUTPUT 103 | |---- None 104 | """ 105 | for class_name in list(zorder_dict.keys()): 106 | pos = order_dict[class_name] 107 | m = np.ma.masked_where(masks[:,:,pos-1] == 0, masks[:,:,pos-1]) 108 | ax.imshow(m, cmap = matplotlib.colors.ListedColormap(['white', color_dict[class_name]]), \ 109 | vmin=0, vmax=1, alpha=mask_alpha, zorder=zorder_dict[class_name]) 110 | 111 | if legend : ax.legend([matplotlib.patches.Patch(fc=color, alpha=mask_alpha) for color in list(color_dict.values())], list(color_dict.keys()), **legend_kwargs) 112 | 113 | def compute_polygon_mask(polygon_list, img_size): 114 | """ 115 | Convert the shapely.MultiPolygon into a numpy mask. 116 | ------------ 117 | INPUT 118 | |---- polygon_list (shapely.MultiPolygon) the list of polygon 119 | |---- img_size (tuple) dimension of the image 120 | OUTPUT 121 | |---- mask (2D numpy.array) the mask 122 | """ 123 | # fill mask image 124 | mask = np.zeros(img_size, np.uint8) 125 | 126 | # add polygon to mask 127 | for poly in polygon_list: 128 | rc = np.array(list(poly.exterior.coords)) 129 | rr, cc = skimage.draw.polygon(rc[:,1], rc[:,0]) 130 | mask[rr, cc] = 1 131 | # remove holes 132 | for poly in polygon_list: 133 | for poly_int in poly.interiors: 134 | rc = np.array(list(poly_int.coords)) 135 | rr, cc = skimage.draw.polygon(rc[:,1], rc[:,0]) 136 | mask[rr, cc] = 0 137 | 138 | return mask 139 | 140 | def get_polygons_masks(polygon_dict, order_dict, img_size, filename=None): 141 | """ 142 | Convert the shapely polygons_dict into one numpy mask. 143 | ------------ 144 | INPUT 145 | |---- polygon_dict (dict) dictionnary of shapely.MultiPolygon for each classe by name 146 | |---- order_dict (dict) dictionnary of class order in the resulting mask 147 | |---- img_size (tuple) dimension of the image 148 | |---- filename (str) path to save the mask (save only if given) 149 | OUTPUT 150 | |---- all_mask (3D numpy.array) the mask in dimension (C x H x W) 151 | """ 152 | all_mask = np.zeros((img_size[0], img_size[1], len(order_dict)), np.uint8) 153 | for class_name, poly_list in polygon_dict.items(): 154 | all_mask[:,:,order_dict[class_name]-1] = compute_polygon_mask(poly_list, img_size) 155 | 156 | if filename is not None: 157 | #skimage.external.tifffile.imsave(filename, np.moveaxis(all_mask, 2, 0)) 158 | save_geotiff(filename, np.moveaxis(all_mask, 2, 0), dtype='uint8') 159 | 160 | return all_mask 161 | 162 | def save_geotiff(filename, img, dtype='uint16'): 163 | """ 164 | Save the image in GeoTiff. 165 | ------------ 166 | INPUT 167 | |---- filename (str) the filename to save the image 168 | |---- img (3D numpy array) the image to save as B x H x W 169 | OUTPUT 170 | |---- None 171 | """ 172 | with rasterio.open(filename, \ 173 | mode='w', \ 174 | driver='GTiff', \ 175 | width=img.shape[2], \ 176 | height=img.shape[1], \ 177 | count=img.shape[0], \ 178 | dtype=dtype) as dst: 179 | dst.write(img) 180 | 181 | def get_polygon_dict(img_id, class_dict, img_size, wkt_df, grid_df): 182 | """ 183 | Get the polygon list for the specified image classes from the raw information. 184 | ------------ 185 | INPUT 186 | |---- img_id (str) the image id 187 | |---- class_dict (dictionnary) dictionnary specifying group of classes and name 188 | |---- img_size (tuple) dimension of the image 189 | |---- wkt_df (pandas.DataFrame) dataframe containing the wkt 190 | |---- grid_df (pandas.DataFrame) dataframe containing the scalling info 191 | OUTPUT 192 | |---- polygon_dict (dict) dictionnary of shapely.MultiPolygon for each classe by name 193 | """ 194 | polygon_dict = {} 195 | for class_name, class_val in class_dict.items(): 196 | # get first polygon list 197 | poly_list = get_polygon_list(img_id, class_val[0], wkt_df) 198 | scale = get_scale_factor(img_id, grid_df, img_size) 199 | poly_list = scale_polygon_list(poly_list, scale) 200 | # get polygon_list for next class 201 | if len(class_val) > 1: 202 | for next_class in class_val[1:]: 203 | next_poly_list = get_polygon_list(img_id, next_class, wkt_df) 204 | scale = get_scale_factor(img_id, grid_df, img_size) 205 | next_poly_list = scale_polygon_list(next_poly_list, scale) 206 | poly_list = poly_list.union(next_poly_list) 207 | 208 | polygon_dict[class_name] = poly_list 209 | return polygon_dict 210 | 211 | # ------------------------ Preprocess Images ----------------------------------- 212 | 213 | def load_image(filepath, img_id): 214 | """ 215 | Load the image associated with the id provided. 216 | ------------ 217 | INPUT 218 | |---- filepath (str) the path to the sixteen bands images 219 | |---- img_id (str) the id of the image to load 220 | OUTPUT 221 | |---- img_A (3D numpy array) the SWIR bands 222 | |---- img_M (3D numpy array) the Multispectral bands 223 | |---- img_P (2D numpy array) the Panchromatic band 224 | """ 225 | img_M = skimage.img_as_float(skimage.io.imread(filepath+img_id+'_M.tif', plugin="tifffile")) 226 | img_A = skimage.img_as_float(skimage.io.imread(filepath+img_id+'_A.tif', plugin="tifffile")) 227 | img_P = skimage.img_as_float(skimage.io.imread(filepath+img_id+'_P.tif', plugin="tifffile")) 228 | 229 | return np.moveaxis(img_A, 0, 2), np.moveaxis(img_M, 0, 2), img_P 230 | 231 | def contrast_stretch(img, percentile=(0.5,99.5), out_range=(0,1)): 232 | """ 233 | Stretch the image histogram for each channel independantly. The image histogram 234 | is streched such that the lower and upper percentile are saturated. 235 | ------------ 236 | INPUT 237 | |---- img (3D numpy array) the image to stretch (H x W x B) 238 | |---- percentile (tuple) the two percentile value to saturate 239 | |---- out_range (tuple) the output range value 240 | OUTPUT 241 | |---- img_adj (3D numpy array) the streched image (H x W x B) 242 | """ 243 | n_band = img.shape[2] 244 | q = [tuple(np.percentile(img[:,:,i], [0,99.5])) for i in range(n_band)] 245 | img_adj = np.stack([skimage.exposure.rescale_intensity(img[:,:,i], in_range=q[i], out_range=out_range) for i in range(n_band)], axis=2) 246 | return img_adj 247 | 248 | def pansharpen(img_MS, img_Pan, order=2, W=1.5, stretch_perc=(0.5,99.5)): 249 | """ 250 | Perform an image fusion of the multispectral image using the panchromatic one. 251 | The image is fisrt upsampled and interpolated. Then the panchromatic image is 252 | summed. And the image histogram is stretched. 253 | ------------ 254 | INPUT 255 | |---- img_MS (3D numpy.array) the multispectral data as H x W x B 256 | |---- img_Pan (2D numpy.array) the Panchromatic data 257 | |---- order (int) the interpolation method to use (according to the scikit image method resize) 258 | |---- W (float) the weight of the summed panchromatic 259 | OUTPUT 260 | |---- img_fused (3D numpy.array) the pansharpened multispectral data as H x W x B 261 | """ 262 | m_up = skimage.transform.resize(img_MS, img_Pan.shape, order=order) 263 | img_fused = np.multiply(m_up, W*np.expand_dims(img_Pan, axis=2)) 264 | img_fused = contrast_stretch(img_fused, stretch_perc) 265 | return img_fused 266 | 267 | def NDVI(R, NIR): 268 | """ 269 | Compute the NDVI from the red and near infrared bands. Note that this 270 | function can be used to compute the NDWI by calling it with NDVI(NIR, G). 271 | ------------ 272 | INPUT 273 | |---- R (2D numpy.array) the red band 274 | |---- NIR (2D numpy.array) the near infrared band 275 | OUTPUT 276 | |---- NDVI (2D numpy.array) the NDVI 277 | """ 278 | return (NIR - R)/(NIR + R + 1e-9) 279 | 280 | def EVI(R, NIR, B): 281 | """ 282 | Compute the Enhenced Vegetation index from the red, near infrared bands and 283 | blue band. 284 | ------------ 285 | INPUT 286 | |---- R (2D numpy.array) the red band 287 | |---- NIR (2D numpy.array) the near infrared band 288 | |---- B (2D numpy.array) the blue band 289 | OUTPUT 290 | |---- evi (2D numpy.array) the EVI 291 | """ 292 | L, C1, C2 = 1.0, 6.0, 7.5 293 | evi = (NIR - R) / (NIR + C1 * R - C2 * B + L) 294 | evi = evi.clip(max=np.percentile(evi, 99), min=np.percentile(evi, 1)) 295 | evi = evi.clip(max=1, min=-1) # clip if too big 296 | return evi 297 | 298 | # ------------------------ Extract images -------------------------------------- 299 | 300 | def get_crops_grid(img_h, img_w, crop_size, overlap=None): 301 | """ 302 | Get a list of crop coordinates for the image in a grid fashion starting in 303 | the upper left corner. There might be some un-considered pixels on the 304 | right and bottom. 305 | ------------ 306 | INPUT 307 | |---- img_h (int) the image height 308 | |---- img_w (int) the image width 309 | |---- crop_size (tuple) the height and width of the crop 310 | |---- overlap (tuple) the total amount overlapped between crop in the grid 311 | OUTPUT 312 | |---- crops (list of tuple) list of crop upper left crops coordinates 313 | """ 314 | # compute the overlap if No given 315 | if overlap is None: 316 | nx = np.ceil(img_h / crop_size[0]) 317 | ny = np.ceil(img_w / crop_size[1]) 318 | excess_y = ny*crop_size[0] - img_h 319 | excess_x = nx*crop_size[1] - img_w 320 | overlap = (np.ceil(excess_y / (ny-1)), np.ceil(excess_x / (nx-1))) 321 | 322 | crops = [] # (row, col) 323 | for i in np.arange(0,img_h+crop_size[0],crop_size[0]-overlap[0])[:]: 324 | for j in np.arange(0,img_w+crop_size[1],crop_size[1]-overlap[1])[:]: 325 | if i+crop_size[0] <= img_h and j+crop_size[1] <= img_w: 326 | crops.append((i, j)) 327 | return crops 328 | 329 | def load_image_part(xy, hw, filename, as_float=True): 330 | """ 331 | load an image subpart specified by the crop coordinates and dimensions. 332 | ------------ 333 | INPUT 334 | |---- xy (tuple) crop coordinsates as (row, col) 335 | |---- hw (tuple) crop dimensions as (h, w) 336 | |---- filename (str) the filename 337 | |---- as_float (bool) whether to convert the image in float 338 | OUTPUT 339 | |---- img (3D numpy array) the image subpart 340 | """ 341 | with rasterio.open(filename, mode='r') as src: 342 | img = src.read(window=rasterio.windows.Window(xy[1], xy[0], hw[1], hw[0])) 343 | if as_float: img = skimage.img_as_float(img) 344 | return img 345 | 346 | def get_represented_classes(filename_mask, order_dict, crop_coord, crop_size): 347 | """ 348 | find which classes are represented on the crop specified. 349 | ------------ 350 | INPUT 351 | |---- filename (str) the filename 352 | |---- order_dict (dict) a dictionnary specifying the class name associated 353 | | with which dimension of the masks {dim+1:'name'} 354 | |---- crop_coord (tuple) crop coordinsates as (row, col) 355 | |---- crop_size (tuple) crop dimensions as (h, w) 356 | OUTPUT 357 | |---- classes (list) list of present class name 358 | """ 359 | mask = load_image_part(crop_coord, crop_size, filename_mask, as_float=False) 360 | classes = [order_dict[cl+1] for cl in np.unique(mask.nonzero()[0])] 361 | return classes 362 | 363 | def get_samples(img_id_list, img_path, mask_path, crop_size, overlap, order_dict, cl_offset, cl_size, as_fraction=False, verbose=False): 364 | """ 365 | Produce a pandas datafarme containing all the crop informations. 366 | ------------ 367 | INPUT 368 | |---- img_id_list (list) the list of ids of image to processed 369 | |---- img_path (str) the folder path to the images 370 | |---- mask_path (str) the folder path to the masks 371 | |---- crop_size (tuple) crop dimensions as (h, w) 372 | |---- overlap (tuple) the total amount overlapped between crop in the grid 373 | |---- order_dict (dict) a dictionnary specifying the class name associated 374 | | with which dimension of the masks {dim+1:'name'} 375 | |---- cl_offset (tuple) the offset from crop to check class presence (row, col) 376 | |---- cl_size (tuple) the size of the patch where class presence is checked (h, w) 377 | |---- as_fraction (bool) whether to specify crop_size as fraction. if 378 | | True, the crop_size value represent fraction and 379 | | should be between 0 and 1 380 | |---- verbose (bool) whether to display processing 381 | OUTPUT 382 | |---- sample_df (pandas dataframe) informations for all samples 383 | """ 384 | # DataFrame (img_id, x, y, h, w classes) 385 | if verbose : 386 | print(f'>>>> Extract samples from images \n'+'-'*80) 387 | summary = {'building':0, 'misc':0, 'road':0, 'track':0, \ 388 | 'tree':0, 'crop':0, 'water':0, 'vehicle':0} 389 | # storing variables 390 | ids, row, col, H, W, cl_list = [], [], [], [], [], [] 391 | for i, id in enumerate(img_id_list): 392 | if verbose: 393 | print(f'\t|---- {i+1:02n} : cropping image {id}') 394 | summary2 = {'building':0, 'misc':0, 'road':0, 'track':0, \ 395 | 'tree':0, 'crop':0, 'water':0, 'vehicle':0} 396 | # get height width 397 | with rasterio.open(img_path+id+'.tif', mode='r') as src: 398 | img_h, img_w = src.height, src.width 399 | # define crop size from fraction if requested 400 | if as_fraction: 401 | crop_size = np.floor(img_h*crop_size[0]), np.floor(img_w*crop_size[1]) 402 | # get the grid crops 403 | crops = get_crops_grid(img_h, img_w, crop_size, overlap) 404 | # fill lists 405 | for crop in crops: 406 | ids.append(id) 407 | row.append(crop[0]) 408 | col.append(crop[1]) 409 | H.append(crop_size[0]) 410 | W.append(crop_size[1]) 411 | classes = get_represented_classes(mask_path+id+'_mask.tif', \ 412 | order_dict, \ 413 | (crop[0]+cl_offset[0], crop[1]+cl_offset[1]), \ 414 | cl_size) 415 | # count classes 416 | if verbose: 417 | for cl in classes: 418 | summary[cl] += 1 419 | summary2[cl] += 1 420 | cl_list.append(classes) 421 | #display count 422 | if verbose: 423 | for cl, count in summary2.items(): 424 | print(f'\t\t|---- {cl} : {count}') 425 | # display total count 426 | if verbose: 427 | print('-'*80+'\n>>>> Total \n') 428 | for cl, count in summary.items(): 429 | print(f'\t|---- {cl} : {count}') 430 | # build dataframe 431 | sample_df = pd.DataFrame({'img_id':ids, \ 432 | 'row':row, 'col':col, \ 433 | 'h':H, 'w':W, \ 434 | 'classes':cl_list}) 435 | return sample_df 436 | 437 | # ------------------------ Training functions --------------------------------- 438 | 439 | def print_param_summary(**params): 440 | """ 441 | Print the dictionnary passed as a table. 442 | ------------ 443 | INPUT 444 | |---- params (keyword arguments) value to display 445 | OUTPUT 446 | |---- None 447 | """ 448 | # get the max length of values and keys 449 | max_len = max([len(str(key)) for key in params.keys()])+5 450 | max_len_val = max([max([len(subval) for subval in str(val).split('\n')]) for val in params.values()])+3 451 | # print header 452 | print('-'*(max_len+max_len_val+1)) 453 | print('| Parameter'.ljust(max_len) + '| Value'.ljust(max_len_val)+'|') 454 | print('-'*(max_len+max_len_val+1)) 455 | # print values and subvalues 456 | for key, value in params.items(): 457 | for i, subvalue in enumerate(str(value).split('\n')): 458 | if i == 0 : 459 | print(f'| {key}'.ljust(max_len)+f'| {subvalue}'.ljust(max_len_val)+'|') 460 | else : 461 | print('| '.ljust(max_len)+f'| {subvalue}'.ljust(max_len_val)+'|') 462 | print('-'*(max_len+max_len_val+1)) 463 | 464 | def load_sample_df(filename, class_type, others_frac=0, seed=None): 465 | """ 466 | Load the sample information dataframe for a given class by adding a fraction 467 | of non-class_type to it. 468 | ------------ 469 | INPUT 470 | |---- filename (str) patht to the csv file. 471 | |---- class_type (str) class name to select 472 | |---- other_frac (float) between 0 and 1, specify the fraction of the 473 | | class_type dataframe of non-class_type element 474 | | to add. 475 | |---- seed (int) passed to the random_state of pandas.DatsFrame.sample 476 | OUTPUT 477 | |---- sub_df (pandas.DataFrame) The dataframe of the class_type 478 | """ 479 | df = pd.read_csv(filename, index_col=0, converters={'classes' : literal_eval}) 480 | # get samples with the given class 481 | sub_df = df[pd.DataFrame(df.classes.tolist()).isin([class_type]).any(1)] 482 | # get sample without the given class 483 | other_df = df[~pd.DataFrame(df.classes.tolist()).isin([class_type]).any(1)] 484 | # add other_frac percent of other in sub, shuffle the rows and reset the index 485 | other_df = other_df.sample(n=int(others_frac*sub_df.shape[0]), random_state=seed, axis=0) 486 | sub_df = pd.concat([sub_df, other_df], axis=0).sample(frac=1, random_state=seed) 487 | return sub_df 488 | 489 | def stat_from_list(list): 490 | """ 491 | Compute the mean and standard deviation of the list. 492 | ------------ 493 | INPUT 494 | |---- list (list) list of value 495 | OUTPUT 496 | |---- mean (float) the mean of the list values 497 | |---- std (float) the standard deviation of the list of values 498 | """ 499 | list = torch.Tensor(list) 500 | return list.mean().item(), list.std().item() 501 | 502 | def append_scores(dest_dict, **keys): 503 | """ 504 | Add the kwargs to the passed dictionnary. Each entry (key) is then a list 505 | of value. The value passe is append to such list. If a key is not present 506 | in the disctionnary, it is added. If a value is a list, the mean and std 507 | are append to the dictionnary. 508 | ------------ 509 | INPUT 510 | |---- dest_dict (dictionnary) where the values are append (modified inplace) 511 | |---- keys (keyword arguments) value to append 512 | OUTPUT 513 | |---- None 514 | """ 515 | for name, val in keys.items(): 516 | if type(val) is list: 517 | m, s = stat_from_list(val) 518 | if name in dest_dict.keys(): 519 | dest_dict[name]['mean'].append(m) 520 | dest_dict[name]['std'].append(s) 521 | else: 522 | dest_dict[name] = {} 523 | dest_dict[name]['mean'] = [m] 524 | dest_dict[name]['std'] = [s] 525 | else: 526 | if name in dest_dict.keys(): 527 | dest_dict[name].append(val) 528 | else: 529 | dest_dict[name] = [val] 530 | 531 | # ------------------------- Testing functions --------------------------------- 532 | 533 | def load_models(folder_path, class_names, model_class, **model_param): 534 | """ 535 | Load the models in the given folder. 536 | ------------ 537 | Input 538 | |---- folder_path (str) path to the folder 539 | |---- class_names (list) list of class_name to load 540 | |---- model_class (pytorch.nn.Module) a model class 541 | |---- model_param (kwargs) the model parameters as kwargs 542 | OUTPUT 543 | |---- models (dict) dictionary of models (value) for each class name (key) 544 | """ 545 | models = {} 546 | for class_type in class_names: 547 | try: 548 | state_dict = torch.load(folder_path+'Unet_'+class_type+'_trained.pt', map_location=torch.device('cpu')) 549 | models[class_type] = model_class(**model_param) 550 | models[class_type].load_state_dict(state_dict) 551 | except FileNotFoundError: 552 | pass 553 | return models 554 | 555 | def load_logs(folder_path, class_names): 556 | """ 557 | Load the training LOG in the given folder. 558 | ------------ 559 | Input 560 | |---- folder_path (str) path to the folder 561 | |---- class_names (list) list of class_name to load 562 | OUTPUT 563 | |---- logs_train (dict) dictionary of LOGS (value) for each class name (key) 564 | """ 565 | logs_train = {} 566 | for class_type in class_names: 567 | try: 568 | with open(folder_path+'Unet_'+class_type+'_log_train.pickle', 'rb') as f: 569 | logs_train[class_type] = pickle.load(f) 570 | except FileNotFoundError: 571 | pass 572 | return logs_train 573 | 574 | def plot_loss(log_train, ax, title=None, train=False, disp_std=True): 575 | """ 576 | Plot the loss, validation jaccard and F1-scores on the given axis. 577 | ------------ 578 | Input 579 | |---- log_train (dict) dictionnary contaiing the LOG of the training 580 | |---- ax (matplotlib.Axes) the axes on which to plot 581 | |---- title (str) optional title 582 | |---- train (bool) whether to inclue train scores curves 583 | |---- disp_std (bool) whether to add the 2STD range around score curves 584 | OUTPUT 585 | |---- None 586 | """ 587 | # plot loss 588 | ax.plot(log_train['epoch'], log_train['loss'], lw=2.5, color='gray', label='loss') 589 | ax.set_xlim([log_train['epoch'][0], log_train['epoch'][-1]]) 590 | # twin ax 591 | ax_r = ax.twinx() 592 | scores = ['jaccard_val', 'f1_val'] 593 | colors = ['coral', 'crimson'] 594 | if train : 595 | scores += ['jaccard_train', 'f1_train'] 596 | colors += ['cornflowerblue','dodgerblue'] 597 | # plot scores evolution 598 | for score, color in zip(scores, colors): 599 | mean, std = np.array(log_train[score]['mean']), np.array(log_train[score]['std']) 600 | ax_r.plot(log_train['epoch'], mean, lw=2, color=color, label=score.title().replace("_", " ")) 601 | if disp_std : ax_r.fill_between(log_train['epoch'], mean-2*std, mean+2*std, fc=color, alpha=0.15) 602 | # plot goodies 603 | ax_r.set_ylim([0,1]) 604 | ax.set_xlabel('Epochs') 605 | ax.set_ylabel('Loss') 606 | ax_r.set_ylabel('Scores') 607 | handles, labels = ax.get_legend_handles_labels() 608 | handles_r, labels_r = ax_r.get_legend_handles_labels() 609 | ax.legend(handles+handles_r, labels+labels_r) 610 | ax.set_title(title, loc='left', fontweight='bold') 611 | 612 | def class_prediction(model, input, augmented_pred=True, device=torch.device('cpu')): 613 | """ 614 | Segmente the passed image with the passed model and retrun the mask. It can 615 | be the average over flipped and rotated input or just the direct prediction. 616 | ------------ 617 | Input 618 | |---- model (torch.nn.Module) The model to use 619 | |---- input (torch.Tensor) The input image (B x C x H x W) 620 | |---- augmented_pred (bool) whether to take the average prediction over 621 | | flip and rotations 622 | |---- device (torch.device) the device on which operation are performed 623 | OUTPUT 624 | |---- prediction (2D torch.Tensor) the segmentation 625 | """ 626 | if augmented_pred: 627 | predictions = [] 628 | # horizontal or vertical flip 629 | for dim_flip in [None, 2, 3]: 630 | # 0, 90, 180, 270 degree rotations 631 | for rot_angle in [0,1,2,3]: 632 | if not dim_flip is None: 633 | transformed_input = torch.flip(input, dims=[dim_flip]) # flip 634 | else: 635 | transformed_input = input 636 | transformed_input = torch.rot90(transformed_input, k=rot_angle, dims=[2,3]) # rotate 637 | pred = model(transformed_input).argmax(dim=1) # predict 638 | corrected_pred = torch.rot90(pred, k=-rot_angle, dims=[1,2]) # un-rotate 639 | if not dim_flip is None: 640 | corrected_pred = torch.flip(corrected_pred, dims=[dim_flip-1])# un-flip 641 | predictions.append(corrected_pred) 642 | prediction = torch.stack(predictions, dim=3).float() # stack 643 | prediction = prediction.mean(dim=3) # mean 644 | prediction = torch.where(prediction >= 0.5, torch.ones(pred.shape).to(device), torch.zeros(pred.shape).to(device)) # thresholding 645 | else: 646 | prediction = model(input).argmax(dim=1) # predict 647 | return prediction 648 | 649 | def get_dataset_scores(data_set, model, augmented_pred=True, verbose=False): 650 | """ 651 | compute the Jaccard, Recall, Precision and F1-score from the prediction of 652 | each sample in the dataset with the passed model. 653 | ------------ 654 | INPUT 655 | |---- data_set (torch.utils.data.Dataset) the dataset to use 656 | |---- model (torch.nn.Module) the model to use 657 | |---- augmented_pred (bool) whether to take the average prediction over 658 | | flip and rotations 659 | |---- verbose (bool) whether to print a summary of the processing 660 | OUTPUT 661 | |---- scores (dict) dictionnary of scores 662 | """ 663 | # define dataloader 664 | dataloader = torch.utils.data.DataLoader(data_set, batch_size=16, shuffle=False, num_workers=4) 665 | # get GPU if available 666 | if cuda.is_available(): 667 | device = torch.device('cuda') 668 | else: 669 | device = torch.device('cpu') 670 | 671 | scores = {'jaccard':[], 'f1-score':[], 'recall':[], 'precision':[]} 672 | metrics = {'jaccard':jaccard_score, 'f1-score':f1_score, 'recall':recall_score, 'precision':precision_score} 673 | 674 | for b, (input, mask) in enumerate(dataloader): 675 | input, mask = input.to(device), mask.to(device).long() 676 | output = class_prediction(model, input, augmented_pred=augmented_pred, device=device) 677 | # compute scores for each image separatly 678 | for i in range(output.shape[0]): 679 | m, o = mask[i,:,:].flatten().cpu(), output[i,:,:].flatten().cpu() 680 | for name, metric in metrics.items(): 681 | scores[name].append(metric(m, o)) 682 | if verbose : print_progessbar(b, dataloader.__len__(), '|---- Batch', Size=20, end_char='\n') 683 | return scores 684 | 685 | def print_progessbar(N, Max, Name='', Size=10, end_char=''): 686 | """ 687 | Print a progress bar. To be used in a for-loop and called at each iteration 688 | with the iteration number and the max number of iteration. 689 | ------------ 690 | INPUT 691 | |---- N (int) the iteration current number 692 | |---- Max (int) the total number of iteration 693 | |---- Name (str) an optional name for the progress bar 694 | |---- Size (int) the size of the progress bar 695 | |---- end_char (str) the print end parameter to used in the end of the 696 | | of the progress bar (default is '') 697 | OUTPUT 698 | |---- None 699 | """ 700 | print(f'\r{Name} {N+1:03d}/{Max:03d}'.ljust(26) \ 701 | + f'[{"#"*int(Size*(N+1)/Max)}'.ljust(Size+1) + f'] {(int(100*(N+1)/Max))}%'.ljust(6), \ 702 | end=end_char) 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | 720 | 721 | 722 | 723 | 724 | 725 | 726 | 727 | 728 | 729 | 730 | 731 | 732 | 733 | 734 | 735 | # 736 | -------------------------------------------------------------------------------- /Figures/Segmentations_labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoine-spahr/Deep-Satellite-Image-Segmentation/ead36b52258d8f3b9dc9fce527cb50d8810565e9/Figures/Segmentations_labels.png -------------------------------------------------------------------------------- /Figures/indices.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoine-spahr/Deep-Satellite-Image-Segmentation/ead36b52258d8f3b9dc9fce527cb50d8810565e9/Figures/indices.png -------------------------------------------------------------------------------- /Figures/pansharpening.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoine-spahr/Deep-Satellite-Image-Segmentation/ead36b52258d8f3b9dc9fce527cb50d8810565e9/Figures/pansharpening.png -------------------------------------------------------------------------------- /Figures/sample_ex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoine-spahr/Deep-Satellite-Image-Segmentation/ead36b52258d8f3b9dc9fce527cb50d8810565e9/Figures/sample_ex.png -------------------------------------------------------------------------------- /Figures/train_samples.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoine-spahr/Deep-Satellite-Image-Segmentation/ead36b52258d8f3b9dc9fce527cb50d8810565e9/Figures/train_samples.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Semantic Segmentation of Satellite Images 🛰️ 2 | --- 3 | ## Goal 4 | Wouldn't it be nice to automatically get a map from a Satellite image ? Map are easier to read and simplify the representation of the world. Segmentation using deep learning is a popular approach to tackle this problem : _Deep Semantic Segmentation_. The goal is to get a map composed of few categories (buildings, roads, tracks, trees, crops, water, ...) from a multispectral satellite image. Each pixel is thus classified into one of the categories. 5 | 6 | ## Data 7 | The problem is addressed using the _DSTL dataset_ from this [kaggle competition](https://www.kaggle.com/c/dstl-satellite-imagery-feature-detection/data). It is composed of 25 Worldview 3 satellite images already labelled for 10 categories : building, man-made-structure (misc), road, tracks, trees, crops, standing water, running water, large vehicles and small vehicles. 8 categories are used by combining running and standing water together and large and small vehicles. Below is presented the 25 images masks. Each images are composed of : 8 | * 8 multispectral bands ranging from the visible spectrum to the infrared one with a mild spatial resolution. 9 | * 8 bands of short wave infrared at a low spatial resolution. 10 | * 1 Panchromatic bands at high spatial resolution 11 | 12 | ![labels overview](Figures/Segmentations_labels.png "Labels") 13 | 14 | The 25 scenes are heterogenous and the classes do not seems evenly distributed among them. Some contains only crops and trees while some are mainly covered by urban areas. 15 | 16 | ## Data Preparation 17 | The resolution of the SWIR bands is way lower than the multispectral ones and the panchromatic, and may not bring relevant information at the spatial resolution required. That's why only the 8 multispectral bands and the panchromatic ones are used. The input are preprocessed to improve the spatial resolution (pansharpening) and create some common indices of remote sensing that usually helps in discriminating the classes. 18 | 19 | ### Indices 20 | 21 | Three indices are computed from the multispectral bands. 22 | 1. The [NDVI](https://en.wikipedia.org/wiki/Normalized_difference_vegetation_index) that report the vegetation content of a pixel. 23 | 2. The [NDWI](https://en.wikipedia.org/wiki/Normalized_difference_water_index) highlight the water bodies. 24 | 3. The [EVI](https://en.wikipedia.org/wiki/Enhanced_vegetation_index) also highlight vegetation by taking into account the blue band more sensitive to atmospheric effects, hence correcting for them. 25 | 26 | An example of the three indices are presented below for the image 6100_2_2. With the 3 indices, a total of 11 bands are available for each pixel. 27 | 28 | ![Indices](Figures/indices.png "indices") 29 | 30 | ### Pansharpening 31 | The spatial resolution of the images are increased by using the information carried in the panchromatic band. first the 11 bands are upsampled to the size of the panchromatic one using a bicubic upsampling. Then each bands is multiplied with 1.5 time the panchromatic image. Finally the each bands histograms is stretched to cover the full range of the color, enhancing the contrast of the image. Below is presented the effect of the pansharpening on a subset of an image. 32 | 33 | ![pansharpening](Figures/pansharpening.png "pansharpening") 34 | 35 | ### Training Sample Generation 36 | 37 | From the 25 labelled images, 21 are used for the training (and validation) set, while 4 images are kept aside as the test set (6100_2_2, 6060_2_3, 6110_4_0, 6160_2_1). A sample of for the train/validation set is composed of a 160x160 crop from the images. The dataset is composed of a `pandas.DataFrame` containing the crop coordinates for each image id. The different crops overlap in a way that the center 80x80 regions of the samples cover the whole image since the center region will be more weighted in the loss calculation. The GIF below show how the samples are taken from the image. For each crop which class is present on the center region is also registered in order to be able to extract only crops that contains a building for example. 38 | 39 | ![train_sample_generation](Figures/train_samples.gif "sample_generation") 40 | 41 | Below five training samples example for each of the eight classes are presented as the True color image with the mask overlay. We can observe that the mask are sometime inexact or approximate, especially for the crop where trees are counted as crops. 42 | 43 | ![train sample example](Figures/sample_ex.png "sample example") 44 | 45 | ## Training Procedure 46 | 47 | ## Results 48 | --------------------------------------------------------------------------------