├── LICENSE ├── README.md ├── README.png ├── crf_loss.py ├── data_loader_gan.py ├── models.py ├── regularize.py ├── train_gan_net.py ├── training_utils.py └── variables.py /LICENSE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT 2 | 3 | ICG Software - 2023, all rights reserved, hereinafter "the Software". 4 | 5 | This software has been developed by researchers of ICG (Institute of Computer Graphics and Vision). 6 | 7 | Institute of Computer Graphics and Vision (ICG), Inffeldgasse 16/II, 8010 Graz, Austria 8 | 9 | ICG holds all the ownership rights on the Software. 10 | 11 | The Software is still being currently developed. It is the ICG's aim for the Software to be used by the scientific community so as to test it and, evaluate it so that ICG may improve it. 12 | 13 | For these reasons ICG has decided to distribute the Software. 14 | 15 | The academic user explicitly acknowledges having received from ICG all information allowing him to appreciate the adequacy between of the Software and his needs and to undertake all necessary precautions for his execution and use. 16 | 17 | The Software is provided only as a source. 18 | 19 | In case of using the Software for a publication or other results obtained through the use of the Software, user should cite the Software as follows: 20 | 21 | @inproceedings{zorzi2021machine, 22 | title={Machine-learned regularization and polygonization of building segmentation masks}, 23 | author={Zorzi, Stefano and Bittner, Ksenia and Fraundorfer, Friedrich}, 24 | booktitle={2020 25th International Conference on Pattern Recognition (ICPR)}, 25 | pages={3098--3105}, 26 | year={2021}, 27 | organization={IEEE} 28 | } 29 | 30 | Every user of the Software could communicate to the developers [stefano.zorzi@icg.tugraz.at] his or her remarks as to the use of the Software. 31 | 32 | EVERY USER CAN USE, EXPLOIT OR COMMERCIALLY DISTRIBUTE THE SOFTWARE AFTER INFORMATION TO ICG (fraundorfer@icg.tugraz.at). IN ANY CASE OF USE, THE SOFTWARE HAS TO BE CITED AS STATED ABOVE. 33 | 34 | THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALIZATION OR ADAPTATION. NO BACKGROUND OF ICG IS TRANSFERRED OR LICENCED UNDER THIS AGREEMENT. 35 | 36 | UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL ICG OR THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Regularization of Building Boundaries in Satellite and Aerial Images 2 | This repository contains the implementation for our publication "Machine-learned regularization and polygonization of building segmentation masks", ICPR 2021. 3 | If you use this implementation please cite the following publication: 4 | 5 | ~~~ 6 | @inproceedings{zorzi2021machine, 7 | title={Machine-learned regularization and polygonization of building segmentation masks}, 8 | author={Zorzi, Stefano and Bittner, Ksenia and Fraundorfer, Friedrich}, 9 | booktitle={2020 25th International Conference on Pattern Recognition (ICPR)}, 10 | pages={3098--3105}, 11 | year={2021}, 12 | organization={IEEE} 13 | } 14 | ~~~ 15 | and 16 | ~~~ 17 | @inproceedings{zorzi2019regularization, 18 | title={Regularization of building boundaries in satellite images using adversarial and regularized losses}, 19 | author={Zorzi, Stefano and Fraundorfer, Friedrich}, 20 | booktitle={IGARSS 2019-2019 IEEE International Geoscience and Remote Sensing Symposium}, 21 | pages={5140--5143}, 22 | year={2019}, 23 | organization={IEEE} 24 | } 25 | ~~~ 26 | 27 |

