├── .gitignore ├── .idea └── clusternet_segmentation.iml ├── Dockerfile ├── README.md ├── interactive_tool ├── clusters_to_class.py ├── crop_img.py └── inter_cluster.py ├── main.py ├── models ├── ContextEncoder.py ├── UNetValid.py ├── VGG_UNet.py ├── __init__.py └── __pycache__ │ ├── UNetValid.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── requirements.txt ├── super_train.py ├── tools ├── __init__.py ├── clusters2classes.py ├── evaluation.py ├── full_prediction.py ├── patch_generator.py └── rasterization.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc └── clustering.cpython-36.pyc └── clustering.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.npy 3 | *.csv 4 | *.hdf5 5 | *.h5 6 | *.png 7 | *.json 8 | *.gz 9 | *.tar 10 | *.tif 11 | *.tiff 12 | *.shp 13 | *.xml 14 | *.jpg 15 | *.tfrecords 16 | *.JPEG 17 | data/* -------------------------------------------------------------------------------- /.idea/clusternet_segmentation.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 2 | 3 | ARG FAISS_CPU_OR_GPU=cpu 4 | ARG FAISS_VERSION=1.4.0 5 | 6 | ENV HOME /root 7 | 8 | RUN apt-get update && \ 9 | apt-get install -y build-essential curl software-properties-common bzip2 python3-pip && \ 10 | curl https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh > /tmp/conda.sh && \ 11 | bash /tmp/conda.sh -b -p /opt/conda && \ 12 | /opt/conda/bin/conda update -n base conda && \ 13 | /opt/conda/bin/conda install -y -c pytorch faiss-${FAISS_CPU_OR_GPU}=${FAISS_VERSION} && \ 14 | apt-get remove -y --auto-remove curl bzip2 && \ 15 | apt-get clean && \ 16 | rm -fr /tmp/conda.sh 17 | 18 | #RUN /opt/conda/bin/conda install -c conda-forge tensorflow-gpu \ 19 | # keras \ 20 | # scikit-image \ 21 | # click 22 | 23 | RUN /opt/conda/bin/pip install sklearn \ 24 | scikit-image \ 25 | click \ 26 | pandas \ 27 | tensorflow-gpu==1.11.0 \ 28 | Keras==2.2.2 \ 29 | Keras-Applications==1.0.4 \ 30 | Keras-Preprocessing==1.0.2 31 | 32 | ENV LC_ALL=C.UTF-8 33 | ENV LANG=C.UTF-8 34 | 35 | ENV PATH="/opt/conda/bin:${PATH}" 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ClusterNet for Unsupervised Segmentation on satellite images 2 | Semantic segmentation on satellite images is used to automatically detect and classify 3 | objects of interest while preserving the geometric structure of these objects. In this work, 4 | we present a novel segmentation approach of weekly supervised learning based on 5 | applying K-Means clustering technique to the high-dimensional output of the Deep 6 | Neural Network. This approach is generic and can be applied to any image and to any 7 | class of objects on that image. At the same time, it does not require the ground-truth 8 | during the training stage. 9 | 10 | 11 | 12 | ### Prerequisites 13 | 14 | ``` 15 | Keras 2.2.2 16 | faiss 1.4.0 17 | ``` 18 | 19 | ### Installing 20 | 21 | All library requirementes needed to run the script are in requirements.txt file 22 | 23 | To install them in your local environment use command: 24 | 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## Running the experiments 30 | 31 | 1. You have to generate the train and test (not required) dataset of images of the same size from the original images. For this call tools/patch_generator.py from the root of repository. 32 | For example, to generate images of size 512*512 with no overlap, use the command: 33 | ``` 34 | python ./tools/patch_generator --output_folder --patch_shape 512 512 --stride 512 512 35 | ``` 36 | 37 | 2. To train the keras model in an usupervised fashion run: 38 | ``` 39 | python ./main.py 40 | ``` 41 | *Note*: in this work we use ImageDataGenerator from keras.preprocessing.image which requires that your images are stored in subfolers one levele deeper that the folder you specified as . 42 | 43 | ## Inspired by 44 | *Deep Clustering for Unsupervised Learning of Visual Features* - [DeepCluster](https://github.com/facebookresearch/deepcluster) 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /interactive_tool/clusters_to_class.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import skimage.io 7 | import os 8 | import click 9 | 10 | 11 | @click.command() 12 | @click.argument('base_image_file', type=click.STRING) 13 | @click.argument('pos_neg_mask_file', type=click.STRING) 14 | @click.argument('masks_folder', type=click.STRING) 15 | @click.option('--output_folder', type=click.STRING, default='') 16 | def main(base_image_file, pos_neg_mask_file, masks_folder, output_folder): 17 | base_image = skimage.io.imread(base_image_file) 18 | pos_neg_mask = np.load(pos_neg_mask_file) 19 | 20 | pos_clusts = np.unique(base_image[pos_neg_mask == 1]) 21 | neg_clusts = np.unique(base_image[pos_neg_mask == -1]) 22 | 23 | intersect = np.intersect1d(pos_clusts, neg_clusts) 24 | 25 | for clust in intersect: 26 | pos = np.sum((base_image == clust) * (pos_neg_mask == 1)) 27 | neg = np.sum((base_image == clust) * (pos_neg_mask == -1)) 28 | pos_ratio = pos / (pos + neg) 29 | 30 | if pos_ratio < 0.0: 31 | index = np.argwhere(pos_clusts == clust) 32 | pos_clusts = np.delete(pos_clusts, index) 33 | 34 | for file in os.listdir(masks_folder): 35 | file_path = os.path.join(masks_folder, file) 36 | mask = skimage.io.imread(file_path).astype('float32') 37 | 38 | if output_folder != '': 39 | out_file_path = os.path.join(output_folder, file) 40 | skimage.io.imsave(out_file_path, np.isin(mask, pos_clusts)) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /interactive_tool/crop_img.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import click 6 | import os 7 | import skimage.io 8 | import skimage.util 9 | import numpy 10 | import math 11 | 12 | @click.command() 13 | @click.argument('input_images', type=click.STRING) 14 | @click.option('--crop_shape', nargs=2, type=click.INT, default=(309, 309)) 15 | @click.option('--output_folder', type=click.STRING) 16 | def main(input_images, crop_shape, output_folder): 17 | assert os.path.isdir(input_images) 18 | 19 | height_crop = int((512 - crop_shape[0])) / 2 20 | width_crop = int((512 - crop_shape[1])) / 2 21 | 22 | for file in os.listdir(input_images): 23 | image_file = os.path.join(input_images, file) 24 | image = skimage.io.imread(image_file) 25 | 26 | crop_img = skimage.util.crop(image, ((math.ceil(height_crop), int(height_crop)), (math.ceil(width_crop), int(width_crop)))) 27 | 28 | out_image_file = os.path.join(output_folder, file) 29 | skimage.io.imsave(out_image_file, crop_img) 30 | 31 | return 0 32 | 33 | 34 | if __name__ == "__main__": 35 | main() -------------------------------------------------------------------------------- /interactive_tool/inter_cluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import skimage.io 6 | import numpy as np 7 | from keras import Model 8 | from tkinter import * 9 | from PIL import ImageTk, Image 10 | 11 | import matplotlib.pyplot as plt 12 | from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg 13 | from matplotlib.figure import Figure 14 | import matplotlib.image as mpimg 15 | import skimage.io 16 | 17 | from models.UNetValid import get_model 18 | from utils import clustering 19 | 20 | img_path = '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/crop_img.png' 21 | 22 | 23 | class Paint(object): 24 | 25 | def __init__(self): 26 | self.root = Tk() 27 | self.mask_clust = Toplevel() 28 | 29 | self.img_np = skimage.io.imread(img_path) 30 | img = Image.fromarray(self.img_np) 31 | self.img = ImageTk.PhotoImage(image=img) 32 | frame = Frame(self.root, width=self.img.width(), height=self.img.height()) 33 | frame.grid(row=1, columnspan=2) 34 | self.canvas = Canvas(frame, bg='#FFFFFF', width=512, height=512, 35 | scrollregion=(0, 0, self.img.width(), self.img.height())) 36 | hbar = Scrollbar(frame, orient=HORIZONTAL) 37 | hbar.pack(side=BOTTOM, fill=X) 38 | hbar.config(command=self.canvas.xview) 39 | vbar = Scrollbar(frame, orient=VERTICAL) 40 | vbar.pack(side=RIGHT, fill=Y) 41 | vbar.config(command=self.canvas.yview) 42 | self.canvas.config(width=512, height=512) 43 | self.canvas.config(xscrollcommand=hbar.set, yscrollcommand=vbar.set) 44 | self.canvas.pack(side=LEFT, expand=True, fill=BOTH) 45 | 46 | # self.c = Canvas(self.root, bg='black', width=self.img.width(), height=self.img.height()) 47 | # self.c.grid(row=1, columnspan=5) 48 | 49 | self.done_button = Button(self.root, text='DONE', command=self.done) 50 | self.done_button.grid(row=0, column=1) 51 | 52 | self.pos_stroke = BooleanVar() 53 | self.stroke_type_button = Checkbutton(self.root, text='positive', variable=self.pos_stroke) 54 | self.stroke_type_button.grid(row=0, column=0) 55 | 56 | self.fig_mask_clust = Figure(figsize=(4, 4)) 57 | self.subplot_mask = self.fig_mask_clust.add_subplot(111) 58 | self.subplot_mask.axis('off') 59 | 60 | self.mask_clust_c = FigureCanvasTkAgg(self.fig_mask_clust, master=self.mask_clust) 61 | self.mask_clust_c.show() 62 | self.mask_clust_c.get_tk_widget().pack(side=TOP, fill=BOTH, expand=1) 63 | 64 | self.setup() 65 | self.root.mainloop() 66 | 67 | def setup(self): 68 | self.old_x = None 69 | self.old_y = None 70 | 71 | self.num_clusters = 100 72 | 73 | model_weights_path = '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/results/38/model.h5' 74 | img_path = '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/raw/test/images/JNCASR_Overall_Inducer_123_Folder_F_30_ebss_07_R3D_D3D_PRJ.png' 75 | self.image = skimage.io.imread(img_path).astype('float32') 76 | 77 | self.image = np.expand_dims(self.image, -1) 78 | # mean = np.array([9.513245, 10.786118, 10.060172], dtype=np.float32).reshape(1, 1, 3) 79 | # std = np.array([18.76071, 19.97462, 19.616455], dtype=np.float32).reshape(1, 1, 3) 80 | mean = np.array([30.17722], dtype='float32').reshape(1, 1, 1) 81 | std = np.array([33.690296], dtype='float32').reshape(1, 1, 1) 82 | self.image -= mean 83 | self.image /= std 84 | # self.image /= 255 85 | 86 | self.model = get_model(self.image.shape, self.num_clusters) 87 | self.model.load_weights(model_weights_path) 88 | 89 | before_last_layer_model = Model(inputs=self.model.input, outputs=self.model.get_layer('for_clust').output) 90 | 91 | self.x = before_last_layer_model.predict(np.expand_dims(self.image, axis=0), verbose=1, batch_size=1) 92 | 93 | self.pos_neg_mask = np.zeros((self.x.shape[1], self.x.shape[2])) 94 | # self.pos_neg_mask = np.load( 95 | # '/home/zhygallo/zhygallo/inria/cluster_seg/deepcluster_segmentation/data/patches_512_lux/pos_neg_mask.npy') 96 | self.x_flat = np.reshape(self.x, (self.x.shape[0] * self.x.shape[1] * self.x.shape[2], self.x.shape[-1])) 97 | 98 | self.n_pix, self.dim_pix = self.x_flat.shape[0], self.x_flat.shape[1] 99 | 100 | self.deepcluster = clustering.Kmeans(self.num_clusters) 101 | self.centroids = self.deepcluster.cluster(self.x_flat, verbose=True, init_cents=None) 102 | 103 | masks_flat = np.zeros((self.n_pix, 1)) 104 | for clust_ind, clust in enumerate(self.deepcluster.images_lists): 105 | masks_flat[clust] = clust_ind 106 | self.masks = masks_flat.reshape((self.x.shape[1], self.x.shape[2])) 107 | 108 | self.count = 1 109 | skimage.io.imsave( 110 | '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/binar_pred/clust_%i.png' % self.count, 111 | self.masks.astype('uint32')) 112 | self.subplot_mask.imshow(self.masks) 113 | self.canvas.create_image(0, 0, image=self.img, anchor=NW) 114 | 115 | self.canvas.bind('', self.paint) 116 | self.canvas.bind('', self.reset) 117 | 118 | def activate_button(self, some_button): 119 | self.active_button.config(relief=RAISED) 120 | some_button.config(relief=SUNKEN) 121 | self.active_button = some_button 122 | 123 | def paint(self, event): 124 | self.line_width = 3.0 125 | if self.pos_stroke.get(): 126 | self.color = 'blue' 127 | else: 128 | self.color = 'red' 129 | paint_color = self.color 130 | if self.old_x and self.old_y: 131 | self.canvas.create_line(self.old_x, self.old_y, 132 | self.canvas.canvasx(event.x), self.canvas.canvasy(event.y), 133 | width=self.line_width, fill=paint_color, 134 | capstyle=ROUND, smooth=TRUE, splinesteps=36) 135 | self.old_x = self.canvas.canvasx(event.x) 136 | self.old_y = self.canvas.canvasy(event.y) 137 | if self.pos_stroke.get(): 138 | self.pos_neg_mask[int(self.canvas.canvasy(event.y)), int(self.canvas.canvasx(event.x))] = 1 139 | else: 140 | self.pos_neg_mask[int(self.canvas.canvasy(event.y)), int(self.canvas.canvasx(event.x))] = -1 141 | 142 | def reset(self, event): 143 | self.old_x, self.old_y = None, None 144 | 145 | def done(self): 146 | np.save('/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/pos_neg_mask_transfer.npy', 147 | self.pos_neg_mask) 148 | pos_clusts = np.unique(self.masks[self.pos_neg_mask == 1]) 149 | neg_clusts = np.unique(self.masks[self.pos_neg_mask == -1]) 150 | 151 | intersect = np.intersect1d(pos_clusts, neg_clusts) 152 | 153 | for clust in intersect: 154 | pos = np.sum((self.masks == clust) * (self.pos_neg_mask == 1)) 155 | neg = np.sum((self.masks == clust) * (self.pos_neg_mask == -1)) 156 | pos_ratio = pos / (pos + neg) 157 | neg_ratio = neg / (pos + neg) 158 | 159 | if pos_ratio < 0.3 or neg_ratio < 0.3: 160 | continue 161 | 162 | self.centroids = np.delete(self.centroids, clust, axis=0) 163 | neg_centroid = self.x[0][((self.masks == clust) * 164 | (self.pos_neg_mask == -1))].mean(axis=0).reshape(1, self.centroids.shape[1]) 165 | pos_centroid = self.x[0][((self.masks == clust) * 166 | (self.pos_neg_mask == 1))].mean(axis=0).reshape(1, self.centroids.shape[1]) 167 | 168 | self.centroids = np.concatenate((self.centroids, neg_centroid, pos_centroid), axis=0) 169 | 170 | self.deepcluster = clustering.Kmeans(self.centroids.shape[0]) 171 | self.centroids = self.deepcluster.cluster(self.x_flat, verbose=True, init_cents=self.centroids) 172 | 173 | masks_flat = np.zeros((self.n_pix, 1)) 174 | for clust_ind, clust in enumerate(self.deepcluster.images_lists): 175 | masks_flat[clust] = clust_ind 176 | self.masks = masks_flat.reshape((self.x.shape[1], self.x.shape[2])) 177 | 178 | pos_clusts = np.unique(self.masks[self.pos_neg_mask == 1]) 179 | neg_clusts = np.unique(self.masks[self.pos_neg_mask == -1]) 180 | 181 | intersect = np.intersect1d(pos_clusts, neg_clusts) 182 | 183 | pos_neg_arr = np.zeros(self.masks.shape) 184 | pos_neg_arr[np.isin(self.masks, pos_clusts)] = 1 185 | pos_neg_arr[np.isin(self.masks, neg_clusts)] = 2 186 | pos_neg_arr[np.isin(self.masks, intersect)] = 3 187 | self.subplot_mask.imshow(pos_neg_arr) 188 | 189 | self.mask_clust_c.show() 190 | 191 | pos_arr = np.zeros(self.masks.shape) 192 | pos_arr[np.isin(self.masks, pos_clusts)] = 1 193 | for clust in intersect: 194 | pos = np.sum((self.masks == clust) * (self.pos_neg_mask == 1)) 195 | neg = np.sum((self.masks == clust) * (self.pos_neg_mask == -1)) 196 | pos_ratio = pos / (pos + neg) 197 | 198 | if pos_ratio < 0.3: 199 | pos_arr[self.masks == clust] = 0 200 | 201 | # np.save('/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/pos_arr.npy', pos_clusts) 202 | 203 | skimage.io.imsave( 204 | '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/binar_pred/pos_%i.png' % self.count, 205 | np.isin(self.masks, pos_clusts)) 206 | skimage.io.imsave( 207 | '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/binar_pred/only_pos_%i.png' % self.count, 208 | np.isin(self.masks, pos_clusts) * (1 - np.isin(self.masks, neg_clusts))) 209 | skimage.io.imsave( 210 | '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/binar_pred/filt_pos_%i.png' % self.count, 211 | pos_arr.astype('uint16')) 212 | self.count += 1 213 | 214 | 215 | def main(): 216 | Paint() 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import click 7 | import os 8 | from keras.preprocessing.image import ImageDataGenerator 9 | from keras.utils import to_categorical 10 | from keras.optimizers import Adam 11 | from keras import Model 12 | from keras.callbacks import ModelCheckpoint 13 | from keras.callbacks import Callback 14 | import keras.backend as K 15 | 16 | from models.UNetValid import get_model 17 | from utils import clustering 18 | from tools.evaluation import final_prediction 19 | 20 | @click.command() 21 | @click.argument('train_img_folder', type=click.STRING) 22 | @click.option('--test_img_folder', type=click.STRING, default='') 23 | @click.option('--img_shape', nargs=2, type=click.INT, default=(256, 256)) 24 | @click.option('--num_clusters', type=click.INT, default=100) 25 | @click.option('--num_epochs', type=click.INT, default=100) 26 | @click.option('--learn_rate', type=click.FLOAT, default=1e-4) 27 | @click.option('--clust_batch_size', type=click.INT, default=32) 28 | @click.option('--train_batch_size', type=click.INT, default=4) 29 | @click.option('--current_epoch', type=click.INT, default=0) 30 | @click.option('--pretrained_weights', type=click.STRING, default='') 31 | @click.option('--out_weights_file', type=click.STRING, default='') 32 | @click.option('--out_pred_masks_test', type=click.STRING, default='') 33 | def main(train_img_folder, test_img_folder, img_shape, num_clusters, num_epochs, learn_rate, 34 | clust_batch_size, train_batch_size, current_epoch, pretrained_weights, out_weights_file, 35 | out_pred_masks_test): 36 | assert os.path.isdir(train_img_folder) 37 | if train_img_folder[-1] != '/': 38 | train_img_folder += '/' 39 | if test_img_folder[-1] != '/': 40 | assert os.path.isdir(test_img_folder) 41 | test_img_folder += '/' 42 | 43 | train_image_datagen = ImageDataGenerator( 44 | featurewise_center=True, 45 | featurewise_std_normalization=True, 46 | data_format='channels_last', 47 | ) 48 | 49 | # (512, 512) 50 | train_image_datagen.mean = np.array([30.17722], dtype='float32').reshape(1, 1, 1) 51 | train_image_datagen.std = np.array([33.690296], dtype='float32').reshape(1, 1, 1) 52 | 53 | SEED = 1 54 | train_image_generator = train_image_datagen.flow_from_directory(train_img_folder, 55 | color_mode='grayscale', 56 | target_size=img_shape, 57 | batch_size=clust_batch_size, 58 | class_mode=None, 59 | shuffle=True, 60 | seed=SEED) 61 | 62 | 63 | test_image_datagen = ImageDataGenerator( 64 | featurewise_center=True, 65 | featurewise_std_normalization=True, 66 | data_format='channels_last', 67 | ) 68 | 69 | test_image_datagen.mean = train_image_datagen.mean 70 | test_image_datagen.std = train_image_datagen.std 71 | 72 | 73 | test_image_generator = test_image_datagen.flow_from_directory(test_img_folder, 74 | color_mode='grayscale', 75 | target_size=img_shape, 76 | batch_size=32, 77 | class_mode=None, 78 | shuffle=False) 79 | 80 | num_images = train_image_generator.n 81 | model = get_model(train_image_generator.image_shape, num_classes=num_clusters) 82 | 83 | alpha = K.variable(1.) 84 | losses = {'clust_output': 'categorical_crossentropy'} 85 | loss_weights = {'clust_output': alpha} 86 | 87 | before_last_layer_model = Model(inputs=model.input, outputs=model.get_layer('for_clust').output) 88 | 89 | if pretrained_weights != '': 90 | # for layer in model.layers[-1:]: 91 | # layer.name += '_lux' 92 | model.load_weights(pretrained_weights, by_name=True) 93 | 94 | deepcluster = clustering.Kmeans(num_clusters) 95 | 96 | while current_epoch <= num_epochs: 97 | print('\n############# Epoch %i / %i #############' % ((current_epoch), num_epochs)) 98 | num_iter = int(np.ceil(num_images / float(clust_batch_size))) 99 | for i in range(num_iter): 100 | print('%i:%i/%i' % (i * clust_batch_size, (i + 1) * clust_batch_size, num_images)) 101 | 102 | images_batch = train_image_generator[i] 103 | 104 | print('Prediction:') 105 | features_batch = before_last_layer_model.predict(images_batch, verbose=1, batch_size=train_batch_size) 106 | 107 | flat_feat_batch = features_batch.reshape(-1, features_batch.shape[-1]) 108 | print('Clustering:') 109 | deepcluster.cluster(flat_feat_batch, verbose=True) 110 | 111 | masks_batch_flat = np.zeros((flat_feat_batch.shape[0], 1)) 112 | for clust_ind, clust in enumerate(deepcluster.images_lists): 113 | masks_batch_flat[clust] = clust_ind 114 | masks_batch_flat = to_categorical(masks_batch_flat, num_clusters) 115 | masks_batch = masks_batch_flat.reshape((features_batch.shape[0], features_batch.shape[1], 116 | features_batch.shape[2], num_clusters)) 117 | 118 | masks = {'clust_output': masks_batch} 119 | 120 | print('Training:') 121 | 122 | model.compile(optimizer=Adam(lr=(learn_rate)), loss=losses, loss_weights=loss_weights) 123 | model.fit(images_batch, masks, batch_size=train_batch_size, epochs=1, verbose=1) 124 | 125 | if out_pred_masks_test != '': 126 | out_pred_epoch_fold = os.path.join(out_pred_masks_test, str(current_epoch)) 127 | final_pred_model = Model(inputs=model.input, outputs=model.get_layer('clust_output').output) 128 | final_prediction(out_pred_epoch_fold, test_image_generator, final_pred_model) 129 | model.save_weights(os.path.join(out_pred_epoch_fold, 'model.h5')) 130 | 131 | current_epoch += 1 132 | 133 | 134 | if __name__ == '__main__': 135 | main() -------------------------------------------------------------------------------- /models/ContextEncoder.py: -------------------------------------------------------------------------------- 1 | """Keras implementation of UNet model""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import math 7 | from keras.models import Model 8 | from keras.layers import Input, Conv2D, UpSampling2D, Concatenate, Cropping2D, BatchNormalization, Activation, MaxPooling2D 9 | from keras import backend as K 10 | 11 | K.set_image_data_format('channels_last') 12 | 13 | 14 | def get_model(img_shape, num_classes=2): 15 | inputs = Input(shape=img_shape) 16 | 17 | conv1_1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(inputs) 18 | conv1_2 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv1_1) 19 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_2) 20 | 21 | conv2_1 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool1) 22 | conv2_2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv2_1) 23 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_2) 24 | 25 | conv3_1 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool2) 26 | conv3_2 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv3_1) 27 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3_2) 28 | 29 | conv4_1 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool3) 30 | conv4_2 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv4_1) 31 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4_2) 32 | 33 | conv5_1 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool4) 34 | conv5_2 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv5_1) 35 | 36 | up6 = Conv2D(512, 2, activation='relu', padding='valid', kernel_initializer='he_normal')( 37 | UpSampling2D(size=(2, 2))(conv5_2)) 38 | height_dif = int((conv4_2.shape[1] - up6.shape[1])) / 2 39 | width_dif = int((conv4_2.shape[2] - up6.shape[2])) / 2 40 | crop_conv4_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv4_2) 41 | concat6 = Concatenate()([crop_conv4_2, up6]) 42 | conv6_1 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat6) 43 | conv6_2 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv6_1) 44 | 45 | up7 = Conv2D(256, 2, activation='relu', padding='valid', kernel_initializer='he_normal')( 46 | UpSampling2D(size=(2, 2))(conv6_2)) 47 | height_dif = int((conv3_2.shape[1] - up7.shape[1])) / 2 48 | width_dif = int((conv3_2.shape[2] - up7.shape[2])) / 2 49 | crop_conv3_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv3_2) 50 | concat7 = Concatenate()([crop_conv3_2, up7]) 51 | conv7_1 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat7) 52 | conv7_2 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv7_1) 53 | 54 | up8 = Conv2D(128, 2, activation='relu', padding='valid', kernel_initializer='he_normal')( 55 | UpSampling2D(size=(2, 2))(conv7_2)) 56 | height_dif = int((conv2_2.shape[1] - up8.shape[1])) / 2 57 | width_dif = int((conv2_2.shape[2] - up8.shape[2])) / 2 58 | crop_conv2_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv2_2) 59 | concat8 = Concatenate()([crop_conv2_2, up8]) 60 | conv8_1 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat8) 61 | conv8_2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv8_1) 62 | 63 | up9 = Conv2D(64, 2, activation='relu', padding='valid', kernel_initializer='he_normal')( 64 | UpSampling2D(size=(2, 2))(conv8_2)) 65 | height_dif = int((conv1_2.shape[1] - up9.shape[1])) / 2 66 | width_dif = int((conv1_2.shape[2] - up9.shape[2])) / 2 67 | crop_conv1_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv1_2) 68 | concat9 = Concatenate()([crop_conv1_2, up9]) 69 | conv9_1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat9) 70 | conv9_2 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal', name='for_clust')(conv9_1) 71 | 72 | conv10 = Conv2D(num_classes, (1, 1), activation='softmax', name='clust_output')(conv9_2) 73 | 74 | conv11 = Conv2D(3, (1, 1), activation='sigmoid', name='rgb_output')(conv9_2) 75 | 76 | model = Model(inputs=inputs, outputs=[conv10, conv11]) 77 | 78 | return model 79 | 80 | 81 | def main(): 82 | model = get_model(img_shape=(512, 512, 3)) 83 | return 0 84 | 85 | 86 | if __name__ == "__main__": 87 | main() -------------------------------------------------------------------------------- /models/UNetValid.py: -------------------------------------------------------------------------------- 1 | """Keras implementation of UNet model""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import math 7 | from keras.models import Model 8 | from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Cropping2D 9 | from keras import backend as K 10 | 11 | K.set_image_data_format('channels_last') 12 | 13 | 14 | def get_model(img_shape, num_classes=2): 15 | inputs = Input(shape=img_shape) 16 | 17 | conv1_1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(inputs) 18 | conv1_2 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv1_1) 19 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_2) 20 | 21 | conv2_1 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool1) 22 | conv2_2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv2_1) 23 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_2) 24 | 25 | conv3_1 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool2) 26 | conv3_2 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv3_1) 27 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3_2) 28 | 29 | conv4_1 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool3) 30 | conv4_2 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv4_1) 31 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4_2) 32 | 33 | conv5_1 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool4) 34 | conv5_2 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv5_1) 35 | 36 | 37 | up6 = Conv2D(512, 2, activation='relu', padding='valid', kernel_initializer='he_normal')( 38 | UpSampling2D(size=(2, 2))(conv5_2)) 39 | height_dif = int((conv4_2.shape[1] - up6.shape[1])) / 2 40 | width_dif = int((conv4_2.shape[2] - up6.shape[2])) / 2 41 | crop_conv4_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv4_2) 42 | concat6 = Concatenate()([crop_conv4_2, up6]) 43 | conv6_1 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat6) 44 | conv6_2 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv6_1) 45 | 46 | up7 = Conv2D(256, 2, activation='relu', padding='valid', kernel_initializer='he_normal')( 47 | UpSampling2D(size=(2, 2))(conv6_2)) 48 | height_dif = int((conv3_2.shape[1] - up7.shape[1])) / 2 49 | width_dif = int((conv3_2.shape[2] - up7.shape[2])) / 2 50 | crop_conv3_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv3_2) 51 | concat7 = Concatenate()([crop_conv3_2, up7]) 52 | conv7_1 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat7) 53 | conv7_2 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv7_1) 54 | 55 | up8 = Conv2D(128, 2, activation='relu', padding='valid', kernel_initializer='he_normal')( 56 | UpSampling2D(size=(2, 2))(conv7_2)) 57 | height_dif = int((conv2_2.shape[1] - up8.shape[1])) / 2 58 | width_dif = int((conv2_2.shape[2] - up8.shape[2])) / 2 59 | crop_conv2_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv2_2) 60 | concat8 = Concatenate()([crop_conv2_2, up8]) 61 | conv8_1 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat8) 62 | conv8_2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv8_1) 63 | 64 | up9 = Conv2D(64, 2, activation='relu', padding='valid', kernel_initializer='he_normal')( 65 | UpSampling2D(size=(2, 2))(conv8_2)) 66 | height_dif = int((conv1_2.shape[1] - up9.shape[1])) / 2 67 | width_dif = int((conv1_2.shape[2] - up9.shape[2])) / 2 68 | crop_conv1_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv1_2) 69 | concat9 = Concatenate()([crop_conv1_2, up9]) 70 | conv9_1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat9) 71 | conv9_2 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal', name='for_clust')(conv9_1) 72 | 73 | conv10 = Conv2D(num_classes, (1, 1), activation='softmax', name='clust_output')(conv9_2) 74 | 75 | # conv11 = Conv2D(2, (1, 1), activation='softmax', name='rgb_output')(conv9_2) 76 | 77 | 78 | model = Model(inputs=inputs, outputs=[conv10]) 79 | 80 | return model 81 | -------------------------------------------------------------------------------- /models/VGG_UNet.py: -------------------------------------------------------------------------------- 1 | """Keras implementation of UNet model""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import math 7 | from keras.models import Model 8 | from keras.layers import Conv2D, UpSampling2D, Concatenate, Cropping2D, BatchNormalization, Activation, MaxPooling2D 9 | from keras.applications.vgg16 import VGG16 10 | from keras import backend as K 11 | 12 | K.set_image_data_format('channels_last') 13 | 14 | 15 | def get_model(img_shape, num_classes=2): 16 | class_model = VGG16(weights='imagenet', include_top=False, input_shape=img_shape) 17 | 18 | class_model.get_layer('block1_conv1').trainable = False 19 | class_model.get_layer('block1_conv2').trainable = False 20 | class_model.get_layer('block2_conv1').trainable = False 21 | class_model.get_layer('block2_conv2').trainable = False 22 | 23 | class_model = Model(inputs=class_model.inputs, outputs=class_model.get_layer('block3_pool').output) 24 | 25 | skip1 = class_model.get_layer('block3_conv3').output 26 | skip2 = class_model.get_layer('block2_conv2').output 27 | skip3 = class_model.get_layer('block1_conv2').output 28 | 29 | conv4_1 = Conv2D(512, 3, padding='valid')(class_model.output) 30 | batch4_1 = BatchNormalization()(conv4_1) 31 | act4_1 = Activation('relu')(batch4_1) 32 | conv4_2 = Conv2D(512, 3, padding='valid')(act4_1) 33 | batch4_2 = BatchNormalization()(conv4_2) 34 | act4_2 = Activation('relu')(batch4_2) 35 | 36 | # up5 = UpSampling2D(size=(2, 2))(act4_2) 37 | up5 = Conv2D(256, 2, activation='relu', padding='valid')( 38 | UpSampling2D(size=(2, 2))(act4_2)) 39 | height_dif = int((skip1.shape[1] - up5.shape[1])) / 2 40 | width_dif = int((skip1.shape[2] - up5.shape[2])) / 2 41 | crop_skip1 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), 42 | (math.ceil(width_dif), int(width_dif))))(skip1) 43 | concat5 = Concatenate()([crop_skip1, up5]) 44 | conv5_1 = Conv2D(256, 3, padding='valid')(concat5) 45 | batch5_1 = BatchNormalization()(conv5_1) 46 | act5_1 = Activation('relu')(batch5_1) 47 | conv5_2 = Conv2D(256, 3, padding='valid')(act5_1) 48 | batch5_2 = BatchNormalization()(conv5_2) 49 | act5_2 = Activation('relu')(batch5_2) 50 | 51 | # up6 = UpSampling2D(size=(2, 2))(act5_2) 52 | up6 = Conv2D(128, 2, activation='relu', padding='valid')( 53 | UpSampling2D(size=(2, 2))(act5_2)) 54 | height_dif = int((skip2.shape[1] - up6.shape[1])) / 2 55 | width_dif = int((skip2.shape[2] - up6.shape[2])) / 2 56 | crop_skip2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), 57 | (math.ceil(width_dif), int(width_dif))))(skip2) 58 | concat6 = Concatenate()([crop_skip2, up6]) 59 | conv6_1 = Conv2D(128, 3, padding='valid')(concat6) 60 | batch6_1 = BatchNormalization()(conv6_1) 61 | act6_1 = Activation('relu')(batch6_1) 62 | conv6_2 = Conv2D(128, 3, padding='valid')(act6_1) 63 | batch6_2 = BatchNormalization()(conv6_2) 64 | act6_2 = Activation('relu')(batch6_2) 65 | 66 | up7 = Conv2D(64, 2, activation='relu', padding='valid')( 67 | UpSampling2D(size=(2, 2))(act6_2)) 68 | height_dif = int((skip3.shape[1] - up7.shape[1])) / 2 69 | width_dif = int((skip3.shape[2] - up7.shape[2])) / 2 70 | crop_skip3 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), 71 | (math.ceil(width_dif), int(width_dif))))(skip3) 72 | concat7 = Concatenate()([crop_skip3, up7]) 73 | conv7_1 = Conv2D(64, 3, padding='valid')(concat7) 74 | batch7_1 = BatchNormalization()(conv7_1) 75 | act7_1 = Activation('relu')(batch7_1) 76 | conv7_2 = Conv2D(64, 3, padding='valid')(act7_1) 77 | batch7_2 = BatchNormalization()(conv7_2) 78 | act7_2 = Activation('relu', name='for_clust')(batch7_2) 79 | 80 | conv8 = Conv2D(num_classes, (1, 1), activation='softmax', name='clust_output')(act7_2) 81 | 82 | # conv9 = Conv2D(3, (1, 1), activation='sigmoid', name='rgb_output')(act7_2) 83 | 84 | model = Model(inputs=class_model.inputs, outputs=[conv8]) 85 | 86 | return model 87 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/UNetValid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/models/__pycache__/UNetValid.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.4.1 2 | affine==2.2.1 3 | astor==0.7.1 4 | attrs==18.2.0 5 | certifi==2018.11.29 6 | chardet==3.0.4 7 | click==6.7 8 | click-plugins==1.0.4 9 | cligj==0.4.0 10 | cloudpickle==0.5.6 11 | cycler==0.10.0 12 | dask==0.19.2 13 | decorator==4.3.0 14 | faiss==1.4.0 15 | Fiona==1.7.13 16 | gast==0.2.0 17 | grpcio==1.15.0 18 | h5py==2.8.0 19 | idna==2.7 20 | imageio==2.4.1 21 | Keras==2.2.2 22 | Keras-Applications==1.0.6 23 | Keras-Preprocessing==1.0.5 24 | kiwisolver==1.0.1 25 | lxml==4.2.5 26 | Markdown==2.6.11 27 | matplotlib==2.2.3 28 | mkl-fft==1.0.6 29 | mkl-random==1.0.1 30 | munch==2.3.2 31 | networkx==2.2 32 | numpy==1.15.2 33 | opencv-python==4.0.0.21 34 | pandas==0.23.4 35 | Pillow==5.3.0 36 | progressbar2==3.38.0 37 | protobuf==3.6.1 38 | pydensecrf==1.0rc3 39 | PyGraph==0.2.1 40 | PyMaxflow==1.2.11 41 | pyparsing==2.2.1 42 | python-dateutil==2.7.3 43 | python-utils==2.3.0 44 | pytz==2018.5 45 | PyWavelets==1.0.0 46 | PyYAML==3.13 47 | rasterio==1.0.4 48 | requests==2.19.1 49 | scikit-image==0.14.0 50 | scikit-learn==0.19.2 51 | scipy==1.1.0 52 | Shapely==1.6.4.post2 53 | six==1.11.0 54 | sklearn==0.0 55 | snuggs==1.4.2 56 | tensorboard==1.10.0 57 | tensorflow==1.10.1 58 | tensorlayer==1.10.1 59 | termcolor==1.1.0 60 | toolz==0.9.0 61 | tqdm==4.25.0 62 | urllib3==1.23 63 | Werkzeug==0.14.1 64 | wrapt==1.10.11 65 | -------------------------------------------------------------------------------- /super_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import click 7 | import os 8 | from keras.preprocessing.image import ImageDataGenerator 9 | from keras.utils import to_categorical 10 | from keras.optimizers import Adam 11 | from keras.callbacks import ModelCheckpoint 12 | from keras import Model 13 | 14 | from models.ContextEncoder import get_model 15 | from tools.evaluation import final_prediction 16 | 17 | def crop_center(image): 18 | crop_shape = (309, 309) 19 | crop_bef_h = (image.shape[0] - crop_shape[0]) // 2 20 | crop_bef_w = (image.shape[1] - crop_shape[1]) // 2 21 | out_img = image.copy() 22 | out_img[crop_bef_h:crop_bef_h + crop_shape[0], crop_bef_w:crop_bef_w + crop_shape[1]] = 0.5 23 | return out_img 24 | 25 | @click.command() 26 | @click.argument('train_img_folder', type=click.STRING) 27 | @click.argument('train_mask_folder', type=click.STRING) 28 | @click.option('--test_img_folder', type=click.STRING, default='') 29 | @click.option('--test_mask_folder', type=click.STRING, default='') 30 | @click.option('--in_img_shape', nargs=2, type=click.INT, default=(512, 512)) 31 | @click.option('--out_img_shape', nargs=2, type=click.INT, default=(309, 309)) 32 | @click.option('--num_classes', type=click.INT, default=800) 33 | @click.option('--num_epochs', type=click.INT, default=100) 34 | @click.option('--learn_rate', type=click.FLOAT, default=1e-3) 35 | @click.option('--batch_size', type=click.INT, default=4) 36 | @click.option('--current_epoch', type=click.INT, default=0) 37 | @click.option('--pretrained_weights', type=click.STRING, default='') 38 | @click.option('--out_weights_file', type=click.STRING, default='') 39 | @click.option('--out_pred_masks_test', type=click.STRING, default='') 40 | def main(train_img_folder, train_mask_folder, test_img_folder, test_mask_folder, 41 | in_img_shape, out_img_shape, num_classes, num_epochs, learn_rate, batch_size, 42 | current_epoch, pretrained_weights, out_weights_file, out_pred_masks_test): 43 | assert os.path.isdir(train_img_folder) 44 | if train_img_folder[-1] != '/': 45 | train_img_folder += '/' 46 | assert os.path.isdir(test_img_folder) 47 | if test_img_folder[-1] != '/': 48 | test_img_folder += '/' 49 | 50 | train_image_datagen = ImageDataGenerator( 51 | # featurewise_center=True, 52 | # featurewise_std_normalization=True, 53 | rescale=1./255, 54 | data_format='channels_last', 55 | preprocessing_function=crop_center 56 | ) 57 | 58 | train_mask_datagen = ImageDataGenerator( 59 | rescale=1. / 255, 60 | # featurewise_center=True, 61 | # featurewise_std_normalization=True, 62 | ) 63 | 64 | # (512, 512) lux 65 | # train_image_datagen.mean = np.array([11.4296465, 13.140564, 12.277675], dtype=np.float32).reshape(1, 1, 3) 66 | # train_image_datagen.std = np.array([27.561337, 29.903225, 29.525864], dtype=np.float32).reshape(1, 1, 3) 67 | 68 | SEED = 1 69 | train_image_gen = train_image_datagen.flow_from_directory(train_img_folder, 70 | target_size=in_img_shape, 71 | batch_size=batch_size, 72 | class_mode=None, 73 | shuffle=True, 74 | seed=SEED) 75 | 76 | train_mask_gen = train_mask_datagen.flow_from_directory(train_mask_folder, 77 | target_size=out_img_shape, 78 | batch_size=batch_size, 79 | # color_mode='grayscale', 80 | class_mode=None, 81 | shuffle=True, 82 | seed=SEED) 83 | 84 | test_image_datagen = ImageDataGenerator( 85 | # featurewise_center=True, 86 | # featurewise_std_normalization=True, 87 | rescale=1. / 255, 88 | data_format='channels_last', 89 | preprocessing_function=crop_center 90 | ) 91 | test_image_gen = test_image_datagen.flow_from_directory(test_img_folder, 92 | target_size=in_img_shape, 93 | batch_size=batch_size, 94 | class_mode=None, 95 | shuffle=False) 96 | 97 | # test_image_datagen.mean = train_image_datagen.mean 98 | # test_image_datagen.std = train_image_datagen.mean 99 | 100 | model = get_model(train_image_gen.image_shape, num_classes=num_classes) 101 | model.compile(optimizer=Adam(lr=(learn_rate)), loss='mse') 102 | 103 | if pretrained_weights != '': 104 | # for layer in model.layers[-2:]: 105 | # layer.name += '_lux' 106 | model.load_weights(pretrained_weights, by_name=True) 107 | 108 | callback_list = [] 109 | if out_weights_file != '': 110 | model_checkpoint = ModelCheckpoint(out_weights_file, period=1, save_weights_only=True) 111 | callback_list.append(model_checkpoint) 112 | 113 | while current_epoch <= num_epochs: 114 | print('\n############# Epoch %i / %i #############' % ((current_epoch), num_epochs)) 115 | steps_train_epoch = train_image_gen.n // batch_size 116 | for i in range(steps_train_epoch): 117 | train_img_batch = train_image_gen[i] 118 | train_mask_batch = train_mask_gen[i] 119 | # train_mask_batch = to_categorical(train_mask_gen[i], num_classes) 120 | 121 | model.fit(train_img_batch, train_mask_batch, batch_size=8, epochs=1, 122 | callbacks=callback_list, verbose=1) 123 | 124 | 125 | if out_pred_masks_test != '': 126 | out_pred_epoch_fold = os.path.join(out_pred_masks_test, str(current_epoch)) 127 | final_prediction(out_pred_epoch_fold, test_image_gen, model) 128 | model.save_weights(os.path.join(out_pred_epoch_fold, 'model.h5')) 129 | 130 | current_epoch += 1 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/tools/__init__.py -------------------------------------------------------------------------------- /tools/clusters2classes.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import click 6 | import skimage.io 7 | import numpy as np 8 | import random 9 | import os 10 | 11 | 12 | def gen_classes_from_clusters(in_clust_test, in_clust_train, in_gt_train, gt_fraction, keep_clust_threshold, 13 | out_fold, per_cluster_out_fold=''): 14 | assert os.path.isdir(in_clust_test) 15 | if in_clust_test[-1] != '/': 16 | in_clust_test += '/' 17 | assert os.path.isdir(in_clust_train) 18 | if in_clust_train[-1] != '/': 19 | in_clust_train += '/' 20 | assert os.path.isdir(in_gt_train) 21 | if in_gt_train[-1] != '/': 22 | in_gt_train += '/' 23 | assert type(gt_fraction) is float 24 | assert gt_fraction > 0 and gt_fraction <= 1 25 | assert type(keep_clust_threshold) is float 26 | assert keep_clust_threshold >= 0 and keep_clust_threshold <= 1 27 | if out_fold != '': 28 | if not os.path.isdir(out_fold): 29 | os.mkdir(out_fold) 30 | if per_cluster_out_fold != '': 31 | if not os.path.isdir(per_cluster_out_fold): 32 | os.mkdir(per_cluster_out_fold) 33 | 34 | all_clust_test_files = [] 35 | for dirpath, dirnames, filenames in os.walk(in_clust_test): 36 | all_clust_test_files.extend([os.path.join(dirpath, f) 37 | for f in filenames if f.endswith('.png')]) 38 | 39 | all_clust_train_files = [] 40 | for dirpath, dirnames, filenames in os.walk(in_clust_train): 41 | all_clust_train_files.extend([os.path.join(os.path.relpath(dirpath, in_clust_train), f) 42 | for f in filenames if f.endswith('.png')]) 43 | 44 | all_gt_train_files = [] 45 | for dirpath, dirnames, filenames in os.walk(in_gt_train): 46 | all_gt_train_files.extend([os.path.join(os.path.relpath(dirpath, in_gt_train), f) 47 | for f in filenames if f.endswith('.png')]) 48 | 49 | 50 | assert len(all_gt_train_files) * gt_fraction > 0 51 | assert len(all_clust_train_files) * gt_fraction > 0 52 | 53 | SEED = 1 54 | random.seed(SEED) 55 | random.shuffle(all_gt_train_files) 56 | num_gt = int(gt_fraction*len(all_gt_train_files)) 57 | sub_gt_files = all_gt_train_files[:num_gt] 58 | assert np.sum([gt_file in all_clust_train_files for gt_file in sub_gt_files]) == len(sub_gt_files) 59 | 60 | sub_gt_train_full_path = [os.path.join(in_gt_train, file) for file in sub_gt_files] 61 | sub_clust_train_full_path = [os.path.join(in_clust_train, file) for file in sub_gt_files] 62 | 63 | gt_train_masks = np.asarray(skimage.io.imread_collection(sub_gt_train_full_path)) 64 | clust_train_masks = np.asarray(skimage.io.imread_collection(sub_clust_train_full_path)) 65 | 66 | selected_clusters = [] 67 | for clust in np.unique(clust_train_masks): 68 | clust_pixel_num = (clust_train_masks == clust).sum() 69 | gt_pixel_num = (gt_train_masks[clust_train_masks == clust] == 1).sum() 70 | if gt_pixel_num / clust_pixel_num >= keep_clust_threshold: 71 | selected_clusters.append(clust) 72 | 73 | result_masks = [] 74 | for clust_test_file in all_clust_test_files: 75 | clust_image = skimage.io.imread(clust_test_file) 76 | result_mask = np.zeros(clust_image.shape) 77 | for c in selected_clusters: 78 | result_mask[clust_image == c] = 1 79 | result_masks.append(result_mask) 80 | 81 | if out_fold != '': 82 | rel_path = os.path.relpath(clust_test_file, in_clust_test) 83 | full_out_path = os.path.join(out_fold, rel_path) 84 | if not os.path.isdir(os.path.dirname(full_out_path)): 85 | os.makedirs(os.path.dirname(full_out_path)) 86 | skimage.io.imsave(full_out_path, result_mask.astype('uint8')) 87 | 88 | return result_masks 89 | 90 | 91 | @click.command() 92 | @click.argument('in_clust_test', type=click.STRING) 93 | @click.argument('in_clust_train', type=click.STRING) 94 | @click.argument('in_gt_train', type=click.STRING) 95 | @click.option('--gt_fraction', type=click.FLOAT, default='') 96 | @click.option('--keep_clust_threshold', type=click.FLOAT, default=0) 97 | @click.option('--out_fold', type=click.STRING, default='') 98 | @click.option('--per_cluster_out_fold', type=click.STRING, default='') 99 | def main(in_clust_test, in_clust_train, in_gt_train, gt_fraction, 100 | keep_clust_threshold, out_fold, per_cluster_out_fold): 101 | gen_classes_from_clusters(in_clust_test, in_clust_train, in_gt_train, gt_fraction=gt_fraction, 102 | keep_clust_threshold=keep_clust_threshold, out_fold=out_fold, 103 | per_cluster_out_fold=per_cluster_out_fold) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /tools/evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import click 6 | import os 7 | import scipy.spatial 8 | import sklearn.metrics 9 | import skimage.io 10 | import json 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pandas as pd 14 | 15 | from tools.clusters2classes import gen_classes_from_clusters 16 | from models.UNetValid import get_model 17 | 18 | 19 | def final_prediction(out_fold, image_generator, model): 20 | num_iter = int(np.ceil(image_generator.n / float(image_generator.batch_size))) 21 | batch_size = image_generator.batch_size 22 | 23 | test_files = image_generator.filenames 24 | 25 | print('Final Predicton:') 26 | for i in range(num_iter): 27 | print('%i:%i/%i' % (i * batch_size, (i + 1) * batch_size, image_generator.n)) 28 | 29 | if i < num_iter - 1: 30 | batch_files = test_files[i * batch_size: (i + 1) * batch_size] 31 | else: 32 | batch_files = test_files[i * batch_size:] 33 | 34 | batch_imgs = image_generator[i] 35 | 36 | pred_mask_probs = model.predict(batch_imgs, batch_size=1, verbose=1) 37 | for file, pred_mask_prob in zip(batch_files, pred_mask_probs): 38 | file_dir = os.path.dirname(file) 39 | file_name = file.split('/')[-1] 40 | full_out_dir = os.path.join(out_fold, file_dir) 41 | if not os.path.isdir(full_out_dir): 42 | os.makedirs(full_out_dir) 43 | mask = pred_mask_prob.argmax(axis=-1).astype('uint16') 44 | skimage.io.imsave(os.path.join(full_out_dir, file_name), mask) 45 | 46 | 47 | def evaluate(input_gt, input_pred_mask, output_file=''): 48 | assert os.path.isdir(input_gt) 49 | if input_gt[-1] != '/': 50 | input_gt += '/' 51 | assert os.path.isdir(input_pred_mask) 52 | if input_pred_mask[-1] != '/': 53 | input_pred_mask += '/' 54 | 55 | result_dict = {} 56 | all_gt_files = [] 57 | for dirpath, dirnames, filenames in os.walk(input_gt): 58 | all_gt_files.extend([os.path.join(os.path.relpath(dirpath, input_gt), f) 59 | for f in filenames if f.endswith('.png')]) 60 | 61 | all_pred_files_full_path = [os.path.join(input_pred_mask, file) for file in all_gt_files] 62 | all_gt_files_full_path = [os.path.join(input_gt, file) for file in all_gt_files] 63 | 64 | all_gt = np.asarray(skimage.io.imread_collection(all_gt_files_full_path)) 65 | all_pred_mask = np.asarray(skimage.io.imread_collection(all_pred_files_full_path)) 66 | 67 | result_dict['all_dice'] = 1 - scipy.spatial.distance.dice(all_gt.flatten(), all_pred_mask.flatten()) 68 | result_dict['all_accuracy'] = sklearn.metrics.accuracy_score(all_gt.flatten(), all_pred_mask.flatten()) 69 | 70 | if output_file != '': 71 | with open(output_file, 'w') as fp: 72 | json.dump(result_dict, fp, indent=2, sort_keys=True) 73 | 74 | return result_dict 75 | 76 | 77 | def evaluate_per_epoch(): 78 | test = 'test' 79 | in_result_dict = '' 80 | out_figure = 'data/patches_512_lux/figure_%s.png' % test 81 | 82 | if in_result_dict == '': 83 | # in_gt_train_fold = 'data/patches_512_lux/gt_masks/train_masks' 84 | in_gt_train_fold = 'data/patches_512_lux/gt_masks/test_masks' 85 | in_gt_test_fold = 'data/patches_512_lux/gt_masks/%s_masks' % test 86 | in_pred_clust_fold = 'data/patches_512_lux/pred_masks' 87 | out_fold_final_preds = 'data/patches_512_lux/final_preds' 88 | out_result_file = 'data/patches_512_lux/result_dict.json' 89 | 90 | gt_fraction = 0.6 91 | keep_clust_threshold = 0.4 92 | 93 | epochs = list(range(20)) 94 | 95 | result = {} 96 | for epoch_dir in sorted(os.listdir(in_pred_clust_fold), key=int): 97 | if not int(epoch_dir) in epochs: 98 | continue 99 | epoch_dirpath = os.path.join(in_pred_clust_fold, epoch_dir) 100 | epoch_test_dirpath = os.path.join(epoch_dirpath) 101 | epoch_train_dirpath = os.path.join(epoch_dirpath) 102 | # epoch_test_dirpath = os.path.join(epoch_dirpath, test) 103 | # epoch_train_dirpath = os.path.join(epoch_dirpath, 'train') 104 | 105 | epoch_out_fold = os.path.join(out_fold_final_preds, epoch_dir) 106 | if not os.path.isdir(epoch_out_fold): 107 | os.makedirs(epoch_out_fold) 108 | 109 | if not set(os.listdir(epoch_test_dirpath)) < set(os.listdir(epoch_out_fold)): 110 | gen_classes_from_clusters(in_clust_test=epoch_test_dirpath, 111 | in_clust_train=epoch_train_dirpath, 112 | in_gt_train=in_gt_train_fold, 113 | gt_fraction=gt_fraction, 114 | keep_clust_threshold=keep_clust_threshold, 115 | out_fold=epoch_out_fold) 116 | 117 | result[epoch_dir] = evaluate(input_gt=in_gt_test_fold, input_pred_mask=epoch_out_fold) 118 | 119 | with open(out_result_file, 'w') as fp: 120 | json.dump(result, fp, indent=2, sort_keys=True) 121 | 122 | else: 123 | with open(in_result_dict, 'r') as json_data: 124 | result = json.load(json_data) 125 | 126 | df = pd.DataFrame({'epochs': list(result.keys()), 'dice': [item['all_dice'] for item in result.values()], 127 | 'accuracy': [item['all_accuracy'] for item in result.values()]}) 128 | 129 | fig, ax = plt.subplots(nrows=1, ncols=1) 130 | ax.plot('epochs', 'dice', data=df, marker='o', markerfacecolor='black', color='blue', linewidth=4) 131 | # ax.plot('epochs', 'accuracy', data=df, marker='o', markerfacecolor='black', color='green', linewidth=4) 132 | ax.set_xlabel('Epochs') 133 | ax.set_ylabel('Dice') 134 | fig.savefig(out_figure) 135 | 136 | 137 | @click.command() 138 | @click.option('--image_path', type=click.STRING) 139 | @click.option('--pred_mask_path', type=click.STRING, default='') 140 | @click.option('--model_weights', type=click.STRING, default='') 141 | @click.option('--num_clust', type=click.INT, default=2) 142 | @click.option('--gt_mask_path', type=click.STRING, default='') 143 | @click.option('--out_pred_mask', type=click.STRING, default='') 144 | def main(image_path, pred_mask_path, model_weights, num_clust, gt_mask_path, out_pred_mask): 145 | if pred_mask_path == '': 146 | image = skimage.io.imread(image_path).astype('float32') 147 | mean = np.array([11.4296465, 13.140564, 12.277675], dtype=np.float32).reshape(1, 1, 3) 148 | std = np.array([27.561337, 29.903225, 29.525864], dtype=np.float32).reshape(1, 1, 3) 149 | image -= mean 150 | image /= std 151 | model = get_model(image.shape, num_clust) 152 | model.load_weights(model_weights) 153 | 154 | pred_mask = model.predict(np.expand_dims(image, axis=0)).argmax(axis=-1).astype('uint16')[0] 155 | else: 156 | pred_mask = skimage.io.imread(pred_mask_path) 157 | 158 | gt_mask = skimage.io.imread(gt_mask_path) 159 | 160 | dice = 1 - scipy.spatial.distance.dice(gt_mask.flatten(), pred_mask.flatten()) 161 | accuracy = sklearn.metrics.accuracy_score(gt_mask.flatten(), pred_mask.flatten()) 162 | 163 | overlap = gt_mask.flatten() * pred_mask.flatten() 164 | union = gt_mask.flatten() + pred_mask.flatten() 165 | 166 | IOU = overlap.sum()/float((union>0).sum()) 167 | 168 | print('Dice Score: %f' % dice) 169 | print('Accuracy: %f' % accuracy) 170 | print('IoU: %f' % IOU) 171 | 172 | if out_pred_mask != '' and pred_mask_path == '': 173 | skimage.io.imsave(out_pred_mask, pred_mask) 174 | 175 | 176 | if __name__ == '__main__': 177 | main() 178 | -------------------------------------------------------------------------------- /tools/full_prediction.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import click 6 | import os 7 | import skimage.io 8 | 9 | from osgeo import gdal 10 | 11 | 12 | def array_to_tiff_file(image, output_file, geoprojection=None, geotransform=None): 13 | """Gets images in a form of numpy arrays and saves them as images with spacial information 14 | Args: 15 | image: numpy array. input image of shape (num_channels, w, h) or (w, h) in case of 1 channel 16 | output_file: string. path to the output file 17 | geoprojection: driver_in.GetProjection() 18 | geotransform: driver_in.GetGeoTransform() 19 | Returns: 20 | - 21 | """ 22 | driver = gdal.GetDriverByName("GTiff") 23 | 24 | if image.ndim == 3: 25 | n_channels, h, w = image.shape 26 | elif image.ndim == 2: 27 | h, w = image.shape 28 | n_channels = 1 29 | ds_out = driver.Create(output_file, w, h, n_channels, gdal.GDT_Float32) 30 | if geoprojection and geotransform: 31 | ds_out.SetProjection(geoprojection) 32 | ds_out.SetGeoTransform(geotransform) 33 | 34 | # write channels separately 35 | for i in range(n_channels): 36 | band = ds_out.GetRasterBand(i + 1) # 1-based index 37 | if n_channels > 1: 38 | channel = image[i] 39 | else: 40 | channel = image 41 | band.WriteArray(channel) 42 | band.FlushCache() 43 | 44 | del ds_out 45 | 46 | 47 | def get_prediction(in_full_img_fold, in_pred_mask_fold, crop_shape=(512, 512), out_fold=''): 48 | """Generates masks of a crop_shape for images in the input_folder 49 | Args: 50 | input_folder: string. path to a folder with images on which model will be applied 51 | output_folder: string. path to the output folder with predicted masks 52 | Returns: 53 | - 54 | """ 55 | assert os.path.isdir(in_full_img_fold) 56 | if in_full_img_fold[-1] != '/': 57 | in_full_img_fold += '/' 58 | if out_fold != '': 59 | if not os.path.isdir(out_fold): 60 | os.mkdir(out_fold) 61 | if out_fold != '/': 62 | out_fold += '/' 63 | 64 | img_files = os.listdir(in_full_img_fold) 65 | for img_file in img_files: 66 | dataset = gdal.Open(os.path.join(in_full_img_fold, img_file)) 67 | 68 | geotransform = dataset.GetGeoTransform() 69 | geoprojection = dataset.GetProjection() 70 | 71 | x_origin = geotransform[0] 72 | y_origin = geotransform[3] 73 | pixel_width = geotransform[1] 74 | pixel_height = geotransform[5] 75 | 76 | image = dataset.ReadAsArray() 77 | image = image.transpose((1, 2, 0)) 78 | 79 | mask_files = os.listdir(in_pred_mask_fold) 80 | mask_files.sort() 81 | 82 | count = 0 83 | for row in range(0, image.shape[0] - crop_shape[0], crop_shape[0]): 84 | for col in range(0, image.shape[1] - crop_shape[1], crop_shape[1]): 85 | crop_img = image[row:row + crop_shape[0], col:col + crop_shape[1], :].astype('float32') 86 | if (crop_img == 0).sum() / crop_img.size > 0.1: 87 | continue 88 | crop_geotransform = (x_origin + col * pixel_width, pixel_width, geotransform[2], 89 | y_origin + row * pixel_height, geotransform[4], pixel_height) 90 | full_path_mask = os.path.join(in_pred_mask_fold, mask_files[count]) 91 | pred_mask = skimage.io.imread(full_path_mask) 92 | 93 | if out_fold != '': 94 | if not os.path.isdir(os.path.join(out_fold, img_file.split('.')[0])): 95 | os.mkdir(os.path.join(out_fold, img_file.split('.')[0])) 96 | array_to_tiff_file(pred_mask, 97 | os.path.join(out_fold, img_file.split('.')[0], 98 | 'pred_mask_' + str(count).zfill(6) + '.tif'), 99 | geoprojection=geoprojection, 100 | geotransform=crop_geotransform) 101 | 102 | count += 1 103 | 104 | 105 | @click.command() 106 | @click.argument('in_full_img_fold', type=click.STRING) 107 | @click.argument('in_pred_mask_fold', type=click.STRING) 108 | @click.option('--crop_shape', nargs=2, type=click.INT, default=(512, 512)) 109 | @click.option('--out_fold', type=click.STRING, default='') 110 | def main(in_full_img_fold, in_pred_mask_fold, crop_shape, out_fold): 111 | get_prediction(in_full_img_fold, in_pred_mask_fold, crop_shape=crop_shape, out_fold=out_fold) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /tools/patch_generator.py: -------------------------------------------------------------------------------- 1 | """Script for generating patches""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import click 8 | import skimage.io 9 | import skimage.util 10 | import numpy as np 11 | 12 | # mean = np.array([11.4296465, 13.140564, 12.277675], dtype=np.float32).reshape(1, 1, 3) 13 | # std = np.array([27.561337, 29.903225, 29.525864], dtype=np.float32).reshape(1, 1, 3) 14 | 15 | def get_patches(input_fold, output_folder='', output_np_folder='', masks_folder='', 16 | patch_shape=(512, 512), crop_shape=None, stride=(100, 100), keep_dark_region_ratio=0.01): 17 | """ 18 | Generates cropped patches from the full-size input images 19 | Args: 20 | input_fold: string. path to the raw train data folder 21 | output_folder: string. path to the output folder generated patches 22 | output_np_folder: string. path to the output folder with numpy file 23 | masks_folder: string. if not '' generates a patch with a mask corresponding to an image patch 24 | patch_shape: tuple of ints. (height, width), specifies patch shape 25 | crop_shape: tuple of ints. (height, width), specifies the area cropped from the patch. 26 | stride: tuple of ints. stride in height and width direction of a sliding window 27 | keep_dark_region_ratio: float. ratio of black pixels which is allowed for a patch to be considered 28 | Returns: 29 | list of numpy arrays. image patches and mask patches 30 | """ 31 | 32 | assert os.path.isdir(input_fold) 33 | if input_fold[-1] != '/': 34 | input_fold += '/' 35 | if output_folder != '': 36 | if not os.path.isdir(output_folder): 37 | os.mkdir(output_folder) 38 | if output_folder[-1] != '/': 39 | output_folder += '/' 40 | if output_np_folder != '': 41 | assert os.path.isdir(output_np_folder) 42 | if output_np_folder[-1] != '/': 43 | output_np_folder += '/' 44 | if masks_folder != '': 45 | assert os.path.isdir(masks_folder) 46 | if masks_folder[-1] != '/': 47 | masks_folder += '/' 48 | assert type(patch_shape) is tuple 49 | assert type(patch_shape[0]) is int and type(patch_shape[1]) is int 50 | assert patch_shape[0] > 0 and patch_shape[1] > 0 51 | 52 | if crop_shape != None: 53 | crop_bef_h = (patch_shape[0] - crop_shape[0]) // 2 54 | crop_bef_w = (patch_shape[1] - crop_shape[1]) // 2 55 | 56 | assert type(stride) is tuple 57 | assert type(stride[0]) is int and type(stride[1]) is int 58 | assert stride[0] > 0 and stride[1] > 0 59 | 60 | img_count = 0 61 | for subdir, dir, files in os.walk(input_fold): 62 | for file in files: 63 | if file.split('.')[-1] not in ['tiff', 'tif', 'png', 'jpg', 'JPEG', 'jpeg']: 64 | continue 65 | 66 | img_patches = [] 67 | mask_patches = [] 68 | 69 | img_file = os.path.join(subdir, file) 70 | full_img = skimage.io.imread(img_file) 71 | 72 | if masks_folder != '': 73 | mask_file = os.path.join(masks_folder, file.split('.')[0] + '_mask.' + file.split('.')[-1]) 74 | # mask_file = os.path.join(masks_folder, file.split('.')[0] + '.' + file.split('.')[-1]) 75 | assert os.path.isfile(mask_file) 76 | full_mask = skimage.io.imread(mask_file) 77 | # assert full_img.shape[0] == full_mask.shape[0] and full_img.shape[1] == full_mask.shape[1] 78 | 79 | for row in range(0, full_img.shape[0] - patch_shape[0], stride[0]): 80 | for col in range(0, full_img.shape[1] - patch_shape[1], stride[1]): 81 | crop_img = full_img[row:row + patch_shape[0], col:col + patch_shape[1]] 82 | # do not save patches where the number of black pixels is higher than skip_dark_region_ratio 83 | if (crop_img == 0).sum() / crop_img.size > keep_dark_region_ratio: 84 | continue 85 | 86 | img_patches.append(crop_img) 87 | 88 | if masks_folder != '': 89 | crop_mask = full_mask[row:row + patch_shape[0], col:col + patch_shape[1]] 90 | mask_patches.append(crop_mask) 91 | 92 | if output_folder != '': 93 | full_path = os.path.join(output_folder, file.split('.')[0]) 94 | if not os.path.isdir(full_path): 95 | os.mkdir(full_path) 96 | for ind, img_patch in enumerate(img_patches): 97 | skimage.io.imsave(os.path.join(full_path, str(img_count + ind).zfill(6) + '.png'), img_patch) 98 | if crop_shape != None: 99 | crop_patch_out = os.path.join(output_folder, 'crops', file.split('.')[0]) 100 | if not os.path.isdir(crop_patch_out): 101 | os.makedirs(crop_patch_out) 102 | crop_patch = img_patch[crop_bef_h:crop_bef_h + crop_shape[0], 103 | crop_bef_w:crop_bef_w + crop_shape[1]] 104 | skimage.io.imsave(os.path.join(crop_patch_out, str(img_count + ind).zfill(6) + '.png'), 105 | crop_patch) 106 | 107 | if masks_folder != '': 108 | mask_output_folder = os.path.join(output_folder, 'masks', file.split('.')[0]) 109 | if not os.path.isdir(mask_output_folder): 110 | os.makedirs(mask_output_folder) 111 | for ind, mask_patch in enumerate(mask_patches): 112 | if crop_shape != None: 113 | mask_patch = mask_patch[crop_bef_h:crop_bef_h + crop_shape[0], 114 | crop_bef_w:crop_bef_w + crop_shape[1]] 115 | skimage.io.imsave(os.path.join(mask_output_folder, str(img_count + ind).zfill(6) + '.png'), 116 | mask_patch.astype('uint8')) 117 | 118 | img_count += len(img_patches) 119 | 120 | if output_np_folder != '': 121 | full_path = os.path.join(output_folder, file.split('.')[0]) 122 | if not os.path.isdir(full_path): 123 | os.mkdir(full_path) 124 | img_patches_np = np.asarray(img_patches, dtype='float32') 125 | np.save(os.path.join(full_path, 'img_patches'), img_patches_np) 126 | mask_patches_np = np.asarray(mask_patches) 127 | np.save(os.path.join(full_path, 'mask_patches'), mask_patches_np) 128 | 129 | return img_patches, mask_patches 130 | 131 | 132 | @click.command() 133 | @click.argument('input_folder', type=click.STRING) 134 | @click.option('--output_folder', type=click.STRING, default='') 135 | @click.option('--output_np_folder', type=click.STRING, default='') 136 | @click.option('--masks_folder', type=click.STRING, default='') 137 | @click.option('--patch_shape', nargs=2, type=click.INT, default=(512, 512), help='Height and width of a crop') 138 | @click.option('--crop_shape', nargs=2, type=click.INT, default=(309, 309)) 139 | @click.option('--stride', nargs=2, type=click.INT, default=(100, 100), help='Stride in height and widht direction') 140 | @click.option('--keep_dark_region_ratio', type=click.FLOAT, default=0.01) 141 | def main(input_folder, output_folder, output_np_folder, masks_folder, patch_shape, crop_shape, 142 | stride, keep_dark_region_ratio): 143 | get_patches(input_folder, output_folder=output_folder, output_np_folder=output_np_folder, 144 | masks_folder=masks_folder, patch_shape=patch_shape, crop_shape=crop_shape, stride=stride, 145 | keep_dark_region_ratio=keep_dark_region_ratio) 146 | 147 | 148 | if __name__ == '__main__': 149 | main() 150 | -------------------------------------------------------------------------------- /tools/rasterization.py: -------------------------------------------------------------------------------- 1 | """Script for generating rasterized masks from shapefiles""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import rasterio 7 | from rasterio import features 8 | import geopandas as gpd 9 | import click 10 | import os 11 | 12 | 13 | def get_rasterization(input_shp, meta_data={}, ref_tif='', output_tif=''): 14 | """ 15 | Get rasterized image from shapefile. 16 | Args: 17 | input_shp: string. path to the shapefile (.shp). 18 | meta_data: dictionary. meta_data passed to rasterio.open() function 19 | in case reference image is not specified. 20 | ref_tif: string. path to the reference .tif image. 21 | When specified meta_data is ignored. 22 | output_tif: string. path to the output rasterized image. 23 | 24 | Returns: 25 | numpy array. rasterized image. 26 | """ 27 | 28 | # Checking the correctness of the arguments 29 | assert os.path.isfile(input_shp), 'no file: input_shp' 30 | assert type(meta_data) is dict, 'meta_data has to be dictionary' 31 | if ref_tif != '': 32 | assert os.path.isfile(ref_tif), 'no file: ref_tif' 33 | 34 | shape_df = gpd.read_file(input_shp) 35 | 36 | if ref_tif != '': 37 | rst = rasterio.open(ref_tif) 38 | meta_data = rst.meta.copy() 39 | meta_data['count'] = 1 40 | 41 | with rasterio.open(output_tif, 'w', **meta_data) as out: 42 | shapes = ((geom, 1) for geom in shape_df.geometry) 43 | # create a generator of geom, value pairs to use in rasterizing 44 | result = features.rasterize(shapes=shapes, out_shape=out.shape, fill=0, transform=out.transform) 45 | if output_tif: 46 | out.write_band(1, result) 47 | 48 | return result 49 | 50 | 51 | @click.command() 52 | @click.argument('input_shp', type=click.STRING) 53 | @click.option('--ref_tif', type=click.STRING, default='', help='path to the reference .tif image') 54 | @click.option('--output_tif', type=click.STRING, default='', help='path to the output rasterized image.') 55 | def main(input_shp, ref_tif, output_tif): 56 | meta_data = { 57 | 'driver': 'GTiff', 58 | 'dtype': 'uint8', 59 | 'nodata': None, 60 | 'width': 13500, 61 | 'height': 15300, 62 | 'count': 1, 63 | 'crs': {'init': 'epsg:26911'}, 64 | 'transform': (0.5, 0.0, 481649.0, 0.0, -0.5, 3623101.0)} 65 | 66 | files = os.listdir(input_shp) 67 | files = [file for file in files if file.endswith('.shp')] 68 | for file in files: 69 | input_shp_path = os.path.join(input_shp, file) 70 | ref_tif_path = os.path.join(ref_tif, file.split('.')[0]+'.tif') 71 | output_tif_paht = os.path.join(output_tif, file.split('.')[0]+'.tif') 72 | get_rasterization(input_shp=input_shp_path, meta_data=meta_data, ref_tif=ref_tif_path, output_tif=output_tif_paht) 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/clustering.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/utils/__pycache__/clustering.cpython-36.pyc -------------------------------------------------------------------------------- /utils/clustering.py: -------------------------------------------------------------------------------- 1 | # Skript based on https://github.com/facebookresearch/deepcluster/blob/master/clustering.py 2 | # Facebook, Inc. 3 | 4 | import time 5 | 6 | import faiss 7 | from PIL import ImageFile 8 | 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | __all__ = ['Kmeans'] 13 | 14 | 15 | def run_kmeans(x, nmb_clusters, verbose=False, init_cents=None): 16 | """Runs kmeans on 1 GPU. 17 | Args: 18 | x: data 19 | nmb_clusters (int): number of clusters 20 | Returns: 21 | list: ids of data in each cluster 22 | """ 23 | n_data, d = x.shape 24 | 25 | # faiss implementation of k-means 26 | clus = faiss.Clustering(d, nmb_clusters) 27 | clus.niter = 20 28 | clus.max_points_per_centroid = 10000000 29 | 30 | if init_cents is not None: 31 | clus.centroids.resize(init_cents.size) 32 | faiss.memcpy(clus.centroids.data(), faiss.swig_ptr(init_cents), init_cents.size * 4) 33 | 34 | index = faiss.IndexFlatL2(d) 35 | 36 | # perform the training 37 | clus.train(x, index) 38 | _, I = index.search(x, 1) 39 | losses = faiss.vector_to_array(clus.obj) 40 | if verbose: 41 | print('k-means loss evolution: {0}'.format(losses)) 42 | 43 | centroids = faiss.vector_to_array(clus.centroids).reshape((nmb_clusters, d)) 44 | 45 | return [int(n[0]) for n in I], losses[-1], centroids 46 | 47 | 48 | 49 | class Kmeans: 50 | def __init__(self, k): 51 | self.k = k 52 | 53 | def cluster(self, data, verbose=False, init_cents=None): 54 | """Performs k-means clustering. 55 | Args: 56 | x_data (np.array N * dim): data to cluster 57 | """ 58 | end = time.time() 59 | 60 | xb = data 61 | # cluster the data 62 | I, loss, cents = run_kmeans(xb, self.k, verbose, init_cents) 63 | self.images_lists = [[] for i in range(self.k)] 64 | for i in range(len(data)): 65 | self.images_lists[I[i]].append(i) 66 | 67 | if verbose: 68 | print('k-means time: {0:.0f} s'.format(time.time() - end)) 69 | 70 | return cents 71 | 72 | def make_graph(xb, nnn): 73 | """Builds a graph of nearest neighbors. 74 | Args: 75 | xb (np.array): data 76 | nnn (int): number of nearest neighbors 77 | Returns: 78 | list: for each data the list of ids to its nnn nearest neighbors 79 | list: for each data the list of distances to its nnn NN 80 | """ 81 | N, dim = xb.shape 82 | 83 | # L2 84 | index = faiss.IndexFlatL2(dim) 85 | 86 | index.add(xb) 87 | D, I = index.search(xb, nnn + 1) 88 | return I, D 89 | --------------------------------------------------------------------------------