├── assets ├── train1.png └── train2.png ├── requirements.txt ├── unit_tests.py ├── README.md └── trainer.py /assets/train1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrii-khizbullin/area_predictor/HEAD/assets/train1.png -------------------------------------------------------------------------------- /assets/train2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrii-khizbullin/area_predictor/HEAD/assets/train2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.1.0 3 | matplotlib==3.0.3 4 | numpy==1.16.3 5 | Pillow==6.0.0 6 | pyparsing==2.4.0 7 | python-dateutil==2.8.0 8 | six==1.12.0 9 | torch==1.0.1.post2 10 | torchsummary==1.5.1 11 | torchvision==0.2.2.post3 12 | -------------------------------------------------------------------------------- /unit_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from trainer import Predictor, CustomDataset 3 | 4 | 5 | class NetTest(unittest.TestCase): 6 | 7 | def test_predictor(self): 8 | dataset = CustomDataset("data/") 9 | image = dataset.get_item(dataset.get_list()[0]).image 10 | 11 | predictor = Predictor() 12 | 13 | area = predictor(image) 14 | 15 | self.assertTrue(isinstance(area, int)) 16 | self.assertGreater(area, 0) 17 | self.assertLessEqual(area, image.shape[0]*image.shape[1]) 18 | 19 | print(area) 20 | pass 21 | 22 | 23 | if __name__ == "__main__": 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Neural-network-based predictor for area of a rectangle 2 | 3 | #### Task 4 | The task is to determine a pixel area of one rectangle on an image. 5 | 6 | The size of provided dataset is obviously very small for training any deep learning model. Consequently I've decided to go with a synthetic dataset for training. As for the model, I've experimented with a couple of unconventional light models, which did not result in a success presumably due to the small size of the dataset and a very weak target (just 2 scalars for an image). Thus I've stopped on a classical semantic segmentation proxy task. In this case interior of a rectangle is annotated as foreground class (1 or white), exterior - as background class (0 or black). As a feature extractor I've chosen a standard solution: UNet. Channel number of the used model is significantly lower that of the original UNet since the task is relatively simple. 7 | 8 | The model turns out to be a bit too beefy - 30 MB. The footprint can be shrinked by reducing channels in the bottleneck of Unet hourglass. A trained model is included into this deliverable. 9 | 10 | The work resolution is chosen to be 128x128 to speedup training. Training took about 45 minutes on T4 in Google Colab. 11 | 12 | Accuracy of prediction I am estimating as a difference in predicted and ground truth areas divided by ground truth area. This metric (averaged over 69 annotated images) in the provided dataset is approximately 1.1%. It can be reduced to almost zero by increasing the resolution of the network to native 256x256 and extra shaping of channels in encoder and decoder of UNet. 13 | 14 | With regards to the provided dataset, one can notice that all the noise in the images comes to alpha channel. Thus I've decided to throw away alpha channel, and the network is trained on RGB channels. This also simplifies synthetic image generator. 15 | 16 | Intermediate results of training (from left to right: prediction map, ground truth segmentation, input image): 17 | 18 | ![](assets/train1.png) 19 | 20 | Fully trained model's predictions: 21 | 22 | ![](assets/train2.png) 23 | 24 | 25 | #### Code 26 | 27 | I am using PyTorch as a DL framework. The main code and entry point are located in trainer.py. 28 | 29 | To install requirements: 30 | ``` 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | Custom dataset is assumed to be located at `"data/"` folder next to the script. 35 | 36 | The default mode of the script is to launch training: 37 | ``` 38 | python trainer.py 39 | ``` 40 | After finishing the script will produce a PyTorch model - snapshot.pth. 41 | 42 | For evaluation, given a snapshot.pth, one can run 43 | ``` 44 | python trainer --evaluate 45 | ``` 46 | to get results of accuracy evaluation. 47 | 48 | I am also asked to provide a function to preform inference. It is implemented as a function class Predictor. It assumes snapshot.pth file present in the directory. How to use Predictor: 49 | ``` 50 | predictor = Predictor() 51 | area = predictor(image_np_hwc) 52 | area = predictor(image_np_hwc) 53 | ``` 54 | 55 | Predictor is tested with a unit test in unit_tests.py. 56 | ``` 57 | python unit_tests.py 58 | ``` 59 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import math 5 | import numpy as np 6 | from PIL import Image 7 | from typing import Union, List 8 | from argparse import ArgumentParser 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.modules.loss 13 | import torch.nn.functional as F 14 | import torchsummary 15 | 16 | 17 | class Sample: 18 | """ 19 | This class represents either one sample or a batch of samples. 20 | A sample keeps both input and target tensors as well as auxiliary 21 | payload such as sample name. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | image: Union[np.ndarray, List[np.ndarray]], 27 | anno_hw: Union[np.ndarray, List[np.ndarray]], 28 | name: Union[str, List[str]], 29 | image_tensor: torch.Tensor, 30 | segmentation_tensor: torch.Tensor): 31 | """ 32 | Constructor. 33 | :param image: original image 34 | :param anno_hw: annotated rectangle width/height 35 | :param name: text name of a sample 36 | :param image_tensor: input tensor to feed into NN 37 | :param segmentation_tensor: target for semantic segmentation 38 | """ 39 | 40 | self.image = image 41 | self.anno_hw = anno_hw 42 | self.name = name 43 | self.image_tensor = image_tensor 44 | self.segmentation_tensor = segmentation_tensor 45 | 46 | def cuda(self): 47 | """ Push all tensors to cuda. """ 48 | if self.image_tensor is not None: 49 | self.image_tensor = self.image_tensor.cuda() 50 | if self.segmentation_tensor is not None: 51 | self.segmentation_tensor = self.segmentation_tensor.cuda() 52 | 53 | def batchify(self): 54 | """ Create a batch out of one sample. """ 55 | if self.image_tensor is not None: 56 | self.image_tensor = self.image_tensor.unsqueeze(dim=0) 57 | if self.segmentation_tensor is not None: 58 | self.segmentation_tensor = self.segmentation_tensor.unsqueeze(dim=0) 59 | 60 | @staticmethod 61 | def collate(batch_list): 62 | """ Create a batch out of a list of samples. """ 63 | images = [t.image for t in batch_list] 64 | annos = [t.anno_hw for t in batch_list] 65 | image_tensor = torch.stack( 66 | [t.image_tensor for t in batch_list], dim=0) 67 | segmentation_tensor = torch.stack( 68 | [t.segmentation_tensor for t in batch_list], dim=0) 69 | names = [t.name for t in batch_list] 70 | batch_sample = Sample( 71 | images, 72 | annos, 73 | names, 74 | image_tensor, 75 | segmentation_tensor) 76 | return batch_sample 77 | 78 | 79 | class CustomDataset: 80 | """ Class that reads samples of a custom dataset from a disk and preprocesses them. """ 81 | 82 | def __init__(self, folder, preprocess_fn=None): 83 | self._ext = ".png" 84 | self._folder = folder 85 | self._preprocess_fn = preprocess_fn 86 | name_list = glob.glob(os.path.join(folder, "*" + self._ext)) 87 | raw_name_list = [os.path.splitext(os.path.split(p)[-1])[0] for p in name_list] 88 | self._name_list = [n for n in raw_name_list if self._parse_name(n) is not None] 89 | if len(self._name_list) == 0: 90 | print("Warning: custom dataset not found!") 91 | 92 | def get_list(self): 93 | return self._name_list 94 | 95 | def get_item(self, item): 96 | path = os.path.join(self._folder, item + self._ext) 97 | image_pil = Image.open(path) 98 | image_hwc = np.asarray(image_pil) 99 | anno_hw = self._parse_name(item) 100 | if self._preprocess_fn is not None: 101 | image_tensor_chw = self._preprocess_fn(image_hwc) 102 | else: 103 | image_tensor_chw = None 104 | return Sample(image_hwc, anno_hw, item, image_tensor_chw, None) 105 | 106 | def _parse_name(self, name: str) -> Union[np.ndarray, None]: 107 | word_list = name.split("_") 108 | if len(word_list) == 5: 109 | width = int(word_list[2]) 110 | height = int(word_list[4]) 111 | return np.array((height, width), dtype=np.int) 112 | else: 113 | return None 114 | 115 | def get_image_shape(self): 116 | sample = self.get_item(self._name_list[0]) 117 | return list(sample.image_tensor.size()) 118 | 119 | 120 | class SyntheticDataset: 121 | """ 122 | Since the custom dataset is too small, we train on a generated images with 123 | presumably wider representativeness of samples. 124 | """ 125 | 126 | def __init__(self, image_shape_chw): 127 | self._image_shape_chw = image_shape_chw 128 | 129 | def generate(self): 130 | image_chw = np.zeros(self._image_shape_chw, dtype=np.float32) 131 | segmentation_tensor = np.zeros(self._image_shape_chw[1:], dtype=np.float32) 132 | anno_hw = np.zeros(4, dtype=np.int) 133 | 134 | h, w = self._image_shape_chw[1:] 135 | 136 | min_h, min_w = h // 5, w // 5 137 | 138 | color_bg = np.random.randint(256, size=3) 139 | color_fg = np.random.randint(256, size=3) 140 | color_bg = np.reshape(color_bg, (-1, 1, 1)) 141 | color_fg = np.reshape(color_fg, (-1, 1, 1)) 142 | # Continue mining until a big enough sample is found 143 | while True: 144 | left, right = np.sort(np.random.randint(w, size=2)) 145 | top, bottom = np.sort(np.random.randint(h, size=2)) 146 | if right - left < min_w: 147 | continue 148 | if bottom - top < min_h: 149 | continue 150 | break 151 | 152 | # Create input image 153 | image_chw[...] = color_bg 154 | image_chw[:, top:bottom, left:right] = color_fg 155 | 156 | # Create target segmentation 157 | segmentation_tensor[top:bottom, left:right] = 1.0 158 | 159 | anno_hw[0] = bottom - top 160 | anno_hw[1] = right - left 161 | 162 | image_float_chw = image_chw / 255 163 | image_hwc = np.transpose(image_chw, (1, 2, 0)) 164 | image_tensor_chw = torch.from_numpy(image_float_chw) 165 | segmentation_tensor = torch.from_numpy(segmentation_tensor) 166 | 167 | return Sample(image_hwc, anno_hw, "synthetic", image_tensor_chw, 168 | segmentation_tensor) 169 | 170 | 171 | class DatasetDispatcher: 172 | """ 173 | Class to multiplex synthetic and custom datasets. 174 | Provides generators for train and validation batches. 175 | """ 176 | 177 | def __init__( 178 | self, 179 | work_shape = (3, 128, 128) # work resolution of a NN 180 | ): 181 | 182 | self._custom_dataset = CustomDataset("data/", self.preprocess_fn) 183 | self._image_shape = work_shape 184 | self._synthetic_dataset = SyntheticDataset(self._image_shape) 185 | self._val_names = self._custom_dataset.get_list() 186 | print("Custom dataset size =", len(self._val_names)) 187 | 188 | @staticmethod 189 | def preprocess_fn(image_hwc: np.ndarray) -> torch.Tensor: 190 | image_hwc3 = image_hwc[:, :, :3] # throw away alpha channel 191 | image_chw = np.transpose(image_hwc3, (2, 0, 1)) / 255.0 192 | image_tensor_chw = torch.from_numpy(image_chw).type(torch.float32) 193 | return image_tensor_chw 194 | 195 | def tensor_custom_to_work_reso(self, tensor): 196 | tensor = F.interpolate( 197 | tensor.unsqueeze(0), 198 | size=self._image_shape[1:], 199 | mode='nearest').squeeze(0) 200 | return tensor 201 | 202 | def tensor_work_to_custom(self, tensor): 203 | tensor = F.interpolate( 204 | tensor.unsqueeze(0), 205 | size=self._custom_dataset.get_image_shape()[1:], 206 | mode='nearest').squeeze(0) 207 | return tensor 208 | 209 | def decode_prediction(self, pred: torch.Tensor) -> np.ndarray: 210 | pred_custom_res = self.tensor_work_to_custom(pred) 211 | mask = pred_custom_res > 0.5 212 | area = mask.sum() 213 | return area.detach().cpu().item() 214 | 215 | def train_gen(self, batches_per_epoch, batch_size): 216 | """ Train generator returns full-size batches. """ 217 | for i in range(batches_per_epoch): 218 | sample_list = [] 219 | for ib in range(batch_size): 220 | sample = self._synthetic_dataset.generate() 221 | sample_list.append(sample) 222 | batch = Sample.collate(sample_list) 223 | yield batch 224 | 225 | def val_gen(self): 226 | """ Validation generator returns quasi-batches of size 1. """ 227 | for name in self._val_names: 228 | sample = self._custom_dataset.get_item(name) 229 | sample.image_tensor = \ 230 | self.tensor_custom_to_work_reso(sample.image_tensor) 231 | yield sample 232 | 233 | def get_image_shape(self): 234 | return self._image_shape 235 | 236 | 237 | def Conv1x1(in_channels, out_channels): 238 | return nn.Conv2d( 239 | in_channels, 240 | out_channels, 241 | kernel_size=1, 242 | stride=1) 243 | 244 | 245 | def Conv3x3(in_channels, out_channels): 246 | return nn.Conv2d( 247 | in_channels, 248 | out_channels, 249 | kernel_size=3, 250 | stride=1, 251 | padding=1, 252 | bias=True) 253 | 254 | 255 | def Upconv2x2(in_channels, out_channels): 256 | return nn.ConvTranspose2d( 257 | in_channels, 258 | out_channels, 259 | kernel_size=2, 260 | stride=2) 261 | 262 | 263 | class DownConv(nn.Module): 264 | """ Block UNet encoder. """ 265 | 266 | def __init__(self, in_channels, out_channels, has_pool=True): 267 | super().__init__() 268 | 269 | self._has_pool = has_pool 270 | 271 | self._conv1 = Conv3x3(in_channels, out_channels) 272 | self._conv2 = Conv3x3(out_channels, out_channels) 273 | if self._has_pool: 274 | self._pool = nn.MaxPool2d(kernel_size=2, stride=2) 275 | 276 | def forward(self, x): 277 | x = F.relu(self._conv1(x)) 278 | x = F.relu(self._conv2(x)) 279 | shortcut = x 280 | if self._has_pool: 281 | x = self._pool(x) 282 | return x, shortcut 283 | 284 | 285 | class UpConv(nn.Module): 286 | """ Block of UNet decoder. """ 287 | 288 | def __init__(self, in_channels, out_channels): 289 | super().__init__() 290 | 291 | self._upconv = Upconv2x2(in_channels, out_channels) 292 | self._conv1 = Conv3x3(2 * out_channels, out_channels) 293 | self._conv2 = Conv3x3(out_channels, out_channels) 294 | 295 | def forward(self, shortcut, decoder_path): 296 | 297 | decoder_path = self._upconv(decoder_path) 298 | x = torch.cat((decoder_path, shortcut), 1) 299 | x = F.relu(self._conv1(x)) 300 | x = F.relu(self._conv2(x)) 301 | return x 302 | 303 | 304 | class UNet(nn.Module): 305 | """ Implementation of UNet based on https://arxiv.org/abs/1505.04597 """ 306 | 307 | def __init__( 308 | self, 309 | out_channels, 310 | in_channels, 311 | num_levels=5, 312 | start_channels=64): 313 | 314 | super().__init__() 315 | 316 | self._out_channels = out_channels 317 | self._in_channels = in_channels 318 | self._start_channels = start_channels 319 | self._depth = num_levels 320 | 321 | down_convs = [] 322 | up_convs = [] 323 | 324 | # Create encoder blocks 325 | outs = None 326 | for i in range(num_levels): 327 | ins = self._in_channels if i == 0 else outs 328 | outs = self._start_channels * (2 ** i) 329 | pooling = True if i < num_levels - 1 else False 330 | 331 | down_conv = DownConv(ins, outs, has_pool=pooling) 332 | down_convs.append(down_conv) 333 | 334 | # Create decoder blocks 335 | for i in range(num_levels - 1): 336 | ins = outs 337 | outs = ins // 2 338 | up_conv = UpConv(ins, outs) 339 | up_convs.append(up_conv) 340 | 341 | self._last_conv = Conv1x1(outs, self._out_channels) 342 | 343 | self._down_convs = nn.ModuleList(down_convs) 344 | self._up_convs = nn.ModuleList(up_convs) 345 | 346 | self._init_weights() 347 | 348 | def _init_weights(self): 349 | for i, m in enumerate(self.modules()): 350 | if isinstance(m, nn.Conv2d): 351 | nn.init.xavier_normal_(m.weight) 352 | nn.init.constant_(m.bias, 0) 353 | 354 | def forward(self, x): 355 | encoder_outs = [] 356 | 357 | for _, module in enumerate(self._down_convs): 358 | x, shortcut = module(x) 359 | encoder_outs.append(shortcut) 360 | 361 | for i, module in enumerate(self._up_convs): 362 | shortcut = encoder_outs[-(i + 2)] 363 | x = module(shortcut, x) 364 | 365 | x = self._last_conv(x) 366 | 367 | return x 368 | 369 | 370 | class Net(nn.Module): 371 | """ The entire NN. Encapsulates UNet and corresponding loss. """ 372 | 373 | def __init__(self, input_shape_chw): 374 | super().__init__() 375 | 376 | self.input_shape_chw = input_shape_chw 377 | 378 | start_channels = 8 379 | out_channels = 1 380 | 381 | num_levels = min( 382 | int(math.log2(self.input_shape_chw[1])), 383 | int(math.log2(self.input_shape_chw[2])), 384 | ) 385 | print("num_levels = ", num_levels) 386 | 387 | self._net = UNet( 388 | out_channels, input_shape_chw[0], 389 | start_channels=start_channels, num_levels=num_levels) 390 | 391 | self._seg_loss = nn.modules.loss.BCELoss() 392 | 393 | def forward(self, image_tensor_batch: torch.Tensor): 394 | assert len(list(image_tensor_batch.size())) >= 3 395 | logits = self._net(image_tensor_batch) 396 | logits = logits.squeeze(1) 397 | pred = torch.sigmoid(logits) 398 | return pred 399 | 400 | def loss(self, pred: torch.Tensor, segmentation_gt: torch.Tensor): 401 | """ 402 | Loss could be more complex, but we just go with 403 | vanilla binary cross entropy. 404 | """ 405 | 406 | if segmentation_gt is not None: 407 | segmentation_loss = self._seg_loss( 408 | pred.view(-1), 409 | segmentation_gt.view(-1) 410 | ) 411 | else: 412 | segmentation_loss = torch.zeros((1,), dtype=torch.float32) 413 | 414 | total_loss = segmentation_loss 415 | details = { 416 | "seg_loss": segmentation_loss.detach().cpu().item() 417 | } 418 | return total_loss, details 419 | 420 | 421 | class Trainer: 422 | """ Class to manage train/validation cycles. """ 423 | 424 | def __init__(self, load_last_snapshot=False): 425 | """ 426 | Constructor. 427 | :param load_last_snapshot: whether to load the last snapshot from disk 428 | """ 429 | 430 | has_gpu = torch.cuda.device_count() > 0 431 | if has_gpu: 432 | print(torch.cuda.get_device_name(0)) 433 | else: 434 | print("GPU not found") 435 | self.use_gpu = has_gpu 436 | 437 | self._dispatcher = DatasetDispatcher() 438 | self._net = Net(self._dispatcher.get_image_shape()) 439 | if self.use_gpu: 440 | self._net.cuda() 441 | 442 | self._snapshot_name = "snapshot.pth" 443 | if load_last_snapshot: 444 | load_kwargs = {} if self.use_gpu else {'map_location': 'cpu'} 445 | self._net.load_state_dict(torch.load(self._snapshot_name, **load_kwargs)) 446 | 447 | # Print summary in Keras style 448 | shape_chw = tuple(self._dispatcher.get_image_shape()) 449 | print("Work image shape =", shape_chw) 450 | torchsummary.summary(self._net, input_size=shape_chw) 451 | pass 452 | 453 | def train(self): 454 | """ Perform training of the network. """ 455 | 456 | num_epochs = 50 457 | batch_size = 16 458 | batches_per_epoch = 1024 459 | learning_rate = 0.02 460 | 461 | optimizer = torch.optim.SGD(self._net.parameters(), lr=learning_rate) 462 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 463 | optimizer, [40, 45], gamma=0.1, last_epoch=-1) 464 | 465 | training_start_time = time.time() 466 | 467 | self.validate() 468 | 469 | for epoch in range(num_epochs): 470 | print("Epoch ------ ", epoch) 471 | 472 | train_gen = self._dispatcher.train_gen(batches_per_epoch, batch_size) 473 | 474 | self._net.train() 475 | 476 | for batch_index, batch in enumerate(train_gen): 477 | if self.use_gpu: 478 | batch.cuda() 479 | 480 | pred = self._net.forward(batch.image_tensor) 481 | 482 | loss, details = self._net.loss(pred, batch.segmentation_tensor) 483 | 484 | if batch_index % 50 == 0: 485 | print("epoch={} batch={} loss={:.4f}".format( 486 | epoch, batch_index, loss.item() 487 | )) 488 | self._render_prediction( 489 | pred.detach().cpu().numpy()[0], 490 | batch.segmentation_tensor.detach().cpu().numpy()[0], 491 | batch.image_tensor.detach().cpu().numpy()[0].transpose((1, 2, 0))) 492 | print("-------------------------------") 493 | 494 | optimizer.zero_grad() 495 | loss.backward() 496 | optimizer.step() 497 | pass 498 | 499 | scheduler.step() 500 | 501 | # Save after every epoch 502 | torch.save(self._net.state_dict(), self._snapshot_name) 503 | 504 | # Validate every epoch 505 | self.validate() 506 | 507 | pass 508 | # end of epoch 509 | 510 | training_end_time = time.time() 511 | print("Training took {} hours".format( 512 | (training_end_time - training_start_time)/3600)) 513 | 514 | print("Train finished!") 515 | 516 | def validate(self): 517 | """ Validation cycle. Performed over a custom dataset. """ 518 | print("Validation") 519 | 520 | self._net.eval() 521 | 522 | val_gen = self._dispatcher.val_gen() 523 | relative_error_list = [] 524 | 525 | for sample_idx, sample in enumerate(val_gen): 526 | if self.use_gpu: 527 | sample.cuda() 528 | sample.batchify() 529 | 530 | pred = self._net.forward(sample.image_tensor) 531 | loss, details = self._net.loss(pred, sample.segmentation_tensor) 532 | 533 | pred_area = self._dispatcher.decode_prediction(pred) 534 | 535 | anno_hw = sample.anno_hw 536 | gt_area = anno_hw[0] * anno_hw[1] 537 | relative_error = abs(pred_area - gt_area) / gt_area 538 | relative_error_list.append(relative_error) 539 | 540 | if sample_idx % 20 == 0: 541 | print("loss={:.4f} gt_area={} pred_area={}".format( 542 | loss.item(), gt_area, pred_area 543 | )) 544 | self._render_prediction( 545 | pred.detach().cpu().numpy()[0], 546 | None, 547 | sample.image_tensor.detach().cpu().numpy()[0].transpose((1, 2, 0))) 548 | 549 | average_relative_error = \ 550 | np.array(relative_error_list).sum() / len(relative_error_list) 551 | print("-------- Final metric -----------") 552 | print("Average relative area error = {:0.6f}".format(average_relative_error)) 553 | 554 | pass 555 | 556 | def get_net(self): 557 | return self._net 558 | 559 | 560 | def _render_prediction(self, pred: np.ndarray, gt: np.ndarray, input_image: np.ndarray): 561 | """ 562 | This function visualizes predictions. Works nicely only 563 | in ipython notebook thus commented out here. 564 | """ 565 | 566 | # % matplotlib inline 567 | # import matplotlib.pyplot as plt 568 | # 569 | # fig = plt.figure(figsize=(10, 3)) 570 | # fig.add_subplot(1, 3, 1) 571 | # plt.imshow(pred, cmap='gray', vmin=0.0, vmax=1.0) 572 | # fig.add_subplot(1, 3, 2) 573 | # if gt is not None: 574 | # plt.imshow(gt, cmap='gray', vmin=0.0, vmax=1.0) 575 | # fig.add_subplot(1, 3, 3) 576 | # plt.imshow(input_image, vmin=0.0, vmax=1.0) 577 | # 578 | # plt.show() 579 | 580 | pass 581 | 582 | 583 | 584 | class Predictor: 585 | """ Function object to perform predictions. """ 586 | 587 | def __init__(self): 588 | self._trainer = Trainer(load_last_snapshot=True) 589 | 590 | def __call__(self, img: np.ndarray) -> int: 591 | """ 592 | Makes prediction of the area based on the input image. 593 | :param img: numpy image in format [H, W, C] where C is RGB or RGBA 594 | :return: area of a rectangle in pixels 595 | """ 596 | assert len(img.shape) == 3 597 | assert img.shape[2] in (3, 4) 598 | dispatcher = DatasetDispatcher() 599 | tensor = dispatcher.preprocess_fn(img) 600 | tensor = dispatcher.tensor_custom_to_work_reso(tensor) 601 | tensor = tensor.unsqueeze(0) 602 | net = self._trainer.get_net() 603 | net.eval() 604 | pred = net(tensor) 605 | result = dispatcher.decode_prediction(pred) 606 | return result 607 | 608 | 609 | def main(): 610 | """ 611 | Default mode of operation is training. To run validation of a trained 612 | model over a custom dataset, run: 613 | trainer.py --validate 614 | """ 615 | parser = ArgumentParser() 616 | parser.add_argument('--validate', default=False, action='store_true') 617 | args = parser.parse_args() 618 | 619 | if args.validate: 620 | trainer = Trainer(load_last_snapshot=True) 621 | trainer.validate() 622 | else: 623 | trainer = Trainer() 624 | trainer.train() 625 | pass 626 | 627 | 628 | if __name__ == "__main__": 629 | main() 630 | 631 | --------------------------------------------------------------------------------