├── LICENSE
├── README.md
├── README.png
├── crf_loss.py
├── data_loader_gan.py
├── models.py
├── regularize.py
├── train_gan_net.py
├── training_utils.py
└── variables.py
/LICENSE:
--------------------------------------------------------------------------------
1 | SOFTWARE LICENSE AGREEMENT
2 |
3 | ICG Software - 2023, all rights reserved, hereinafter "the Software".
4 |
5 | This software has been developed by researchers of ICG (Institute of Computer Graphics and Vision).
6 |
7 | Institute of Computer Graphics and Vision (ICG), Inffeldgasse 16/II, 8010 Graz, Austria
8 |
9 | ICG holds all the ownership rights on the Software.
10 |
11 | The Software is still being currently developed. It is the ICG's aim for the Software to be used by the scientific community so as to test it and, evaluate it so that ICG may improve it.
12 |
13 | For these reasons ICG has decided to distribute the Software.
14 |
15 | The academic user explicitly acknowledges having received from ICG all information allowing him to appreciate the adequacy between of the Software and his needs and to undertake all necessary precautions for his execution and use.
16 |
17 | The Software is provided only as a source.
18 |
19 | In case of using the Software for a publication or other results obtained through the use of the Software, user should cite the Software as follows:
20 |
21 | @inproceedings{zorzi2021machine,
22 | title={Machine-learned regularization and polygonization of building segmentation masks},
23 | author={Zorzi, Stefano and Bittner, Ksenia and Fraundorfer, Friedrich},
24 | booktitle={2020 25th International Conference on Pattern Recognition (ICPR)},
25 | pages={3098--3105},
26 | year={2021},
27 | organization={IEEE}
28 | }
29 |
30 | Every user of the Software could communicate to the developers [stefano.zorzi@icg.tugraz.at] his or her remarks as to the use of the Software.
31 |
32 | EVERY USER CAN USE, EXPLOIT OR COMMERCIALLY DISTRIBUTE THE SOFTWARE AFTER INFORMATION TO ICG (fraundorfer@icg.tugraz.at). IN ANY CASE OF USE, THE SOFTWARE HAS TO BE CITED AS STATED ABOVE.
33 |
34 | THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALIZATION OR ADAPTATION. NO BACKGROUND OF ICG IS TRANSFERRED OR LICENCED UNDER THIS AGREEMENT.
35 |
36 | UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL ICG OR THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Regularization of Building Boundaries in Satellite and Aerial Images
2 | This repository contains the implementation for our publication "Machine-learned regularization and polygonization of building segmentation masks", ICPR 2021.
3 | If you use this implementation please cite the following publication:
4 |
5 | ~~~
6 | @inproceedings{zorzi2021machine,
7 | title={Machine-learned regularization and polygonization of building segmentation masks},
8 | author={Zorzi, Stefano and Bittner, Ksenia and Fraundorfer, Friedrich},
9 | booktitle={2020 25th International Conference on Pattern Recognition (ICPR)},
10 | pages={3098--3105},
11 | year={2021},
12 | organization={IEEE}
13 | }
14 | ~~~
15 | and
16 | ~~~
17 | @inproceedings{zorzi2019regularization,
18 | title={Regularization of building boundaries in satellite images using adversarial and regularized losses},
19 | author={Zorzi, Stefano and Fraundorfer, Friedrich},
20 | booktitle={IGARSS 2019-2019 IEEE International Geoscience and Remote Sensing Symposium},
21 | pages={5140--5143},
22 | year={2019},
23 | organization={IEEE}
24 | }
25 | ~~~
26 |
27 |

