├── .gitignore ├── LICENSE ├── README.md ├── helper.py ├── images ├── output_0_1.png ├── output_2_2.png └── output_9_1.png ├── loss.py ├── pytorch_fcn.ipynb ├── pytorch_resnet18_unet.ipynb ├── pytorch_unet.ipynb ├── pytorch_unet.py ├── pytorch_unet_resnet18_colab.ipynb └── simulation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Naoto Usuyama 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # UNet/FCN PyTorch [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/usuyama/pytorch-unet/blob/master/pytorch_unet_resnet18_colab.ipynb) 3 | 4 | This repository contains simple PyTorch implementations of U-Net and FCN, which are deep learning segmentation methods proposed by Ronneberger et al. and Long et al. 5 | 6 | - [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) 7 | - [Fully Convolutional Networks for Semantic Segmentation](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf) 8 | 9 | # Synthetic images/masks for training 10 | 11 | First clone the repository and cd into the project directory. 12 | 13 | ```python 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import helper 17 | import simulation 18 | 19 | # Generate some random images 20 | input_images, target_masks = simulation.generate_random_data(192, 192, count=3) 21 | 22 | for x in [input_images, target_masks]: 23 | print(x.shape) 24 | print(x.min(), x.max()) 25 | 26 | # Change channel-order and make 3 channels for matplot 27 | input_images_rgb = [x.astype(np.uint8) for x in input_images] 28 | 29 | # Map each channel (i.e. class) to each color 30 | target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks] 31 | 32 | # Left: Input image (black and white), Right: Target mask (6ch) 33 | helper.plot_side_by_side([input_images_rgb, target_masks_rgb]) 34 | ``` 35 | 36 | ## Left: Input image (black and white), Right: Target mask (6ch) 37 | ![png](https://raw.githubusercontent.com/usuyama/pytorch-unet/master/images/output_0_1.png) 38 | 39 | 40 | ## Prepare Dataset and DataLoader 41 | ```python 42 | from torch.utils.data import Dataset, DataLoader 43 | from torchvision import transforms, datasets, models 44 | 45 | class SimDataset(Dataset): 46 | def __init__(self, count, transform=None): 47 | self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count) 48 | self.transform = transform 49 | 50 | def __len__(self): 51 | return len(self.input_images) 52 | 53 | def __getitem__(self, idx): 54 | image = self.input_images[idx] 55 | mask = self.target_masks[idx] 56 | if self.transform: 57 | image = self.transform(image) 58 | 59 | return [image, mask] 60 | 61 | # use the same transformations for train/val in this example 62 | trans = transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet 65 | ]) 66 | 67 | train_set = SimDataset(2000, transform = trans) 68 | val_set = SimDataset(200, transform = trans) 69 | 70 | image_datasets = { 71 | 'train': train_set, 'val': val_set 72 | } 73 | 74 | batch_size = 25 75 | 76 | dataloaders = { 77 | 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0), 78 | 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0) 79 | } 80 | ``` 81 | 82 | ## Check the outputs from DataLoader 83 | ```python 84 | import torchvision.utils 85 | 86 | def reverse_transform(inp): 87 | inp = inp.numpy().transpose((1, 2, 0)) 88 | mean = np.array([0.485, 0.456, 0.406]) 89 | std = np.array([0.229, 0.224, 0.225]) 90 | inp = std * inp + mean 91 | inp = np.clip(inp, 0, 1) 92 | inp = (inp * 255).astype(np.uint8) 93 | 94 | return inp 95 | 96 | # Get a batch of training data 97 | inputs, masks = next(iter(dataloaders['train'])) 98 | 99 | print(inputs.shape, masks.shape) 100 | 101 | plt.imshow(reverse_transform(inputs[3])) 102 | ``` 103 | 104 | torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192]) 105 | 106 | 107 | ![png](https://raw.githubusercontent.com/usuyama/pytorch-unet/master/images/output_2_2.png) 108 | 109 | 110 | 111 | # Create the UNet module 112 | 113 | ```python 114 | import torch 115 | import torch.nn as nn 116 | from torchvision import models 117 | 118 | def convrelu(in_channels, out_channels, kernel, padding): 119 | return nn.Sequential( 120 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 121 | nn.ReLU(inplace=True), 122 | ) 123 | 124 | 125 | class ResNetUNet(nn.Module): 126 | def __init__(self, n_class): 127 | super().__init__() 128 | 129 | self.base_model = models.resnet18(pretrained=True) 130 | self.base_layers = list(self.base_model.children()) 131 | 132 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) 133 | self.layer0_1x1 = convrelu(64, 64, 1, 0) 134 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) 135 | self.layer1_1x1 = convrelu(64, 64, 1, 0) 136 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) 137 | self.layer2_1x1 = convrelu(128, 128, 1, 0) 138 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) 139 | self.layer3_1x1 = convrelu(256, 256, 1, 0) 140 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) 141 | self.layer4_1x1 = convrelu(512, 512, 1, 0) 142 | 143 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 144 | 145 | self.conv_up3 = convrelu(256 + 512, 512, 3, 1) 146 | self.conv_up2 = convrelu(128 + 512, 256, 3, 1) 147 | self.conv_up1 = convrelu(64 + 256, 256, 3, 1) 148 | self.conv_up0 = convrelu(64 + 256, 128, 3, 1) 149 | 150 | self.conv_original_size0 = convrelu(3, 64, 3, 1) 151 | self.conv_original_size1 = convrelu(64, 64, 3, 1) 152 | self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) 153 | 154 | self.conv_last = nn.Conv2d(64, n_class, 1) 155 | 156 | def forward(self, input): 157 | x_original = self.conv_original_size0(input) 158 | x_original = self.conv_original_size1(x_original) 159 | 160 | layer0 = self.layer0(input) 161 | layer1 = self.layer1(layer0) 162 | layer2 = self.layer2(layer1) 163 | layer3 = self.layer3(layer2) 164 | layer4 = self.layer4(layer3) 165 | 166 | layer4 = self.layer4_1x1(layer4) 167 | x = self.upsample(layer4) 168 | layer3 = self.layer3_1x1(layer3) 169 | x = torch.cat([x, layer3], dim=1) 170 | x = self.conv_up3(x) 171 | 172 | x = self.upsample(x) 173 | layer2 = self.layer2_1x1(layer2) 174 | x = torch.cat([x, layer2], dim=1) 175 | x = self.conv_up2(x) 176 | 177 | x = self.upsample(x) 178 | layer1 = self.layer1_1x1(layer1) 179 | x = torch.cat([x, layer1], dim=1) 180 | x = self.conv_up1(x) 181 | 182 | x = self.upsample(x) 183 | layer0 = self.layer0_1x1(layer0) 184 | x = torch.cat([x, layer0], dim=1) 185 | x = self.conv_up0(x) 186 | 187 | x = self.upsample(x) 188 | x = torch.cat([x, x_original], dim=1) 189 | x = self.conv_original_size2(x) 190 | 191 | out = self.conv_last(x) 192 | 193 | return out 194 | ``` 195 | 196 | ## Model summary 197 | ```python 198 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 199 | model = ResNetUNet(n_class=6) 200 | model = model.to(device) 201 | 202 | # check keras-like model summary using torchsummary 203 | from torchsummary import summary 204 | summary(model, input_size=(3, 224, 224)) 205 | ``` 206 | 207 | ---------------------------------------------------------------- 208 | Layer (type) Output Shape Param # 209 | ================================================================ 210 | Conv2d-1 [-1, 64, 224, 224] 1,792 211 | ReLU-2 [-1, 64, 224, 224] 0 212 | Conv2d-3 [-1, 64, 224, 224] 36,928 213 | ReLU-4 [-1, 64, 224, 224] 0 214 | Conv2d-5 [-1, 64, 112, 112] 9,408 215 | BatchNorm2d-6 [-1, 64, 112, 112] 128 216 | ReLU-7 [-1, 64, 112, 112] 0 217 | MaxPool2d-8 [-1, 64, 56, 56] 0 218 | Conv2d-9 [-1, 64, 56, 56] 4,096 219 | BatchNorm2d-10 [-1, 64, 56, 56] 128 220 | ReLU-11 [-1, 64, 56, 56] 0 221 | Conv2d-12 [-1, 64, 56, 56] 36,864 222 | BatchNorm2d-13 [-1, 64, 56, 56] 128 223 | ReLU-14 [-1, 64, 56, 56] 0 224 | Conv2d-15 [-1, 256, 56, 56] 16,384 225 | BatchNorm2d-16 [-1, 256, 56, 56] 512 226 | Conv2d-17 [-1, 256, 56, 56] 16,384 227 | BatchNorm2d-18 [-1, 256, 56, 56] 512 228 | ReLU-19 [-1, 256, 56, 56] 0 229 | Bottleneck-20 [-1, 256, 56, 56] 0 230 | Conv2d-21 [-1, 64, 56, 56] 16,384 231 | BatchNorm2d-22 [-1, 64, 56, 56] 128 232 | ReLU-23 [-1, 64, 56, 56] 0 233 | Conv2d-24 [-1, 64, 56, 56] 36,864 234 | BatchNorm2d-25 [-1, 64, 56, 56] 128 235 | ReLU-26 [-1, 64, 56, 56] 0 236 | Conv2d-27 [-1, 256, 56, 56] 16,384 237 | BatchNorm2d-28 [-1, 256, 56, 56] 512 238 | ReLU-29 [-1, 256, 56, 56] 0 239 | Bottleneck-30 [-1, 256, 56, 56] 0 240 | Conv2d-31 [-1, 64, 56, 56] 16,384 241 | BatchNorm2d-32 [-1, 64, 56, 56] 128 242 | ReLU-33 [-1, 64, 56, 56] 0 243 | Conv2d-34 [-1, 64, 56, 56] 36,864 244 | BatchNorm2d-35 [-1, 64, 56, 56] 128 245 | ReLU-36 [-1, 64, 56, 56] 0 246 | Conv2d-37 [-1, 256, 56, 56] 16,384 247 | BatchNorm2d-38 [-1, 256, 56, 56] 512 248 | ReLU-39 [-1, 256, 56, 56] 0 249 | Bottleneck-40 [-1, 256, 56, 56] 0 250 | Conv2d-41 [-1, 128, 56, 56] 32,768 251 | BatchNorm2d-42 [-1, 128, 56, 56] 256 252 | ReLU-43 [-1, 128, 56, 56] 0 253 | Conv2d-44 [-1, 128, 28, 28] 147,456 254 | BatchNorm2d-45 [-1, 128, 28, 28] 256 255 | ReLU-46 [-1, 128, 28, 28] 0 256 | Conv2d-47 [-1, 512, 28, 28] 65,536 257 | BatchNorm2d-48 [-1, 512, 28, 28] 1,024 258 | Conv2d-49 [-1, 512, 28, 28] 131,072 259 | BatchNorm2d-50 [-1, 512, 28, 28] 1,024 260 | ReLU-51 [-1, 512, 28, 28] 0 261 | Bottleneck-52 [-1, 512, 28, 28] 0 262 | Conv2d-53 [-1, 128, 28, 28] 65,536 263 | BatchNorm2d-54 [-1, 128, 28, 28] 256 264 | ReLU-55 [-1, 128, 28, 28] 0 265 | Conv2d-56 [-1, 128, 28, 28] 147,456 266 | BatchNorm2d-57 [-1, 128, 28, 28] 256 267 | ReLU-58 [-1, 128, 28, 28] 0 268 | Conv2d-59 [-1, 512, 28, 28] 65,536 269 | BatchNorm2d-60 [-1, 512, 28, 28] 1,024 270 | ReLU-61 [-1, 512, 28, 28] 0 271 | Bottleneck-62 [-1, 512, 28, 28] 0 272 | Conv2d-63 [-1, 128, 28, 28] 65,536 273 | BatchNorm2d-64 [-1, 128, 28, 28] 256 274 | ReLU-65 [-1, 128, 28, 28] 0 275 | Conv2d-66 [-1, 128, 28, 28] 147,456 276 | BatchNorm2d-67 [-1, 128, 28, 28] 256 277 | ReLU-68 [-1, 128, 28, 28] 0 278 | Conv2d-69 [-1, 512, 28, 28] 65,536 279 | BatchNorm2d-70 [-1, 512, 28, 28] 1,024 280 | ReLU-71 [-1, 512, 28, 28] 0 281 | Bottleneck-72 [-1, 512, 28, 28] 0 282 | Conv2d-73 [-1, 128, 28, 28] 65,536 283 | BatchNorm2d-74 [-1, 128, 28, 28] 256 284 | ReLU-75 [-1, 128, 28, 28] 0 285 | Conv2d-76 [-1, 128, 28, 28] 147,456 286 | BatchNorm2d-77 [-1, 128, 28, 28] 256 287 | ReLU-78 [-1, 128, 28, 28] 0 288 | Conv2d-79 [-1, 512, 28, 28] 65,536 289 | BatchNorm2d-80 [-1, 512, 28, 28] 1,024 290 | ReLU-81 [-1, 512, 28, 28] 0 291 | Bottleneck-82 [-1, 512, 28, 28] 0 292 | Conv2d-83 [-1, 256, 28, 28] 131,072 293 | BatchNorm2d-84 [-1, 256, 28, 28] 512 294 | ReLU-85 [-1, 256, 28, 28] 0 295 | Conv2d-86 [-1, 256, 14, 14] 589,824 296 | BatchNorm2d-87 [-1, 256, 14, 14] 512 297 | ReLU-88 [-1, 256, 14, 14] 0 298 | Conv2d-89 [-1, 1024, 14, 14] 262,144 299 | BatchNorm2d-90 [-1, 1024, 14, 14] 2,048 300 | Conv2d-91 [-1, 1024, 14, 14] 524,288 301 | BatchNorm2d-92 [-1, 1024, 14, 14] 2,048 302 | ReLU-93 [-1, 1024, 14, 14] 0 303 | Bottleneck-94 [-1, 1024, 14, 14] 0 304 | Conv2d-95 [-1, 256, 14, 14] 262,144 305 | BatchNorm2d-96 [-1, 256, 14, 14] 512 306 | ReLU-97 [-1, 256, 14, 14] 0 307 | Conv2d-98 [-1, 256, 14, 14] 589,824 308 | BatchNorm2d-99 [-1, 256, 14, 14] 512 309 | ReLU-100 [-1, 256, 14, 14] 0 310 | Conv2d-101 [-1, 1024, 14, 14] 262,144 311 | BatchNorm2d-102 [-1, 1024, 14, 14] 2,048 312 | ReLU-103 [-1, 1024, 14, 14] 0 313 | Bottleneck-104 [-1, 1024, 14, 14] 0 314 | Conv2d-105 [-1, 256, 14, 14] 262,144 315 | BatchNorm2d-106 [-1, 256, 14, 14] 512 316 | ReLU-107 [-1, 256, 14, 14] 0 317 | Conv2d-108 [-1, 256, 14, 14] 589,824 318 | BatchNorm2d-109 [-1, 256, 14, 14] 512 319 | ReLU-110 [-1, 256, 14, 14] 0 320 | Conv2d-111 [-1, 1024, 14, 14] 262,144 321 | BatchNorm2d-112 [-1, 1024, 14, 14] 2,048 322 | ReLU-113 [-1, 1024, 14, 14] 0 323 | Bottleneck-114 [-1, 1024, 14, 14] 0 324 | Conv2d-115 [-1, 256, 14, 14] 262,144 325 | BatchNorm2d-116 [-1, 256, 14, 14] 512 326 | ReLU-117 [-1, 256, 14, 14] 0 327 | Conv2d-118 [-1, 256, 14, 14] 589,824 328 | BatchNorm2d-119 [-1, 256, 14, 14] 512 329 | ReLU-120 [-1, 256, 14, 14] 0 330 | Conv2d-121 [-1, 1024, 14, 14] 262,144 331 | BatchNorm2d-122 [-1, 1024, 14, 14] 2,048 332 | ReLU-123 [-1, 1024, 14, 14] 0 333 | Bottleneck-124 [-1, 1024, 14, 14] 0 334 | Conv2d-125 [-1, 256, 14, 14] 262,144 335 | BatchNorm2d-126 [-1, 256, 14, 14] 512 336 | ReLU-127 [-1, 256, 14, 14] 0 337 | Conv2d-128 [-1, 256, 14, 14] 589,824 338 | BatchNorm2d-129 [-1, 256, 14, 14] 512 339 | ReLU-130 [-1, 256, 14, 14] 0 340 | Conv2d-131 [-1, 1024, 14, 14] 262,144 341 | BatchNorm2d-132 [-1, 1024, 14, 14] 2,048 342 | ReLU-133 [-1, 1024, 14, 14] 0 343 | Bottleneck-134 [-1, 1024, 14, 14] 0 344 | Conv2d-135 [-1, 256, 14, 14] 262,144 345 | BatchNorm2d-136 [-1, 256, 14, 14] 512 346 | ReLU-137 [-1, 256, 14, 14] 0 347 | Conv2d-138 [-1, 256, 14, 14] 589,824 348 | BatchNorm2d-139 [-1, 256, 14, 14] 512 349 | ReLU-140 [-1, 256, 14, 14] 0 350 | Conv2d-141 [-1, 1024, 14, 14] 262,144 351 | BatchNorm2d-142 [-1, 1024, 14, 14] 2,048 352 | ReLU-143 [-1, 1024, 14, 14] 0 353 | Bottleneck-144 [-1, 1024, 14, 14] 0 354 | Conv2d-145 [-1, 512, 14, 14] 524,288 355 | BatchNorm2d-146 [-1, 512, 14, 14] 1,024 356 | ReLU-147 [-1, 512, 14, 14] 0 357 | Conv2d-148 [-1, 512, 7, 7] 2,359,296 358 | BatchNorm2d-149 [-1, 512, 7, 7] 1,024 359 | ReLU-150 [-1, 512, 7, 7] 0 360 | Conv2d-151 [-1, 2048, 7, 7] 1,048,576 361 | BatchNorm2d-152 [-1, 2048, 7, 7] 4,096 362 | Conv2d-153 [-1, 2048, 7, 7] 2,097,152 363 | BatchNorm2d-154 [-1, 2048, 7, 7] 4,096 364 | ReLU-155 [-1, 2048, 7, 7] 0 365 | Bottleneck-156 [-1, 2048, 7, 7] 0 366 | Conv2d-157 [-1, 512, 7, 7] 1,048,576 367 | BatchNorm2d-158 [-1, 512, 7, 7] 1,024 368 | ReLU-159 [-1, 512, 7, 7] 0 369 | Conv2d-160 [-1, 512, 7, 7] 2,359,296 370 | BatchNorm2d-161 [-1, 512, 7, 7] 1,024 371 | ReLU-162 [-1, 512, 7, 7] 0 372 | Conv2d-163 [-1, 2048, 7, 7] 1,048,576 373 | BatchNorm2d-164 [-1, 2048, 7, 7] 4,096 374 | ReLU-165 [-1, 2048, 7, 7] 0 375 | Bottleneck-166 [-1, 2048, 7, 7] 0 376 | Conv2d-167 [-1, 512, 7, 7] 1,048,576 377 | BatchNorm2d-168 [-1, 512, 7, 7] 1,024 378 | ReLU-169 [-1, 512, 7, 7] 0 379 | Conv2d-170 [-1, 512, 7, 7] 2,359,296 380 | BatchNorm2d-171 [-1, 512, 7, 7] 1,024 381 | ReLU-172 [-1, 512, 7, 7] 0 382 | Conv2d-173 [-1, 2048, 7, 7] 1,048,576 383 | BatchNorm2d-174 [-1, 2048, 7, 7] 4,096 384 | ReLU-175 [-1, 2048, 7, 7] 0 385 | Bottleneck-176 [-1, 2048, 7, 7] 0 386 | Conv2d-177 [-1, 1024, 7, 7] 2,098,176 387 | ReLU-178 [-1, 1024, 7, 7] 0 388 | Upsample-179 [-1, 1024, 14, 14] 0 389 | Conv2d-180 [-1, 512, 14, 14] 524,800 390 | ReLU-181 [-1, 512, 14, 14] 0 391 | Conv2d-182 [-1, 512, 14, 14] 7,078,400 392 | ReLU-183 [-1, 512, 14, 14] 0 393 | Upsample-184 [-1, 512, 28, 28] 0 394 | Conv2d-185 [-1, 512, 28, 28] 262,656 395 | ReLU-186 [-1, 512, 28, 28] 0 396 | Conv2d-187 [-1, 512, 28, 28] 4,719,104 397 | ReLU-188 [-1, 512, 28, 28] 0 398 | Upsample-189 [-1, 512, 56, 56] 0 399 | Conv2d-190 [-1, 256, 56, 56] 65,792 400 | ReLU-191 [-1, 256, 56, 56] 0 401 | Conv2d-192 [-1, 256, 56, 56] 1,769,728 402 | ReLU-193 [-1, 256, 56, 56] 0 403 | Upsample-194 [-1, 256, 112, 112] 0 404 | Conv2d-195 [-1, 64, 112, 112] 4,160 405 | ReLU-196 [-1, 64, 112, 112] 0 406 | Conv2d-197 [-1, 128, 112, 112] 368,768 407 | ReLU-198 [-1, 128, 112, 112] 0 408 | Upsample-199 [-1, 128, 224, 224] 0 409 | Conv2d-200 [-1, 64, 224, 224] 110,656 410 | ReLU-201 [-1, 64, 224, 224] 0 411 | Conv2d-202 [-1, 6, 224, 224] 390 412 | ================================================================ 413 | Total params: 40,549,382 414 | Trainable params: 40,549,382 415 | Non-trainable params: 0 416 | ---------------------------------------------------------------- 417 | 418 | 419 | # Define the main training loop 420 | 421 | ```python 422 | from collections import defaultdict 423 | import torch.nn.functional as F 424 | from loss import dice_loss 425 | 426 | def calc_loss(pred, target, metrics, bce_weight=0.5): 427 | bce = F.binary_cross_entropy_with_logits(pred, target) 428 | 429 | pred = F.sigmoid(pred) 430 | dice = dice_loss(pred, target) 431 | 432 | loss = bce * bce_weight + dice * (1 - bce_weight) 433 | 434 | metrics['bce'] += bce.data.cpu().numpy() * target.size(0) 435 | metrics['dice'] += dice.data.cpu().numpy() * target.size(0) 436 | metrics['loss'] += loss.data.cpu().numpy() * target.size(0) 437 | 438 | return loss 439 | 440 | def print_metrics(metrics, epoch_samples, phase): 441 | outputs = [] 442 | for k in metrics.keys(): 443 | outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples)) 444 | 445 | print("{}: {}".format(phase, ", ".join(outputs))) 446 | 447 | def train_model(model, optimizer, scheduler, num_epochs=25): 448 | best_model_wts = copy.deepcopy(model.state_dict()) 449 | best_loss = 1e10 450 | 451 | for epoch in range(num_epochs): 452 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 453 | print('-' * 10) 454 | 455 | since = time.time() 456 | 457 | # Each epoch has a training and validation phase 458 | for phase in ['train', 'val']: 459 | if phase == 'train': 460 | scheduler.step() 461 | for param_group in optimizer.param_groups: 462 | print("LR", param_group['lr']) 463 | 464 | model.train() # Set model to training mode 465 | else: 466 | model.eval() # Set model to evaluate mode 467 | 468 | metrics = defaultdict(float) 469 | epoch_samples = 0 470 | 471 | for inputs, labels in dataloaders[phase]: 472 | inputs = inputs.to(device) 473 | labels = labels.to(device) 474 | 475 | # zero the parameter gradients 476 | optimizer.zero_grad() 477 | 478 | # forward 479 | # track history if only in train 480 | with torch.set_grad_enabled(phase == 'train'): 481 | outputs = model(inputs) 482 | loss = calc_loss(outputs, labels, metrics) 483 | 484 | # backward + optimize only if in training phase 485 | if phase == 'train': 486 | loss.backward() 487 | optimizer.step() 488 | 489 | # statistics 490 | epoch_samples += inputs.size(0) 491 | 492 | print_metrics(metrics, epoch_samples, phase) 493 | epoch_loss = metrics['loss'] / epoch_samples 494 | 495 | # deep copy the model 496 | if phase == 'val' and epoch_loss < best_loss: 497 | print("saving best model") 498 | best_loss = epoch_loss 499 | best_model_wts = copy.deepcopy(model.state_dict()) 500 | 501 | time_elapsed = time.time() - since 502 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 503 | 504 | print('Best val loss: {:4f}'.format(best_loss)) 505 | 506 | # load best model weights 507 | model.load_state_dict(best_model_wts) 508 | return model 509 | ``` 510 | 511 | ## Training 512 | ```python 513 | import torch 514 | import torch.optim as optim 515 | from torch.optim import lr_scheduler 516 | import time 517 | import copy 518 | 519 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 520 | print(device) 521 | 522 | num_class = 6 523 | model = ResNetUNet(num_class).to(device) 524 | 525 | # freeze backbone layers 526 | #for l in model.base_layers: 527 | # for param in l.parameters(): 528 | # param.requires_grad = False 529 | 530 | optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4) 531 | 532 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1) 533 | 534 | model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=60) 535 | ``` 536 | 537 | cuda:0 538 | Epoch 0/59 539 | ---------- 540 | LR 0.0001 541 | train: bce: 0.070256, dice: 0.856320, loss: 0.463288 542 | val: bce: 0.014897, dice: 0.515814, loss: 0.265356 543 | saving best model 544 | 0m 51s 545 | Epoch 1/59 546 | ---------- 547 | LR 0.0001 548 | train: bce: 0.011369, dice: 0.309445, loss: 0.160407 549 | val: bce: 0.003790, dice: 0.113682, loss: 0.058736 550 | saving best model 551 | 0m 51s 552 | Epoch 2/59 553 | ---------- 554 | LR 0.0001 555 | train: bce: 0.003480, dice: 0.089928, loss: 0.046704 556 | val: bce: 0.002525, dice: 0.067604, loss: 0.035064 557 | saving best model 558 | 0m 51s 559 | 560 | (Omitted) 561 | 562 | Epoch 57/59 563 | ---------- 564 | LR 1e-05 565 | train: bce: 0.000523, dice: 0.010289, loss: 0.005406 566 | val: bce: 0.001558, dice: 0.030965, loss: 0.016261 567 | 0m 51s 568 | Epoch 58/59 569 | ---------- 570 | LR 1e-05 571 | train: bce: 0.000518, dice: 0.010209, loss: 0.005364 572 | val: bce: 0.001548, dice: 0.031034, loss: 0.016291 573 | 0m 51s 574 | Epoch 59/59 575 | ---------- 576 | LR 1e-05 577 | train: bce: 0.000518, dice: 0.010168, loss: 0.005343 578 | val: bce: 0.001566, dice: 0.030785, loss: 0.016176 579 | 0m 50s 580 | Best val loss: 0.016171 581 | 582 | 583 | ## Use the trained model 584 | 585 | ```python 586 | import math 587 | 588 | model.eval() # Set model to the evaluation mode 589 | 590 | # Create another simulation dataset for test 591 | test_dataset = SimDataset(3, transform = trans) 592 | test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0) 593 | 594 | # Get the first batch 595 | inputs, labels = next(iter(test_loader)) 596 | inputs = inputs.to(device) 597 | labels = labels.to(device) 598 | 599 | # Predict 600 | pred = model(inputs) 601 | # The loss functions include the sigmoid function. 602 | pred = F.sigmoid(pred) 603 | pred = pred.data.cpu().numpy() 604 | print(pred.shape) 605 | 606 | # Change channel-order and make 3 channels for matplot 607 | input_images_rgb = [reverse_transform(x) for x in inputs.cpu()] 608 | 609 | # Map each channel (i.e. class) to each color 610 | target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()] 611 | pred_rgb = [helper.masks_to_colorimg(x) for x in pred] 612 | 613 | helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb]) 614 | ``` 615 | 616 | (3, 6, 192, 192) 617 | 618 | ### Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask 619 | 620 | ![png](https://raw.githubusercontent.com/usuyama/pytorch-unet/master/images/output_9_1.png) 621 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def plot_img_array(img_array, ncol=3): 5 | nrow = len(img_array) // ncol 6 | 7 | f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4)) 8 | 9 | for i in range(len(img_array)): 10 | plots[i // ncol, i % ncol] 11 | plots[i // ncol, i % ncol].imshow(img_array[i]) 12 | 13 | from functools import reduce 14 | def plot_side_by_side(img_arrays): 15 | flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays)) 16 | 17 | plot_img_array(np.array(flatten_list), ncol=len(img_arrays)) 18 | 19 | import itertools 20 | def plot_errors(results_dict, title): 21 | markers = itertools.cycle(('+', 'x', 'o')) 22 | 23 | plt.title('{}'.format(title)) 24 | 25 | for label, result in sorted(results_dict.items()): 26 | plt.plot(result, marker=next(markers), label=label) 27 | plt.ylabel('dice_coef') 28 | plt.xlabel('epoch') 29 | plt.legend(loc=3, bbox_to_anchor=(1, 0)) 30 | 31 | plt.show() 32 | 33 | def masks_to_colorimg(masks): 34 | colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)]) 35 | 36 | colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255 37 | channels, height, width = masks.shape 38 | 39 | for y in range(height): 40 | for x in range(width): 41 | selected_colors = colors[masks[:,y,x] > 0.5] 42 | 43 | if len(selected_colors) > 0: 44 | colorimg[y,x,:] = np.mean(selected_colors, axis=0) 45 | 46 | return colorimg.astype(np.uint8) -------------------------------------------------------------------------------- /images/output_0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usuyama/pytorch-unet/a0ec1374e06ca9161c21781302328b074f19391b/images/output_0_1.png -------------------------------------------------------------------------------- /images/output_2_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usuyama/pytorch-unet/a0ec1374e06ca9161c21781302328b074f19391b/images/output_2_2.png -------------------------------------------------------------------------------- /images/output_9_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usuyama/pytorch-unet/a0ec1374e06ca9161c21781302328b074f19391b/images/output_9_1.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def dice_loss(pred, target, smooth = 1.): 5 | pred = pred.contiguous() 6 | target = target.contiguous() 7 | 8 | intersection = (pred * target).sum(dim=2).sum(dim=2) 9 | 10 | loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))) 11 | 12 | return loss.mean() 13 | -------------------------------------------------------------------------------- /pytorch_fcn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 55, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n", 14 | "(3, 3, 192, 192) (3, 6, 192, 192)\n" 15 | ] 16 | }, 17 | { 18 | "data": { 19 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAKvCAYAAAArysUEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+sXHd95//ne53CH2lWCc1dy3Jymx8yIKi2F7jyFtGgsCltEuVLSCOlsSoINOpN9CXSVlup30CqgipF6rak0VbdBm4UK2YFJrRuSlT525KNUEOrpOQavMYBAnbqKLaMfUlYEgEKdfzeP+65dLie6zsz58xnZs48H9JoznzmnDnv4/j4lc/nnPlMZCaSJKmcfzfqAiRJmjaGryRJhRm+kiQVZvhKklSY4StJUmGGryRJhQ0tfCPi6oh4JiIORcSdw9qPJEmTJobxPd+I2AR8C3g3cBR4CtiRmV9vfGeSJE2YYfV8twOHMvPZzPwx8Fng+iHtS5KkiXLOkD53K/B8x+ujwH9ab+ULL7wwL7nkkiGVIo23ffv2fTczZ0ZdR1M8nzXNej2fhxW+G4qIBWABYHZ2lqWlpVGVIo1URDw36hrq8nyWVvR6Pg9r2PkYcHHH64uqtp/IzMXMnM/M+ZmZ1vxPvzSVPJ+l/gwrfJ8CtkXEpRHxGuBm4JEh7UuSpIkylGHnzDwVEXcAfw9sAnZm5tPD2JckSZNmaNd8M3MvsHdYny9J0qRyhitJkgozfCVJKszwlSSpMMNXkqTCDF9JkgozfCVJKszwlSSpMMNXkqTCDF9JkgozfCVJKszwlSSpMMNXkqTCDF9JkgozfCVJKszwlSSpMMNXkqTCBg7fiLg4Ir4YEV+PiKcj4r9U7R+LiGMRsb96XNtcuZIkTb5zamx7CvjdzPxKRJwH7IuIR6v37s3Mj9cvT5Kk9hk4fDPzOHC8Wn45Ir4BbG2qMEmS2qqRa74RcQnwFuCfq6Y7IuJAROyMiAua2IckSW1RO3wj4meBPcDvZOZLwH3A5cAcKz3je9bZbiEiliJiaXl5uW4ZkkbI81nqT63wjYifYSV4P52Zfw2QmScy89XMPA3cD2zvtm1mLmbmfGbOz8zM1ClD0oh5Pkv9qXO3cwAPAN/IzD/taN/SsdoNwMHBy5MkqX3q3O38DuB9wNciYn/V9hFgR0TMAQkcAW6rVaEkSS1T527nfwSiy1t7By+njIggM0ddhqQGxCevI2/721GXIfVlame4Whk1l9QG8cnrRl2C1JepDV8wgKU2MYA1SaY6fMEAltrEANakmPrwBQNYahMDWJPA8K0YwFJ7GMAad4ZvBwNYag8DWOPM8F3DAJbawwDWuDJ8uzCApfYwgDWODN91GMBSexjAGjeG71kYwFJ7GMAaJ3Xmdp5YTi0ptYdTS2oS2fOVJKkww1eSpMIMX0mSCjN8JUkqbCpvuJIk9eftl9/d9zZPHL5rCJW0Q+3wjYgjwMvAq8CpzJyPiNcBDwGXAEeAmzLze3X3pckTEWRmz8+SNA2aGnZ+V2bOZeZ89fpO4LHM3AY8Vr3WFFoN1F6fJWkaDOua7/XArmp5F/DeIe1HY251opJenyVpGjQRvgl8ISL2RcRC1bY5M49Xy98BNjewH00ge76SdKYmbrj65cw8FhH/AXg0Ir7Z+WZmZkSc8S9rFdQLALOzsw2UoXHkNd/p4Pks9ad2zzczj1XPJ4GHge3AiYjYAlA9n+yy3WJmzmfm/MzMTN0yNKbs+U4Hz2epP7XCNyLOjYjzVpeBXwUOAo8At1Sr3QJ8vs5+NLm85itJZ6o77LwZeLj6h/Mc4DOZ+XcR8RTwuYi4FXgOuKnmfjSh7PlK0plqhW9mPgv8Ypf2F4Cr6ny22sFrvpJ0JqeX1FDZ85WkMxm+Giqv+UrSmQxfDZU9X0k6kz+sIEnakD+S0Cx7vpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUN/MMKEfEG4KGOpsuAPwDOB34bWK7aP5KZeweuUJKklhk4fDPzGWAOICI2AceAh4EPAvdm5scbqVCSpJZpatj5KuBwZj7X0OdJktRaTYXvzcDujtd3RMSBiNgZERc0tA9JklqhdvhGxGuA9wB/WTXdB1zOypD0ceCedbZbiIiliFhaXl7utoqkCeH5LPWniZ7vNcBXMvMEQGaeyMxXM/M0cD+wvdtGmbmYmfOZOT8zM9NAGZJGxfNZ6k8T4buDjiHniNjS8d4NwMEG9iFJUmsMfLczQEScC7wbuK2j+Y8jYg5I4Mia9yRJmnq1wjczfwD83Jq299WqSJKklnOGK0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMLO6WWliNgJXAeczMxfqNpeBzwEXAIcAW7KzO9FRAD/HbgW+CHwgcz8SvOlt8PKH1dvMnOIlUiqa2HPsZ7XXbxx6xAr0bjrtef7IHD1mrY7gccycxvwWPUa4BpgW/VYAO6rX2b7RERfwTvoNpKGb2HPsb6Cd9Bt1B49hW9mPg68uKb5emBXtbwLeG9H+6dyxZPA+RGxpYli26JugBrA0vioG6AG8HTqadh5HZsz83i1/B1gc7W8FXi+Y72jVdtxplyTobn6WQ5FS6PRZGiufpZD0dOjTvj+RGZmRPSVAhGxwMqwNLOzs02UMdY2Ct6zhejZto0IA1gjN23n80bBe7YQPdu2C3uOGcBTok74noiILZl5vBpWPlm1HwMu7ljvoqrtp2TmIrAIMD8/P7Xp0Utwrq7jcLPGlefzil6Cc3Udh5unW52vGj0C3FIt3wJ8vqP9/bHil4DvdwxPT6X1QrPfHut66xvKUjnrhWa/Pdb11jeUp0NP4RsRu4EngDdExNGIuBX4I+DdEfFt4Feq1wB7gWeBQ8D9wP/beNUTpKng3Wg7A1gavqaCd6PtDOD262nYOTN3rPPWVV3WTeBDdYpqu7rXaDPTsJXGRN1rtIs3bjVsp5AzXA1Rt4Bs6uaobp9jIEvD0y0gm7o5qtvnGMjtZvhKklSY4VtQ018J8itG0ug0/ZUgv2I0XQxfSZIKM3wlSSqskRmupBL8BSipPb78jnf1vO72f/riECsZDXu+kiQVZvhKklSY4StJUmGGb0FNT4LhpBrS6DQ9CYaTakwXw1eSpMIM3yEa5hSQw5y6UtKZhjkF5DCnrtR4MnxHoG4AO9wsjY+6Aexw83QyfIes6Z8AbPonCiX1rumfAGz6Jwo1OQzfApoKYINXGr2mAtjgnW7OcDViq4F6tgB1mFmaDKuBerYAdZhZ0EP4RsRO4DrgZGb+QtX2J8D/A/wYOAx8MDP/T0RcAnwDeKba/MnMvH0IdU+czDxriA4asNPU652mY9V4W7xx61lDdNCAnaZebxunjOxHLz3fB4E/Bz7V0fYo8OHMPBUR/w34MPD/Ve8dzsy5RqtsidXwaKInaxBJo7UalE30ZKcpdLViw2u+mfk48OKati9k5qnq5ZPARUOorbXqBqfBK42PusFp8E6nJq75/hbwUMfrSyPiq8BLwO9n5pca2EfrDNILNnSl8TRIL9jQnW61wjci7gJOAZ+umo4Ds5n5QkS8DfibiHhzZr7UZdsFYAFgdna2ThkTzUBVG3g+rzBQ1auBv2oUER9g5Uas38wqQTLzlcx8oVrex8rNWK/vtn1mLmbmfGbOz8zMDFqGpDHg+Sz1Z6DwjYirgd8D3pOZP+xon4mITdXyZcA24NkmCpUkqS16+arRbuBK4MKIOAp8lJW7m18LPFpds1z9StE7gT+MiH8FTgO3Z+aLXT9YkqQptWH4ZuaOLs0PrLPuHmBP3aIkSWozp5eUJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpML8PV9NvW7zazvtpzSZXt5/Zp/yvLnTI6jk7Oz5aqqt98MWTfzso6SyugXv2dpHafwqkgrZKGANYGlybBSw4xbA41WNJElTwPCVJKkww1eSpMIMX0mSCjN8NbU2+jqRXzeSJsdGXycat68bGb6aausFrMErTZ71AnbcghecZENjZvXrPSXDz6CVhmPXl94OwC1XPFFsn+MYtN1s2PONiJ0RcTIiDna0fSwijkXE/upxbcd7H46IQxHxTET82rAKlyRpUvUy7PwgcHWX9nszc6567AWIiDcBNwNvrrb5i4jY1FSxkiS1wYbhm5mPAy/2+HnXA5/NzFcy81+AQ8D2GvVJktQ6dW64uiMiDlTD0hdUbVuB5zvWOVq1SZKkyqDhex9wOTAHHAfu6fcDImIhIpYiYml5eXnAMiSNA89nqT8DhW9mnsjMVzPzNHA//za0fAy4uGPVi6q2bp+xmJnzmTk/MzMzSBmSxoTns9SfgcI3IrZ0vLwBWL0T+hHg5oh4bURcCmwDvlyvREmS2mXD7/lGxG7gSuDCiDgKfBS4MiLmgASOALcBZObTEfE54OvAKeBDmfnqcEqXJGkybRi+mbmjS/MDZ1n/buDuOkWpnfr5fdxe1nVyDGl0VifQaGrdkhNxjAOnl5QkqTCnl1QxvfRURzG9pKT+9dJTHcX0kpPCnq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhTrKhseLkGlJ7OLnG+uz5SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJU2IbhGxE7I+JkRBzsaHsoIvZXjyMRsb9qvyQiftTx3ieGWbwkSZOol68aPQj8OfCp1YbM/I3V5Yi4B/h+x/qHM3OuqQIlSWqbDcM3Mx+PiEu6vRcrv3x+E/Cfmy1LkqT2qnvN9wrgRGZ+u6Pt0oj4akT8Q0RcUfPzJUlqnbozXO0Adne8Pg7MZuYLEfE24G8i4s2Z+dLaDSNiAVgAmJ2drVmGpFHyfJb6M3DPNyLOAX4deGi1LTNfycwXquV9wGHg9d22z8zFzJzPzPmZmZlBy5A0Bjyfpf7UGXb+FeCbmXl0tSEiZiJiU7V8GbANeLZeiZIktUsvXzXaDTwBvCEijkbErdVbN/PTQ84A7wQOVF89+ivg9sx8scmCJUmadL3c7bxjnfYPdGnbA+ypX5YkSe3lDFeSJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYZGZo66BiFgGfgB8d9S1NOBCPI5xMgnH8fOZ2Zrf4YuIl4FnRl1HAybh704vPI6yejqfxyJ8ASJiKTPnR11HXR7HeGnLcUyStvyZexzjpS3HscphZ0mSCjN8JUkqbJzCd3HUBTTE4xgvbTmOSdKWP3OPY7y05TiAMbrmK0nStBinnq8kSVPB8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSpsaOEbEVdHxDMRcSgi7hzWfiRJmjSRmc1/aMQm4FvAu4GjwFPAjsz8euM7kyRpwgyr57sdOJSZz2bmj4HPAtcPaV+SJE2UYYXvVuD5jtdHqzZJkqbeOaPacUQsAAsA55577tve+MY3jqoUaaT27dv33cycGXUddXg+Syt6PZ+HFb7HgIs7Xl9Utf1EZi4CiwDz8/O5tLQ0pFKk8RYRz426hro8n6UVvZ7Pwxp2fgrYFhGXRsRrgJuBR4a0L0mSJspQer6ZeSoi7gD+HtgE7MzMp4exL0mSJs3Qrvlm5l5g77A+X5KkSeUMV5IkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUNHL4RcXFEfDEivh4RT0fEf6naPxYRxyJif/W4trlyJUmafOfU2PYU8LuZ+ZWIOA/YFxGPVu/dm5kfr1+eJEntM3D4ZuZx4Hi1/HJEfAPY2lRhkiS1VSPXfCPiEuAtwD9XTXdExIGI2BkRFzSxD0mS2qJ2+EbEzwJ7gN/JzJeA+4DLgTlWesb3rLPdQkQsRcTS8vJy3TIkjZDns9SfWuEbET/DSvB+OjP/GiAzT2Tmq5l5Grgf2N5t28xczMz5zJyfmZmpU4akEfN8lvpT527nAB4AvpGZf9rRvqVjtRuAg4OXJ0lS+9S52/kdwPuAr0XE/qrtI8COiJgDEjgC3FarQkmSWqbO3c7/CESXt/YOXo4kSe3nDFeSJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhq4kSEUR0+yVLSZNm15fezq4vvX3UZYyE4StJUmHn1P2AiDgCvAy8CpzKzPmIeB3wEHAJcAS4KTO/V3dfkiS1QVM933dl5lxmzlev7wQey8xtwGPVa0mSxPCGna8HdlXLu4D3Dmk/kiRNnCbCN4EvRMS+iFio2jZn5vFq+TvA5gb2I0lSK9S+5gv8cmYei4j/ADwaEd/sfDMzMyJy7UZVUC8AzM7ONlCG2mSjO5rXez/zjL9qKsDzWWez0R3N671/yxVPDKOcsVC755uZx6rnk8DDwHbgRERsAaieT3bZbjEz5zNzfmZmpm4ZkkbI81nqT62eb0ScC/y7zHy5Wv5V4A+BR4BbgD+qnj9ft1BNl/V6sKs9Xnu40uRYrwe72uNtcw93PXWHnTcDD1f/IJ4DfCYz/y4ingI+FxG3As8BN9XcjyRJrVErfDPzWeAXu7S/AFxV57MlSWorZ7iSJKkww1eSpMIMX0mSCmvie75SMd7lLLXHNN7lvMqeryRJhRm+kiQVZvhKklSY4StJUmGGryRJhRm+kiQVZvhKklSY4StJUmGGryRJhRm+kiQVZvhKklSY4StJUmED/7BCRLwBeKij6TLgD4Dzgd8Glqv2j2Tm3oErlCSpZQYO38x8BpgDiIhNwDHgYeCDwL2Z+fFGKpQkqWWaGna+Cjicmc819HmSJLVWU+F7M7C74/UdEXEgInZGxAUN7UOSpFaoHb4R8RrgPcBfVk33AZezMiR9HLhnne0WImIpIpaWl5e7rSJpQng+S/1poud7DfCVzDwBkJknMvPVzDwN3A9s77ZRZi5m5nxmzs/MzDRQhqRR8XyW+tNE+O6gY8g5IrZ0vHcDcLCBfUiS1BoD3+0MEBHnAu8Gbuto/uOImAMSOLLmPUmSpl6t8M3MHwA/t6btfbUqkiSp5ZzhSpKkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwmrNcKVyIoLMrP0safTefvndtT/jicN3NVCJRsWe74RYDc66z5Kk0TN8J0RENPIsSRo9w3dC2POVpPYwfCeEPV9Jag/Dd0LY85Wk9jB8J4Q9X0lqj57CNyJ2RsTJiDjY0fa6iHg0Ir5dPV9QtUdE/FlEHIqIAxHx1mEVP03s+UpSe/Ta830QuHpN253AY5m5DXiseg1wDbCteiwA99UvU/Z8Jak9egrfzHwceHFN8/XArmp5F/DejvZP5YongfMjYksTxU4ze76S1B51ZrjanJnHq+XvAJur5a3A8x3rHa3ajqOBtWmGq3564eNSs6TuvvyOd/W87vZ/+uIQK5ksjdxwlSv/Qvb1r2RELETEUkQsLS8vN1FGq9nz1TjzfJb6Uyd8T6wOJ1fPJ6v2Y8DFHetdVLX9lMxczMz5zJyfmZmpUcZ08Jqvxpnns9SfOuH7CHBLtXwL8PmO9vdXdz3/EvD9juFpDcieryS1R0/XfCNiN3AlcGFEHAU+CvwR8LmIuBV4DripWn0vcC1wCPgh8MGGa55KbbrmK0nTrqfwzcwd67x1VZd1E/hQnaJ0Jnu+ktQeznA1IbzmK0ntUeerRirInq/UHk8cvmvUJWjE7PlKklSY4StJUmGGryRJhXnNV8V5/VlqD6eMHIw9X0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqbMPwjYidEXEyIg52tP1JRHwzIg5ExMMRcX7VfklE/Cgi9lePTwyzeEmSJlEvPd8HgavXtD0K/EJm/kfgW8CHO947nJlz1eP2ZsqUJKk9NgzfzHwceHFN2xcy81T18kngoiHUJklSKzVxzfe3gP+/4/WlEfHViPiHiLiigc+XJKlVav2kYETcBZwCPl01HQdmM/OFiHgb8DcR8ebMfKnLtgvAAsDs7GydMiSNmOez1J+Be74R8QHgOuA3s/qB1sx8JTNfqJb3AYeB13fbPjMXM3M+M+dnZmYGLUPSGPB8lvozUPhGxNXA7wHvycwfdrTPRMSmavkyYBvwbBOFSpLUFhsOO0fEbuBK4MKIOAp8lJW7m18LPBoRAE9Wdza/E/jDiPhX4DRwe2a+2PWDJUmaUhuGb2bu6NL8wDrr7gH21C1KkqQ2c4YrSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqbBav+c7aaofgeiq+lVESRPi5f3r9x3OmztdsBKpf1MRvmcL3bXrGMLSeDtb6K5dxxDWuGr9sHMvwVtnfUnl9BK8ddaXSvFvpiRJhbU6fAftxdr7lcbPoL1Ye78aRxv+rYyInRFxMiIOdrR9LCKORcT+6nFtx3sfjohDEfFMRPzasAqXJGlS9fK/hA8CV3dpvzcz56rHXoCIeBNwM/Dmapu/iIhNTRUrSVIbbBi+mfk48GKPn3c98NnMfCUz/wU4BGyvUZ8kSa1T52LIHRFxoBqWvqBq2wo837HO0apNkiRVBg3f+4DLgTngOHBPvx8QEQsRsRQRS8vLywOWIWkceD5L/RkofDPzRGa+mpmngfv5t6HlY8DFHateVLV1+4zFzJzPzPmZmZlBypA0Jjyfpf4MFL4RsaXj5Q3A6p3QjwA3R8RrI+JSYBvw5XolDm7Q2aqmeZYrv2alcTXobFXTPMtVfPK6UZegdWw4vWRE7AauBC6MiKPAR4ErI2IOSOAIcBtAZj4dEZ8Dvg6cAj6Uma8Op3QNS0RM9f+ASG0Sn7yOvO1vR12G1tgwfDNzR5fmB86y/t3A3XWKalJm9tWbM3RWGMAaR+fNne5r0oxp7vV2MoDHz1RM/ZKZGwZJL+tMG4egNY7Omzu9Yaj2ss60cQh6vEzFrxqtMlz7Zw9Y48pw7Z894PExFT1f1WMPWGoPe8DjwfBVTwxgqT0M4NEzfNUzA1hqDwN4tKbqmq/qG+U14F7C3+vTUu9GeQ14YU/X+Zd+yuKN7Z2d2PBV30oHcD897tV1DWGpN6UDuJfQXbtuG0PYYWcNpNQQ9KD7cYhc6l2pIeh+greJ7caZPV8NbNg94PUCtNs+u63r16Sk3g27B7xegHbr1XZbd2HPsVb1gO35qpZh9TC7fe7ZJkJZ7z17wFLvhtUD7hamizduXTdM13uvTT1gw1e1NR1w6wVvLwxgqZ6mA3i94O1FmwPY8FUjhhlw/Q4dO9Qs1TPMa8D9Dh23aai5k+GrxjQRwGs/o6mfhbT3K/WniQBe20sdNEjXbteG3q/hq0YZclJ7OBHH8Bi+alxTAVx3+NjhZ6m+pgK47vBx24afDV8NhT1gqT3sATdvw/CNiJ0RcTIiDna0PRQR+6vHkYjYX7VfEhE/6njvE8MsXuPNAJbawwBuVi893weBqzsbMvM3MnMuM+eAPcBfd7x9ePW9zLy9uVI1iQxgqT0M4OZsGL6Z+TjwYrf3YuVf1puA3Q3XpRYxgKX2MICbUXd6ySuAE5n57Y62SyPiq8BLwO9n5pdq7kMFjdNNSnWnhzT0Ne1G9YtF3dSdHrINXy/qVDd8d/DTvd7jwGxmvhARbwP+JiLenJkvrd0wIhaABYDZ2dmaZUgaJc9nqT8D3+0cEecAvw48tNqWma9k5gvV8j7gMPD6bttn5mJmzmfm/MzMzKBlqGWamhyjqck61BvPZ3XT1OQYTU3WMU7qfNXoV4BvZubR1YaImImITdXyZcA24Nl6JWra9RvADjdL46vfAG7bcPOqXr5qtBt4AnhDRByNiFurt27mzBut3gkcqL569FfA7ZnZ9WYtaT11fhyhzo8ySGpenR9HqPOjDONuw2u+mbljnfYPdGnbw8pXj6RaMvOMIF193U84G7zS6C3euPWMIF193U84tyV4of4NV9LQdAtg6L0XbPBK46NbAEPvveA2BS84vaTGXFO/aiRp9Jr6VaM2sOersbcapL30eA1dabytBmkvPd42hu4qw1cTw2CV2qPNwdoLh50lSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKi3GYuCAiloEfAN8ddS0NuBCPY5xMwnH8fGa25kdwI+Jl4JlR19GASfi70wuPo6yezuexCF+AiFjKzPlR11GXxzFe2nIck6Qtf+Yex3hpy3GscthZkqTCDF9Jkgobp/BdHHUBDfE4xktbjmOStOXP3OMYL205DmCMrvlKkjQtxqnnK0nSVDB8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSChta+EbE1RHxTEQciog7h7UfSZImTWRm8x8asQn4FvBu4CjwFLAjM7/e+M4kSZoww+r5bgcOZeazmflj4LPA9UPalyRJE+WcIX3uVuD5jtdHgf/UuUJELAALAOeee+7b3vjGNw6pFGm87du377uZOTPqOurwfJZW9HoA/shgAAAQwUlEQVQ+Dyt8N5SZi8AiwPz8fC4tLY2qFGmkIuK5UddQl+eztKLX83lYw87HgIs7Xl9UtUmSNPWGFb5PAdsi4tKIeA1wM/DIkPYlSdJEGcqwc2aeiog7gL8HNgE7M/PpYexLkqRJM7Rrvpm5F9g7rM+XJGlSOcOVJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjh25CIICJGXYakBuz60tvZ9aW3j7oMtZjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFDe0nBduol68SnW2dzGyyHEk19PJVorOtc8sVTzRZjqaMPV9JkgobuOcbERcDnwI2AwksZuZ/j4iPAb8NLFerfiQz99YtdBycree62uO1dytNhrP1XFd7vPZuNSx1hp1PAb+bmV+JiPOAfRHxaPXevZn58frlSZLUPgOHb2YeB45Xyy9HxDeArU0VJklSWzVyw1VEXAK8Bfhn4B3AHRHxfmCJld7x97psswAsAMzOzjZRhgYQEWRm7WdNN8/n8fD2y++u/RlPHL6rgUq0kdo3XEXEzwJ7gN/JzJeA+4DLgTlWesb3dNsuMxczcz4z52dmZuqWoQGtBmfdZ003z2epP7XCNyJ+hpXg/XRm/jVAZp7IzFcz8zRwP7C9fpkaltUbxeo+S5J6N3D4xsq/ug8A38jMP+1o39Kx2g3AwcHL07DZ85Wk8upc830H8D7gaxGxv2r7CLAjIuZY+frREeC2WhVOiEkNIa/5SmfyK0Yatjp3O/8j0G3MsRXf6Z0W9nwlqTxnuJpyXvOVpPIM3ylnz1eSyjN8p5w9X0kqz/CdcvZ8Jak8w3fK2fOVpPIM3ylnz1eSyjN8p5w9X0kqz/CdcvZ8Jak8w3fK2fOVpPIM3ylnz1eSyjN8p5w9X0kqr84PK6gF7PlK7fHE4btGXYJ6ZM9XkqTCDF9JkgozfCVJKszwlSSpsNo3XEXEEeBl4FXgVGbOR8TrgIeAS4AjwE2Z+b26+5IkqQ2a6vm+KzPnMnO+en0n8FhmbgMeq15LkiSGN+x8PbCrWt4FvHdI+5EkaeI0Eb4JfCEi9kXEQtW2OTOPV8vfATav3SgiFiJiKSKWlpeXGyhD0qh4Pkv9aSJ8fzkz3wpcA3woIt7Z+WauzMJwxkwMmbmYmfOZOT8zM9NAGZJGxfNZ6k/t8M3MY9XzSeBhYDtwIiK2AFTPJ+vuR5KktqgVvhFxbkSct7oM/CpwEHgEuKVa7Rbg83X2I0lSm9T9qtFm4OFqcv1zgM9k5t9FxFPA5yLiVuA54Kaa+5EkqTVqhW9mPgv8Ypf2F4Cr6ny2JElt5QxXkiQVZvhKklSY4StJUmG153bWZKtuljurla9qSxp3X37HuzZcZ/s/fbFAJdqIPV9JkgozfCVJKszwlSSpMMNXkqTCDF9JkgozfCVJKszwlSSpMMNXkqTCnGRjyjmBhtQeTqAxOez5SpJUmOErSVJhhq8kSYUNfM03It4APNTRdBnwB8D5wG8Dy1X7RzJz78AVSpLUMgOHb2Y+A8wBRMQm4BjwMPBB4N7M/HgjFUqS1DJNDTtfBRzOzOca+jxJklqrqfC9Gdjd8fqOiDgQETsj4oJuG0TEQkQsRcTS8vJyt1UkTQjPZ6k/tcM3Il4DvAf4y6rpPuByVoakjwP3dNsuMxczcz4z52dmZuqWIWmEPJ+l/jTR870G+EpmngDIzBOZ+WpmngbuB7Y3sA9JklqjifDdQceQc0Rs6XjvBuBgA/uQJKk1ak0vGRHnAu8Gbuto/uOImAMSOLLmPUmSpl6t8M3MHwA/t6btfbUqkiSp5ZzhSpKkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqbCewjcidkbEyYg42NH2uoh4NCK+XT1fULVHRPxZRByKiAMR8dZhFS9J0iTqtef7IHD1mrY7gccycxvwWPUa4BpgW/VYAO6rX6YkSe3RU/hm5uPAi2uarwd2Vcu7gPd2tH8qVzwJnB8RW5ooVpKkNqhzzXdzZh6vlr8DbK6WtwLPd6x3tGr7KRGxEBFLEbG0vLxcowxJo+b5LPWnkRuuMjOB7HObxcycz8z5mZmZJsqQNCKez1J/6oTvidXh5Or5ZNV+DLi4Y72LqjZJkkS98H0EuKVavgX4fEf7+6u7nn8J+H7H8LQkSVPvnF5WiojdwJXAhRFxFPgo8EfA5yLiVuA54KZq9b3AtcAh4IfABxuuWZKkidZT+GbmjnXeuqrLugl8qE5RkiS1mTNcSZJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFbZh+EbEzog4GREHO9r+JCK+GREHIuLhiDi/ar8kIn4UEfurxyeGWbwkSZOol57vg8DVa9oeBX4hM/8j8C3gwx3vHc7MuepxezNlKiJGXYKkhsQnrxt1CRqxDcM3Mx8HXlzT9oXMPFW9fBK4aAi1aQ0DWGoPA3i6ndPAZ/wW8FDH60sj4qvAS8DvZ+aXum0UEQvAAsDs7GwDZYyXXoIyMwf63EG2k4ap7efzwp5jG66zeOPWvj83PnkdedvfDlKSJlytG64i4i7gFPDpquk4MJuZbwH+K/CZiPj33bbNzMXMnM/M+ZmZmTpljJ1ee6iD9mTtAWvctPl87iV4+1lvLXvA02ng8I2IDwDXAb+ZVVcsM1/JzBeq5X3AYeD1DdQ5MboFY2b+5NHL+oPuR1KzugXq4o1bf/LoZf1eGMDTZ6DwjYirgd8D3pOZP+xon4mITdXyZcA24NkmCp0EawOxW+B2azOApfGzNki7BW63NgNYvejlq0a7gSeAN0TE0Yi4Ffhz4Dzg0TVfKXoncCAi9gN/BdyemS92/eCW6Ra8Z2MAS+OrW/CejQGsfm14w1Vm7ujS/MA66+4B9tQtatL1ekNUZjYSnt6EJQ1PrzdSLd64deDQ7eRNWNPBGa5awh6w1B72gNvP8G1Yvz3QJnusBrDUrH6/PjTI143WYwC3m+HbMgaw1B4GcHsZvi1kAEvtYQC3UxMzXKlDvzc/9RqU3lAllbew51hfQ8m93nDlDVWy5ytJUmGG7xAMe3pJSeUMe3pJTSfDtyH9TprR76Qcksrpd9KMfiflkAzfBnUL4LUh263N4JXGT7cAXhuy3doMXvXCG64a1m3WqrP1gg1eaXx1m7XqbL1gg1e9suc7BP1MLylpvPUzvaTUK3u+Q2KwSu1hsKpp9nwlSSrM8JUkqTDDV5KkwgxfSZIK2zB8I2JnRJyMiIMdbR+LiGMRsb96XNvx3ocj4lBEPBMRvzaswiVJmlS99HwfBK7u0n5vZs5Vj70AEfEm4GbgzdU2fxERm5oqVpKkNtgwfDPzceDFHj/veuCzmflKZv4LcAjYXqM+SZJap8413zsi4kA1LH1B1bYVeL5jnaNV2xkiYiEiliJiaXl5uUYZkkbN81nqz6Dhex9wOTAHHAfu6fcDMnMxM+czc35mZmbAMiSNA89nqT8DhW9mnsjMVzPzNHA//za0fAy4uGPVi6o2SZJUGSh8I2JLx8sbgNU7oR8Bbo6I10bEpcA24Mv1SpQkqV02nNs5InYDVwIXRsRR4KPAlRExByRwBLgNIDOfjojPAV8HTgEfysxXh1O6JEmTacPwzcwdXZofOMv6dwN31ylKkqQ2c4YrSZIKM3wlSSrM3/OVgIjYcB1/o1maDC/v37hfed7c6QKVrM+er6ZeL8Hbz3qSRqeX4O1nvWGx56upNUiYrm5jL1gaL4OE6eo2o+gF2/OVJKkww1dTqe4QskPQ0vioO4Q8iiFow1dTp6ngNICl0WsqOEsHsOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJU2IbhGxE7I+JkRBzsaHsoIvZXjyMRsb9qvyQiftTx3ieGWbw0iKbmZXZ+Z2n0mpqXufT8zr38sMKDwJ8Dn1ptyMzfWF2OiHuA73esfzgz55oqUBqGzKw1Q5XBK42P8+ZO15qhahQ/rLBh+Gbm4xFxSbf3YuVfr5uA/9xsWZIktVfda75XACcy89sdbZdGxFcj4h8i4or1NoyIhYhYioil5eXlmmVI/cvMvnuwg2wzDTyfNWrnzZ3uuwc7yDZNqRu+O4DdHa+PA7OZ+RbgvwKfiYh/323DzFzMzPnMnJ+ZmalZhjS4XsPU0F2f57PGRa9hOqrQXdXLNd+uIuIc4NeBt622ZeYrwCvV8r6IOAy8HliqWac0VAar1B6jDtZe1On5/grwzcw8utoQETMRsalavgzYBjxbr0RJktqll68a7QaeAN4QEUcj4tbqrZv56SFngHcCB6qvHv0VcHtmvthkwZIkTbpe7nbesU77B7q07QH21C9LkqT2coYrSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqbAYh59Si4hl4AfAd0ddSwMuxOMYJ5NwHD+fma35EdyIeBl4ZtR1NGAS/u70wuMoq6fzeSzCFyAiljJzftR11OVxjJe2HMckacufuccxXtpyHKscdpYkqTDDV5KkwsYpfBdHXUBDPI7x0pbjmCRt+TP3OMZLW44DGKNrvpIkTYtx6vlKkjQVRh6+EXF1RDwTEYci4s5R19OPiDgSEV+LiP0RsVS1vS4iHo2Ib1fPF4y6zrUiYmdEnIyIgx1tXeuOFX9W/fc5EBFvHV3lP22d4/hYRByr/pvsj4hrO977cHUcz0TEr42m6nbzfC7P83kyz+eRhm9EbAL+B3AN8CZgR0S8aZQ1DeBdmTnXcQv8ncBjmbkNeKx6PW4eBK5e07Ze3dcA26rHAnBfoRp78SBnHgfAvdV/k7nM3AtQ/b26GXhztc1fVH//1BDP55F5EM/niTufR93z3Q4cysxnM/PHwGeB60dcU13XA7uq5V3Ae0dYS1eZ+Tjw4prm9eq+HvhUrngSOD8itpSp9OzWOY71XA98NjNfycx/AQ6x8vdPzfF8HgHP58k8n0cdvluB5zteH63aJkUCX4iIfRGxULVtzszj1fJ3gM2jKa1v69U9if+N7qiG1HZ2DBNO4nFMmkn/M/Z8Hk+tPJ9HHb6T7pcz862sDOV8KCLe2flmrtxKPnG3k09q3ZX7gMuBOeA4cM9oy9EE8XweP609n0cdvseAizteX1S1TYTMPFY9nwQeZmXY48TqME71fHJ0FfZlvbon6r9RZp7IzFcz8zRwP/82FDVRxzGhJvrP2PN5/LT5fB51+D4FbIuISyPiNaxcQH9kxDX1JCLOjYjzVpeBXwUOslL/LdVqtwCfH02FfVuv7keA91d3Sf4S8P2O4ayxs+b61Q2s/DeBleO4OSJeGxGXsnLDyZdL19dyns/jw/N53GXmSB/AtcC3gMPAXaOup4+6LwP+d/V4erV24OdYubvw28D/Al436lq71L6blSGcf2XlWsmt69UNBCt3sB4GvgbMj7r+DY7jf1Z1HmDlBN3Ssf5d1XE8A1wz6vrb+PB8Hkntns8TeD47w5UkSYWNethZkqSpY/hKklSY4StJUmGGryRJhRm+kiQVZvhKklSY4StJUmGGryRJhf1faexlQszjvd8AAAAASUVORK5CYII=\n", 20 | "text/plain": [ 21 | "
" 22 | ] 23 | }, 24 | "metadata": {}, 25 | "output_type": "display_data" 26 | } 27 | ], 28 | "source": [ 29 | "%matplotlib inline\n", 30 | "%load_ext autoreload\n", 31 | "%autoreload 2\n", 32 | "\n", 33 | "import os,sys\n", 34 | "import pandas as pd\n", 35 | "import matplotlib.pyplot as plt\n", 36 | "import numpy as np\n", 37 | "import helper\n", 38 | "import simulation\n", 39 | "\n", 40 | "# Generate some random images\n", 41 | "input_images, target_masks = simulation.generate_random_data(192, 192, count=3)\n", 42 | "\n", 43 | "print(input_images.shape, target_masks.shape)\n", 44 | "\n", 45 | "# Change channel-order and make 3 channels for matplot\n", 46 | "input_images_rgb = [(x.swapaxes(0, 2).swapaxes(0,1) * -255 + 255).astype(np.uint8) for x in input_images]\n", 47 | "\n", 48 | "# Map each channel (i.e. class) to each color\n", 49 | "target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]\n", 50 | "\n", 51 | "# Left: Input image, Right: Target mask\n", 52 | "helper.plot_side_by_side([input_images_rgb, target_masks_rgb])" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 87, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),\n", 64 | " BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n", 65 | " ReLU(inplace),\n", 66 | " MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),\n", 67 | " Sequential(\n", 68 | " (0): BasicBlock(\n", 69 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 70 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 71 | " (relu): ReLU(inplace)\n", 72 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 73 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 74 | " )\n", 75 | " (1): BasicBlock(\n", 76 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 77 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 78 | " (relu): ReLU(inplace)\n", 79 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 80 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 81 | " )\n", 82 | " ),\n", 83 | " Sequential(\n", 84 | " (0): BasicBlock(\n", 85 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 86 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 87 | " (relu): ReLU(inplace)\n", 88 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 89 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (downsample): Sequential(\n", 91 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 92 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 93 | " )\n", 94 | " )\n", 95 | " (1): BasicBlock(\n", 96 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 97 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 98 | " (relu): ReLU(inplace)\n", 99 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 100 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 101 | " )\n", 102 | " ),\n", 103 | " Sequential(\n", 104 | " (0): BasicBlock(\n", 105 | " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 106 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 107 | " (relu): ReLU(inplace)\n", 108 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 109 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 110 | " (downsample): Sequential(\n", 111 | " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 112 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 113 | " )\n", 114 | " )\n", 115 | " (1): BasicBlock(\n", 116 | " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 117 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 118 | " (relu): ReLU(inplace)\n", 119 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 120 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 121 | " )\n", 122 | " ),\n", 123 | " Sequential(\n", 124 | " (0): BasicBlock(\n", 125 | " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 126 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 127 | " (relu): ReLU(inplace)\n", 128 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 129 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 130 | " (downsample): Sequential(\n", 131 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 132 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 133 | " )\n", 134 | " )\n", 135 | " (1): BasicBlock(\n", 136 | " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 137 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 138 | " (relu): ReLU(inplace)\n", 139 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 140 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 141 | " )\n", 142 | " ),\n", 143 | " AvgPool2d(kernel_size=7, stride=1, padding=0),\n", 144 | " Linear(in_features=512, out_features=1000, bias=True)]" 145 | ] 146 | }, 147 | "execution_count": 87, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "from torchvision import models\n", 154 | "\n", 155 | "base_model = models.resnet18(pretrained=True)\n", 156 | "\n", 157 | "def find_last_layer(layer):\n", 158 | " children = list(layer.children())\n", 159 | " if len(children) == 0:\n", 160 | " return layer\n", 161 | " else:\n", 162 | " return find_last_layer(children[-1])\n", 163 | " \n", 164 | "list(base_model.children())" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 86, 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "ename": "NameError", 174 | "evalue": "name 'OrderedDict' is not defined", 175 | "output_type": "error", 176 | "traceback": [ 177 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 178 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 179 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmodel_wo_avgpool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSequential\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbase_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_wo_avgpool\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 180 | "\u001b[0;31mNameError\u001b[0m: name 'OrderedDict' is not defined" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "from torch import nn\n", 186 | "\n", 187 | "model_wo_avgpool = nn.Sequential(*list(base_model.children())[:-2])\n", 188 | "\n", 189 | "#OrderedDict(model_wo_avgpool.named_children())" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 93, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "import torch\n", 199 | "\n", 200 | "class FCN(nn.Module):\n", 201 | "\n", 202 | " def __init__(self, n_class):\n", 203 | " super().__init__()\n", 204 | " \n", 205 | " self.base_model = models.resnet18(pretrained=True)\n", 206 | " \n", 207 | " layers = list(base_model.children())\n", 208 | " self.layer1 = nn.Sequential(*layers[:5]) # size=(N, 64, x.H/2, x.W/2)\n", 209 | " self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')\n", 210 | " self.layer2 = layers[5] # size=(N, 128, x.H/4, x.W/4)\n", 211 | " self.upsample2 = nn.Upsample(scale_factor=8, mode='bilinear')\n", 212 | " self.layer3 = layers[6] # size=(N, 256, x.H/8, x.W/8)\n", 213 | " self.upsample3 = nn.Upsample(scale_factor=16, mode='bilinear')\n", 214 | " self.layer4 = layers[7] # size=(N, 512, x.H/16, x.W/16)\n", 215 | " self.upsample4 = nn.Upsample(scale_factor=32, mode='bilinear')\n", 216 | " \n", 217 | " self.conv1k = nn.Conv2d(64 + 128 + 256 + 512, n_class, 1)\n", 218 | " self.sigmoid = nn.Sigmoid()\n", 219 | " \n", 220 | " def forward(self, x):\n", 221 | " x = self.layer1(x)\n", 222 | " up1 = self.upsample1(x)\n", 223 | " x = self.layer2(x)\n", 224 | " up2 = self.upsample2(x)\n", 225 | " x = self.layer3(x)\n", 226 | " up3 = self.upsample3(x)\n", 227 | " x = self.layer4(x)\n", 228 | " up4 = self.upsample4(x)\n", 229 | " \n", 230 | " merge = torch.cat([up1, up2, up3, up4], dim=1)\n", 231 | " merge = self.conv1k(merge)\n", 232 | " out = self.sigmoid(merge)\n", 233 | " \n", 234 | " return out\n", 235 | "\n", 236 | "fcn_model = FCN(6)\n", 237 | "\n", 238 | "import torchsummary\n", 239 | "\n", 240 | "#torchsummary.summary(fcn_model, input_size=(3, 224, 224), device='cpu')" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 109, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n", 250 | " since = time.time()\n", 251 | "\n", 252 | " best_model_wts = copy.deepcopy(model.state_dict())\n", 253 | " best_loss = 1e10\n", 254 | "\n", 255 | " for epoch in range(num_epochs):\n", 256 | " print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n", 257 | " print('-' * 10)\n", 258 | "\n", 259 | " # Each epoch has a training and validation phase\n", 260 | " for phase in ['train', 'val']:\n", 261 | " if phase == 'train':\n", 262 | " scheduler.step()\n", 263 | " model.train() # Set model to training mode\n", 264 | " else:\n", 265 | " model.eval() # Set model to evaluate mode\n", 266 | "\n", 267 | " running_loss = 0.0\n", 268 | " running_corrects = 0\n", 269 | "\n", 270 | " # Iterate over data.\n", 271 | " batch_size = 10\n", 272 | " epoch_steps = 10\n", 273 | " for i in range(epoch_steps):\n", 274 | " input_images, target_masks = simulation.generate_random_data(192, 192, count=batch_size)\n", 275 | "\n", 276 | " inputs = torch.from_numpy(input_images)\n", 277 | " labels = torch.from_numpy(target_masks)\n", 278 | " inputs = inputs.to(device)\n", 279 | " labels = labels.to(device) \n", 280 | "\n", 281 | " # zero the parameter gradients\n", 282 | " optimizer.zero_grad()\n", 283 | "\n", 284 | " # forward\n", 285 | " # track history if only in train\n", 286 | " with torch.set_grad_enabled(phase == 'train'):\n", 287 | " outputs = model(inputs)\n", 288 | " loss = criterion(outputs, labels)\n", 289 | "\n", 290 | " # backward + optimize only if in training phase\n", 291 | " if phase == 'train':\n", 292 | " loss.backward()\n", 293 | " optimizer.step()\n", 294 | "\n", 295 | " # statistics\n", 296 | " running_loss += loss.item() * inputs.size(0)\n", 297 | "\n", 298 | " epoch_loss = running_loss / (batch_size * epoch_steps)\n", 299 | "\n", 300 | " print('{} Loss: {:.4f}'.format(\n", 301 | " phase, epoch_loss))\n", 302 | "\n", 303 | " # deep copy the model\n", 304 | " if phase == 'val' and epoch_loss < best_loss:\n", 305 | " best_loss = epoch_loss\n", 306 | " best_model_wts = copy.deepcopy(model.state_dict())\n", 307 | "\n", 308 | " print()\n", 309 | "\n", 310 | " time_elapsed = time.time() - since\n", 311 | " print('Training complete in {:.0f}m {:.0f}s'.format(\n", 312 | " time_elapsed // 60, time_elapsed % 60))\n", 313 | " print('Best val loss: {:4f}'.format(best_loss))\n", 314 | "\n", 315 | " # load best model weights\n", 316 | " model.load_state_dict(best_model_wts)\n", 317 | " return model" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 110, 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "name": "stdout", 327 | "output_type": "stream", 328 | "text": [ 329 | "Epoch 0/24\n", 330 | "----------\n" 331 | ] 332 | }, 333 | { 334 | "name": "stderr", 335 | "output_type": "stream", 336 | "text": [ 337 | "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/functional.py:1749: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n", 338 | " \"See the documentation of nn.Upsample for details.\".format(mode))\n" 339 | ] 340 | }, 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "train Loss: 0.4876\n", 346 | "val Loss: 0.2639\n", 347 | "\n", 348 | "Epoch 1/24\n", 349 | "----------\n", 350 | "train Loss: 0.1647\n", 351 | "val Loss: 0.0984\n", 352 | "\n", 353 | "Epoch 2/24\n", 354 | "----------\n", 355 | "train Loss: 0.0781\n", 356 | "val Loss: 0.0628\n", 357 | "\n", 358 | "Epoch 3/24\n", 359 | "----------\n", 360 | "train Loss: 0.0570\n", 361 | "val Loss: 0.0539\n", 362 | "\n", 363 | "Epoch 4/24\n", 364 | "----------\n", 365 | "train Loss: 0.0469\n", 366 | "val Loss: 0.0464\n", 367 | "\n", 368 | "Epoch 5/24\n", 369 | "----------\n", 370 | "train Loss: 0.0443\n", 371 | "val Loss: 0.0441\n", 372 | "\n", 373 | "Epoch 6/24\n", 374 | "----------\n", 375 | "train Loss: 0.0420\n", 376 | "val Loss: 0.0425\n", 377 | "\n", 378 | "Epoch 7/24\n", 379 | "----------\n", 380 | "train Loss: 0.0416\n", 381 | "val Loss: 0.0417\n", 382 | "\n", 383 | "Epoch 8/24\n", 384 | "----------\n", 385 | "train Loss: 0.0407\n", 386 | "val Loss: 0.0419\n", 387 | "\n", 388 | "Epoch 9/24\n", 389 | "----------\n", 390 | "train Loss: 0.0411\n", 391 | "val Loss: 0.0413\n", 392 | "\n", 393 | "Epoch 10/24\n", 394 | "----------\n", 395 | "train Loss: 0.0405\n", 396 | "val Loss: 0.0397\n", 397 | "\n", 398 | "Epoch 11/24\n", 399 | "----------\n", 400 | "train Loss: 0.0402\n", 401 | "val Loss: 0.0392\n", 402 | "\n", 403 | "Epoch 12/24\n", 404 | "----------\n", 405 | "train Loss: 0.0406\n", 406 | "val Loss: 0.0395\n", 407 | "\n", 408 | "Epoch 13/24\n", 409 | "----------\n", 410 | "train Loss: 0.0402\n", 411 | "val Loss: 0.0396\n", 412 | "\n", 413 | "Epoch 14/24\n", 414 | "----------\n", 415 | "train Loss: 0.0405\n", 416 | "val Loss: 0.0406\n", 417 | "\n", 418 | "Epoch 15/24\n", 419 | "----------\n", 420 | "train Loss: 0.0411\n", 421 | "val Loss: 0.0407\n", 422 | "\n", 423 | "Epoch 16/24\n", 424 | "----------\n", 425 | "train Loss: 0.0406\n", 426 | "val Loss: 0.0411\n", 427 | "\n", 428 | "Epoch 17/24\n", 429 | "----------\n", 430 | "train Loss: 0.0410\n", 431 | "val Loss: 0.0401\n", 432 | "\n", 433 | "Epoch 18/24\n", 434 | "----------\n", 435 | "train Loss: 0.0396\n", 436 | "val Loss: 0.0398\n", 437 | "\n", 438 | "Epoch 19/24\n", 439 | "----------\n", 440 | "train Loss: 0.0403\n", 441 | "val Loss: 0.0403\n", 442 | "\n", 443 | "Epoch 20/24\n", 444 | "----------\n", 445 | "train Loss: 0.0410\n", 446 | "val Loss: 0.0415\n", 447 | "\n", 448 | "Epoch 21/24\n", 449 | "----------\n", 450 | "train Loss: 0.0402\n", 451 | "val Loss: 0.0402\n", 452 | "\n", 453 | "Epoch 22/24\n", 454 | "----------\n", 455 | "train Loss: 0.0393\n", 456 | "val Loss: 0.0398\n", 457 | "\n", 458 | "Epoch 23/24\n", 459 | "----------\n", 460 | "train Loss: 0.0405\n", 461 | "val Loss: 0.0398\n", 462 | "\n", 463 | "Epoch 24/24\n", 464 | "----------\n", 465 | "train Loss: 0.0409\n", 466 | "val Loss: 0.0410\n", 467 | "\n", 468 | "Training complete in 1m 11s\n", 469 | "Best val loss: 0.039217\n" 470 | ] 471 | } 472 | ], 473 | "source": [ 474 | "import torch\n", 475 | "import torch.nn as nn\n", 476 | "import torch.optim as optim\n", 477 | "from torch.optim import lr_scheduler\n", 478 | "import time\n", 479 | "import copy\n", 480 | "\n", 481 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 482 | "\n", 483 | "model_ft = FCN(6).to(device)\n", 484 | "\n", 485 | "criterion = nn.BCELoss()\n", 486 | "\n", 487 | "# Observe that all parameters are being optimized\n", 488 | "optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)\n", 489 | "\n", 490 | "# Decay LR by a factor of 0.1 every 7 epochs\n", 491 | "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)\n", 492 | "\n", 493 | "model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)" 494 | ] 495 | } 496 | ], 497 | "metadata": { 498 | "kernelspec": { 499 | "display_name": "Python [conda env:py36]", 500 | "language": "python", 501 | "name": "conda-env-py36-py" 502 | }, 503 | "language_info": { 504 | "codemirror_mode": { 505 | "name": "ipython", 506 | "version": 3 507 | }, 508 | "file_extension": ".py", 509 | "mimetype": "text/x-python", 510 | "name": "python", 511 | "nbconvert_exporter": "python", 512 | "pygments_lexer": "ipython3", 513 | "version": "3.6.4" 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 2 518 | } 519 | -------------------------------------------------------------------------------- /pytorch_resnet18_unet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "(3, 192, 192, 3)\n", 13 | "0 255\n", 14 | "(3, 6, 192, 192)\n", 15 | "0.0 1.0\n" 16 | ] 17 | }, 18 | { 19 | "data": { 20 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAKvCAYAAAArysUEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3X+sXOV95/HPZ+0mUikVEAbWa3ANyCEKVXsTRm4QIoLSpAahODRKaqtK3AT1ghRWzbZ/FIK0oK6Qsm0o2igbkouwMKvEQOPQoKzbxkLZkFRQuE4cxwQINnHCtS37BkeBLRFZm+/+cc80h8vce+fOOfOcH/N+SaOZeeY5c77n+h5/7vOcM2ccEQIAAOn8h6oLAABg3BC+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkNrLwtb3B9rO299u+aVTrAQCgaTyKz/naXiHph5LeI2lG0pOSNkfED0pfGQAADTOqke96Sfsj4vmI+KWk+yVtHNG6AABolJUjet/Vkl7IPZ+R9HsLdbbNZbYwzn4aEZ2qiyjLmWeeGWvXrq26DKASu3fvHmh/HlX4uk/b6wLW9qSkyRGtH2iSH1ddQFH5/XnNmjWanp6uuCKgGrYH2p9HNe08I+nc3PNzJB3Od4iIqYjoRkR3RDUASCS/P3c6rRnEAyMzqvB9UtI62+fZfpOkTZIeHtG6AABolJFMO0fECds3SvpnSSskbY2Ip0axLgAAmmZUx3wVETsl7RzV+wMA0FRc4QoAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACCxocPX9rm2v2H7adtP2f7zrP0224ds78luV5dXLgAAzbeywLInJP1lRHzH9qmSdtvelb12Z0R8unh5AAC0z9DhGxFHJB3JHr9s+2lJq8sqDACAtirlmK/ttZLeIelfs6Ybbe+1vdX26WWsAwCAtigcvrZ/Q9IOSZ+IiJck3SXpAkkTmhsZ37HAcpO2p21PF60BQLXy+/Ps7GzV5QC1Vyh8bf+a5oL3ixHxFUmKiKMRcTIiXpN0t6T1/ZaNiKmI6EZEt0gNAKqX3587nU7V5QC1N/QxX9uWdI+kpyPi73Ltq7LjwZJ0raR9xUoEpIhY8LW5X0UATTG549CCr019YDxOHSpytvOlkj4s6fu292Rtn5S02faEpJB0UNL1hSrEWFssdOf3IYSBelssdOf3aXsIFznb+duS+v1vt3P4coBfmR+8/cI13yciCGCgpuYHb79wzfeZ3HGo1QHMFa7QCAuFKmELNM9CodrmsJ2P8EUt5Ue0SwVs/vVBpqkBpJUf0S4VsPnXB5mmbirCF7U26MiWETBQf4OObMdhBEz4AgCQGOELAEBihC8AAIkRvgAAJEb4zhMRnDFbI4P+W/Bvhn62fesSbfvWJVWXgcygZy+3+SznHsIXtbScjw8t52NJANJbzseHlvOxpCYjfNEICwUwI16geRYK4HEY8fYUubYzMFK233D5yKX6A6inqQ+sfsPlI5fq32aEL2qtF6h8qxHQfL1A5VuNCF80BAELtMe4BOxiOOYLAEBihC8AAIkRvgAAJFb4mK/tg5JelnRS0omI6No+Q9IDktZKOijpQxHxs6LrAgCgDcoa+V4RERMR0c2e3yTpkYhYJ+mR7DkAANDopp03StqWPd4m6f0jWg8AAI1TxkeNQtLXbYekL0TElKSzI+KIJEXEEdtnlbCe4Yob8gpIy12Oj8IAozfsdZqXu9yWyx4baj3AoMoI30sj4nAWsLtsPzPIQrYnJU2WsH4AFcvvz2vWrKm4GqD+XOa1cW3fJun/SvozSZdno95Vkv5PRFy4yHK1uUBv7+fBSBYJ7c6dL9F43W43pqenqy5D0q9GvIxkkYrtgfbnQsd8bZ9i+9TeY0nvlbRP0sOStmTdtkj6apH1AADQJkWnnc+W9FA2Slwp6UsR8U+2n5T0oO3rJP1E0gcLrgcAgNYoFL4R8byk3+3T/qKkK4u8NwAAbcUVrgAASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIr44sVWoVrOgPtwTWdUVeMfAEASGysR77L+UYnRsRAvT1x6RUD913/L98YYSXA0hj5AgCQGOELAEBihC8AAIkRvgAAJEb4AgCQ2NBnO9u+UNIDuabzJf1XSadJ+jNJs1n7JyNi59AVAhhbl1xw+8B9/8d/HGEhQMmGDt+IeFbShCTZXiHpkKSHJH1U0p0R8elSKlx+XbI90D0AAFUoa9r5SkkHIuLHJb3f0Hqfxx30HgCA1MoK302Stuee32h7r+2ttk8vaR0D6Y1oB70HACC1wuFr+02S3ifp77OmuyRdoLkp6SOS7lhguUnb07ani9Yw732XdQ+guPz+PDs7u/QCwJgr4/KSV0n6TkQclaTevSTZvlvS1/otFBFTkqayfqUNQ5dzzJcABsqR35+73W4l00pcMhJNUsa082blppxtr8q9dq2kfSWsY2CMfAEAdVdo5Gv71yW9R9L1uea/sT0hKSQdnPfayDHyBQDUXaHwjYhXJL1lXtuHC1VUECNfAEDdte4KV5ztDACou9aFLyNfAEDdlXG2MwCMxGMHbqm6BGAkWjfyBQCg7ghfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAYKX9tbbR+zvS/XdobtXbafy+5Pz9pt+zO299vea/udoyoeAIAmGnTke6+kDfPabpL0SESsk/RI9lySrpK0LrtNSrqreJkAALTHQOEbEY9KOj6veaOkbdnjbZLen2u/L+Y8Luk026vKKBYAgDYocsz37Ig4IknZ/VlZ+2pJL+T6zWRtAABAoznhyn3a4g2d7Enb07anR1ADgITy+/Ps7GzV5QC1VyR8j/amk7P7Y1n7jKRzc/3OkXR4/sIRMRUR3YjoFqgBQA3k9+dOp1N1OUDtFQnfhyVtyR5vkfTVXPtHsrOe3yXp573paQAAIK0cpJPt7ZIul3Sm7RlJt0r6lKQHbV8n6SeSPph13ynpakn7Jb0i6aMl1wwAQKMNFL4RsXmBl67s0zckfbxIUQAAtBlXuAIAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCd5nmvrQJQBv4C9dUXQLGFOE7BAIYaA8CGFUgfIdEAAPtQQAjtSXD1/ZW28ds78u1/a3tZ2zvtf2Q7dOy9rW2f2F7T3b7/CiLrxoBDLQHAYyUBhn53itpw7y2XZJ+OyJ+R9IPJd2ce+1ARExktxvKKbO+CGCgPQhgpLJk+EbEo5KOz2v7ekScyJ4+LumcEdTWGAQw0B4EMFIo45jvxyT9Y+75eba/a/ubti8r4f0bgQAG2oMAxqgVCl/bt0g6IemLWdMRSWsi4h2S/kLSl2z/5gLLTtqetj1dpIY6IYAxrvL78+zsbNXllIIAxigNHb62t0i6RtKfRJY6EfFqRLyYPd4t6YCkt/ZbPiKmIqIbEd1ha6gjAhjjKL8/dzqdqsspDQGMURkqfG1vkPRXkt4XEa/k2ju2V2SPz5e0TtLzZRTaJAQw0B4EMEZhkI8abZf0mKQLbc/Yvk7SZyWdKmnXvI8UvVvSXtvfk/RlSTdExPG+b9xyBDDQHgQwyrZyqQ4RsblP8z0L9N0haUfRotoiImS76jIAlMBfuEZx/deqLgMtwRWuRowRMNAejIBRFsI3AQIYaA8CGGUgfBMhgIH2IIBRFOGbEAEMtAcBjCII38QIYKA9CGAMi/CtAAEMtAcBjGEQvhUhgIH2IICxXIRvhQhgoD0IYCwH4VsxAhhoDwIYg1ryCld4Pa5YBbQHV6xCVRj5AgCQGOELAEBiYzntvNhxVqaVgWZ5ec/CY4hTJ15LWAkwuLEK30FObur1IYSBelssdOf3IYRRN2Mz7bzcs4o5Cxmor0GCt0h/YNT4jQQAILElw9f2VtvHbO/Ltd1m+5DtPdnt6txrN9veb/tZ2384qsKXY9hRLKNfoH6GHcUy+kWdDPLbeK+kDX3a74yIiey2U5Jsv13SJkkXZct8zvaKsooFAKANlgzfiHhU0vEB32+jpPsj4tWI+JGk/ZLWF6gPAIDWKTIPc6Ptvdm09OlZ22pJL+T6zGRtAAAgM2z43iXpAkkTko5IuiNr7/f5nL4HTm1P2p62PT1kDQBqIr8/z87OVl0OUHtDhW9EHI2IkxHxmqS79aup5RlJ5+a6niPp8ALvMRUR3YjoDlMDgPrI78+dTqfqcoDaGyp8ba/KPb1WUu9M6IclbbL9ZtvnSVon6YliJQIA0C5LXuHK9nZJl0s60/aMpFslXW57QnNTygclXS9JEfGU7Qcl/UDSCUkfj4iToyl9cLaH+tgQV7kC6ufUideG+tgQV7lCnSwZvhGxuU/zPYv0v13S7UWKAgCgzcbmU+fLHcUy6gXqa7mjWEa9qJux+mKFXqDyrUZA8/UClW81QhONVfj2ELBAexCwaKKxmXYGAKAuCF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDElgxf21ttH7O9L9f2gO092e2g7T1Z+1rbv8i99vlRFg8AQBMN8pWC90r6rKT7eg0R8ce9x7bvkPTzXP8DETFRVoEAALTNkuEbEY/aXtvvNc99Me6HJP1+uWUBANBeRY/5XibpaEQ8l2s7z/Z3bX/T9mUF3x8AgNYZZNp5MZslbc89PyJpTUS8aPtiSf9g+6KIeGn+grYnJU0WXD+AGsjvz2vWrKm4GqD+hh752l4p6Y8kPdBri4hXI+LF7PFuSQckvbXf8hExFRHdiOgOWwOAesjvz51Op+pygNorMu38B5KeiYiZXoPtju0V2ePzJa2T9HyxEgEAaJdBPmq0XdJjki60PWP7uuylTXr9lLMkvVvSXtvfk/RlSTdExPEyCwYAoOkGOdt58wLtf9qnbYekHcXLAgCgvbjCFQAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJCYI6LqGmR7VtK/Sfpp1bWU4EyxHXXShO34rYhozffw2X5Z0rNV11GCJvzuDILtSGug/bkW4StJtqfb8N2+bEe9tGU7mqQtP3O2o17ash09TDsDAJAY4QsAQGJ1Ct+pqgsoCdtRL23ZjiZpy8+c7aiXtmyHpBod8wUAYFzUaeQLAMBYIHwBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASG1n42t5g+1nb+23fNKr1AADQNI6I8t/UXiHph5LeI2lG0pOSNkfED0pfGQAADTOqke96Sfsj4vmI+KWk+yVtHNG6AABolFGF72pJL+Sez2RtAACMvZUjel/3aXvd/LbtSUmT2dOLR1QH0AQ/jYhO1UUUkd+fTznllIvf9ra3VVwRUI3du3cPtD+PKnxnJJ2be36OpMP5DhExJWlKkmyXf+AZaI4fV11AUfn9udvtxvT0dMUVAdWwPdD+PKpp5yclrbN9nu03Sdok6eERrQsAgEYZycg3Ik7YvlHSP0taIWlrRDw1inUBANA0o5p2VkTslLRzVO8PAEBTcYUrAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAILGRfc4XAIZ1yQW3l/6ejx24pfT3BIbVmpFv73uJy74HAKBsrRn52h7Jfdss9kdFW7cZaKvJHYcWfG3qA3yLa521JnwjQrZLv2+LQUbyvT5t2m6gjRYL3fl9COF6ak34MvJd2Pzg7bdt+T5t+8MDaJP5wdsvXPN9JnccIoBriGO+Y3bMd6FQJWyB5lkoVAnb+mtN+DLy7S//R8RS25R/va1/fABNlh/RLhWw+dcHmaZGWq0JX0a+ixv0j4m2/NEBtNmgI1tGwPXVmvBl5AsAaIqhw9f2uba/Yftp20/Z/vOs/Tbbh2zvyW5Xl1fuwhj5AgCaosjZzick/WVEfMf2qZJ2296VvXZnRHy6eHmDY+QLAGiKoUe+EXEkIr6TPX5Z0tOSKjvAwMgXANAUpRzztb1W0jsk/WvWdKPtvba32j69jHUMUMNI7tti0D8m+KMDqL9Bz17mLOf6Khy+tn9D0g5Jn4iIlyTdJekCSROSjki6Y4HlJm1P254uWoPEyHchy/n40HI+lgTk5ffn2dnZqstpreV8fGg5H0tCeoXC1/avaS54vxgRX5GkiDgaEScj4jVJd0ta32/ZiJiKiG5EdIvUkKtlJPdts1AAt+WPDVQjvz93Op2qyxkbCwUwI976G/qEK8+l0z2Sno6Iv8u1r4qII9nTayXtK1biYLi288J629OzVNC2ZbuBNpr6wOo3XD5yqf6onyJnO18q6cOSvm97T9b2SUmbbU9ICkkHJV1fqMIBMfJdXG97+FYjoPl6gcq3GjXX0OEbEd+W1O9/653DlzM8Rr6DaeM2AeOKgG0urnA1JiNfAEB9tCZ8OdsZANAUrQlfRr4AgKYocsIVAIzEYwduqboEYKRaM/IFAKApCF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgscJfrGD7oKSXJZ2UdCIiurbPkPSApLWSDkr6UET8rOi6AABog7JGvldExEREdLPnN0l6JCLWSXokew4AADS6aeeNkrZlj7dJev+I1gMAQOOUEb4h6eu2d9uezNrOjogjkpTdn1XCegAAaIXCx3wlXRoRh22fJWmX7WcGWSgL6sklOwKovfz+vGbNmoqrAeqv8Mg3Ig5n98ckPSRpvaSjtldJUnZ/rM9yUxHRzR0nBtBQ+f250+lUXQ5Qe4XC1/Yptk/tPZb0Xkn7JD0saUvWbYukrxZZDwAAbVJ02vlsSQ/Z7r3XlyLin2w/KelB29dJ+omkDxZcDwAArVEofCPieUm/26f9RUlXFnlvAADaiitcAQCQGOELAEBihC8AAIkRvgAAJEb4AgCQGOELAEBihC8AAIkRvgAAJEb4AgCQGOELAEBihC8AAIkRvgAAJEb4AgCQGOELAEBihG9CEVF1CQBK4i9cU3UJaDDCNzECGGgPAhjDGjp8bV9oe0/u9pLtT9i+zfahXPvVZRbcBgQw0B4EMIYxdPhGxLMRMRERE5IulvSKpIeyl+/svRYRO8sotG0IYKA9CGAsV1nTzldKOhARPy7p/cYCAQy0BwGM5SgrfDdJ2p57fqPtvba32j69pHW0EgEMtAcBjEEVDl/bb5L0Pkl/nzXdJekCSROSjki6Y4HlJm1P254uWkPTEcBouvz+PDs7W3U5lSKAMYgyRr5XSfpORByVpIg4GhEnI+I1SXdLWt9voYiYiohuRHRLqKHxCGA0WX5/7nQ6VZdTOQIYSykjfDcrN+Vse1XutWsl7SthHWOBAAbagwDGYgqFr+1fl/QeSV/JNf+N7e/b3ivpCkn/pcg6xg0BDLQHAYyFrCyycES8Iukt89o+XKgiKCJku+oyAJTAX7hGcf3Xqi4DNcMVrmqKETDQHoyAMR/hW2MEMNAeBDDyCN+aI4CB9iCA0VPomC+Wh+O4QHtwHBdFMPIFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIbKDwtb3V9jHb+3JtZ9jeZfu57P70rN22P2N7v+29tt85quIBAGiiQUe+90raMK/tJkmPRMQ6SY9kzyXpKknrstukpLuKlwkAQHsMFL4R8aik4/OaN0ralj3eJun9ufb7Ys7jkk6zvaqMYgEAaIMix3zPjogjkpTdn5W1r5b0Qq7fTNYGAAA0mhOu+n1jfLyhkz1pe9r29AhqAJBQfn+enZ2tuhyg9oqE79HedHJ2fyxrn5F0bq7fOZIOz184IqYiohsR3QI1AKiB/P7c6XSqLgeovSLh+7CkLdnjLZK+mmv/SHbW87sk/bw3PQ0AAKSVg3SyvV3S5ZLOtD0j6VZJn5L0oO3rJP1E0gez7jslXS1pv6RXJH205JoBAGi0gcI3IjYv8NKVffqGpI8XKQoAgDbjClcAACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYgNd2xn1Nnc57cXZ/b5mGUDdPHHpFUv2Wf8v30hQCUaJkS8AAIkRvgAAJEb4AgCQGOELAEBiS4av7a22j9nel2v7W9vP2N5r+yHbp2Xta23/wvae7Pb5URYPAEATDTLyvVfShnltuyT9dkT8jqQfSro599qBiJjIbjeUUyYAAO2xZPhGxKOSjs9r+3pEnMiePi7pnBHUBgBAK5VxzPdjkv4x9/w829+1/U3bl5Xw/gAAtEqhi2zYvkXSCUlfzJqOSFoTES/avljSP9i+KCJe6rPspKTJIuvHHC6ggarl9+c1a9ZUXE2zcQGN8TD0yNf2FknXSPqTyC6xFBGvRsSL2ePdkg5Iemu/5SNiKiK6EdEdtgYA9ZDfnzudTtXlALU3VPja3iDpryS9LyJeybV3bK/IHp8vaZ2k58soFACAtlhy2tn2dkmXSzrT9oykWzV3dvObJe3Kpjwfz85sfrekv7Z9QtJJSTdExPG+bwwAwJhaMnwjYnOf5nsW6LtD0o6iRQEA0GZc4QoAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACCxJcPX9lbbx2zvy7XdZvuQ7T3Z7ercazfb3m/7Wdt/OKrCAQBoqkFGvvdK2tCn/c6ImMhuOyXJ9tslbZJ0UbbM52yvKKtYAADaYMnwjYhHJR0f8P02Sro/Il6NiB9J2i9pfYH6AABonSLHfG+0vTeblj49a1st6YVcn5msDQAAZIYN37skXSBpQtIRSXdk7e7TN/q9ge1J29O2p4esAUBN5Pfn2dnZqssBam+o8I2IoxFxMiJek3S3fjW1PCPp3FzXcyQdXuA9piKiGxHdYWoAUB/5/bnT6VRdDlB7Q4Wv7VW5p9dK6p0J/bCkTbbfbPs8SeskPVGsRAAA2mXlUh1sb5d0uaQzbc9IulXS5bYnNDelfFDS9ZIUEU/ZflDSDySdkPTxiDg5mtIBAGimJcM3Ijb3ab5nkf63S7q9SFEAALTZkuHbZhF9zwWTJNn9zh0DUFcv71n4KNqpE68lrARY2theXnKx4B3kdQD1sVjwDvI6kNpY/kYOGqwEMFB/gwYrAYw64bexJiKCsAdaYtu3LtG2b11SdRmoMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDExjJ8B72ABhfaAOpv0AtocKEN1MlYhq+0dLASvEBzLBWsBC/qZqwvL0nAAu1BwKJJxnbkCwBAVcZ65JvCcq9aNWh/Ru1Aesu9atWg/bdc9tgw5aDBGPkCAJAYI98RG3SE2hvxMqIF6mvQEWpvxMuIFgtZcuRre6vtY7b35doesL0nux20vSdrX2v7F7nXPj/K4gEAaKJBRr73SvqspPt6DRHxx73Htu+Q9PNc/wMRMVFWgQAAtM2S4RsRj9pe2+81z82RfkjS75dbFgAA7VX0hKvLJB2NiOdybefZ/q7tb9q+rOD7AwDQOkVPuNosaXvu+RFJayLiRdsXS/oH2xdFxEvzF7Q9KWmy4PoB1EB+f16zZk3F1QD1N/TI1/ZKSX8k6YFeW0S8GhEvZo93Szog6a39lo+IqYjoRkR32BoA1EN+f+50OlWXA9RekWnnP5D0TETM9Bpsd2yvyB6fL2mdpOeLlQgAQLsM8lGj7ZIek3Sh7Rnb12UvbdLrp5wl6d2S9tr+nqQvS7ohIo6XWTAAAE03yNnOmxdo/9M+bTsk7SheFgAA7cUVrmqCK1sB7cGVrbAUru0MAEBihC8AAIkRvgAAJEb4AgCQGOELAEBihC8AAIkRvgAAJEb4AgCQGOELAEBihC8AAIkRvgAAJOaIqLoG2Z6V9G+Sflp1LSU4U2xHnTRhO34rIlrzJbi2X5b0bNV1lKAJvzuDYDvSGmh/rkX4SpLt6YjoVl1HUWxHvbRlO5qkLT9ztqNe2rIdPUw7AwCQGOELAEBidQrfqaoLKAnbUS9t2Y4macvPnO2ol7Zsh6QaHfMFAGBc1GnkCwDAWCB8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEhtZ+NreYPtZ2/tt3zSq9QAA0DSOiPLf1F4h6YeS3iNpRtKTkjZHxA9KXxkAAA0zqpHvekn7I+L5iPilpPslbRzRugAAaJSVI3rf1ZJeyD2fkfR7+Q62JyVNZk8vHlEdQBP8NCI6VRdRRH5/PuWUUy5+29veVnFFQDV279490P48qvB1n7bXzW9HxJSkKUmyXf7cN9AcP666gKLy+3O3243p6emKKwKqYXug/XlU084zks7NPT9H0uERrQsAgEYZVfg+KWmd7fNsv0nSJkkPj2hdAAA0ykimnSPihO0bJf2zpBWStkbEU6NYFwAATTOqY76KiJ2Sdo7q/QEAaCqucAUAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGLGkTQcAAAQ1UlEQVSELwAAiRG+AAAkRvgCAJAY4QsAQGIj+1YjFBMRsj3wPYB6uuSC25e9zGMHbhlBJagTRr411QvUQe8BAM1B+NZURCzrHgDQHEOHr+1zbX/D9tO2n7L951n7bbYP2d6T3a4ur9zxwcgXANqryDHfE5L+MiK+Y/tUSbtt78peuzMiPl28vPHFMV8AaK+hwzcijkg6kj1+2fbTklaXVdi4G4eR7zBT5k3eXqDNnrj0imUvs/5fvjGCSpqhlGO+ttdKeoekf82abrS91/ZW26cvsMyk7Wnb02XU0DYc80WT5Pfn2dnZqssBaq9w+Nr+DUk7JH0iIl6SdJekCyRNaG5kfEe/5SJiKiK6EdEtWkMbjcPIF+2R3587nU7V5QC1Vyh8bf+a5oL3ixHxFUmKiKMRcTIiXpN0t6T1xcscP4x8AaC9ipztbEn3SHo6Iv4u174q1+1aSfuGL298MfIFgPYqcrbzpZI+LOn7tvdkbZ+UtNn2hKSQdFDS9YUqHFOc7QwA7VXkbOdvS+r3v/7O4ctBDyNfAGgvrnBVUxzzBYD24osVaoqRL9AOfEkC+mHkCwBAYoQvAACJEb4AACRG+AIAkBgnXKEynCwGtMc4f0nCMBj5AgCQGOELAEBihC8AAIkRvgAAJEb4AgCQGOELAEBihC8AAIkRvgAAJEb4AgCQWOErXNk+KOllSSclnYiIru0zJD0gaa2kg5I+FBE/K7ou1NNyvlOYq1oB9Ta549DAfac+sHqElbRbWSPfKyJiIiK62fObJD0SEeskPZI9b71x+2L7iFj2Ng+zDFAFf+GaqktIanLHoWUF77DLYM6opp03StqWPd4m6f0jWk/tjEuwFN3Ocfk5odnGJYCLBigBvHxlhG9I+rrt3bYns7azI+KIJGX3Z81fyPak7Wnb0yXUUCttD5aytq/tP6dxkt+fZ2dnqy6nVG0P4LKCkwBeHhf9D9D2f4qIw7bPkrRL0n+W9HBEnJbr87OIOH2R92jF/8Lzf5ZtPL652O/LYts77HJjYnfukE3jdbvdmJ5u/t/U80M3rv9aRZWMzmKBudjx3GGXGwe2B9qfC498I+Jwdn9M0kOS1ks6antVVsgqSceKrqeJ2jayW2h7bC8ZoIv1advPCe3UthHwQgE69YHVSwboYn0YAQ+mUPjaPsX2qb3Hkt4raZ+khyVtybptkfTVIutpsrYEy2LBuxwEMJqsLQG8WPAuBwE8vKIj37Mlfdv29yQ9Iel/R8Q/SfqUpPfYfk7Se7LnY6utwTLsdDHTzGiytgTwfMNOF4/7NPOwCoVvRDwfEb+b3S6KiNuz9hcj4sqIWJfdHy+n3OZqcgD3q71ogPZbvsk/I4yXJgdwv1Fp0QDttzyj38VxhauECBegPZocwKge4ZtYGwK4rGljpp/RdG0I4LKmjZl+Xh7CtwJtCGAAc9oQwEiP8K0IAQy0BwGM5SJ8K0QAA+1BAGM5CN+KEcBAexDAGBThWwMEMNAeBDAGQfjWBAEMtAcBjKWsrLqANhmXj85ERCnbyh8cqLM2fpFCP5M7DpXyMSEuqrE8jHwBAEiM8MWSRnEpyFFcshLA0kZxKchRXLKy7QhfDG3YAGa6GaifYQOY6ebhEL4YSFlfBVjWVxMCGF5ZXwVY1lcTjiPCFwNbLICXCuHF+hC8QHqLBfBSIbxYH4J3MJztjGWxvWCIDjOdTPAC1Zn6wOoFQ3SY6WSCd3CMfLFsfKsR0B58q1E1hg5f2xfa3pO7vWT7E7Zvs30o1351mQWjHooGJ8EL1EfR4CR4l2/oaeeIeFbShCTZXiHpkKSHJH1U0p0R8elSKkRt9QJ0OdPNhC5QT70AXc50M6E7vLKO+V4p6UBE/Jj/XMcP/+ZAexCoaZR1zHeTpO255zfa3mt7q+3T+y1ge9L2tO3pkmoAUJH8/jw7O1t1OUDtuegFD2y/SdJhSRdFxFHbZ0v6qaSQ9N8krYqIjy3xHlx1AeNsd0R0qy6iLN1uN6an+Zsa48n2QPtzGSPfqyR9JyKOSlJEHI2IkxHxmqS7Ja0vYR0AALRGGeG7WbkpZ9urcq9dK2lfCesAAKA1Cp1wZfvXJb1H0vW55r+xPaG5aeeD814DAGDsFQrfiHhF0lvmtX24UEUAALQcV7gCACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIbKDwtb3V9jHb+3JtZ9jeZfu57P70rN22P2N7v+29tt85quIBAGiiQUe+90raMK/tJkmPRMQ6SY9kzyXpKknrstukpLuKlwkAQHsMFL4R8aik4/OaN0ralj3eJun9ufb7Ys7jkk6zvaqMYgEAaIMix3zPjogjkpTdn5W1r5b0Qq7fTNb2OrYnbU/bni5QA4AayO/Ps7OzVZcD1N4oTrhyn7Z4Q0PEVER0I6I7ghoAJJTfnzudTtXlALVXJHyP9qaTs/tjWfuMpHNz/c6RdLjAegAAaJUi4fuwpC3Z4y2Svppr/0h21vO7JP28Nz0NAACklYN0sr1d0uWSzrQ9I+lWSZ+S9KDt6yT9RNIHs+47JV0tab+kVyR9tOSaAQBotIHCNyI2L/DSlX36hqSPFykKAIA24wpXAAAkRvgCAJDYQNPOGD9zRw/6s/t9mgxAXb28Z+Fx1qkTryWsBD2EL15nsdCd34cQBuptsdCd34cQTotpZ/y7QYK3SH8A6QwSvEX6oxh+2gAAJEb4QtLwo1hGv0D9DDuKZfSbDj9pAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8IWk4a9WxVWugPoZ9mpVXOUqHcIXAIDECF/8u+WOYhn1AvW13FEso960lgxf21ttH7O9L9f2t7afsb3X9kO2T8va19r+he092e3zoywe5bO9ZKgO0gdA9U6deG3JUB2kD8o3yLca3Svps5Luy7XtknRzRJyw/d8l3Szpr7LXDkTERKlVIjnCFWgPwrV+lhz5RsSjko7Pa/t6RJzInj4u6ZwR1AYAQCuVccz3Y5L+Mff8PNvftf1N25cttJDtSdvTtqdLqAFAhfL78+zsbNXlALVXKHxt3yLphKQvZk1HJK2JiHdI+gtJX7L9m/2WjYipiOhGRLdIDQCql9+fO51O1eUAtTd0+NreIukaSX8S2ffKRcSrEfFi9ni3pAOS3lpGoQAAtMVQ4Wt7g+ZOsHpfRLySa+/YXpE9Pl/SOknPl1EoAABtseTZzra3S7pc0pm2ZyTdqrmzm98saVd2VuzjEXGDpHdL+mvbJySdlHRDRBzv+8YAAIypJcM3Ijb3ab5ngb47JO0oWhQAAG3GFa4AAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASWzJ8bW+1fcz2vlzbbbYP2d6T3a7OvXaz7f22n7X9h6MqHACAphpk5HuvpA192u+MiInstlOSbL9d0iZJF2XLfM72irKKBQCgDZYM34h4VNLxAd9vo6T7I+LViPiRpP2S1heoDwCA1ilyzPdG23uzaenTs7bVkl7I9ZnJ2t7A9qTtadvTBWoAUAP5/Xl2drbqcoDaGzZ875J0gaQJSUck3ZG1u0/f6PcGETEVEd2I6A5ZA4CayO/PnU6n6nKA2hsqfCPiaEScjIjXJN2tX00tz0g6N9f1HEmHi5UIAEC7DBW+tlflnl4rqXcm9MOSNtl+s+3zJK2T9ESxEgEAaJeVS3WwvV3S5ZLOtD0j6VZJl9ue0NyU8kFJ10tSRDxl+0FJP5B0QtLHI+LkaEoHAKCZlgzfiNjcp/meRfrfLun2IkUBANBmXOEKAIDECF8AABIjfAEASIzwbYiIUETfj0wDaJht37pE2751SdVloEKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJDYkt/ni9EY9jrNy13O9lDrATC4Ya/TvNzltlz22FDrQf0sOfK1vdX2Mdv7cm0P2N6T3Q7a3pO1r7X9i9xrnx9l8QAANNEgI997JX1W0n29hoj4495j23dI+nmu/4GImCirwLZa7oi0N+JlJAvUz3JHpL0RLyPZ8bVk+EbEo7bX9nvNc0nwIUm/X25ZAAC0V9ETri6TdDQinsu1nWf7u7a/afuyhRa0PWl72vZ0wRoAVCy/P8/OzlZdDlB7RcN3s6TtuedHJK2JiHdI+gtJX7L9m/0WjIipiOhGRLdgDQAqlt+fO51O1eUAtTd0+NpeKemPJD3Qa4uIVyPixezxbkkHJL21aJEAALRJkZHvH0h6JiJmeg22O7ZXZI/Pl7RO0vPFSgQAoF0G+ajRdkmPSbrQ9ozt67KXNun1U86S9G5Je21/T9KXJd0QEcfLLBgAgKYb5GznzQu0/2mfth2SdhQvCwCA9uLykgAAJEb4AgCQGOELAEBihC8AAInxrUYNwTWdgfbgms5g5AsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiTkiqq5Btmcl/Zukn1ZdSwnOFNtRJ03Yjt+KiNZ8Ca7tlyU9W3UdJWjC784g2I60BtqfaxG+kmR7OiK6VddRFNtRL23ZjiZpy8+c7aiXtmxHD9POAAAkRvgCAJBYncJ3quoCSsJ21EtbtqNJ2vIzZzvqpS3bIalGx3wBABgXdRr5AgAwFioPX9sbbD9re7/tm6quZzlsH7T9fdt7bE9nbWfY3mX7uez+9KrrnM/2VtvHbO/LtfWt23M+k/377LX9zuoqf70FtuM224eyf5M9tq/OvXZzth3P2v7DaqpuN/bn9Nifm7k/Vxq+tldI+p+SrpL0dkmbbb+9ypqGcEVETOROgb9J0iMRsU7SI9nzurlX0oZ5bQvVfZWkddltUtJdiWocxL1643ZI0p3Zv8lEROyUpOz3apOki7JlPpf9/qEk7M+VuVfsz43bn6se+a6XtD8ino+IX0q6X9LGimsqaqOkbdnjbZLeX2EtfUXEo5KOz2teqO6Nku6LOY9LOs32qjSVLm6B7VjIRkn3R8SrEfEjSfs19/uH8rA/V4D9uZn7c9Xhu1rSC7nnM1lbU4Skr9vebXsyazs7Io5IUnZ/VmXVLc9CdTfx3+jGbEpta26asInb0TRN/xmzP9dTK/fnqsPXfdqadPr1pRHxTs1N5Xzc9rurLmgEmvZvdJekCyRNSDoi6Y6svWnb0URN/xmzP9dPa/fnqsN3RtK5uefnSDpcUS3LFhGHs/tjkh7S3LTH0d40TnZ/rLoKl2Whuhv1bxQRRyPiZES8Julu/WoqqlHb0VCN/hmzP9dPm/fnqsP3SUnrbJ9n+02aO4D+cMU1DcT2KbZP7T2W9F5J+zRX/5as2xZJX62mwmVbqO6HJX0kO0vyXZJ+3pvOqqN5x6+u1dy/iTS3HZtsv9n2eZo74eSJ1PW1HPtzfbA/111EVHqTdLWkH0o6IOmWqutZRt3nS/pednuqV7ukt2ju7MLnsvszqq61T+3bNTeF8/809xfkdQvVrbnpnf+Z/ft8X1K36vqX2I7/ldW5V3M76Kpc/1uy7XhW0lVV19/GG/tzJbWzPzdwf+YKVwAAJFb1tDMAAGOH8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAAS+/8+rgO+nO/fhAAAAABJRU5ErkJggg==\n", 21 | "text/plain": [ 22 | "
" 23 | ] 24 | }, 25 | "metadata": { 26 | "needs_background": "light" 27 | }, 28 | "output_type": "display_data" 29 | } 30 | ], 31 | "source": [ 32 | "%matplotlib inline\n", 33 | "%load_ext autoreload\n", 34 | "%autoreload 2\n", 35 | "\n", 36 | "import os,sys\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "import numpy as np\n", 39 | "import helper\n", 40 | "import simulation\n", 41 | "\n", 42 | "# Generate some random images\n", 43 | "input_images, target_masks = simulation.generate_random_data(192, 192, count=3)\n", 44 | "\n", 45 | "for x in [input_images, target_masks]:\n", 46 | " print(x.shape)\n", 47 | " print(x.min(), x.max())\n", 48 | "\n", 49 | "# Change channel-order and make 3 channels for matplot\n", 50 | "input_images_rgb = [x.astype(np.uint8) for x in input_images]\n", 51 | "\n", 52 | "# Map each channel (i.e. class) to each color\n", 53 | "target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]\n", 54 | "\n", 55 | "# Left: Input image, Right: Target mask\n", 56 | "helper.plot_side_by_side([input_images_rgb, target_masks_rgb])" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "{'train': 2000, 'val': 200}" 68 | ] 69 | }, 70 | "execution_count": 2, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "from torch.utils.data import Dataset, DataLoader\n", 77 | "from torchvision import transforms, datasets, models\n", 78 | "\n", 79 | "class SimDataset(Dataset):\n", 80 | " def __init__(self, count, transform=None):\n", 81 | " self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count) \n", 82 | " self.transform = transform\n", 83 | " \n", 84 | " def __len__(self):\n", 85 | " return len(self.input_images)\n", 86 | " \n", 87 | " def __getitem__(self, idx): \n", 88 | " image = self.input_images[idx]\n", 89 | " mask = self.target_masks[idx]\n", 90 | " if self.transform:\n", 91 | " image = self.transform(image)\n", 92 | " \n", 93 | " return [image, mask]\n", 94 | "\n", 95 | "# use same transform for train/val for this example\n", 96 | "trans = transforms.Compose([\n", 97 | " transforms.ToTensor(),\n", 98 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet\n", 99 | "])\n", 100 | "\n", 101 | "train_set = SimDataset(2000, transform=trans)\n", 102 | "val_set = SimDataset(200, transform=trans)\n", 103 | "\n", 104 | "image_datasets = {\n", 105 | " 'train': train_set, 'val': val_set\n", 106 | "}\n", 107 | "\n", 108 | "batch_size = 25\n", 109 | "\n", 110 | "dataloaders = {\n", 111 | " 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),\n", 112 | " 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)\n", 113 | "}\n", 114 | "\n", 115 | "dataset_sizes = {\n", 116 | " x: len(image_datasets[x]) for x in image_datasets.keys()\n", 117 | "}\n", 118 | "\n", 119 | "dataset_sizes" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 3, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192])\n", 132 | "-2.117904 2.64 -1.8839339 0.6775894\n", 133 | "0.0 1.0 0.004708478 0.06845663\n" 134 | ] 135 | }, 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "" 140 | ] 141 | }, 142 | "execution_count": 3, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | }, 146 | { 147 | "data": { 148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD8CAYAAAB+fLH0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADnJJREFUeJzt3WGsZGV9x/HvDywmtSRgFEIAC5LVREyzRYImRoO1KpLGlSbY5UXdqOliAklf9EXBJtW0aWJaqYlpxawpARMFaStKGqoS0uibUoFKUVRkQZTLbpaijdpqNLv774s5N86z3Lv33plzZubO/X6SmznzzJk5z9m787vPc+bM+aeqkKRVp8y7A5IWi6EgqWEoSGoYCpIahoKkhqEgqTFYKCS5IsljSQ4muWGo7UjqV4Y4TyHJqcB3gbcAK8ADwDVV9a3eNyapV0ONFC4DDlbVk1X1S+AOYM9A25LUoxcM9LrnAk+P3V8BXrveykk8rVIa3nNV9dKNVhoqFLJGW/PGT7If2D/Q9iU93/c3s9JQobACnD92/zzg0PgKVXUAOACOFKRFMtQxhQeAXUkuTHIasBe4e6BtSerRICOFqjqa5HrgS8CpwC1V9egQ25LUr0E+ktxyJ5w+SLPwUFVdutFKntEoqWEoSGoYCpIahoKkhqEgqWEoSGoYCpIahoKkhqEgqWEoSGoYCpIahoKkhqEgqWEoSGoYCpIahoKkhqEgqWEoSGoYCpIahoKkxsShkOT8JP+W5NtJHk3yx137h5I8k+Th7ufK/roraWjTXOL9KPAnVfWfSU4HHkpyb/fYR6vqI9N3T9KsTRwKVXUYONwt/zTJtxnVkJS0jfVyTCHJBcBvA//RNV2f5JEktyQ5c53n7E/yYJIH++iDpH5MXQwmyW8AXwH+qqo+l+Rs4DlGBWX/Ejinqt67wWtYDEYa3vDFYJL8GvDPwKer6nMAVXWkqo5V1XHgk8Bl02xD0mxN8+lDgH8Avl1VfzvWfs7YalcB35y8e5JmbZpPH14P/CHwjSQPd20fAK5JspvR9OEp4NqpeihppiwwK+0cFpiVtHWGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIakxzRqM0mOObPKnulGTgnuw8hoIWymbD4MT1DYf+GAqau60GwUavYUBMx2MKkhqOFDRXmxklnPiXf6PnHK9ytDAFQ0Fzc7I398ne1OOPrfcaBsPknD5o4Wzlzewbv3+GguZivb/wk7zJ13tOHwcwdyJDQQvhlGSqv/qOGPpjKGjmTvwL3tcbeq1gcbSwdYaCpMbUnz4keQr4KXAMOFpVlyZ5MfBZ4AJGF299V1X9z7TbkjS8vkYKb6qq3WMXhbwBuK+qdgH3dfclbQNDTR/2ALd1y7cB7xxoO9rmhjhA6EHH6fQRCgV8OclDSfZ3bWd3BWhXC9Ge1cN2JM1AH2c0vr6qDiU5C7g3yXc286QuQPZvuKKkmZp6pFBVh7rbZ4G7GNWOPLJaPq67fXaN5x2oqks3U5xCy2uIjwz9GHI60xaYfVGS01eXgbcyqh15N7CvW20f8IVptqPta/UNerzqpN9TWGtZ8zHt9OFs4K5RrVleAHymqr6Y5AHgziTvA34AXD3ldrRNrR702+wXnDxIOH/WktSgVr+tOH6FpLVGA6thMM23G0/2ugKsJSlpEl5PQYNaa/qw1mhh2msteiyiP44UNKi1DjSebIowyZu7z69hy1DQwMZHCps56AhbCwZHCP1z+qBBrXegcSvBMMnXoR0lTM5Q0KDWO6YwfruZC7FOsk1NxumDBrXeMYXxW+jnykvTvoZGHCloUJsZKaz1mGXj5seRgqSGIwUtJEcA8+NIQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVJj4jMak7ySUb3IVS8H/hw4A/gj4L+79g9U1T0T91DSTPVy4dYkpwLPAK8F3gP8b1V9ZAvP90oZ0vBmeuHWNwNPVNX3e3o9SXPSVyjsBW4fu399kkeS3JLkzJ62IWkGpg6FJKcB7wD+sWu6GbgI2A0cBm5a53n7kzyY5MFp+yCpP1MfU0iyB7iuqt66xmMXAP9SVa/e4DU8piANb2bHFK5hbOqwWli2cxWj2pKStompLrKS5NeBtwDXjjX/dZLdQAFPnfCYpAVnLUlp57CWpKStMxQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUvRaCMfHvphnGfr5cqQgqWEoSGo4fZBopy9r2UlTGkcK2vE2CoTNrrMsDAXtaFt5s++UYDAUJDUMBUmNTYVCV+np2STfHGt7cZJ7kzze3Z7ZtSfJx5Ic7KpEXTJU5yX1b7MjhVuBK05ouwG4r6p2Afd19wHeDuzqfvYzqhilHex41YY/06yvfm0qFKrqq8CPTmjeA9zWLd8GvHOs/VM1cj9wxgkFYiQtsGmOKZxdVYcButuzuvZzgafH1lvp2hrWkpQW0xAnL611lsfzxntVdQA4ABaDWXabOfHH7z4sjmlGCkdWpwXd7bNd+wpw/th65wGHptiONJitBNBOCatpQuFuYF+3vA/4wlj7u7tPIV4H/Hh1miEtos282XdKIMAmpw9JbgcuB16SZAX4IPBh4M4k7wN+AFzdrX4PcCVwEPgZ8J6e+yz1bie96TdigVktBI8pzIQFZiVtnaEgqWEoSGoYCpIaXnlJC8GDi4vDkYKkhqEgqWEoSGoYCpIahoKkhqEgqWEoSGp4noJ2jEmu7bgTz59wpCCpYShIahgKkhqGgqSGoSCpYShIamwYCuvUkfybJN/pakXeleSMrv2CJD9P8nD384khOy+pf5sZKdzK8+tI3gu8uqp+C/gucOPYY09U1e7u5/39dFPSrGwYCmvVkayqL1fV0e7u/YwKvkhaAn0cU3gv8K9j9y9M8vUkX0nyhvWeZC1JaTFNdZpzkj8DjgKf7poOAy+rqh8meQ3w+SQXV9VPTnyutSSlxTRxKCTZB/we8ObqKspU1S+AX3TLDyV5AngF4GhAc7cTv8cwiYmmD0muAP4UeEdV/Wys/aVJTu2WXw7sAp7so6OSZmPDkcI6dSRvBF4I3JtR+t7ffdLwRuAvkhwFjgHvr6ofrfnCkhaStSSlncNakpK2zlCQ1DAUJDUMhW1kksuJSVtlKGwzBoOGZihsQwaDhmQobFMGg4ZiKEhqGArbmKMFDcFQ2OaOVxkO6pWhsCQMBvXFUFgiBoP6YCgsGYNB0zIUlpDBoGkYCkvKYNCkDIUlZjBoEoaCpIahsOQ8j0FbZSjsEAaDNmvSWpIfSvLMWM3IK8ceuzHJwSSPJXnbUB3X1hkM2oxJa0kCfHSsZuQ9AEleBewFLu6e8/HVS75rMRgM2shEtSRPYg9wR1X9oqq+BxwELpuifxqAwaCTmeaYwvVdKfpbkpzZtZ0LPD22zkrXpgVjMGg9k4bCzcBFwG5G9SNv6trXqsu15v8+C8xu3SlJrz/SWiYKhao6UlXHquo48El+NUVYAc4fW/U84NA6r3Ggqi7dTHEKSbMzaS3Jc8buXgWsfjJxN7A3yQuTXMioluTXpuuipFmatJbk5Ul2M5oaPAVcC1BVjya5E/gWoxL111XVsWG6LmkI1pKUdg5rSUraOkNBUsNQkNQwFCQ1DAVJDUNBUsNQkNQwFCQ1DAVJDUNBUsNQkNQwFCQ1DAVJDUNBUsNQkNQwFCQ1DAVJDUNBUsNQkNSYtJbkZ8fqSD6V5OGu/YIkPx977BNDdl5S/za8mjOjWpJ/B3xqtaGq/mB1OclNwI/H1n+iqnb31UFJs7VhKFTVV5NcsNZjSQK8C/idfrslaV6mPabwBuBIVT0+1nZhkq8n+UqSN0z5+pJmbDPTh5O5Brh97P5h4GVV9cMkrwE+n+TiqvrJiU9Msh/YP+X2JfVs4pFCkhcAvw98drWtK0H/w275IeAJ4BVrPd9aktJimmb68LvAd6pqZbUhyUuTnNotv5xRLcknp+uipFnazEeStwP/DrwyyUqS93UP7aWdOgC8EXgkyX8B/wS8v6p+1GeHJQ3LWpLSzmEtSUlbZyhIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIakx7Oba+PAf8X3e7zF7Ccu/jsu8fbO99/M3NrLQQ11MASPLgsl+abdn3cdn3D3bGPjp9kNQwFCQ1FikUDsy7AzOw7Pu47PsHO2AfF+aYgqTFsEgjBUkLYO6hkOSKJI8lOZjkhnn3py9dNe5vdNW3H+zaXpzk3iSPd7dnzrufW7FOBfI19ykjH+t+r48kuWR+Pd+cdfbvQ0meGaukfuXYYzd2+/dYkrfNp9f9m2sodIVj/h54O/Aq4Jokr5pnn3r2pqraPfYR1g3AfVW1C7ivu7+d3ApccULbevv0dkbFgHYxKg9484z6OI1bef7+AXy0+z3urqp7ALr/p3uBi7vnfHy1ENJ2N++RwmXAwap6sqp+CdwB7Jlzn4a0B7itW74NeOcc+7JlVfVV4MTiPuvt0x7gUzVyP3BGknNm09PJrLN/69kD3NGVSvwecJDR/+dtb96hcC7w9Nj9la5tGRTw5SQPdcV0Ac6uqsMA3e1Zc+tdf9bbp2X63V7fTYFuGZvyLdP+NeYdClmjbVk+Dnl9VV3CaBh9XZI3zrtDM7Ysv9ubgYuA3Yyqqt/UtS/L/j3PvENhBTh/7P55wKE59aVXVXWou30WuIvR0PLI6hC6u312fj3szXr7tBS/26o6UlXHquo48El+NUVYiv1by7xD4QFgV5ILk5zG6MDN3XPu09SSvCjJ6avLwFuBbzLat33davuAL8ynh71ab5/uBt7dfQrxOuDHq9OM7eSE4yBXMfo9wmj/9iZ5YZILGR1Q/dqs+zeEuX5LsqqOJrke+BJwKnBLVT06zz715GzgriQw+jf+TFV9MckDwJ1d5e4fAFfPsY9b1lUgvxx4SZIV4IPAh1l7n+4BrmR0AO5nwHtm3uEtWmf/Lk+ym9HU4CngWoCqejTJncC3gKPAdVV1bB797ptnNEpqzHv6IGnBGAqSGoaCpIahIKlhKEhqGAqSGoaCpIahIKnx/34czn/DIR47AAAAAElFTkSuQmCC\n", 149 | "text/plain": [ 150 | "
" 151 | ] 152 | }, 153 | "metadata": { 154 | "needs_background": "light" 155 | }, 156 | "output_type": "display_data" 157 | } 158 | ], 159 | "source": [ 160 | "import torchvision.utils\n", 161 | "\n", 162 | "def reverse_transform(inp):\n", 163 | " inp = inp.numpy().transpose((1, 2, 0))\n", 164 | " mean = np.array([0.485, 0.456, 0.406])\n", 165 | " std = np.array([0.229, 0.224, 0.225])\n", 166 | " inp = std * inp + mean\n", 167 | " inp = np.clip(inp, 0, 1)\n", 168 | " inp = (inp * 255).astype(np.uint8)\n", 169 | " \n", 170 | " return inp\n", 171 | "\n", 172 | "# Get a batch of training data\n", 173 | "inputs, masks = next(iter(dataloaders['train']))\n", 174 | "\n", 175 | "print(inputs.shape, masks.shape)\n", 176 | "for x in [inputs.numpy(), masks.numpy()]:\n", 177 | " print(x.min(), x.max(), x.mean(), x.std())\n", 178 | "\n", 179 | "plt.imshow(reverse_transform(inputs[3]))" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 4, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "data": { 189 | "text/plain": [ 190 | "[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),\n", 191 | " BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n", 192 | " ReLU(inplace),\n", 193 | " MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),\n", 194 | " Sequential(\n", 195 | " (0): BasicBlock(\n", 196 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 197 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 198 | " (relu): ReLU(inplace)\n", 199 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 200 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 201 | " )\n", 202 | " (1): BasicBlock(\n", 203 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 204 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 205 | " (relu): ReLU(inplace)\n", 206 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 207 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 208 | " )\n", 209 | " ),\n", 210 | " Sequential(\n", 211 | " (0): BasicBlock(\n", 212 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 213 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 214 | " (relu): ReLU(inplace)\n", 215 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 216 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 217 | " (downsample): Sequential(\n", 218 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 219 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 220 | " )\n", 221 | " )\n", 222 | " (1): BasicBlock(\n", 223 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 224 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 225 | " (relu): ReLU(inplace)\n", 226 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 227 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 228 | " )\n", 229 | " ),\n", 230 | " Sequential(\n", 231 | " (0): BasicBlock(\n", 232 | " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 233 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 234 | " (relu): ReLU(inplace)\n", 235 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 236 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 237 | " (downsample): Sequential(\n", 238 | " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 239 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 240 | " )\n", 241 | " )\n", 242 | " (1): BasicBlock(\n", 243 | " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 244 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 245 | " (relu): ReLU(inplace)\n", 246 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 247 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 248 | " )\n", 249 | " ),\n", 250 | " Sequential(\n", 251 | " (0): BasicBlock(\n", 252 | " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 253 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 254 | " (relu): ReLU(inplace)\n", 255 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 256 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 257 | " (downsample): Sequential(\n", 258 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 259 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 260 | " )\n", 261 | " )\n", 262 | " (1): BasicBlock(\n", 263 | " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 264 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 265 | " (relu): ReLU(inplace)\n", 266 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 267 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 268 | " )\n", 269 | " ),\n", 270 | " AvgPool2d(kernel_size=7, stride=1, padding=0),\n", 271 | " Linear(in_features=512, out_features=1000, bias=True)]" 272 | ] 273 | }, 274 | "execution_count": 4, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "from torchvision import models\n", 281 | "\n", 282 | "base_model = models.resnet18(pretrained=False)\n", 283 | " \n", 284 | "list(base_model.children())" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 5, 290 | "metadata": {}, 291 | "outputs": [ 292 | { 293 | "name": "stdout", 294 | "output_type": "stream", 295 | "text": [ 296 | "----------------------------------------------------------------\n", 297 | " Layer (type) Output Shape Param #\n", 298 | "================================================================\n", 299 | " Conv2d-1 [-1, 64, 112, 112] 9,408\n", 300 | " BatchNorm2d-2 [-1, 64, 112, 112] 128\n", 301 | " ReLU-3 [-1, 64, 112, 112] 0\n", 302 | " MaxPool2d-4 [-1, 64, 56, 56] 0\n", 303 | " Conv2d-5 [-1, 64, 56, 56] 36,864\n", 304 | " BatchNorm2d-6 [-1, 64, 56, 56] 128\n", 305 | " ReLU-7 [-1, 64, 56, 56] 0\n", 306 | " Conv2d-8 [-1, 64, 56, 56] 36,864\n", 307 | " BatchNorm2d-9 [-1, 64, 56, 56] 128\n", 308 | " ReLU-10 [-1, 64, 56, 56] 0\n", 309 | " BasicBlock-11 [-1, 64, 56, 56] 0\n", 310 | " Conv2d-12 [-1, 64, 56, 56] 36,864\n", 311 | " BatchNorm2d-13 [-1, 64, 56, 56] 128\n", 312 | " ReLU-14 [-1, 64, 56, 56] 0\n", 313 | " Conv2d-15 [-1, 64, 56, 56] 36,864\n", 314 | " BatchNorm2d-16 [-1, 64, 56, 56] 128\n", 315 | " ReLU-17 [-1, 64, 56, 56] 0\n", 316 | " BasicBlock-18 [-1, 64, 56, 56] 0\n", 317 | " Conv2d-19 [-1, 128, 28, 28] 73,728\n", 318 | " BatchNorm2d-20 [-1, 128, 28, 28] 256\n", 319 | " ReLU-21 [-1, 128, 28, 28] 0\n", 320 | " Conv2d-22 [-1, 128, 28, 28] 147,456\n", 321 | " BatchNorm2d-23 [-1, 128, 28, 28] 256\n", 322 | " Conv2d-24 [-1, 128, 28, 28] 8,192\n", 323 | " BatchNorm2d-25 [-1, 128, 28, 28] 256\n", 324 | " ReLU-26 [-1, 128, 28, 28] 0\n", 325 | " BasicBlock-27 [-1, 128, 28, 28] 0\n", 326 | " Conv2d-28 [-1, 128, 28, 28] 147,456\n", 327 | " BatchNorm2d-29 [-1, 128, 28, 28] 256\n", 328 | " ReLU-30 [-1, 128, 28, 28] 0\n", 329 | " Conv2d-31 [-1, 128, 28, 28] 147,456\n", 330 | " BatchNorm2d-32 [-1, 128, 28, 28] 256\n", 331 | " ReLU-33 [-1, 128, 28, 28] 0\n", 332 | " BasicBlock-34 [-1, 128, 28, 28] 0\n", 333 | " Conv2d-35 [-1, 256, 14, 14] 294,912\n", 334 | " BatchNorm2d-36 [-1, 256, 14, 14] 512\n", 335 | " ReLU-37 [-1, 256, 14, 14] 0\n", 336 | " Conv2d-38 [-1, 256, 14, 14] 589,824\n", 337 | " BatchNorm2d-39 [-1, 256, 14, 14] 512\n", 338 | " Conv2d-40 [-1, 256, 14, 14] 32,768\n", 339 | " BatchNorm2d-41 [-1, 256, 14, 14] 512\n", 340 | " ReLU-42 [-1, 256, 14, 14] 0\n", 341 | " BasicBlock-43 [-1, 256, 14, 14] 0\n", 342 | " Conv2d-44 [-1, 256, 14, 14] 589,824\n", 343 | " BatchNorm2d-45 [-1, 256, 14, 14] 512\n", 344 | " ReLU-46 [-1, 256, 14, 14] 0\n", 345 | " Conv2d-47 [-1, 256, 14, 14] 589,824\n", 346 | " BatchNorm2d-48 [-1, 256, 14, 14] 512\n", 347 | " ReLU-49 [-1, 256, 14, 14] 0\n", 348 | " BasicBlock-50 [-1, 256, 14, 14] 0\n", 349 | " Conv2d-51 [-1, 512, 7, 7] 1,179,648\n", 350 | " BatchNorm2d-52 [-1, 512, 7, 7] 1,024\n", 351 | " ReLU-53 [-1, 512, 7, 7] 0\n", 352 | " Conv2d-54 [-1, 512, 7, 7] 2,359,296\n", 353 | " BatchNorm2d-55 [-1, 512, 7, 7] 1,024\n", 354 | " Conv2d-56 [-1, 512, 7, 7] 131,072\n", 355 | " BatchNorm2d-57 [-1, 512, 7, 7] 1,024\n", 356 | " ReLU-58 [-1, 512, 7, 7] 0\n", 357 | " BasicBlock-59 [-1, 512, 7, 7] 0\n", 358 | " Conv2d-60 [-1, 512, 7, 7] 2,359,296\n", 359 | " BatchNorm2d-61 [-1, 512, 7, 7] 1,024\n", 360 | " ReLU-62 [-1, 512, 7, 7] 0\n", 361 | " Conv2d-63 [-1, 512, 7, 7] 2,359,296\n", 362 | " BatchNorm2d-64 [-1, 512, 7, 7] 1,024\n", 363 | " ReLU-65 [-1, 512, 7, 7] 0\n", 364 | " BasicBlock-66 [-1, 512, 7, 7] 0\n", 365 | " AvgPool2d-67 [-1, 512, 1, 1] 0\n", 366 | " Linear-68 [-1, 1000] 513,000\n", 367 | "================================================================\n", 368 | "Total params: 11,689,512\n", 369 | "Trainable params: 11,689,512\n", 370 | "Non-trainable params: 0\n", 371 | "----------------------------------------------------------------\n", 372 | "Input size (MB): 0.57\n", 373 | "Forward/backward pass size (MB): 62.79\n", 374 | "Params size (MB): 44.59\n", 375 | "Estimated Total Size (MB): 107.96\n", 376 | "----------------------------------------------------------------\n" 377 | ] 378 | } 379 | ], 380 | "source": [ 381 | "# check keras-like model summary using torchsummary\n", 382 | "import torch\n", 383 | "from torchsummary import summary\n", 384 | "\n", 385 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 386 | "base_model = base_model.to(device)\n", 387 | "\n", 388 | "summary(base_model, input_size=(3, 224, 224))" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 6, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "import torch\n", 398 | "import torch.nn as nn\n", 399 | "\n", 400 | "def convrelu(in_channels, out_channels, kernel, padding):\n", 401 | " return nn.Sequential(\n", 402 | " nn.Conv2d(in_channels, out_channels, kernel, padding=padding),\n", 403 | " nn.ReLU(inplace=True),\n", 404 | " )\n", 405 | "\n", 406 | "class ResNetUNet(nn.Module):\n", 407 | "\n", 408 | " def __init__(self, n_class):\n", 409 | " super().__init__()\n", 410 | " \n", 411 | " self.base_model = models.resnet18(pretrained=True)\n", 412 | " \n", 413 | " self.base_layers = list(base_model.children()) \n", 414 | " \n", 415 | " self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)\n", 416 | " self.layer0_1x1 = convrelu(64, 64, 1, 0)\n", 417 | " self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) \n", 418 | " self.layer1_1x1 = convrelu(64, 64, 1, 0) \n", 419 | " self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) \n", 420 | " self.layer2_1x1 = convrelu(128, 128, 1, 0) \n", 421 | " self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) \n", 422 | " self.layer3_1x1 = convrelu(256, 256, 1, 0) \n", 423 | " self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)\n", 424 | " self.layer4_1x1 = convrelu(512, 512, 1, 0) \n", 425 | " \n", 426 | " self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n", 427 | " \n", 428 | " self.conv_up3 = convrelu(256 + 512, 512, 3, 1)\n", 429 | " self.conv_up2 = convrelu(128 + 512, 256, 3, 1)\n", 430 | " self.conv_up1 = convrelu(64 + 256, 256, 3, 1)\n", 431 | " self.conv_up0 = convrelu(64 + 256, 128, 3, 1)\n", 432 | " \n", 433 | " self.conv_original_size0 = convrelu(3, 64, 3, 1)\n", 434 | " self.conv_original_size1 = convrelu(64, 64, 3, 1)\n", 435 | " self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)\n", 436 | " \n", 437 | " self.conv_last = nn.Conv2d(64, n_class, 1)\n", 438 | " \n", 439 | " def forward(self, input):\n", 440 | " x_original = self.conv_original_size0(input)\n", 441 | " x_original = self.conv_original_size1(x_original)\n", 442 | " \n", 443 | " layer0 = self.layer0(input) \n", 444 | " layer1 = self.layer1(layer0)\n", 445 | " layer2 = self.layer2(layer1)\n", 446 | " layer3 = self.layer3(layer2) \n", 447 | " layer4 = self.layer4(layer3)\n", 448 | " \n", 449 | " layer4 = self.layer4_1x1(layer4)\n", 450 | " x = self.upsample(layer4)\n", 451 | " layer3 = self.layer3_1x1(layer3)\n", 452 | " x = torch.cat([x, layer3], dim=1)\n", 453 | " x = self.conv_up3(x)\n", 454 | " \n", 455 | " x = self.upsample(x)\n", 456 | " layer2 = self.layer2_1x1(layer2)\n", 457 | " x = torch.cat([x, layer2], dim=1)\n", 458 | " x = self.conv_up2(x)\n", 459 | "\n", 460 | " x = self.upsample(x)\n", 461 | " layer1 = self.layer1_1x1(layer1)\n", 462 | " x = torch.cat([x, layer1], dim=1)\n", 463 | " x = self.conv_up1(x)\n", 464 | "\n", 465 | " x = self.upsample(x)\n", 466 | " layer0 = self.layer0_1x1(layer0)\n", 467 | " x = torch.cat([x, layer0], dim=1)\n", 468 | " x = self.conv_up0(x)\n", 469 | " \n", 470 | " x = self.upsample(x)\n", 471 | " x = torch.cat([x, x_original], dim=1)\n", 472 | " x = self.conv_original_size2(x) \n", 473 | " \n", 474 | " out = self.conv_last(x) \n", 475 | " \n", 476 | " return out" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 7, 482 | "metadata": {}, 483 | "outputs": [ 484 | { 485 | "name": "stdout", 486 | "output_type": "stream", 487 | "text": [ 488 | "----------------------------------------------------------------\n", 489 | " Layer (type) Output Shape Param #\n", 490 | "================================================================\n", 491 | " Conv2d-1 [-1, 64, 224, 224] 1,792\n", 492 | " ReLU-2 [-1, 64, 224, 224] 0\n", 493 | " Conv2d-3 [-1, 64, 224, 224] 36,928\n", 494 | " ReLU-4 [-1, 64, 224, 224] 0\n", 495 | " Conv2d-5 [-1, 64, 112, 112] 9,408\n", 496 | " BatchNorm2d-6 [-1, 64, 112, 112] 128\n", 497 | " ReLU-7 [-1, 64, 112, 112] 0\n", 498 | " MaxPool2d-8 [-1, 64, 56, 56] 0\n", 499 | " Conv2d-9 [-1, 64, 56, 56] 36,864\n", 500 | " BatchNorm2d-10 [-1, 64, 56, 56] 128\n", 501 | " ReLU-11 [-1, 64, 56, 56] 0\n", 502 | " Conv2d-12 [-1, 64, 56, 56] 36,864\n", 503 | " BatchNorm2d-13 [-1, 64, 56, 56] 128\n", 504 | " ReLU-14 [-1, 64, 56, 56] 0\n", 505 | " BasicBlock-15 [-1, 64, 56, 56] 0\n", 506 | " Conv2d-16 [-1, 64, 56, 56] 36,864\n", 507 | " BatchNorm2d-17 [-1, 64, 56, 56] 128\n", 508 | " ReLU-18 [-1, 64, 56, 56] 0\n", 509 | " Conv2d-19 [-1, 64, 56, 56] 36,864\n", 510 | " BatchNorm2d-20 [-1, 64, 56, 56] 128\n", 511 | " ReLU-21 [-1, 64, 56, 56] 0\n", 512 | " BasicBlock-22 [-1, 64, 56, 56] 0\n", 513 | " Conv2d-23 [-1, 128, 28, 28] 73,728\n", 514 | " BatchNorm2d-24 [-1, 128, 28, 28] 256\n", 515 | " ReLU-25 [-1, 128, 28, 28] 0\n", 516 | " Conv2d-26 [-1, 128, 28, 28] 147,456\n", 517 | " BatchNorm2d-27 [-1, 128, 28, 28] 256\n", 518 | " Conv2d-28 [-1, 128, 28, 28] 8,192\n", 519 | " BatchNorm2d-29 [-1, 128, 28, 28] 256\n", 520 | " ReLU-30 [-1, 128, 28, 28] 0\n", 521 | " BasicBlock-31 [-1, 128, 28, 28] 0\n", 522 | " Conv2d-32 [-1, 128, 28, 28] 147,456\n", 523 | " BatchNorm2d-33 [-1, 128, 28, 28] 256\n", 524 | " ReLU-34 [-1, 128, 28, 28] 0\n", 525 | " Conv2d-35 [-1, 128, 28, 28] 147,456\n", 526 | " BatchNorm2d-36 [-1, 128, 28, 28] 256\n", 527 | " ReLU-37 [-1, 128, 28, 28] 0\n", 528 | " BasicBlock-38 [-1, 128, 28, 28] 0\n", 529 | " Conv2d-39 [-1, 256, 14, 14] 294,912\n", 530 | " BatchNorm2d-40 [-1, 256, 14, 14] 512\n", 531 | " ReLU-41 [-1, 256, 14, 14] 0\n", 532 | " Conv2d-42 [-1, 256, 14, 14] 589,824\n", 533 | " BatchNorm2d-43 [-1, 256, 14, 14] 512\n", 534 | " Conv2d-44 [-1, 256, 14, 14] 32,768\n", 535 | " BatchNorm2d-45 [-1, 256, 14, 14] 512\n", 536 | " ReLU-46 [-1, 256, 14, 14] 0\n", 537 | " BasicBlock-47 [-1, 256, 14, 14] 0\n", 538 | " Conv2d-48 [-1, 256, 14, 14] 589,824\n", 539 | " BatchNorm2d-49 [-1, 256, 14, 14] 512\n", 540 | " ReLU-50 [-1, 256, 14, 14] 0\n", 541 | " Conv2d-51 [-1, 256, 14, 14] 589,824\n", 542 | " BatchNorm2d-52 [-1, 256, 14, 14] 512\n", 543 | " ReLU-53 [-1, 256, 14, 14] 0\n", 544 | " BasicBlock-54 [-1, 256, 14, 14] 0\n", 545 | " Conv2d-55 [-1, 512, 7, 7] 1,179,648\n", 546 | " BatchNorm2d-56 [-1, 512, 7, 7] 1,024\n", 547 | " ReLU-57 [-1, 512, 7, 7] 0\n", 548 | " Conv2d-58 [-1, 512, 7, 7] 2,359,296\n", 549 | " BatchNorm2d-59 [-1, 512, 7, 7] 1,024\n", 550 | " Conv2d-60 [-1, 512, 7, 7] 131,072\n", 551 | " BatchNorm2d-61 [-1, 512, 7, 7] 1,024\n", 552 | " ReLU-62 [-1, 512, 7, 7] 0\n", 553 | " BasicBlock-63 [-1, 512, 7, 7] 0\n", 554 | " Conv2d-64 [-1, 512, 7, 7] 2,359,296\n", 555 | " BatchNorm2d-65 [-1, 512, 7, 7] 1,024\n", 556 | " ReLU-66 [-1, 512, 7, 7] 0\n", 557 | " Conv2d-67 [-1, 512, 7, 7] 2,359,296\n", 558 | " BatchNorm2d-68 [-1, 512, 7, 7] 1,024\n", 559 | " ReLU-69 [-1, 512, 7, 7] 0\n", 560 | " BasicBlock-70 [-1, 512, 7, 7] 0\n", 561 | " Conv2d-71 [-1, 512, 7, 7] 262,656\n", 562 | " ReLU-72 [-1, 512, 7, 7] 0\n", 563 | " Upsample-73 [-1, 512, 14, 14] 0\n", 564 | " Conv2d-74 [-1, 256, 14, 14] 65,792\n", 565 | " ReLU-75 [-1, 256, 14, 14] 0\n", 566 | " Conv2d-76 [-1, 512, 14, 14] 3,539,456\n", 567 | " ReLU-77 [-1, 512, 14, 14] 0\n", 568 | " Upsample-78 [-1, 512, 28, 28] 0\n", 569 | " Conv2d-79 [-1, 128, 28, 28] 16,512\n", 570 | " ReLU-80 [-1, 128, 28, 28] 0\n", 571 | " Conv2d-81 [-1, 256, 28, 28] 1,474,816\n", 572 | " ReLU-82 [-1, 256, 28, 28] 0\n", 573 | " Upsample-83 [-1, 256, 56, 56] 0\n", 574 | " Conv2d-84 [-1, 64, 56, 56] 4,160\n", 575 | " ReLU-85 [-1, 64, 56, 56] 0\n", 576 | " Conv2d-86 [-1, 256, 56, 56] 737,536\n", 577 | " ReLU-87 [-1, 256, 56, 56] 0\n", 578 | " Upsample-88 [-1, 256, 112, 112] 0\n", 579 | " Conv2d-89 [-1, 64, 112, 112] 4,160\n", 580 | " ReLU-90 [-1, 64, 112, 112] 0\n", 581 | " Conv2d-91 [-1, 128, 112, 112] 368,768\n", 582 | " ReLU-92 [-1, 128, 112, 112] 0\n", 583 | " Upsample-93 [-1, 128, 224, 224] 0\n", 584 | " Conv2d-94 [-1, 64, 224, 224] 110,656\n", 585 | " ReLU-95 [-1, 64, 224, 224] 0\n", 586 | " Conv2d-96 [-1, 6, 224, 224] 390\n", 587 | "================================================================\n", 588 | "Total params: 17,800,134\n", 589 | "Trainable params: 17,800,134\n", 590 | "Non-trainable params: 0\n", 591 | "----------------------------------------------------------------\n", 592 | "Input size (MB): 0.57\n", 593 | "Forward/backward pass size (MB): 354.87\n", 594 | "Params size (MB): 67.90\n", 595 | "Estimated Total Size (MB): 423.34\n", 596 | "----------------------------------------------------------------\n" 597 | ] 598 | }, 599 | { 600 | "name": "stderr", 601 | "output_type": "stream", 602 | "text": [ 603 | "C:\\Users\\naotous\\AppData\\Local\\Continuum\\anaconda3\\envs\\torch\\lib\\site-packages\\torch\\nn\\modules\\upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.\n", 604 | " warnings.warn(\"nn.{} is deprecated. Use nn.functional.interpolate instead.\".format(self.name))\n" 605 | ] 606 | } 607 | ], 608 | "source": [ 609 | "# check keras-like model summary using torchsummary\n", 610 | "\n", 611 | "from torchsummary import summary\n", 612 | "\n", 613 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 614 | "model = ResNetUNet(6)\n", 615 | "model = model.to(device)\n", 616 | "\n", 617 | "summary(model, input_size=(3, 224, 224))" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 8, 623 | "metadata": {}, 624 | "outputs": [], 625 | "source": [ 626 | "from collections import defaultdict\n", 627 | "import torch.nn.functional as F\n", 628 | "import torch\n", 629 | "from loss import dice_loss\n", 630 | "\n", 631 | "def calc_loss(pred, target, metrics, bce_weight=0.5):\n", 632 | " bce = F.binary_cross_entropy_with_logits(pred, target)\n", 633 | " \n", 634 | " pred = torch.sigmoid(pred)\n", 635 | " dice = dice_loss(pred, target)\n", 636 | " \n", 637 | " loss = bce * bce_weight + dice * (1 - bce_weight)\n", 638 | " \n", 639 | " metrics['bce'] += bce.data.cpu().numpy() * target.size(0)\n", 640 | " metrics['dice'] += dice.data.cpu().numpy() * target.size(0)\n", 641 | " metrics['loss'] += loss.data.cpu().numpy() * target.size(0)\n", 642 | " \n", 643 | " return loss\n", 644 | "\n", 645 | "def print_metrics(metrics, epoch_samples, phase): \n", 646 | " outputs = []\n", 647 | " for k in metrics.keys():\n", 648 | " outputs.append(\"{}: {:4f}\".format(k, metrics[k] / epoch_samples))\n", 649 | " \n", 650 | " print(\"{}: {}\".format(phase, \", \".join(outputs))) \n", 651 | "\n", 652 | "def train_model(model, optimizer, scheduler, num_epochs=25):\n", 653 | " best_model_wts = copy.deepcopy(model.state_dict())\n", 654 | " best_loss = 1e10\n", 655 | "\n", 656 | " for epoch in range(num_epochs):\n", 657 | " print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n", 658 | " print('-' * 10)\n", 659 | " \n", 660 | " since = time.time()\n", 661 | "\n", 662 | " # Each epoch has a training and validation phase\n", 663 | " for phase in ['train', 'val']:\n", 664 | " if phase == 'train':\n", 665 | " scheduler.step()\n", 666 | " for param_group in optimizer.param_groups:\n", 667 | " print(\"LR\", param_group['lr'])\n", 668 | " \n", 669 | " model.train() # Set model to training mode\n", 670 | " else:\n", 671 | " model.eval() # Set model to evaluate mode\n", 672 | "\n", 673 | " metrics = defaultdict(float)\n", 674 | " epoch_samples = 0\n", 675 | " \n", 676 | " for inputs, labels in dataloaders[phase]:\n", 677 | " inputs = inputs.to(device)\n", 678 | " labels = labels.to(device) \n", 679 | "\n", 680 | " # zero the parameter gradients\n", 681 | " optimizer.zero_grad()\n", 682 | "\n", 683 | " # forward\n", 684 | " # track history if only in train\n", 685 | " with torch.set_grad_enabled(phase == 'train'):\n", 686 | " outputs = model(inputs)\n", 687 | " loss = calc_loss(outputs, labels, metrics)\n", 688 | "\n", 689 | " # backward + optimize only if in training phase\n", 690 | " if phase == 'train':\n", 691 | " loss.backward()\n", 692 | " optimizer.step()\n", 693 | "\n", 694 | " # statistics\n", 695 | " epoch_samples += inputs.size(0)\n", 696 | "\n", 697 | " print_metrics(metrics, epoch_samples, phase)\n", 698 | " epoch_loss = metrics['loss'] / epoch_samples\n", 699 | "\n", 700 | " # deep copy the model\n", 701 | " if phase == 'val' and epoch_loss < best_loss:\n", 702 | " print(\"saving best model\")\n", 703 | " best_loss = epoch_loss\n", 704 | " best_model_wts = copy.deepcopy(model.state_dict())\n", 705 | "\n", 706 | " time_elapsed = time.time() - since\n", 707 | " print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n", 708 | " \n", 709 | " print('Best val loss: {:4f}'.format(best_loss))\n", 710 | "\n", 711 | " # load best model weights\n", 712 | " model.load_state_dict(best_model_wts)\n", 713 | " return model" 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": 9, 719 | "metadata": {}, 720 | "outputs": [ 721 | { 722 | "name": "stdout", 723 | "output_type": "stream", 724 | "text": [ 725 | "cuda:0\n", 726 | "Epoch 0/14\n", 727 | "----------\n", 728 | "LR 0.0001\n", 729 | "train: bce: 0.106049, dice: 0.939754, loss: 0.522902\n", 730 | "val: bce: 0.021187, dice: 0.740403, loss: 0.380795\n", 731 | "saving best model\n", 732 | "1m 13s\n", 733 | "Epoch 1/14\n", 734 | "----------\n", 735 | "LR 0.0001\n", 736 | "train: bce: 0.021152, dice: 0.539552, loss: 0.280352\n", 737 | "val: bce: 0.015782, dice: 0.403593, loss: 0.209687\n", 738 | "saving best model\n", 739 | "1m 23s\n", 740 | "Epoch 2/14\n", 741 | "----------\n", 742 | "LR 0.0001\n", 743 | "train: bce: 0.011438, dice: 0.297177, loss: 0.154307\n", 744 | "val: bce: 0.008063, dice: 0.226069, loss: 0.117066\n", 745 | "saving best model\n", 746 | "1m 35s\n", 747 | "Epoch 3/14\n", 748 | "----------\n", 749 | "LR 0.0001\n", 750 | "train: bce: 0.007809, dice: 0.202010, loss: 0.104910\n", 751 | "val: bce: 0.007454, dice: 0.183544, loss: 0.095499\n", 752 | "saving best model\n", 753 | "1m 40s\n", 754 | "Epoch 4/14\n", 755 | "----------\n", 756 | "LR 0.0001\n", 757 | "train: bce: 0.006656, dice: 0.175859, loss: 0.091257\n", 758 | "val: bce: 0.006215, dice: 0.169616, loss: 0.087916\n", 759 | "saving best model\n", 760 | "1m 39s\n", 761 | "Epoch 5/14\n", 762 | "----------\n", 763 | "LR 0.0001\n", 764 | "train: bce: 0.004646, dice: 0.149397, loss: 0.077021\n", 765 | "val: bce: 0.003372, dice: 0.134775, loss: 0.069073\n", 766 | "saving best model\n", 767 | "1m 39s\n", 768 | "Epoch 6/14\n", 769 | "----------\n", 770 | "LR 0.0001\n", 771 | "train: bce: 0.002749, dice: 0.107913, loss: 0.055331\n", 772 | "val: bce: 0.002363, dice: 0.093581, loss: 0.047972\n", 773 | "saving best model\n", 774 | "1m 40s\n", 775 | "Epoch 7/14\n", 776 | "----------\n", 777 | "LR 0.0001\n", 778 | "train: bce: 0.002137, dice: 0.079275, loss: 0.040706\n", 779 | "val: bce: 0.002504, dice: 0.086800, loss: 0.044652\n", 780 | "saving best model\n", 781 | "1m 39s\n", 782 | "Epoch 8/14\n", 783 | "----------\n", 784 | "LR 0.0001\n", 785 | "train: bce: 0.001844, dice: 0.062099, loss: 0.031971\n", 786 | "val: bce: 0.002087, dice: 0.069803, loss: 0.035945\n", 787 | "saving best model\n", 788 | "1m 40s\n", 789 | "Epoch 9/14\n", 790 | "----------\n", 791 | "LR 0.0001\n", 792 | "train: bce: 0.001642, dice: 0.050935, loss: 0.026289\n", 793 | "val: bce: 0.001931, dice: 0.064805, loss: 0.033368\n", 794 | "saving best model\n", 795 | "1m 40s\n", 796 | "Epoch 10/14\n", 797 | "----------\n", 798 | "LR 1e-05\n", 799 | "train: bce: 0.001492, dice: 0.043337, loss: 0.022415\n", 800 | "val: bce: 0.001986, dice: 0.063748, loss: 0.032867\n", 801 | "saving best model\n", 802 | "1m 41s\n", 803 | "Epoch 11/14\n", 804 | "----------\n", 805 | "LR 1e-05\n", 806 | "train: bce: 0.001475, dice: 0.042005, loss: 0.021740\n", 807 | "val: bce: 0.002009, dice: 0.063995, loss: 0.033002\n", 808 | "1m 54s\n", 809 | "Epoch 12/14\n", 810 | "----------\n", 811 | "LR 1e-05\n", 812 | "train: bce: 0.001460, dice: 0.041238, loss: 0.021349\n", 813 | "val: bce: 0.002017, dice: 0.064332, loss: 0.033175\n", 814 | "1m 55s\n", 815 | "Epoch 13/14\n", 816 | "----------\n", 817 | "LR 1e-05\n", 818 | "train: bce: 0.001450, dice: 0.040528, loss: 0.020989\n", 819 | "val: bce: 0.002030, dice: 0.064902, loss: 0.033466\n", 820 | "1m 54s\n", 821 | "Epoch 14/14\n", 822 | "----------\n", 823 | "LR 1e-05\n", 824 | "train: bce: 0.001433, dice: 0.039867, loss: 0.020650\n", 825 | "val: bce: 0.002121, dice: 0.066033, loss: 0.034077\n", 826 | "1m 54s\n", 827 | "Best val loss: 0.032867\n" 828 | ] 829 | } 830 | ], 831 | "source": [ 832 | "import torch\n", 833 | "import torch.optim as optim\n", 834 | "from torch.optim import lr_scheduler\n", 835 | "import time\n", 836 | "import copy\n", 837 | "\n", 838 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 839 | "print(device)\n", 840 | "\n", 841 | "num_class = 6\n", 842 | "\n", 843 | "model = ResNetUNet(num_class).to(device)\n", 844 | "\n", 845 | "# freeze backbone layers\n", 846 | "# Comment out to finetune further\n", 847 | "for l in model.base_layers:\n", 848 | " for param in l.parameters():\n", 849 | " param.requires_grad = False\n", 850 | "\n", 851 | "optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)\n", 852 | "\n", 853 | "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1) \n", 854 | " \n", 855 | "model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=15)" 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": 10, 861 | "metadata": {}, 862 | "outputs": [ 863 | { 864 | "name": "stdout", 865 | "output_type": "stream", 866 | "text": [ 867 | "(3, 6, 192, 192)\n" 868 | ] 869 | }, 870 | { 871 | "data": { 872 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsQAAAKvCAYAAABtZtkaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzs3X+sbGd5H/rvY7uJVOoKiDfI18YxIIfcELUnYcstQlAoJTG+3DoUJbVVJU6CekACqVHzR0iQCmqFFLWh6FZpCQdh2VwlBm4cGpS6bbgoDUkEhePEcUwcB5uYcGzLPuDcxLdE5NrnuX+c2WU47B+z98zsmTXr85G29p531pp5Zh8/er9+9ztrqrsDAABjddGqCwAAgFUSiAEAGDWBGACAUROIAQAYNYEYAIBRE4gBABi1pQXiqrququ6vqgeq6m3Leh4AAJhHLeM6xFV1cZI/TvKaJGeSfDbJTd39hwt/MgAAmMOyVoivTfJAd3+hu/8qyYeS3LCk5wIAgCO7ZEmPe0WSL03dPpPk7+x1cFX5uDzG7MvdvbXqIg7jsssu66uvvnrVZcBK3HXXXYPqWf3KmM3ar8sKxLXL2DeE3qo6meTkkp4fhuSLqy5gFtM9e9VVV+X06dMrrghWo6rWvmf1K5w3a78ua8vEmSTPm7p9ZZJHpg/o7lPdvd3d20uqAVig6Z7d2hrM4hiMkn6Fw1lWIP5skmuq6vlV9S1JbkzysSU9FwAAHNlStkx091NV9dYk/zXJxUlu6e7PLeO5AABgHsvaQ5zuvjPJnct6fAAAWASfVAcAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACM2pEDcVU9r6p+o6ruq6rPVdU/m4y/s6oerqq7J1/XL65cAABYrEvmOPepJD/Z3b9bVZcmuauqPj657z3d/XPzlwcAAMt15EDc3Y8meXTy85NVdV+SKxZVGAAAHIeF7CGuqquTfE+S/z4ZemtV3VNVt1TVsxbxHAAAsAxzB+Kq+htJ7kjyE939F0nem+SFSU7k/Aryu/c472RVna6q0/PWACzfdM+ePXt21eUA+9CvcDhzBeKq+ms5H4Z/sbt/JUm6+7Hufrq7zyV5f5Jrdzu3u09193Z3b89TA3A8pnt2a2tr1eUA+9CvcDjzXGWiknwgyX3d/W+nxi+fOuz1Se49enkAALBc81xl4mVJfjjJH1TV3ZOxn0lyU1WdSNJJHkryprkqBACAJZrnKhO/naR2uevOo5cDAADHyyfVAQAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKN2ybwPUFUPJXkyydNJnuru7ap6dpIPJ7k6yUNJfqi7/2ze5wIAgEVb1Arxq7r7RHdvT26/LcknuvuaJJ+Y3AYAgLWzrC0TNyS5bfLzbUl+YEnPAwAAc1lEIO4kv15Vd1XVycnYc7v70SSZfH/OAp4HAAAWbu49xEle1t2PVNVzkny8qv5olpMm4fnkgQeyNs51H3jMRVXHUAmrMN2zV1111YqrYRYn73j4wGNOveGKY6iE46Zfh0e/rlb1DCFn5geremeS/zfJP03yyu5+tKouT/LfuvtF+5y3uCJYuFmC8IUE40O5a2r//SBsb2/36dOnV10Ge5hlYr2QiXZ2VTWontWv602/Ltes/TrXlomqekZVXbrzc5LvS3Jvko8luXly2M1JfnWe52F1jhKG5zkPmM9RJtd5zgOOTr+uj3m3TDw3yUfr/GrgJUl+qbv/S1V9NslHquqNSf40yQ/O+TyswF6hdrfV392OPddtpRiO0V6T5G6rSbsde/KOh608wTHRr+tloVsmjlyELRNrZ7eAO0u4Pep5IzeoP78m/gS7jnabMGeZLI963pjZMsG89OvxmbVfF/GmOjbMPKH2oqpvOt9KMSzXPJPkqTdc8U3nW3linbz0he860nmfevDtC65kMfTrehKI57QT9ub9vs4OW99uoRjWwVEn1t2s62SbHH7FaLdJFlZt3n7dOX+dezXRr+tiWR/MMRo7YXHe7+viwiB71PouPE9AhuW4cGI86krRheeZcGHx9Ov6EojntBP05v0OACzeIv8yxOYSiOe0aSvE0+atbZ1fG2yiefcR2ocIx0e/rheBeE5WiAEAhk0gntMmrxADAIyBQDwnK8QAAMMmEM/JCjEAwLAJxHPa5BXieWtb59cGm2jeSy+5dBMcH/26XgTiOVkhBgAYNoF4Tpu2QryoD9RY1Ad8APtb1AX6F/WBAcDe9Ov6EojnNIYV4sOG4nUL+TA2h51k/ekVVke/rgeBeE6btkKc7B7SZ61zt+OGEPphyHZbHZp10tztOKtNsDz6dT0JxHPa1BXivULxXsF4r/vW9fXBptlrkt1rot3rPpMrLJ9+XT+XrLqAoTvXnYuq5v6+jnbqu9Csq8Xr+rpgU516wxW7Tpqzrj6ZXOH46Nf1cuQV4qp6UVXdPfX1F1X1E1X1zqp6eGr8+kUWvG42dYV4x1HrW/fXBZvqqJOkyRWOn35dH0deIe7u+5OcSJKqujjJw0k+muTHkrynu39uIRWuuU1eId6xU98sK8Pr/lpgDHYmy1lWmkyssFr6dT0sasvEq5M82N1frJEFok1fIZ42pFoBkyckyacefPuqS5iJfl2tRb2p7sYkt0/dfmtV3VNVt1TVsxb0HGtpE68yAQCrNpQgy2aonjOQVdW3JHkkyYu7+7Gqem6SLyfpJP8qyeXd/eO7nHcyycnJzZfMVQQM213dvb3qIg4y3bNXXXXVS774xS+uuCJYjapa+57dpH596QvfdeRzhWpm7ddFBOIbkrylu79vl/uuTvJr3f3dBzyGZVLGbO0n1wttb2/36dOnV10GrMQQAvE0/cqYzdqvi9gycVOmtktU1eVT970+yb0LeA4AAFiKud5UV1V/PclrkrxpavhfV9WJnN8y8dAF9wEAwFqZKxB391eTfNsFYz88V0UAAHCMfHQzAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGozBeKquqWqHq+qe6fGnl1VH6+qz0++P2syXlX176rqgaq6p6q+d1nFAwDAvGZdIb41yXUXjL0tySe6+5okn5jcTpLXJrlm8nUyyXvnLxMAAJZjpkDc3Z9M8sQFwzckuW3y821JfmBq/IN93qeTPLOqLl9EsQAAsGjz7CF+bnc/miST78+ZjF+R5EtTx52ZjAEAwNpZxpvqapex/qaDqk5W1emqOr2EGoAFm+7Zs2fPrrocYB/6FQ5nnkD82M5WiMn3xyfjZ5I8b+q4K5M8cuHJ3X2qu7e7e3uOGoBjMt2zW1tbqy4H2Id+hcOZJxB/LMnNk59vTvKrU+M/MrnaxN9N8uc7WysAAGDdXDLLQVV1e5JXJrmsqs4keUeSn03ykap6Y5I/TfKDk8PvTHJ9kgeSfDXJjy24ZgAAWJiZAnF337THXa/e5dhO8pZ5igIAgOPik+oAABg1gRgAgFETiAEAGDWBGACAUROIAQAYNYEYAIBRE4gBABg1gRgAgFETiJnZue6c6151GcCMbvutl+a233rpqssAZqBfV0sgBgBg1ARiAABGTSAGAGDUBGJGy35oGJZ63+tWXQIwo6H1q0DMqAnFMCxDm2RhzIbUr5esugDWz0Ehca/7L6paRjlLd657sLVDkgPfmb7X/Te//FPLKGfp6n2vS7/p11ZdBhyJfl1PVoghVophaIa08gRjN4R+PXCFuKpuSfK6JI9393dPxv5Nkv89yV8leTDJj3X3/1NVVye5L8n9k9M/3d1vXkLdLNFeq6U7oXFTV1OtFDNUe60c7aw0DXVl6SBDWXmCafp1Pc2yQnxrkusuGPt4ku/u7r+V5I+T/PTUfQ9294nJlzDMoFgphmEZwsoTcN469+uBgbi7P5nkiQvGfr27n5rc/HSSK5dQG6yEUAzDss6TLPCN1rVfF7GH+MeT/Oep28+vqt+rqt+sqpcv4PHh2AnFMCzrOskC32wd+3WuQFxVb0/yVJJfnAw9muSq7v6eJP88yS9V1d/c49yTVXW6qk7PUwMsi1D8jaZ79uzZs6suB77JOk6yq6JfWXfr1q9HDsRVdXPOv9nun3SfTw7d/bXu/srk57ty/g1337Hb+d19qru3u3v7qDXAsgnFXzfds1tbW6suB3a1bpPsquhXhmCd+vVI1yGuquuS/FSSv9fdX50a30ryRHc/XVUvSHJNki8spFJWbqxXYHD1CYZqU9+tfpB1fzc77Ea/rtaBK8RVdXuSTyV5UVWdqao3Jvn5JJcm+XhV3V1VvzA5/BVJ7qmq30/yy0ne3N1P7PrAMCBWimFY1mnlCdjfOvTrgSvE3X3TLsMf2OPYO5LcMW9RsI6sFMOwrMvKE3CwVferT6qDQ7BSDMOyDitPwGxW2a8CMRySUAzDIhTDcKyqX4/0pjrW1yxhzZ/952f7BIvy5N0Hr0tceuLcMVSy2Vb951g2w0H9qlcXYxX9KhBvkFlXLoW58/wOWLVZwvDOcSbaCLSs1Cz9unOMfh1ev9oysSEO+2d8f/aH1Zo1DB/1eGBx9Ovm8y+2AY4aboViWI2jTpYmWTh++nUc/GsN3LyhViiG4zXvJGmSheOj38bDvzQAwBII1MOx0W+qc8UFGJbPvOxVBx5z7e/8xjFUAhxEv7JJ/K8LAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwALd9n1D626BJiZQDxw8142zmXn4HhdeuLcSs+HZbvs+ody2fUP5ct3Xr3qUuamX8fjwEBcVbdU1eNVde/U2Dur6uGqunvydf3UfT9dVQ9U1f1V9f3LKpyvO2qoFYZhNY46SZpc4fjp13GY5YM5bk3y80k+eMH4e7r756YHquq7ktyY5MVJ/pck/3dVfUd3P72AWg9tTIHvoqpDfQzzmH43DMeYLuJ/6Ylzh/oUK5Mr6+agfn3B24+pkGOgXzffgf+63f3JJE/M+Hg3JPlQd3+tu/8kyQNJrp2jPg5h1pArDMN6mHXSNLnC6unXzTbPRze/tap+JMnpJD/Z3X+W5Iokn5465sxkjGMi7MKwmDxhOPTr5jrqm+rem+SFSU4keTTJuyfju6WxXf+OX1Unq+p0VZ0+Yg3AMZru2bNnz666HGAf+hUO50iBuLsf6+6nu/tckvfn69siziR53tShVyZ5ZI/HONXd2929fZQagOM13bNbW1urLgfYh36FwzlSIK6qy6duvj7JzhUoPpbkxqr61qp6fpJrknxmvhIBAGB5DtxDXFW3J3llksuq6kySdyR5ZVWdyPntEA8leVOSdPfnquojSf4wyVNJ3rKqK0wAAMAsDgzE3X3TLsMf2Of4dyV51zxFAQDAcfFJdQAAjJpADADAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoHRiIq+qWqnq8qu6dGvtwVd09+Xqoqu6ejF9dVX85dd8vLLN4AACY1yUzHHNrkp9P8sGdge7+xzs/V9W7k/z51PEPdveJRRUIAADLdGAg7u5PVtXVu91XVZXkh5L8/cWWBQAAx2PePcQvT/JYd39+auz5VfV7VfWbVfXyOR8fAACWapYtE/u5KcntU7cfTXJVd3+lql6S5D9W1Yu7+y8uPLGqTiY5OefzA8dkumevuuqqFVcD7Ee/wuEceYW4qi5J8o+SfHhnrLu/1t1fmfx8V5IHk3zHbud396nu3u7u7aPWAByf6Z7d2tpadTnAPvQrHM48Wyb+QZI/6u4zOwNVtVVVF09+fkGSa5J8Yb4SAQBgeWa57NrtST6V5EVVdaaq3ji568Z843aJJHlFknuq6veT/HKSN3f3E4ssGAAAFmmWq0zctMf4j+4ydkeSO+YvCwAAjodPqgMAYNQEYgAARk0gBgBg1ARiAABGTSAGAGDUBGIAAEZNIAYAYNQEYgAARk0gBgBg1ARiAABGTSAGAGDUqrtXXUOq6myS/5Hky6uuZQEui9exTobwOr69u7dWXcRhVNWTSe5fdR0LMIT/PmbhdRyvQfWsfl07XsfxmqlfLzmOSg7S3VtVdbq7t1ddy7y8jvWyKa9jDd2/Cb/XTfnvw+vgAPp1jXgd68mWCQAARk0gBgBg1NYpEJ9adQEL4nWsl015HetmU36vXsd62ZTXsW425ffqdayXTXkdSdbkTXUAALAq67RCDAAAx04gBgBg1ARiAABGTSAGAGDUBGIAAEZNIAYAYNQEYgAARk0gBgBg1ARiAABGTSAGAGDUBGIAAEZNIAYAYNQEYgAARk0gBgBg1ARiAABGTSAGAGDUBGIAAEZNIAYAYNQEYgAARk0gBgBg1ARiAABGTSAGAGDUBGIAAEZNIAYAYNQEYgAARk0gBgBg1ARiAABGTSAGAGDUBGIAAEZtaYG4qq6rqvur6oGqetuyngcAAOZR3b34B626OMkfJ3lNkjNJPpvkpu7+w4U/GQAAzGFZK8TXJnmgu7/Q3X+V5ENJbljScwEAwJEtKxBfkeRLU7fPTMYAAGCtXLKkx61dxr5hb0ZVnUxycnLzJUuqA4bgy929teoiDjLds894xjNe8p3f+Z0rrghW46677lr7ntWvcN6s/bqsQHwmyfOmbl+Z5JHpA7r7VJJTSVJVi9/IDMPxxVUXMIvpnt3e3u7Tp0+vuCJYjapa+57Vr3DerP26rC0Tn01yTVU9v6q+JcmNST62pOcCAIAjW8oKcXc/VVVvTfJfk1yc5Jbu/twyngsAAOaxrC0T6e47k9y5rMcHAIBF8El1AACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgdORBX1fOq6jeq6r6q+lxV/bPJ+Dur6uGqunvydf3iygUAgMW6ZI5zn0ryk939u1V1aZK7qurjk/ve090/N395AACwXEcOxN39aJJHJz8/WVX3JbliUYUBAMBxWMge4qq6Osn3JPnvk6G3VtU9VXVLVT1rEc8BAADLMHcgrqq/keSOJD/R3X+R5L1JXpjkRM6vIL97j/NOVtXpqjo9bw3A8k337NmzZ1ddDrAP/QqHM1cgrqq/lvNh+Be7+1eSpLsf6+6nu/tckvcnuXa3c7v7VHdvd/f2PDUAx2O6Z7e2tlZdDrAP/QqHM89VJirJB5Lc193/dmr88qnDXp/k3qOXBwAAyzXPVSZeluSHk/xBVd09GfuZJDdV1YkkneShJG+aq0IAAFiiea4y8dtJape77jx6OQAAcLx8Uh0AAKMmEAMAMGoCMQAAozbPm+pg45zrnvnYi2q3LfTAcTp5x8MzH3vqDT5MFVZpnftVIIYcLghfeI5gDMfvMBPrhecIxnC8htCvtkwwekcJw4s8Hzico0yuizwfmN1Q+lUgZtQWFWaFYjgei5ochWJYviH1qy0TjNZ+IXa/bRB7nXeu2/YJWKL9JsX9/qy613kn73jY9glYkqH1qxViRmmvUHtR1YGhdr9jrBTDcuw1SZ56wxUHTpL7HWOlGBZviP0qEDM6+4XhwxCK4XjsN7kehlAMyzfUfhWIIUe/UoQtErAaR/3TqS0ScPyG0K8CMaOy2+rtvKF2t/OtEsNi7LYaNO8kudv5VolhfkPuV4EYAIBRE4gZtUVtebB1Ao7Hov6EausELN+Q+lUgBgBg1ARiAABGbe4P5qiqh5I8meTpJE9193ZVPTvJh5NcneShJD/U3X8273MBAMCiLWqF+FXdfaK7tye335bkE919TZJPTG4DAMDaWdZHN9+Q5JWTn29L8t+S/NSSnos5zXKJMG8ag/Xx5N0Hr2VceuLcMVQCHES/DsMiAnEn+fWq6iTv6+5TSZ7b3Y8mSXc/WlXPWcDzsGCHuVbuzrGCMazOLBPrhceaaGE19OuwLGLLxMu6+3uTvDbJW6rqFbOcVFUnq+p0VZ1eQA0c0lE/OGLTPnBiUa9n034vu5nu2bNnz666nNE5zOS6iPPW1aIuyL/pH8ShX1dLv543pH6d+zff3Y9Mvj+e5KNJrk3yWFVdniST74/vct6p7t6e2nfMMZk3vI0h/PHNpnt2a2tr1eWMyryT5KZNshxMv66Ofh2muX7rVfWMqrp05+ck35fk3iQfS3Lz5LCbk/zqPM/D4ox9RXQZH7O8jI+Dhh2LmhyHOsku42Nbl/HxspDo1yH367y/8ecm+e2q+v0kn0nyn7r7vyT52SSvqarPJ3nN5DasLVtIYFiOOslu+lYJWEdD6Ne5AnF3f6G7//bk68Xd/a7J+Fe6+9Xdfc3k+xOLKZd5LDq8DTUM7rV6e9jXs9fxVodZlEWvEm3SqlNy+Mlyr+OtDrMI+vW8ofbrMH/bMKf9QvFBwXi/Y4RhWI79JtmDJtr9jhGGYfGG2K/Lug4xrL2LqvYMtkdZ/RaGYblOveGKPSfKo/xpVRiG5Rlav1ohZtQWFWKFYTgei5oUhWFYviH1q0DM6M0bZoVh1sl9X7si931ts8PevJOjMAzHZyj9assE5Ouh9jBbJQRhWJ2dSfIwf3oVhGE1htCvAjFMEXJhWIRcGI517ldbJgAAGDWBGACAUROIR2TR2wFsL4DluvTEubV+PODr9OuwCcQAAIyaN9WNzH4fRnFYO49jpRiW59IT577pI1xnuazahcfc91dXJr/10iTJzS//1OIKBP6n3fr1qJ68+yKrxMfICvEICbAwLPNOivf91ZULqgQ4yKJCrDB8vKwQj9RRV4qFaViN6ZWn//Vb976W587K8M4xl544l2uXXx4w5agrxULw6gjEI3aYD6MQhGH1dibLWSZaEyusln4dFoGYfcPuovYbA4uz3+R532SfsFVhWA/79eui9hszP4EYAGAFrAyvjyMH4qp6UZIPTw29IMm/SPLMJP80ydnJ+M90951HrhAAAJboyIG4u+9PciJJquriJA8n+WiSH0vynu7+uYVUCAAAS7SozSuvTvJgd39xQY8HAADHYlF7iG9McvvU7bdW1Y8kOZ3kJ7v7zxb0PBwzV5eAYfGhGwCHN/cKcVV9S5J/mOT/mgy9N8kLc347xaNJ3r3HeSer6nRVnZ63BmD5pnv27NmzB58ArIx+hcNZxJaJ1yb53e5+LEm6+7Hufrq7zyV5f/a4+k93n+ru7e7eXkANwJJN9+zW1taqywH2oV/hcBYRiG/K1HaJqrp86r7XJ7l3Ac8BAABLMdce4qr660lek+RNU8P/uqpOJOkkD11wHwAArJW5AnF3fzXJt10w9sNzVQQAAMfIZwYCADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjdsmqC+Abneue+zEuqlpAJcAsPvOyV839GNf+zm8soBLgIPqVvVghBgBg1ARiAABGTSAGAGDUZgrEVXVLVT1eVfdOjT27qj5eVZ+ffH/WZLyq6t9V1QNVdU9Vfe+yigcAgHnNukJ8a5LrLhh7W5JPdPc1ST4xuZ0kr01yzeTrZJL3zl8mAMB8rv2d3/ifXzBtpkDc3Z9M8sQFwzckuW3y821JfmBq/IN93qeTPLOqLl9EsQAAR/Xk3RflybvPRx/BmGnz7CF+bnc/miST78+ZjF+R5EtTx52ZjAEArBWhmGQ5b6rb7SK433Rx3ao6WVWnq+r0EmoAFmy6Z8+ePbvqcoB96NfDEYqZJxA/trMVYvL98cn4mSTPmzruyiSPXHhyd5/q7u3u3p6jBuCYTPfs1tbWqssB9qFfd3fpiXO59MS5Xe87KBRfdv1DS6iIdTFPIP5YkpsnP9+c5Fenxn9kcrWJv5vkz3e2VgAArKu9QvFOGBaKN9dMH91cVbcneWWSy6rqTJJ3JPnZJB+pqjcm+dMkPzg5/M4k1yd5IMlXk/zYgmsGAFiK3ULxk3dflLP/6dtXUA3HZaZA3N037XHXq3c5tpO8ZZ6iAAAWrd73uvSbfm3VZbCGfFIdADAa9b7XrboE1pBADACMylFC8db/9sW84O1/soRqWAcCMQAwOocNxfe95e8tqRLWgUAMAIzChfuH632v2zUYf+Zlr/qfX0n2vFQbm2OmN9VxfC6q3T7XBFhXLugPwzLLm+p262u9vtmsEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqBwbiqrqlqh6vqnunxv5NVf1RVd1TVR+tqmdOxq+uqr+sqrsnX7+wzOIBAGBes6wQ35rkugvGPp7ku7v7byX54yQ/PXXfg919YvL15sWUCQAAy3FgIO7uTyZ54oKxX+/upyY3P53kyiXUBgAAS7eIPcQ/nuQ/T91+flX9XlX9ZlW9fAGPDwAAS3PJPCdX1duTPJXkFydDjya5qru/UlUvSfIfq+rF3f0Xu5x7MsnJeZ4fOD7TPXvVVVetuBpgP/oVDufIK8RVdXOS1yX5J93dSdLdX+vur0x+vivJg0m+Y7fzu/tUd2939/ZRawCOz3TPbm1trbocYB/6FQ7nSIG4qq5L8lNJ/mF3f3VqfKuqLp78/IIk1yT5wiIKBQCAZThwy0RV3Z7klUkuq6ozSd6R81eV+NYkH6+qJPn05IoSr0jyL6vqqSRPJ3lzdz+x6wMDAMAaODAQd/dNuwx/YI9j70hyx7xFAQDAcfFJdRvg3Pkt3MBA1Ptet+oSgBnp13EQiDeEUAzDYpKF4dCvm08g3iBCMQyLSRaGQ79uNoF4wwjFMCwmWRgO/bq5BOINJBTDsJhkYTj062YSiDeUUAzDYpKF4dCvm0cTBt1PAAAbBUlEQVQg3mBCMQyLSRaGQ79uFoF4wwnFMCwmWRgO/bo5BOIREIphWEyyMBz6dTMIxCMhFMOwmGRhOPTr8AnEIyIUw7CYZGE49OuwCcQjIxTDsJhkYTj063AJxCMkFMOwmGRhOPTrMAnEIyUUw7CYZGE49OvwCMQjJhTDsJhkYTj067BcsuoCmN9FVasuATiEftOvrboEYEb6dRwOXCGuqluq6vGqundq7J1V9XBV3T35un7qvp+uqgeq6v6q+v5lFQ4AAIswy5aJW5Nct8v4e7r7xOTrziSpqu9KcmOSF0/O+Q9VdfGiigUAgEU7cMtEd3+yqq6e8fFuSPKh7v5akj+pqgeSXJvkU0euEABg4qUvfNe+93/qwbcfUyVsknn2EL+1qn4kyekkP9ndf5bkiiSfnjrmzGRsNM5156KqQ38HVuOgyXU3JlxYjVn69cJj9CuzOOpVJt6b5IVJTiR5NMm7J+O7JbtdL2VQVSer6nRVnT5iDWtpJ9we9jusu+mePXv27KrLAfaxif16lP95hVkdKRB392Pd/XR3n0vy/pzfFpGcXxF+3tShVyZ5ZI/HONXd2929fZQa1tXOpcwO+x3W3XTPbm1trbocYB/6FQ7nSIG4qi6fuvn6JDtXoPhYkhur6lur6vlJrknymflKHBYrxACwPqwsM4sD9xBX1e1JXpnksqo6k+QdSV5ZVSdyfjvEQ0nelCTd/bmq+kiSP0zyVJK3dPfTyyl9PdlDDAAwLLNcZeKmXYY/sM/x70oy2v8ds0IMADAsPrp5wewhBgAYFoF4wawQAwAMi0C8YFaIAQCGRSBeMCvEAADDIhAvmBViAIBhEYgXzAoxAMCwCMQLZoUYANbHpx58+6pLYAAOvA4xh2OFGIbFZAnD8KkH337oT53T38xKIAYABkHAZVlsmQAAYNQEYgAARk0gBgBg1ARiAABGTSAGAGDUBGIAAEZNIAYAYNQODMRVdUtVPV5V906Nfbiq7p58PVRVd0/Gr66qv5y67xeWWTwAAMxrlg/muDXJzyf54M5Ad//jnZ+r6t1J/nzq+Ae7+8SiCgQAgGU6MBB39yer6urd7quqSvJDSf7+YssCAIDjMe8e4pcneay7Pz819vyq+r2q+s2qevmcjw8AAEs1y5aJ/dyU5Pap248muaq7v1JVL0nyH6vqxd39FxeeWFUnk5yc8/mBYzLds1ddddWKqwH2o1/hcI68QlxVlyT5R0k+vDPW3V/r7q9Mfr4ryYNJvmO387v7VHdvd/f2UWsAjs90z25tba26HGAf+hUOZ54tE/8gyR9195mdgaraqqqLJz+/IMk1Sb4wX4kAALA8s1x27fYkn0ryoqo6U1VvnNx1Y75xu0SSvCLJPVX1+0l+Ocmbu/uJRRYMAACLNMtVJm7aY/xHdxm7I8kd85cFAADHwyfVAQAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAo1bdveoaUlVnk/yPJF9edS0LcFm8jnUyhNfx7d29teoiDqOqnkxy/6rrWIAh/PcxC6/jeA2qZ/Xr2vE6jtdM/XrJcVRykO7eqqrT3b296lrm5XWsl015HWvo/k34vW7Kfx9eBwfQr2vE61hPtkwAADBqAjEAAKO2ToH41KoLWBCvY71syutYN5vye/U61sumvI51sym/V69jvWzK60iyJm+qAwCAVVmnFWIAADh2AjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwagIxAACjJhADADBqAjEAAKMmEAMAMGoCMQAAoyYQAwAwaksLxFV1XVXdX1UPVNXblvU8AAAwj+ruxT9o1cVJ/jjJa5KcSfLZJDd19x8u/MkAAGAOy1ohvjbJA939he7+qyQfSnLDkp4LAACO7JIlPe4VSb40dftMkr8zfUBVnUxycnLzJUuqA4bgy929teoiDjLds894xjNe8p3f+Z0rrghW46677lr7ntWvcN6s/bqsQFy7jH3D3ozuPpXkVJJU1eL3bcBwfHHVBcxiume3t7f79OnTK64IVqOq1r5n9SucN2u/LmvLxJkkz5u6fWWSR5b0XAAAcGTLCsSfTXJNVT2/qr4lyY1JPrak5wIAgCNbypaJ7n6qqt6a5L8muTjJLd39uWU8FwAAzGNZe4jT3XcmuXNZjw8AAIvgk+oAABg1gRgAgFETiAEAGDWBGACAUROIAQAYNYEYAIBRE4gBABg1gRgAgFETiAEAGDWBGACAUROIAQAYNYEYAIBRE4gBABg1gRgAgFETiAEAGDWBGACAUTtyIK6q51XVb1TVfVX1uar6Z5Pxd1bVw1V19+Tr+sWVCwAAi3XJHOc+leQnu/t3q+rSJHdV1ccn972nu39u/vIAAGC5jhyIu/vRJI9Ofn6yqu5LcsWiCgMAgOOwkD3EVXV1ku9J8t8nQ2+tqnuq6paqetYe55ysqtNVdXoRNQDLNd2zZ8+eXXU5wD70KxzO3IG4qv5GkjuS/ER3/0WS9yZ5YZITOb+C/O7dzuvuU9293d3b89YALN90z25tba26HGAf+hUOZ65AXFV/LefD8C92968kSXc/1t1Pd/e5JO9Pcu38ZQIAwHLMc5WJSvKBJPd197+dGr986rDXJ7n36OUBAMByzXOViZcl+eEkf1BVd0/GfibJTVV1IkkneSjJm+aqEAAAlmieq0z8dpLa5a47j14OAAAcL59UBwDAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoXbLqAli8c90HHnNR1TFUAsziybsPXpu49MS5Y6gEOIh+3UxzB+KqeijJk0meTvJUd29X1bOTfDjJ1UkeSvJD3f1n8z4X+5slCF94rGAMqzPLxHrhsSZaWA39utkWtWXiVd19oru3J7ffluQT3X1Nkk9MbrNEhwnDizgPmM9hJtdFnAccnX7dfMv6l7ohyW2Tn29L8gNLeh4yf6gViuF4zTtJmmTh+OjXcVjEv1In+fWququqTk7GntvdjybJ5PtzLjypqk5W1emqOr2AGkZrUWFWKOYg0z179uzZVZczWIuaHE2y7Ee/LoZ+HY9F/Au9rLu/N8lrk7ylql4xy0ndfaq7t6e2WQBrbLpnt7a2Vl0OsA/9CoczdyDu7kcm3x9P8tEk1yZ5rKouT5LJ98fnfR6+2aJXda0Sw3ItepXIqhMsj34dl7n+darqGVV16c7PSb4vyb1JPpbk5slhNyf51XmeBwAAlmXey649N8lH6/yluy5J8kvd/V+q6rNJPlJVb0zyp0l+cM7nAQCApZgrEHf3F5L87V3Gv5Lk1fM8NgAAHAcbWgAAGDWBGACAUROIAQAYNYEYAIBRE4gBABg1gXjALjp/ubu1fTzgG1164txaPx7wdfp1XARiAABGTSAeuEWt6lodhuOxqFUiq02wfPp1PATiDTBvmBWG4XjNOzmaXOH46NdxEIg3xFFDrTAMq3HUSdLkCsdPv26+uT66mfWyE27Pdc98LLA6O5Plk3cfvDZhYoXV0q+bTSDeQMIuDIvJE4ZDv24mWyYAABg1gRgAgFETiAEAGDWBGACAUTvym+qq6kVJPjw19IIk/yLJM5P80yRnJ+M/0913HrlCAABYoiMH4u6+P8mJJKmqi5M8nOSjSX4syXu6++cWUiEAACzRorZMvDrJg939xQU9HgAAHItFBeIbk9w+dfutVXVPVd1SVc/a7YSqOllVp6vq9IJqAJZoumfPnj178AnAyuhXOJzqGT7VbN8HqPqWJI8keXF3P1ZVz03y5SSd5F8luby7f/yAx5ivCBi2u7p7e9VFHMb29nafPu3/ZRmnqhpUz+rXw3npC9916HM+9eDbl1AJizBrvy7ik+pem+R3u/uxJNn5Pini/Ul+bQHPwcS57lxUNfd3YPmOMrHuxYQLy3fUnp0+T68O0yK2TNyUqe0SVXX51H2vT3LvAp6DiZ0wO+93AODrFvk/sAzPXCvEVfXXk7wmyZumhv91VZ3I+S0TD11wH3OyQgwAsFhzBeLu/mqSb7tg7Ifnqoh9WSEGgPX10he+y7aJAfJJdQNzbvImyHm/AwBwnkA8MFaIAQAWSyAeGCvEAACLJRAPjBViAIDFEogHxgoxAMBiCcQDY4UYAGCxBOKBsUIMALBYi/jo5kE4TBBc51VUK8SMxck7Hp752FNvuGKJlQAH0a8M3cYH4qOsiO6cs47h0SfVsekOM7FeeI6JFo6XfmVTbPSWiXm3B6zj9gIrxGyyo0yuizwfmJ1+ZZNsbCBeVJhdt1BsDzGbalGTo0kWlk+/smk2csvEfqFvvxXSvc5bp20GQ10h3pQ93CzHfpPifn9W3eu8k3c87M+xc/rMy14187HX/s5vLLES1o1+XT/6dX4bt0K8V/C6qOrAoLXfMeuysmqFmE2z1yR56g1XHDhJ7neMlSdYPP3KptqoQLxfGD6MdQ7FQ10hht3sN7kehkkWlk+/ssk2KhDv5qgBcF2DoxViNt1R/3TqT65w/DapXz/14NvX6nE4Xhuzh3i3oDdvqN25TNmFz7PKsGyFmE2x22rQvJPkqTdc8U2Pu8r9iSZGNsUY+jX5es++9IXvmut8hmemQFxVtyR5XZLHu/u7J2PPTvLhJFcneSjJD3X3n1VVJfk/klyf5KtJfrS7f3fxpQMALJ5gOz6zbpm4Ncl1F4y9LcknuvuaJJ+Y3E6S1ya5ZvJ1Msl75y/z8Ba1EmpFFY7HolaF1vFPsbBp9CubZqZA3N2fTPLEBcM3JLlt8vNtSX5gavyDfd6nkzyzqi5fRLEAALBo87yp7rnd/WiSTL4/ZzJ+RZIvTR13ZjL2DarqZFWdrqrTc9QAHJPpnj179uyqywH2oV/hcJZxlYnd9hh80zveuvtUd2939/YSagAWbLpnt7a2Vl0OsA/9CoczTyB+bGcrxOT745PxM0meN3XclUkemeN5AABgaea57NrHktyc5Gcn3391avytVfWhJH8nyZ/vbK1gvLw5EYbFx7vCcOjX+c162bXbk7wyyWVVdSbJO3I+CH+kqt6Y5E+T/ODk8Dtz/pJrD+T8Zdd+bME1AwDAwswUiLv7pj3uevUux3aSt8xT1CIs6gM0fLIbHI9FXZDfx7/C8ulXNs3Gf3QzAADsZ2MC8W6rwfOu7i7j46CB83ZbXZp3tWgZHy8L6Fc238YE4r0cNRTbKgGrcdRJ1p9e4fjpVzbFRgXivVZvDxtu9zre6jAs1l6rQYedLPc63moTLI5+ZZNtVCBO9g/FBwXj/Y4RhmE59ptkD5po9zvG5AqLp1/ZVPNch3htXVS1Z7A9ylaIMYbhnd/TGF87x+/UG67Yc6I8yp9Wxzi53vZbL02S3PzyT624Ejadfp2ffl0/G7dCvGNRQU4ghOOxqElxjJMrHDf9yqbZ2ECczB9mhWE4XvNOjiZXOD76lU2ykVsmpu2E2sNslRCEYXV2JsnD/OnVxAqroV/ZFBsfiHcIuTAsJk0YDv3K0G30lgkAADiIQAwAwKgJxAAAjNpo9hDzjWZ9k+FBx9mbDcdj57ql8x7nuqewfPp1eKwQAwAwalaIR+qglV2fVAfr5aCVIp98BetDvw7PgSvEVXVLVT1eVfdOjf2bqvqjqrqnqj5aVc+cjF9dVX9ZVXdPvn5hmcUDAMC8ZtkycWuS6y4Y+3iS7+7uv5Xkj5P89NR9D3b3icnXmxdTJgAALMeBgbi7P5nkiQvGfr27n5rc/HSSK5dQGwAALN0i3lT340n+89Tt51fV71XVb1bVy/c6qapOVtXpqjq9gBqAJZvu2bNnz666HGAf+hUOZ65AXFVvT/JUkl+cDD2a5Kru/p4k/zzJL1XV39zt3O4+1d3b3b09Tw3A8Zju2a2trVWXA+xDv8LhHDkQV9XNSV6X5J90n78kQXd/rbu/Mvn5riQPJvmORRQKAADLcKRAXFXXJfmpJP+wu786Nb5VVRdPfn5BkmuSfGERhQIAwDIceB3iqro9ySuTXFZVZ5K8I+evKvGtST5e569T++nJFSVekeRfVtVTSZ5O8ubufmLXB2atuf4wDIvrmcJw6Nf1c2Ag7u6bdhn+wB7H3pHkjnmLAgCA4+KjmwEAGDWBGACAUROIAQAYNYEYAIBRE4gBABg1gRgAgFETiAEAGDWBGACAUROIAQAYNYEYAIBRE4gBABg1gRgAgFETiAEAGDWBGACAUROIAQAYNYEYAIBROzAQV9UtVfV4Vd07NfbOqnq4qu6efF0/dd9PV9UDVXV/VX3/sgoHAIBFmGWF+NYk1+0y/p7uPjH5ujNJquq7ktyY5MWTc/5DVV28qGIBAGDRDgzE3f3JJE/M+Hg3JPlQd3+tu/8kyQNJrp2jPgAAWKp59hC/tarumWypeNZk7IokX5o65sxk7JtU1cmqOl1Vp+eoATgm0z179uzZVZcD7EO/wuEcNRC/N8kLk5xI8miSd0/Ga5dje7cH6O5T3b3d3dtHrAE4RtM9u7W1tepygH3oVzicIwXi7n6su5/u7nNJ3p+vb4s4k+R5U4demeSR+UoEAIDlOVIgrqrLp26+PsnOFSg+luTGqvrWqnp+kmuSfGa+EgEAYHkuOeiAqro9ySuTXFZVZ5K8I8krq+pEzm+HeCjJm5Kkuz9XVR9J8odJnkrylu5+ejmlAwDA/A4MxN190y7DH9jn+Hcledc8RQEAwHHxSXUAAIyaQAwAwKgJxAzCud716n3Amqr3vW7VJQAz0q8CMQMiFMOwmGRhOMberwIxgyIUw7CMfZKFIRlzvwrEDI5QDMMy5kkWhmas/SoQM0hCMQzLWCdZGKIx9qtAzGAJxTAsY5xkYajG1q8CMYMmFMOwjG2ShSEbU78KxAyeUAzDMqZJFoZuLP0qELMRhGIYlrFMsrAJxtCvAjEbQyiGYRnDJAubYtP7VSBmowjFMCybPsnCJtnkfhWI2ThCMQzLJk+ysGk2tV8FYjaSUAzDsqmTLGyiTexXgZiNJRTDsGziJAubatP69cBAXFW3VNXjVXXv1NiHq+ruyddDVXX3ZPzqqvrLqft+YZnFw0GEYhiWTZtkYZNtUr/OskJ8a5Lrpge6+x9394nuPpHkjiS/MnX3gzv3dfebF1cqHI1QDMOySZMsbLpN6dcDA3F3fzLJE7vdV1WV5IeS3L7gumChhGIYlk2ZZGEMNqFf591D/PIkj3X356fGnl9Vv1dVv1lVL9/rxKo6WVWnq+r0nDXATITi+Uz37NmzZ1ddDiOwCZPsquhXjtvQ+3XeQHxTvnF1+NEkV3X39yT550l+qar+5m4ndvep7t7u7u05a4CZCcVHN92zW1tbqy6HkRj6JLsq+pVVGHK/HjkQV9UlSf5Rkg/vjHX317r7K5Of70ryYJLvmLdIWCShGIZlyJMsjM1Q+3WeFeJ/kOSPuvvMzkBVbVXVxZOfX5DkmiRfmK9EWDyhGIZlqJMsjNEQ+3WWy67dnuRTSV5UVWeq6o2Tu27MN7+Z7hVJ7qmq30/yy0ne3N27viEPVk0ohmEZ4iQLYzW0fr3koAO6+6Y9xn90l7E7cv4ybLBQF1WtugTgEPpNv7bqEoAZ6VefVAcAwMgJxAAAjJpADADAqAnEAACMmkAMAMCoCcQAAIyaQAwAwKgJxAAAjJpADADAqAnEAACMmkAMAMCoVXevuoZU1dkk/yPJl1ddywJcFq9jnQzhdXx7d2+tuojDqKonk9y/6joWYAj/fczC6zheg+pZ/bp2vI7jNVO/XnIclRyku7eq6nR3b6+6lnl5HetlU17HGrp/E36vm/Lfh9fBAfTrGvE61pMtEwAAjJpADADAqK1TID616gIWxOtYL5vyOtbNpvxevY71simvY91syu/V61gvm/I6kqzJm+oAAGBV1mmFGAAAjt3KA3FVXVdV91fVA1X1tlXXcxhV9VBV/UFV3V1Vpydjz66qj1fV5yffn7XqOi9UVbdU1eNVde/U2K5113n/bvLvc09Vfe/qKv9Ge7yOd1bVw5N/k7ur6vqp+3568jrur6rvX03Vw6dnj5+e1bNHpV+Pn34dZr+uNBBX1cVJ/n2S1yb5riQ3VdV3rbKmI3hVd5+YuvTI25J8oruvSfKJye11c2uS6y4Y26vu1ya5ZvJ1Msl7j6nGWdyab34dSfKeyb/Jie6+M0km/13dmOTFk3P+w+S/Pw5Bz67MrdGzevaQ9OvK3Br9Orh+XfUK8bVJHujuL3T3XyX5UJIbVlzTvG5Ictvk59uS/MAKa9lVd38yyRMXDO9V9w1JPtjnfTrJM6vq8uOpdH97vI693JDkQ939te7+kyQP5Px/fxyOnl0BPatnj0i/roB+HWa/rjoQX5HkS1O3z0zGhqKT/HpV3VVVJydjz+3uR5Nk8v05K6vucPaqe4j/Rm+d/Onplqk/pw3xdayjof8e9ex60rPLMfTfoX5dTxvZr6sOxLXL2JAue/Gy7v7enP+Tx1uq6hWrLmgJhvZv9N4kL0xyIsmjSd49GR/a61hXQ/896tn1o2eXZ+i/Q/26fja2X1cdiM8ked7U7SuTPLKiWg6tux+ZfH88yUdz/s8Dj+38uWPy/fHVVXgoe9U9qH+j7n6su5/u7nNJ3p+v/8lmUK9jjQ3696hn14+eXapB/w716/rZ5H5ddSD+bJJrqur5VfUtOb8h+2MrrmkmVfWMqrp05+ck35fk3pyv/+bJYTcn+dXVVHhoe9X9sSQ/Mnkn7N9N8uc7f/ZZRxfsvXp9zv+bJOdfx41V9a1V9fycfwPDZ467vg2gZ9eHnuUg+nV96Nd1190r/UpyfZI/TvJgkrevup5D1P2CJL8/+frcTu1Jvi3n30H6+cn3Z6+61l1qvz3n/9Tx/+X8/9W9ca+6c/7PIP9+8u/zB0m2V13/Aa/j/5zUeU/ON+jlU8e/ffI67k/y2lXXP9QvPbuS2vWsnj3q71y/Hn/t+nWA/eqT6gAAGLVVb5kAAICVEogBABg1gRgAgFETiAEAGDWBGACAUROIAQAYNYEY+P/brQMBAAAAAEH+1oNcFAHAmhADALAWTxXZI62AmxkAAAAASUVORK5CYII=\n", 873 | "text/plain": [ 874 | "
" 875 | ] 876 | }, 877 | "metadata": { 878 | "needs_background": "light" 879 | }, 880 | "output_type": "display_data" 881 | } 882 | ], 883 | "source": [ 884 | "#### prediction\n", 885 | "\n", 886 | "import math\n", 887 | "\n", 888 | "model.eval() # Set model to evaluate mode\n", 889 | "\n", 890 | "test_dataset = SimDataset(3, transform = trans)\n", 891 | "test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)\n", 892 | " \n", 893 | "inputs, labels = next(iter(test_loader))\n", 894 | "inputs = inputs.to(device)\n", 895 | "labels = labels.to(device)\n", 896 | "\n", 897 | "pred = model(inputs)\n", 898 | "pred = torch.sigmoid(pred)\n", 899 | "pred = pred.data.cpu().numpy()\n", 900 | "print(pred.shape)\n", 901 | "\n", 902 | "# Change channel-order and make 3 channels for matplot\n", 903 | "input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]\n", 904 | "\n", 905 | "# Map each channel (i.e. class) to each color\n", 906 | "target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]\n", 907 | "pred_rgb = [helper.masks_to_colorimg(x) for x in pred]\n", 908 | "\n", 909 | "helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])" 910 | ] 911 | } 912 | ], 913 | "metadata": { 914 | "kernelspec": { 915 | "display_name": "Python 3", 916 | "language": "python", 917 | "name": "python3" 918 | }, 919 | "language_info": { 920 | "codemirror_mode": { 921 | "name": "ipython", 922 | "version": 3 923 | }, 924 | "file_extension": ".py", 925 | "mimetype": "text/x-python", 926 | "name": "python", 927 | "nbconvert_exporter": "python", 928 | "pygments_lexer": "ipython3", 929 | "version": "3.6.7" 930 | } 931 | }, 932 | "nbformat": 4, 933 | "nbformat_minor": 2 934 | } 935 | -------------------------------------------------------------------------------- /pytorch_unet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "(3, 192, 192, 3)\n", 13 | "0 255\n", 14 | "(3, 6, 192, 192)\n", 15 | "0.0 1.0\n" 16 | ] 17 | }, 18 | { 19 | "data": { 20 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAKvCAYAAAArysUEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+sHOV97/HPp3bDH5QrIGwty+AakJMoRO0JHLlBCRGEJjEIxSGRqK0qcVJ0D0hw1bSVekmRLqgSUpWGoFv1huQgLJyrxEDrkKDKbcNFuSGtoHBMfB1DINjECB859glEAYWIxPb3/nFmk+Gw692zM/vMj32/pNXuPjOz851zzvjj55nZGUeEAABAOr9VdQEAAEwawhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMTGFr62N9p+1vZ+2zeNaz0AADSNx/E9X9srJP1Q0gclHZL0hKQtEfF06SsDAKBhxtXz3SBpf0Q8HxG/lHSvpE1jWhcAAI2yckyfu0bSi7n3hyT9Yb+ZbXOZLUyyn0REp+oiynLWWWfFunXrqi4DqMTu3buH2p/HFb4D2Z6RNFPV+oEaeaHqAorK789r167V3NxcxRUB1bA91P48rmHneUnn5N6fnbX9WkTMRsR0REyPqQYAieT3506nNZ14YGzGFb5PSFpv+1zbb5G0WdKDY1oXAACNMpZh54g4ZvtGSf8maYWkbRHx1DjWBQBA04ztmG9E7JK0a1yfDwBAU3GFKwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXzRGRFRdAoCS+MtXVV1CpQhfNAoBDLTHJAcw4YvGIYCB9pjUACZ80UgEMNAekxjAI4ev7XNsf9v207afsv1nWfuttudt78keV5ZXLvAbBDDQHpMWwEV6vsck/WVEvFPSeyTdYPud2bQ7ImIqe+wqXCXQBwEMtMckBfDI4RsRhyPiyez1q5J+IGlNWYUBwyKAgfaYlAAu5Ziv7XWS3i3pP7OmG23vtb3N9hllrAM4GQIYaI9JCODC4Wv7dyTtlPSZiHhF0p2Szpc0JemwpNv7LDdje872XNEaAIkArlJ+f15YWKi6HLRA2wO4UPja/m0tBu9XI+LrkhQRRyLieESckHSXpA29lo2I2YiYjojpIjUAeQRwNfL7c6fTqboctESbA7jI2c6WdLekH0TEF3Ltq3OzXS1p3+jlActHAAPt0dYALtLzfa+kT0j6wJKvFX3O9vdt75V0maQ/L6NQYDkIYKA92hjAK0ddMCL+XZJ7TOKrRaiFiNDiAA2ApvOXr1Jc989Vl1EarnCFVqMHDLRHm3rAhC9ajwAG2qMtAUz4YiIQwEB7tCGARz7mC6TG8VugPdp0/HYU9HwBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxLi2MwDgpC4+/7bSPuvRAzeX9llNVrjna/ug7e/b3mN7Lms70/ZDtp/Lns8oXiqapHsXobKeAaBNyhp2viwipiJiOnt/k6SHI2K9pIez95gg3TsQlfUMAG0yrmO+myRtz15vl/TRMa0HNUXPFwD6KyN8Q9K3bO+2PZO1rYqIw9nrH0taVcJ60CD0fAGgvzJOuHpfRMzb/l1JD9l+Jj8xIsL2m7ovWVDPLG1HO0SEbJf2jHrL789r166tuBqg/gr3fCNiPns+KukBSRskHbG9WpKy56M9lpuNiOnccWK0CD3fyZLfnzudTtXlALVXKHxtn2r7tO5rSR+StE/Sg5K2ZrNtlfTNIutB83DMFwD6KzrsvErSA1nvZKWkr0XEv9p+QtL9tq+V9IKkawquBw1DzxcA+isUvhHxvKQ/6NH+kqTLi3w2mo1jvgDQH5eXxFjQ8wWA/ghfjAXHfAGgP8IXY0HPFwD6I3wxFvR8AaA/whdjQc8XAPojfDEW9HwBoD/CF2NBzxcA+iN8MRb0fAGgP8IXY0HPFwD6K+OuRpUatWfEP+pA/Tz+3st+/frJiz/Wc55e7bMfXzO2miA9euDmqktoncaH76i6oU0IA/XSL3RPZmbnvCRCGM0x8cPOHFME6mOU4M3rhjBQd40O37KCkwAGqldWcBLAaILGDjsvJzDzQ8v9luPuOUB1ThaYFz769Te8v/7z/23gcjM75xmCRq01sudbpKdqu2/I0gMG0usXoBc++vU3Be9Ssx9f0zdk6QGjzhoXvmUFJAEMVO9kwbscBDCaZuTwtf1223tyj1dsf8b2rbbnc+1XlllwmRhmBupnucHbxTAzmmTk8I2IZyNiKiKmJF0k6TVJD2ST7+hOi4hdZRSarbOsj/q1XgFM77d6EcHvoeV69UpHDd6uXgFM77d62797sbZ/9+Kqy6iVsk64ulzSgYh4IXVvkt4r0B4b/uPbVZcAJFHWMd/Nknbk3t9oe6/tbbbPKGkdb1JW8BLgQPXKGjZm+BlNUDh8bb9F0kck/WPWdKek8yVNSTos6fY+y83YnrM9V7QGANXK788LCwtVlwPUXhk93yskPRkRRyQpIo5ExPGIOCHpLkkbei0UEbMRMR0R0yXUAKBC+f250+lUXQ5Qe2WE7xblhpxtr85Nu1rSvhLWAQBAaxQ64cr2qZI+KOm6XPPnbE9JCkkHl0wDAGDiFQrfiPi5pLcuaftEoYoAAGi5xl3hCgCApmvsjRWk8m6GwMUc0lrOz3uYefmqWDuUdTMELqqR1nIunjHMvFsvebRIOY1BzxcAgMRch16f7aGL6FVvkZ5P2Z+HcnR/LxPyu9jdpq/cTU9Px9zccF/f79VLLdL7LfvzUI5uj3cSerW2h9qfW9HzHfU/EHX4jweANxp12JjhZjRJ48K3rFsB9pt/QnpaQC2UdSvAfvPT60VdNS58pZMH8KAQPtk8BC+Q3skCeFAIn2weghd11tiznW33DdFRhpMJXqA6sx9f0zdERxlOJnhRd43s+XZxVyOgPbirESZJo8NXKh6cBC9QH0WDk+BFUzR22DmvG6DLGW4mdIF66gbocoabCV00TSvCt4tABdqDQEWbtSp80R78Rwpoj0m4uMZytS58Tzb0zD/oQLO8uqf/aSmnTZ1IWAlQrtaE7zDHeyfskoVAY50sdJfOQwijiRp/trNU3tWtAFRvmOAtMj9QB0P91dreZvuo7X25tjNtP2T7uez5jKzdtv/e9n7be21fOK7iAQBoomH/y3iPpI1L2m6S9HBErJf0cPZekq6QtD57zEi6s3iZ/XFTBaA9Ru3F0vtF0wz1FxsRj0h6eUnzJknbs9fbJX001/6VWPSYpNNtry6jWAAA2qDIfxdXRcTh7PWPJa3KXq+R9GJuvkNZGwAAUEknXMXiGO6yxnFtz9iesz3cXbcB1FZ+f15YWKi6HKD2ioTvke5wcvZ8NGufl3RObr6zs7Y3iIjZiJiOiOkCNQCogfz+3Ol0qi4HqL0i4fugpK3Z662Svplr/2R21vN7JP0sNzwNAMDEG+oiG7Z3SLpU0lm2D0m6RdLfSrrf9rWSXpB0TTb7LklXStov6TVJny65ZgAAGm2o8I2ILX0mXd5j3pB0Q5GilsP2SF8b4ipXQP2cNnVipK8NcZUrNA1fjgMAILFWhO9ye7H0eoH6Wm4vll4vmqg1N1boBip3NQKarxuo3NUIbdWa8O0iYIH2IGDRVq0YdgYAoEkIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACCxgeFre5vto7b35dr+zvYztvfafsD26Vn7Otu/sL0ne3xpnMUDANBEw/R875G0cUnbQ5LeFRG/L+mHkj6bm3YgIqayx/XllAkAQHsMDN+IeETSy0vavhURx7K3j0k6ewy1AQDQSmUc8/1TSf+Se3+u7e/Z/o7tS0r4fAAAWmVlkYVt3yzpmKSvZk2HJa2NiJdsXyTpG7YviIhXeiw7I2mmyPoB1EN+f167dm3F1QD1N3LP1/anJF0l6U8iIiQpIl6PiJey17slHZD0tl7LR8RsRExHxPSoNQCoh/z+3Ol0qi4HqL2Rwtf2Rkl/JekjEfFarr1je0X2+jxJ6yU9X0ahAAC0xcBhZ9s7JF0q6SzbhyTdosWzm0+R9JBtSXosO7P5/ZL+xvavJJ2QdH1EvNzzgwEAmFADwzcitvRovrvPvDsl7SxaFAAAbcYVrgAASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIbGL62t9k+antfru1W2/O292SPK3PTPmt7v+1nbX94XIUDANBUw/R875G0sUf7HRExlT12SZLtd0raLOmCbJkv2l5RVrEAALTBwPCNiEckvTzk522SdG9EvB4RP5K0X9KGAvUBANA6RY753mh7bzYsfUbWtkbSi7l5DmVtAAAgM2r43inpfElTkg5Lun25H2B7xvac7bkRawBQE/n9eWFhoepygNobKXwj4khEHI+IE5Lu0m+GluclnZOb9eysrddnzEbEdERMj1IDgPrI78+dTqfqcoDaGyl8ba/Ovb1aUvdM6AclbbZ9iu1zJa2X9HixEgEAaJeVg2awvUPSpZLOsn1I0i2SLrU9JSkkHZR0nSRFxFO275f0tKRjkm6IiOPjKR0AgGYaGL4RsaVH890nmf82SbcVKQoAgDbjClcAACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJDQxf29tsH7W9L9d2n+092eOg7T1Z+zrbv8hN+9I4iwcAoIlWDjHPPZL+QdJXug0R8cfd17Zvl/Sz3PwHImKqrAIBAGibgeEbEY/YXtdrmm1LukbSB8otCwCA9ip6zPcSSUci4rlc27m2v2f7O7YvKfj5AAC0zjDDziezRdKO3PvDktZGxEu2L5L0DdsXRMQrSxe0PSNppuD6AdRAfn9eu3ZtxdUA9Tdyz9f2Skkfk3Rfty0iXo+Il7LXuyUdkPS2XstHxGxETEfE9Kg1AKiH/P7c6XSqLgeovSLDzn8k6ZmIONRtsN2xvSJ7fZ6k9ZKeL1YiAADtMsxXjXZIelTS220fsn1tNmmz3jjkLEnvl7Q3++rRP0m6PiJeLrNgAACabpiznbf0af9Uj7adknYWLwsAgPbiClcAACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYo6IqmuQ7QVJP5f0k6prKcFZYjvqpAnb8XsR0Zr78Nl+VdKzVddRgib87QyD7UhrqP25FuErSbbn2nBvX7ajXtqyHU3Slp8521EvbdmOLoadAQBIjPAFACCxOoXvbNUFlITtqJe2bEeTtOVnznbUS1u2Q1KNjvkCADAp6tTzBQBgIhC+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiY0tfG1vtP2s7f22bxrXegAAaBpHRPkfaq+Q9ENJH5R0SNITkrZExNOlrwwAgIYZV893g6T9EfF8RPxS0r2SNo1pXQAANMq4wneNpBdz7w9lbQAATLyVVa3Y9oykmeztRVXVAdTATyKiU3URReT351NPPfWid7zjHRVXBFRj9+7dQ+3P4wrfeUnn5N6fnbX9WkTMSpqVJNvlH3gGmuOFqgsoKr8/T09Px9zcXMUVAdWwPdT+PK5h5yckrbd9ru23SNos6cExrQsAgEYZS883Io7ZvlHSv0laIWlbRDw1jnUBANA0YzvmGxG7JO0a1+cDANBUXOEKAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgsZHD1/Y5tr9t+2nbT9n+s6z9VtvztvdkjyvLKxcAgOZbWWDZY5L+MiKetH2apN22H8qm3RERny9eHgAA7TNy+EbEYUmHs9ev2v6BpDVlFQYAQFuVcszX9jpJ75b0n1nTjbb32t5m+4wy1gEAQFsUDl/bvyNpp6TPRMQrku6UdL6kKS32jG/vs9yM7Tnbc0VrAFCt/P68sLBQdTlA7RUKX9u/rcXg/WpEfF2SIuJIRByPiBOS7pK0odeyETEbEdMRMV2kBgDVy+/PnU6n6nKA2itytrMl3S3pBxHxhVz76txsV0vaN3p5AAC0T5Gznd8r6ROSvm97T9b215K22J6SFJIOSrquUIUAALRMkbOd/12Se0zaNXo5AAC0H1e4AgAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASGxl0Q+wfVDSq5KOSzoWEdO2z5R0n6R1kg5KuiYiflp0XQAAtEFZPd/LImIqIqaz9zdJejgi1kt6OHsPAAA0vmHnTZK2Z6+3S/romNYDAEDjlBG+IelbtnfbnsnaVkXE4ez1jyWtKmE9AAC0QuFjvpLeFxHztn9X0kO2n8lPjIiwHUsXyoJ6Zmk7qhex+OuyXXElaIr8/rx27dqKq0He9u9eLEnaesmjFVeCvMI934iYz56PSnpA0gZJR2yvlqTs+WiP5WYjYjp3nBhAQ+X3506nU3U5QO0VCl/bp9o+rfta0ock7ZP0oKSt2WxbJX2zyHoAAGiTosPOqyQ9kA1PrpT0tYj4V9tPSLrf9rWSXpB0TcH1AADQGoXCNyKel/QHPdpfknR5kc8GAKCtuMIVAACJEb4AACRG+AIAkFgZ3/NFw3S/x1t0Pr4HDFSv+z3eovPxPeC06PkCAJAYPd8JNKjHyhWugOYY1GPlClf1RM8XAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzv+eJN+H4v0B58v7ee6PkCAJAY4QsAQGIjDzvbfruk+3JN50n6H5JOl/RfJS1k7X8dEbtGrhAAgJYZOXwj4llJU5Jke4WkeUkPSPq0pDsi4vOlVAgAQMuUNex8uaQDEfFCSZ8HAEBrlRW+myXtyL2/0fZe29tsn1HSOgAAaIXC4Wv7LZI+Iukfs6Y7JZ2vxSHpw5Ju77PcjO0523NFawBQrfz+vLCwMHgBYMKV0fO9QtKTEXFEkiLiSEQcj4gTku6StKHXQhExGxHTETFdQg0AKpTfnzudTtXlALVXRvhuUW7I2fbq3LSrJe0rYR0AALRGoStc2T5V0gclXZdr/pztKUkh6eCSaQAATLxC4RsRP5f01iVtnyhUEQAALccVrgAASIzwBQAgMcIXAIDECF8AABLjfr6YKBEx9Lzc1xiot8ffe9nQ8274j2+PsZLlo+cLAEBi9HwxEZbT4wVQbzM75yVJF178sYHzXvjo18ddzkgI34IiQrYHPqMahC6GdfH5tw0976MHbh5jJeinG7rL8WQW0D2vc1whhp0L6gbroGekVzR4CW6gPkYJ3jKXLxs934Lo+dZPmaHZ/Sx+h0A1ygzN7mfNfnxNaZ85KsK3IHq+9TKu3ir/iQLSGxS8Jzue++RJjgfP7JyvPIAJ34Lo+TZD93fAUDLQfN3gfPwL/efpBvPJQrhKHPMtiJ5vffQL1rJ+BwQ3kE6/Xu9ye6z9esdVHwMmfAvq/oM86BnjNe7gHbQeAOUpK3i76hjAhG9B9Hzrq+jPnt8dUB9Fj9FWfYx3qdYf8z1ZT6WMf1w55lu9Xr/jfj/z5f4uur/Dpevjd1qNV/f07y+cNnUiYSUYl1690X7BudxLRs5+fM2bPr+qk6+G6vna3mb7qO19ubYzbT9k+7ns+Yys3bb/3vZ+23ttXziu4gcZNERYxhAiPV8gjZMF7zDTgToZ9q/1Hkkbl7TdJOnhiFgv6eHsvSRdIWl99piRdGfxMpdv2GAt60IMHPOtj7L/w8N/oKo3bLASwO1Tdq+0LsPPQw07R8Qjttctad4k6dLs9XZJ/1fSf8/avxKLqfOY7dNtr46Iw2UUPIzlBl6RYUR6vsB4LTdQX93zWyMNQXPJSKRU5L+Jq3KB+mNJq7LXayS9mJvvUNYGAABU0tnOWS93Wd1N2zO252zPlVEDgOrk9+eFhYWqywFqr0j4HrG9WpKy56NZ+7ykc3LznZ21vUFEzEbEdERMF6gBQA3k9+dOp1N1OUDtFQnfByVtzV5vlfTNXPsns7Oe3yPpZymP9wIAUHdDnXBle4cWT646y/YhSbdI+ltJ99u+VtILkq7JZt8l6UpJ+yW9JunTJdcMAECjDXu285Y+ky7vMW9IuqFIUUARZV8Eg6+LAdUp+yIYVV/TuauVX4ob5SpGAOppuV8b4kpXaIJWhq80fKASvM3X63dYVm91OZeuxPgMG6gEb/P16uWW1VtdzqUrx6214SsN/keSf0Tbrayrl6EeBgUrwdtuRQO4LsPNXa0OX2kxYPs90B79fp+jBmiqWxRieU6bOtH3gfbo1xsdNUDLvkVhGVofvpgcZQUwwQtUr6wArmPwShNwS0FA+k2gnixAGWYGmqEbqCcL0LoNMy9F+KJVet1/N2/UgKXXC6TX6/67eaMGbNW9XonwRQt1g7LM+zUDqEY3KMvoydYhdLs45ovWKhqcBC9QH0WDs07BK9HzRcuN0gsmdIF6GqUXXLfQ7SJ8MREIVKA96hqoy8GwMwAAiRG+AAAkRvgCAJAY4QsAQGKELwAAiQ0MX9vbbB+1vS/X9ne2n7G91/YDtk/P2tfZ/oXtPdnjS+MsHgCAJhqm53uPpI1L2h6S9K6I+H1JP5T02dy0AxExlT2uL6dMAADaY2D4RsQjkl5e0vatiDiWvX1M0tljqA0AgFYq45jvn0r6l9z7c21/z/Z3bF9SwucDANAqha5wZftmScckfTVrOixpbUS8ZPsiSd+wfUFEvNJj2RlJM0XWD6Ae8vvz2rVrK64GqL+Re762PyXpKkl/EtmFcyPi9Yh4KXu9W9IBSW/rtXxEzEbEdERMj1oDgHrI78+dTqfqcoDaGyl8bW+U9FeSPhIRr+XaO7ZXZK/Pk7Re0vNlFAoAQFsMHHa2vUPSpZLOsn1I0i1aPLv5FEkPZResfyw7s/n9kv7G9q8knZB0fUS83PODAQCYUAPDNyK29Gi+u8+8OyXtLFpU1SKCu+AALeEvX6W47p+rLgN4A65w1cdy7v8KoN785auqLgF4A8L3JAhgoD0IYNQJ4TsAAQy0BwGMuiB8h0AAA+1BAKMOCN8hEcBAexDAqBrhuwwEMNAeBDCqRPguEwEMtAcBjKoQviMggIH2IIBRBcJ3RAQw0B4EMFIjfAsggIH2IICREuFbEAEMtAcBjFQI3xIQwEB7EMBIgfAtCQEMtAcBjHEbeFejScQdjYD24I5GqCN6vgAAJEb4AgCQ2MDwtb3N9lHb+3Jtt9qet70ne1yZm/ZZ2/ttP2v7w+MqHACAphqm53uPpI092u+IiKnssUuSbL9T0mZJF2TLfNH2irKKBQCgDQaGb0Q8IunlIT9vk6R7I+L1iPiRpP2SNhSoDwCA1ilyzPdG23uzYekzsrY1kl7MzXMoawMAAJlRw/dOSedLmpJ0WNLty/0A2zO252zPjVgDgJrI788LCwtVlwPU3kjhGxFHIuJ4RJyQdJd+M7Q8L+mc3KxnZ229PmM2IqYjYnqUGgDUR35/7nQ6VZcD1N5I4Wt7de7t1ZK6Z0I/KGmz7VNsnytpvaTHi5UIAEC7DLzCle0dki6VdJbtQ5JukXSp7SlJIemgpOskKSKesn2/pKclHZN0Q0QcH0/pAAA008DwjYgtPZrvPsn8t0m6rUhRAAC0GVe4AgAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASGxg+NreZvuo7X25tvts78keB23vydrX2f5FbtqXxlk8AABNtHKIee6R9A+SvtJtiIg/7r62fbukn+XmPxARU2UVCABA2wwM34h4xPa6XtNsW9I1kj5QblkAALRX0WO+l0g6EhHP5drOtf0929+xfUnBzwcAoHWGGXY+mS2SduTeH5a0NiJesn2RpG/YviAiXlm6oO0ZSTMF1w+gBvL789q1ayuuBqi/kXu+tldK+pik+7ptEfF6RLyUvd4t6YCkt/VaPiJmI2I6IqZHrQFAPeT3506nU3U5QO0VGXb+I0nPRMShboPtju0V2evzJK2X9HyxEgEAaJdhvmq0Q9Kjkt5u+5Dta7NJm/XGIWdJer+kvdlXj/5J0vUR8XKZBQMA0HTDnO28pU/7p3q07ZS0s3hZAAC0F1e4AgAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASMwRUXUNsr0g6eeSflJ1LSU4S2xHnTRhO34vIlpzE1zbr0p6tuo6StCEv51hsB1pDbU/1yJ8Jcn2XERMV11HUWxHvbRlO5qkLT9ztqNe2rIdXQw7AwCQGOELAEBidQrf2aoLKAnbUS9t2Y4macvPnO2ol7Zsh6QaHfMFAGBS1KnnCwDARCB8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEhtb+NreaPurKeZ9AAART0lEQVRZ2/tt3zSu9QAA0DSOiPI/1F4h6YeSPijpkKQnJG2JiKdLXxkAAA0zrp7vBkn7I+L5iPilpHslbRrTugAAaJSVY/rcNZJezL0/JOkP8zPYnpE0k729aEx1AE3wk4joVF1EEfn9+dRTT73oHe94R8UVAdXYvXv3UPvzuMJ3oIiYlTQrSbbLH/sGmuOFqgsoKr8/T09Px9zcXMUVAdWwPdT+PK5h53lJ5+Ten521AQAw8cYVvk9IWm/7XNtvkbRZ0oNjWhcAAI0ylmHniDhm+0ZJ/yZphaRtEfHUONYFAEDTjO2Yb0TskrRrXJ8PAEBTcYUrAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAIDHCFwCAxEYOX9vn2P627adtP2X7z7L2W23P296TPa4sr1wAAJpvZYFlj0n6y4h40vZpknbbfiibdkdEfL54ecB4RUTfabYTVgKgqJmd832nzX58TcJKBhs5fCPisKTD2etXbf9AUr22DujjZKG7dB5CGKi3k4Xu0nnqEsJFer6/ZnudpHdL+k9J75V0o+1PSprTYu/4pz2WmZE0U8b6geVYGry9wjU/T0QQwAPk9+e1a9dWXA0mydLg7RWu+Xlmds7XIoALn3Bl+3ck7ZT0mYh4RdKdks6XNKXFnvHtvZaLiNmImI6I6aI1AKPqF6qE7fLk9+dOp1N1OZhQ/UK1DmG7VKHwtf3bWgzer0bE1yUpIo5ExPGIOCHpLkkbipcJlCPfox0UsPnpwwxTA0gr36MdFLD56cMMU49bkbOdLeluST+IiC/k2lfnZrta0r7RywPGY9ieLT1goP6G7dnWqQdc5JjveyV9QtL3be/J2v5a0hbbU5JC0kFJ1xWqEACAlilytvO/S+rVLdg1ejkAALQfV7gCACAxwhcAgMQIX0ykYc9e5ixnoP6GPXu5Dmc5dxG+mCjL+frQcr6WBCC95Xx9aDlfS0qB8MVE6xfA9HiB5ukXwHXq8XaVcnlJoElsv+nykYPmB1BPsx9f86bLRw6avw4IX0ykbqByVyOg+bqBOhF3NQLagIAF2qNuAXsyHPMFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIrPBFNmwflPSqpOOSjkXEtO0zJd0naZ2kg5KuiYifFl0XAABtUFbP97KImIqI6ez9TZIejoj1kh7O3gMAAI1v2HmTpO3Z6+2SPjqm9QAA0DhlhG9I+pbt3bZnsrZVEXE4e/1jSauWLmR7xvac7bkSagBQofz+vLCwUHU5QO2VcWOF90XEvO3flfSQ7WfyEyMibL/p1jERMStpVpJ6TW+7iPj1re2GeQbqLL8/T09PT9T+fPH5ty17mUcP3DyGStAkhXu+ETGfPR+V9ICkDZKO2F4tSdnz0aLraZtuoA77DABoj0Lha/tU26d1X0v6kKR9kh6UtDWbbaukbxZZTxt17yM77DMAoD2KDjuvkvRA1jtbKelrEfGvtp+QdL/tayW9IOmagutpHXq+ADC5CoVvRDwv6Q96tL8k6fIin912HPMFgMnFFa4qQs8XACYX4VsRjvkCwOQifCtCzxcAJhfhWxF6vgAwuQjfitDzBYDJRfhWhJ4vAEwuwrci9HwBYHIRvhWh5wsAk6uMGytgBPR8F3ERETQdN0n4DX/5KsV1/1x1GY1AzxeVo3cPtIe/fFXVJTQC4YtaIICB9iCAByN8URsEMNAeBPDJEb6oFQIYaA8CuD/CF7VDAAPtQQD3RviilghgoD0I4DcjfFFbBDDQHgTwG40cvrbfbntP7vGK7c/YvtX2fK79yjILxmQhgIH2IIB/Y+TwjYhnI2IqIqYkXSTpNUkPZJPv6E6LiF1lFIrJRQAD7UEALypr2PlySQci4oWSPg94AwIYaA8CuLzw3SxpR+79jbb32t5m+4xeC9iesT1ne66kGtByBHB95ffnhYWFqstBA0x6ABcOX9tvkfQRSf+YNd0p6XxJU5IOS7q913IRMRsR0xExXbQGTA4CuJ7y+3On06m6HDTEJAdwGT3fKyQ9GRFHJCkijkTE8Yg4IekuSRtKWAfwawQw0B6TGsBlhO8W5Yacba/OTbta0r4S1gG8AQEMtMckBnCh8LV9qqQPSvp6rvlztr9ve6+kyyT9eZF1AP0QwEB7TFoAF7qfb0T8XNJbl7R9olBFwDJwP2CgPSbpfsCFwhcoiuAE2mNSgrMMXF4SAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASIzwBQAgMcIXAIDECF8AABIjfAEASGyo+/na3ibpKklHI+JdWduZku6TtE7SQUnXRMRPvXiD1v8p6UpJr0n6VEQ8WX7paUXEspfhXrVAPT3+3suWvcyG//j2GCrBpBq253uPpI1L2m6S9HBErJf0cPZekq6QtD57zEi6s3iZAAC0x1DhGxGPSHp5SfMmSduz19slfTTX/pVY9Jik022vLqNYAADaoMgx31URcTh7/WNJq7LXayS9mJvvUNb2BrZnbM/ZnitQA4AayO/PCwsLVZcD1F4pJ1zF4gHRZR0UjYjZiJiOiOkyagBQnfz+3Ol0qi4HqL0i4XukO5ycPR/N2uclnZOb7+ysDQAAqFj4Pihpa/Z6q6Rv5to/6UXvkfSz3PA0AAATb9ivGu2QdKmks2wfknSLpL+VdL/tayW9IOmabPZdWvya0X4tftXo0yXXDABAow0VvhGxpc+ky3vMG5JuKFIUAABtxhWuAABIjPAFACAxwhcAgMQIXwAAEhvqhCtwkwSgTbhJAqpGzxcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIjPAFACAxwhcAgMQIXwAAEiN8AQBIbGD42t5m+6jtfbm2v7P9jO29th+wfXrWvs72L2zvyR5fGmfxAAA00TA933skbVzS9pCkd0XE70v6oaTP5qYdiIip7HF9OWUCANAeA8M3Ih6R9PKStm9FxLHs7WOSzh5DbQAAtFIZx3z/VNK/5N6fa/t7tr9j+5J+C9mesT1ne66EGgBUKL8/LywsVF0OUHuFwtf2zZKOSfpq1nRY0tqIeLekv5D0Ndv/pdeyETEbEdMRMV2kBgDVy+/PnU6n6nKA2hs5fG1/StJVkv4kIkKSIuL1iHgpe71b0gFJbyuhTgAAWmOk8LW9UdJfSfpIRLyWa+/YXpG9Pk/SeknPl1EoAABtsXLQDLZ3SLpU0lm2D0m6RYtnN58i6SHbkvRYdmbz+yX9je1fSToh6fqIeLnnBwMAMKEGhm9EbOnRfHefeXdK2lm0KAAA2owrXAEAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACQ2MHxtb7N91Pa+XNuttudt78keV+amfdb2ftvP2v7wuAoHAKCphun53iNpY4/2OyJiKnvskiTb75S0WdIF2TJftL2irGIBAGiDlYNmiIhHbK8b8vM2Sbo3Il6X9CPb+yVtkPToyBWOQUT0nWY7YSUAinp1T/8+xGlTJxJWAgyvyDHfG23vzYalz8ja1kh6MTfPoaztTWzP2J6zPVeghmWJiJMG77DzAHij/P68sLCQZJ2v7vmtkwbvsPMAVRj1r/JOSedLmpJ0WNLty/2AiJiNiOmImB6xhuWub6zzA5Msvz93Op2xr2+5gUoAo25G+ouMiCMRcTwiTki6S4tDy5I0L+mc3KxnZ20AACAzUvjaXp17e7Wk7pnQD0rabPsU2+dKWi/p8WIlFjdqL5beL1A/o/Zi6f2iTgaecGV7h6RLJZ1l+5CkWyRdantKUkg6KOk6SYqIp2zfL+lpScck3RARx8dTOgAAzTTM2c5bejTffZL5b5N0W5GiAABoM8ZhAABIjPAFACAxwhcAgMQIXwAAEpuI8B31kpFcanJ0XCkM4zLqJSO51OTotn/3Ym3/7sVVl9EqExG+AADUycSE73J7sfR6gfpabi+WXi/qZuD3fNukG6jc1Qhovm6gclcjNNFEhW8XAQu0BwGLJpqYYWcAAOqC8AUAIDHCFwCAxCbymC+KG/Y7vIPm4/g7UL1hv8M7aL6tlzxaRjkTgZ4vAACJ0fPFSAb1WLs9Xnq2QP0N6rF2e7z0bMtDzxcAgMQGhq/tbbaP2t6Xa7vP9p7scdD2nqx9ne1f5KZ9aZzFAwDQRMMMO98j6R8kfaXbEBF/3H1t+3ZJP8vNfyAipsoqEACAthkYvhHxiO11vaZ58YDeNZI+UG5ZAAC0V9FjvpdIOhIRz+XazrX9PdvfsX1JvwVtz9iesz1XsAYAFcvvzwsLC1WXA9Re0fDdImlH7v1hSWsj4t2S/kLS12z/l14LRsRsRExHxHTBGgBULL8/dzqdqssBam/k8LW9UtLHJN3XbYuI1yPipez1bkkHJL2taJEAALRJke/5/pGkZyLiULfBdkfSyxFx3PZ5ktZLer5gjWggvt8LtAff7y3fMF812iHpUUlvt33I9rXZpM1645CzJL1f0t7sq0f/JOn6iHi5zIIBAGi6Yc523tKn/VM92nZK2lm8LAAA2osrXAEAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJEb4AACRG+AIAkBjhCwBAYoQvAACJOSKqrkG2FyT9XNJPqq6lBGeJ7aiTJmzH70VEa26Ca/tVSc9WXUcJmvC3Mwy2I62h9udahK8k2Z6LiOmq6yiK7aiXtmxHk7TlZ8521EtbtqOLYWcAABIjfAEASKxO4TtbdQElYTvqpS3b0SRt+ZmzHfXSlu2QVKNjvgAATIo69XwBAJgIlYev7Y22n7W93/ZNVdezHLYP2v6+7T2257K2M20/ZPu57PmMqutcyvY220dt78u19azbi/4++/3stX1hdZW/UZ/tuNX2fPY72WP7yty0z2bb8aztD1dTdbuxP6fH/tzM/bnS8LW9QtL/knSFpHdK2mL7nVXWNILLImIqdwr8TZIejoj1kh7O3tfNPZI2LmnrV/cVktZnjxlJdyaqcRj36M3bIUl3ZL+TqYjYJUnZ39VmSRdky3wx+/tDSdifK3OP2J8btz9X3fPdIGl/RDwfEb+UdK+kTRXXVNQmSduz19slfbTCWnqKiEckvbykuV/dmyR9JRY9Jul026vTVHpyfbajn02S7o2I1yPiR5L2a/HvD+Vhf64A+3Mz9+eqw3eNpBdz7w9lbU0Rkr5le7ftmaxtVUQczl7/WNKqakpbtn51N/F3dGM2pLYtN0zYxO1omqb/jNmf66mV+3PV4dt074uIC7U4lHOD7ffnJ8biqeSNO528qXVn7pR0vqQpSYcl3V5tOWgQ9uf6ae3+XHX4zks6J/f+7KytESJiPns+KukBLQ57HOkO42TPR6urcFn61d2o31FEHImI4xFxQtJd+s1QVKO2o6Ea/TNmf66fNu/PVYfvE5LW2z7X9lu0eAD9wYprGortU22f1n0t6UOS9mmx/q3ZbFslfbOaCpetX90PSvpkdpbkeyT9LDecVTtLjl9drcXfibS4HZttn2L7XC2ecPJ46vpajv25Ptif6y4iKn1IulLSDyUdkHRz1fUso+7zJP2/7PFUt3ZJb9Xi2YXPSfo/ks6sutYete/Q4hDOr7R4rOTafnVLshbPYD0g6fuSpquuf8B2/O+szr1a3EFX5+a/OduOZyVdUXX9bXywP1dSO/tzA/dnrnAFAEBiVQ87AwAwcQhfAAASI3wBAEiM8AUAIDHCFwCAxAhfAAASI3wBAEiM8AUAILH/D3LKQakHH5b0AAAAAElFTkSuQmCC\n", 21 | "text/plain": [ 22 | "
" 23 | ] 24 | }, 25 | "metadata": {}, 26 | "output_type": "display_data" 27 | } 28 | ], 29 | "source": [ 30 | "%matplotlib inline\n", 31 | "%load_ext autoreload\n", 32 | "%autoreload 2\n", 33 | "\n", 34 | "import os,sys\n", 35 | "import pandas as pd\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "import numpy as np\n", 38 | "import helper\n", 39 | "import simulation\n", 40 | "\n", 41 | "# Generate some random images\n", 42 | "input_images, target_masks = simulation.generate_random_data(192, 192, count=3)\n", 43 | "\n", 44 | "for x in [input_images, target_masks]:\n", 45 | " print(x.shape)\n", 46 | " print(x.min(), x.max())\n", 47 | "\n", 48 | "# Change channel-order and make 3 channels for matplot\n", 49 | "input_images_rgb = [x.astype(np.uint8) for x in input_images]\n", 50 | "\n", 51 | "# Map each channel (i.e. class) to each color\n", 52 | "target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]\n", 53 | "\n", 54 | "# Left: Input image, Right: Target mask (Ground-truth)\n", 55 | "helper.plot_side_by_side([input_images_rgb, target_masks_rgb])" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "{'train': 2000, 'val': 200}" 67 | ] 68 | }, 69 | "execution_count": 2, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "from torch.utils.data import Dataset, DataLoader\n", 76 | "from torchvision import transforms, datasets, models\n", 77 | "\n", 78 | "class SimDataset(Dataset):\n", 79 | " def __init__(self, count, transform=None):\n", 80 | " self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count) \n", 81 | " self.transform = transform\n", 82 | " \n", 83 | " def __len__(self):\n", 84 | " return len(self.input_images)\n", 85 | " \n", 86 | " def __getitem__(self, idx): \n", 87 | " image = self.input_images[idx]\n", 88 | " mask = self.target_masks[idx]\n", 89 | " if self.transform:\n", 90 | " image = self.transform(image)\n", 91 | " \n", 92 | " return [image, mask]\n", 93 | "\n", 94 | "# use same transform for train/val for this example\n", 95 | "trans = transforms.Compose([\n", 96 | " transforms.ToTensor(),\n", 97 | "])\n", 98 | "\n", 99 | "train_set = SimDataset(2000, transform = trans)\n", 100 | "val_set = SimDataset(200, transform = trans)\n", 101 | "\n", 102 | "image_datasets = {\n", 103 | " 'train': train_set, 'val': val_set\n", 104 | "}\n", 105 | "\n", 106 | "batch_size = 25\n", 107 | "\n", 108 | "dataloaders = {\n", 109 | " 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),\n", 110 | " 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)\n", 111 | "}\n", 112 | "\n", 113 | "dataset_sizes = {\n", 114 | " x: len(image_datasets[x]) for x in image_datasets.keys()\n", 115 | "}\n", 116 | "\n", 117 | "dataset_sizes" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 3, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192])\n", 130 | "0.0 1.0 0.02312283 0.1502936\n", 131 | "0.0 1.0 0.004655129 0.06806962\n" 132 | ] 133 | }, 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "" 138 | ] 139 | }, 140 | "execution_count": 3, 141 | "metadata": {}, 142 | "output_type": "execute_result" 143 | }, 144 | { 145 | "data": { 146 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD8CAYAAAB+fLH0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADn9JREFUeJzt3X+sZGV9x/H3p1j9w5oA1W4IYEGzmohptkrQpGqwrYqkcaV/0CVNpWq6mEDSP5o0aJNq2jRpWqmJqT+ypgRMFKStCDFUpaTRf0plqQQFRRaEsJt1qdioraa68O0fc67Oc7337tyZM7/fr2QyZ545M+c5e3c+93nOmXu+qSokacMvzLsDkhaLoSCpYShIahgKkhqGgqSGoSCpMbVQSHJJkoeSHEly7bS2I6lfmcb3FJKcBnwTeANwFLgHuKKqHux9Y5J6Na2RwkXAkap6tKp+DNwM7J/StiT16FlTet+zgSeGHh8FXrXdykn8WqU0fd+pqhecaqVphcIpJTkIHJzX9qU19PgoK00rFI4B5w49Pqdr+6mqOgQcAkcK0iKZ1jGFe4C9Sc5P8mzgAHD7lLYlqUdTGSlU1ckk1wCfB04Drq+qB6axLUn9msopyV13wumDNAv3VtWFp1rJbzRKahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpMXYoJDk3yb8leTDJA0n+uGt/X5JjSe7rbpf2111J0zbJJd5PAn9SVf+Z5HnAvUnu7J77QFW9f/LuSZq1sUOhqo4Dx7vlHyT5OoMakpKWWC/HFJKcB/w68B9d0zVJ7k9yfZIztnnNwSSHkxzuow+S+jFxMZgkvwR8Efirqvp0kj3Ad4AC/hI4q6recYr3sBiMNH3TLwaT5BeBfwY+UVWfBqiqE1X1dFU9A3wMuGiSbUiarUnOPgT4B+DrVfV3Q+1nDa12GfC18bsnadYmOfvwG8AfAF9Ncl/X9h7giiT7GEwfHgOumqiHkmbKArPS+rDArKTdMxQkNQwFSQ1DQVLDUBiyCAddpXkzFDYxGLTuDIUtGAxaZ4bCNqrKcNBaMhROwWDQujEURmAwaJ0YCiMyGLQuDAVJDUNhFxwtaB0YCrtkMGjVGQpjMBi0ygyFMfk9Bq0qQ2FCBoNWjaHQA4NBq8RQ6InBoFUxyYVbAUjyGPAD4GngZFVdmORM4FPAeQwu3np5Vf33pNuatsEFqqX11tdI4fVVtW/oopDXAndV1V7gru6xpCUwrenDfuDGbvlG4K1T2o6knvURCgV8Icm9SQ52bXu6ArQA3wb29LAdSTMw8TEF4DVVdSzJrwB3JvnG8JNVVVvVdegC5ODmdknzNfFIoaqOdfdPArcyqB15YqN8XHf/5BavO1RVF45SnELS7ExaYPa5SZ63sQy8kUHtyNuBK7vVrgRum2Q7kmZn0unDHuDW7lTes4BPVtXnktwD3JLkncDjwOUTbkfSjFhLUlof1pKUtHuGgqSGoSCpYShIahgKkhqGgqSGoSCp0cffPmgJjPp9FK8pIUcKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGX3NeE359WaMaOxSSvJRBvcgNLwL+HDgd+CPgv7r291TVHWP3UNJM9XLh1iSnAceAVwFvB/6nqt6/i9d74VZp+mZ64dbfAh6pqsd7ej9Jc9JXKBwAbhp6fE2S+5Ncn+SMnrahORgeSW4sb3W/CKUC1I+JQyHJs4G3AP/YNX0EeDGwDzgOXLfN6w4mOZzk8KR90PQMH6DcWN7q3gOZq2PiYwpJ9gNXV9Ubt3juPOCzVfXyU7yHv2YWVFX99AO/sbzVPXiGYwnM7JjCFQxNHTYKy3YuY1BbUtKSmOh7Cl1R2TcAVw01/02SfUABj216Tktm1OmDVoe1JLUjpw8rZaTpg99o1I4cKawf//ZBO/KU5PpxpKAdOVJYP44UtCNHCuvHkYJ25Ehh/ThSkNQwFLQjpw/rx+mDduT0Yf04UpDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVJjpFDoKj09meRrQ21nJrkzycPd/Rlde5J8MMmRrkrUK6bVeUn9G3WkcANwyaa2a4G7qmovcFf3GODNwN7udpBBxShJS2KkUKiqLwHf3dS8H7ixW74ReOtQ+8dr4G7g9E0FYiQtsEmOKeypquPd8reBPd3y2cATQ+sd7doa1pKUFlMvF1mpqtptQZeqOgQcAovBSItkkpHCiY1pQXf/ZNd+DDh3aL1zujZJS2CSULgduLJbvhK4baj9bd1ZiFcD3xuaZkhadBsX3dzpxqCq9HHgJwyOEbwT+GUGZx0eBv4VOLNbN8CHgEeArwIXjvD+5c2bt6nfDo/yebfArLQ+Riow6zcaJTUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FS45ShsE0dyb9N8o2uVuStSU7v2s9L8qMk93W3j06z85L6N8pI4QZ+vo7kncDLq+rXgG8C7x567pGq2tfd3tVPNyXNyilDYas6klX1hao62T28m0HBF0kroI9jCu8A/mXo8flJvpLki0leu92LrCUpLaaJakkm+TPgJPCJruk48MKqeirJK4HPJLmgqr6/+bXWkpQW09gjhSR/CPwO8Pu1Ueap6v+q6qlu+V4GVaJe0kM/Jc3IWKGQ5BLgT4G3VNUPh9pfkOS0bvlFwF7g0T46Kmk2Tjl9SHITcDHw/CRHgfcyONvwHODOJAB3d2caXgf8RZKfAM8A76qq7275xpIWkrUkpfVhLUlJu2coSGpMdEpyUY0zJeqOjUhrb6VCYZLjIxuvXfZwWJX90PyszPShrwOmi3DgVZqnpR8pTONDPPye/sbVulnqUBglEE71oT7Ve1SVwaC1sjLTB0n9WOqRwnZ285t9Y12PJUgDSzlSqKptP8TjDvV3ep2BoXWydKGw0wd00rm/wSAtYShsp6+DgR5U1LpbqlDoe8qwne3ez9GC1sFKHmhcRbsNpFHXd2SkzZZqpLCVaf2nTuIHRmvJkcIuzes38Kjv598+aFJLP1KQ1C9DYUo8KKllZShM0U5fspIW1bi1JN+X5NhQzchLh557d5IjSR5K8qZpdXyZGAxaJuPWkgT4wFDNyDsAkrwMOABc0L3mwxuXfF93BoOWxVi1JHewH7i5KwrzLeAIcNEE/ZM0Y5McU7imK0V/fZIzurazgSeG1jnatU3NtH4DezxA62rcUPgI8GJgH4P6kdft9g0sMCstprFCoapOVNXTVfUM8DF+NkU4Bpw7tOo5XdtW73Goqi4cpTiFpNkZt5bkWUMPLwM2zkzcDhxI8pwk5zOoJfnlybrYbHfL9r6H+U4btM7GrSV5cZJ9QAGPAVcBVNUDSW4BHmRQov7qqnp6Ol1v9XUtxWUPBL/erEktZS3JaV1oZdr/Fn5gNWerW0typ79gHPeDvQjhKC2Clfwryd3UbTAMpNZSh0KSkeo2THsbo76PtAyWOhRgOpdo3/wBnjQYDAQtk6U8piBpelYmFKZ9Nedx3t9LumkZrUwowGQfwlFeO07lKWnZLP0xha1M8wPph12rbqVGCpImZyhIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhrj1pL81FAdyceS3Ne1n5fkR0PPfXSanZfUv1H+IOoG4O+Bj280VNXvbSwnuQ743tD6j1TVvr46KGm2ThkKVfWlJOdt9VwGfzJ4OfCb/XZL0rxMekzhtcCJqnp4qO38JF9J8sUkr53w/SXN2KTXU7gCuGno8XHghVX1VJJXAp9JckFVfX/zC5McBA5OuH1JPRt7pJDkWcDvAp/aaOtK0D/VLd8LPAK8ZKvXW0tSWkyTTB9+G/hGVR3daEjygiSndcsvYlBL8tHJuihplkY5JXkT8O/AS5McTfLO7qkDtFMHgNcB93enKP8JeFdVfbfPDkuarqWsJSlpLKtbS1LS9BgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqGgqSGoSCpYShIahgKkhqTXo6tL98B/re7X2XPZ7X3cdX3D5Z7H391lJUW4noKAEkOr/ql2VZ9H1d9/2A99tHpg6SGoSCpsUihcGjeHZiBVd/HVd8/WIN9XJhjCpIWwyKNFCQtgLmHQpJLkjyU5EiSa+fdn7501bi/2lXfPty1nZnkziQPd/dnzLufu7FNBfIt9ykDH+x+rvcnecX8ej6abfbvfUmODVVSv3TouXd3+/dQkjfNp9f9m2sodIVjPgS8GXgZcEWSl82zTz17fVXtGzqFdS1wV1XtBe7qHi+TG4BLNrVtt09vZlAMaC+D8oAfmVEfJ3EDP79/AB/ofo77quoOgO7/6QHggu41H94ohLTs5j1SuAg4UlWPVtWPgZuB/XPu0zTtB27slm8E3jrHvuxaVX0J2FzcZ7t92g98vAbuBk5PctZsejqebfZvO/uBm7tSid8CjjD4/7z05h0KZwNPDD0+2rWtggK+kOTerpguwJ6qOt4tfxvYM5+u9Wq7fVqln+013RTo+qEp3yrtX2PeobDKXlNVr2AwjL46yeuGn6zBaZ+VOvWzivvEYNrzYmAfg6rq1823O9M371A4Bpw79Picrm3pVdWx7v5J4FYGQ8sTG0Po7v7J+fWwN9vt00r8bKvqRFU9XVXPAB/jZ1OEldi/rcw7FO4B9iY5P8mzGRy4uX3OfZpYkucmed7GMvBG4GsM9u3KbrUrgdvm08NebbdPtwNv685CvBr43tA0Y2lsOg5yGYOfIwz270CS5yQ5n8EB1S/Pun/TMNe/kqyqk0muAT4PnAZcX1UPzLNPPdkD3JoEBv/Gn6yqzyW5B7ilq9z9OHD5HPu4a10F8ouB5yc5CrwX+Gu23qc7gEsZHID7IfD2mXd4l7bZv4uT7GMwLXoMuAqgqh5IcgvwIHASuLqqnp5Hv/vmNxolNeY9fZC0YAwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDU+H/1Y/rQJgrTIwAAAABJRU5ErkJggg==\n", 147 | "text/plain": [ 148 | "
" 149 | ] 150 | }, 151 | "metadata": {}, 152 | "output_type": "display_data" 153 | } 154 | ], 155 | "source": [ 156 | "import torchvision.utils\n", 157 | "\n", 158 | "def reverse_transform(inp):\n", 159 | " inp = inp.numpy().transpose((1, 2, 0))\n", 160 | " inp = np.clip(inp, 0, 1)\n", 161 | " inp = (inp * 255).astype(np.uint8)\n", 162 | " \n", 163 | " return inp\n", 164 | "\n", 165 | "# Get a batch of training data\n", 166 | "inputs, masks = next(iter(dataloaders['train']))\n", 167 | "\n", 168 | "print(inputs.shape, masks.shape)\n", 169 | "for x in [inputs.numpy(), masks.numpy()]:\n", 170 | " print(x.min(), x.max(), x.mean(), x.std())\n", 171 | "\n", 172 | "plt.imshow(reverse_transform(inputs[3]))" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 4, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "----------------------------------------------------------------\n", 185 | " Layer (type) Output Shape Param #\n", 186 | "================================================================\n", 187 | " Conv2d-1 [-1, 64, 224, 224] 1,792\n", 188 | " ReLU-2 [-1, 64, 224, 224] 0\n", 189 | " Conv2d-3 [-1, 64, 224, 224] 36,928\n", 190 | " ReLU-4 [-1, 64, 224, 224] 0\n", 191 | " MaxPool2d-5 [-1, 64, 112, 112] 0\n", 192 | " Conv2d-6 [-1, 128, 112, 112] 73,856\n", 193 | " ReLU-7 [-1, 128, 112, 112] 0\n", 194 | " Conv2d-8 [-1, 128, 112, 112] 147,584\n", 195 | " ReLU-9 [-1, 128, 112, 112] 0\n", 196 | " MaxPool2d-10 [-1, 128, 56, 56] 0\n", 197 | " Conv2d-11 [-1, 256, 56, 56] 295,168\n", 198 | " ReLU-12 [-1, 256, 56, 56] 0\n", 199 | " Conv2d-13 [-1, 256, 56, 56] 590,080\n", 200 | " ReLU-14 [-1, 256, 56, 56] 0\n", 201 | " MaxPool2d-15 [-1, 256, 28, 28] 0\n", 202 | " Conv2d-16 [-1, 512, 28, 28] 1,180,160\n", 203 | " ReLU-17 [-1, 512, 28, 28] 0\n", 204 | " Conv2d-18 [-1, 512, 28, 28] 2,359,808\n", 205 | " ReLU-19 [-1, 512, 28, 28] 0\n", 206 | " Upsample-20 [-1, 512, 56, 56] 0\n", 207 | " Conv2d-21 [-1, 256, 56, 56] 1,769,728\n", 208 | " ReLU-22 [-1, 256, 56, 56] 0\n", 209 | " Conv2d-23 [-1, 256, 56, 56] 590,080\n", 210 | " ReLU-24 [-1, 256, 56, 56] 0\n", 211 | " Upsample-25 [-1, 256, 112, 112] 0\n", 212 | " Conv2d-26 [-1, 128, 112, 112] 442,496\n", 213 | " ReLU-27 [-1, 128, 112, 112] 0\n", 214 | " Conv2d-28 [-1, 128, 112, 112] 147,584\n", 215 | " ReLU-29 [-1, 128, 112, 112] 0\n", 216 | " Upsample-30 [-1, 128, 224, 224] 0\n", 217 | " Conv2d-31 [-1, 64, 224, 224] 110,656\n", 218 | " ReLU-32 [-1, 64, 224, 224] 0\n", 219 | " Conv2d-33 [-1, 64, 224, 224] 36,928\n", 220 | " ReLU-34 [-1, 64, 224, 224] 0\n", 221 | " Conv2d-35 [-1, 6, 224, 224] 390\n", 222 | "================================================================\n", 223 | "Total params: 7,783,238\n", 224 | "Trainable params: 7,783,238\n", 225 | "Non-trainable params: 0\n", 226 | "----------------------------------------------------------------\n" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "from torchsummary import summary\n", 232 | "import torch\n", 233 | "import torch.nn as nn\n", 234 | "import pytorch_unet\n", 235 | "\n", 236 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 237 | "\n", 238 | "model = pytorch_unet.UNet(6)\n", 239 | "model = model.to(device)\n", 240 | "\n", 241 | "summary(model, input_size=(3, 224, 224))" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 5, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "from collections import defaultdict\n", 251 | "import torch.nn.functional as F\n", 252 | "from loss import dice_loss\n", 253 | "\n", 254 | "def calc_loss(pred, target, metrics, bce_weight=0.5):\n", 255 | " bce = F.binary_cross_entropy_with_logits(pred, target)\n", 256 | " \n", 257 | " pred = F.sigmoid(pred)\n", 258 | " dice = dice_loss(pred, target)\n", 259 | " \n", 260 | " loss = bce * bce_weight + dice * (1 - bce_weight)\n", 261 | " \n", 262 | " metrics['bce'] += bce.data.cpu().numpy() * target.size(0)\n", 263 | " metrics['dice'] += dice.data.cpu().numpy() * target.size(0)\n", 264 | " metrics['loss'] += loss.data.cpu().numpy() * target.size(0)\n", 265 | " \n", 266 | " return loss\n", 267 | "\n", 268 | "def print_metrics(metrics, epoch_samples, phase): \n", 269 | " outputs = []\n", 270 | " for k in metrics.keys():\n", 271 | " outputs.append(\"{}: {:4f}\".format(k, metrics[k] / epoch_samples))\n", 272 | " \n", 273 | " print(\"{}: {}\".format(phase, \", \".join(outputs))) \n", 274 | "\n", 275 | "def train_model(model, optimizer, scheduler, num_epochs=25):\n", 276 | " best_model_wts = copy.deepcopy(model.state_dict())\n", 277 | " best_loss = 1e10\n", 278 | "\n", 279 | " for epoch in range(num_epochs):\n", 280 | " print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n", 281 | " print('-' * 10)\n", 282 | " \n", 283 | " since = time.time()\n", 284 | "\n", 285 | " # Each epoch has a training and validation phase\n", 286 | " for phase in ['train', 'val']:\n", 287 | " if phase == 'train':\n", 288 | " scheduler.step()\n", 289 | " for param_group in optimizer.param_groups:\n", 290 | " print(\"LR\", param_group['lr'])\n", 291 | " \n", 292 | " model.train() # Set model to training mode\n", 293 | " else:\n", 294 | " model.eval() # Set model to evaluate mode\n", 295 | "\n", 296 | " metrics = defaultdict(float)\n", 297 | " epoch_samples = 0\n", 298 | " \n", 299 | " for inputs, labels in dataloaders[phase]:\n", 300 | " inputs = inputs.to(device)\n", 301 | " labels = labels.to(device) \n", 302 | "\n", 303 | " # zero the parameter gradients\n", 304 | " optimizer.zero_grad()\n", 305 | "\n", 306 | " # forward\n", 307 | " # track history if only in train\n", 308 | " with torch.set_grad_enabled(phase == 'train'):\n", 309 | " outputs = model(inputs)\n", 310 | " loss = calc_loss(outputs, labels, metrics)\n", 311 | "\n", 312 | " # backward + optimize only if in training phase\n", 313 | " if phase == 'train':\n", 314 | " loss.backward()\n", 315 | " optimizer.step()\n", 316 | "\n", 317 | " # statistics\n", 318 | " epoch_samples += inputs.size(0)\n", 319 | "\n", 320 | " print_metrics(metrics, epoch_samples, phase)\n", 321 | " epoch_loss = metrics['loss'] / epoch_samples\n", 322 | "\n", 323 | " # deep copy the model\n", 324 | " if phase == 'val' and epoch_loss < best_loss:\n", 325 | " print(\"saving best model\")\n", 326 | " best_loss = epoch_loss\n", 327 | " best_model_wts = copy.deepcopy(model.state_dict())\n", 328 | "\n", 329 | " time_elapsed = time.time() - since\n", 330 | " print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n", 331 | " print('Best val loss: {:4f}'.format(best_loss))\n", 332 | "\n", 333 | " # load best model weights\n", 334 | " model.load_state_dict(best_model_wts)\n", 335 | " return model" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 6, 341 | "metadata": { 342 | "scrolled": true 343 | }, 344 | "outputs": [ 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "cuda:0\n", 350 | "Epoch 0/39\n", 351 | "----------\n", 352 | "LR 0.0001\n", 353 | "train: bce: 0.210124, dice: 0.994346, loss: 0.602235\n", 354 | "val: bce: 0.030143, dice: 0.986439, loss: 0.508291\n", 355 | "saving best model\n", 356 | "0m 43s\n", 357 | "Epoch 1/39\n", 358 | "----------\n", 359 | "LR 0.0001\n", 360 | "train: bce: 0.022030, dice: 0.806168, loss: 0.414099\n", 361 | "val: bce: 0.023499, dice: 0.671528, loss: 0.347514\n", 362 | "saving best model\n", 363 | "0m 43s\n", 364 | "Epoch 2/39\n", 365 | "----------\n", 366 | "LR 0.0001\n", 367 | "train: bce: 0.023134, dice: 0.522101, loss: 0.272618\n", 368 | "val: bce: 0.017994, dice: 0.439513, loss: 0.228753\n", 369 | "saving best model\n", 370 | "0m 43s\n", 371 | "Epoch 3/39\n", 372 | "----------\n", 373 | "LR 0.0001\n", 374 | "train: bce: 0.015791, dice: 0.392756, loss: 0.204273\n", 375 | "val: bce: 0.015154, dice: 0.353304, loss: 0.184229\n", 376 | "saving best model\n", 377 | "0m 43s\n", 378 | "Epoch 4/39\n", 379 | "----------\n", 380 | "LR 0.0001\n", 381 | "train: bce: 0.012854, dice: 0.299000, loss: 0.155927\n", 382 | "val: bce: 0.011838, dice: 0.235490, loss: 0.123664\n", 383 | "saving best model\n", 384 | "0m 43s\n", 385 | "Epoch 5/39\n", 386 | "----------\n", 387 | "LR 0.0001\n", 388 | "train: bce: 0.010764, dice: 0.217516, loss: 0.114140\n", 389 | "val: bce: 0.010928, dice: 0.202027, loss: 0.106478\n", 390 | "saving best model\n", 391 | "0m 43s\n", 392 | "Epoch 6/39\n", 393 | "----------\n", 394 | "LR 0.0001\n", 395 | "train: bce: 0.010902, dice: 0.222725, loss: 0.116813\n", 396 | "val: bce: 0.010661, dice: 0.192998, loss: 0.101830\n", 397 | "saving best model\n", 398 | "0m 43s\n", 399 | "Epoch 7/39\n", 400 | "----------\n", 401 | "LR 0.0001\n", 402 | "train: bce: 0.009604, dice: 0.184641, loss: 0.097122\n", 403 | "val: bce: 0.010067, dice: 0.181135, loss: 0.095601\n", 404 | "saving best model\n", 405 | "0m 43s\n", 406 | "Epoch 8/39\n", 407 | "----------\n", 408 | "LR 0.0001\n", 409 | "train: bce: 0.009128, dice: 0.176201, loss: 0.092664\n", 410 | "val: bce: 0.008653, dice: 0.176254, loss: 0.092453\n", 411 | "saving best model\n", 412 | "0m 43s\n", 413 | "Epoch 9/39\n", 414 | "----------\n", 415 | "LR 0.0001\n", 416 | "train: bce: 0.008457, dice: 0.170643, loss: 0.089550\n", 417 | "val: bce: 0.008299, dice: 0.171656, loss: 0.089977\n", 418 | "saving best model\n", 419 | "0m 43s\n", 420 | "Epoch 10/39\n", 421 | "----------\n", 422 | "LR 0.0001\n", 423 | "train: bce: 0.007046, dice: 0.151076, loss: 0.079061\n", 424 | "val: bce: 0.005749, dice: 0.138535, loss: 0.072142\n", 425 | "saving best model\n", 426 | "0m 43s\n", 427 | "Epoch 11/39\n", 428 | "----------\n", 429 | "LR 0.0001\n", 430 | "train: bce: 0.004789, dice: 0.094846, loss: 0.049817\n", 431 | "val: bce: 0.004794, dice: 0.082758, loss: 0.043776\n", 432 | "saving best model\n", 433 | "0m 43s\n", 434 | "Epoch 12/39\n", 435 | "----------\n", 436 | "LR 0.0001\n", 437 | "train: bce: 0.003822, dice: 0.066693, loss: 0.035258\n", 438 | "val: bce: 0.004868, dice: 0.075574, loss: 0.040221\n", 439 | "saving best model\n", 440 | "0m 43s\n", 441 | "Epoch 13/39\n", 442 | "----------\n", 443 | "LR 0.0001\n", 444 | "train: bce: 0.003647, dice: 0.065981, loss: 0.034814\n", 445 | "val: bce: 0.005102, dice: 0.078447, loss: 0.041774\n", 446 | "0m 43s\n", 447 | "Epoch 14/39\n", 448 | "----------\n", 449 | "LR 0.0001\n", 450 | "train: bce: 0.003680, dice: 0.068849, loss: 0.036265\n", 451 | "val: bce: 0.004177, dice: 0.066650, loss: 0.035413\n", 452 | "saving best model\n", 453 | "0m 43s\n", 454 | "Epoch 15/39\n", 455 | "----------\n", 456 | "LR 0.0001\n", 457 | "train: bce: 0.003029, dice: 0.053153, loss: 0.028091\n", 458 | "val: bce: 0.003654, dice: 0.061158, loss: 0.032406\n", 459 | "saving best model\n", 460 | "0m 43s\n", 461 | "Epoch 16/39\n", 462 | "----------\n", 463 | "LR 0.0001\n", 464 | "train: bce: 0.002797, dice: 0.050167, loss: 0.026482\n", 465 | "val: bce: 0.003610, dice: 0.059508, loss: 0.031559\n", 466 | "saving best model\n", 467 | "0m 43s\n", 468 | "Epoch 17/39\n", 469 | "----------\n", 470 | "LR 0.0001\n", 471 | "train: bce: 0.002720, dice: 0.049958, loss: 0.026339\n", 472 | "val: bce: 0.003184, dice: 0.057431, loss: 0.030307\n", 473 | "saving best model\n", 474 | "0m 43s\n", 475 | "Epoch 18/39\n", 476 | "----------\n", 477 | "LR 0.0001\n", 478 | "train: bce: 0.002537, dice: 0.046737, loss: 0.024637\n", 479 | "val: bce: 0.003113, dice: 0.054996, loss: 0.029055\n", 480 | "saving best model\n", 481 | "0m 43s\n", 482 | "Epoch 19/39\n", 483 | "----------\n", 484 | "LR 0.0001\n", 485 | "train: bce: 0.002300, dice: 0.044468, loss: 0.023384\n", 486 | "val: bce: 0.002945, dice: 0.051255, loss: 0.027100\n", 487 | "saving best model\n", 488 | "0m 43s\n", 489 | "Epoch 20/39\n", 490 | "----------\n", 491 | "LR 0.0001\n", 492 | "train: bce: 0.002042, dice: 0.040555, loss: 0.021299\n", 493 | "val: bce: 0.002866, dice: 0.050504, loss: 0.026685\n", 494 | "saving best model\n", 495 | "0m 43s\n", 496 | "Epoch 21/39\n", 497 | "----------\n", 498 | "LR 0.0001\n", 499 | "train: bce: 0.001988, dice: 0.038980, loss: 0.020484\n", 500 | "val: bce: 0.002593, dice: 0.047394, loss: 0.024993\n", 501 | "saving best model\n", 502 | "0m 43s\n", 503 | "Epoch 22/39\n", 504 | "----------\n", 505 | "LR 0.0001\n", 506 | "train: bce: 0.001841, dice: 0.036638, loss: 0.019239\n", 507 | "val: bce: 0.002522, dice: 0.045939, loss: 0.024230\n", 508 | "saving best model\n", 509 | "0m 43s\n", 510 | "Epoch 23/39\n", 511 | "----------\n", 512 | "LR 0.0001\n", 513 | "train: bce: 0.001795, dice: 0.035693, loss: 0.018744\n", 514 | "val: bce: 0.002727, dice: 0.044743, loss: 0.023735\n", 515 | "saving best model\n", 516 | "0m 43s\n", 517 | "Epoch 24/39\n", 518 | "----------\n", 519 | "LR 0.0001\n", 520 | "train: bce: 0.001691, dice: 0.034025, loss: 0.017858\n", 521 | "val: bce: 0.002360, dice: 0.043020, loss: 0.022690\n", 522 | "saving best model\n", 523 | "0m 43s\n", 524 | "Epoch 25/39\n", 525 | "----------\n", 526 | "LR 1e-05\n", 527 | "train: bce: 0.001572, dice: 0.031303, loss: 0.016437\n", 528 | "val: bce: 0.002217, dice: 0.040832, loss: 0.021524\n", 529 | "saving best model\n", 530 | "0m 43s\n", 531 | "Epoch 26/39\n", 532 | "----------\n", 533 | "LR 1e-05\n", 534 | "train: bce: 0.001514, dice: 0.030473, loss: 0.015993\n", 535 | "val: bce: 0.002166, dice: 0.040488, loss: 0.021327\n", 536 | "saving best model\n", 537 | "0m 43s\n", 538 | "Epoch 27/39\n", 539 | "----------\n", 540 | "LR 1e-05\n", 541 | "train: bce: 0.001501, dice: 0.030128, loss: 0.015815\n", 542 | "val: bce: 0.002229, dice: 0.040340, loss: 0.021285\n", 543 | "saving best model\n", 544 | "0m 43s\n", 545 | "Epoch 28/39\n", 546 | "----------\n", 547 | "LR 1e-05\n", 548 | "train: bce: 0.001496, dice: 0.029890, loss: 0.015693\n", 549 | "val: bce: 0.002166, dice: 0.040157, loss: 0.021162\n", 550 | "saving best model\n", 551 | "0m 43s\n", 552 | "Epoch 29/39\n", 553 | "----------\n", 554 | "LR 1e-05\n", 555 | "train: bce: 0.001488, dice: 0.029740, loss: 0.015614\n", 556 | "val: bce: 0.002215, dice: 0.040059, loss: 0.021137\n", 557 | "saving best model\n", 558 | "0m 43s\n", 559 | "Epoch 30/39\n", 560 | "----------\n", 561 | "LR 1e-05\n", 562 | "train: bce: 0.001479, dice: 0.029537, loss: 0.015508\n", 563 | "val: bce: 0.002149, dice: 0.039748, loss: 0.020948\n", 564 | "saving best model\n", 565 | "0m 43s\n", 566 | "Epoch 31/39\n", 567 | "----------\n", 568 | "LR 1e-05\n", 569 | "train: bce: 0.001469, dice: 0.029364, loss: 0.015416\n", 570 | "val: bce: 0.002212, dice: 0.039819, loss: 0.021016\n", 571 | "0m 43s\n", 572 | "Epoch 32/39\n", 573 | "----------\n", 574 | "LR 1e-05\n", 575 | "train: bce: 0.001470, dice: 0.029170, loss: 0.015320\n", 576 | "val: bce: 0.002146, dice: 0.039689, loss: 0.020918\n", 577 | "saving best model\n", 578 | "0m 43s\n", 579 | "Epoch 33/39\n", 580 | "----------\n", 581 | "LR 1e-05\n", 582 | "train: bce: 0.001456, dice: 0.029055, loss: 0.015255\n", 583 | "val: bce: 0.002180, dice: 0.039492, loss: 0.020836\n", 584 | "saving best model\n", 585 | "0m 43s\n", 586 | "Epoch 34/39\n", 587 | "----------\n", 588 | "LR 1e-05\n", 589 | "train: bce: 0.001451, dice: 0.028900, loss: 0.015175\n", 590 | "val: bce: 0.002170, dice: 0.039412, loss: 0.020791\n", 591 | "saving best model\n", 592 | "0m 43s\n", 593 | "Epoch 35/39\n", 594 | "----------\n", 595 | "LR 1e-05\n", 596 | "train: bce: 0.001432, dice: 0.028700, loss: 0.015066\n", 597 | "val: bce: 0.002203, dice: 0.039768, loss: 0.020985\n", 598 | "0m 43s\n", 599 | "Epoch 36/39\n", 600 | "----------\n", 601 | "LR 1e-05\n", 602 | "train: bce: 0.001433, dice: 0.028581, loss: 0.015007\n", 603 | "val: bce: 0.002091, dice: 0.039245, loss: 0.020668\n", 604 | "saving best model\n", 605 | "0m 43s\n", 606 | "Epoch 37/39\n", 607 | "----------\n", 608 | "LR 1e-05\n", 609 | "train: bce: 0.001422, dice: 0.028358, loss: 0.014890\n", 610 | "val: bce: 0.002160, dice: 0.039272, loss: 0.020716\n", 611 | "0m 43s\n", 612 | "Epoch 38/39\n", 613 | "----------\n", 614 | "LR 1e-05\n", 615 | "train: bce: 0.001414, dice: 0.028230, loss: 0.014822\n", 616 | "val: bce: 0.002143, dice: 0.039213, loss: 0.020678\n", 617 | "0m 43s\n", 618 | "Epoch 39/39\n", 619 | "----------\n", 620 | "LR 1e-05\n", 621 | "train: bce: 0.001406, dice: 0.027994, loss: 0.014700\n", 622 | "val: bce: 0.002083, dice: 0.039034, loss: 0.020559\n", 623 | "saving best model\n", 624 | "0m 43s\n", 625 | "Best val loss: 0.020559\n" 626 | ] 627 | } 628 | ], 629 | "source": [ 630 | "import torch\n", 631 | "import torch.optim as optim\n", 632 | "from torch.optim import lr_scheduler\n", 633 | "import time\n", 634 | "import copy\n", 635 | "\n", 636 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 637 | "print(device)\n", 638 | "\n", 639 | "num_class = 6\n", 640 | "\n", 641 | "model = pytorch_unet.UNet(num_class).to(device)\n", 642 | "\n", 643 | "# Observe that all parameters are being optimized\n", 644 | "optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)\n", 645 | "\n", 646 | "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)\n", 647 | "\n", 648 | "model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=40)" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": 8, 654 | "metadata": {}, 655 | "outputs": [ 656 | { 657 | "name": "stdout", 658 | "output_type": "stream", 659 | "text": [ 660 | "(3, 6, 192, 192)\n" 661 | ] 662 | }, 663 | { 664 | "data": { 665 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsQAAAKvCAYAAABtZtkaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+sbGd5H/rvUztwJUIFxPtavv5RG8shCml6AltOEIELBRJDURzCFbVVJU6DekAFqU1y1ZJQFW6vkKI2BCnKjclBWDZXiSGtQ7Fy3QYX0UCQKRwT1zEEg+0YcY6MvcERWElEYvu5f5zZYTjsHzN7ZvbM7PX5SFt75p21Zp61fR6/3/Puddaq7g4AAAzV31l2AQAAsEwCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIO2sEBcVVdV1b1VdV9VvWVRnwMAALOoRVyHuKrOSfKFJK9IcirJp5Nc292fm/uHAQDADBa1Qnxlkvu6+4Hu/usk709y9YI+CwAADuzcBb3vhUm+PPb8VJIf3m3jqnK7PIbsq929sewipnHeeef1pZdeuuwyYCnuvPPOtepZ/cqQTdqviwrE+6qq40mOL+vzYYV8adkFTGK8Zy+55JKcPHlyyRXBclTVyvesfoUzJu3XRZ0ycTrJxWPPLxqN/a3uPtHdm929uaAagDka79mNjbVZHINB0q8wnUUF4k8nuaKqLquqpyS5JsmtC/osAAA4sIWcMtHdj1fVm5P8QZJzktzQ3Z9dxGcBAMAsFnYOcXffluS2Rb0/AADMgzvVAQAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIN24EBcVRdX1Uer6nNV9dmq+hej8bdX1emqumv09ar5lQsAAPN17gz7Pp7kF7v7M1X19CR3VtXto9fe1d2/Ont5AACwWAcOxN39UJKHRo8fq6o/TXLhvAoDAIDDMJdziKvq0iQ/lOR/jIbeXFV3V9UNVfXMeXwGAAAswsyBuKq+O8ktSf5ld38jyfVJLk9yLGdWkN+5y37Hq+pkVZ2ctQZg8cZ7dmtra9nlAHvQrzCdmQJxVX1XzoTh3+7u30uS7n64u5/o7ieTvCfJlTvt290nunuzuzdnqQE4HOM9u7GxsexygD3oV5jOLFeZqCTvTfKn3f1rY+MXjG32miT3HLw8AABYrFmuMvHCJD+d5E+q6q7R2C8nubaqjiXpJA8mecNMFQIAwALNcpWJP0pSO7x028HLAQCAw+VOdQAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgnTvrG1TVg0keS/JEkse7e7OqnpXkA0kuTfJgktd195/P+lkAADBv81ohfml3H+vuzdHztyT5SHdfkeQjo+cAALByFnXKxNVJbho9vinJTy7ocwAAYCbzCMSd5MNVdWdVHR+Nnd/dD40efyXJ+XP4HAAAmLuZzyFO8qPdfbqq/tckt1fV58df7O6uqj57p1F4Pn72OLCaxnv2kksuWXI1wF70K0xn5hXi7j49+v5Ikg8muTLJw1V1QZKMvj+yw34nuntz7LxjYIWN9+zGxsayywH2oF9hOjMF4qp6WlU9fftxkh9Lck+SW5NcN9rsuiQfmuVzAABgUWY9ZeL8JB+squ33+p3u/q9V9ekkv1tVr0/ypSSvm/FzAABgIWYKxN39QJJ/sMP415K8bJb3BgCAw+BOdQAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgnXvQHavqOUk+MDb07CT/NskzkvyzJFuj8V/u7tsOXCEAACzQgQNxd9+b5FiSVNU5SU4n+WCSf5rkXd39q3OpEAAAFmhep0y8LMn93f2lOb0fAAAcinkF4muS3Dz2/M1VdXdV3VBVz5zTZwAAwNzNHIir6ilJfiLJfxwNXZ/k8pw5neKhJO/cZb/jVXWyqk7OWgOweOM9u7W1tf8OwNLoV5jOPFaIX5nkM939cJJ098Pd/UR3P5nkPUmu3Gmn7j7R3ZvdvTmHGoAFG+/ZjY2NZZcD7EG/wnTmEYivzdjpElV1wdhrr0lyzxw+AwAAFuLAV5lIkqp6WpJXJHnD2PC/r6pjSTrJg2e9tjDdvetrVXUYJQBTeOyu3f8+/vRjTx5iJcB+9CtH3UyBuLv/Isn3nDX20zNVdLA69n1dKIbVsdfkuv26SRZWg35lCNb+TnX7heFptwMWa7/JddrtgMXRrwzFWv8JnjbkCsWwXNNOmiZZWB79ypD40wsAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAzaWgfiaW+24eYcsFzTXrzfxf5hefQrQ7LWgTiZPOQKw7AaJp00Ta6wfPqVoVj7QJzsH3aFYVgt+02eJldYHfqVITh32QXMi9AL68UkCutDv3LUHYkVYgAAOCiBGACAQROIAQAYNIGYuejuZZcATKF+69XLLgGYkH5dvIkCcVXdUFWPVNU9Y2PPqqrbq+qLo+/PHI1XVf16Vd1XVXdX1fMWVTyrRSiG9WKShfWhXxdr0qtM3JjkN5K8b2zsLUk+0t2/UlVvGT3/10lemeSK0dcPJ7l+9P1ImiYEDuFKGN09iONkfR2/5fTE25547YULrGQ11G+9Ov2G3192GbAj/frt9OviTLRC3N0fS/LoWcNXJ7lp9PimJD85Nv6+PuOTSZ5RVRfMo9hV0t1Tr4geZJ91NIRjZP0cv+X0VJPrQfdZR1aeWDX6dXf6dTFmOYf4/O5+aPT4K0nOHz2+MMmXx7Y7NRo7MmYNfEMIjEM4RtbHrJOkSRYOj37dn36dv7n8o7o+k36mSkBVdbyqTlbVyXnUcFjmFfSGEBiHcIxDMt6zW1tbyy5nYvOaHE2yrBP9ql+Zzix3qnu4qi7o7odGp0Q8Mho/neTise0uGo19m+4+keREklTVWiSnvQLeXufN7rbfEM63HcIxDsV4z25ubq5Fz+41Ke51vuFu+x2/5fRKn6f4gsvfMfG2F73woh3HK85RPAr06+r367bxvr3j/rdOta9ziudnlhXiW5NcN3p8XZIPjY3/zOhqEz+S5Otjp1asrd1CbVXtG/j22mbVV1G369vv+yTvAYdpt0nyxGsv3HeS3GsbK08wf0Pt17P/EvuCy9/xbV+T0K/zMell125OckeS51TVqap6fZJfSfKKqvpikpePnifJbUkeSHJfkvck+edzr/qQ7RWGp7GOoXi75v2+72eVj5GjZ6/JdRrrOsnOg0mWwzLUfp0k8G5vs98qsH6d3aRXmbi2uy/o7u/q7ou6+73d/bXufll3X9HdL+/uR0fbdne/qbsv7+6/391rdY7wpA56GsC6nT4wjxXis98LluGgvzpdh1+5LopJlmXRr99ipfhwuFPdPnYKcbOG2p32X9WwOK8V4m2repwcHTutBs06Se60/6quOs2bSZZF0q+TecHl75joXGH9enACMXua5wrx2e8JrAeTLCzfdih2+sRizHKViUGa1ykPVbUWwXDSFeJ1OxWE4ZjXr1BPvPbCtV9lGnfRJy7ecfzL73v3IVcC36Jf9/a35xTf78oS82aFmD1NukI8j5VjAGB/01xikckIxOxp2hViK8YAwLoRiNmTFWIA4KgTiNmTFWIA4KgTiNmTFWIAWD3OI54vgZg9WSEGAI46gXhK81r5XJcVVCvErLt5XXrpKF7CCVaNfmVZBGL2ZIUYAObvjvvfuuwSGCMQ72MRt1lexO2gF8UKMetmEbdtXcTtZQH9yuoQiA/ooIFv3YKiFWKOioNOsn71CodvKP16x/1vPdBK8UH3Y3e1CgGtqpZfxD52+zlNE/zm8R4cSXd29+ayi5jG5uZmnzx5ctll7Gm3iXGalaJ5vAdHT1WtVc/qV/06ZJP2qxXiCe0WWrt731XfvbYRhmExdpsEj99yet9VpL22MbnC/OlXlu3cZRewTqpq12B7kJV2YRgW68RrL9x1ojzIr1ZNrrA4+pVl2neFuKpuqKpHquqesbH/UFWfr6q7q+qDVfWM0filVfVXVXXX6Ovdiyx+GeYVYoVhOBzzmhRNrrB4+pVlmeSUiRuTXHXW2O1JfqC7fzDJF5L80thr93f3sdHXG+dT5mqZNcwKw3C4Zp0cTa5wePQry7DvKRPd/bGquvSssQ+PPf1kkv9jvmWtvu1QO82pEoIwLM/2JDnNr15NrLAc+pXDNo9ziH8uyQfGnl9WVX+c5BtJ/k13f3wOn7GyhFxYLyZNWB/6lcMyUyCuqrcmeTzJb4+GHkpySXd/raqen+Q/V9Vzu/sbO+x7PMnxWT4fODzjPXvJJZcsuRpgL/oVpnPgy65V1c8meXWSf9Kj8wa6+5vd/bXR4zuT3J/ke3fav7tPdPfmOl3LEYZsvGc3NjaWXQ6wB/0K0zlQIK6qq5L8qyQ/0d1/OTa+UVXnjB4/O8kVSR6YR6EAALAI+54yUVU3J3lJkvOq6lSSt+XMVSWemuT20Tm0nxxdUeLFSf5dVf1NkieTvLG7H11Q7QAAMLNJrjJx7Q7D791l21uS3DJrUQAAcFjcuhkAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYtH0DcVXdUFWPVNU9Y2Nvr6rTVXXX6OtVY6/9UlXdV1X3VtWPL6pwAACYh0lWiG9MctUO4+/q7mOjr9uSpKq+P8k1SZ472uc3q+qceRULAADztm8g7u6PJXl0wve7Osn7u/ub3f1nSe5LcuUM9QEAwELNcg7xm6vq7tEpFc8cjV2Y5Mtj25wajQEAwEo694D7XZ/k/07So+/vTPJz07xBVR1PcvyAn88R0N1T71NVC6iESYz37CWXXLLkaliGT73wpVPvc+UnPrqAStiPfkW/TudAK8Td/XB3P9HdTyZ5T751WsTpJBePbXrRaGyn9zjR3ZvdvXmQGoDDNd6zGxsbyy4H2IN+hekcKBBX1QVjT1+TZPsKFLcmuaaqnlpVlyW5IsmnZisRAAAWZ99TJqrq5iQvSXJeVZ1K8rYkL6mqYzlzysSDSd6QJN392ar63SSfS/J4kjd19xOLKZ1Vtn06hFMcYD3c9PEXJEmue9EdS64E2I9+nb99A3F3X7vD8Hv32P4dSd4xS1EAAHBY3KkOAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBB2/c6xHC27ZtuzHtbYDG2L+I/67bnvfzMlLHx30wdsCjz6NfrXnRHPvOCn8rz7vi9eZV15Pm/GkvnbnawHr768seTJP/o//r4kisB9nLTx1+Q593xv/zt8ys/8dElVrMeBGKmNkmAdetmWB2T3N7VrWBhNcytXz8xr4qGwTnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoO0biKvqhqp6pKruGRv7QFXdNfp6sKruGo1fWlV/NfbauxdZPAAAzGqS6xDfmOQ3krxve6C7//H246p6Z5Kvj21/f3cfm1eBAACwSPsG4u7+WFVdutNrdeauC69L8g/nWxbrzg05YL24IQesD/06f7OeQ/yiJA939xfHxi6rqj+uqj+sqhfN+P4AALBQs966+dokN489fyjJJd39tap6fpL/XFXP7e5vnL1jVR1PcnzGzwcOyXjPXnLJJUuuBtiLfoXpHHiFuKrOTfJTST6wPdbd3+zur40e35nk/iTfu9P+3X2iuze7e/OgNQCHZ7xnNzY2ll0OsAf9CtOZ5ZSJlyf5fHef2h6oqo2qOmf0+NlJrkjywGwlAgDA4kxy2bWbk9yR5DlVdaqqXj966Zp8++kSSfLiJHePLsP2n5K8sbsfnWfBAAAwT5NcZeLaXcZ/doexW5LcMntZAABwONypDgCAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEGr7l52DamqrSR/keSry65lDs6L41gl63Acf6+7N5ZdxDSq6rEk9y67jjlYhz8fk3Ach2utela/rhzHcbgm6tdzD6OS/XT3RlWd7O7NZdcyK8exWo7Kcayge4/Cz/Wo/PlwHOxDv64Qx7GanDIBAMCgCcQAAAzaKgXiE8suYE4cx2o5Ksexao7Kz9VxrJajchyr5qj8XB3Hajkqx5FkRf5RHQAALMsqrRADAMChE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQVtYIK6qq6rq3qq6r6resqjPAQCAWVR3z/9Nq85J8oUkr0hyKsmnk1zb3Z+b+4cBAMAMFrVCfGWS+7r7ge7+6yTvT3L1gj4LAAAObFGB+MIkXx57fmo0BgAAK+XcZX1wVR1Pcnz09PnLqgNWwFe7e2PZRexnvGef9rSnPf/7vu/7llwRLMedd9658j2rX+GMSft1UYH4dJKLx55fNBr7W919IsmJJKmq+Z/IDOvjS8suYBLjPbu5udknT55cckWwHFW18j2rX+GMSft1UadMfDrJFVV1WVU9Jck1SW5d0GcBAMCBLWSFuLsfr6o3J/mDJOckuaG7P7uIzwIAgFks7Bzi7r4tyW2Len8AAJgHd6oDAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABu3AgbiqLq6qj1bV56rqs1X1L0bjb6+q01V11+jrVfMrFwAA5uvcGfZ9PMkvdvdnqurpSe6sqttHr72ru3919vIAAGCxDhyIu/uhJA+NHj9WVX+a5MJ5FQYAAIdhLucQV9WlSX4oyf8YDb25qu6uqhuq6pnz+AwAAFiEmQNxVX13kluS/Mvu/kaS65NcnuRYzqwgv3OX/Y5X1cmqOjlrDcDijffs1tbWsssB9qBfYTozBeKq+q6cCcO/3d2/lyTd/XB3P9HdTyZ5T5Ird9q3u09092Z3b85SA3A4xnt2Y2Nj2eUAe9CvMJ1ZrjJRSd6b5E+7+9fGxi8Y2+w1Se45eHkAALBYs1xl4oVJfjrJn1TVXaOxX05ybVUdS9JJHkzyhpkqBACABZrlKhN/lKR2eOm2g5cDAACHy53qAAAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEE7d9Y3qKoHkzyW5Ikkj3f3ZlU9K8kHklya5MEkr+vuP5/1swAAYN7mtUL80u4+1t2bo+dvSfKR7r4iyUdGzwEAYOUs6pSJq5PcNHp8U5KfXNDnHEh3L7sEYAr1W69edgnAhPQr62gegbiTfLiq7qyq46Ox87v7odHjryQ5fw6fM1dCMawXkyysD/3KuplHIP7R7n5eklcmeVNVvXj8xT6TPL8jfVbV8ao6WVUn51DDgQjFMLnxnt3a2lpODSZZmIh+henMHIi7+/To+yNJPpjkyiQPV9UFSTL6/sgO+53o7s2x846XQiiGyYz37MbGxtLqMMnC/vQrTGemQFxVT6uqp28/TvJjSe5JcmuS60abXZfkQ7N8zqIJxbBeTLKwPvQr62DWFeLzk/xRVf3PJJ9K8v91939N8itJXlFVX0zy8tHzlSYUw3oxycL60K+supkCcXc/0N3/YPT13O5+x2j8a939su6+ortf3t2PzqfcxRKKYb2YZGF96FdWmTvVnUUohvVikoX1oV9ZVQLxDoRiWC8mWVgf+pVVJBDvQiiG9WKShfWhX1k1AvEehGJYLyZZWB/6lVUiEO9DKIb1YpKF9aFfWRUC8QSEYlgvJllYH/qVVSAQT0gohvVikoX1oV9ZNoF4CkIxrBeTLKwP/coyCcRTEophvZhkYX3oV5ZFID4AoRjWi0kW1od+ZRkE4gMSimG9mGRhfehXDptAPAOhGNaLSRbWh37lMAnEMxKKYb2YZGF96FcOy7nLLmAZqmrZJQBT6Df8/rJLACakX1lHVogBABg0gRgAgEE78CkTVfWcJB8YG3p2kn+b5BlJ/lmSrdH4L3f3bQeuEAAAFujAgbi7701yLEmq6pwkp5N8MMk/TfKu7v7VuVQIAAALNK9TJl6W5P7u/tKc3g8AAA7FvALxNUluHnv+5qq6u6puqKpnzukzAABg7mYOxFX1lCQ/keQ/joauT3J5zpxO8VCSd+6y3/GqOllVJ2etAVi88Z7d2trafwdgafQrTGceK8SvTPKZ7n44Sbr74e5+orufTPKeJFfutFN3n+juze7enEMNwIKN9+zGxsayywH2oF9hOvMIxNdm7HSJqrpg7LXXJLlnDp8BAAALMdOd6qrqaUlekeQNY8P/vqqOJekkD571GgAArJSZAnF3/0WS7zlr7KdnqggAAA6RO9UBADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgzZRIK6qG6rqkaq6Z2zsWVV1e1V9cfT9maPxqqpfr6r7quruqnreoooHAIBZTbpCfGOSq84ae0uSj3T3FUk+MnqeJK9McsXo63iS62cvEwAAFmOiQNzdH0vy6FnDVye5afT4piQ/OTb+vj7jk0meUVUXzKNYAACYt1nOIT6/ux8aPf5KkvNHjy9M8uWx7U6NxgAAYOXM5R/VdXcn6Wn2qarjVXWyqk7OowZgscZ7dmtra9nlAHvQrzCdWQLxw9unQoy+PzIaP53k4rHtLhqNfZvuPtHdm929OUMNwCEZ79mNjY1llwPsQb/CdGYJxLcmuW70+LokHxob/5nR1SZ+JMnXx06tAACAlXLuJBtV1c1JXpLkvKo6leRtSX4lye9W1euTfCnJ60ab35bkVUnuS/KXSf7pnGsGAIC5mSgQd/e1u7z0sh227SRvmqUoAAA4LO5UBwDAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAM2rnLLgDmqbtTVXP7DizWCy5/x9ze64773zq39wK+01HuVyvEHCnbIXZe3wGAo2/fQFxVN1TVI1V1z9jYf6iqz1fV3VX1wap6xmj80qr6q6q6a/T17kUWD2fr7rl+BwCOvklWiG9MctVZY7cn+YHu/sEkX0jyS2Ov3d/dx0Zfb5xPmTAZK8QAwLT2DcTd/bEkj5419uHufnz09JNJLlpAbTA1K8QAwLTmcQ7xzyX5L2PPL6uqP66qP6yqF83h/WFiVogBgGnNFIir6q1JHk/y26Ohh5Jc0t0/lOQXkvxOVf3dXfY9XlUnq+rkLDXAOCvEizPes1tbW8suB9iDfoXpHDgQV9XPJnl1kn/So/TQ3d/s7q+NHt+Z5P4k37vT/t19ors3u3vzoDXA2awQL854z25sbCy7HGAP+hWmc6BAXFVXJflXSX6iu/9ybHyjqs4ZPX52kiuSPDCPQmESVogBgGnte2OOqro5yUuSnFdVp5K8LWeuKvHUJLePVtI+ObqixIuT/Luq+pskTyZ5Y3c/uuMbwwJYIQYAprVvIO7ua3cYfu8u296S5JZZi4KDcqc6AGBa7lTHkWKFGACYlkDMkeIcYgBgWgIxR4oVYgBgWgIxR4oVYgBgWgIxR4oVYgBgWgIxR4oVYgBgWgIxR4oVYgBgWgIxR4oVYgBgWvvemAPWiRViWC933P/WZZcATOgo96sVYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABm3fQFxVN1TVI1V1z9jY26vqdFXdNfp61dhrv1RV91XVvVX144sqHAAA5mGSFeIbk1y1w/i7uvvY6Ou2JKmq709yTZLnjvb5zao6Z17FcrR0tzvCwRq56eMvyE0ff8GyywAmoF+ns++d6rr7Y1V16YTvd3WS93f3N5P8WVXdl+TKJHccuMKBmiYouqsaLN/xW05PvO2J1164wEqA/ehXzjbLrZvfXFU/k+Rkkl/s7j9PcmGST45tc2o0xoQOsmK6vY9gDIdvmon17H1MtHC49Cu7Oeg/qrs+yeVJjiV5KMk7p32DqjpeVSer6uQBazhyZj19wOkHLNJ4z25tbS27nJVwkMl1nvvDbvTrd9Kv7OVAK8Td/fD246p6T5LfHz09neTisU0vGo3t9B4nkpwYvcegk9w8g6zVYhZlvGc3NzcH3bPznBitPrEI+vVb9CuTOFAgrqoLuvuh0dPXJNm+AsWtSX6nqn4tyf+W5Iokn5q5yiNsvzC8V7Dda9/uFophAfabXPeaKPfa9/gtp02yMGf6lUntG4ir6uYkL0lyXlWdSvK2JC+pqmNJOsmDSd6QJN392ar63SSfS/J4kjd19xOLKf1omyTMbm/jVAlYvkkmx+1t/OoVlku/crZJrjJx7Q7D791j+3ckeccsRR2WZV/JYbfPn/azqmrH97JKzFHzqRe+dOJtr/zER+f++btNjNOuFJ147YU7vpdVJ44S/co6cae6JZlXGN5vP6vHMB/zmlz3289qFMxOvzItgXiFzLqaazUYDtesq0NWl+Dw6Ff2IhAvwU6rtvMKszu9j1VimM1Oq0Dzmhx3eh+rTnBw+pWDEIgBABi0We5Ux5ytNTNaAAAZK0lEQVTM+1SH3f6R3aId9DOn3c+pISzbvH91uts/2lm0mz7+gkPZ77oX3XGgz4F50K/6dRIC8YRmuV4wcLjOe9WDeeyuvX8B9vRjTx5SNcB+xq9IceUnPvod/atfWTSBeE5c4mz6vxS4qx6LcN6rHpxou8fu+juDn2SnXQnaXmka6goSi7Xduw+847Js/KNvf02/6tdFE4jnSMCD5Zk0CI/bXoUa+kQLyzTeuxv/6Eu7bqdfWSSBGABYiu0wvFcQhsPgKhML4DJncLgOsjo8br/zjYHp7Xf3uYOGYf3KIgx6hXiR1+x1TjHM304T7LwmR+cowvyd3bPb/3hu1pVh/cq8+WvWCpj3irIValiseV9yyYX9GZLDPk1CvzIJgRgAOBTbK8bOGWbVCMRLsOhTNSb5PGByi7xd6yJvMwuraNFhWL9yEALxCpk1FDtVAg7XrJOsX73C4dGv7EUgXpLdVm3nfftjq8MwH7utAh10ktxtP6tNMDv9yrT2DcRVdUNVPVJV94yNfaCq7hp9PVhVd43GL62qvxp77d2LLH7dzSsUC8NwOOY1yZpcYfH0K9OY5LJrNyb5jSTv2x7o7n+8/biq3pnk62Pb39/dx+ZV4FBNctc7p0jA6tieNPeaJP3KFVaDfuVs+wbi7v5YVV2602t1Jq29Lsk/nG9Zw1FVewbbg4bedVgdXoca4WwnXnvhnhPlQSfRdVhtuu5Fdyy7BJiKfmVSs96Y40VJHu7uL46NXVZVf5zkG0n+TXd/fMbPOFT7BdRp3mfabQ/7c+EoePqxJ+dyc45pLvK/PRnOYwVpHSZWWDX6lXmbdRa5NsnNY88fSnJJd/9Qkl9I8jtV9Xd32rGqjlfVyao6OWMNczdrqDzo/sv6XJjEeM9ubW0tu5xvM+sdqw66/6yTo8mVRTnK/XpQ+pW91CSrkqNTJn6/u39gbOzcJKeTPL+7T+2y339P8n92956ht6pW7mTYWVZr5xFMp/l8QXjt3dndm8suYhqbm5t98uRq/V12llXieUzQ06w+mVjXW1WtVc+uYr8mB+9Z/co0Ju3XWU6ZeHmSz4+H4araSPJodz9RVc9OckWSB2b4jKU5yGkM8wymQi5MZ3uSnGaSnedKlUkTpjPt6U76lUWa5LJrNye5I8lzqupUVb1+9NI1+fbTJZLkxUnuHl2G7T8leWN3PzrPgg/bpMFUgIXVMOmkuaxf2wLfol9ZFZNcZeLaXcZ/doexW5LcMntZq0XYhfVi8oT1oV9ZBe5UBwDAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADFp197JrSFVtJfmLJF9ddi1zcF4cxypZh+P4e929sewiplFVjyW5d9l1zME6/PmYhOM4XGvVs/p15TiOwzVRv557GJXsp7s3qupkd28uu5ZZOY7VclSOYwXdexR+rkflz4fjYB/6dYU4jtXklAkAAAZNIAYAYNBWKRCfWHYBc+I4VstROY5Vc1R+ro5jtRyV41g1R+Xn6jhWy1E5jiQr8o/qAABgWVZphRgAAA6dQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAM2sICcVVdVVX3VtV9VfWWRX0OAADMorp7/m9adU6SLyR5RZJTST6d5Nru/tzcPwwAAGawqBXiK5Pc190PdPdfJ3l/kqsX9FkAAHBg5y7ofS9M8uWx56eS/PD4BlV1PMnx0dPnL6gOWAdf7e6NZRexn/GefdrTnvb87/u+71tyRbAcd95558r3rH6FMybt10UF4n1194kkJ5KkquZ/3gasjy8tu4BJjPfs5uZmnzx5cskVwXJU1cr3rH6FMybt10WdMnE6ycVjzy8ajQEAwEpZVCD+dJIrquqyqnpKkmuS3LqgzwIAgANbyCkT3f14Vb05yR8kOSfJDd392UV8FgAAzGJh5xB3921JblvU+wMAwDy4Ux0AAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIMmEAMAMGgCMQAAgyYQAwAwaAcOxFV1cVV9tKo+V1Wfrap/MRp/e1Wdrqq7Rl+vml+5AAAwX+fOsO/jSX6xuz9TVU9PcmdV3T567V3d/auzlwcAAIt14EDc3Q8leWj0+LGq+tMkF86rMAAAOAxzOYe4qi5N8kNJ/sdo6M1VdXdV3VBVz9xln+NVdbKqTs6jBmCxxnt2a2tr2eUAe9CvMJ2ZA3FVfXeSW5L8y+7+RpLrk1ye5FjOrCC/c6f9uvtEd2929+asNQCLN96zGxsbyy4H2IN+henMFIir6rtyJgz/dnf/XpJ098Pd/UR3P5nkPUmunL1MAABYjFmuMlFJ3pvkT7v718bGLxjb7DVJ7jl4eQAAsFizXGXihUl+OsmfVNVdo7FfTnJtVR1L0kkeTPKGmSoEAIAFmuUqE3+UpHZ46baDlwMAAIfLneoAABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQROIAQAYNIEYAIBBE4gBABg0gRgAgEETiAEAGDSBGACAQTt31jeoqgeTPJbkiSSPd/dmVT0ryQeSXJrkwSSv6+4/n/WzAABg3ua1QvzS7j7W3Zuj529J8pHuviLJR0bPAQBg5SzqlImrk9w0enxTkp9c0OcAAMBM5hGIO8mHq+rOqjo+Gju/ux8aPf5KkvPP3qmqjlfVyao6OYcagAUb79mtra1llwPsQb/CdOYRiH+0u5+X5JVJ3lRVLx5/sbs7Z0Jzzho/0d2bY6dZACtsvGc3NjaWXQ6wB/0K05k5EHf36dH3R5J8MMmVSR6uqguSZPT9kVk/BwAAFmGmQFxVT6uqp28/TvJjSe5JcmuS60abXZfkQ7N8DgAALMqsl107P8kHq2r7vX6nu/9rVX06ye9W1euTfCnJ62b8HACAI+UFl79j6n3uuP+tC6iEmQJxdz+Q5B/sMP61JC+b5b1Zru5OVc38HTgcB5lYz2aihcMxS79u76tf58ud6tjRdpid9TsAwKoTiNnRmYuDzP4dAGDVCcTsyAoxAKyueZwmxbcIxOzICjEAMBQCMTuyQgwADIVAzI6sEAMAQyEQsyMrxADAUAjE7MgKMQAwFAIxO7JCDAAMhUDMjqwQAwBDIRCzIyvEAMBQCMTsyAoxADAUAjE7skIMAAyFQMyOrBADwOq64/63LruEI0UgZkdWiAGAoRCI2ZEVYgBgKM496I5V9ZwkHxgbenaSf5vkGUn+WZKt0fgvd/dtB66QpbBCDOvFr09hfejX1XPgFeLuvre7j3X3sSTPT/KXST44evld268dlTBsxRPWS/3Wq5ddAjAh/cqyzeuUiZclub+7vzSn91tJQjGsF5MsrA/9yjLNKxBfk+Tmsedvrqq7q+qGqnrmTjtU1fGqOllVJ+dUw6EQihmq8Z7d2traf4cVYZJliPQrTGfmQFxVT0nyE0n+42jo+iSXJzmW5KEk79xpv+4+0d2b3b05aw2HTShmiMZ7dmNjY9nlTMUky9DoV5jOPFaIX5nkM939cJJ098Pd/UR3P5nkPUmunMNnrByhGNaLSRbWh37lsM0jEF+bsdMlquqCsddek+SeOXzGShKKYb2YZGF96FcO00yBuKqeluQVSX5vbPjfV9WfVNXdSV6a5Odn+YxVJxTDejHJwvrQrxyWmQJxd/9Fd39Pd399bOynu/vvd/cPdvdPdPdDs5e52oRiWC8mWVgf+pXD4E51cyIUw3oxycL60K8smkA8R0IxrBeTLKwP/coiCcRzJhTDejHJwvrQryyKQLwAQjGsF5MsrA/9yiIIxAsiFMN6McnC+tCvzJtAvEBCMawXkyysD/3KPAnECyYUw3oxycL60K/Mi0B8CIRiWC8mWVgf+pV5EIgPiVAM68UkC+tDvzIrgfgQCcWwXkyysD70K7M4d9kFrIuqWnYJwBT6Db+/7BKACelXls0KMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIM2USCuqhuq6pGqumds7FlVdXtVfXH0/Zmj8aqqX6+q+6rq7qp63qKKBwCAWU162bUbk/xGkveNjb0lyUe6+1eq6i2j5/86ySuTXDH6+uEk14++swR7XfvYpeRg9Tx21+7rFE8/9uQhVgJMYree1a/rZaJA3N0fq6pLzxq+OslLRo9vSvLfcyYQX53kfX0miX2yqp5RVRd090PzKJjvdNAbfnS3UAxLcPyW0zuO/9gzfy3JzydJfvxZ7/qO1x+76++YZOGQnd2vJ157YZK9//K6/bp+XR+znEN8/ljI/UqS80ePL0zy5bHtTo3Gvk1VHa+qk1V1coYaBm/Wu9+5ex6TGu/Zra2tZZeztvYOw9/yB4/+/I7b7TcJQ6Jf52Wnfj1+y+mJ+1C/ro+5/JcarQZPlay6+0R3b3b35jxqGKKdwmxVWfVlIcZ7dmNjY9nlrKWdJtcTr70wJ1574Y4rwruFYtiPfp3dXv3K0TPLrZsf3j4VoqouSPLIaPx0kovHtrtoNMYcnR2GhWBYbbv92nXcdigeD8J/8OjP7xiWgcWZpF85WmZZIb41yXWjx9cl+dDY+M+MrjbxI0m+7vzh+RKGYb1MO7meHYCtFMPhEYaHadLLrt2c5I4kz6mqU1X1+iS/kuQVVfXFJC8fPU+S25I8kOS+JO9J8s/nXjV/SxiG9TLp5GpVGJZPGB6OSa8yce0uL71sh207yZtmKQoAAA6Lf/64xqwOw3qZdrXJKjEsz1796nJqR49AfERNGpaFalgNk06wJmJYH/p1fQjER9h+YVcYhtWy3+RpcoXVoV+PFoF4jU1yU43t6xKfHX6FYTh8u92UY9zTjz35t1+/eP+X8+E//4V8+M9/weQKh2yaft1pnPUyy3WIAQAGTwBef1aI19ykt152i2ZYDZOsOk2zHbA40/Tr9hfrSSBeQ2ef7rBf2HUjD1ius/+1+n6TphsDwPLM2q+sJ4F4Te0Uis8OvjuNCcOwHDtNsmdPpDuNCcNw+PTr8DiHeI1V1Y4heK/tgeU58doLd5xU99oeWA79OiwC8ZrbKRTvtt062D6Ww653mnOs1+VnyWraaZLdbbt1cNPHX5Akue5Fdxzq537qhS+deNsrP/HRBVbCUaZf52Md+lUgPgIENFgv6zJ5Avp1KJxDDADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKDtG4ir6oaqeqSq7hkb+w9V9fmquruqPlhVzxiNX1pVf1VVd42+3r3I4gEAYFaTrBDfmOSqs8ZuT/ID3f2DSb6Q5JfGXru/u4+Nvt44nzIBAGAx9g3E3f2xJI+eNfbh7n589PSTSS5aQG0AALBw8ziH+OeS/Jex55dV1R9X1R9W1Yt226mqjlfVyao6OYcagAUb79mtra1llwPsQb/CdGrC2/5emuT3u/sHzhp/a5LNJD/V3V1VT03y3d39tap6fpL/nOS53f2Nfd5/8vvmsnamuS3yQRyBO/Xd2d2byy5iGpubm33ypL/LHlXbt3ddlMO+bey8VdVa9ax+Pdr0694m7dcDrxBX1c8meXWSf9KjxNPd3+zur40e35nk/iTfe9DPAACARTv3IDtV1VVJ/lWS/727/3JsfCPJo939RFU9O8kVSR6YS6WsrWlWcLdXk4/Aqi+srWlWhLZXp9Z9FQnWlX6dj30DcVXdnOQlSc6rqlNJ3pYzV5V4apLbR8Hlk6MrSrw4yb+rqr9J8mSSN3b3ozu+MQAArIB9A3F3X7vD8Ht32faWJLfMWhQAABwWd6oDAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAbtQLduhkVxy2ZYL24BC+tDv+7OCjEAAIMmEAMAMGgCMQAAgyYQAwAwaAIxAACDJhADADBoAjEAAIO2byCuqhuq6pGqumds7O1Vdbqq7hp9vWrstV+qqvuq6t6q+vFFFQ4AAPMwyQrxjUmu2mH8Xd19bPR1W5JU1fcnuSbJc0f7/GZVnTOvYgEAYN72DcTd/bEkj074flcneX93f7O7/yzJfUmunKE+AABYqFnOIX5zVd09OqXimaOxC5N8eWybU6Ox71BVx6vqZFWdnKEG4JCM9+zW1tayywH2oF9hOgcNxNcnuTzJsSQPJXnntG/Q3Se6e7O7Nw9YA3CIxnt2Y2Nj2eUAe9CvMJ0DBeLufri7n+juJ5O8J986LeJ0kovHNr1oNAYAACvpQIG4qi4Ye/qaJNtXoLg1yTVV9dSquizJFUk+NVuJAACwOOfut0FV3ZzkJUnOq6pTSd6W5CVVdSxJJ3kwyRuSpLs/W1W/m+RzSR5P8qbufmIxpQMAwOz2DcTdfe0Ow+/dY/t3JHnHLEUBAMBhcac6AAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYNAEYgAABk0gBgBg0ARiAAAGTSAGAGDQBGIAAAZNIAYAYND2DcRVdUNVPVJV94yNfaCq7hp9PVhVd43GL62qvxp77d2LLB4AAGZ17gTb3JjkN5K8b3ugu//x9uOqemeSr49tf393H5tXgQAAsEj7BuLu/lhVXbrTa1VVSV6X5B/OtywAADgcs55D/KIkD3f3F8fGLquqP66qP6yqF+22Y1Udr6qTVXVyxhqAQzDes1tbW8suB9iDfoXpzBqIr01y89jzh5Jc0t0/lOQXkvxOVf3dnXbs7hPdvdndmzPWAByC8Z7d2NhYdjnAHvQrTOfAgbiqzk3yU0k+sD3W3d/s7q+NHt+Z5P4k3ztrkQAAsCizrBC/PMnnu/vU9kBVbVTVOaPHz05yRZIHZisRAAAWZ5LLrt2c5I4kz6mqU1X1+tFL1+TbT5dIkhcnuXt0Gbb/lOSN3f3oPAsGAIB5muQqE9fuMv6zO4zdkuSW2csCAIDD4U51AAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADFp197JrSFVtJfmLJF9ddi1zcF4cxypZh+P4e929sewiplFVjyW5d9l1zME6/PmYhOM4XGvVs/p15TiOwzVRv557GJXsp7s3qupkd28uu5ZZOY7VclSOYwXdexR+rkflz4fjYB/6dYU4jtXklAkAAAZNIAYAYNBWKRCfWHYBc+I4VstROY5Vc1R+ro5jtRyV41g1R+Xn6jhWy1E5jiQr8o/qAABgWVZphRgAAA7d0gNxVV1VVfdW1X1V9ZZl1zONqnqwqv6kqu6qqpOjsWdV1e1V9cXR92cuu86zVdUNVfVIVd0zNrZj3XXGr4/++9xdVc9bXuXfbpfjeHtVnR79N7mrql419tovjY7j3qr68eVUvf707OHTs3r2oPTr4dOv69mvSw3EVXVOkv8nySuTfH+Sa6vq+5dZ0wG8tLuPjV165C1JPtLdVyT5yOj5qrkxyVVnje1W9yuTXDH6Op7k+kOqcRI35juPI0neNfpvcqy7b0uS0Z+ra5I8d7TPb47+/DEFPbs0N0bP6tkp6deluTH6de36ddkrxFcmua+7H+juv07y/iRXL7mmWV2d5KbR45uS/OQSa9lRd38syaNnDe9W99VJ3tdnfDLJM6rqgsOpdG+7HMdurk7y/u7+Znf/WZL7cubPH9PRs0ugZ/XsAenXJdCv69mvyw7EFyb58tjzU6OxddFJPlxVd1bV8dHY+d390OjxV5Kcv5zSprZb3ev43+jNo1893TD267R1PI5VtO4/Rz27mvTsYqz7z1C/rqYj2a/LDsTr7ke7+3k58yuPN1XVi8df7DOX8Fi7y3isa90j1ye5PMmxJA8leedyy2HF6NnVo2fZjX5dPUe2X5cdiE8nuXjs+UWjsbXQ3adH3x9J8sGc+fXAw9u/7hh9f2R5FU5lt7rX6r9Rdz/c3U9095NJ3pNv/cpmrY5jha31z1HPrh49u1Br/TPUr6vnKPfrsgPxp5NcUVWXVdVTcuaE7FuXXNNEquppVfX07cdJfizJPTlT/3Wjza5L8qHlVDi13eq+NcnPjP4l7I8k+frYr31WzlnnXr0mZ/6bJGeO45qqempVXZYz/4DhU4dd3xGgZ1eHnmU/+nV16NdV191L/UryqiRfSHJ/krcuu54p6n52kv85+vrsdu1Jvidn/gXpF5P8tyTPWnatO9R+c878quNvcuY8n9fvVneSypl/pXx/kj9Jsrns+vc5jv93VOfdOdOgF4xt/9bRcdyb5JXLrn9dv/TsUmrXs3r2oD9z/Xr4tevXNexXd6oDAGDQln3KBAAALJVADADAoAnEAAAMmkAMAMCgCcQAAAyaQAwAwKAJxAAADJpADADAoP3/1Wo52AeZZucAAAAASUVORK5CYII=\n", 666 | "text/plain": [ 667 | "
" 668 | ] 669 | }, 670 | "metadata": {}, 671 | "output_type": "display_data" 672 | } 673 | ], 674 | "source": [ 675 | "# prediction\n", 676 | "\n", 677 | "import math\n", 678 | "\n", 679 | "model.eval() # Set model to evaluate mode\n", 680 | "\n", 681 | "test_dataset = SimDataset(3, transform = trans)\n", 682 | "test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)\n", 683 | " \n", 684 | "inputs, labels = next(iter(test_loader))\n", 685 | "inputs = inputs.to(device)\n", 686 | "labels = labels.to(device)\n", 687 | "\n", 688 | "pred = model(inputs)\n", 689 | "\n", 690 | "pred = pred.data.cpu().numpy()\n", 691 | "print(pred.shape)\n", 692 | "\n", 693 | "# Change channel-order and make 3 channels for matplot\n", 694 | "input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]\n", 695 | "\n", 696 | "# Map each channel (i.e. class) to each color\n", 697 | "target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]\n", 698 | "pred_rgb = [helper.masks_to_colorimg(x) for x in pred]\n", 699 | "\n", 700 | "helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])" 701 | ] 702 | } 703 | ], 704 | "metadata": { 705 | "kernelspec": { 706 | "display_name": "Python [conda env:py36]", 707 | "language": "python", 708 | "name": "conda-env-py36-py" 709 | }, 710 | "language_info": { 711 | "codemirror_mode": { 712 | "name": "ipython", 713 | "version": 3 714 | }, 715 | "file_extension": ".py", 716 | "mimetype": "text/x-python", 717 | "name": "python", 718 | "nbconvert_exporter": "python", 719 | "pygments_lexer": "ipython3", 720 | "version": "3.6.4" 721 | } 722 | }, 723 | "nbformat": 4, 724 | "nbformat_minor": 2 725 | } 726 | -------------------------------------------------------------------------------- /pytorch_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def double_conv(in_channels, out_channels): 5 | return nn.Sequential( 6 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 7 | nn.ReLU(inplace=True), 8 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 9 | nn.ReLU(inplace=True) 10 | ) 11 | 12 | 13 | class UNet(nn.Module): 14 | 15 | def __init__(self, n_class): 16 | super().__init__() 17 | 18 | self.dconv_down1 = double_conv(3, 64) 19 | self.dconv_down2 = double_conv(64, 128) 20 | self.dconv_down3 = double_conv(128, 256) 21 | self.dconv_down4 = double_conv(256, 512) 22 | 23 | self.maxpool = nn.MaxPool2d(2) 24 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 25 | 26 | self.dconv_up3 = double_conv(256 + 512, 256) 27 | self.dconv_up2 = double_conv(128 + 256, 128) 28 | self.dconv_up1 = double_conv(128 + 64, 64) 29 | 30 | self.conv_last = nn.Conv2d(64, n_class, 1) 31 | 32 | 33 | def forward(self, x): 34 | conv1 = self.dconv_down1(x) 35 | x = self.maxpool(conv1) 36 | 37 | conv2 = self.dconv_down2(x) 38 | x = self.maxpool(conv2) 39 | 40 | conv3 = self.dconv_down3(x) 41 | x = self.maxpool(conv3) 42 | 43 | x = self.dconv_down4(x) 44 | 45 | x = self.upsample(x) 46 | x = torch.cat([x, conv3], dim=1) 47 | 48 | x = self.dconv_up3(x) 49 | x = self.upsample(x) 50 | x = torch.cat([x, conv2], dim=1) 51 | 52 | x = self.dconv_up2(x) 53 | x = self.upsample(x) 54 | x = torch.cat([x, conv1], dim=1) 55 | 56 | x = self.dconv_up1(x) 57 | 58 | out = self.conv_last(x) 59 | 60 | return out 61 | -------------------------------------------------------------------------------- /simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | def generate_random_data(height, width, count): 5 | x, y = zip(*[generate_img_and_mask(height, width) for i in range(0, count)]) 6 | 7 | X = np.asarray(x) * 255 8 | X = X.repeat(3, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8) 9 | Y = np.asarray(y) 10 | 11 | return X, Y 12 | 13 | def generate_img_and_mask(height, width): 14 | shape = (height, width) 15 | 16 | triangle_location = get_random_location(*shape) 17 | circle_location1 = get_random_location(*shape, zoom=0.7) 18 | circle_location2 = get_random_location(*shape, zoom=0.5) 19 | mesh_location = get_random_location(*shape) 20 | square_location = get_random_location(*shape, zoom=0.8) 21 | plus_location = get_random_location(*shape, zoom=1.2) 22 | 23 | # Create input image 24 | arr = np.zeros(shape, dtype=bool) 25 | arr = add_triangle(arr, *triangle_location) 26 | arr = add_circle(arr, *circle_location1) 27 | arr = add_circle(arr, *circle_location2, fill=True) 28 | arr = add_mesh_square(arr, *mesh_location) 29 | arr = add_filled_square(arr, *square_location) 30 | arr = add_plus(arr, *plus_location) 31 | arr = np.reshape(arr, (1, height, width)).astype(np.float32) 32 | 33 | # Create target masks 34 | masks = np.asarray([ 35 | add_filled_square(np.zeros(shape, dtype=bool), *square_location), 36 | add_circle(np.zeros(shape, dtype=bool), *circle_location2, fill=True), 37 | add_triangle(np.zeros(shape, dtype=bool), *triangle_location), 38 | add_circle(np.zeros(shape, dtype=bool), *circle_location1), 39 | add_filled_square(np.zeros(shape, dtype=bool), *mesh_location), 40 | # add_mesh_square(np.zeros(shape, dtype=bool), *mesh_location), 41 | add_plus(np.zeros(shape, dtype=bool), *plus_location) 42 | ]).astype(np.float32) 43 | 44 | return arr, masks 45 | 46 | def add_square(arr, x, y, size): 47 | s = int(size / 2) 48 | arr[x-s,y-s:y+s] = True 49 | arr[x+s,y-s:y+s] = True 50 | arr[x-s:x+s,y-s] = True 51 | arr[x-s:x+s,y+s] = True 52 | 53 | return arr 54 | 55 | def add_filled_square(arr, x, y, size): 56 | s = int(size / 2) 57 | 58 | xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]] 59 | 60 | return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, yy > y - s, yy < y + s])) 61 | 62 | def logical_and(arrays): 63 | new_array = np.ones(arrays[0].shape, dtype=bool) 64 | for a in arrays: 65 | new_array = np.logical_and(new_array, a) 66 | 67 | return new_array 68 | 69 | def add_mesh_square(arr, x, y, size): 70 | s = int(size / 2) 71 | 72 | xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]] 73 | 74 | return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, xx % 2 == 1, yy > y - s, yy < y + s, yy % 2 == 1])) 75 | 76 | def add_triangle(arr, x, y, size): 77 | s = int(size / 2) 78 | 79 | triangle = np.tril(np.ones((size, size), dtype=bool)) 80 | 81 | arr[x-s:x-s+triangle.shape[0],y-s:y-s+triangle.shape[1]] = triangle 82 | 83 | return arr 84 | 85 | def add_circle(arr, x, y, size, fill=False): 86 | xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]] 87 | circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2) 88 | new_arr = np.logical_or(arr, np.logical_and(circle < size, circle >= size * 0.7 if not fill else True)) 89 | 90 | return new_arr 91 | 92 | def add_plus(arr, x, y, size): 93 | s = int(size / 2) 94 | arr[x-1:x+1,y-s:y+s] = True 95 | arr[x-s:x+s,y-1:y+1] = True 96 | 97 | return arr 98 | 99 | def get_random_location(width, height, zoom=1.0): 100 | x = int(width * random.uniform(0.1, 0.9)) 101 | y = int(height * random.uniform(0.1, 0.9)) 102 | 103 | size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom) 104 | 105 | return (x, y, size) --------------------------------------------------------------------------------