28 | 29 | Explanatory video of the approach: 30 | 31 | [![Watch the video](https://img.youtube.com/vi/07YQOlwIOMs/0.jpg)](https://www.youtube.com/watch?v=07YQOlwIOMs) 32 | 33 | # Dependencies 34 | 35 | * cuda 10.2 36 | * pytorch >= 1.3 37 | * opencv 38 | * gdal 39 | 40 | # Running the implementation 41 | After installing all of the required dependencies above you can download the pretrained weights from [here](https://drive.google.com/drive/folders/1IPrDpvFq9ODW7UtPAJR_T-gGzxDat_uu?usp=sharing). 42 | 43 | Unzip the archive and place *saved_models_gan* folder in the main *projectRegularization* directory. 44 | 45 | Please note that the polygonization step is not yet available! 46 | 47 | ## Evaluation 48 | Modify *variables.py* accordingly, then run the prediction issuing the command 49 | 50 | ~~~ 51 | python regularize.py 52 | ~~~ 53 | 54 | ## Training 55 | Modify *variables.py* accordingly, then run the training issuing the command 56 | 57 | ~~~ 58 | python train_gan_net.py 59 | ~~~ 60 | -------------------------------------------------------------------------------- /README.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zorzi-s/projectRegularization/c03b94dbcf66549518117c635cf61d843ee662ef/README.png -------------------------------------------------------------------------------- /crf_loss.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import itertools 6 | import time 7 | import datetime 8 | import sys 9 | from math import exp 10 | import random 11 | 12 | #from torchvision.utils import save_image 13 | #from torchvision import datasets 14 | 15 | from torch.utils.data import DataLoader 16 | from torch.autograd import Variable 17 | 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch 21 | 22 | kernel_size = 9 #gaussian kernel dimension 23 | dilation = 1 #cheating :) The "real" dimension of the gaussian kernel is kernel size, but the "effective" dimension is (kernel_size*dilation + 1) 24 | padding = (kernel_size // 2) * dilation #do not touch this 25 | bs = 4 #batch size 26 | win = 256 #window size 27 | 28 | sigma_X = 3.0 #for distance gaussian 29 | sigma_I = 0.1 #for RGB/grayscale gaussian 30 | 31 | sample_interval = 20 # sample image every 32 | 33 | class kernel_loss(torch.nn.Module): 34 | 35 | def sub_kernel(self): 36 | filters = kernel_size * kernel_size 37 | middle = kernel_size // 2 38 | kernel = Variable(torch.zeros((filters, 1, kernel_size, kernel_size))).cuda() 39 | for i in range(kernel_size): 40 | for j in range(kernel_size): 41 | kernel[i*kernel_size+j, 0, i, j] = -1 42 | kernel[i*kernel_size+j, 0, middle, middle] = kernel[i*kernel_size+j, 0, middle, middle] + 1 43 | return kernel 44 | 45 | def dist_kernel(self): 46 | filters = kernel_size * kernel_size 47 | middle = kernel_size // 2 48 | kernel = Variable(torch.zeros((bs, filters, 1, 1))).cuda() 49 | 50 | for i in range(kernel_size): 51 | for j in range(kernel_size): 52 | ii = i - middle 53 | jj = j - middle 54 | distance = pow(ii,2) + pow(jj,2) 55 | kernel[:, i*kernel_size+j, 0, 0] = exp(-distance / pow(sigma_X,2)) 56 | #print(kernel.view(4,1,kernel_size,kernel_size)) 57 | return kernel 58 | 59 | def central_kernel(self): 60 | filters = kernel_size * kernel_size 61 | middle = kernel_size // 2 62 | kernel = Variable(torch.zeros((filters, 1, kernel_size, kernel_size))).cuda() 63 | for i in range(kernel_size): 64 | for j in range(kernel_size): 65 | kernel[i*kernel_size+j, 0, middle, middle] = 1 66 | return kernel 67 | 68 | def select_kernel(self): 69 | filters = kernel_size * kernel_size 70 | middle = kernel_size // 2 71 | kernel = Variable(torch.zeros((filters, 1, kernel_size, kernel_size))).cuda() 72 | for i in range(kernel_size): 73 | for j in range(kernel_size): 74 | kernel[i*kernel_size+j, 0, i, j] = 1 75 | return kernel 76 | 77 | def color_tensor(self, x): 78 | result = Variable(torch.zeros((bs, kernel_size*kernel_size, win-2*padding, win-2*padding))).cuda() 79 | 80 | for i in range(x.shape[1]): 81 | channel = x[:,i,:,:].unsqueeze(1) 82 | sub = nn.Conv2d(in_channels=1, out_channels=kernel_size*kernel_size, kernel_size=kernel_size, bias=False, padding=0, dilation=dilation) 83 | sub.weight.data = self.sub_matrix 84 | color = sub(channel) 85 | color = torch.pow(color,2) 86 | result = result + color 87 | 88 | result = torch.exp(-result / pow(sigma_I,2)) 89 | return result 90 | 91 | def probability_tensor(self, y): 92 | conv = nn.Conv2d(in_channels=1, out_channels=kernel_size*kernel_size, kernel_size=kernel_size, bias=False, padding=0, dilation=dilation) 93 | conv.weight.data = self.select_matrix 94 | prob = conv(y) 95 | return prob 96 | 97 | #def probability_central(self, y): 98 | # conv = nn.Conv2d(in_channels=1, out_channels=kernel_size*kernel_size, kernel_size=kernel_size, bias=False, padding=padding) 99 | # conv.weight.data = self.one_matrix 100 | # prob = conv(y) 101 | # return prob 102 | 103 | def __init__(self): 104 | super(kernel_loss,self).__init__() 105 | #self.softmax = nn.Softmax(dim=1) 106 | self.dist_tensor = self.dist_kernel() 107 | #self.one_matrix = self.central_kernel() 108 | self.select_matrix = self.select_kernel() 109 | self.sub_matrix = self.sub_kernel() #shape: [filters, 1, h, w] 110 | 111 | 112 | def forward(self,x,y): 113 | """ 114 | x --> Image. It can also have just 1 channel (grayscale). Values between 0 and 1 115 | y --> Mask. Values between 0 and 1 116 | """ 117 | #y = self.softmax(y) 118 | y0 = y[:,0,:,:].unsqueeze(1) #build: 0, background: 1, default 1 119 | y1 = y[:,1,:,:].unsqueeze(1) #build: 1, background: 0, default 0 120 | y0p = y0[:,:,padding:-padding,padding:-padding] 121 | y1p = y1[:,:,padding:-padding,padding:-padding] 122 | 123 | W = self.color_tensor(x) 124 | W = (W * self.dist_tensor.expand_as(W)) 125 | 126 | potts_loss_0 = y0p.expand_as(W) * W * self.probability_tensor(y1) 127 | potts_loss_1 = y1p.expand_as(W) * W * self.probability_tensor(y0) 128 | 129 | numel = potts_loss_0.numel() 130 | #ncut_loss_0 = (potts_loss_0 / (self.probability_tensor(y0) * W)).mean() 131 | #ncut_loss_1 = (potts_loss_1 / (self.probability_tensor(y1) * W)).mean() 132 | 133 | """ 134 | if random.randint(0,sample_interval) == 0: 135 | r = random.randint(0,20) 136 | 137 | img = torch.mean(W, dim=1).unsqueeze(1) 138 | #amin = torch.min(img) 139 | #amax = torch.max(img) 140 | #img = (img - amin) / (amax - amin) 141 | save_image(img, "./debug/%d_img.png" % r, nrow=2) 142 | 143 | #img2 = torch.mean(potts_loss_0, dim=1).unsqueeze(1) 144 | #amin = torch.min(img2) 145 | #amax = torch.max(img2) 146 | #img2 = (img2 - amin) / (amax - amin) 147 | #save_image(img2, "./debug/%d_b.png" % r, nrow=2) 148 | 149 | img3 = torch.mean(potts_loss_0, dim=1).unsqueeze(1) 150 | #amin = torch.min(img3) 151 | #amax = torch.max(img3) 152 | #img3 = (img3 - amin) / (amax - amin) 153 | save_image(img3, "./debug/%d_loss.png" % r, nrow=2) 154 | 155 | #img4 = torch.mean(loss_matrix, dim=1).unsqueeze(1) 156 | ##amin = torch.min(img4) 157 | ##amax = torch.max(img4) 158 | ##img4 = (img4 - amin) / (amax - amin) 159 | #save_image(img4, "./debug/%d_d.png" % r, nrow=2) 160 | save_image(x, "./debug/%d_map.png" % r, nrow=2) 161 | """ 162 | 163 | potts_loss_0 = (potts_loss_0).mean() 164 | potts_loss_1 = (potts_loss_1).mean() 165 | potts_loss = potts_loss_0 + potts_loss_1 166 | 167 | return potts_loss 168 | 169 | """ 170 | #ncut_loss_0 = potts_loss_0 / (self.probability_tensor(y0) * W).mean() 171 | #ncut_loss_1 = potts_loss_1 / (self.probability_tensor(y1) * W).mean() 172 | ncut_loss_0 = potts_loss_0 / (y0p.expand_as(W) * W).mean() 173 | ncut_loss_1 = potts_loss_1 / (y1p.expand_as(W) * W).mean() 174 | 175 | #ncut_loss_0 = ncut_loss_0.mean() 176 | #ncut_loss_1 = ncut_loss_1.mean() 177 | ncut_loss = ncut_loss_0 + ncut_loss_1 178 | 179 | #potts_loss = potts_loss_0 + potts_loss_1 180 | #ncut_loss = ncut_loss_0 + ncut_loss_1 181 | 182 | return (potts_loss, ncut_loss, numel) 183 | """ 184 | 185 | -------------------------------------------------------------------------------- /data_loader_gan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from glob import glob 4 | from tqdm import tqdm 5 | import random 6 | from skimage import io 7 | from skimage.segmentation import mark_boundaries 8 | from skimage.transform import rotate 9 | import variables as var 10 | 11 | TEST = False 12 | 13 | def to_categorical(y, num_classes=None, dtype='float32'): 14 | 15 | y = np.array(y, dtype='int') 16 | input_shape = y.shape 17 | if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: 18 | input_shape = tuple(input_shape[:-1]) 19 | y = y.ravel() 20 | if not num_classes: 21 | num_classes = np.max(y) + 1 22 | n = y.shape[0] 23 | categorical = np.zeros((n, num_classes), dtype=dtype) 24 | categorical[np.arange(n), y] = 1 25 | output_shape = input_shape + (num_classes,) 26 | categorical = np.reshape(categorical, output_shape) 27 | return categorical 28 | 29 | class DataLoader(): 30 | 31 | def __init__(self, ws=512, nb=10000, bs=8): 32 | self.nb = nb 33 | self.bs = bs 34 | self.ws = ws 35 | 36 | #self.rgb_files = self.rgb_files[:10] 37 | #self.dsm_files = self.dsm_files[:10] 38 | #self.gti_files = self.gti_files[:10] 39 | 40 | self.load_data() 41 | self.num_tiles = len(self.rgb_imgs) 42 | self.sliding_index = 0 43 | 44 | def generator(self): 45 | for _ in range(self.nb): 46 | batch_rgb = [] 47 | batch_gti = [] 48 | batch_seg = [] 49 | for _ in range(self.bs): 50 | rgb, gti, seg = self.extract_image() 51 | 52 | batch_rgb.append(rgb) 53 | 54 | # the ground truth is categorized 55 | gti = to_categorical(gti != 0, 2) 56 | batch_gti.append(gti) 57 | 58 | # the segmentation is categorized 59 | seg = to_categorical(seg != 0, 2) 60 | batch_seg.append(seg) 61 | 62 | batch_rgb = np.asarray(batch_rgb) 63 | batch_gti = np.asarray(batch_gti) 64 | batch_seg = np.asarray(batch_seg) 65 | batch_rgb = batch_rgb / 255.0 66 | 67 | #batch_gti = batch_gti[:,:,:,np.newaxis] / 255.0 68 | 69 | yield (batch_rgb, batch_gti, batch_seg) 70 | 71 | 72 | def test_shape(self, a): 73 | ri = a.shape[0] % self.ws 74 | rj = a.shape[1] % self.ws 75 | return a[:-ri,:-rj] 76 | 77 | 78 | def random_hsv(self, img, value_h=30, value_s=30, value_v=30): 79 | hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 80 | h, s, v = cv2.split(hsv) 81 | 82 | h = np.int16(h) 83 | s = np.int16(s) 84 | v = np.int16(v) 85 | 86 | h += value_h 87 | h[h < 0] = 0 88 | h[h > 255] = 255 89 | 90 | s += value_s 91 | s[s < 0] = 0 92 | s[s > 255] = 255 93 | 94 | v += value_v 95 | v[v < 0] = 0 96 | v[v > 255] = 255 97 | 98 | h = np.uint8(h) 99 | s = np.uint8(s) 100 | v = np.uint8(v) 101 | 102 | final_hsv = cv2.merge((h, s, v)) 103 | img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR) 104 | return img 105 | 106 | 107 | def extract_image(self, mode="sequential"): 108 | if mode is "random": 109 | rand_t = random.randint(0, self.num_tiles-1) 110 | else: 111 | if self.sliding_index < self.num_tiles: 112 | rand_t = self.sliding_index 113 | self.sliding_index = self.sliding_index + 1 114 | else: 115 | rand_t = 0 116 | self.sliding_index = 0 117 | 118 | rgb = self.rgb_imgs[rand_t].copy() 119 | gti = self.gti_imgs[rand_t].copy() 120 | seg = self.seg_imgs[rand_t].copy() 121 | 122 | h = rgb.shape[1] 123 | w = rgb.shape[0] 124 | 125 | void = True 126 | while void: 127 | rot = random.randint(0,90) 128 | ri = random.randint(0, int(h-self.ws*2)) 129 | rj = random.randint(0, int(w-self.ws*2)) 130 | win_rgb = rgb[ri:ri+int(self.ws*2), rj:rj+int(self.ws*2)] 131 | win_gti = gti[ri:ri+int(self.ws*2), rj:rj+int(self.ws*2)] 132 | win_seg = seg[ri:ri+int(self.ws*2), rj:rj+int(self.ws*2)] 133 | 134 | win_rgb = np.uint8(rotate(win_rgb, rot, resize=False, preserve_range=True)) 135 | win_gti = np.uint8(rotate(win_gti, rot, resize=False, preserve_range=True)) 136 | win_seg = np.uint8(rotate(win_seg, rot, resize=False, preserve_range=True)) 137 | 138 | win_rgb = win_rgb[self.ws//2:-self.ws//2, self.ws//2:-self.ws//2] 139 | win_gti = win_gti[self.ws//2:-self.ws//2, self.ws//2:-self.ws//2] 140 | win_seg = win_seg[self.ws//2:-self.ws//2, self.ws//2:-self.ws//2] 141 | 142 | if np.count_nonzero(win_seg): 143 | void = False 144 | 145 | # Perform some data augmentation 146 | rot = random.randint(0,3) 147 | win_rgb = np.rot90(win_rgb, k=rot) 148 | win_gti = np.rot90(win_gti, k=rot) 149 | win_seg = np.rot90(win_seg, k=rot) 150 | if random.randint(0,1) is 1: 151 | win_rgb = np.fliplr(win_rgb) 152 | win_gti = np.fliplr(win_gti) 153 | win_seg = np.fliplr(win_seg) 154 | 155 | r_h = random.randint(-20,20) 156 | r_s = random.randint(-20,20) 157 | r_v = random.randint(-20,20) 158 | win_rgb = self.random_hsv(win_rgb, r_h, r_s, r_v) 159 | 160 | win_rgb = win_rgb.astype(np.float32) 161 | win_gti = win_gti.astype(np.float32) 162 | win_seg = win_seg.astype(np.float32) 163 | return (win_rgb, win_gti, win_seg) 164 | 165 | 166 | def load_data(self): 167 | self.rgb_imgs = [] 168 | self.gti_imgs = [] 169 | self.seg_imgs = [] 170 | 171 | rgb_files = glob(var.DATASET_RGB) 172 | gti_files = glob(var.DATASET_GTI) 173 | seg_files = glob(var.DATASET_SEG) 174 | 175 | rgb_files.sort() 176 | gti_files.sort() 177 | seg_files.sort() 178 | 179 | combined = list(zip(rgb_files, gti_files, seg_files)) 180 | random.shuffle(combined) 181 | 182 | rgb_files[:], gti_files[:], seg_files[:] = zip(*combined) 183 | 184 | if TEST: 185 | rgb_files = rgb_files[:4] 186 | gti_files = gti_files[:4] 187 | seg_files = seg_files[:4] 188 | 189 | for rgb_name, gti_name, seg_name in tqdm(zip(rgb_files, gti_files, seg_files), total=len(rgb_files), desc="Loading dataset into RAM"): 190 | 191 | tmp = io.imread(rgb_name) 192 | tmp = tmp.astype(np.uint8) 193 | self.rgb_imgs.append(tmp) 194 | 195 | tmp = io.imread(gti_name) 196 | tmp = tmp.astype(np.uint8) 197 | self.gti_imgs.append(tmp) 198 | 199 | tmp = io.imread(seg_name) 200 | tmp = tmp.astype(np.uint8) 201 | self.seg_imgs.append(tmp) 202 | 203 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | def weights_init_normal(m): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 10 | if hasattr(m, "bias") and m.bias is not None: 11 | torch.nn.init.constant_(m.bias.data, 0.0) 12 | elif classname.find("BatchNorm2d") != -1: 13 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | torch.nn.init.constant_(m.bias.data, 0.0) 15 | 16 | 17 | 18 | class ResidualBlock(nn.Module): 19 | def __init__(self, in_features): 20 | super(ResidualBlock, self).__init__() 21 | 22 | self.block = nn.Sequential( 23 | #nn.ReflectionPad2d(1), 24 | nn.Conv2d(in_features, in_features, 3, stride=1, padding=1), 25 | nn.InstanceNorm2d(in_features), 26 | nn.ReLU(inplace=True), 27 | #nn.ReflectionPad2d(1), 28 | nn.Conv2d(in_features, in_features, 3, stride=1, padding=1), 29 | nn.InstanceNorm2d(in_features), 30 | nn.ReLU(inplace=True), 31 | ) 32 | 33 | def forward(self, x): 34 | return x + self.block(x) 35 | 36 | 37 | 38 | class GeneratorResNet(nn.Module): 39 | def __init__(self, num_residual_blocks=8, in_features=256): 40 | super(GeneratorResNet, self).__init__() 41 | 42 | out_features = in_features 43 | 44 | model = [] 45 | 46 | # Residual blocks 47 | for _ in range(num_residual_blocks): 48 | model += [ResidualBlock(out_features)] 49 | 50 | # Upsampling 51 | for _ in range(2): 52 | out_features //= 2 53 | model += [ 54 | nn.Upsample(scale_factor=2), 55 | nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), 56 | nn.InstanceNorm2d(out_features), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(out_features, out_features, 3, stride=1, padding=1), 59 | nn.InstanceNorm2d(out_features), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(out_features, out_features, 3, stride=1, padding=1), 62 | nn.InstanceNorm2d(out_features), 63 | nn.ReLU(inplace=True), 64 | ] 65 | in_features = out_features 66 | 67 | # Output layer 68 | #model += [nn.ReflectionPad2d(2), nn.Conv2d(out_features, 2, 7), nn.Softmax()] 69 | model += [nn.Conv2d(out_features, 2, 7, stride=1, padding=3), nn.Sigmoid()] 70 | 71 | self.model = nn.Sequential(*model) 72 | 73 | def forward(self, feature_map): 74 | x = self.model(feature_map) 75 | return x 76 | 77 | 78 | class Encoder(nn.Module): 79 | def __init__(self, channels=3+2): 80 | super(Encoder, self).__init__() 81 | 82 | # Initial convolution block 83 | out_features = 64 84 | model = [ 85 | nn.Conv2d(channels, out_features, 7, stride=1, padding=3), 86 | nn.InstanceNorm2d(out_features), 87 | nn.ReLU(inplace=True), 88 | ] 89 | in_features = out_features 90 | 91 | # Downsampling 92 | for _ in range(2): 93 | out_features *= 2 94 | model += [ 95 | nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), 96 | nn.InstanceNorm2d(out_features), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(out_features, out_features, 3, stride=1, padding=1), 99 | nn.InstanceNorm2d(out_features), 100 | nn.ReLU(inplace=True), 101 | nn.MaxPool2d(2, stride=2), 102 | ] 103 | in_features = out_features 104 | 105 | self.model = nn.Sequential(*model) 106 | 107 | def forward(self, arguments): 108 | x = torch.cat(arguments, dim=1) 109 | x = self.model(x) 110 | return x 111 | 112 | 113 | class Discriminator(nn.Module): 114 | def __init__(self): 115 | super(Discriminator, self).__init__() 116 | 117 | channels = 2 118 | out_channels = 2 119 | 120 | def discriminator_block(in_filters, out_filters, normalize=True): 121 | """Returns downsampling layers of each discriminator block""" 122 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1)] 123 | if normalize: 124 | layers.append(nn.InstanceNorm2d(out_filters)) 125 | layers.append(nn.ReLU()) 126 | 127 | layers.append(nn.Conv2d(out_filters, out_filters, 3, stride=1, padding=1)) 128 | if normalize: 129 | layers.append(nn.InstanceNorm2d(out_filters)) 130 | layers.append(nn.ReLU()) 131 | layers.append(nn.MaxPool2d(2, stride=2)) 132 | return layers 133 | 134 | self.model = nn.Sequential( 135 | *discriminator_block(channels, 64, normalize=False), 136 | *discriminator_block(64, 128), 137 | *discriminator_block(128, 256), 138 | *discriminator_block(256, 512), 139 | nn.Conv2d(512, out_channels, 3, padding=1), 140 | nn.Sigmoid() 141 | ) 142 | 143 | def forward(self, img): 144 | #img = torch.cat((rgb, mask), dim=1) 145 | img = self.model(img) 146 | return img 147 | -------------------------------------------------------------------------------- /regularize.py: -------------------------------------------------------------------------------- 1 | import random 2 | from skimage import io 3 | from skimage.transform import rotate 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | import gdal 8 | import os 9 | import glob 10 | from skimage.segmentation import mark_boundaries 11 | from PIL import Image, ImageDraw, ImageFont 12 | from numpy.linalg import svd 13 | import cv2 14 | from skimage import measure 15 | 16 | from models import GeneratorResNet, Encoder 17 | from skimage.transform import rescale 18 | import variables as var 19 | 20 | 21 | 22 | 23 | def compute_IoU(mask, pred): 24 | mask = mask!=0 25 | pred = pred!=0 26 | 27 | m1 = np.logical_and(mask, pred) 28 | m2 = np.logical_and(np.logical_not(mask), np.logical_not(pred)) 29 | m3 = np.logical_and(mask==0, pred==1) 30 | m4 = np.logical_and(mask==1, pred==0) 31 | m5 = np.logical_or(mask, pred) 32 | 33 | tp = np.count_nonzero(m1) 34 | fp = np.count_nonzero(m3) 35 | fn = np.count_nonzero(m4) 36 | 37 | IoU = tp/(tp+(fn+fp)) 38 | return IoU 39 | 40 | 41 | def to_categorical(y, num_classes=None, dtype='float32'): 42 | y = np.array(y, dtype='int') 43 | input_shape = y.shape 44 | if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: 45 | input_shape = tuple(input_shape[:-1]) 46 | y = y.ravel() 47 | if not num_classes: 48 | num_classes = np.max(y) + 1 49 | n = y.shape[0] 50 | categorical = np.zeros((n, num_classes), dtype=dtype) 51 | categorical[np.arange(n), y] = 1 52 | output_shape = input_shape + (num_classes,) 53 | categorical = np.reshape(categorical, output_shape) 54 | return categorical 55 | 56 | 57 | def predict_building(rgb, mask, model): 58 | Tensor = torch.cuda.FloatTensor 59 | 60 | mask = to_categorical(mask, 2) 61 | 62 | rgb = rgb[np.newaxis, :, :, :] 63 | mask = mask[np.newaxis, :, :, :] 64 | 65 | E, G = model 66 | 67 | rgb = Tensor(rgb) 68 | mask = Tensor(mask) 69 | rgb = rgb.permute(0,3,1,2) 70 | mask = mask.permute(0,3,1,2) 71 | 72 | rgb = rgb / 255.0 73 | 74 | # PREDICTION 75 | pred = G(E([rgb, mask])) 76 | pred = pred.permute(0,2,3,1) 77 | 78 | pred = pred.detach().cpu().numpy() 79 | 80 | pred = np.argmax(pred[0,:,:,:], axis=-1) 81 | return pred 82 | 83 | 84 | 85 | def fix_limits(i_min, i_max, j_min, j_max, min_image_size=256): 86 | 87 | def closest_divisible_size(size, factor=4): 88 | while size % factor: 89 | size += 1 90 | return size 91 | 92 | height = i_max - i_min 93 | width = j_max - j_min 94 | 95 | # pad the rows 96 | if height < min_image_size: 97 | diff = min_image_size - height 98 | else: 99 | diff = closest_divisible_size(height) - height + 16 100 | 101 | i_min -= (diff // 2) 102 | i_max += (diff // 2 + diff % 2) 103 | 104 | # pad the columns 105 | if width < min_image_size: 106 | diff = min_image_size - width 107 | else: 108 | diff = closest_divisible_size(width) - width + 16 109 | 110 | j_min -= (diff // 2) 111 | j_max += (diff // 2 + diff % 2) 112 | 113 | return i_min, i_max, j_min, j_max 114 | 115 | 116 | 117 | def regularization(rgb, ins_segmentation, model, in_mode="instance", out_mode="instance", min_size=10): 118 | assert in_mode == "instance" or in_mode == "semantic" 119 | assert out_mode == "instance" or out_mode == "semantic" 120 | 121 | if in_mode == "semantic": 122 | ins_segmentation = np.uint16(measure.label(ins_segmentation, background=0)) 123 | 124 | max_instance = np.amax(ins_segmentation) 125 | border = 256 126 | 127 | ins_segmentation = np.uint16(cv2.copyMakeBorder(ins_segmentation,border,border,border,border,cv2.BORDER_CONSTANT,value=0)) 128 | rgb = np.uint8(cv2.copyMakeBorder(rgb,border,border,border,border,cv2.BORDER_CONSTANT,value=(0,0,0))) 129 | 130 | regularization = np.zeros(ins_segmentation.shape, dtype=np.uint16) 131 | 132 | for ins in tqdm(range(1, max_instance+1), desc="Regularization"): 133 | indices = np.argwhere(ins_segmentation==ins) 134 | building_size = indices.shape[0] 135 | if building_size > min_size: 136 | i_min = np.amin(indices[:,0]) 137 | i_max = np.amax(indices[:,0]) 138 | j_min = np.amin(indices[:,1]) 139 | j_max = np.amax(indices[:,1]) 140 | 141 | i_min, i_max, j_min, j_max = fix_limits(i_min, i_max, j_min, j_max) 142 | 143 | mask = np.copy(ins_segmentation[i_min:i_max, j_min:j_max] == ins) 144 | rgb_mask = np.copy(rgb[i_min:i_max, j_min:j_max, :]) 145 | 146 | 147 | 148 | max_building_size = 1024 149 | rescaled = False 150 | if mask.shape[0] > max_building_size and mask.shape[0] >= mask.shape[1]: 151 | f = max_building_size / mask.shape[0] 152 | mask = rescale(mask, f, anti_aliasing=False, preserve_range=True) 153 | rgb_mask = rescale(rgb_mask, f, anti_aliasing=False) 154 | rescaled = True 155 | elif mask.shape[1] > max_building_size and mask.shape[1] >= mask.shape[0]: 156 | f = max_building_size / mask.shape[1] 157 | mask = rescale(mask, f, anti_aliasing=False) 158 | rgb_mask = rescale(rgb_mask, f, anti_aliasing=False, preserve_range=True) 159 | rescaled = True 160 | 161 | pred = predict_building(rgb_mask, mask, model) 162 | 163 | if rescaled: 164 | pred = rescale(pred, 1/f, anti_aliasing=False, preserve_range=True) 165 | 166 | 167 | 168 | pred_indices = np.argwhere(pred != 0) 169 | 170 | if pred_indices.shape[0] > 0: 171 | pred_indices[:,0] = pred_indices[:,0] + i_min 172 | pred_indices[:,1] = pred_indices[:,1] + j_min 173 | x, y = zip(*pred_indices) 174 | if out_mode == "semantic": 175 | regularization[x,y] = 1 176 | else: 177 | regularization[x,y] = ins 178 | 179 | return regularization[border:-border, border:-border] 180 | 181 | 182 | 183 | def copyGeoreference(inp, output): 184 | dataset = gdal.Open(inp) 185 | if dataset is None: 186 | print('Unable to open', inp, 'for reading') 187 | sys.exit(1) 188 | 189 | projection = dataset.GetProjection() 190 | geotransform = dataset.GetGeoTransform() 191 | 192 | if projection is None and geotransform is None: 193 | print('No projection or geotransform found on file' + input) 194 | sys.exit(1) 195 | 196 | dataset2 = gdal.Open(output, gdal.GA_Update) 197 | 198 | if dataset2 is None: 199 | print('Unable to open', output, 'for writing') 200 | sys.exit(1) 201 | 202 | if geotransform is not None and geotransform != (0, 1, 0, 0, 0, 1): 203 | dataset2.SetGeoTransform(geotransform) 204 | 205 | if projection is not None and projection != '': 206 | dataset2.SetProjection(projection) 207 | 208 | gcp_count = dataset.GetGCPCount() 209 | if gcp_count != 0: 210 | dataset2.SetGCPs(dataset.GetGCPs(), dataset.GetGCPProjection()) 211 | 212 | dataset = None 213 | dataset2 = None 214 | 215 | 216 | 217 | def regularize_segmentations(img_folder, seg_folder, out_folder, in_mode="semantic", out_mode="instance", samples=None): 218 | """ 219 | BUILDING REGULARIZATION 220 | Inputs: 221 | - satellite image (3 channels) 222 | - building segmentation (1 channel) 223 | Output: 224 | - regularized mask 225 | """ 226 | 227 | img_files = glob.glob(img_folder) 228 | seg_files = glob.glob(seg_folder) 229 | 230 | img_files.sort() 231 | seg_files.sort() 232 | 233 | for num, (satellite_image_file, building_segmentation_file) in enumerate(zip(img_files, seg_files)): 234 | print(satellite_image_file, building_segmentation_file) 235 | _, rgb_name = os.path.split(satellite_image_file) 236 | _, seg_name = os.path.split(building_segmentation_file) 237 | assert rgb_name == seg_name 238 | 239 | output_file = out_folder + seg_name 240 | 241 | E1 = Encoder() 242 | G = GeneratorResNet() 243 | G.load_state_dict(torch.load(var.MODEL_GENERATOR)) 244 | E1.load_state_dict(torch.load(var.MODEL_ENCODER)) 245 | E1 = E1.cuda() 246 | G = G.cuda() 247 | 248 | model = [E1,G] 249 | 250 | M = io.imread(building_segmentation_file) 251 | M = np.uint16(M) 252 | P = io.imread(satellite_image_file) 253 | P = np.uint8(P) 254 | 255 | R = regularization(P, M, model, in_mode=in_mode, out_mode=out_mode) 256 | 257 | if out_mode == "instance": 258 | io.imsave(output_file, np.uint16(R)) 259 | else: 260 | io.imsave(output_file, np.uint8(R*255)) 261 | 262 | if samples is not None: 263 | i = 1000 264 | j = 1000 265 | h, w = 1080, 1920 266 | P = P[i:i+h, j:j+w] 267 | R = R[i:i+h, j:j+w] 268 | M = M[i:i+h, j:j+w] 269 | 270 | R = mark_boundaries(P, R, mode="thick") 271 | M = mark_boundaries(P, M, mode="thick") 272 | 273 | R = np.uint8(R*255) 274 | M = np.uint8(M*255) 275 | 276 | font = cv2.FONT_HERSHEY_SIMPLEX 277 | bottomLeftCornerOfText = (20,1060) 278 | fontScale = 1 279 | fontColor = (255,255,0) 280 | lineType = 2 281 | 282 | cv2.putText(R, "INRIA dataset, " + rgb_name + ", regularization", 283 | bottomLeftCornerOfText, 284 | font, 285 | fontScale, 286 | fontColor, 287 | lineType) 288 | 289 | cv2.putText(M, "INRIA dataset, " + rgb_name + ", segmentation", 290 | bottomLeftCornerOfText, 291 | font, 292 | fontScale, 293 | fontColor, 294 | lineType) 295 | 296 | io.imsave(samples + "./%d_2reg.png" % num, np.uint8(R)) 297 | io.imsave(samples + "./%d_1seg.png" % num, np.uint8(M)) 298 | 299 | copyGeoreference(satellite_image_file, output_file) 300 | copyGeoreference(satellite_image_file, building_segmentation_file) 301 | 302 | 303 | 304 | regularize_segmentations(img_folder=var.INF_RGB, seg_folder=var.INF_SEG, out_folder=var.INF_OUT, in_mode="semantic", out_mode="instance", samples=None) 305 | -------------------------------------------------------------------------------- /train_gan_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.optim.lr_scheduler import MultiStepLR 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | 11 | from tqdm import tqdm 12 | import click 13 | import numpy as np 14 | import cv2 15 | from skimage.segmentation import mark_boundaries 16 | from skimage import io 17 | import itertools 18 | 19 | from models import GeneratorResNet, Encoder, Discriminator 20 | from data_loader_gan import DataLoader 21 | from training_utils import sample_images, LossBuffer, LambdaLR 22 | import variables as var 23 | from crf_loss import kernel_loss 24 | 25 | 26 | 27 | def crf_factor(batch_index, start_crf_batch, end_crf_batch, crf_initial_factor, crf_final_factor): 28 | if batch_index <= start_crf_batch: 29 | return 0.0 30 | elif start_crf_batch < batch_index < end_crf_batch: 31 | return crf_initial_factor + ((crf_final_factor - crf_initial_factor) * (batch_index - start_crf_batch) / (end_crf_batch - start_crf_batch)) 32 | else: 33 | return crf_final_factor 34 | 35 | 36 | def train( 37 | models_path='./saved_models_gan/', \ 38 | restore=False, \ 39 | batch_size=4, \ 40 | start_batch=0, n_batches=140000, \ 41 | start_crf_batch=60000, end_crf_batch=120000, crf_initial_factor=0.0, crf_final_factor=175.0, \ 42 | start_lr_decay=120000, \ 43 | start_lr=0.00004, win_size=256, sample_interval=20, backup_interval=5000): 44 | 45 | patch_size = int(win_size / pow(2, 4)) 46 | 47 | Tensor = torch.cuda.FloatTensor 48 | 49 | e1 = Encoder(channels=3+2) 50 | e2 = Encoder(channels=2) 51 | net = GeneratorResNet() 52 | disc = Discriminator() 53 | 54 | if restore: 55 | print("Restoring model number %d" % start_batch) 56 | e1.load_state_dict(torch.load(models_path + "E%d_e1" % start_batch)) 57 | e2.load_state_dict(torch.load(models_path + "E%d_e2" % start_batch)) 58 | net.load_state_dict(torch.load(models_path + "E%d_net" % start_batch)) 59 | disc.load_state_dict(torch.load(models_path + "E%d_disc" % start_batch)) 60 | 61 | e1 = e1.cuda() 62 | e2 = e2.cuda() 63 | net = net.cuda() 64 | disc = disc.cuda() 65 | 66 | os.makedirs(models_path, exist_ok=True) 67 | 68 | loss_0_buffer = LossBuffer() 69 | loss_1_buffer = LossBuffer() 70 | loss_2_buffer = LossBuffer() 71 | loss_3_buffer = LossBuffer() 72 | loss_4_buffer = LossBuffer() 73 | loss_5_buffer = LossBuffer() 74 | 75 | gen_obj = DataLoader(bs=batch_size, nb=n_batches, ws=win_size) 76 | 77 | # Optimizers 78 | optimizer_G = torch.optim.Adam(itertools.chain(net.parameters(), e1.parameters(), e2.parameters()), lr=start_lr) 79 | optimizer_D = torch.optim.Adam(disc.parameters(), lr=start_lr) 80 | 81 | # Learning rate update schedulers 82 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_batches, start_lr_decay).step) 83 | lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(n_batches, start_lr_decay).step) 84 | 85 | bce_criterion = nn.BCELoss() 86 | bce_criterion = bce_criterion.cuda() 87 | 88 | densecrflosslayer = kernel_loss() 89 | densecrflosslayer = densecrflosslayer.cuda() 90 | 91 | loader = gen_obj.generator() 92 | train_iterator = tqdm(loader, total=(n_batches + 1 - start_batch)) 93 | img_index = 0 94 | 95 | for batch_index, (rgb, gti, seg) in enumerate(train_iterator): 96 | 97 | batch_index = batch_index + start_batch 98 | 99 | rgb = Variable(Tensor(rgb)) 100 | gti = Variable(Tensor(gti)) 101 | seg = Variable(Tensor(seg)) 102 | 103 | rgb = rgb.permute(0,3,1,2) 104 | gti = gti.permute(0,3,1,2) 105 | seg = seg.permute(0,3,1,2) 106 | 107 | # Adversarial ground truths 108 | ones = Variable(Tensor(np.ones((batch_size, 1, patch_size, patch_size))), requires_grad=False) 109 | zeros = Variable(Tensor(np.zeros((batch_size, 1, patch_size, patch_size))), requires_grad=False) 110 | valid = torch.cat((ones, zeros), dim=1) 111 | fake = torch.cat((zeros, ones), dim=1) 112 | 113 | # ------------------ 114 | # Train Generators 115 | # ------------------ 116 | 117 | #e1.train() 118 | #e2.train() 119 | #net.train() 120 | 121 | optimizer_G.zero_grad() 122 | 123 | reg = net(e1([rgb, seg])) 124 | rec = net(e2([gti])) 125 | 126 | # Identity loss (reconstruction loss) 127 | loss_rec_1 = bce_criterion(reg, seg) 128 | loss_rec_2 = bce_criterion(rec, gti) 129 | 130 | # GAN loss 131 | loss_GAN = bce_criterion(disc(reg), valid) 132 | 133 | # CRF loss 134 | pot_multiplier = crf_factor(batch_index, start_crf_batch, end_crf_batch, crf_initial_factor, crf_final_factor) 135 | loss_pot = densecrflosslayer(rgb, reg) 136 | loss_pot = loss_pot.cuda() 137 | 138 | # Total loss 139 | loss_G = 3 * loss_GAN + 1 * loss_rec_1 + 3 * loss_rec_2 + pot_multiplier * loss_pot 140 | 141 | loss_G.backward() 142 | optimizer_G.step() 143 | 144 | 145 | # ----------------------- 146 | # Train Discriminator A 147 | # ----------------------- 148 | 149 | #disc.train() 150 | 151 | optimizer_D.zero_grad() 152 | 153 | loss_real = bce_criterion(disc(rec.detach()), valid) 154 | loss_fake = bce_criterion(disc(reg.detach()), fake) 155 | 156 | # Total loss 157 | loss_D = (loss_real + loss_fake) / 2 158 | 159 | loss_D.backward() 160 | optimizer_D.step() 161 | 162 | # -------------- 163 | # Update LR 164 | # -------------- 165 | 166 | lr_scheduler_G.step(batch_index) 167 | lr_scheduler_D.step(batch_index) 168 | 169 | for g in optimizer_D.param_groups: 170 | current_lr = g['lr'] 171 | 172 | # -------------- 173 | # Log Progress 174 | # -------------- 175 | 176 | status = "[Batch %d][D loss: %f][G loss: %f, adv: %f, rec1: %f, rec2: %f][pot: %f, pot_mul: %f][lr: %f]" % \ 177 | (batch_index, \ 178 | loss_0_buffer.push(loss_D.item()), \ 179 | loss_1_buffer.push(loss_G.item()), loss_2_buffer.push(loss_GAN.item()), loss_3_buffer.push(loss_rec_1.item()), loss_4_buffer.push(loss_rec_2.item()), 180 | loss_5_buffer.push(loss_pot.item()), pot_multiplier, current_lr, ) 181 | 182 | train_iterator.set_description(status) 183 | 184 | if (batch_index % sample_interval == 0): 185 | img_index += 1 186 | void_mask = torch.zeros(gti.shape).cuda() 187 | sample_images(img_index, rgb, [void_mask, gti, rec, seg, reg]) 188 | if img_index >= 100: 189 | img_index = 0 190 | 191 | if (batch_index % backup_interval == 0): 192 | torch.save(e1.state_dict(), models_path + "E" + str(batch_index) + "_e1") 193 | torch.save(e2.state_dict(), models_path + "E" + str(batch_index) + "_e2") 194 | torch.save(net.state_dict(), models_path + "E" + str(batch_index) + "_net") 195 | torch.save(disc.state_dict(), models_path + "E" + str(batch_index) + "_disc") 196 | 197 | 198 | if __name__ == '__main__': 199 | train() 200 | -------------------------------------------------------------------------------- /training_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import glob 4 | from tqdm import tqdm 5 | import random 6 | from skimage import io 7 | from skimage.segmentation import mark_boundaries 8 | 9 | import random 10 | import time 11 | import datetime 12 | import sys 13 | 14 | from torch.autograd import Variable 15 | import torch 16 | import numpy as np 17 | 18 | import gdal 19 | 20 | import variables as var 21 | 22 | 23 | def sample_images(sample_index, img, masks): 24 | batch = img.shape[0] 25 | 26 | img = img.permute(0,2,3,1) 27 | 28 | for i in range(len(masks)): 29 | masks[i] = masks[i].permute(0,2,3,1) 30 | 31 | img = img.cpu().numpy() 32 | ip = np.uint8(img * 255) 33 | for i in range(len(masks)): 34 | masks[i] = masks[i].detach().cpu().numpy() 35 | masks[i] = np.argmax(masks[i], axis=-1) 36 | masks[i] = np.uint8(masks[i] * 255) 37 | 38 | line_mode = "inner" 39 | 40 | for i in range(len(masks)): 41 | row = np.copy(ip[0,:,:,:]) 42 | line = cv2.Canny(masks[i][0,:,:], 0, 255) 43 | row = mark_boundaries(row, line, color=(1,1,0), mode=line_mode) * 255#, outline_color=(self.red,self.greed,0)) 44 | for b in range(1,batch): 45 | pic = np.copy(ip[b,:,:,:]) 46 | line = cv2.Canny(masks[i][b,:,:], 0, 255) 47 | pic = mark_boundaries(pic, line, color=(1,1,0), mode=line_mode) * 255#, outline_color=(self.red,self.greed,0)) 48 | row = np.concatenate((row, pic), 1) 49 | masks[i] = row 50 | 51 | img = np.concatenate(masks, 0) 52 | img = np.uint8(img) 53 | io.imsave(var.DEBUG_DIR + "debug_%s.png" % str(sample_index), img) 54 | 55 | 56 | class LossBuffer(): 57 | def __init__(self, max_size=100): 58 | self.data = [] 59 | self.max_size = max_size 60 | 61 | def push(self, data): 62 | self.data.append(data) 63 | if len(self.data) > self.max_size: 64 | self.data = self.data[1:] 65 | return sum(self.data) / len(self.data) 66 | 67 | 68 | class LambdaLR(): 69 | def __init__(self, n_batches, decay_start_batch): 70 | assert ((n_batches - decay_start_batch) > 0), "Decay must start before the training session ends!" 71 | self.n_batches = n_batches 72 | self.decay_start_batch = decay_start_batch 73 | 74 | def step(self, batch): 75 | if batch > self.decay_start_batch: 76 | factor = 1.0 - (batch - self.decay_start_batch) / (self.n_batches - self.decay_start_batch) 77 | if factor > 0: 78 | return factor 79 | else: 80 | return 0.0 81 | else: 82 | return 1.0 83 | -------------------------------------------------------------------------------- /variables.py: -------------------------------------------------------------------------------- 1 | # TRAINING 2 | DATASET_RGB = "./data/rgb/*.tif" 3 | DATASET_GTI = "./data/gti/*.tif" 4 | DATASET_SEG = "./data/seg/*.tif" 5 | 6 | DEBUG_DIR = "./debug/" 7 | 8 | 9 | # INFERENCE 10 | INF_RGB = "./test_data/rgb/*.tif" 11 | INF_SEG = "./test_data/seg/*.tif" 12 | INF_OUT = "./test_data/reg_output/" 13 | 14 | MODEL_ENCODER = "./saved_models_gan/E140000_e1" 15 | MODEL_GENERATOR = "./saved_models_gan/E140000_net" 16 | --------------------------------------------------------------------------------