├── LICENSE ├── README.md └── code └── section1 ├── images ├── base_conv.png ├── base_conv_skip.png ├── conv_functional2.png ├── down_conv.png ├── unet_paper.png ├── unet_skip_connection.png └── up_conv.png ├── utils.py ├── video1_1.ipynb ├── video1_2.ipynb ├── video1_3.ipynb ├── video1_4.ipynb ├── video1_5.ipynb └── video1_6.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Packt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mastering-PyTorch-for-Deep-Learning 2 | Mastering PyTorch for Deep Learning, Published by Packt 3 | -------------------------------------------------------------------------------- /code/section1/images/base_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Mastering-PyTorch-for-Deep-Learning/013f4f03d3375d14f01779a9f6cc7be32200d003/code/section1/images/base_conv.png -------------------------------------------------------------------------------- /code/section1/images/base_conv_skip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Mastering-PyTorch-for-Deep-Learning/013f4f03d3375d14f01779a9f6cc7be32200d003/code/section1/images/base_conv_skip.png -------------------------------------------------------------------------------- /code/section1/images/conv_functional2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Mastering-PyTorch-for-Deep-Learning/013f4f03d3375d14f01779a9f6cc7be32200d003/code/section1/images/conv_functional2.png -------------------------------------------------------------------------------- /code/section1/images/down_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Mastering-PyTorch-for-Deep-Learning/013f4f03d3375d14f01779a9f6cc7be32200d003/code/section1/images/down_conv.png -------------------------------------------------------------------------------- /code/section1/images/unet_paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Mastering-PyTorch-for-Deep-Learning/013f4f03d3375d14f01779a9f6cc7be32200d003/code/section1/images/unet_paper.png -------------------------------------------------------------------------------- /code/section1/images/unet_skip_connection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Mastering-PyTorch-for-Deep-Learning/013f4f03d3375d14f01779a9f6cc7be32200d003/code/section1/images/unet_skip_connection.png -------------------------------------------------------------------------------- /code/section1/images/up_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Mastering-PyTorch-for-Deep-Learning/013f4f03d3375d14f01779a9f6cc7be32200d003/code/section1/images/up_conv.png -------------------------------------------------------------------------------- /code/section1/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Mar 8 15:09:04 2018 5 | 6 | Mastering PyTorch for Deep Learning 7 | 8 | @author: pbialecki 9 | """ 10 | 11 | import re 12 | import os 13 | import errno 14 | 15 | from sklearn.model_selection import train_test_split 16 | 17 | SEED=2809 18 | 19 | 20 | def get_image_id(image_name): 21 | ''' 22 | Returns the image id regardless of the smoothing factor. 23 | ''' 24 | pattern = '_\w{1}(\d+_C\d+)_F\d+_(s\d+)' 25 | match = re.findall(pattern, image_name) 26 | image_id = '_'.join(match[0]) 27 | return image_id 28 | 29 | 30 | def get_image_name(image_path): 31 | ''' 32 | Returns the image name given the path 33 | ''' 34 | image_name = image_path.split('/')[-1].split('.')[0] 35 | return image_name 36 | 37 | 38 | def get_number_of_cells(image_name): 39 | ''' 40 | Returns the number of cells for the current image. 41 | ''' 42 | pattern = '\w+_\w+\d+_C(\d+)_' 43 | nb_cells = int(re.findall(pattern, image_name)[0]) 44 | return nb_cells 45 | 46 | 47 | def split_data(image_paths, target_paths): 48 | ''' 49 | Splits the data into a training and a validation set. 50 | ''' 51 | nb_cells = [get_number_of_cells(im_path) for im_path in image_paths] 52 | im_path_train, im_path_val, tar_path_train, tar_path_val = train_test_split( 53 | image_paths, 54 | target_paths, 55 | test_size=0.1, 56 | random_state=SEED, 57 | stratify=nb_cells) 58 | 59 | return im_path_train, im_path_val, tar_path_train, tar_path_val 60 | 61 | 62 | def download_data(root='./'): 63 | ''' 64 | Downloads the BBBC005 dataset from: 65 | https://data.broadinstitute.org/bbbc/BBBC005/ 66 | ''' 67 | from six.moves import urllib 68 | import zipfile 69 | 70 | folder = os.path.expanduser('data') 71 | data_url = 'https://data.broadinstitute.org/bbbc/BBBC005/BBBC005_v1_images.zip' 72 | target_url = 'https://data.broadinstitute.org/bbbc/BBBC005/BBBC005_v1_ground_truth.zip' 73 | 74 | data_folder = data_url.split('/')[-1].replace('.zip', '') 75 | target_folder = target_url.split('/')[-1].replace('.zip', '') 76 | 77 | if os.path.exists(os.path.join(root, folder)) and \ 78 | os.path.exists(os.path.join(root, folder, 'data_paths.txt')): 79 | return 80 | 81 | # Download dataset if it doesn't exist already 82 | try: 83 | os.makedirs(os.path.join(root, folder)) 84 | os.makedirs(os.path.join(root, folder, data_folder)) 85 | os.makedirs(os.path.join(root, folder, target_folder)) 86 | except OSError as e: 87 | if e.errno == errno.EEXIST: 88 | pass 89 | else: 90 | raise 91 | 92 | print('Downloading ' + data_url) 93 | data = urllib.request.urlopen(data_url) 94 | filename = data_url.rpartition('/')[2] 95 | file_path = os.path.join(root, folder, filename) 96 | with open(file_path, 'wb') as f: 97 | f.write(data.read()) 98 | with zipfile.ZipFile(file_path, 'r') as zip_f: 99 | zip_f.extractall(os.path.join(root, folder)) 100 | os.unlink(file_path) 101 | 102 | print('Downloading ' + target_url) 103 | data = urllib.request.urlopen(target_url) 104 | filename = target_url.rpartition('/')[2] 105 | file_path = os.path.join(root, folder, filename) 106 | with open(file_path, 'wb') as f: 107 | f.write(data.read()) 108 | with zipfile.ZipFile(file_path, 'r') as zip_f: 109 | for name in zip_f.namelist(): 110 | if name.startswith('BBBC005_v1_ground_truth/'): 111 | zip_f.extract(name, os.path.join(root, folder)) 112 | os.unlink(file_path) 113 | -------------------------------------------------------------------------------- /code/section1/video1_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Mastering PyTorch\n", 8 | "\n", 9 | "## Supervised learning\n", 10 | "\n", 11 | "### Powerful PyTorch\n", 12 | "\n", 13 | "#### Accompanying notebook to Video 1.1" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "# Import libs\n", 25 | "from __future__ import print_function\n", 26 | "\n", 27 | "import torch\n", 28 | "import torch.nn as nn\n", 29 | "import torch.optim as optim\n", 30 | "from torch.autograd import Variable\n", 31 | "import torch.nn.functional as F\n", 32 | "from torch.utils.data import Dataset, DataLoader\n", 33 | "\n", 34 | "import random\n", 35 | "import time" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "collapsed": true 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "# Setup globals\n", 47 | "batch_size = 1\n", 48 | "in_features = 10\n", 49 | "hidden = 20\n", 50 | "out_features = 1" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "collapsed": false 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "# Sequential API example\n", 62 | "# Create model\n", 63 | "model = nn.Sequential(\n", 64 | " nn.Linear(in_features, hidden),\n", 65 | " nn.ReLU(),\n", 66 | " nn.Linear(hidden, out_features)\n", 67 | ")\n", 68 | "print(model)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "collapsed": false 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "# Create dummy input\n", 80 | "x = Variable(torch.randn(batch_size, in_features))\n", 81 | "# Run forward pass\n", 82 | "output = model(x)\n", 83 | "print(output)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "collapsed": false 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "# Functional API example\n", 95 | "# Create model\n", 96 | "class CustomNet(nn.Module):\n", 97 | " def __init__(self, in_features, hidden, out_features):\n", 98 | " \"\"\"\n", 99 | " Create three linear layers\n", 100 | " \"\"\"\n", 101 | " super(CustomNet, self).__init__()\n", 102 | " self.linear1 = nn.Linear(in_features, hidden)\n", 103 | " self.linear2 = nn.Linear(hidden, hidden)\n", 104 | " self.linear3 = nn.Linear(hidden, out_features)\n", 105 | "\n", 106 | " def forward(self, x):\n", 107 | " \"\"\"\n", 108 | " Draw a random number from [0, 10]. \n", 109 | " If it's 0, skip the second layer. Otherwise loop it!\n", 110 | " \"\"\"\n", 111 | " x = F.relu(self.linear1(x))\n", 112 | " while random.randint(0, 10) != 0: \n", 113 | " #while x.norm() > 2:\n", 114 | " print('2nd layer used')\n", 115 | " x = F.relu(self.linear2(x))\n", 116 | " x = self.linear3(x)\n", 117 | " return x\n", 118 | "\n", 119 | "custom_model = CustomNet(in_features, hidden, out_features)\n", 120 | "print(custom_model)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": { 127 | "collapsed": false 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "# Run forward pass with same dummy variable\n", 132 | "output = custom_model(x)\n", 133 | "print(output)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "collapsed": true 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "# ConvNet example" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "![ConvNet](images/conv_functional2.png)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": { 158 | "collapsed": true 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "# Debug example\n", 163 | "# Create Convnet\n", 164 | "class ConvNet(nn.Module):\n", 165 | " def __init__(self, in_channels, hidden, out_features):\n", 166 | " \"\"\"\n", 167 | " Create ConvNet with two parallel convolutions\n", 168 | " \"\"\"\n", 169 | " super(ConvNet, self).__init__()\n", 170 | " self.conv1_1 = nn.Conv2d(in_channels=in_channels,\n", 171 | " out_channels=10,\n", 172 | " kernel_size=3,\n", 173 | " padding=1)\n", 174 | " self.conv1_2 = nn.Conv2d(in_channels=in_channels,\n", 175 | " out_channels=10,\n", 176 | " kernel_size=3,\n", 177 | " padding=1)\n", 178 | " self.conv2 = nn.Conv2d(in_channels=20,\n", 179 | " out_channels=1,\n", 180 | " kernel_size=3,\n", 181 | " padding=1)\n", 182 | " self.linear1 = nn.Linear(hidden, out_features)\n", 183 | "\n", 184 | " def forward(self, x):\n", 185 | " \"\"\"\n", 186 | " Pass input through both ConvLayers and stack them afterwards\n", 187 | " \"\"\"\n", 188 | " x1 = F.relu(self.conv1_1(x))\n", 189 | " x2 = F.relu(self.conv1_2(x))\n", 190 | " x = torch.cat((x1, x2), dim=1)\n", 191 | " x = self.conv2(x)\n", 192 | " print('x size (after conv2): {}'.format(x.shape))\n", 193 | " x = x.view(x.size(0), -1)\n", 194 | " x = self.linear1(x)\n", 195 | " return x\n", 196 | " \n", 197 | "conv_model = ConvNet(in_channels=3, hidden=576, out_features=out_features)\n", 198 | "# Create dummy input\n", 199 | "x_conv = Variable(torch.randn(batch_size, 3, 24, 24))" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": { 206 | "collapsed": false 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "# Run forward pass\n", 211 | "output = conv_model(x_conv)\n", 212 | "print(output)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": { 219 | "collapsed": false 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "## Dataset / DataLoader example\n", 224 | "# Create a random Dataset\n", 225 | "class RandomDataset(Dataset):\n", 226 | " def __init__(self, nb_samples, consume_time=False):\n", 227 | " self.data = torch.randn(nb_samples, in_features)\n", 228 | " self.target = torch.randn(nb_samples, out_features)\n", 229 | " self.consume_time=consume_time\n", 230 | "\n", 231 | " def __getitem__(self, index):\n", 232 | " x = self.data[index]\n", 233 | " y = self.target[index]\n", 234 | "\n", 235 | " # Transform data\n", 236 | " x = x + torch.FloatTensor(x.shape).normal_() * 1e-2\n", 237 | " \n", 238 | " if self.consume_time:\n", 239 | " # Do some time consuming operation\n", 240 | " for i in xrange(5000000):\n", 241 | " j = i + 1\n", 242 | "\n", 243 | " return x, y\n", 244 | "\n", 245 | " def __len__(self):\n", 246 | " return len(self.data)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": { 253 | "collapsed": false 254 | }, 255 | "outputs": [], 256 | "source": [ 257 | "# Training loop\n", 258 | "criterion = nn.MSELoss()\n", 259 | "optimizer = optim.Adam(model.parameters(), lr=1e-2)\n", 260 | "def train(loader):\n", 261 | " for batch_idx, (data, target) in enumerate(loader):\n", 262 | " # Wrap data and target into a Variable\n", 263 | " data, target = Variable(data), Variable(target)\n", 264 | "\n", 265 | " # Clear gradients\n", 266 | " optimizer.zero_grad()\n", 267 | "\n", 268 | " # Forward pass\n", 269 | " output = model(data)\n", 270 | "\n", 271 | " # Calculate loss\n", 272 | " loss = criterion(output, target)\n", 273 | "\n", 274 | " # Backward pass\n", 275 | " loss.backward()\n", 276 | "\n", 277 | " # Weight update\n", 278 | " optimizer.step()\n", 279 | "\n", 280 | " print('Batch {}\\tLoss {}'.format(batch_idx, loss.data.numpy()[0]))" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": { 287 | "collapsed": true 288 | }, 289 | "outputs": [], 290 | "source": [ 291 | "# Create Dataset\n", 292 | "data = RandomDataset(nb_samples=30)\n", 293 | "# Create DataLoader\n", 294 | "loader = DataLoader(dataset=data,\n", 295 | " batch_size=batch_size,\n", 296 | " num_workers=0,\n", 297 | " shuffle=True)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": { 304 | "collapsed": false 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "# Start training\n", 309 | "t0 = time.time()\n", 310 | "train(loader)\n", 311 | "time_fast = time.time() - t0\n", 312 | "print('Training finished in {:.2f} seconds'.format(time_fast))" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": { 319 | "collapsed": false 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "# Create time consuming Dataset\n", 324 | "data_slow = RandomDataset(nb_samples=30, consume_time=True)\n", 325 | "loader_slow = DataLoader(dataset=data_slow,\n", 326 | " batch_size=batch_size,\n", 327 | " num_workers=0,\n", 328 | " shuffle=True)\n", 329 | "# Start training\n", 330 | "t0 = time.time()\n", 331 | "train(loader_slow)\n", 332 | "time_slow = time.time() - t0\n", 333 | "print('Training finished in {:.2f} seconds'.format(time_slow))" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "collapsed": false 341 | }, 342 | "outputs": [], 343 | "source": [ 344 | "loader_slow_multi_proc = DataLoader(dataset=data_slow,\n", 345 | " batch_size=batch_size,\n", 346 | " num_workers=4,\n", 347 | " shuffle=True)\n", 348 | "# Start training\n", 349 | "t0 = time.time()\n", 350 | "train(loader_slow_multi_proc)\n", 351 | "time_multi_proc = time.time() - t0\n", 352 | "print('Training finished in {:.2f} seconds'.format(time_multi_proc))" 353 | ] 354 | } 355 | ], 356 | "metadata": { 357 | "kernelspec": { 358 | "display_name": "Python 2", 359 | "language": "python", 360 | "name": "python2" 361 | }, 362 | "language_info": { 363 | "codemirror_mode": { 364 | "name": "ipython", 365 | "version": 2 366 | }, 367 | "file_extension": ".py", 368 | "mimetype": "text/x-python", 369 | "name": "python", 370 | "nbconvert_exporter": "python", 371 | "pygments_lexer": "ipython2", 372 | "version": "2.7.13" 373 | } 374 | }, 375 | "nbformat": 4, 376 | "nbformat_minor": 2 377 | } 378 | -------------------------------------------------------------------------------- /code/section1/video1_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Mastering PyTorch\n", 8 | "\n", 9 | "## Supervised learning\n", 10 | "\n", 11 | "### Build a UNet for segmenting cells\n", 12 | "\n", 13 | "#### Accompanying notebook to Video 1.2" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "This notebook will guide us through creating a UNet to segment cell data.\n", 21 | "\n", 22 | "The image set (BBBC005v1) was taken from the Broad Bioimage Benchmark Collection [Ljosa et al., Nature Methods, 2012]: https://data.broadinstitute.org/bbbc/BBBC005/\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "# Include libraries\n", 34 | "\n", 35 | "import numpy as np\n", 36 | "from PIL import Image\n", 37 | "\n", 38 | "import os\n", 39 | "\n", 40 | "import torch\n", 41 | "import torch.nn as nn\n", 42 | "import torch.optim as optim\n", 43 | "import torch.nn.functional as F\n", 44 | "from torch.utils.data import Dataset, DataLoader\n", 45 | "from torch.autograd import Variable\n", 46 | "\n", 47 | "from torchvision import transforms\n", 48 | "import torchvision.transforms.functional as TF\n", 49 | "\n", 50 | "from utils import get_image_name, get_number_of_cells, \\\n", 51 | " split_data, download_data, SEED\n", 52 | "\n", 53 | "from sklearn.model_selection import train_test_split\n", 54 | "\n", 55 | "import matplotlib.pyplot as plt\n", 56 | "%matplotlib inline" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "collapsed": false 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "root = './'\n", 68 | "download_data(root=root)\n", 69 | "\n", 70 | "data_paths = os.path.join('./', 'data_paths.txt')\n", 71 | "if not os.path.exists(data_paths):\n", 72 | " !wget http://pbialecki.de/mastering_pytorch/data_paths.txt\n", 73 | "\n", 74 | "if not os.path.isfile(data_paths):\n", 75 | " print('data_paths.txt missing!')" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "collapsed": false 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "# Setup Globals\n", 87 | "use_cuda = torch.cuda.is_available()\n", 88 | "np.random.seed(SEED)\n", 89 | "torch.manual_seed(SEED)\n", 90 | "if use_cuda:\n", 91 | " torch.cuda.manual_seed(SEED)\n", 92 | " print('Using: {}'.format(torch.cuda.get_device_name(0)))\n", 93 | "print_steps = 10" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "collapsed": true 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "# Utility functions\n", 105 | "def weights_init(m):\n", 106 | " '''\n", 107 | " Initialize the weights of each Conv2d layer using xavier_uniform\n", 108 | " (\"Understanding the difficulty of training deep feedforward\n", 109 | " neural networks\" - Glorot, X. & Bengio, Y. (2010))\n", 110 | " '''\n", 111 | " if isinstance(m, nn.Conv2d):\n", 112 | " nn.init.xavier_uniform(m.weight.data)\n", 113 | " m.bias.data.zero_()\n", 114 | " if isinstance(m, nn.ConvTranspose2d):\n", 115 | " nn.init.xavier_uniform(m.weight.data)\n", 116 | " m.bias.data.zero_()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "collapsed": true 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "class CellDataset(Dataset):\n", 128 | " def __init__(self, image_paths, target_paths, size):\n", 129 | " self.image_paths = image_paths\n", 130 | " self.target_paths = target_paths\n", 131 | " self.resize_image = transforms.Resize(\n", 132 | " size=size, interpolation=Image.BILINEAR)\n", 133 | " self.resize_mask = transforms.Resize(\n", 134 | " size=size, interpolation=Image.NEAREST)\n", 135 | "\n", 136 | " def transform(self, image, mask):\n", 137 | " # Resize\n", 138 | " image = self.resize_image(image)\n", 139 | " mask = self.resize_mask(mask)\n", 140 | " # Transform to tensor\n", 141 | " image = TF.to_tensor(image)\n", 142 | " mask = TF.to_tensor(mask)\n", 143 | " return image, mask\n", 144 | "\n", 145 | " def __getitem__(self, index):\n", 146 | " image = Image.open(self.image_paths[index])\n", 147 | " mask = Image.open(self.target_paths[index])\n", 148 | " x, y = self.transform(image, mask)\n", 149 | " return x, y\n", 150 | "\n", 151 | " def __len__(self):\n", 152 | " return len(self.image_paths)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": { 159 | "collapsed": true 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "def get_random_sample(dataset):\n", 164 | " '''\n", 165 | " Get a random sample from the specified dataset.\n", 166 | " '''\n", 167 | " data, target = dataset[int(np.random.choice(len(dataset), 1))]\n", 168 | " data.unsqueeze_(0)\n", 169 | " target.unsqueeze_(0)\n", 170 | " if use_cuda:\n", 171 | " data = data.cuda()\n", 172 | " target = target.cuda()\n", 173 | " data = Variable(data)\n", 174 | " target = Variable(target)\n", 175 | " return data, target" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "### Overview of UNet architecture\n", 183 | "\n", 184 | "![unet](./images/unet_paper.png)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "### BaseConv\n", 192 | "\n", 193 | "* Consists of 2 Conv layers with ReLU\n", 194 | "\n", 195 | "![base_conv](./images/base_conv.png)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": { 202 | "collapsed": true 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "class BaseConv(nn.Module):\n", 207 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 208 | " stride):\n", 209 | " super(BaseConv, self).__init__()\n", 210 | "\n", 211 | " self.act = nn.ReLU()\n", 212 | "\n", 213 | " self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding,\n", 214 | " stride)\n", 215 | "\n", 216 | " self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size,\n", 217 | " padding, stride)\n", 218 | "\n", 219 | " def forward(self, x):\n", 220 | " x = self.act(self.conv1(x))\n", 221 | " x = self.act(self.conv2(x))\n", 222 | " return x" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "### DownConv\n", 230 | "* Consists of MaxPool layer and a BaseConv block\n", 231 | "\n", 232 | "![down_conv](images/down_conv.png)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": { 239 | "collapsed": true 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "class DownConv(nn.Module):\n", 244 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 245 | " stride):\n", 246 | " super(DownConv, self).__init__()\n", 247 | "\n", 248 | " self.pool1 = nn.MaxPool2d(kernel_size=2)\n", 249 | " self.conv_block = BaseConv(in_channels, out_channels, kernel_size,\n", 250 | " padding, stride)\n", 251 | "\n", 252 | " def forward(self, x):\n", 253 | " x = self.pool1(x)\n", 254 | " x = self.conv_block(x)\n", 255 | " return x" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "### UpConv\n", 263 | "* Consists of ConvTranspose layer (for upsampling) and a BaseConv block\n", 264 | "\n", 265 | "![up_conv](images/up_conv.png)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": { 272 | "collapsed": true 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "class UpConv(nn.Module):\n", 277 | " def __init__(self, in_channels, in_channels_skip, out_channels,\n", 278 | " kernel_size, padding, stride):\n", 279 | " super(UpConv, self).__init__()\n", 280 | "\n", 281 | " self.conv_trans1 = nn.ConvTranspose2d(\n", 282 | " in_channels, in_channels, kernel_size=2, padding=0, stride=2)\n", 283 | " self.conv_block = BaseConv(\n", 284 | " in_channels=in_channels + in_channels_skip,\n", 285 | " out_channels=out_channels,\n", 286 | " kernel_size=kernel_size,\n", 287 | " padding=padding,\n", 288 | " stride=stride)\n", 289 | "\n", 290 | " def forward(self, x, x_skip):\n", 291 | " x = self.conv_trans1(x)\n", 292 | " x = torch.cat((x, x_skip), dim=1)\n", 293 | " x = self.conv_block(x)\n", 294 | " return x" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "### Create a modified UNet\n", 302 | "![image](images/unet_paper.png)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": { 309 | "collapsed": true 310 | }, 311 | "outputs": [], 312 | "source": [ 313 | "class UNet(nn.Module):\n", 314 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 315 | " stride):\n", 316 | " super(UNet, self).__init__()\n", 317 | "\n", 318 | " self.init_conv = BaseConv(in_channels, out_channels, kernel_size,\n", 319 | " padding, stride)\n", 320 | "\n", 321 | " self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size,\n", 322 | " padding, stride)\n", 323 | "\n", 324 | " self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size,\n", 325 | " padding, stride)\n", 326 | "\n", 327 | " self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size,\n", 328 | " padding, stride)\n", 329 | "\n", 330 | " self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels,\n", 331 | " kernel_size, padding, stride)\n", 332 | "\n", 333 | " self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels,\n", 334 | " kernel_size, padding, stride)\n", 335 | "\n", 336 | " self.up1 = UpConv(2 * out_channels, out_channels, out_channels,\n", 337 | " kernel_size, padding, stride)\n", 338 | "\n", 339 | " self.out = nn.Conv2d(out_channels, 1, kernel_size, padding, stride)\n", 340 | "\n", 341 | " def forward(self, x):\n", 342 | " # Encoder\n", 343 | " x = self.init_conv(x)\n", 344 | " x1 = self.down1(x)\n", 345 | " x2 = self.down2(x1)\n", 346 | " x3 = self.down3(x2)\n", 347 | " # Decoder\n", 348 | " x_up = self.up3(x3, x2)\n", 349 | " x_up = self.up2(x_up, x1)\n", 350 | " x_up = self.up1(x_up, x)\n", 351 | " x_out = F.sigmoid(self.out(x_up))\n", 352 | " return x_out" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "metadata": { 359 | "collapsed": true 360 | }, 361 | "outputs": [], 362 | "source": [ 363 | "def train(epoch):\n", 364 | " '''\n", 365 | " Main training loop\n", 366 | " '''\n", 367 | " # Set model to train mode\n", 368 | " model.train()\n", 369 | " # Iterate training set\n", 370 | " for batch_idx, (data, mask) in enumerate(train_loader):\n", 371 | " if use_cuda:\n", 372 | " data = data.cuda()\n", 373 | " mask = mask.cuda()\n", 374 | " data = Variable(data)\n", 375 | " mask = Variable(mask.squeeze())\n", 376 | " optimizer.zero_grad()\n", 377 | " output = model(data)\n", 378 | " loss = criterion(output.squeeze(), mask)\n", 379 | " loss.backward()\n", 380 | " optimizer.step()\n", 381 | "\n", 382 | " if batch_idx % print_steps == 0:\n", 383 | " loss_data = loss.data[0]\n", 384 | " train_losses.append(loss_data)\n", 385 | " print(\n", 386 | " 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.\n", 387 | " format(epoch, batch_idx * len(data), len(train_loader.dataset),\n", 388 | " 100. * batch_idx / len(train_loader), loss_data))" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": { 395 | "collapsed": true 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "def validate():\n", 400 | " '''\n", 401 | " Validation loop\n", 402 | " '''\n", 403 | " # Set model to eval mode\n", 404 | " model.eval()\n", 405 | " # Setup val_loss\n", 406 | " val_loss = 0\n", 407 | " # Disable gradients (to save memory)\n", 408 | " with torch.no_grad():\n", 409 | " # Iterate validation set\n", 410 | " for data, mask in val_loader:\n", 411 | " if use_cuda:\n", 412 | " data = data.cuda()\n", 413 | " mask = mask.cuda()\n", 414 | " data = Variable(data)\n", 415 | " mask = Variable(mask.squeeze())\n", 416 | " output = model(data)\n", 417 | " val_loss += F.binary_cross_entropy(output.squeeze(), mask).data[0]\n", 418 | " # Calculate mean of validation loss\n", 419 | " val_loss /= len(val_loader)\n", 420 | " val_losses.append(val_loss)\n", 421 | " print('Validation loss: {:.4f}'.format(val_loss))" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": { 428 | "collapsed": true 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "# Get train data folders and split to training / validation set\n", 433 | "with open(data_paths, 'r') as f:\n", 434 | " data_paths_list = f.readlines()\n", 435 | "image_paths = [line.split(',')[0].strip() for line in data_paths_list]\n", 436 | "target_paths = [line.split(',')[1].strip() for line in data_paths_list]\n", 437 | "\n", 438 | "# Split data into train/validation datasets\n", 439 | "im_path_train, im_path_val, tar_path_train, tar_path_val = split_data(\n", 440 | " image_paths, target_paths)" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": { 447 | "collapsed": true 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "# Create datasets\n", 452 | "train_dataset = CellDataset(\n", 453 | " image_paths=im_path_train,\n", 454 | " target_paths=tar_path_train,\n", 455 | " size=(96, 96)\n", 456 | ")\n", 457 | "val_dataset = CellDataset(\n", 458 | " image_paths=im_path_val,\n", 459 | " target_paths=tar_path_val,\n", 460 | " size=(96, 96)\n", 461 | ")\n", 462 | "\n", 463 | "# Wrap in DataLoader\n", 464 | "train_loader = DataLoader(\n", 465 | " dataset=train_dataset,\n", 466 | " batch_size=32,\n", 467 | " num_workers=6,\n", 468 | " shuffle=True\n", 469 | ")\n", 470 | "val_loader = DataLoader(\n", 471 | " dataset=val_dataset,\n", 472 | " batch_size=64,\n", 473 | " num_workers=6,\n", 474 | " shuffle=True\n", 475 | ")" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": { 482 | "collapsed": true 483 | }, 484 | "outputs": [], 485 | "source": [ 486 | "# Creae model\n", 487 | "model = UNet(\n", 488 | " in_channels=1, out_channels=32, kernel_size=3, padding=1, stride=1)\n", 489 | "# Initialize weights\n", 490 | "model.apply(weights_init)\n", 491 | "# Push to GPU, if available\n", 492 | "if use_cuda:\n", 493 | " model.cuda()" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": null, 499 | "metadata": { 500 | "collapsed": true 501 | }, 502 | "outputs": [], 503 | "source": [ 504 | "# Create optimizer\n", 505 | "optimizer = optim.SGD(model.parameters(), lr=1e-3)\n", 506 | "# Create criterion\n", 507 | "criterion = nn.BCELoss()" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": { 514 | "collapsed": false, 515 | "scrolled": true 516 | }, 517 | "outputs": [], 518 | "source": [ 519 | "# Start training\n", 520 | "train_losses, val_losses = [], []\n", 521 | "epochs = 30\n", 522 | "for epoch in range(1, epochs):\n", 523 | " train(epoch)\n", 524 | " validate()" 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "metadata": {}, 530 | "source": [ 531 | "Let's visualize the loss curves and some validation images!" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "metadata": { 538 | "collapsed": false, 539 | "scrolled": false 540 | }, 541 | "outputs": [], 542 | "source": [ 543 | "train_losses = np.array(train_losses)\n", 544 | "val_losses = np.array(val_losses)\n", 545 | "\n", 546 | "val_indices = np.linspace(0, (epochs-1)*len(train_loader)/print_steps, epochs-1)\n", 547 | "\n", 548 | "plt.plot(train_losses, '-', label='train loss')\n", 549 | "plt.plot(val_indices, val_losses, '--', label='val loss')\n", 550 | "plt.yscale(\"log\", nonposy='clip')\n", 551 | "plt.xlabel('Iterations')\n", 552 | "plt.ylabel('BCELoss')\n", 553 | "plt.legend()\n", 554 | "plt.show()" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": null, 560 | "metadata": { 561 | "collapsed": false 562 | }, 563 | "outputs": [], 564 | "source": [ 565 | "val_data, val_target = get_random_sample(val_dataset)\n", 566 | "\n", 567 | "val_pred = model(val_data)\n", 568 | "val_pred_arr = val_pred.data.cpu().squeeze_().numpy()\n", 569 | "val_target_arr = val_target.data.cpu().squeeze_().numpy()\n", 570 | "\n", 571 | "fig, (ax1, ax2, ax3) = plt.subplots(1, 3)\n", 572 | "ax1.imshow(val_pred_arr)\n", 573 | "ax1.set_title('Prediction')\n", 574 | "ax2.imshow(val_target_arr)\n", 575 | "ax2.set_title('Target')\n", 576 | "ax3.imshow(np.abs(val_pred_arr - val_target_arr))\n", 577 | "ax3.set_title('Absolute error')" 578 | ] 579 | } 580 | ], 581 | "metadata": { 582 | "kernelspec": { 583 | "display_name": "Python 2", 584 | "language": "python", 585 | "name": "python2" 586 | }, 587 | "language_info": { 588 | "codemirror_mode": { 589 | "name": "ipython", 590 | "version": 2 591 | }, 592 | "file_extension": ".py", 593 | "mimetype": "text/x-python", 594 | "name": "python", 595 | "nbconvert_exporter": "python", 596 | "pygments_lexer": "ipython2", 597 | "version": "2.7.13" 598 | } 599 | }, 600 | "nbformat": 4, 601 | "nbformat_minor": 2 602 | } 603 | -------------------------------------------------------------------------------- /code/section1/video1_3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Mastering PyTorch\n", 8 | "\n", 9 | "## Supervised learning\n", 10 | "\n", 11 | "### Extend the UNet with skip connections and data augmentation\n", 12 | "\n", 13 | "#### Accompanying notebook to Video 1.3" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "# Include libraries\n", 25 | "\n", 26 | "import numpy as np\n", 27 | "from PIL import Image\n", 28 | "\n", 29 | "import os\n", 30 | "import random\n", 31 | "\n", 32 | "import torch\n", 33 | "import torch.nn as nn\n", 34 | "import torch.optim as optim\n", 35 | "import torch.nn.functional as F\n", 36 | "from torch.utils.data import Dataset, DataLoader\n", 37 | "from torch.autograd import Variable\n", 38 | "\n", 39 | "from torchvision import transforms\n", 40 | "import torchvision.transforms.functional as TF\n", 41 | "\n", 42 | "from utils import get_image_name, get_number_of_cells, \\\n", 43 | " split_data, download_data, SEED\n", 44 | "\n", 45 | "from sklearn.model_selection import train_test_split\n", 46 | "\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "%matplotlib inline" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "root = './'\n", 60 | "download_data(root=root)\n", 61 | "\n", 62 | "data_paths = os.path.join('./', 'data_paths.txt')\n", 63 | "if not os.path.exists(data_paths):\n", 64 | " !wget http://pbialecki.de/mastering_pytorch/data_paths.txt\n", 65 | "\n", 66 | "if not os.path.isfile(data_paths):\n", 67 | " print('data_paths.txt missing!')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": { 74 | "collapsed": false 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "# Setup Globals\n", 79 | "use_cuda = torch.cuda.is_available()\n", 80 | "np.random.seed(SEED)\n", 81 | "torch.manual_seed(SEED)\n", 82 | "if use_cuda:\n", 83 | " torch.cuda.manual_seed(SEED)\n", 84 | " print('Using: {}'.format(torch.cuda.get_device_name(0)))\n", 85 | "print_steps = 10" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "collapsed": true 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "# Utility functions\n", 97 | "def weights_init(m):\n", 98 | " '''\n", 99 | " Initialize the weights of each Conv2d layer using xavier_uniform\n", 100 | " (\"Understanding the difficulty of training deep feedforward\n", 101 | " neural networks\" - Glorot, X. & Bengio, Y. (2010))\n", 102 | " '''\n", 103 | " if isinstance(m, nn.Conv2d):\n", 104 | " nn.init.xavier_uniform(m.weight.data)\n", 105 | " m.bias.data.zero_()\n", 106 | " if isinstance(m, nn.ConvTranspose2d):\n", 107 | " nn.init.xavier_uniform(m.weight.data)\n", 108 | " m.bias.data.zero_()" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": { 115 | "collapsed": true 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "class CellDataset(Dataset):\n", 120 | " def __init__(self, image_paths, target_paths, size, train=False):\n", 121 | " self.image_paths = image_paths\n", 122 | " self.target_paths = target_paths\n", 123 | " self.size = size\n", 124 | " resize_size = [s+10 for s in self.size]\n", 125 | " self.resize_image = transforms.Resize(\n", 126 | " size=resize_size, interpolation=Image.BILINEAR)\n", 127 | " self.resize_mask = transforms.Resize(\n", 128 | " size=resize_size, interpolation=Image.NEAREST)\n", 129 | " self.train = train\n", 130 | " \n", 131 | " def transform(self, image, mask):\n", 132 | " # Resize\n", 133 | " image = self.resize_image(image)\n", 134 | " mask = self.resize_mask(mask)\n", 135 | " \n", 136 | " # Perform data augmentation\n", 137 | " if self.train: \n", 138 | " # Random cropping\n", 139 | " i, j, h, w = transforms.RandomCrop.get_params(\n", 140 | " image, output_size=self.size)\n", 141 | " image = TF.crop(image, i, j, h, w)\n", 142 | " mask = TF.crop(mask, i, j, h, w)\n", 143 | " \n", 144 | " # Random horizontal flipping\n", 145 | " if random.random() > 0.5:\n", 146 | " image = TF.hflip(image)\n", 147 | " mask = TF.hflip(mask)\n", 148 | " \n", 149 | " # Random vertical flipping\n", 150 | " if random.random() > 0.5:\n", 151 | " image = TF.vflip(image)\n", 152 | " mask = TF.vflip(mask)\n", 153 | " else:\n", 154 | " center_crop = transforms.CenterCrop(self.size)\n", 155 | " image = center_crop(image)\n", 156 | " mask = center_crop(mask)\n", 157 | " \n", 158 | " # Transform to tensor\n", 159 | " image = TF.to_tensor(image)\n", 160 | " mask = TF.to_tensor(mask)\n", 161 | " return image, mask\n", 162 | "\n", 163 | " def __getitem__(self, index):\n", 164 | " image = Image.open(self.image_paths[index])\n", 165 | " mask = Image.open(self.target_paths[index])\n", 166 | " x, y = self.transform(image, mask)\n", 167 | " return x, y\n", 168 | "\n", 169 | " def __len__(self):\n", 170 | " return len(self.image_paths)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "collapsed": true 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "def get_random_sample(dataset):\n", 182 | " '''\n", 183 | " Get a random sample from the specified dataset.\n", 184 | " '''\n", 185 | " data, target = dataset[int(np.random.choice(len(dataset), 1))]\n", 186 | " data.unsqueeze_(0)\n", 187 | " target.unsqueeze_(0)\n", 188 | " if use_cuda:\n", 189 | " data = data.cuda()\n", 190 | " target = target.cuda()\n", 191 | " data = Variable(data)\n", 192 | " target = Variable(target)\n", 193 | " return data, target" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "### Add Residual to BaseConv \n", 201 | "* Add an additional Conv layer, if channels do not match\n", 202 | "\n", 203 | "![base_conv_skip](images/base_conv_skip.png)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": { 210 | "collapsed": true 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "class BaseConv(nn.Module):\n", 215 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 216 | " stride):\n", 217 | " super(BaseConv, self).__init__()\n", 218 | "\n", 219 | " self.act = nn.ReLU()\n", 220 | "\n", 221 | " self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding,\n", 222 | " stride)\n", 223 | "\n", 224 | " self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size,\n", 225 | " padding, stride)\n", 226 | " \n", 227 | " self.downsample = None\n", 228 | " if in_channels != out_channels:\n", 229 | " self.downsample = nn.Sequential(\n", 230 | " nn.Conv2d(\n", 231 | " in_channels, out_channels, kernel_size, padding, stride)\n", 232 | " )\n", 233 | "\n", 234 | " def forward(self, x):\n", 235 | " residual = x\n", 236 | " out = self.act(self.conv1(x))\n", 237 | " out = self.conv2(out)\n", 238 | " \n", 239 | " if self.downsample:\n", 240 | " residual = self.downsample(x)\n", 241 | " out += residual\n", 242 | " out = self.act(out)\n", 243 | " return out" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "collapsed": true 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "class DownConv(nn.Module):\n", 255 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 256 | " stride):\n", 257 | " super(DownConv, self).__init__()\n", 258 | "\n", 259 | " self.pool1 = nn.MaxPool2d(kernel_size=2)\n", 260 | " self.conv_block = BaseConv(in_channels, out_channels, kernel_size,\n", 261 | " padding, stride)\n", 262 | "\n", 263 | " def forward(self, x):\n", 264 | " x = self.pool1(x)\n", 265 | " x = self.conv_block(x)\n", 266 | " return x" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": { 273 | "collapsed": true 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "class UpConv(nn.Module):\n", 278 | " def __init__(self, in_channels, in_channels_skip, out_channels,\n", 279 | " kernel_size, padding, stride):\n", 280 | " super(UpConv, self).__init__()\n", 281 | "\n", 282 | " self.conv_trans1 = nn.ConvTranspose2d(\n", 283 | " in_channels, in_channels, kernel_size=2, padding=0, stride=2)\n", 284 | " self.conv_block = BaseConv(\n", 285 | " in_channels=in_channels + in_channels_skip,\n", 286 | " out_channels=out_channels,\n", 287 | " kernel_size=kernel_size,\n", 288 | " padding=padding,\n", 289 | " stride=stride)\n", 290 | "\n", 291 | " def forward(self, x, x_skip):\n", 292 | " x = self.conv_trans1(x)\n", 293 | " x = torch.cat((x, x_skip), dim=1)\n", 294 | " x = self.conv_block(x)\n", 295 | " return x" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": { 302 | "collapsed": true 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "class ResUNet(nn.Module):\n", 307 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 308 | " stride):\n", 309 | " super(ResUNet, self).__init__()\n", 310 | "\n", 311 | " self.init_conv = BaseConv(in_channels, out_channels, kernel_size, padding, stride)\n", 312 | "\n", 313 | " self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size,\n", 314 | " padding, stride)\n", 315 | "\n", 316 | " self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size,\n", 317 | " padding, stride)\n", 318 | "\n", 319 | " self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size,\n", 320 | " padding, stride)\n", 321 | "\n", 322 | " self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels,\n", 323 | " kernel_size, padding, stride)\n", 324 | "\n", 325 | " self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels,\n", 326 | " kernel_size, padding, stride)\n", 327 | "\n", 328 | " self.up1 = UpConv(2 * out_channels, out_channels, out_channels,\n", 329 | " kernel_size, padding, stride)\n", 330 | "\n", 331 | " self.out = nn.Conv2d(out_channels, 1, kernel_size, padding, stride)\n", 332 | "\n", 333 | " def forward(self, x):\n", 334 | " # Encoder\n", 335 | " x = self.init_conv(x)\n", 336 | " x1 = self.down1(x)\n", 337 | " x2 = self.down2(x1)\n", 338 | " x3 = self.down3(x2)\n", 339 | " # Decoder\n", 340 | " x_up = self.up3(x3, x2)\n", 341 | " x_up = self.up2(x_up, x1)\n", 342 | " x_up = self.up1(x_up, x)\n", 343 | " x_out = F.sigmoid(self.out(x_up))\n", 344 | " return x_out" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": { 351 | "collapsed": true 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "def train(epoch):\n", 356 | " '''\n", 357 | " Main training loop\n", 358 | " '''\n", 359 | " # Set model to train mode\n", 360 | " model.train()\n", 361 | " # Iterate training set\n", 362 | " for batch_idx, (data, mask) in enumerate(train_loader):\n", 363 | " if use_cuda:\n", 364 | " data = data.cuda()\n", 365 | " mask = mask.cuda()\n", 366 | " data = Variable(data)\n", 367 | " mask = Variable(mask.squeeze())\n", 368 | " optimizer.zero_grad()\n", 369 | " output = model(data)\n", 370 | " loss = criterion(output.squeeze(), mask)\n", 371 | " loss.backward()\n", 372 | " optimizer.step()\n", 373 | " \n", 374 | " if batch_idx % print_steps == 0:\n", 375 | " loss_data = loss.data[0]\n", 376 | " train_losses.append(loss_data)\n", 377 | " print(\n", 378 | " 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.\n", 379 | " format(epoch, batch_idx * len(data), len(train_loader.dataset),\n", 380 | " 100. * batch_idx / len(train_loader), loss_data))" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": { 387 | "collapsed": true 388 | }, 389 | "outputs": [], 390 | "source": [ 391 | "def validate():\n", 392 | " '''\n", 393 | " Validation loop\n", 394 | " '''\n", 395 | " # Set model to eval mode\n", 396 | " model.eval()\n", 397 | " # Setup val_loss\n", 398 | " val_loss = 0\n", 399 | " # Disable gradients (to save memory)\n", 400 | " with torch.no_grad():\n", 401 | " # Iterate validation set\n", 402 | " for data, mask in val_loader:\n", 403 | " if use_cuda:\n", 404 | " data = data.cuda()\n", 405 | " mask = mask.cuda()\n", 406 | " data = Variable(data)\n", 407 | " mask = Variable(mask.squeeze())\n", 408 | " output = model(data)\n", 409 | " val_loss += F.binary_cross_entropy(output.squeeze(), mask).data[0]\n", 410 | " # Calculate mean of validation loss\n", 411 | " val_loss /= len(val_loader)\n", 412 | " val_losses.append(val_loss)\n", 413 | " print('Validation loss: {:.4f}'.format(val_loss))" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": { 420 | "collapsed": true 421 | }, 422 | "outputs": [], 423 | "source": [ 424 | "# Get train data folders and split to training / validation set\n", 425 | "with open(data_paths, 'r') as f:\n", 426 | " data_paths = f.readlines()\n", 427 | "image_paths = [line.split(',')[0].strip() for line in data_paths]\n", 428 | "target_paths = [line.split(',')[1].strip() for line in data_paths]\n", 429 | "\n", 430 | "# Split data into train/validation datasets\n", 431 | "im_path_train, im_path_val, tar_path_train, tar_path_val = split_data(\n", 432 | " image_paths, target_paths)" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "metadata": { 439 | "collapsed": true 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "# Create datasets\n", 444 | "train_dataset = CellDataset(\n", 445 | " image_paths=im_path_train,\n", 446 | " target_paths=tar_path_train,\n", 447 | " size=(96, 96),\n", 448 | " train=True\n", 449 | ")\n", 450 | "val_dataset = CellDataset(\n", 451 | " image_paths=im_path_val,\n", 452 | " target_paths=tar_path_val,\n", 453 | " size=(96, 96),\n", 454 | " train=False\n", 455 | ")\n", 456 | "\n", 457 | "# Wrap in DataLoader\n", 458 | "train_loader = DataLoader(\n", 459 | " dataset=train_dataset,\n", 460 | " batch_size=32,\n", 461 | " num_workers=12,\n", 462 | " shuffle=True\n", 463 | ")\n", 464 | "val_loader = DataLoader(\n", 465 | " dataset=val_dataset,\n", 466 | " batch_size=64,\n", 467 | " num_workers=12,\n", 468 | " shuffle=True\n", 469 | ")" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "collapsed": true 477 | }, 478 | "outputs": [], 479 | "source": [ 480 | "# Creae model\n", 481 | "model = ResUNet(\n", 482 | " in_channels=1, out_channels=32, kernel_size=3, padding=1, stride=1)\n", 483 | "# Initialize weights\n", 484 | "model.apply(weights_init)\n", 485 | "# Push to GPU, if available\n", 486 | "if use_cuda:\n", 487 | " model.cuda()" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": null, 493 | "metadata": { 494 | "collapsed": true 495 | }, 496 | "outputs": [], 497 | "source": [ 498 | "# Create optimizer and scheduler\n", 499 | "optimizer = optim.SGD(model.parameters(), lr=1e-3)\n", 500 | "# Create criterion\n", 501 | "criterion = nn.BCELoss()" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": { 508 | "collapsed": false, 509 | "scrolled": true 510 | }, 511 | "outputs": [], 512 | "source": [ 513 | "# Start training\n", 514 | "train_losses, val_losses = [], []\n", 515 | "epochs = 30\n", 516 | "for epoch in range(1, epochs):\n", 517 | " train(epoch)\n", 518 | " validate()" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "metadata": {}, 524 | "source": [ 525 | "Let's visualize the loss curves and some validation images!" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": null, 531 | "metadata": { 532 | "collapsed": false 533 | }, 534 | "outputs": [], 535 | "source": [ 536 | "train_losses = np.array(train_losses)\n", 537 | "val_losses = np.array(val_losses)\n", 538 | "\n", 539 | "val_indices = np.linspace(0, (epochs-1)*len(train_loader)/print_steps, epochs-1)\n", 540 | "\n", 541 | "plt.plot(train_losses, '-', label='train loss')\n", 542 | "plt.plot(val_indices, val_losses, '--', label='val loss')\n", 543 | "plt.yscale(\"log\", nonposy='clip')\n", 544 | "plt.xlabel('Iterations')\n", 545 | "plt.ylabel('BCELoss')\n", 546 | "plt.legend()\n", 547 | "plt.show()" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": null, 553 | "metadata": { 554 | "collapsed": false 555 | }, 556 | "outputs": [], 557 | "source": [ 558 | "val_data, val_target = get_random_sample(val_dataset)\n", 559 | "\n", 560 | "val_pred = model(val_data)\n", 561 | "val_pred_arr = val_pred.data.cpu().squeeze_().numpy()\n", 562 | "val_target_arr = val_target.data.cpu().squeeze_().numpy()\n", 563 | "\n", 564 | "fig, (ax1, ax2, ax3) = plt.subplots(1, 3)\n", 565 | "ax1.imshow(val_pred_arr)\n", 566 | "ax1.set_title('Prediction')\n", 567 | "ax2.imshow(val_target_arr)\n", 568 | "ax2.set_title('Target')\n", 569 | "ax3.imshow(np.abs(val_pred_arr - val_target_arr))\n", 570 | "ax3.set_title('Absolute error')" 571 | ] 572 | } 573 | ], 574 | "metadata": { 575 | "kernelspec": { 576 | "display_name": "Python 2", 577 | "language": "python", 578 | "name": "python2" 579 | }, 580 | "language_info": { 581 | "codemirror_mode": { 582 | "name": "ipython", 583 | "version": 2 584 | }, 585 | "file_extension": ".py", 586 | "mimetype": "text/x-python", 587 | "name": "python", 588 | "nbconvert_exporter": "python", 589 | "pygments_lexer": "ipython2", 590 | "version": "2.7.13" 591 | } 592 | }, 593 | "nbformat": 4, 594 | "nbformat_minor": 2 595 | } 596 | -------------------------------------------------------------------------------- /code/section1/video1_4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Mastering PyTorch\n", 8 | "\n", 9 | "## Supervised learning\n", 10 | "\n", 11 | "### Tune the training with a cusom loss\n", 12 | "\n", 13 | "#### Accompanying notebook to Video 1.4" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "# Include libraries\n", 25 | "\n", 26 | "import numpy as np\n", 27 | "from PIL import Image\n", 28 | "\n", 29 | "import os\n", 30 | "import random\n", 31 | "\n", 32 | "import torch\n", 33 | "import torch.nn as nn\n", 34 | "import torch.optim as optim\n", 35 | "import torch.nn.functional as F\n", 36 | "from torch.utils.data import Dataset, DataLoader\n", 37 | "from torch.autograd import Variable\n", 38 | "\n", 39 | "from torchvision import transforms\n", 40 | "import torchvision.transforms.functional as TF\n", 41 | "\n", 42 | "from utils import get_image_name, get_number_of_cells, \\\n", 43 | " split_data, download_data, SEED\n", 44 | "\n", 45 | "from sklearn.model_selection import train_test_split\n", 46 | "\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "%matplotlib inline" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "root = './'\n", 60 | "download_data(root=root)\n", 61 | "\n", 62 | "data_paths = os.path.join('./', 'data_paths.txt')\n", 63 | "if not os.path.exists(data_paths):\n", 64 | " !wget http://pbialecki.de/mastering_pytorch/data_paths.txt\n", 65 | "\n", 66 | "if not os.path.isfile(data_paths):\n", 67 | " print('data_paths.txt missing!')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": { 74 | "collapsed": false 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "# Setup Globals\n", 79 | "use_cuda = torch.cuda.is_available()\n", 80 | "np.random.seed(SEED)\n", 81 | "torch.manual_seed(SEED)\n", 82 | "if use_cuda:\n", 83 | " torch.cuda.manual_seed(SEED)\n", 84 | " print('Using: {}'.format(torch.cuda.get_device_name(0)))\n", 85 | "print_steps = 10" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "collapsed": true 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "# Utility functions\n", 97 | "def weights_init(m):\n", 98 | " '''\n", 99 | " Initialize the weights of each Conv2d layer using xavier_uniform\n", 100 | " (\"Understanding the difficulty of training deep feedforward\n", 101 | " neural networks\" - Glorot, X. & Bengio, Y. (2010))\n", 102 | " '''\n", 103 | " if isinstance(m, nn.Conv2d):\n", 104 | " nn.init.xavier_uniform(m.weight.data)\n", 105 | " m.bias.data.zero_()\n", 106 | " if isinstance(m, nn.ConvTranspose2d):\n", 107 | " nn.init.xavier_uniform(m.weight.data)\n", 108 | " m.bias.data.zero_()" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": { 115 | "collapsed": true 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "class CellDataset(Dataset):\n", 120 | " def __init__(self, image_paths, target_paths, size, train=False):\n", 121 | " self.image_paths = image_paths\n", 122 | " self.target_paths = target_paths\n", 123 | " self.size = size\n", 124 | " resize_size = [s+10 for s in self.size]\n", 125 | " self.resize_image = transforms.Resize(\n", 126 | " size=resize_size, interpolation=Image.BILINEAR)\n", 127 | " self.resize_mask = transforms.Resize(\n", 128 | " size=resize_size, interpolation=Image.NEAREST)\n", 129 | " self.train = train\n", 130 | " \n", 131 | " def transform(self, image, mask):\n", 132 | " # Resize\n", 133 | " image = self.resize_image(image)\n", 134 | " mask = self.resize_mask(mask)\n", 135 | " \n", 136 | " # Perform data augmentation\n", 137 | " if self.train: \n", 138 | " # Random cropping\n", 139 | " i, j, h, w = transforms.RandomCrop.get_params(\n", 140 | " image, output_size=self.size)\n", 141 | " image = TF.crop(image, i, j, h, w)\n", 142 | " mask = TF.crop(mask, i, j, h, w)\n", 143 | " \n", 144 | " # Random horizontal flipping\n", 145 | " if random.random() > 0.5:\n", 146 | " image = TF.hflip(image)\n", 147 | " mask = TF.hflip(mask)\n", 148 | " \n", 149 | " # Random vertical flipping\n", 150 | " if random.random() > 0.5:\n", 151 | " image = TF.vflip(image)\n", 152 | " mask = TF.vflip(mask)\n", 153 | " else:\n", 154 | " center_crop = transforms.CenterCrop(self.size)\n", 155 | " image = center_crop(image)\n", 156 | " mask = center_crop(mask)\n", 157 | " \n", 158 | " # Transform to tensor\n", 159 | " image = TF.to_tensor(image)\n", 160 | " mask = TF.to_tensor(mask)\n", 161 | " return image, mask\n", 162 | "\n", 163 | " def __getitem__(self, index):\n", 164 | " image = Image.open(self.image_paths[index])\n", 165 | " mask = Image.open(self.target_paths[index])\n", 166 | " x, y = self.transform(image, mask)\n", 167 | " return x, y\n", 168 | "\n", 169 | " def __len__(self):\n", 170 | " return len(self.image_paths)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "collapsed": true 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "def get_random_sample(dataset):\n", 182 | " '''\n", 183 | " Get a random sample from the specified dataset.\n", 184 | " '''\n", 185 | " data, target = dataset[int(np.random.choice(len(dataset), 1))]\n", 186 | " data.unsqueeze_(0)\n", 187 | " target.unsqueeze_(0)\n", 188 | " if use_cuda:\n", 189 | " data = data.cuda()\n", 190 | " target = target.cuda()\n", 191 | " data = Variable(data)\n", 192 | " target = Variable(target)\n", 193 | " return data, target" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "collapsed": true 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "class BaseConv(nn.Module):\n", 205 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 206 | " stride):\n", 207 | " super(BaseConv, self).__init__()\n", 208 | "\n", 209 | " self.act = nn.ReLU()\n", 210 | "\n", 211 | " self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding,\n", 212 | " stride)\n", 213 | "\n", 214 | " self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size,\n", 215 | " padding, stride)\n", 216 | " \n", 217 | " self.downsample = None\n", 218 | " if in_channels != out_channels:\n", 219 | " self.downsample = nn.Sequential(\n", 220 | " nn.Conv2d(\n", 221 | " in_channels, out_channels, kernel_size, padding, stride)\n", 222 | " )\n", 223 | "\n", 224 | " def forward(self, x):\n", 225 | " residual = x\n", 226 | " out = self.act(self.conv1(x))\n", 227 | " out = self.conv2(out)\n", 228 | " \n", 229 | " if self.downsample:\n", 230 | " residual = self.downsample(x)\n", 231 | " out += residual\n", 232 | " out = self.act(out)\n", 233 | " return out" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": { 240 | "collapsed": true 241 | }, 242 | "outputs": [], 243 | "source": [ 244 | "class DownConv(nn.Module):\n", 245 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 246 | " stride):\n", 247 | " super(DownConv, self).__init__()\n", 248 | "\n", 249 | " self.pool1 = nn.MaxPool2d(kernel_size=2)\n", 250 | " self.conv_block = BaseConv(in_channels, out_channels, kernel_size,\n", 251 | " padding, stride)\n", 252 | "\n", 253 | " def forward(self, x):\n", 254 | " x = self.pool1(x)\n", 255 | " x = self.conv_block(x)\n", 256 | " return x" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": { 263 | "collapsed": true 264 | }, 265 | "outputs": [], 266 | "source": [ 267 | "class UpConv(nn.Module):\n", 268 | " def __init__(self, in_channels, in_channels_skip, out_channels,\n", 269 | " kernel_size, padding, stride):\n", 270 | " super(UpConv, self).__init__()\n", 271 | "\n", 272 | " self.conv_trans1 = nn.ConvTranspose2d(\n", 273 | " in_channels, in_channels, kernel_size=2, padding=0, stride=2)\n", 274 | " self.conv_block = BaseConv(\n", 275 | " in_channels=in_channels + in_channels_skip,\n", 276 | " out_channels=out_channels,\n", 277 | " kernel_size=kernel_size,\n", 278 | " padding=padding,\n", 279 | " stride=stride)\n", 280 | "\n", 281 | " def forward(self, x, x_skip):\n", 282 | " x = self.conv_trans1(x)\n", 283 | " x = torch.cat((x, x_skip), dim=1)\n", 284 | " x = self.conv_block(x)\n", 285 | " return x" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": { 292 | "collapsed": true 293 | }, 294 | "outputs": [], 295 | "source": [ 296 | "class ResUNet(nn.Module):\n", 297 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 298 | " stride):\n", 299 | " super(ResUNet, self).__init__()\n", 300 | "\n", 301 | " self.init_conv = BaseConv(in_channels, out_channels, kernel_size, padding, stride)\n", 302 | "\n", 303 | " self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size,\n", 304 | " padding, stride)\n", 305 | "\n", 306 | " self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size,\n", 307 | " padding, stride)\n", 308 | "\n", 309 | " self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size,\n", 310 | " padding, stride)\n", 311 | "\n", 312 | " self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels,\n", 313 | " kernel_size, padding, stride)\n", 314 | "\n", 315 | " self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels,\n", 316 | " kernel_size, padding, stride)\n", 317 | "\n", 318 | " self.up1 = UpConv(2 * out_channels, out_channels, out_channels,\n", 319 | " kernel_size, padding, stride)\n", 320 | "\n", 321 | " self.out = nn.Conv2d(out_channels, 1, kernel_size, padding, stride)\n", 322 | "\n", 323 | " def forward(self, x):\n", 324 | " # Encoder\n", 325 | " x = self.init_conv(x)\n", 326 | " x1 = self.down1(x)\n", 327 | " x2 = self.down2(x1)\n", 328 | " x3 = self.down3(x2)\n", 329 | " # Decoder\n", 330 | " x_up = self.up3(x3, x2)\n", 331 | " x_up = self.up2(x_up, x1)\n", 332 | " x_up = self.up1(x_up, x)\n", 333 | " x_out = F.sigmoid(self.out(x_up))\n", 334 | " return x_out" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "metadata": { 341 | "collapsed": true 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "def train(epoch):\n", 346 | " '''\n", 347 | " Main training loop\n", 348 | " '''\n", 349 | " # Set model to train mode\n", 350 | " model.train()\n", 351 | " # Iterate training set\n", 352 | " for batch_idx, (data, mask) in enumerate(train_loader):\n", 353 | " if use_cuda:\n", 354 | " data = data.cuda()\n", 355 | " mask = mask.cuda()\n", 356 | " data = Variable(data)\n", 357 | " mask = Variable(mask.squeeze())\n", 358 | " optimizer.zero_grad()\n", 359 | " output = model(data)\n", 360 | " loss_mask = criterion(output.squeeze(), mask)\n", 361 | " loss_dice = dice_loss(mask, output.squeeze())\n", 362 | " loss = loss_mask + loss_dice\n", 363 | " loss.backward()\n", 364 | " optimizer.step()\n", 365 | " \n", 366 | " if batch_idx % print_steps == 0:\n", 367 | " loss_mask_data = loss_mask.data[0]\n", 368 | " loss_dice_data = loss_dice.data[0]\n", 369 | " train_losses.append(loss_mask_data)\n", 370 | " print(\n", 371 | " 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tLoss(dice): {:.6f}'.\n", 372 | " format(epoch, batch_idx * len(data),\n", 373 | " len(train_loader.dataset), 100. * batch_idx / len(\n", 374 | " train_loader), loss_mask_data, loss_dice_data))" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": { 381 | "collapsed": true 382 | }, 383 | "outputs": [], 384 | "source": [ 385 | "def validate():\n", 386 | " '''\n", 387 | " Validation loop\n", 388 | " '''\n", 389 | " # Set model to eval mode\n", 390 | " model.eval()\n", 391 | " # Setup val_loss\n", 392 | " val_mask_loss = 0\n", 393 | " val_dice_loss = 0\n", 394 | " # Disable gradients (to save memory)\n", 395 | " with torch.no_grad():\n", 396 | " # Iterate validation set\n", 397 | " for data, mask in val_loader:\n", 398 | " if use_cuda:\n", 399 | " data = data.cuda()\n", 400 | " mask = mask.cuda()\n", 401 | " data = Variable(data)\n", 402 | " mask = Variable(mask.squeeze())\n", 403 | " output = model(data)\n", 404 | " val_mask_loss += F.binary_cross_entropy(output.squeeze(), mask).data[0]\n", 405 | " val_dice_loss += dice_loss(mask, output.squeeze()).data[0]\n", 406 | " # Calculate mean of validation loss\n", 407 | " val_mask_loss /= len(val_loader)\n", 408 | " val_dice_loss /= len(val_loader)\n", 409 | " val_losses.append(val_mask_loss)\n", 410 | " print('Validation\\tLoss: {:.6f}\\tLoss(dice): {:.6f}'.format(val_mask_loss, val_dice_loss))" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": { 417 | "collapsed": true 418 | }, 419 | "outputs": [], 420 | "source": [ 421 | "# Get train data folders and split to training / validation set\n", 422 | "with open(data_paths, 'r') as f:\n", 423 | " data_paths = f.readlines()\n", 424 | "image_paths = [line.split(',')[0].strip() for line in data_paths]\n", 425 | "target_paths = [line.split(',')[1].strip() for line in data_paths]\n", 426 | "\n", 427 | "# Split data into train/validation datasets\n", 428 | "im_path_train, im_path_val, tar_path_train, tar_path_val = split_data(\n", 429 | " image_paths, target_paths)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "collapsed": true 437 | }, 438 | "outputs": [], 439 | "source": [ 440 | "# Create datasets\n", 441 | "train_dataset = CellDataset(\n", 442 | " image_paths=im_path_train,\n", 443 | " target_paths=tar_path_train,\n", 444 | " size=(96, 96),\n", 445 | " train=True\n", 446 | ")\n", 447 | "val_dataset = CellDataset(\n", 448 | " image_paths=im_path_val,\n", 449 | " target_paths=tar_path_val,\n", 450 | " size=(96, 96),\n", 451 | " train=False\n", 452 | ")\n", 453 | "\n", 454 | "# Wrap in DataLoader\n", 455 | "train_loader = DataLoader(\n", 456 | " dataset=train_dataset,\n", 457 | " batch_size=32,\n", 458 | " num_workers=12,\n", 459 | " shuffle=True\n", 460 | ")\n", 461 | "val_loader = DataLoader(\n", 462 | " dataset=val_dataset,\n", 463 | " batch_size=64,\n", 464 | " num_workers=12,\n", 465 | " shuffle=True\n", 466 | ")" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": null, 472 | "metadata": { 473 | "collapsed": true 474 | }, 475 | "outputs": [], 476 | "source": [ 477 | "# Creae model\n", 478 | "model = ResUNet(\n", 479 | " in_channels=1, out_channels=32, kernel_size=3, padding=1, stride=1)\n", 480 | "# Initialize weights\n", 481 | "model.apply(weights_init)\n", 482 | "# Push to GPU, if available\n", 483 | "if use_cuda:\n", 484 | " model.cuda()" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "metadata": { 491 | "collapsed": true 492 | }, 493 | "outputs": [], 494 | "source": [ 495 | "# Create optimizer and scheduler\n", 496 | "optimizer = optim.SGD(model.parameters(), lr=1e-3)\n", 497 | "# Create criterion\n", 498 | "criterion = nn.BCELoss()" 499 | ] 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "metadata": {}, 504 | "source": [ 505 | "#### Dice coefficient\n", 506 | "\n", 507 | "Calculate the dice coefficient.\n", 508 | "\n", 509 | "Divide the \"overlap\" between the predicted and the ground truth mask by\n", 510 | "the total size of the two objects.\n", 511 | "\n", 512 | "\\begin{align}\n", 513 | "QS = \\frac{2|X \\cap Y|}{|X| + |Y|}\n", 514 | "\\end{align}" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": { 521 | "collapsed": true 522 | }, 523 | "outputs": [], 524 | "source": [ 525 | "def dice_loss(y_target, y_pred, smooth=0.0):\n", 526 | " y_target = y_target.view(-1)\n", 527 | " y_pred = y_pred.view(-1)\n", 528 | " intersection = (y_target * y_pred).sum()\n", 529 | " dice_coef = (2. * intersection + smooth) / (\n", 530 | " y_target.sum() + y_pred.sum() + smooth)\n", 531 | " return 1. - dice_coef" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "metadata": { 538 | "collapsed": false, 539 | "scrolled": true 540 | }, 541 | "outputs": [], 542 | "source": [ 543 | "# Start training\n", 544 | "train_losses, val_losses = [], []\n", 545 | "epochs = 30\n", 546 | "for epoch in range(1, epochs):\n", 547 | " train(epoch)\n", 548 | " validate()" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "metadata": { 555 | "collapsed": false 556 | }, 557 | "outputs": [], 558 | "source": [ 559 | "train_losses = np.array(train_losses)\n", 560 | "val_losses = np.array(val_losses)\n", 561 | "\n", 562 | "val_indices = np.linspace(0, (epochs-1)*len(train_loader)/print_steps, epochs-1)\n", 563 | "\n", 564 | "plt.plot(train_losses, '-', label='train loss')\n", 565 | "plt.plot(val_indices, val_losses, '--', label='val loss')\n", 566 | "plt.yscale(\"log\", nonposy='clip')\n", 567 | "plt.xlabel('Iterations')\n", 568 | "plt.ylabel('BCELoss')\n", 569 | "plt.legend()\n", 570 | "plt.show()" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": { 577 | "collapsed": true 578 | }, 579 | "outputs": [], 580 | "source": [ 581 | "val_data, val_target = get_random_sample(val_dataset)" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": null, 587 | "metadata": { 588 | "collapsed": true 589 | }, 590 | "outputs": [], 591 | "source": [ 592 | "val_pred = model(val_data)\n", 593 | "val_pred_arr = val_pred.data.cpu().squeeze_().numpy()\n", 594 | "val_target_arr = val_target.data.cpu().squeeze_().numpy()" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": null, 600 | "metadata": { 601 | "collapsed": false 602 | }, 603 | "outputs": [], 604 | "source": [ 605 | "fig, (ax1, ax2, ax3) = plt.subplots(1, 3)\n", 606 | "ax1.imshow(val_pred_arr)\n", 607 | "ax1.set_title('Prediction')\n", 608 | "ax2.imshow(val_target_arr)\n", 609 | "ax2.set_title('Target')\n", 610 | "ax3.imshow(np.abs(val_pred_arr - val_target_arr))\n", 611 | "ax3.set_title('Absolute error')" 612 | ] 613 | } 614 | ], 615 | "metadata": { 616 | "kernelspec": { 617 | "display_name": "Python 2", 618 | "language": "python", 619 | "name": "python2" 620 | }, 621 | "language_info": { 622 | "codemirror_mode": { 623 | "name": "ipython", 624 | "version": 2 625 | }, 626 | "file_extension": ".py", 627 | "mimetype": "text/x-python", 628 | "name": "python", 629 | "nbconvert_exporter": "python", 630 | "pygments_lexer": "ipython2", 631 | "version": "2.7.13" 632 | } 633 | }, 634 | "nbformat": 4, 635 | "nbformat_minor": 2 636 | } 637 | -------------------------------------------------------------------------------- /code/section1/video1_5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Mastering PyTorch\n", 8 | "\n", 9 | "## Supervised learning\n", 10 | "\n", 11 | "### Visualize the training in Visdom\n", 12 | "\n", 13 | "#### Accompanying notebook to Video 1.5" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "# Include libraries\n", 25 | "\n", 26 | "import numpy as np\n", 27 | "from PIL import Image\n", 28 | "\n", 29 | "import os\n", 30 | "import random\n", 31 | "\n", 32 | "import torch\n", 33 | "import torch.nn as nn\n", 34 | "import torch.optim as optim\n", 35 | "import torch.nn.functional as F\n", 36 | "from torch.utils.data import Dataset, DataLoader\n", 37 | "from torch.autograd import Variable\n", 38 | "\n", 39 | "from torchvision import transforms\n", 40 | "import torchvision.transforms.functional as TF\n", 41 | "\n", 42 | "from utils import get_image_name, get_number_of_cells, \\\n", 43 | " split_data, download_data, SEED\n", 44 | "\n", 45 | "from sklearn.model_selection import train_test_split\n", 46 | "\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "%matplotlib inline" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "root = './'\n", 60 | "download_data(root=root)\n", 61 | "\n", 62 | "data_paths = os.path.join('./', 'data_paths.txt')\n", 63 | "if not os.path.exists(data_paths):\n", 64 | " !wget http://pbialecki.de/mastering_pytorch/data_paths.txt\n", 65 | "\n", 66 | "if not os.path.isfile(data_paths):\n", 67 | " print('data_paths.txt missing!')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": { 74 | "collapsed": false 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "# Setup Globals\n", 79 | "use_cuda = torch.cuda.is_available()\n", 80 | "data_paths = os.path.join('./', 'data', 'data_paths.txt')\n", 81 | "np.random.seed(SEED)\n", 82 | "torch.manual_seed(SEED)\n", 83 | "if use_cuda:\n", 84 | " torch.cuda.manual_seed(SEED)\n", 85 | " print('Using: {}'.format(torch.cuda.get_device_name(0)))\n", 86 | "print_steps = 10" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": { 93 | "collapsed": true 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "# Utility functions\n", 98 | "def weights_init(m):\n", 99 | " '''\n", 100 | " Initialize the weights of each Conv2d layer using xavier_uniform\n", 101 | " (\"Understanding the difficulty of training deep feedforward\n", 102 | " neural networks\" - Glorot, X. & Bengio, Y. (2010))\n", 103 | " '''\n", 104 | " if isinstance(m, nn.Conv2d):\n", 105 | " nn.init.xavier_uniform(m.weight.data)\n", 106 | " elif isinstance(m, nn.ConvTranspose2d):\n", 107 | " nn.init.xavier_uniform(m.weight.data)\n", 108 | "\n", 109 | "def dice_loss(y_target, y_pred, smooth=0.0):\n", 110 | " y_target = y_target.view(-1)\n", 111 | " y_pred = y_pred.view(-1)\n", 112 | " intersection = (y_target * y_pred).sum()\n", 113 | " dice_coef = (2. * intersection + smooth) / (\n", 114 | " y_target.sum() + y_pred.sum() + smooth)\n", 115 | " return 1. - dice_coef" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "collapsed": true 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "class CellDataset(Dataset):\n", 127 | " def __init__(self, image_paths, target_paths, size, train=False):\n", 128 | " self.image_paths = image_paths\n", 129 | " self.target_paths = target_paths\n", 130 | " self.size = size\n", 131 | " resize_size = [s+10 for s in self.size]\n", 132 | " self.resize_image = transforms.Resize(\n", 133 | " size=resize_size, interpolation=Image.BILINEAR)\n", 134 | " self.resize_mask = transforms.Resize(\n", 135 | " size=resize_size, interpolation=Image.NEAREST)\n", 136 | " self.train = train\n", 137 | " \n", 138 | " def transform(self, image, mask):\n", 139 | " # Resize\n", 140 | " image = self.resize_image(image)\n", 141 | " mask = self.resize_mask(mask)\n", 142 | " \n", 143 | " # Perform data augmentation\n", 144 | " if self.train: \n", 145 | " # Random cropping\n", 146 | " i, j, h, w = transforms.RandomCrop.get_params(\n", 147 | " image, output_size=self.size)\n", 148 | " image = TF.crop(image, i, j, h, w)\n", 149 | " mask = TF.crop(mask, i, j, h, w)\n", 150 | " \n", 151 | " # Random horizontal flipping\n", 152 | " if random.random() > 0.5:\n", 153 | " image = TF.hflip(image)\n", 154 | " mask = TF.hflip(mask)\n", 155 | " \n", 156 | " # Random vertical flipping\n", 157 | " if random.random() > 0.5:\n", 158 | " image = TF.vflip(image)\n", 159 | " mask = TF.vflip(mask)\n", 160 | " else:\n", 161 | " center_crop = transforms.CenterCrop(self.size)\n", 162 | " image = center_crop(image)\n", 163 | " mask = center_crop(mask)\n", 164 | " \n", 165 | " # Transform to tensor\n", 166 | " image = TF.to_tensor(image)\n", 167 | " mask = TF.to_tensor(mask)\n", 168 | " return image, mask\n", 169 | "\n", 170 | " def __getitem__(self, index):\n", 171 | " image = Image.open(self.image_paths[index])\n", 172 | " mask = Image.open(self.target_paths[index])\n", 173 | " x, y = self.transform(image, mask)\n", 174 | " return x, y\n", 175 | "\n", 176 | " def __len__(self):\n", 177 | " return len(self.image_paths)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "collapsed": true 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "def get_random_sample(dataset):\n", 189 | " '''\n", 190 | " Get a random sample from the specified dataset.\n", 191 | " '''\n", 192 | " data, target = dataset[int(np.random.choice(len(dataset), 1))]\n", 193 | " data.unsqueeze_(0)\n", 194 | " target.unsqueeze_(0)\n", 195 | " if use_cuda:\n", 196 | " data = data.cuda()\n", 197 | " target = target.cuda()\n", 198 | " data = Variable(data)\n", 199 | " target = Variable(target)\n", 200 | " return data, target" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "collapsed": true 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "class BaseConv(nn.Module):\n", 212 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 213 | " stride):\n", 214 | " super(BaseConv, self).__init__()\n", 215 | "\n", 216 | " self.act = nn.ReLU()\n", 217 | "\n", 218 | " self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding,\n", 219 | " stride)\n", 220 | "\n", 221 | " self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size,\n", 222 | " padding, stride)\n", 223 | " \n", 224 | " self.downsample = None\n", 225 | " if in_channels != out_channels:\n", 226 | " self.downsample = nn.Sequential(\n", 227 | " nn.Conv2d(\n", 228 | " in_channels, out_channels, kernel_size, padding, stride)\n", 229 | " )\n", 230 | "\n", 231 | " def forward(self, x):\n", 232 | " residual = x\n", 233 | " out = self.act(self.conv1(x))\n", 234 | " out = self.conv2(out)\n", 235 | " \n", 236 | " if self.downsample:\n", 237 | " residual = self.downsample(x)\n", 238 | " out += residual\n", 239 | " out = self.act(out)\n", 240 | " return out" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": { 247 | "collapsed": true 248 | }, 249 | "outputs": [], 250 | "source": [ 251 | "class DownConv(nn.Module):\n", 252 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 253 | " stride):\n", 254 | " super(DownConv, self).__init__()\n", 255 | "\n", 256 | " self.pool1 = nn.MaxPool2d(kernel_size=2)\n", 257 | " self.conv_block = BaseConv(in_channels, out_channels, kernel_size,\n", 258 | " padding, stride)\n", 259 | "\n", 260 | " def forward(self, x):\n", 261 | " x = self.pool1(x)\n", 262 | " x = self.conv_block(x)\n", 263 | " return x" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": { 270 | "collapsed": true 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "class UpConv(nn.Module):\n", 275 | " def __init__(self, in_channels, in_channels_skip, out_channels,\n", 276 | " kernel_size, padding, stride):\n", 277 | " super(UpConv, self).__init__()\n", 278 | "\n", 279 | " self.conv_trans1 = nn.ConvTranspose2d(\n", 280 | " in_channels, in_channels, kernel_size=2, padding=0, stride=2)\n", 281 | " self.conv_block = BaseConv(\n", 282 | " in_channels=in_channels + in_channels_skip,\n", 283 | " out_channels=out_channels,\n", 284 | " kernel_size=kernel_size,\n", 285 | " padding=padding,\n", 286 | " stride=stride)\n", 287 | "\n", 288 | " def forward(self, x, x_skip):\n", 289 | " x = self.conv_trans1(x)\n", 290 | " x = torch.cat((x, x_skip), dim=1)\n", 291 | " x = self.conv_block(x)\n", 292 | " return x" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": { 299 | "collapsed": true 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "class ResUNet(nn.Module):\n", 304 | " def __init__(self, in_channels, out_channels, kernel_size, padding,\n", 305 | " stride):\n", 306 | " super(ResUNet, self).__init__()\n", 307 | "\n", 308 | " self.init_conv = BaseConv(in_channels, out_channels, kernel_size, padding, stride)\n", 309 | "\n", 310 | " self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size,\n", 311 | " padding, stride)\n", 312 | "\n", 313 | " self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size,\n", 314 | " padding, stride)\n", 315 | "\n", 316 | " self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size,\n", 317 | " padding, stride)\n", 318 | "\n", 319 | " self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels,\n", 320 | " kernel_size, padding, stride)\n", 321 | "\n", 322 | " self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels,\n", 323 | " kernel_size, padding, stride)\n", 324 | "\n", 325 | " self.up1 = UpConv(2 * out_channels, out_channels, out_channels,\n", 326 | " kernel_size, padding, stride)\n", 327 | "\n", 328 | " self.out = nn.Conv2d(out_channels, 1, kernel_size, padding, stride)\n", 329 | "\n", 330 | " def forward(self, x):\n", 331 | " # Encoder\n", 332 | " x = self.init_conv(x)\n", 333 | " x1 = self.down1(x)\n", 334 | " x2 = self.down2(x1)\n", 335 | " x3 = self.down3(x2)\n", 336 | " # Decoder\n", 337 | " x_up = self.up3(x3, x2)\n", 338 | " x_up = self.up2(x_up, x1)\n", 339 | " x_up = self.up1(x_up, x)\n", 340 | " x_out = F.sigmoid(self.out(x_up))\n", 341 | " return x_out" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": { 348 | "collapsed": true 349 | }, 350 | "outputs": [], 351 | "source": [ 352 | "def train(epoch, visualize=False):\n", 353 | " '''\n", 354 | " Main training loop\n", 355 | " '''\n", 356 | " global win_loss\n", 357 | " global win_images\n", 358 | " # Set model to train mode\n", 359 | " model.train()\n", 360 | " # Iterate training set\n", 361 | " for batch_idx, (data, mask) in enumerate(train_loader):\n", 362 | " if use_cuda:\n", 363 | " data = data.cuda()\n", 364 | " mask = mask.cuda()\n", 365 | " data = Variable(data)\n", 366 | " mask = Variable(mask.squeeze())\n", 367 | " optimizer.zero_grad()\n", 368 | " output = model(data)\n", 369 | " loss_mask = criterion(output.squeeze(), mask)\n", 370 | " loss_dice = dice_loss(mask, output.squeeze())\n", 371 | " loss = loss_mask + loss_dice\n", 372 | " loss.backward()\n", 373 | " optimizer.step()\n", 374 | " \n", 375 | " if batch_idx % print_steps == 0:\n", 376 | " loss_mask_data = loss_mask.data[0]\n", 377 | " loss_dice_data = loss_dice.data[0]\n", 378 | " train_losses.append(loss_mask_data)\n", 379 | " print(\n", 380 | " 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tLoss(dice): {:.6f}'.\n", 381 | " format(epoch, batch_idx * len(data),\n", 382 | " len(train_loader.dataset), 100. * batch_idx / len(\n", 383 | " train_loader), loss_mask_data, loss_dice_data))\n", 384 | " \n", 385 | " x_idx = (epoch - 1) * len(train_loader) + batch_idx\n", 386 | " losses = [loss_mask_data, loss_dice_data]\n", 387 | " win_loss = visualize_losses(losses, x_idx, win_loss)\n", 388 | " \n", 389 | " if visualize:\n", 390 | " # Visualize some images in Visdom\n", 391 | " nb_images = 4\n", 392 | " images_pred = output.data[:nb_images].cpu()\n", 393 | " images_target = mask.data[:nb_images].cpu().unsqueeze(1)\n", 394 | " images_input = data.data[:nb_images].cpu()\n", 395 | " images = torch.zeros(3 * images_pred.size(0), *images_pred.size()[1:])\n", 396 | " images[::3] = images_input\n", 397 | " images[1::3] = images_pred\n", 398 | " images[2::3] = images_target\n", 399 | " # Resize images to fit in visdom\n", 400 | " images = resize_tensors(images)\n", 401 | " images = make_grid(images, nrow=3, pad_value=0.5)\n", 402 | " win_images = visualize_images(\n", 403 | " images.numpy(), win_images, title='Training: input - prediction - target')" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "metadata": { 410 | "collapsed": true 411 | }, 412 | "outputs": [], 413 | "source": [ 414 | "def validate():\n", 415 | " '''\n", 416 | " Validation loop\n", 417 | " '''\n", 418 | " global win_eval_images\n", 419 | " # Set model to eval mode\n", 420 | " model.eval()\n", 421 | " # Setup val_loss\n", 422 | " val_mask_loss = 0\n", 423 | " val_dice_loss = 0\n", 424 | " # Disable gradients (to save memory)\n", 425 | " with torch.no_grad():\n", 426 | " # Iterate validation set\n", 427 | " for data, mask in val_loader:\n", 428 | " if use_cuda:\n", 429 | " data = data.cuda()\n", 430 | " mask = mask.cuda()\n", 431 | " data = Variable(data)\n", 432 | " mask = Variable(mask.squeeze())\n", 433 | " output = model(data)\n", 434 | " val_mask_loss += F.binary_cross_entropy(output.squeeze(), mask).data[0]\n", 435 | " val_dice_loss += dice_loss(mask, output.squeeze()).data[0]\n", 436 | " # Calculate mean of validation loss\n", 437 | " val_mask_loss /= len(val_loader)\n", 438 | " val_dice_loss /= len(val_loader)\n", 439 | " val_losses.append(val_mask_loss)\n", 440 | " print('Validation\\tLoss: {:.6f}\\tLoss(dice): {:.6f}'.format(val_mask_loss, val_dice_loss))\n", 441 | " \n", 442 | " # Visualize some images in Visdom\n", 443 | " nb_images = 4\n", 444 | " images_pred = output.data[:nb_images].cpu()\n", 445 | " images_target = mask.data[:nb_images].cpu().unsqueeze(1)\n", 446 | " images_input = data.data[:nb_images].cpu()\n", 447 | " images = torch.zeros(3 * images_pred.size(0), *images_pred.size()[1:])\n", 448 | " images[::3] = images_input\n", 449 | " images[1::3] = images_pred\n", 450 | " images[2::3] = images_target\n", 451 | " # Resize images to fit in visdom\n", 452 | " images = resize_tensors(images)\n", 453 | " images = make_grid(images, nrow=3, pad_value=0.5)\n", 454 | " win_eval_images = visualize_images(\n", 455 | " images.numpy(), win_eval_images, title='Validation: input - prediction - target')" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "metadata": { 462 | "collapsed": true 463 | }, 464 | "outputs": [], 465 | "source": [ 466 | "# Get train data folders and split to training / validation set\n", 467 | "with open(data_paths, 'r') as f:\n", 468 | " data_paths = f.readlines()\n", 469 | "image_paths = [line.split(',')[0].strip() for line in data_paths]\n", 470 | "target_paths = [line.split(',')[1].strip() for line in data_paths]\n", 471 | "\n", 472 | "# Split data into train/validation datasets\n", 473 | "im_path_train, im_path_val, tar_path_train, tar_path_val = split_data(\n", 474 | " image_paths, target_paths)" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": { 481 | "collapsed": true 482 | }, 483 | "outputs": [], 484 | "source": [ 485 | "# Create datasets\n", 486 | "train_dataset = CellDataset(\n", 487 | " image_paths=im_path_train,\n", 488 | " target_paths=tar_path_train,\n", 489 | " size=(96, 96),\n", 490 | " train=True\n", 491 | ")\n", 492 | "val_dataset = CellDataset(\n", 493 | " image_paths=im_path_val,\n", 494 | " target_paths=tar_path_val,\n", 495 | " size=(96, 96),\n", 496 | " train=False\n", 497 | ")\n", 498 | "\n", 499 | "# Wrap in DataLoader\n", 500 | "train_loader = DataLoader(\n", 501 | " dataset=train_dataset,\n", 502 | " batch_size=32,\n", 503 | " num_workers=12,\n", 504 | " shuffle=True\n", 505 | ")\n", 506 | "val_loader = DataLoader(\n", 507 | " dataset=val_dataset,\n", 508 | " batch_size=64,\n", 509 | " num_workers=12,\n", 510 | " shuffle=True\n", 511 | ")" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": null, 517 | "metadata": { 518 | "collapsed": true 519 | }, 520 | "outputs": [], 521 | "source": [ 522 | "# Creae model\n", 523 | "model = ResUNet(\n", 524 | " in_channels=1, out_channels=32, kernel_size=3, padding=1, stride=1)\n", 525 | "# Initialize weights\n", 526 | "model.apply(weights_init)\n", 527 | "# Push to GPU, if available\n", 528 | "if use_cuda:\n", 529 | " model.cuda()" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": null, 535 | "metadata": { 536 | "collapsed": true 537 | }, 538 | "outputs": [], 539 | "source": [ 540 | "# Create optimizer and scheduler\n", 541 | "optimizer = optim.SGD(model.parameters(), lr=1e-3)\n", 542 | "# Create criterion\n", 543 | "criterion = nn.BCELoss()" 544 | ] 545 | }, 546 | { 547 | "cell_type": "markdown", 548 | "metadata": {}, 549 | "source": [ 550 | "#### Create visdom helper functions" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": null, 556 | "metadata": { 557 | "collapsed": true 558 | }, 559 | "outputs": [], 560 | "source": [ 561 | "from visdom import Visdom\n", 562 | "from torchvision.utils import make_grid\n", 563 | "# Setup visdom\n", 564 | "viz = Visdom(port=6006)\n", 565 | "win_loss = None\n", 566 | "win_images = None\n", 567 | "win_eval_loss = None\n", 568 | "win_eval_images = None\n", 569 | "\n", 570 | "def visualize_losses(losses, x_idx, win):\n", 571 | " if not win:\n", 572 | " win = viz.line(\n", 573 | " Y=np.column_stack(losses),\n", 574 | " X=np.column_stack([x_idx] * len(losses)),\n", 575 | " opts=dict(\n", 576 | " showlegend=True,\n", 577 | " xlabel='iteration',\n", 578 | " ylabel='BCELoss',\n", 579 | " ytype='log',\n", 580 | " title='Losses',\n", 581 | " legend=['Loss(mask)', 'Loss(dice)']))\n", 582 | " else:\n", 583 | " win = viz.line(\n", 584 | " Y=np.column_stack(losses),\n", 585 | " X=np.column_stack([x_idx] * len(losses)),\n", 586 | " opts=dict(showlegend=True),\n", 587 | " win=win,\n", 588 | " update='append')\n", 589 | " return win\n", 590 | "\n", 591 | "def visualize_images(images, win, title=''):\n", 592 | " if not win:\n", 593 | " win = viz.images(tensor=images, opts=dict(title=title))\n", 594 | " else:\n", 595 | " win = viz.images(tensor=images, win=win, opts=dict(title=title))\n", 596 | " return win\n", 597 | "\n", 598 | "def resize_tensors(tensors, size=(128, 128)):\n", 599 | " to_pil = transforms.ToPILImage()\n", 600 | " res = transforms.Resize(size=size)\n", 601 | " to_tensor = transforms.ToTensor()\n", 602 | " images = torch.stack([to_tensor(res(to_pil(t))) for t in tensors])\n", 603 | " return images" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": null, 609 | "metadata": { 610 | "collapsed": false 611 | }, 612 | "outputs": [], 613 | "source": [ 614 | "# Start training\n", 615 | "train_losses, val_losses = [], []\n", 616 | "epochs = 30\n", 617 | "for epoch in range(1, epochs):\n", 618 | " train(epoch, visualize=True)\n", 619 | " validate()" 620 | ] 621 | } 622 | ], 623 | "metadata": { 624 | "kernelspec": { 625 | "display_name": "Python 2", 626 | "language": "python", 627 | "name": "python2" 628 | }, 629 | "language_info": { 630 | "codemirror_mode": { 631 | "name": "ipython", 632 | "version": 2 633 | }, 634 | "file_extension": ".py", 635 | "mimetype": "text/x-python", 636 | "name": "python", 637 | "nbconvert_exporter": "python", 638 | "pygments_lexer": "ipython2", 639 | "version": "2.7.13" 640 | } 641 | }, 642 | "nbformat": 4, 643 | "nbformat_minor": 2 644 | } 645 | --------------------------------------------------------------------------------