├── .gitignore
├── .idea
└── clusternet_segmentation.iml
├── Dockerfile
├── README.md
├── interactive_tool
├── clusters_to_class.py
├── crop_img.py
└── inter_cluster.py
├── main.py
├── models
├── ContextEncoder.py
├── UNetValid.py
├── VGG_UNet.py
├── __init__.py
└── __pycache__
│ ├── UNetValid.cpython-36.pyc
│ └── __init__.cpython-36.pyc
├── requirements.txt
├── super_train.py
├── tools
├── __init__.py
├── clusters2classes.py
├── evaluation.py
├── full_prediction.py
├── patch_generator.py
└── rasterization.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
└── clustering.cpython-36.pyc
└── clustering.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.npy
3 | *.csv
4 | *.hdf5
5 | *.h5
6 | *.png
7 | *.json
8 | *.gz
9 | *.tar
10 | *.tif
11 | *.tiff
12 | *.shp
13 | *.xml
14 | *.jpg
15 | *.tfrecords
16 | *.JPEG
17 | data/*
--------------------------------------------------------------------------------
/.idea/clusternet_segmentation.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
2 |
3 | ARG FAISS_CPU_OR_GPU=cpu
4 | ARG FAISS_VERSION=1.4.0
5 |
6 | ENV HOME /root
7 |
8 | RUN apt-get update && \
9 | apt-get install -y build-essential curl software-properties-common bzip2 python3-pip && \
10 | curl https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh > /tmp/conda.sh && \
11 | bash /tmp/conda.sh -b -p /opt/conda && \
12 | /opt/conda/bin/conda update -n base conda && \
13 | /opt/conda/bin/conda install -y -c pytorch faiss-${FAISS_CPU_OR_GPU}=${FAISS_VERSION} && \
14 | apt-get remove -y --auto-remove curl bzip2 && \
15 | apt-get clean && \
16 | rm -fr /tmp/conda.sh
17 |
18 | #RUN /opt/conda/bin/conda install -c conda-forge tensorflow-gpu \
19 | # keras \
20 | # scikit-image \
21 | # click
22 |
23 | RUN /opt/conda/bin/pip install sklearn \
24 | scikit-image \
25 | click \
26 | pandas \
27 | tensorflow-gpu==1.11.0 \
28 | Keras==2.2.2 \
29 | Keras-Applications==1.0.4 \
30 | Keras-Preprocessing==1.0.2
31 |
32 | ENV LC_ALL=C.UTF-8
33 | ENV LANG=C.UTF-8
34 |
35 | ENV PATH="/opt/conda/bin:${PATH}"
36 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ClusterNet for Unsupervised Segmentation on satellite images
2 | Semantic segmentation on satellite images is used to automatically detect and classify
3 | objects of interest while preserving the geometric structure of these objects. In this work,
4 | we present a novel segmentation approach of weekly supervised learning based on
5 | applying K-Means clustering technique to the high-dimensional output of the Deep
6 | Neural Network. This approach is generic and can be applied to any image and to any
7 | class of objects on that image. At the same time, it does not require the ground-truth
8 | during the training stage.
9 |
10 |
11 |
12 | ### Prerequisites
13 |
14 | ```
15 | Keras 2.2.2
16 | faiss 1.4.0
17 | ```
18 |
19 | ### Installing
20 |
21 | All library requirementes needed to run the script are in requirements.txt file
22 |
23 | To install them in your local environment use command:
24 |
25 | ```
26 | pip install -r requirements.txt
27 | ```
28 |
29 | ## Running the experiments
30 |
31 | 1. You have to generate the train and test (not required) dataset of images of the same size from the original images. For this call tools/patch_generator.py from the root of repository.
32 | For example, to generate images of size 512*512 with no overlap, use the command:
33 | ```
34 | python ./tools/patch_generator --output_folder --patch_shape 512 512 --stride 512 512
35 | ```
36 |
37 | 2. To train the keras model in an usupervised fashion run:
38 | ```
39 | python ./main.py
40 | ```
41 | *Note*: in this work we use ImageDataGenerator from keras.preprocessing.image which requires that your images are stored in subfolers one levele deeper that the folder you specified as .
42 |
43 | ## Inspired by
44 | *Deep Clustering for Unsupervised Learning of Visual Features* - [DeepCluster](https://github.com/facebookresearch/deepcluster)
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/interactive_tool/clusters_to_class.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import skimage.io
7 | import os
8 | import click
9 |
10 |
11 | @click.command()
12 | @click.argument('base_image_file', type=click.STRING)
13 | @click.argument('pos_neg_mask_file', type=click.STRING)
14 | @click.argument('masks_folder', type=click.STRING)
15 | @click.option('--output_folder', type=click.STRING, default='')
16 | def main(base_image_file, pos_neg_mask_file, masks_folder, output_folder):
17 | base_image = skimage.io.imread(base_image_file)
18 | pos_neg_mask = np.load(pos_neg_mask_file)
19 |
20 | pos_clusts = np.unique(base_image[pos_neg_mask == 1])
21 | neg_clusts = np.unique(base_image[pos_neg_mask == -1])
22 |
23 | intersect = np.intersect1d(pos_clusts, neg_clusts)
24 |
25 | for clust in intersect:
26 | pos = np.sum((base_image == clust) * (pos_neg_mask == 1))
27 | neg = np.sum((base_image == clust) * (pos_neg_mask == -1))
28 | pos_ratio = pos / (pos + neg)
29 |
30 | if pos_ratio < 0.0:
31 | index = np.argwhere(pos_clusts == clust)
32 | pos_clusts = np.delete(pos_clusts, index)
33 |
34 | for file in os.listdir(masks_folder):
35 | file_path = os.path.join(masks_folder, file)
36 | mask = skimage.io.imread(file_path).astype('float32')
37 |
38 | if output_folder != '':
39 | out_file_path = os.path.join(output_folder, file)
40 | skimage.io.imsave(out_file_path, np.isin(mask, pos_clusts))
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/interactive_tool/crop_img.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import click
6 | import os
7 | import skimage.io
8 | import skimage.util
9 | import numpy
10 | import math
11 |
12 | @click.command()
13 | @click.argument('input_images', type=click.STRING)
14 | @click.option('--crop_shape', nargs=2, type=click.INT, default=(309, 309))
15 | @click.option('--output_folder', type=click.STRING)
16 | def main(input_images, crop_shape, output_folder):
17 | assert os.path.isdir(input_images)
18 |
19 | height_crop = int((512 - crop_shape[0])) / 2
20 | width_crop = int((512 - crop_shape[1])) / 2
21 |
22 | for file in os.listdir(input_images):
23 | image_file = os.path.join(input_images, file)
24 | image = skimage.io.imread(image_file)
25 |
26 | crop_img = skimage.util.crop(image, ((math.ceil(height_crop), int(height_crop)), (math.ceil(width_crop), int(width_crop))))
27 |
28 | out_image_file = os.path.join(output_folder, file)
29 | skimage.io.imsave(out_image_file, crop_img)
30 |
31 | return 0
32 |
33 |
34 | if __name__ == "__main__":
35 | main()
--------------------------------------------------------------------------------
/interactive_tool/inter_cluster.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import skimage.io
6 | import numpy as np
7 | from keras import Model
8 | from tkinter import *
9 | from PIL import ImageTk, Image
10 |
11 | import matplotlib.pyplot as plt
12 | from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
13 | from matplotlib.figure import Figure
14 | import matplotlib.image as mpimg
15 | import skimage.io
16 |
17 | from models.UNetValid import get_model
18 | from utils import clustering
19 |
20 | img_path = '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/crop_img.png'
21 |
22 |
23 | class Paint(object):
24 |
25 | def __init__(self):
26 | self.root = Tk()
27 | self.mask_clust = Toplevel()
28 |
29 | self.img_np = skimage.io.imread(img_path)
30 | img = Image.fromarray(self.img_np)
31 | self.img = ImageTk.PhotoImage(image=img)
32 | frame = Frame(self.root, width=self.img.width(), height=self.img.height())
33 | frame.grid(row=1, columnspan=2)
34 | self.canvas = Canvas(frame, bg='#FFFFFF', width=512, height=512,
35 | scrollregion=(0, 0, self.img.width(), self.img.height()))
36 | hbar = Scrollbar(frame, orient=HORIZONTAL)
37 | hbar.pack(side=BOTTOM, fill=X)
38 | hbar.config(command=self.canvas.xview)
39 | vbar = Scrollbar(frame, orient=VERTICAL)
40 | vbar.pack(side=RIGHT, fill=Y)
41 | vbar.config(command=self.canvas.yview)
42 | self.canvas.config(width=512, height=512)
43 | self.canvas.config(xscrollcommand=hbar.set, yscrollcommand=vbar.set)
44 | self.canvas.pack(side=LEFT, expand=True, fill=BOTH)
45 |
46 | # self.c = Canvas(self.root, bg='black', width=self.img.width(), height=self.img.height())
47 | # self.c.grid(row=1, columnspan=5)
48 |
49 | self.done_button = Button(self.root, text='DONE', command=self.done)
50 | self.done_button.grid(row=0, column=1)
51 |
52 | self.pos_stroke = BooleanVar()
53 | self.stroke_type_button = Checkbutton(self.root, text='positive', variable=self.pos_stroke)
54 | self.stroke_type_button.grid(row=0, column=0)
55 |
56 | self.fig_mask_clust = Figure(figsize=(4, 4))
57 | self.subplot_mask = self.fig_mask_clust.add_subplot(111)
58 | self.subplot_mask.axis('off')
59 |
60 | self.mask_clust_c = FigureCanvasTkAgg(self.fig_mask_clust, master=self.mask_clust)
61 | self.mask_clust_c.show()
62 | self.mask_clust_c.get_tk_widget().pack(side=TOP, fill=BOTH, expand=1)
63 |
64 | self.setup()
65 | self.root.mainloop()
66 |
67 | def setup(self):
68 | self.old_x = None
69 | self.old_y = None
70 |
71 | self.num_clusters = 100
72 |
73 | model_weights_path = '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/results/38/model.h5'
74 | img_path = '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/raw/test/images/JNCASR_Overall_Inducer_123_Folder_F_30_ebss_07_R3D_D3D_PRJ.png'
75 | self.image = skimage.io.imread(img_path).astype('float32')
76 |
77 | self.image = np.expand_dims(self.image, -1)
78 | # mean = np.array([9.513245, 10.786118, 10.060172], dtype=np.float32).reshape(1, 1, 3)
79 | # std = np.array([18.76071, 19.97462, 19.616455], dtype=np.float32).reshape(1, 1, 3)
80 | mean = np.array([30.17722], dtype='float32').reshape(1, 1, 1)
81 | std = np.array([33.690296], dtype='float32').reshape(1, 1, 1)
82 | self.image -= mean
83 | self.image /= std
84 | # self.image /= 255
85 |
86 | self.model = get_model(self.image.shape, self.num_clusters)
87 | self.model.load_weights(model_weights_path)
88 |
89 | before_last_layer_model = Model(inputs=self.model.input, outputs=self.model.get_layer('for_clust').output)
90 |
91 | self.x = before_last_layer_model.predict(np.expand_dims(self.image, axis=0), verbose=1, batch_size=1)
92 |
93 | self.pos_neg_mask = np.zeros((self.x.shape[1], self.x.shape[2]))
94 | # self.pos_neg_mask = np.load(
95 | # '/home/zhygallo/zhygallo/inria/cluster_seg/deepcluster_segmentation/data/patches_512_lux/pos_neg_mask.npy')
96 | self.x_flat = np.reshape(self.x, (self.x.shape[0] * self.x.shape[1] * self.x.shape[2], self.x.shape[-1]))
97 |
98 | self.n_pix, self.dim_pix = self.x_flat.shape[0], self.x_flat.shape[1]
99 |
100 | self.deepcluster = clustering.Kmeans(self.num_clusters)
101 | self.centroids = self.deepcluster.cluster(self.x_flat, verbose=True, init_cents=None)
102 |
103 | masks_flat = np.zeros((self.n_pix, 1))
104 | for clust_ind, clust in enumerate(self.deepcluster.images_lists):
105 | masks_flat[clust] = clust_ind
106 | self.masks = masks_flat.reshape((self.x.shape[1], self.x.shape[2]))
107 |
108 | self.count = 1
109 | skimage.io.imsave(
110 | '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/binar_pred/clust_%i.png' % self.count,
111 | self.masks.astype('uint32'))
112 | self.subplot_mask.imshow(self.masks)
113 | self.canvas.create_image(0, 0, image=self.img, anchor=NW)
114 |
115 | self.canvas.bind('', self.paint)
116 | self.canvas.bind('', self.reset)
117 |
118 | def activate_button(self, some_button):
119 | self.active_button.config(relief=RAISED)
120 | some_button.config(relief=SUNKEN)
121 | self.active_button = some_button
122 |
123 | def paint(self, event):
124 | self.line_width = 3.0
125 | if self.pos_stroke.get():
126 | self.color = 'blue'
127 | else:
128 | self.color = 'red'
129 | paint_color = self.color
130 | if self.old_x and self.old_y:
131 | self.canvas.create_line(self.old_x, self.old_y,
132 | self.canvas.canvasx(event.x), self.canvas.canvasy(event.y),
133 | width=self.line_width, fill=paint_color,
134 | capstyle=ROUND, smooth=TRUE, splinesteps=36)
135 | self.old_x = self.canvas.canvasx(event.x)
136 | self.old_y = self.canvas.canvasy(event.y)
137 | if self.pos_stroke.get():
138 | self.pos_neg_mask[int(self.canvas.canvasy(event.y)), int(self.canvas.canvasx(event.x))] = 1
139 | else:
140 | self.pos_neg_mask[int(self.canvas.canvasy(event.y)), int(self.canvas.canvasx(event.x))] = -1
141 |
142 | def reset(self, event):
143 | self.old_x, self.old_y = None, None
144 |
145 | def done(self):
146 | np.save('/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/pos_neg_mask_transfer.npy',
147 | self.pos_neg_mask)
148 | pos_clusts = np.unique(self.masks[self.pos_neg_mask == 1])
149 | neg_clusts = np.unique(self.masks[self.pos_neg_mask == -1])
150 |
151 | intersect = np.intersect1d(pos_clusts, neg_clusts)
152 |
153 | for clust in intersect:
154 | pos = np.sum((self.masks == clust) * (self.pos_neg_mask == 1))
155 | neg = np.sum((self.masks == clust) * (self.pos_neg_mask == -1))
156 | pos_ratio = pos / (pos + neg)
157 | neg_ratio = neg / (pos + neg)
158 |
159 | if pos_ratio < 0.3 or neg_ratio < 0.3:
160 | continue
161 |
162 | self.centroids = np.delete(self.centroids, clust, axis=0)
163 | neg_centroid = self.x[0][((self.masks == clust) *
164 | (self.pos_neg_mask == -1))].mean(axis=0).reshape(1, self.centroids.shape[1])
165 | pos_centroid = self.x[0][((self.masks == clust) *
166 | (self.pos_neg_mask == 1))].mean(axis=0).reshape(1, self.centroids.shape[1])
167 |
168 | self.centroids = np.concatenate((self.centroids, neg_centroid, pos_centroid), axis=0)
169 |
170 | self.deepcluster = clustering.Kmeans(self.centroids.shape[0])
171 | self.centroids = self.deepcluster.cluster(self.x_flat, verbose=True, init_cents=self.centroids)
172 |
173 | masks_flat = np.zeros((self.n_pix, 1))
174 | for clust_ind, clust in enumerate(self.deepcluster.images_lists):
175 | masks_flat[clust] = clust_ind
176 | self.masks = masks_flat.reshape((self.x.shape[1], self.x.shape[2]))
177 |
178 | pos_clusts = np.unique(self.masks[self.pos_neg_mask == 1])
179 | neg_clusts = np.unique(self.masks[self.pos_neg_mask == -1])
180 |
181 | intersect = np.intersect1d(pos_clusts, neg_clusts)
182 |
183 | pos_neg_arr = np.zeros(self.masks.shape)
184 | pos_neg_arr[np.isin(self.masks, pos_clusts)] = 1
185 | pos_neg_arr[np.isin(self.masks, neg_clusts)] = 2
186 | pos_neg_arr[np.isin(self.masks, intersect)] = 3
187 | self.subplot_mask.imshow(pos_neg_arr)
188 |
189 | self.mask_clust_c.show()
190 |
191 | pos_arr = np.zeros(self.masks.shape)
192 | pos_arr[np.isin(self.masks, pos_clusts)] = 1
193 | for clust in intersect:
194 | pos = np.sum((self.masks == clust) * (self.pos_neg_mask == 1))
195 | neg = np.sum((self.masks == clust) * (self.pos_neg_mask == -1))
196 | pos_ratio = pos / (pos + neg)
197 |
198 | if pos_ratio < 0.3:
199 | pos_arr[self.masks == clust] = 0
200 |
201 | # np.save('/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/pos_arr.npy', pos_clusts)
202 |
203 | skimage.io.imsave(
204 | '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/binar_pred/pos_%i.png' % self.count,
205 | np.isin(self.masks, pos_clusts))
206 | skimage.io.imsave(
207 | '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/binar_pred/only_pos_%i.png' % self.count,
208 | np.isin(self.masks, pos_clusts) * (1 - np.isin(self.masks, neg_clusts)))
209 | skimage.io.imsave(
210 | '/home/zhygallo/zhygallo/zeiss/clusternet_segmentation/data/binar_pred/filt_pos_%i.png' % self.count,
211 | pos_arr.astype('uint16'))
212 | self.count += 1
213 |
214 |
215 | def main():
216 | Paint()
217 |
218 |
219 | if __name__ == "__main__":
220 | main()
221 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import click
7 | import os
8 | from keras.preprocessing.image import ImageDataGenerator
9 | from keras.utils import to_categorical
10 | from keras.optimizers import Adam
11 | from keras import Model
12 | from keras.callbacks import ModelCheckpoint
13 | from keras.callbacks import Callback
14 | import keras.backend as K
15 |
16 | from models.UNetValid import get_model
17 | from utils import clustering
18 | from tools.evaluation import final_prediction
19 |
20 | @click.command()
21 | @click.argument('train_img_folder', type=click.STRING)
22 | @click.option('--test_img_folder', type=click.STRING, default='')
23 | @click.option('--img_shape', nargs=2, type=click.INT, default=(256, 256))
24 | @click.option('--num_clusters', type=click.INT, default=100)
25 | @click.option('--num_epochs', type=click.INT, default=100)
26 | @click.option('--learn_rate', type=click.FLOAT, default=1e-4)
27 | @click.option('--clust_batch_size', type=click.INT, default=32)
28 | @click.option('--train_batch_size', type=click.INT, default=4)
29 | @click.option('--current_epoch', type=click.INT, default=0)
30 | @click.option('--pretrained_weights', type=click.STRING, default='')
31 | @click.option('--out_weights_file', type=click.STRING, default='')
32 | @click.option('--out_pred_masks_test', type=click.STRING, default='')
33 | def main(train_img_folder, test_img_folder, img_shape, num_clusters, num_epochs, learn_rate,
34 | clust_batch_size, train_batch_size, current_epoch, pretrained_weights, out_weights_file,
35 | out_pred_masks_test):
36 | assert os.path.isdir(train_img_folder)
37 | if train_img_folder[-1] != '/':
38 | train_img_folder += '/'
39 | if test_img_folder[-1] != '/':
40 | assert os.path.isdir(test_img_folder)
41 | test_img_folder += '/'
42 |
43 | train_image_datagen = ImageDataGenerator(
44 | featurewise_center=True,
45 | featurewise_std_normalization=True,
46 | data_format='channels_last',
47 | )
48 |
49 | # (512, 512)
50 | train_image_datagen.mean = np.array([30.17722], dtype='float32').reshape(1, 1, 1)
51 | train_image_datagen.std = np.array([33.690296], dtype='float32').reshape(1, 1, 1)
52 |
53 | SEED = 1
54 | train_image_generator = train_image_datagen.flow_from_directory(train_img_folder,
55 | color_mode='grayscale',
56 | target_size=img_shape,
57 | batch_size=clust_batch_size,
58 | class_mode=None,
59 | shuffle=True,
60 | seed=SEED)
61 |
62 |
63 | test_image_datagen = ImageDataGenerator(
64 | featurewise_center=True,
65 | featurewise_std_normalization=True,
66 | data_format='channels_last',
67 | )
68 |
69 | test_image_datagen.mean = train_image_datagen.mean
70 | test_image_datagen.std = train_image_datagen.std
71 |
72 |
73 | test_image_generator = test_image_datagen.flow_from_directory(test_img_folder,
74 | color_mode='grayscale',
75 | target_size=img_shape,
76 | batch_size=32,
77 | class_mode=None,
78 | shuffle=False)
79 |
80 | num_images = train_image_generator.n
81 | model = get_model(train_image_generator.image_shape, num_classes=num_clusters)
82 |
83 | alpha = K.variable(1.)
84 | losses = {'clust_output': 'categorical_crossentropy'}
85 | loss_weights = {'clust_output': alpha}
86 |
87 | before_last_layer_model = Model(inputs=model.input, outputs=model.get_layer('for_clust').output)
88 |
89 | if pretrained_weights != '':
90 | # for layer in model.layers[-1:]:
91 | # layer.name += '_lux'
92 | model.load_weights(pretrained_weights, by_name=True)
93 |
94 | deepcluster = clustering.Kmeans(num_clusters)
95 |
96 | while current_epoch <= num_epochs:
97 | print('\n############# Epoch %i / %i #############' % ((current_epoch), num_epochs))
98 | num_iter = int(np.ceil(num_images / float(clust_batch_size)))
99 | for i in range(num_iter):
100 | print('%i:%i/%i' % (i * clust_batch_size, (i + 1) * clust_batch_size, num_images))
101 |
102 | images_batch = train_image_generator[i]
103 |
104 | print('Prediction:')
105 | features_batch = before_last_layer_model.predict(images_batch, verbose=1, batch_size=train_batch_size)
106 |
107 | flat_feat_batch = features_batch.reshape(-1, features_batch.shape[-1])
108 | print('Clustering:')
109 | deepcluster.cluster(flat_feat_batch, verbose=True)
110 |
111 | masks_batch_flat = np.zeros((flat_feat_batch.shape[0], 1))
112 | for clust_ind, clust in enumerate(deepcluster.images_lists):
113 | masks_batch_flat[clust] = clust_ind
114 | masks_batch_flat = to_categorical(masks_batch_flat, num_clusters)
115 | masks_batch = masks_batch_flat.reshape((features_batch.shape[0], features_batch.shape[1],
116 | features_batch.shape[2], num_clusters))
117 |
118 | masks = {'clust_output': masks_batch}
119 |
120 | print('Training:')
121 |
122 | model.compile(optimizer=Adam(lr=(learn_rate)), loss=losses, loss_weights=loss_weights)
123 | model.fit(images_batch, masks, batch_size=train_batch_size, epochs=1, verbose=1)
124 |
125 | if out_pred_masks_test != '':
126 | out_pred_epoch_fold = os.path.join(out_pred_masks_test, str(current_epoch))
127 | final_pred_model = Model(inputs=model.input, outputs=model.get_layer('clust_output').output)
128 | final_prediction(out_pred_epoch_fold, test_image_generator, final_pred_model)
129 | model.save_weights(os.path.join(out_pred_epoch_fold, 'model.h5'))
130 |
131 | current_epoch += 1
132 |
133 |
134 | if __name__ == '__main__':
135 | main()
--------------------------------------------------------------------------------
/models/ContextEncoder.py:
--------------------------------------------------------------------------------
1 | """Keras implementation of UNet model"""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import math
7 | from keras.models import Model
8 | from keras.layers import Input, Conv2D, UpSampling2D, Concatenate, Cropping2D, BatchNormalization, Activation, MaxPooling2D
9 | from keras import backend as K
10 |
11 | K.set_image_data_format('channels_last')
12 |
13 |
14 | def get_model(img_shape, num_classes=2):
15 | inputs = Input(shape=img_shape)
16 |
17 | conv1_1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(inputs)
18 | conv1_2 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv1_1)
19 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_2)
20 |
21 | conv2_1 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool1)
22 | conv2_2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv2_1)
23 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_2)
24 |
25 | conv3_1 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool2)
26 | conv3_2 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv3_1)
27 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3_2)
28 |
29 | conv4_1 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool3)
30 | conv4_2 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv4_1)
31 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4_2)
32 |
33 | conv5_1 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool4)
34 | conv5_2 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv5_1)
35 |
36 | up6 = Conv2D(512, 2, activation='relu', padding='valid', kernel_initializer='he_normal')(
37 | UpSampling2D(size=(2, 2))(conv5_2))
38 | height_dif = int((conv4_2.shape[1] - up6.shape[1])) / 2
39 | width_dif = int((conv4_2.shape[2] - up6.shape[2])) / 2
40 | crop_conv4_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv4_2)
41 | concat6 = Concatenate()([crop_conv4_2, up6])
42 | conv6_1 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat6)
43 | conv6_2 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv6_1)
44 |
45 | up7 = Conv2D(256, 2, activation='relu', padding='valid', kernel_initializer='he_normal')(
46 | UpSampling2D(size=(2, 2))(conv6_2))
47 | height_dif = int((conv3_2.shape[1] - up7.shape[1])) / 2
48 | width_dif = int((conv3_2.shape[2] - up7.shape[2])) / 2
49 | crop_conv3_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv3_2)
50 | concat7 = Concatenate()([crop_conv3_2, up7])
51 | conv7_1 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat7)
52 | conv7_2 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv7_1)
53 |
54 | up8 = Conv2D(128, 2, activation='relu', padding='valid', kernel_initializer='he_normal')(
55 | UpSampling2D(size=(2, 2))(conv7_2))
56 | height_dif = int((conv2_2.shape[1] - up8.shape[1])) / 2
57 | width_dif = int((conv2_2.shape[2] - up8.shape[2])) / 2
58 | crop_conv2_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv2_2)
59 | concat8 = Concatenate()([crop_conv2_2, up8])
60 | conv8_1 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat8)
61 | conv8_2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv8_1)
62 |
63 | up9 = Conv2D(64, 2, activation='relu', padding='valid', kernel_initializer='he_normal')(
64 | UpSampling2D(size=(2, 2))(conv8_2))
65 | height_dif = int((conv1_2.shape[1] - up9.shape[1])) / 2
66 | width_dif = int((conv1_2.shape[2] - up9.shape[2])) / 2
67 | crop_conv1_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv1_2)
68 | concat9 = Concatenate()([crop_conv1_2, up9])
69 | conv9_1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat9)
70 | conv9_2 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal', name='for_clust')(conv9_1)
71 |
72 | conv10 = Conv2D(num_classes, (1, 1), activation='softmax', name='clust_output')(conv9_2)
73 |
74 | conv11 = Conv2D(3, (1, 1), activation='sigmoid', name='rgb_output')(conv9_2)
75 |
76 | model = Model(inputs=inputs, outputs=[conv10, conv11])
77 |
78 | return model
79 |
80 |
81 | def main():
82 | model = get_model(img_shape=(512, 512, 3))
83 | return 0
84 |
85 |
86 | if __name__ == "__main__":
87 | main()
--------------------------------------------------------------------------------
/models/UNetValid.py:
--------------------------------------------------------------------------------
1 | """Keras implementation of UNet model"""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import math
7 | from keras.models import Model
8 | from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Cropping2D
9 | from keras import backend as K
10 |
11 | K.set_image_data_format('channels_last')
12 |
13 |
14 | def get_model(img_shape, num_classes=2):
15 | inputs = Input(shape=img_shape)
16 |
17 | conv1_1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(inputs)
18 | conv1_2 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv1_1)
19 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_2)
20 |
21 | conv2_1 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool1)
22 | conv2_2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv2_1)
23 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_2)
24 |
25 | conv3_1 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool2)
26 | conv3_2 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv3_1)
27 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3_2)
28 |
29 | conv4_1 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool3)
30 | conv4_2 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv4_1)
31 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4_2)
32 |
33 | conv5_1 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(pool4)
34 | conv5_2 = Conv2D(1024, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv5_1)
35 |
36 |
37 | up6 = Conv2D(512, 2, activation='relu', padding='valid', kernel_initializer='he_normal')(
38 | UpSampling2D(size=(2, 2))(conv5_2))
39 | height_dif = int((conv4_2.shape[1] - up6.shape[1])) / 2
40 | width_dif = int((conv4_2.shape[2] - up6.shape[2])) / 2
41 | crop_conv4_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv4_2)
42 | concat6 = Concatenate()([crop_conv4_2, up6])
43 | conv6_1 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat6)
44 | conv6_2 = Conv2D(512, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv6_1)
45 |
46 | up7 = Conv2D(256, 2, activation='relu', padding='valid', kernel_initializer='he_normal')(
47 | UpSampling2D(size=(2, 2))(conv6_2))
48 | height_dif = int((conv3_2.shape[1] - up7.shape[1])) / 2
49 | width_dif = int((conv3_2.shape[2] - up7.shape[2])) / 2
50 | crop_conv3_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv3_2)
51 | concat7 = Concatenate()([crop_conv3_2, up7])
52 | conv7_1 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat7)
53 | conv7_2 = Conv2D(256, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv7_1)
54 |
55 | up8 = Conv2D(128, 2, activation='relu', padding='valid', kernel_initializer='he_normal')(
56 | UpSampling2D(size=(2, 2))(conv7_2))
57 | height_dif = int((conv2_2.shape[1] - up8.shape[1])) / 2
58 | width_dif = int((conv2_2.shape[2] - up8.shape[2])) / 2
59 | crop_conv2_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv2_2)
60 | concat8 = Concatenate()([crop_conv2_2, up8])
61 | conv8_1 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat8)
62 | conv8_2 = Conv2D(128, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(conv8_1)
63 |
64 | up9 = Conv2D(64, 2, activation='relu', padding='valid', kernel_initializer='he_normal')(
65 | UpSampling2D(size=(2, 2))(conv8_2))
66 | height_dif = int((conv1_2.shape[1] - up9.shape[1])) / 2
67 | width_dif = int((conv1_2.shape[2] - up9.shape[2])) / 2
68 | crop_conv1_2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)), (math.ceil(width_dif), int(width_dif))))(conv1_2)
69 | concat9 = Concatenate()([crop_conv1_2, up9])
70 | conv9_1 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal')(concat9)
71 | conv9_2 = Conv2D(64, 3, activation='relu', padding='valid', kernel_initializer='he_normal', name='for_clust')(conv9_1)
72 |
73 | conv10 = Conv2D(num_classes, (1, 1), activation='softmax', name='clust_output')(conv9_2)
74 |
75 | # conv11 = Conv2D(2, (1, 1), activation='softmax', name='rgb_output')(conv9_2)
76 |
77 |
78 | model = Model(inputs=inputs, outputs=[conv10])
79 |
80 | return model
81 |
--------------------------------------------------------------------------------
/models/VGG_UNet.py:
--------------------------------------------------------------------------------
1 | """Keras implementation of UNet model"""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import math
7 | from keras.models import Model
8 | from keras.layers import Conv2D, UpSampling2D, Concatenate, Cropping2D, BatchNormalization, Activation, MaxPooling2D
9 | from keras.applications.vgg16 import VGG16
10 | from keras import backend as K
11 |
12 | K.set_image_data_format('channels_last')
13 |
14 |
15 | def get_model(img_shape, num_classes=2):
16 | class_model = VGG16(weights='imagenet', include_top=False, input_shape=img_shape)
17 |
18 | class_model.get_layer('block1_conv1').trainable = False
19 | class_model.get_layer('block1_conv2').trainable = False
20 | class_model.get_layer('block2_conv1').trainable = False
21 | class_model.get_layer('block2_conv2').trainable = False
22 |
23 | class_model = Model(inputs=class_model.inputs, outputs=class_model.get_layer('block3_pool').output)
24 |
25 | skip1 = class_model.get_layer('block3_conv3').output
26 | skip2 = class_model.get_layer('block2_conv2').output
27 | skip3 = class_model.get_layer('block1_conv2').output
28 |
29 | conv4_1 = Conv2D(512, 3, padding='valid')(class_model.output)
30 | batch4_1 = BatchNormalization()(conv4_1)
31 | act4_1 = Activation('relu')(batch4_1)
32 | conv4_2 = Conv2D(512, 3, padding='valid')(act4_1)
33 | batch4_2 = BatchNormalization()(conv4_2)
34 | act4_2 = Activation('relu')(batch4_2)
35 |
36 | # up5 = UpSampling2D(size=(2, 2))(act4_2)
37 | up5 = Conv2D(256, 2, activation='relu', padding='valid')(
38 | UpSampling2D(size=(2, 2))(act4_2))
39 | height_dif = int((skip1.shape[1] - up5.shape[1])) / 2
40 | width_dif = int((skip1.shape[2] - up5.shape[2])) / 2
41 | crop_skip1 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)),
42 | (math.ceil(width_dif), int(width_dif))))(skip1)
43 | concat5 = Concatenate()([crop_skip1, up5])
44 | conv5_1 = Conv2D(256, 3, padding='valid')(concat5)
45 | batch5_1 = BatchNormalization()(conv5_1)
46 | act5_1 = Activation('relu')(batch5_1)
47 | conv5_2 = Conv2D(256, 3, padding='valid')(act5_1)
48 | batch5_2 = BatchNormalization()(conv5_2)
49 | act5_2 = Activation('relu')(batch5_2)
50 |
51 | # up6 = UpSampling2D(size=(2, 2))(act5_2)
52 | up6 = Conv2D(128, 2, activation='relu', padding='valid')(
53 | UpSampling2D(size=(2, 2))(act5_2))
54 | height_dif = int((skip2.shape[1] - up6.shape[1])) / 2
55 | width_dif = int((skip2.shape[2] - up6.shape[2])) / 2
56 | crop_skip2 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)),
57 | (math.ceil(width_dif), int(width_dif))))(skip2)
58 | concat6 = Concatenate()([crop_skip2, up6])
59 | conv6_1 = Conv2D(128, 3, padding='valid')(concat6)
60 | batch6_1 = BatchNormalization()(conv6_1)
61 | act6_1 = Activation('relu')(batch6_1)
62 | conv6_2 = Conv2D(128, 3, padding='valid')(act6_1)
63 | batch6_2 = BatchNormalization()(conv6_2)
64 | act6_2 = Activation('relu')(batch6_2)
65 |
66 | up7 = Conv2D(64, 2, activation='relu', padding='valid')(
67 | UpSampling2D(size=(2, 2))(act6_2))
68 | height_dif = int((skip3.shape[1] - up7.shape[1])) / 2
69 | width_dif = int((skip3.shape[2] - up7.shape[2])) / 2
70 | crop_skip3 = Cropping2D(cropping=((math.ceil(height_dif), int(height_dif)),
71 | (math.ceil(width_dif), int(width_dif))))(skip3)
72 | concat7 = Concatenate()([crop_skip3, up7])
73 | conv7_1 = Conv2D(64, 3, padding='valid')(concat7)
74 | batch7_1 = BatchNormalization()(conv7_1)
75 | act7_1 = Activation('relu')(batch7_1)
76 | conv7_2 = Conv2D(64, 3, padding='valid')(act7_1)
77 | batch7_2 = BatchNormalization()(conv7_2)
78 | act7_2 = Activation('relu', name='for_clust')(batch7_2)
79 |
80 | conv8 = Conv2D(num_classes, (1, 1), activation='softmax', name='clust_output')(act7_2)
81 |
82 | # conv9 = Conv2D(3, (1, 1), activation='sigmoid', name='rgb_output')(act7_2)
83 |
84 | model = Model(inputs=class_model.inputs, outputs=[conv8])
85 |
86 | return model
87 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/models/__init__.py
--------------------------------------------------------------------------------
/models/__pycache__/UNetValid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/models/__pycache__/UNetValid.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.4.1
2 | affine==2.2.1
3 | astor==0.7.1
4 | attrs==18.2.0
5 | certifi==2018.11.29
6 | chardet==3.0.4
7 | click==6.7
8 | click-plugins==1.0.4
9 | cligj==0.4.0
10 | cloudpickle==0.5.6
11 | cycler==0.10.0
12 | dask==0.19.2
13 | decorator==4.3.0
14 | faiss==1.4.0
15 | Fiona==1.7.13
16 | gast==0.2.0
17 | grpcio==1.15.0
18 | h5py==2.8.0
19 | idna==2.7
20 | imageio==2.4.1
21 | Keras==2.2.2
22 | Keras-Applications==1.0.6
23 | Keras-Preprocessing==1.0.5
24 | kiwisolver==1.0.1
25 | lxml==4.2.5
26 | Markdown==2.6.11
27 | matplotlib==2.2.3
28 | mkl-fft==1.0.6
29 | mkl-random==1.0.1
30 | munch==2.3.2
31 | networkx==2.2
32 | numpy==1.15.2
33 | opencv-python==4.0.0.21
34 | pandas==0.23.4
35 | Pillow==5.3.0
36 | progressbar2==3.38.0
37 | protobuf==3.6.1
38 | pydensecrf==1.0rc3
39 | PyGraph==0.2.1
40 | PyMaxflow==1.2.11
41 | pyparsing==2.2.1
42 | python-dateutil==2.7.3
43 | python-utils==2.3.0
44 | pytz==2018.5
45 | PyWavelets==1.0.0
46 | PyYAML==3.13
47 | rasterio==1.0.4
48 | requests==2.19.1
49 | scikit-image==0.14.0
50 | scikit-learn==0.19.2
51 | scipy==1.1.0
52 | Shapely==1.6.4.post2
53 | six==1.11.0
54 | sklearn==0.0
55 | snuggs==1.4.2
56 | tensorboard==1.10.0
57 | tensorflow==1.10.1
58 | tensorlayer==1.10.1
59 | termcolor==1.1.0
60 | toolz==0.9.0
61 | tqdm==4.25.0
62 | urllib3==1.23
63 | Werkzeug==0.14.1
64 | wrapt==1.10.11
65 |
--------------------------------------------------------------------------------
/super_train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import click
7 | import os
8 | from keras.preprocessing.image import ImageDataGenerator
9 | from keras.utils import to_categorical
10 | from keras.optimizers import Adam
11 | from keras.callbacks import ModelCheckpoint
12 | from keras import Model
13 |
14 | from models.ContextEncoder import get_model
15 | from tools.evaluation import final_prediction
16 |
17 | def crop_center(image):
18 | crop_shape = (309, 309)
19 | crop_bef_h = (image.shape[0] - crop_shape[0]) // 2
20 | crop_bef_w = (image.shape[1] - crop_shape[1]) // 2
21 | out_img = image.copy()
22 | out_img[crop_bef_h:crop_bef_h + crop_shape[0], crop_bef_w:crop_bef_w + crop_shape[1]] = 0.5
23 | return out_img
24 |
25 | @click.command()
26 | @click.argument('train_img_folder', type=click.STRING)
27 | @click.argument('train_mask_folder', type=click.STRING)
28 | @click.option('--test_img_folder', type=click.STRING, default='')
29 | @click.option('--test_mask_folder', type=click.STRING, default='')
30 | @click.option('--in_img_shape', nargs=2, type=click.INT, default=(512, 512))
31 | @click.option('--out_img_shape', nargs=2, type=click.INT, default=(309, 309))
32 | @click.option('--num_classes', type=click.INT, default=800)
33 | @click.option('--num_epochs', type=click.INT, default=100)
34 | @click.option('--learn_rate', type=click.FLOAT, default=1e-3)
35 | @click.option('--batch_size', type=click.INT, default=4)
36 | @click.option('--current_epoch', type=click.INT, default=0)
37 | @click.option('--pretrained_weights', type=click.STRING, default='')
38 | @click.option('--out_weights_file', type=click.STRING, default='')
39 | @click.option('--out_pred_masks_test', type=click.STRING, default='')
40 | def main(train_img_folder, train_mask_folder, test_img_folder, test_mask_folder,
41 | in_img_shape, out_img_shape, num_classes, num_epochs, learn_rate, batch_size,
42 | current_epoch, pretrained_weights, out_weights_file, out_pred_masks_test):
43 | assert os.path.isdir(train_img_folder)
44 | if train_img_folder[-1] != '/':
45 | train_img_folder += '/'
46 | assert os.path.isdir(test_img_folder)
47 | if test_img_folder[-1] != '/':
48 | test_img_folder += '/'
49 |
50 | train_image_datagen = ImageDataGenerator(
51 | # featurewise_center=True,
52 | # featurewise_std_normalization=True,
53 | rescale=1./255,
54 | data_format='channels_last',
55 | preprocessing_function=crop_center
56 | )
57 |
58 | train_mask_datagen = ImageDataGenerator(
59 | rescale=1. / 255,
60 | # featurewise_center=True,
61 | # featurewise_std_normalization=True,
62 | )
63 |
64 | # (512, 512) lux
65 | # train_image_datagen.mean = np.array([11.4296465, 13.140564, 12.277675], dtype=np.float32).reshape(1, 1, 3)
66 | # train_image_datagen.std = np.array([27.561337, 29.903225, 29.525864], dtype=np.float32).reshape(1, 1, 3)
67 |
68 | SEED = 1
69 | train_image_gen = train_image_datagen.flow_from_directory(train_img_folder,
70 | target_size=in_img_shape,
71 | batch_size=batch_size,
72 | class_mode=None,
73 | shuffle=True,
74 | seed=SEED)
75 |
76 | train_mask_gen = train_mask_datagen.flow_from_directory(train_mask_folder,
77 | target_size=out_img_shape,
78 | batch_size=batch_size,
79 | # color_mode='grayscale',
80 | class_mode=None,
81 | shuffle=True,
82 | seed=SEED)
83 |
84 | test_image_datagen = ImageDataGenerator(
85 | # featurewise_center=True,
86 | # featurewise_std_normalization=True,
87 | rescale=1. / 255,
88 | data_format='channels_last',
89 | preprocessing_function=crop_center
90 | )
91 | test_image_gen = test_image_datagen.flow_from_directory(test_img_folder,
92 | target_size=in_img_shape,
93 | batch_size=batch_size,
94 | class_mode=None,
95 | shuffle=False)
96 |
97 | # test_image_datagen.mean = train_image_datagen.mean
98 | # test_image_datagen.std = train_image_datagen.mean
99 |
100 | model = get_model(train_image_gen.image_shape, num_classes=num_classes)
101 | model.compile(optimizer=Adam(lr=(learn_rate)), loss='mse')
102 |
103 | if pretrained_weights != '':
104 | # for layer in model.layers[-2:]:
105 | # layer.name += '_lux'
106 | model.load_weights(pretrained_weights, by_name=True)
107 |
108 | callback_list = []
109 | if out_weights_file != '':
110 | model_checkpoint = ModelCheckpoint(out_weights_file, period=1, save_weights_only=True)
111 | callback_list.append(model_checkpoint)
112 |
113 | while current_epoch <= num_epochs:
114 | print('\n############# Epoch %i / %i #############' % ((current_epoch), num_epochs))
115 | steps_train_epoch = train_image_gen.n // batch_size
116 | for i in range(steps_train_epoch):
117 | train_img_batch = train_image_gen[i]
118 | train_mask_batch = train_mask_gen[i]
119 | # train_mask_batch = to_categorical(train_mask_gen[i], num_classes)
120 |
121 | model.fit(train_img_batch, train_mask_batch, batch_size=8, epochs=1,
122 | callbacks=callback_list, verbose=1)
123 |
124 |
125 | if out_pred_masks_test != '':
126 | out_pred_epoch_fold = os.path.join(out_pred_masks_test, str(current_epoch))
127 | final_prediction(out_pred_epoch_fold, test_image_gen, model)
128 | model.save_weights(os.path.join(out_pred_epoch_fold, 'model.h5'))
129 |
130 | current_epoch += 1
131 |
132 |
133 | if __name__ == '__main__':
134 | main()
135 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/tools/__init__.py
--------------------------------------------------------------------------------
/tools/clusters2classes.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import click
6 | import skimage.io
7 | import numpy as np
8 | import random
9 | import os
10 |
11 |
12 | def gen_classes_from_clusters(in_clust_test, in_clust_train, in_gt_train, gt_fraction, keep_clust_threshold,
13 | out_fold, per_cluster_out_fold=''):
14 | assert os.path.isdir(in_clust_test)
15 | if in_clust_test[-1] != '/':
16 | in_clust_test += '/'
17 | assert os.path.isdir(in_clust_train)
18 | if in_clust_train[-1] != '/':
19 | in_clust_train += '/'
20 | assert os.path.isdir(in_gt_train)
21 | if in_gt_train[-1] != '/':
22 | in_gt_train += '/'
23 | assert type(gt_fraction) is float
24 | assert gt_fraction > 0 and gt_fraction <= 1
25 | assert type(keep_clust_threshold) is float
26 | assert keep_clust_threshold >= 0 and keep_clust_threshold <= 1
27 | if out_fold != '':
28 | if not os.path.isdir(out_fold):
29 | os.mkdir(out_fold)
30 | if per_cluster_out_fold != '':
31 | if not os.path.isdir(per_cluster_out_fold):
32 | os.mkdir(per_cluster_out_fold)
33 |
34 | all_clust_test_files = []
35 | for dirpath, dirnames, filenames in os.walk(in_clust_test):
36 | all_clust_test_files.extend([os.path.join(dirpath, f)
37 | for f in filenames if f.endswith('.png')])
38 |
39 | all_clust_train_files = []
40 | for dirpath, dirnames, filenames in os.walk(in_clust_train):
41 | all_clust_train_files.extend([os.path.join(os.path.relpath(dirpath, in_clust_train), f)
42 | for f in filenames if f.endswith('.png')])
43 |
44 | all_gt_train_files = []
45 | for dirpath, dirnames, filenames in os.walk(in_gt_train):
46 | all_gt_train_files.extend([os.path.join(os.path.relpath(dirpath, in_gt_train), f)
47 | for f in filenames if f.endswith('.png')])
48 |
49 |
50 | assert len(all_gt_train_files) * gt_fraction > 0
51 | assert len(all_clust_train_files) * gt_fraction > 0
52 |
53 | SEED = 1
54 | random.seed(SEED)
55 | random.shuffle(all_gt_train_files)
56 | num_gt = int(gt_fraction*len(all_gt_train_files))
57 | sub_gt_files = all_gt_train_files[:num_gt]
58 | assert np.sum([gt_file in all_clust_train_files for gt_file in sub_gt_files]) == len(sub_gt_files)
59 |
60 | sub_gt_train_full_path = [os.path.join(in_gt_train, file) for file in sub_gt_files]
61 | sub_clust_train_full_path = [os.path.join(in_clust_train, file) for file in sub_gt_files]
62 |
63 | gt_train_masks = np.asarray(skimage.io.imread_collection(sub_gt_train_full_path))
64 | clust_train_masks = np.asarray(skimage.io.imread_collection(sub_clust_train_full_path))
65 |
66 | selected_clusters = []
67 | for clust in np.unique(clust_train_masks):
68 | clust_pixel_num = (clust_train_masks == clust).sum()
69 | gt_pixel_num = (gt_train_masks[clust_train_masks == clust] == 1).sum()
70 | if gt_pixel_num / clust_pixel_num >= keep_clust_threshold:
71 | selected_clusters.append(clust)
72 |
73 | result_masks = []
74 | for clust_test_file in all_clust_test_files:
75 | clust_image = skimage.io.imread(clust_test_file)
76 | result_mask = np.zeros(clust_image.shape)
77 | for c in selected_clusters:
78 | result_mask[clust_image == c] = 1
79 | result_masks.append(result_mask)
80 |
81 | if out_fold != '':
82 | rel_path = os.path.relpath(clust_test_file, in_clust_test)
83 | full_out_path = os.path.join(out_fold, rel_path)
84 | if not os.path.isdir(os.path.dirname(full_out_path)):
85 | os.makedirs(os.path.dirname(full_out_path))
86 | skimage.io.imsave(full_out_path, result_mask.astype('uint8'))
87 |
88 | return result_masks
89 |
90 |
91 | @click.command()
92 | @click.argument('in_clust_test', type=click.STRING)
93 | @click.argument('in_clust_train', type=click.STRING)
94 | @click.argument('in_gt_train', type=click.STRING)
95 | @click.option('--gt_fraction', type=click.FLOAT, default='')
96 | @click.option('--keep_clust_threshold', type=click.FLOAT, default=0)
97 | @click.option('--out_fold', type=click.STRING, default='')
98 | @click.option('--per_cluster_out_fold', type=click.STRING, default='')
99 | def main(in_clust_test, in_clust_train, in_gt_train, gt_fraction,
100 | keep_clust_threshold, out_fold, per_cluster_out_fold):
101 | gen_classes_from_clusters(in_clust_test, in_clust_train, in_gt_train, gt_fraction=gt_fraction,
102 | keep_clust_threshold=keep_clust_threshold, out_fold=out_fold,
103 | per_cluster_out_fold=per_cluster_out_fold)
104 |
105 |
106 | if __name__ == '__main__':
107 | main()
108 |
--------------------------------------------------------------------------------
/tools/evaluation.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import click
6 | import os
7 | import scipy.spatial
8 | import sklearn.metrics
9 | import skimage.io
10 | import json
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import pandas as pd
14 |
15 | from tools.clusters2classes import gen_classes_from_clusters
16 | from models.UNetValid import get_model
17 |
18 |
19 | def final_prediction(out_fold, image_generator, model):
20 | num_iter = int(np.ceil(image_generator.n / float(image_generator.batch_size)))
21 | batch_size = image_generator.batch_size
22 |
23 | test_files = image_generator.filenames
24 |
25 | print('Final Predicton:')
26 | for i in range(num_iter):
27 | print('%i:%i/%i' % (i * batch_size, (i + 1) * batch_size, image_generator.n))
28 |
29 | if i < num_iter - 1:
30 | batch_files = test_files[i * batch_size: (i + 1) * batch_size]
31 | else:
32 | batch_files = test_files[i * batch_size:]
33 |
34 | batch_imgs = image_generator[i]
35 |
36 | pred_mask_probs = model.predict(batch_imgs, batch_size=1, verbose=1)
37 | for file, pred_mask_prob in zip(batch_files, pred_mask_probs):
38 | file_dir = os.path.dirname(file)
39 | file_name = file.split('/')[-1]
40 | full_out_dir = os.path.join(out_fold, file_dir)
41 | if not os.path.isdir(full_out_dir):
42 | os.makedirs(full_out_dir)
43 | mask = pred_mask_prob.argmax(axis=-1).astype('uint16')
44 | skimage.io.imsave(os.path.join(full_out_dir, file_name), mask)
45 |
46 |
47 | def evaluate(input_gt, input_pred_mask, output_file=''):
48 | assert os.path.isdir(input_gt)
49 | if input_gt[-1] != '/':
50 | input_gt += '/'
51 | assert os.path.isdir(input_pred_mask)
52 | if input_pred_mask[-1] != '/':
53 | input_pred_mask += '/'
54 |
55 | result_dict = {}
56 | all_gt_files = []
57 | for dirpath, dirnames, filenames in os.walk(input_gt):
58 | all_gt_files.extend([os.path.join(os.path.relpath(dirpath, input_gt), f)
59 | for f in filenames if f.endswith('.png')])
60 |
61 | all_pred_files_full_path = [os.path.join(input_pred_mask, file) for file in all_gt_files]
62 | all_gt_files_full_path = [os.path.join(input_gt, file) for file in all_gt_files]
63 |
64 | all_gt = np.asarray(skimage.io.imread_collection(all_gt_files_full_path))
65 | all_pred_mask = np.asarray(skimage.io.imread_collection(all_pred_files_full_path))
66 |
67 | result_dict['all_dice'] = 1 - scipy.spatial.distance.dice(all_gt.flatten(), all_pred_mask.flatten())
68 | result_dict['all_accuracy'] = sklearn.metrics.accuracy_score(all_gt.flatten(), all_pred_mask.flatten())
69 |
70 | if output_file != '':
71 | with open(output_file, 'w') as fp:
72 | json.dump(result_dict, fp, indent=2, sort_keys=True)
73 |
74 | return result_dict
75 |
76 |
77 | def evaluate_per_epoch():
78 | test = 'test'
79 | in_result_dict = ''
80 | out_figure = 'data/patches_512_lux/figure_%s.png' % test
81 |
82 | if in_result_dict == '':
83 | # in_gt_train_fold = 'data/patches_512_lux/gt_masks/train_masks'
84 | in_gt_train_fold = 'data/patches_512_lux/gt_masks/test_masks'
85 | in_gt_test_fold = 'data/patches_512_lux/gt_masks/%s_masks' % test
86 | in_pred_clust_fold = 'data/patches_512_lux/pred_masks'
87 | out_fold_final_preds = 'data/patches_512_lux/final_preds'
88 | out_result_file = 'data/patches_512_lux/result_dict.json'
89 |
90 | gt_fraction = 0.6
91 | keep_clust_threshold = 0.4
92 |
93 | epochs = list(range(20))
94 |
95 | result = {}
96 | for epoch_dir in sorted(os.listdir(in_pred_clust_fold), key=int):
97 | if not int(epoch_dir) in epochs:
98 | continue
99 | epoch_dirpath = os.path.join(in_pred_clust_fold, epoch_dir)
100 | epoch_test_dirpath = os.path.join(epoch_dirpath)
101 | epoch_train_dirpath = os.path.join(epoch_dirpath)
102 | # epoch_test_dirpath = os.path.join(epoch_dirpath, test)
103 | # epoch_train_dirpath = os.path.join(epoch_dirpath, 'train')
104 |
105 | epoch_out_fold = os.path.join(out_fold_final_preds, epoch_dir)
106 | if not os.path.isdir(epoch_out_fold):
107 | os.makedirs(epoch_out_fold)
108 |
109 | if not set(os.listdir(epoch_test_dirpath)) < set(os.listdir(epoch_out_fold)):
110 | gen_classes_from_clusters(in_clust_test=epoch_test_dirpath,
111 | in_clust_train=epoch_train_dirpath,
112 | in_gt_train=in_gt_train_fold,
113 | gt_fraction=gt_fraction,
114 | keep_clust_threshold=keep_clust_threshold,
115 | out_fold=epoch_out_fold)
116 |
117 | result[epoch_dir] = evaluate(input_gt=in_gt_test_fold, input_pred_mask=epoch_out_fold)
118 |
119 | with open(out_result_file, 'w') as fp:
120 | json.dump(result, fp, indent=2, sort_keys=True)
121 |
122 | else:
123 | with open(in_result_dict, 'r') as json_data:
124 | result = json.load(json_data)
125 |
126 | df = pd.DataFrame({'epochs': list(result.keys()), 'dice': [item['all_dice'] for item in result.values()],
127 | 'accuracy': [item['all_accuracy'] for item in result.values()]})
128 |
129 | fig, ax = plt.subplots(nrows=1, ncols=1)
130 | ax.plot('epochs', 'dice', data=df, marker='o', markerfacecolor='black', color='blue', linewidth=4)
131 | # ax.plot('epochs', 'accuracy', data=df, marker='o', markerfacecolor='black', color='green', linewidth=4)
132 | ax.set_xlabel('Epochs')
133 | ax.set_ylabel('Dice')
134 | fig.savefig(out_figure)
135 |
136 |
137 | @click.command()
138 | @click.option('--image_path', type=click.STRING)
139 | @click.option('--pred_mask_path', type=click.STRING, default='')
140 | @click.option('--model_weights', type=click.STRING, default='')
141 | @click.option('--num_clust', type=click.INT, default=2)
142 | @click.option('--gt_mask_path', type=click.STRING, default='')
143 | @click.option('--out_pred_mask', type=click.STRING, default='')
144 | def main(image_path, pred_mask_path, model_weights, num_clust, gt_mask_path, out_pred_mask):
145 | if pred_mask_path == '':
146 | image = skimage.io.imread(image_path).astype('float32')
147 | mean = np.array([11.4296465, 13.140564, 12.277675], dtype=np.float32).reshape(1, 1, 3)
148 | std = np.array([27.561337, 29.903225, 29.525864], dtype=np.float32).reshape(1, 1, 3)
149 | image -= mean
150 | image /= std
151 | model = get_model(image.shape, num_clust)
152 | model.load_weights(model_weights)
153 |
154 | pred_mask = model.predict(np.expand_dims(image, axis=0)).argmax(axis=-1).astype('uint16')[0]
155 | else:
156 | pred_mask = skimage.io.imread(pred_mask_path)
157 |
158 | gt_mask = skimage.io.imread(gt_mask_path)
159 |
160 | dice = 1 - scipy.spatial.distance.dice(gt_mask.flatten(), pred_mask.flatten())
161 | accuracy = sklearn.metrics.accuracy_score(gt_mask.flatten(), pred_mask.flatten())
162 |
163 | overlap = gt_mask.flatten() * pred_mask.flatten()
164 | union = gt_mask.flatten() + pred_mask.flatten()
165 |
166 | IOU = overlap.sum()/float((union>0).sum())
167 |
168 | print('Dice Score: %f' % dice)
169 | print('Accuracy: %f' % accuracy)
170 | print('IoU: %f' % IOU)
171 |
172 | if out_pred_mask != '' and pred_mask_path == '':
173 | skimage.io.imsave(out_pred_mask, pred_mask)
174 |
175 |
176 | if __name__ == '__main__':
177 | main()
178 |
--------------------------------------------------------------------------------
/tools/full_prediction.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import click
6 | import os
7 | import skimage.io
8 |
9 | from osgeo import gdal
10 |
11 |
12 | def array_to_tiff_file(image, output_file, geoprojection=None, geotransform=None):
13 | """Gets images in a form of numpy arrays and saves them as images with spacial information
14 | Args:
15 | image: numpy array. input image of shape (num_channels, w, h) or (w, h) in case of 1 channel
16 | output_file: string. path to the output file
17 | geoprojection: driver_in.GetProjection()
18 | geotransform: driver_in.GetGeoTransform()
19 | Returns:
20 | -
21 | """
22 | driver = gdal.GetDriverByName("GTiff")
23 |
24 | if image.ndim == 3:
25 | n_channels, h, w = image.shape
26 | elif image.ndim == 2:
27 | h, w = image.shape
28 | n_channels = 1
29 | ds_out = driver.Create(output_file, w, h, n_channels, gdal.GDT_Float32)
30 | if geoprojection and geotransform:
31 | ds_out.SetProjection(geoprojection)
32 | ds_out.SetGeoTransform(geotransform)
33 |
34 | # write channels separately
35 | for i in range(n_channels):
36 | band = ds_out.GetRasterBand(i + 1) # 1-based index
37 | if n_channels > 1:
38 | channel = image[i]
39 | else:
40 | channel = image
41 | band.WriteArray(channel)
42 | band.FlushCache()
43 |
44 | del ds_out
45 |
46 |
47 | def get_prediction(in_full_img_fold, in_pred_mask_fold, crop_shape=(512, 512), out_fold=''):
48 | """Generates masks of a crop_shape for images in the input_folder
49 | Args:
50 | input_folder: string. path to a folder with images on which model will be applied
51 | output_folder: string. path to the output folder with predicted masks
52 | Returns:
53 | -
54 | """
55 | assert os.path.isdir(in_full_img_fold)
56 | if in_full_img_fold[-1] != '/':
57 | in_full_img_fold += '/'
58 | if out_fold != '':
59 | if not os.path.isdir(out_fold):
60 | os.mkdir(out_fold)
61 | if out_fold != '/':
62 | out_fold += '/'
63 |
64 | img_files = os.listdir(in_full_img_fold)
65 | for img_file in img_files:
66 | dataset = gdal.Open(os.path.join(in_full_img_fold, img_file))
67 |
68 | geotransform = dataset.GetGeoTransform()
69 | geoprojection = dataset.GetProjection()
70 |
71 | x_origin = geotransform[0]
72 | y_origin = geotransform[3]
73 | pixel_width = geotransform[1]
74 | pixel_height = geotransform[5]
75 |
76 | image = dataset.ReadAsArray()
77 | image = image.transpose((1, 2, 0))
78 |
79 | mask_files = os.listdir(in_pred_mask_fold)
80 | mask_files.sort()
81 |
82 | count = 0
83 | for row in range(0, image.shape[0] - crop_shape[0], crop_shape[0]):
84 | for col in range(0, image.shape[1] - crop_shape[1], crop_shape[1]):
85 | crop_img = image[row:row + crop_shape[0], col:col + crop_shape[1], :].astype('float32')
86 | if (crop_img == 0).sum() / crop_img.size > 0.1:
87 | continue
88 | crop_geotransform = (x_origin + col * pixel_width, pixel_width, geotransform[2],
89 | y_origin + row * pixel_height, geotransform[4], pixel_height)
90 | full_path_mask = os.path.join(in_pred_mask_fold, mask_files[count])
91 | pred_mask = skimage.io.imread(full_path_mask)
92 |
93 | if out_fold != '':
94 | if not os.path.isdir(os.path.join(out_fold, img_file.split('.')[0])):
95 | os.mkdir(os.path.join(out_fold, img_file.split('.')[0]))
96 | array_to_tiff_file(pred_mask,
97 | os.path.join(out_fold, img_file.split('.')[0],
98 | 'pred_mask_' + str(count).zfill(6) + '.tif'),
99 | geoprojection=geoprojection,
100 | geotransform=crop_geotransform)
101 |
102 | count += 1
103 |
104 |
105 | @click.command()
106 | @click.argument('in_full_img_fold', type=click.STRING)
107 | @click.argument('in_pred_mask_fold', type=click.STRING)
108 | @click.option('--crop_shape', nargs=2, type=click.INT, default=(512, 512))
109 | @click.option('--out_fold', type=click.STRING, default='')
110 | def main(in_full_img_fold, in_pred_mask_fold, crop_shape, out_fold):
111 | get_prediction(in_full_img_fold, in_pred_mask_fold, crop_shape=crop_shape, out_fold=out_fold)
112 |
113 |
114 | if __name__ == '__main__':
115 | main()
116 |
--------------------------------------------------------------------------------
/tools/patch_generator.py:
--------------------------------------------------------------------------------
1 | """Script for generating patches"""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import os
7 | import click
8 | import skimage.io
9 | import skimage.util
10 | import numpy as np
11 |
12 | # mean = np.array([11.4296465, 13.140564, 12.277675], dtype=np.float32).reshape(1, 1, 3)
13 | # std = np.array([27.561337, 29.903225, 29.525864], dtype=np.float32).reshape(1, 1, 3)
14 |
15 | def get_patches(input_fold, output_folder='', output_np_folder='', masks_folder='',
16 | patch_shape=(512, 512), crop_shape=None, stride=(100, 100), keep_dark_region_ratio=0.01):
17 | """
18 | Generates cropped patches from the full-size input images
19 | Args:
20 | input_fold: string. path to the raw train data folder
21 | output_folder: string. path to the output folder generated patches
22 | output_np_folder: string. path to the output folder with numpy file
23 | masks_folder: string. if not '' generates a patch with a mask corresponding to an image patch
24 | patch_shape: tuple of ints. (height, width), specifies patch shape
25 | crop_shape: tuple of ints. (height, width), specifies the area cropped from the patch.
26 | stride: tuple of ints. stride in height and width direction of a sliding window
27 | keep_dark_region_ratio: float. ratio of black pixels which is allowed for a patch to be considered
28 | Returns:
29 | list of numpy arrays. image patches and mask patches
30 | """
31 |
32 | assert os.path.isdir(input_fold)
33 | if input_fold[-1] != '/':
34 | input_fold += '/'
35 | if output_folder != '':
36 | if not os.path.isdir(output_folder):
37 | os.mkdir(output_folder)
38 | if output_folder[-1] != '/':
39 | output_folder += '/'
40 | if output_np_folder != '':
41 | assert os.path.isdir(output_np_folder)
42 | if output_np_folder[-1] != '/':
43 | output_np_folder += '/'
44 | if masks_folder != '':
45 | assert os.path.isdir(masks_folder)
46 | if masks_folder[-1] != '/':
47 | masks_folder += '/'
48 | assert type(patch_shape) is tuple
49 | assert type(patch_shape[0]) is int and type(patch_shape[1]) is int
50 | assert patch_shape[0] > 0 and patch_shape[1] > 0
51 |
52 | if crop_shape != None:
53 | crop_bef_h = (patch_shape[0] - crop_shape[0]) // 2
54 | crop_bef_w = (patch_shape[1] - crop_shape[1]) // 2
55 |
56 | assert type(stride) is tuple
57 | assert type(stride[0]) is int and type(stride[1]) is int
58 | assert stride[0] > 0 and stride[1] > 0
59 |
60 | img_count = 0
61 | for subdir, dir, files in os.walk(input_fold):
62 | for file in files:
63 | if file.split('.')[-1] not in ['tiff', 'tif', 'png', 'jpg', 'JPEG', 'jpeg']:
64 | continue
65 |
66 | img_patches = []
67 | mask_patches = []
68 |
69 | img_file = os.path.join(subdir, file)
70 | full_img = skimage.io.imread(img_file)
71 |
72 | if masks_folder != '':
73 | mask_file = os.path.join(masks_folder, file.split('.')[0] + '_mask.' + file.split('.')[-1])
74 | # mask_file = os.path.join(masks_folder, file.split('.')[0] + '.' + file.split('.')[-1])
75 | assert os.path.isfile(mask_file)
76 | full_mask = skimage.io.imread(mask_file)
77 | # assert full_img.shape[0] == full_mask.shape[0] and full_img.shape[1] == full_mask.shape[1]
78 |
79 | for row in range(0, full_img.shape[0] - patch_shape[0], stride[0]):
80 | for col in range(0, full_img.shape[1] - patch_shape[1], stride[1]):
81 | crop_img = full_img[row:row + patch_shape[0], col:col + patch_shape[1]]
82 | # do not save patches where the number of black pixels is higher than skip_dark_region_ratio
83 | if (crop_img == 0).sum() / crop_img.size > keep_dark_region_ratio:
84 | continue
85 |
86 | img_patches.append(crop_img)
87 |
88 | if masks_folder != '':
89 | crop_mask = full_mask[row:row + patch_shape[0], col:col + patch_shape[1]]
90 | mask_patches.append(crop_mask)
91 |
92 | if output_folder != '':
93 | full_path = os.path.join(output_folder, file.split('.')[0])
94 | if not os.path.isdir(full_path):
95 | os.mkdir(full_path)
96 | for ind, img_patch in enumerate(img_patches):
97 | skimage.io.imsave(os.path.join(full_path, str(img_count + ind).zfill(6) + '.png'), img_patch)
98 | if crop_shape != None:
99 | crop_patch_out = os.path.join(output_folder, 'crops', file.split('.')[0])
100 | if not os.path.isdir(crop_patch_out):
101 | os.makedirs(crop_patch_out)
102 | crop_patch = img_patch[crop_bef_h:crop_bef_h + crop_shape[0],
103 | crop_bef_w:crop_bef_w + crop_shape[1]]
104 | skimage.io.imsave(os.path.join(crop_patch_out, str(img_count + ind).zfill(6) + '.png'),
105 | crop_patch)
106 |
107 | if masks_folder != '':
108 | mask_output_folder = os.path.join(output_folder, 'masks', file.split('.')[0])
109 | if not os.path.isdir(mask_output_folder):
110 | os.makedirs(mask_output_folder)
111 | for ind, mask_patch in enumerate(mask_patches):
112 | if crop_shape != None:
113 | mask_patch = mask_patch[crop_bef_h:crop_bef_h + crop_shape[0],
114 | crop_bef_w:crop_bef_w + crop_shape[1]]
115 | skimage.io.imsave(os.path.join(mask_output_folder, str(img_count + ind).zfill(6) + '.png'),
116 | mask_patch.astype('uint8'))
117 |
118 | img_count += len(img_patches)
119 |
120 | if output_np_folder != '':
121 | full_path = os.path.join(output_folder, file.split('.')[0])
122 | if not os.path.isdir(full_path):
123 | os.mkdir(full_path)
124 | img_patches_np = np.asarray(img_patches, dtype='float32')
125 | np.save(os.path.join(full_path, 'img_patches'), img_patches_np)
126 | mask_patches_np = np.asarray(mask_patches)
127 | np.save(os.path.join(full_path, 'mask_patches'), mask_patches_np)
128 |
129 | return img_patches, mask_patches
130 |
131 |
132 | @click.command()
133 | @click.argument('input_folder', type=click.STRING)
134 | @click.option('--output_folder', type=click.STRING, default='')
135 | @click.option('--output_np_folder', type=click.STRING, default='')
136 | @click.option('--masks_folder', type=click.STRING, default='')
137 | @click.option('--patch_shape', nargs=2, type=click.INT, default=(512, 512), help='Height and width of a crop')
138 | @click.option('--crop_shape', nargs=2, type=click.INT, default=(309, 309))
139 | @click.option('--stride', nargs=2, type=click.INT, default=(100, 100), help='Stride in height and widht direction')
140 | @click.option('--keep_dark_region_ratio', type=click.FLOAT, default=0.01)
141 | def main(input_folder, output_folder, output_np_folder, masks_folder, patch_shape, crop_shape,
142 | stride, keep_dark_region_ratio):
143 | get_patches(input_folder, output_folder=output_folder, output_np_folder=output_np_folder,
144 | masks_folder=masks_folder, patch_shape=patch_shape, crop_shape=crop_shape, stride=stride,
145 | keep_dark_region_ratio=keep_dark_region_ratio)
146 |
147 |
148 | if __name__ == '__main__':
149 | main()
150 |
--------------------------------------------------------------------------------
/tools/rasterization.py:
--------------------------------------------------------------------------------
1 | """Script for generating rasterized masks from shapefiles"""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import rasterio
7 | from rasterio import features
8 | import geopandas as gpd
9 | import click
10 | import os
11 |
12 |
13 | def get_rasterization(input_shp, meta_data={}, ref_tif='', output_tif=''):
14 | """
15 | Get rasterized image from shapefile.
16 | Args:
17 | input_shp: string. path to the shapefile (.shp).
18 | meta_data: dictionary. meta_data passed to rasterio.open() function
19 | in case reference image is not specified.
20 | ref_tif: string. path to the reference .tif image.
21 | When specified meta_data is ignored.
22 | output_tif: string. path to the output rasterized image.
23 |
24 | Returns:
25 | numpy array. rasterized image.
26 | """
27 |
28 | # Checking the correctness of the arguments
29 | assert os.path.isfile(input_shp), 'no file: input_shp'
30 | assert type(meta_data) is dict, 'meta_data has to be dictionary'
31 | if ref_tif != '':
32 | assert os.path.isfile(ref_tif), 'no file: ref_tif'
33 |
34 | shape_df = gpd.read_file(input_shp)
35 |
36 | if ref_tif != '':
37 | rst = rasterio.open(ref_tif)
38 | meta_data = rst.meta.copy()
39 | meta_data['count'] = 1
40 |
41 | with rasterio.open(output_tif, 'w', **meta_data) as out:
42 | shapes = ((geom, 1) for geom in shape_df.geometry)
43 | # create a generator of geom, value pairs to use in rasterizing
44 | result = features.rasterize(shapes=shapes, out_shape=out.shape, fill=0, transform=out.transform)
45 | if output_tif:
46 | out.write_band(1, result)
47 |
48 | return result
49 |
50 |
51 | @click.command()
52 | @click.argument('input_shp', type=click.STRING)
53 | @click.option('--ref_tif', type=click.STRING, default='', help='path to the reference .tif image')
54 | @click.option('--output_tif', type=click.STRING, default='', help='path to the output rasterized image.')
55 | def main(input_shp, ref_tif, output_tif):
56 | meta_data = {
57 | 'driver': 'GTiff',
58 | 'dtype': 'uint8',
59 | 'nodata': None,
60 | 'width': 13500,
61 | 'height': 15300,
62 | 'count': 1,
63 | 'crs': {'init': 'epsg:26911'},
64 | 'transform': (0.5, 0.0, 481649.0, 0.0, -0.5, 3623101.0)}
65 |
66 | files = os.listdir(input_shp)
67 | files = [file for file in files if file.endswith('.shp')]
68 | for file in files:
69 | input_shp_path = os.path.join(input_shp, file)
70 | ref_tif_path = os.path.join(ref_tif, file.split('.')[0]+'.tif')
71 | output_tif_paht = os.path.join(output_tif, file.split('.')[0]+'.tif')
72 | get_rasterization(input_shp=input_shp_path, meta_data=meta_data, ref_tif=ref_tif_path, output_tif=output_tif_paht)
73 |
74 |
75 | if __name__ == '__main__':
76 | main()
77 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/clustering.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhygallo/clusternet_segmentation/7bea23aa04d0c60e125dc40230f21449cded5ae7/utils/__pycache__/clustering.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/clustering.py:
--------------------------------------------------------------------------------
1 | # Skript based on https://github.com/facebookresearch/deepcluster/blob/master/clustering.py
2 | # Facebook, Inc.
3 |
4 | import time
5 |
6 | import faiss
7 | from PIL import ImageFile
8 |
9 |
10 | ImageFile.LOAD_TRUNCATED_IMAGES = True
11 |
12 | __all__ = ['Kmeans']
13 |
14 |
15 | def run_kmeans(x, nmb_clusters, verbose=False, init_cents=None):
16 | """Runs kmeans on 1 GPU.
17 | Args:
18 | x: data
19 | nmb_clusters (int): number of clusters
20 | Returns:
21 | list: ids of data in each cluster
22 | """
23 | n_data, d = x.shape
24 |
25 | # faiss implementation of k-means
26 | clus = faiss.Clustering(d, nmb_clusters)
27 | clus.niter = 20
28 | clus.max_points_per_centroid = 10000000
29 |
30 | if init_cents is not None:
31 | clus.centroids.resize(init_cents.size)
32 | faiss.memcpy(clus.centroids.data(), faiss.swig_ptr(init_cents), init_cents.size * 4)
33 |
34 | index = faiss.IndexFlatL2(d)
35 |
36 | # perform the training
37 | clus.train(x, index)
38 | _, I = index.search(x, 1)
39 | losses = faiss.vector_to_array(clus.obj)
40 | if verbose:
41 | print('k-means loss evolution: {0}'.format(losses))
42 |
43 | centroids = faiss.vector_to_array(clus.centroids).reshape((nmb_clusters, d))
44 |
45 | return [int(n[0]) for n in I], losses[-1], centroids
46 |
47 |
48 |
49 | class Kmeans:
50 | def __init__(self, k):
51 | self.k = k
52 |
53 | def cluster(self, data, verbose=False, init_cents=None):
54 | """Performs k-means clustering.
55 | Args:
56 | x_data (np.array N * dim): data to cluster
57 | """
58 | end = time.time()
59 |
60 | xb = data
61 | # cluster the data
62 | I, loss, cents = run_kmeans(xb, self.k, verbose, init_cents)
63 | self.images_lists = [[] for i in range(self.k)]
64 | for i in range(len(data)):
65 | self.images_lists[I[i]].append(i)
66 |
67 | if verbose:
68 | print('k-means time: {0:.0f} s'.format(time.time() - end))
69 |
70 | return cents
71 |
72 | def make_graph(xb, nnn):
73 | """Builds a graph of nearest neighbors.
74 | Args:
75 | xb (np.array): data
76 | nnn (int): number of nearest neighbors
77 | Returns:
78 | list: for each data the list of ids to its nnn nearest neighbors
79 | list: for each data the list of distances to its nnn NN
80 | """
81 | N, dim = xb.shape
82 |
83 | # L2
84 | index = faiss.IndexFlatL2(dim)
85 |
86 | index.add(xb)
87 | D, I = index.search(xb, nnn + 1)
88 | return I, D
89 |
--------------------------------------------------------------------------------