28 |
29 | Explanatory video of the approach:
30 |
31 | [](https://www.youtube.com/watch?v=07YQOlwIOMs)
32 |
33 | # Dependencies
34 |
35 | * cuda 10.2
36 | * pytorch >= 1.3
37 | * opencv
38 | * gdal
39 |
40 | # Running the implementation
41 | After installing all of the required dependencies above you can download the pretrained weights from [here](https://drive.google.com/drive/folders/1IPrDpvFq9ODW7UtPAJR_T-gGzxDat_uu?usp=sharing).
42 |
43 | Unzip the archive and place *saved_models_gan* folder in the main *projectRegularization* directory.
44 |
45 | Please note that the polygonization step is not yet available!
46 |
47 | ## Evaluation
48 | Modify *variables.py* accordingly, then run the prediction issuing the command
49 |
50 | ~~~
51 | python regularize.py
52 | ~~~
53 |
54 | ## Training
55 | Modify *variables.py* accordingly, then run the training issuing the command
56 |
57 | ~~~
58 | python train_gan_net.py
59 | ~~~
60 |
--------------------------------------------------------------------------------
/README.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zorzi-s/projectRegularization/c03b94dbcf66549518117c635cf61d843ee662ef/README.png
--------------------------------------------------------------------------------
/crf_loss.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import math
5 | import itertools
6 | import time
7 | import datetime
8 | import sys
9 | from math import exp
10 | import random
11 |
12 | #from torchvision.utils import save_image
13 | #from torchvision import datasets
14 |
15 | from torch.utils.data import DataLoader
16 | from torch.autograd import Variable
17 |
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import torch
21 |
22 | kernel_size = 9 #gaussian kernel dimension
23 | dilation = 1 #cheating :) The "real" dimension of the gaussian kernel is kernel size, but the "effective" dimension is (kernel_size*dilation + 1)
24 | padding = (kernel_size // 2) * dilation #do not touch this
25 | bs = 4 #batch size
26 | win = 256 #window size
27 |
28 | sigma_X = 3.0 #for distance gaussian
29 | sigma_I = 0.1 #for RGB/grayscale gaussian
30 |
31 | sample_interval = 20 # sample image every
32 |
33 | class kernel_loss(torch.nn.Module):
34 |
35 | def sub_kernel(self):
36 | filters = kernel_size * kernel_size
37 | middle = kernel_size // 2
38 | kernel = Variable(torch.zeros((filters, 1, kernel_size, kernel_size))).cuda()
39 | for i in range(kernel_size):
40 | for j in range(kernel_size):
41 | kernel[i*kernel_size+j, 0, i, j] = -1
42 | kernel[i*kernel_size+j, 0, middle, middle] = kernel[i*kernel_size+j, 0, middle, middle] + 1
43 | return kernel
44 |
45 | def dist_kernel(self):
46 | filters = kernel_size * kernel_size
47 | middle = kernel_size // 2
48 | kernel = Variable(torch.zeros((bs, filters, 1, 1))).cuda()
49 |
50 | for i in range(kernel_size):
51 | for j in range(kernel_size):
52 | ii = i - middle
53 | jj = j - middle
54 | distance = pow(ii,2) + pow(jj,2)
55 | kernel[:, i*kernel_size+j, 0, 0] = exp(-distance / pow(sigma_X,2))
56 | #print(kernel.view(4,1,kernel_size,kernel_size))
57 | return kernel
58 |
59 | def central_kernel(self):
60 | filters = kernel_size * kernel_size
61 | middle = kernel_size // 2
62 | kernel = Variable(torch.zeros((filters, 1, kernel_size, kernel_size))).cuda()
63 | for i in range(kernel_size):
64 | for j in range(kernel_size):
65 | kernel[i*kernel_size+j, 0, middle, middle] = 1
66 | return kernel
67 |
68 | def select_kernel(self):
69 | filters = kernel_size * kernel_size
70 | middle = kernel_size // 2
71 | kernel = Variable(torch.zeros((filters, 1, kernel_size, kernel_size))).cuda()
72 | for i in range(kernel_size):
73 | for j in range(kernel_size):
74 | kernel[i*kernel_size+j, 0, i, j] = 1
75 | return kernel
76 |
77 | def color_tensor(self, x):
78 | result = Variable(torch.zeros((bs, kernel_size*kernel_size, win-2*padding, win-2*padding))).cuda()
79 |
80 | for i in range(x.shape[1]):
81 | channel = x[:,i,:,:].unsqueeze(1)
82 | sub = nn.Conv2d(in_channels=1, out_channels=kernel_size*kernel_size, kernel_size=kernel_size, bias=False, padding=0, dilation=dilation)
83 | sub.weight.data = self.sub_matrix
84 | color = sub(channel)
85 | color = torch.pow(color,2)
86 | result = result + color
87 |
88 | result = torch.exp(-result / pow(sigma_I,2))
89 | return result
90 |
91 | def probability_tensor(self, y):
92 | conv = nn.Conv2d(in_channels=1, out_channels=kernel_size*kernel_size, kernel_size=kernel_size, bias=False, padding=0, dilation=dilation)
93 | conv.weight.data = self.select_matrix
94 | prob = conv(y)
95 | return prob
96 |
97 | #def probability_central(self, y):
98 | # conv = nn.Conv2d(in_channels=1, out_channels=kernel_size*kernel_size, kernel_size=kernel_size, bias=False, padding=padding)
99 | # conv.weight.data = self.one_matrix
100 | # prob = conv(y)
101 | # return prob
102 |
103 | def __init__(self):
104 | super(kernel_loss,self).__init__()
105 | #self.softmax = nn.Softmax(dim=1)
106 | self.dist_tensor = self.dist_kernel()
107 | #self.one_matrix = self.central_kernel()
108 | self.select_matrix = self.select_kernel()
109 | self.sub_matrix = self.sub_kernel() #shape: [filters, 1, h, w]
110 |
111 |
112 | def forward(self,x,y):
113 | """
114 | x --> Image. It can also have just 1 channel (grayscale). Values between 0 and 1
115 | y --> Mask. Values between 0 and 1
116 | """
117 | #y = self.softmax(y)
118 | y0 = y[:,0,:,:].unsqueeze(1) #build: 0, background: 1, default 1
119 | y1 = y[:,1,:,:].unsqueeze(1) #build: 1, background: 0, default 0
120 | y0p = y0[:,:,padding:-padding,padding:-padding]
121 | y1p = y1[:,:,padding:-padding,padding:-padding]
122 |
123 | W = self.color_tensor(x)
124 | W = (W * self.dist_tensor.expand_as(W))
125 |
126 | potts_loss_0 = y0p.expand_as(W) * W * self.probability_tensor(y1)
127 | potts_loss_1 = y1p.expand_as(W) * W * self.probability_tensor(y0)
128 |
129 | numel = potts_loss_0.numel()
130 | #ncut_loss_0 = (potts_loss_0 / (self.probability_tensor(y0) * W)).mean()
131 | #ncut_loss_1 = (potts_loss_1 / (self.probability_tensor(y1) * W)).mean()
132 |
133 | """
134 | if random.randint(0,sample_interval) == 0:
135 | r = random.randint(0,20)
136 |
137 | img = torch.mean(W, dim=1).unsqueeze(1)
138 | #amin = torch.min(img)
139 | #amax = torch.max(img)
140 | #img = (img - amin) / (amax - amin)
141 | save_image(img, "./debug/%d_img.png" % r, nrow=2)
142 |
143 | #img2 = torch.mean(potts_loss_0, dim=1).unsqueeze(1)
144 | #amin = torch.min(img2)
145 | #amax = torch.max(img2)
146 | #img2 = (img2 - amin) / (amax - amin)
147 | #save_image(img2, "./debug/%d_b.png" % r, nrow=2)
148 |
149 | img3 = torch.mean(potts_loss_0, dim=1).unsqueeze(1)
150 | #amin = torch.min(img3)
151 | #amax = torch.max(img3)
152 | #img3 = (img3 - amin) / (amax - amin)
153 | save_image(img3, "./debug/%d_loss.png" % r, nrow=2)
154 |
155 | #img4 = torch.mean(loss_matrix, dim=1).unsqueeze(1)
156 | ##amin = torch.min(img4)
157 | ##amax = torch.max(img4)
158 | ##img4 = (img4 - amin) / (amax - amin)
159 | #save_image(img4, "./debug/%d_d.png" % r, nrow=2)
160 | save_image(x, "./debug/%d_map.png" % r, nrow=2)
161 | """
162 |
163 | potts_loss_0 = (potts_loss_0).mean()
164 | potts_loss_1 = (potts_loss_1).mean()
165 | potts_loss = potts_loss_0 + potts_loss_1
166 |
167 | return potts_loss
168 |
169 | """
170 | #ncut_loss_0 = potts_loss_0 / (self.probability_tensor(y0) * W).mean()
171 | #ncut_loss_1 = potts_loss_1 / (self.probability_tensor(y1) * W).mean()
172 | ncut_loss_0 = potts_loss_0 / (y0p.expand_as(W) * W).mean()
173 | ncut_loss_1 = potts_loss_1 / (y1p.expand_as(W) * W).mean()
174 |
175 | #ncut_loss_0 = ncut_loss_0.mean()
176 | #ncut_loss_1 = ncut_loss_1.mean()
177 | ncut_loss = ncut_loss_0 + ncut_loss_1
178 |
179 | #potts_loss = potts_loss_0 + potts_loss_1
180 | #ncut_loss = ncut_loss_0 + ncut_loss_1
181 |
182 | return (potts_loss, ncut_loss, numel)
183 | """
184 |
185 |
--------------------------------------------------------------------------------
/data_loader_gan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from glob import glob
4 | from tqdm import tqdm
5 | import random
6 | from skimage import io
7 | from skimage.segmentation import mark_boundaries
8 | from skimage.transform import rotate
9 | import variables as var
10 |
11 | TEST = False
12 |
13 | def to_categorical(y, num_classes=None, dtype='float32'):
14 |
15 | y = np.array(y, dtype='int')
16 | input_shape = y.shape
17 | if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
18 | input_shape = tuple(input_shape[:-1])
19 | y = y.ravel()
20 | if not num_classes:
21 | num_classes = np.max(y) + 1
22 | n = y.shape[0]
23 | categorical = np.zeros((n, num_classes), dtype=dtype)
24 | categorical[np.arange(n), y] = 1
25 | output_shape = input_shape + (num_classes,)
26 | categorical = np.reshape(categorical, output_shape)
27 | return categorical
28 |
29 | class DataLoader():
30 |
31 | def __init__(self, ws=512, nb=10000, bs=8):
32 | self.nb = nb
33 | self.bs = bs
34 | self.ws = ws
35 |
36 | #self.rgb_files = self.rgb_files[:10]
37 | #self.dsm_files = self.dsm_files[:10]
38 | #self.gti_files = self.gti_files[:10]
39 |
40 | self.load_data()
41 | self.num_tiles = len(self.rgb_imgs)
42 | self.sliding_index = 0
43 |
44 | def generator(self):
45 | for _ in range(self.nb):
46 | batch_rgb = []
47 | batch_gti = []
48 | batch_seg = []
49 | for _ in range(self.bs):
50 | rgb, gti, seg = self.extract_image()
51 |
52 | batch_rgb.append(rgb)
53 |
54 | # the ground truth is categorized
55 | gti = to_categorical(gti != 0, 2)
56 | batch_gti.append(gti)
57 |
58 | # the segmentation is categorized
59 | seg = to_categorical(seg != 0, 2)
60 | batch_seg.append(seg)
61 |
62 | batch_rgb = np.asarray(batch_rgb)
63 | batch_gti = np.asarray(batch_gti)
64 | batch_seg = np.asarray(batch_seg)
65 | batch_rgb = batch_rgb / 255.0
66 |
67 | #batch_gti = batch_gti[:,:,:,np.newaxis] / 255.0
68 |
69 | yield (batch_rgb, batch_gti, batch_seg)
70 |
71 |
72 | def test_shape(self, a):
73 | ri = a.shape[0] % self.ws
74 | rj = a.shape[1] % self.ws
75 | return a[:-ri,:-rj]
76 |
77 |
78 | def random_hsv(self, img, value_h=30, value_s=30, value_v=30):
79 | hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
80 | h, s, v = cv2.split(hsv)
81 |
82 | h = np.int16(h)
83 | s = np.int16(s)
84 | v = np.int16(v)
85 |
86 | h += value_h
87 | h[h < 0] = 0
88 | h[h > 255] = 255
89 |
90 | s += value_s
91 | s[s < 0] = 0
92 | s[s > 255] = 255
93 |
94 | v += value_v
95 | v[v < 0] = 0
96 | v[v > 255] = 255
97 |
98 | h = np.uint8(h)
99 | s = np.uint8(s)
100 | v = np.uint8(v)
101 |
102 | final_hsv = cv2.merge((h, s, v))
103 | img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
104 | return img
105 |
106 |
107 | def extract_image(self, mode="sequential"):
108 | if mode is "random":
109 | rand_t = random.randint(0, self.num_tiles-1)
110 | else:
111 | if self.sliding_index < self.num_tiles:
112 | rand_t = self.sliding_index
113 | self.sliding_index = self.sliding_index + 1
114 | else:
115 | rand_t = 0
116 | self.sliding_index = 0
117 |
118 | rgb = self.rgb_imgs[rand_t].copy()
119 | gti = self.gti_imgs[rand_t].copy()
120 | seg = self.seg_imgs[rand_t].copy()
121 |
122 | h = rgb.shape[1]
123 | w = rgb.shape[0]
124 |
125 | void = True
126 | while void:
127 | rot = random.randint(0,90)
128 | ri = random.randint(0, int(h-self.ws*2))
129 | rj = random.randint(0, int(w-self.ws*2))
130 | win_rgb = rgb[ri:ri+int(self.ws*2), rj:rj+int(self.ws*2)]
131 | win_gti = gti[ri:ri+int(self.ws*2), rj:rj+int(self.ws*2)]
132 | win_seg = seg[ri:ri+int(self.ws*2), rj:rj+int(self.ws*2)]
133 |
134 | win_rgb = np.uint8(rotate(win_rgb, rot, resize=False, preserve_range=True))
135 | win_gti = np.uint8(rotate(win_gti, rot, resize=False, preserve_range=True))
136 | win_seg = np.uint8(rotate(win_seg, rot, resize=False, preserve_range=True))
137 |
138 | win_rgb = win_rgb[self.ws//2:-self.ws//2, self.ws//2:-self.ws//2]
139 | win_gti = win_gti[self.ws//2:-self.ws//2, self.ws//2:-self.ws//2]
140 | win_seg = win_seg[self.ws//2:-self.ws//2, self.ws//2:-self.ws//2]
141 |
142 | if np.count_nonzero(win_seg):
143 | void = False
144 |
145 | # Perform some data augmentation
146 | rot = random.randint(0,3)
147 | win_rgb = np.rot90(win_rgb, k=rot)
148 | win_gti = np.rot90(win_gti, k=rot)
149 | win_seg = np.rot90(win_seg, k=rot)
150 | if random.randint(0,1) is 1:
151 | win_rgb = np.fliplr(win_rgb)
152 | win_gti = np.fliplr(win_gti)
153 | win_seg = np.fliplr(win_seg)
154 |
155 | r_h = random.randint(-20,20)
156 | r_s = random.randint(-20,20)
157 | r_v = random.randint(-20,20)
158 | win_rgb = self.random_hsv(win_rgb, r_h, r_s, r_v)
159 |
160 | win_rgb = win_rgb.astype(np.float32)
161 | win_gti = win_gti.astype(np.float32)
162 | win_seg = win_seg.astype(np.float32)
163 | return (win_rgb, win_gti, win_seg)
164 |
165 |
166 | def load_data(self):
167 | self.rgb_imgs = []
168 | self.gti_imgs = []
169 | self.seg_imgs = []
170 |
171 | rgb_files = glob(var.DATASET_RGB)
172 | gti_files = glob(var.DATASET_GTI)
173 | seg_files = glob(var.DATASET_SEG)
174 |
175 | rgb_files.sort()
176 | gti_files.sort()
177 | seg_files.sort()
178 |
179 | combined = list(zip(rgb_files, gti_files, seg_files))
180 | random.shuffle(combined)
181 |
182 | rgb_files[:], gti_files[:], seg_files[:] = zip(*combined)
183 |
184 | if TEST:
185 | rgb_files = rgb_files[:4]
186 | gti_files = gti_files[:4]
187 | seg_files = seg_files[:4]
188 |
189 | for rgb_name, gti_name, seg_name in tqdm(zip(rgb_files, gti_files, seg_files), total=len(rgb_files), desc="Loading dataset into RAM"):
190 |
191 | tmp = io.imread(rgb_name)
192 | tmp = tmp.astype(np.uint8)
193 | self.rgb_imgs.append(tmp)
194 |
195 | tmp = io.imread(gti_name)
196 | tmp = tmp.astype(np.uint8)
197 | self.gti_imgs.append(tmp)
198 |
199 | tmp = io.imread(seg_name)
200 | tmp = tmp.astype(np.uint8)
201 | self.seg_imgs.append(tmp)
202 |
203 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | def weights_init_normal(m):
7 | classname = m.__class__.__name__
8 | if classname.find("Conv") != -1:
9 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
10 | if hasattr(m, "bias") and m.bias is not None:
11 | torch.nn.init.constant_(m.bias.data, 0.0)
12 | elif classname.find("BatchNorm2d") != -1:
13 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | torch.nn.init.constant_(m.bias.data, 0.0)
15 |
16 |
17 |
18 | class ResidualBlock(nn.Module):
19 | def __init__(self, in_features):
20 | super(ResidualBlock, self).__init__()
21 |
22 | self.block = nn.Sequential(
23 | #nn.ReflectionPad2d(1),
24 | nn.Conv2d(in_features, in_features, 3, stride=1, padding=1),
25 | nn.InstanceNorm2d(in_features),
26 | nn.ReLU(inplace=True),
27 | #nn.ReflectionPad2d(1),
28 | nn.Conv2d(in_features, in_features, 3, stride=1, padding=1),
29 | nn.InstanceNorm2d(in_features),
30 | nn.ReLU(inplace=True),
31 | )
32 |
33 | def forward(self, x):
34 | return x + self.block(x)
35 |
36 |
37 |
38 | class GeneratorResNet(nn.Module):
39 | def __init__(self, num_residual_blocks=8, in_features=256):
40 | super(GeneratorResNet, self).__init__()
41 |
42 | out_features = in_features
43 |
44 | model = []
45 |
46 | # Residual blocks
47 | for _ in range(num_residual_blocks):
48 | model += [ResidualBlock(out_features)]
49 |
50 | # Upsampling
51 | for _ in range(2):
52 | out_features //= 2
53 | model += [
54 | nn.Upsample(scale_factor=2),
55 | nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
56 | nn.InstanceNorm2d(out_features),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(out_features, out_features, 3, stride=1, padding=1),
59 | nn.InstanceNorm2d(out_features),
60 | nn.ReLU(inplace=True),
61 | nn.Conv2d(out_features, out_features, 3, stride=1, padding=1),
62 | nn.InstanceNorm2d(out_features),
63 | nn.ReLU(inplace=True),
64 | ]
65 | in_features = out_features
66 |
67 | # Output layer
68 | #model += [nn.ReflectionPad2d(2), nn.Conv2d(out_features, 2, 7), nn.Softmax()]
69 | model += [nn.Conv2d(out_features, 2, 7, stride=1, padding=3), nn.Sigmoid()]
70 |
71 | self.model = nn.Sequential(*model)
72 |
73 | def forward(self, feature_map):
74 | x = self.model(feature_map)
75 | return x
76 |
77 |
78 | class Encoder(nn.Module):
79 | def __init__(self, channels=3+2):
80 | super(Encoder, self).__init__()
81 |
82 | # Initial convolution block
83 | out_features = 64
84 | model = [
85 | nn.Conv2d(channels, out_features, 7, stride=1, padding=3),
86 | nn.InstanceNorm2d(out_features),
87 | nn.ReLU(inplace=True),
88 | ]
89 | in_features = out_features
90 |
91 | # Downsampling
92 | for _ in range(2):
93 | out_features *= 2
94 | model += [
95 | nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
96 | nn.InstanceNorm2d(out_features),
97 | nn.ReLU(inplace=True),
98 | nn.Conv2d(out_features, out_features, 3, stride=1, padding=1),
99 | nn.InstanceNorm2d(out_features),
100 | nn.ReLU(inplace=True),
101 | nn.MaxPool2d(2, stride=2),
102 | ]
103 | in_features = out_features
104 |
105 | self.model = nn.Sequential(*model)
106 |
107 | def forward(self, arguments):
108 | x = torch.cat(arguments, dim=1)
109 | x = self.model(x)
110 | return x
111 |
112 |
113 | class Discriminator(nn.Module):
114 | def __init__(self):
115 | super(Discriminator, self).__init__()
116 |
117 | channels = 2
118 | out_channels = 2
119 |
120 | def discriminator_block(in_filters, out_filters, normalize=True):
121 | """Returns downsampling layers of each discriminator block"""
122 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1)]
123 | if normalize:
124 | layers.append(nn.InstanceNorm2d(out_filters))
125 | layers.append(nn.ReLU())
126 |
127 | layers.append(nn.Conv2d(out_filters, out_filters, 3, stride=1, padding=1))
128 | if normalize:
129 | layers.append(nn.InstanceNorm2d(out_filters))
130 | layers.append(nn.ReLU())
131 | layers.append(nn.MaxPool2d(2, stride=2))
132 | return layers
133 |
134 | self.model = nn.Sequential(
135 | *discriminator_block(channels, 64, normalize=False),
136 | *discriminator_block(64, 128),
137 | *discriminator_block(128, 256),
138 | *discriminator_block(256, 512),
139 | nn.Conv2d(512, out_channels, 3, padding=1),
140 | nn.Sigmoid()
141 | )
142 |
143 | def forward(self, img):
144 | #img = torch.cat((rgb, mask), dim=1)
145 | img = self.model(img)
146 | return img
147 |
--------------------------------------------------------------------------------
/regularize.py:
--------------------------------------------------------------------------------
1 | import random
2 | from skimage import io
3 | from skimage.transform import rotate
4 | import numpy as np
5 | import torch
6 | from tqdm import tqdm
7 | import gdal
8 | import os
9 | import glob
10 | from skimage.segmentation import mark_boundaries
11 | from PIL import Image, ImageDraw, ImageFont
12 | from numpy.linalg import svd
13 | import cv2
14 | from skimage import measure
15 |
16 | from models import GeneratorResNet, Encoder
17 | from skimage.transform import rescale
18 | import variables as var
19 |
20 |
21 |
22 |
23 | def compute_IoU(mask, pred):
24 | mask = mask!=0
25 | pred = pred!=0
26 |
27 | m1 = np.logical_and(mask, pred)
28 | m2 = np.logical_and(np.logical_not(mask), np.logical_not(pred))
29 | m3 = np.logical_and(mask==0, pred==1)
30 | m4 = np.logical_and(mask==1, pred==0)
31 | m5 = np.logical_or(mask, pred)
32 |
33 | tp = np.count_nonzero(m1)
34 | fp = np.count_nonzero(m3)
35 | fn = np.count_nonzero(m4)
36 |
37 | IoU = tp/(tp+(fn+fp))
38 | return IoU
39 |
40 |
41 | def to_categorical(y, num_classes=None, dtype='float32'):
42 | y = np.array(y, dtype='int')
43 | input_shape = y.shape
44 | if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
45 | input_shape = tuple(input_shape[:-1])
46 | y = y.ravel()
47 | if not num_classes:
48 | num_classes = np.max(y) + 1
49 | n = y.shape[0]
50 | categorical = np.zeros((n, num_classes), dtype=dtype)
51 | categorical[np.arange(n), y] = 1
52 | output_shape = input_shape + (num_classes,)
53 | categorical = np.reshape(categorical, output_shape)
54 | return categorical
55 |
56 |
57 | def predict_building(rgb, mask, model):
58 | Tensor = torch.cuda.FloatTensor
59 |
60 | mask = to_categorical(mask, 2)
61 |
62 | rgb = rgb[np.newaxis, :, :, :]
63 | mask = mask[np.newaxis, :, :, :]
64 |
65 | E, G = model
66 |
67 | rgb = Tensor(rgb)
68 | mask = Tensor(mask)
69 | rgb = rgb.permute(0,3,1,2)
70 | mask = mask.permute(0,3,1,2)
71 |
72 | rgb = rgb / 255.0
73 |
74 | # PREDICTION
75 | pred = G(E([rgb, mask]))
76 | pred = pred.permute(0,2,3,1)
77 |
78 | pred = pred.detach().cpu().numpy()
79 |
80 | pred = np.argmax(pred[0,:,:,:], axis=-1)
81 | return pred
82 |
83 |
84 |
85 | def fix_limits(i_min, i_max, j_min, j_max, min_image_size=256):
86 |
87 | def closest_divisible_size(size, factor=4):
88 | while size % factor:
89 | size += 1
90 | return size
91 |
92 | height = i_max - i_min
93 | width = j_max - j_min
94 |
95 | # pad the rows
96 | if height < min_image_size:
97 | diff = min_image_size - height
98 | else:
99 | diff = closest_divisible_size(height) - height + 16
100 |
101 | i_min -= (diff // 2)
102 | i_max += (diff // 2 + diff % 2)
103 |
104 | # pad the columns
105 | if width < min_image_size:
106 | diff = min_image_size - width
107 | else:
108 | diff = closest_divisible_size(width) - width + 16
109 |
110 | j_min -= (diff // 2)
111 | j_max += (diff // 2 + diff % 2)
112 |
113 | return i_min, i_max, j_min, j_max
114 |
115 |
116 |
117 | def regularization(rgb, ins_segmentation, model, in_mode="instance", out_mode="instance", min_size=10):
118 | assert in_mode == "instance" or in_mode == "semantic"
119 | assert out_mode == "instance" or out_mode == "semantic"
120 |
121 | if in_mode == "semantic":
122 | ins_segmentation = np.uint16(measure.label(ins_segmentation, background=0))
123 |
124 | max_instance = np.amax(ins_segmentation)
125 | border = 256
126 |
127 | ins_segmentation = np.uint16(cv2.copyMakeBorder(ins_segmentation,border,border,border,border,cv2.BORDER_CONSTANT,value=0))
128 | rgb = np.uint8(cv2.copyMakeBorder(rgb,border,border,border,border,cv2.BORDER_CONSTANT,value=(0,0,0)))
129 |
130 | regularization = np.zeros(ins_segmentation.shape, dtype=np.uint16)
131 |
132 | for ins in tqdm(range(1, max_instance+1), desc="Regularization"):
133 | indices = np.argwhere(ins_segmentation==ins)
134 | building_size = indices.shape[0]
135 | if building_size > min_size:
136 | i_min = np.amin(indices[:,0])
137 | i_max = np.amax(indices[:,0])
138 | j_min = np.amin(indices[:,1])
139 | j_max = np.amax(indices[:,1])
140 |
141 | i_min, i_max, j_min, j_max = fix_limits(i_min, i_max, j_min, j_max)
142 |
143 | mask = np.copy(ins_segmentation[i_min:i_max, j_min:j_max] == ins)
144 | rgb_mask = np.copy(rgb[i_min:i_max, j_min:j_max, :])
145 |
146 |
147 |
148 | max_building_size = 1024
149 | rescaled = False
150 | if mask.shape[0] > max_building_size and mask.shape[0] >= mask.shape[1]:
151 | f = max_building_size / mask.shape[0]
152 | mask = rescale(mask, f, anti_aliasing=False, preserve_range=True)
153 | rgb_mask = rescale(rgb_mask, f, anti_aliasing=False)
154 | rescaled = True
155 | elif mask.shape[1] > max_building_size and mask.shape[1] >= mask.shape[0]:
156 | f = max_building_size / mask.shape[1]
157 | mask = rescale(mask, f, anti_aliasing=False)
158 | rgb_mask = rescale(rgb_mask, f, anti_aliasing=False, preserve_range=True)
159 | rescaled = True
160 |
161 | pred = predict_building(rgb_mask, mask, model)
162 |
163 | if rescaled:
164 | pred = rescale(pred, 1/f, anti_aliasing=False, preserve_range=True)
165 |
166 |
167 |
168 | pred_indices = np.argwhere(pred != 0)
169 |
170 | if pred_indices.shape[0] > 0:
171 | pred_indices[:,0] = pred_indices[:,0] + i_min
172 | pred_indices[:,1] = pred_indices[:,1] + j_min
173 | x, y = zip(*pred_indices)
174 | if out_mode == "semantic":
175 | regularization[x,y] = 1
176 | else:
177 | regularization[x,y] = ins
178 |
179 | return regularization[border:-border, border:-border]
180 |
181 |
182 |
183 | def copyGeoreference(inp, output):
184 | dataset = gdal.Open(inp)
185 | if dataset is None:
186 | print('Unable to open', inp, 'for reading')
187 | sys.exit(1)
188 |
189 | projection = dataset.GetProjection()
190 | geotransform = dataset.GetGeoTransform()
191 |
192 | if projection is None and geotransform is None:
193 | print('No projection or geotransform found on file' + input)
194 | sys.exit(1)
195 |
196 | dataset2 = gdal.Open(output, gdal.GA_Update)
197 |
198 | if dataset2 is None:
199 | print('Unable to open', output, 'for writing')
200 | sys.exit(1)
201 |
202 | if geotransform is not None and geotransform != (0, 1, 0, 0, 0, 1):
203 | dataset2.SetGeoTransform(geotransform)
204 |
205 | if projection is not None and projection != '':
206 | dataset2.SetProjection(projection)
207 |
208 | gcp_count = dataset.GetGCPCount()
209 | if gcp_count != 0:
210 | dataset2.SetGCPs(dataset.GetGCPs(), dataset.GetGCPProjection())
211 |
212 | dataset = None
213 | dataset2 = None
214 |
215 |
216 |
217 | def regularize_segmentations(img_folder, seg_folder, out_folder, in_mode="semantic", out_mode="instance", samples=None):
218 | """
219 | BUILDING REGULARIZATION
220 | Inputs:
221 | - satellite image (3 channels)
222 | - building segmentation (1 channel)
223 | Output:
224 | - regularized mask
225 | """
226 |
227 | img_files = glob.glob(img_folder)
228 | seg_files = glob.glob(seg_folder)
229 |
230 | img_files.sort()
231 | seg_files.sort()
232 |
233 | for num, (satellite_image_file, building_segmentation_file) in enumerate(zip(img_files, seg_files)):
234 | print(satellite_image_file, building_segmentation_file)
235 | _, rgb_name = os.path.split(satellite_image_file)
236 | _, seg_name = os.path.split(building_segmentation_file)
237 | assert rgb_name == seg_name
238 |
239 | output_file = out_folder + seg_name
240 |
241 | E1 = Encoder()
242 | G = GeneratorResNet()
243 | G.load_state_dict(torch.load(var.MODEL_GENERATOR))
244 | E1.load_state_dict(torch.load(var.MODEL_ENCODER))
245 | E1 = E1.cuda()
246 | G = G.cuda()
247 |
248 | model = [E1,G]
249 |
250 | M = io.imread(building_segmentation_file)
251 | M = np.uint16(M)
252 | P = io.imread(satellite_image_file)
253 | P = np.uint8(P)
254 |
255 | R = regularization(P, M, model, in_mode=in_mode, out_mode=out_mode)
256 |
257 | if out_mode == "instance":
258 | io.imsave(output_file, np.uint16(R))
259 | else:
260 | io.imsave(output_file, np.uint8(R*255))
261 |
262 | if samples is not None:
263 | i = 1000
264 | j = 1000
265 | h, w = 1080, 1920
266 | P = P[i:i+h, j:j+w]
267 | R = R[i:i+h, j:j+w]
268 | M = M[i:i+h, j:j+w]
269 |
270 | R = mark_boundaries(P, R, mode="thick")
271 | M = mark_boundaries(P, M, mode="thick")
272 |
273 | R = np.uint8(R*255)
274 | M = np.uint8(M*255)
275 |
276 | font = cv2.FONT_HERSHEY_SIMPLEX
277 | bottomLeftCornerOfText = (20,1060)
278 | fontScale = 1
279 | fontColor = (255,255,0)
280 | lineType = 2
281 |
282 | cv2.putText(R, "INRIA dataset, " + rgb_name + ", regularization",
283 | bottomLeftCornerOfText,
284 | font,
285 | fontScale,
286 | fontColor,
287 | lineType)
288 |
289 | cv2.putText(M, "INRIA dataset, " + rgb_name + ", segmentation",
290 | bottomLeftCornerOfText,
291 | font,
292 | fontScale,
293 | fontColor,
294 | lineType)
295 |
296 | io.imsave(samples + "./%d_2reg.png" % num, np.uint8(R))
297 | io.imsave(samples + "./%d_1seg.png" % num, np.uint8(M))
298 |
299 | copyGeoreference(satellite_image_file, output_file)
300 | copyGeoreference(satellite_image_file, building_segmentation_file)
301 |
302 |
303 |
304 | regularize_segmentations(img_folder=var.INF_RGB, seg_folder=var.INF_SEG, out_folder=var.INF_OUT, in_mode="semantic", out_mode="instance", samples=None)
305 |
--------------------------------------------------------------------------------
/train_gan_net.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | import torch
5 | from torch import nn
6 | from torch import optim
7 | from torch.optim.lr_scheduler import MultiStepLR
8 | from torch.autograd import Variable
9 | from torch.utils.data import DataLoader
10 |
11 | from tqdm import tqdm
12 | import click
13 | import numpy as np
14 | import cv2
15 | from skimage.segmentation import mark_boundaries
16 | from skimage import io
17 | import itertools
18 |
19 | from models import GeneratorResNet, Encoder, Discriminator
20 | from data_loader_gan import DataLoader
21 | from training_utils import sample_images, LossBuffer, LambdaLR
22 | import variables as var
23 | from crf_loss import kernel_loss
24 |
25 |
26 |
27 | def crf_factor(batch_index, start_crf_batch, end_crf_batch, crf_initial_factor, crf_final_factor):
28 | if batch_index <= start_crf_batch:
29 | return 0.0
30 | elif start_crf_batch < batch_index < end_crf_batch:
31 | return crf_initial_factor + ((crf_final_factor - crf_initial_factor) * (batch_index - start_crf_batch) / (end_crf_batch - start_crf_batch))
32 | else:
33 | return crf_final_factor
34 |
35 |
36 | def train(
37 | models_path='./saved_models_gan/', \
38 | restore=False, \
39 | batch_size=4, \
40 | start_batch=0, n_batches=140000, \
41 | start_crf_batch=60000, end_crf_batch=120000, crf_initial_factor=0.0, crf_final_factor=175.0, \
42 | start_lr_decay=120000, \
43 | start_lr=0.00004, win_size=256, sample_interval=20, backup_interval=5000):
44 |
45 | patch_size = int(win_size / pow(2, 4))
46 |
47 | Tensor = torch.cuda.FloatTensor
48 |
49 | e1 = Encoder(channels=3+2)
50 | e2 = Encoder(channels=2)
51 | net = GeneratorResNet()
52 | disc = Discriminator()
53 |
54 | if restore:
55 | print("Restoring model number %d" % start_batch)
56 | e1.load_state_dict(torch.load(models_path + "E%d_e1" % start_batch))
57 | e2.load_state_dict(torch.load(models_path + "E%d_e2" % start_batch))
58 | net.load_state_dict(torch.load(models_path + "E%d_net" % start_batch))
59 | disc.load_state_dict(torch.load(models_path + "E%d_disc" % start_batch))
60 |
61 | e1 = e1.cuda()
62 | e2 = e2.cuda()
63 | net = net.cuda()
64 | disc = disc.cuda()
65 |
66 | os.makedirs(models_path, exist_ok=True)
67 |
68 | loss_0_buffer = LossBuffer()
69 | loss_1_buffer = LossBuffer()
70 | loss_2_buffer = LossBuffer()
71 | loss_3_buffer = LossBuffer()
72 | loss_4_buffer = LossBuffer()
73 | loss_5_buffer = LossBuffer()
74 |
75 | gen_obj = DataLoader(bs=batch_size, nb=n_batches, ws=win_size)
76 |
77 | # Optimizers
78 | optimizer_G = torch.optim.Adam(itertools.chain(net.parameters(), e1.parameters(), e2.parameters()), lr=start_lr)
79 | optimizer_D = torch.optim.Adam(disc.parameters(), lr=start_lr)
80 |
81 | # Learning rate update schedulers
82 | lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_batches, start_lr_decay).step)
83 | lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(n_batches, start_lr_decay).step)
84 |
85 | bce_criterion = nn.BCELoss()
86 | bce_criterion = bce_criterion.cuda()
87 |
88 | densecrflosslayer = kernel_loss()
89 | densecrflosslayer = densecrflosslayer.cuda()
90 |
91 | loader = gen_obj.generator()
92 | train_iterator = tqdm(loader, total=(n_batches + 1 - start_batch))
93 | img_index = 0
94 |
95 | for batch_index, (rgb, gti, seg) in enumerate(train_iterator):
96 |
97 | batch_index = batch_index + start_batch
98 |
99 | rgb = Variable(Tensor(rgb))
100 | gti = Variable(Tensor(gti))
101 | seg = Variable(Tensor(seg))
102 |
103 | rgb = rgb.permute(0,3,1,2)
104 | gti = gti.permute(0,3,1,2)
105 | seg = seg.permute(0,3,1,2)
106 |
107 | # Adversarial ground truths
108 | ones = Variable(Tensor(np.ones((batch_size, 1, patch_size, patch_size))), requires_grad=False)
109 | zeros = Variable(Tensor(np.zeros((batch_size, 1, patch_size, patch_size))), requires_grad=False)
110 | valid = torch.cat((ones, zeros), dim=1)
111 | fake = torch.cat((zeros, ones), dim=1)
112 |
113 | # ------------------
114 | # Train Generators
115 | # ------------------
116 |
117 | #e1.train()
118 | #e2.train()
119 | #net.train()
120 |
121 | optimizer_G.zero_grad()
122 |
123 | reg = net(e1([rgb, seg]))
124 | rec = net(e2([gti]))
125 |
126 | # Identity loss (reconstruction loss)
127 | loss_rec_1 = bce_criterion(reg, seg)
128 | loss_rec_2 = bce_criterion(rec, gti)
129 |
130 | # GAN loss
131 | loss_GAN = bce_criterion(disc(reg), valid)
132 |
133 | # CRF loss
134 | pot_multiplier = crf_factor(batch_index, start_crf_batch, end_crf_batch, crf_initial_factor, crf_final_factor)
135 | loss_pot = densecrflosslayer(rgb, reg)
136 | loss_pot = loss_pot.cuda()
137 |
138 | # Total loss
139 | loss_G = 3 * loss_GAN + 1 * loss_rec_1 + 3 * loss_rec_2 + pot_multiplier * loss_pot
140 |
141 | loss_G.backward()
142 | optimizer_G.step()
143 |
144 |
145 | # -----------------------
146 | # Train Discriminator A
147 | # -----------------------
148 |
149 | #disc.train()
150 |
151 | optimizer_D.zero_grad()
152 |
153 | loss_real = bce_criterion(disc(rec.detach()), valid)
154 | loss_fake = bce_criterion(disc(reg.detach()), fake)
155 |
156 | # Total loss
157 | loss_D = (loss_real + loss_fake) / 2
158 |
159 | loss_D.backward()
160 | optimizer_D.step()
161 |
162 | # --------------
163 | # Update LR
164 | # --------------
165 |
166 | lr_scheduler_G.step(batch_index)
167 | lr_scheduler_D.step(batch_index)
168 |
169 | for g in optimizer_D.param_groups:
170 | current_lr = g['lr']
171 |
172 | # --------------
173 | # Log Progress
174 | # --------------
175 |
176 | status = "[Batch %d][D loss: %f][G loss: %f, adv: %f, rec1: %f, rec2: %f][pot: %f, pot_mul: %f][lr: %f]" % \
177 | (batch_index, \
178 | loss_0_buffer.push(loss_D.item()), \
179 | loss_1_buffer.push(loss_G.item()), loss_2_buffer.push(loss_GAN.item()), loss_3_buffer.push(loss_rec_1.item()), loss_4_buffer.push(loss_rec_2.item()),
180 | loss_5_buffer.push(loss_pot.item()), pot_multiplier, current_lr, )
181 |
182 | train_iterator.set_description(status)
183 |
184 | if (batch_index % sample_interval == 0):
185 | img_index += 1
186 | void_mask = torch.zeros(gti.shape).cuda()
187 | sample_images(img_index, rgb, [void_mask, gti, rec, seg, reg])
188 | if img_index >= 100:
189 | img_index = 0
190 |
191 | if (batch_index % backup_interval == 0):
192 | torch.save(e1.state_dict(), models_path + "E" + str(batch_index) + "_e1")
193 | torch.save(e2.state_dict(), models_path + "E" + str(batch_index) + "_e2")
194 | torch.save(net.state_dict(), models_path + "E" + str(batch_index) + "_net")
195 | torch.save(disc.state_dict(), models_path + "E" + str(batch_index) + "_disc")
196 |
197 |
198 | if __name__ == '__main__':
199 | train()
200 |
--------------------------------------------------------------------------------
/training_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import glob
4 | from tqdm import tqdm
5 | import random
6 | from skimage import io
7 | from skimage.segmentation import mark_boundaries
8 |
9 | import random
10 | import time
11 | import datetime
12 | import sys
13 |
14 | from torch.autograd import Variable
15 | import torch
16 | import numpy as np
17 |
18 | import gdal
19 |
20 | import variables as var
21 |
22 |
23 | def sample_images(sample_index, img, masks):
24 | batch = img.shape[0]
25 |
26 | img = img.permute(0,2,3,1)
27 |
28 | for i in range(len(masks)):
29 | masks[i] = masks[i].permute(0,2,3,1)
30 |
31 | img = img.cpu().numpy()
32 | ip = np.uint8(img * 255)
33 | for i in range(len(masks)):
34 | masks[i] = masks[i].detach().cpu().numpy()
35 | masks[i] = np.argmax(masks[i], axis=-1)
36 | masks[i] = np.uint8(masks[i] * 255)
37 |
38 | line_mode = "inner"
39 |
40 | for i in range(len(masks)):
41 | row = np.copy(ip[0,:,:,:])
42 | line = cv2.Canny(masks[i][0,:,:], 0, 255)
43 | row = mark_boundaries(row, line, color=(1,1,0), mode=line_mode) * 255#, outline_color=(self.red,self.greed,0))
44 | for b in range(1,batch):
45 | pic = np.copy(ip[b,:,:,:])
46 | line = cv2.Canny(masks[i][b,:,:], 0, 255)
47 | pic = mark_boundaries(pic, line, color=(1,1,0), mode=line_mode) * 255#, outline_color=(self.red,self.greed,0))
48 | row = np.concatenate((row, pic), 1)
49 | masks[i] = row
50 |
51 | img = np.concatenate(masks, 0)
52 | img = np.uint8(img)
53 | io.imsave(var.DEBUG_DIR + "debug_%s.png" % str(sample_index), img)
54 |
55 |
56 | class LossBuffer():
57 | def __init__(self, max_size=100):
58 | self.data = []
59 | self.max_size = max_size
60 |
61 | def push(self, data):
62 | self.data.append(data)
63 | if len(self.data) > self.max_size:
64 | self.data = self.data[1:]
65 | return sum(self.data) / len(self.data)
66 |
67 |
68 | class LambdaLR():
69 | def __init__(self, n_batches, decay_start_batch):
70 | assert ((n_batches - decay_start_batch) > 0), "Decay must start before the training session ends!"
71 | self.n_batches = n_batches
72 | self.decay_start_batch = decay_start_batch
73 |
74 | def step(self, batch):
75 | if batch > self.decay_start_batch:
76 | factor = 1.0 - (batch - self.decay_start_batch) / (self.n_batches - self.decay_start_batch)
77 | if factor > 0:
78 | return factor
79 | else:
80 | return 0.0
81 | else:
82 | return 1.0
83 |
--------------------------------------------------------------------------------
/variables.py:
--------------------------------------------------------------------------------
1 | # TRAINING
2 | DATASET_RGB = "./data/rgb/*.tif"
3 | DATASET_GTI = "./data/gti/*.tif"
4 | DATASET_SEG = "./data/seg/*.tif"
5 |
6 | DEBUG_DIR = "./debug/"
7 |
8 |
9 | # INFERENCE
10 | INF_RGB = "./test_data/rgb/*.tif"
11 | INF_SEG = "./test_data/seg/*.tif"
12 | INF_OUT = "./test_data/reg_output/"
13 |
14 | MODEL_ENCODER = "./saved_models_gan/E140000_e1"
15 | MODEL_GENERATOR = "./saved_models_gan/E140000_net"
16 |
--------------------------------------------------------------------------------