├── README.md ├── code ├── inference.py ├── model_def.py ├── train.py └── unit_tests.py ├── histopathology_mil_train.ipynb └── infer_pytorch.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Attention-Based Deep Multiple Instance Learning 2 | ## for prostate cancer diagnosis using PyTorch and AWS SageMaker data parallelism 3 | 4 | [Medium article](https://jmg764.medium.com/attention-based-deep-multiple-instance-learning-1bb3df857e24) 5 | -------------------------------------------------------------------------------- /code/inference.py: -------------------------------------------------------------------------------- 1 | 2 | # Licensed to the Apache Software Foundation (ASF) under one 3 | # or more contributor license agreements. See the NOTICE file 4 | # distributed with this work for additional information 5 | # regarding copyright ownership. The ASF licenses this file 6 | # to you under the Apache License, Version 2.0 (the 7 | # "License"); you may not use this file except in compliance 8 | # with the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, 13 | # software distributed under the License is distributed on an 14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | # KIND, either express or implied. See the License for the 16 | # specific language governing permissions and limitations 17 | # under the License. 18 | 19 | from __future__ import print_function 20 | 21 | import os 22 | import torch 23 | 24 | # Network definition 25 | from model_def import Attention 26 | 27 | def model_fn(model_dir): 28 | print("In model_fn. Model directory is -") 29 | print(model_dir) 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | model = Attention() 32 | with open(os.path.join(model_dir, 'model.pth'), 'rb') as f: 33 | print("Loading the histopathology mil model") 34 | model.load_state_dict(torch.load(f, map_location=device)) 35 | return model 36 | 37 | -------------------------------------------------------------------------------- /code/model_def.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | import torch.nn as nn 21 | 22 | class Attention(nn.Module): 23 | def __init__(self): 24 | super(Attention, self).__init__() 25 | self.L = 512 # 512 node fully connected layer 26 | self.D = 128 # 128 node attention layer 27 | self.K = 1 28 | 29 | self.feature_extractor_part1 = nn.Sequential( 30 | nn.Conv2d(3, 36, kernel_size=4), 31 | nn.ReLU(), 32 | nn.MaxPool2d(2, stride=2), 33 | nn.Conv2d(36, 48, kernel_size=3), 34 | nn.ReLU(), 35 | nn.MaxPool2d(2, stride=2) 36 | ) 37 | 38 | self.feature_extractor_part2 = nn.Sequential( 39 | nn.Linear(48 * 30 * 30, self.L), 40 | nn.ReLU(), 41 | nn.Dropout(), 42 | nn.Linear(self.L, self.L), 43 | nn.ReLU(), 44 | nn.Dropout() 45 | ) 46 | 47 | self.attention = nn.Sequential( 48 | nn.Linear(self.L, self.D), 49 | nn.Tanh(), 50 | nn.Linear(self.D, self.K) 51 | ) 52 | 53 | self.classifier = nn.Sequential( 54 | nn.Linear(self.L * self.K, 1), 55 | nn.Sigmoid() 56 | ) 57 | 58 | def forward(self, x): 59 | x = x.squeeze(0) 60 | 61 | H = self.feature_extractor_part1(x) 62 | H = H.view(-1, 48 * 30 * 30) 63 | H = self.feature_extractor_part2(H) 64 | 65 | A = self.attention(H) # NxK 66 | A = torch.transpose(A, 1, 0) # KxN 67 | A = F.softmax(A, dim=1) # softmax over N 68 | 69 | M = torch.mm(A, H) 70 | 71 | Y_prob = self.classifier(M) 72 | Y_hat = torch.ge(Y_prob, 0.5).float() 73 | 74 | return Y_prob, Y_hat, A.byte() 75 | 76 | def calculate_classification_error(self, X, Y): 77 | Y = Y.float() 78 | _, Y_hat, _ = self.forward(X) 79 | error = 1. - Y_hat.eq(Y).cpu().float().mean().data 80 | 81 | return error, Y_hat 82 | 83 | def calculate_objective(self, X, Y): 84 | Y = Y.float() 85 | Y_prob, _, A = self.forward(X) 86 | Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5) 87 | neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob)) 88 | 89 | return neg_log_likelihood, A 90 | 91 | 92 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | from __future__ import print_function 15 | import os 16 | import argparse 17 | import time 18 | import numpy as np 19 | import pandas as pd 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | import torch.optim as optim 24 | import torch.nn as nn 25 | from torchvision import datasets, transforms 26 | from torch.autograd import Variable 27 | from torch.optim.lr_scheduler import StepLR 28 | import torch.utils.data as data_utils 29 | import PIL 30 | from sklearn.model_selection import train_test_split 31 | from sklearn.metrics import accuracy_score 32 | 33 | # Import SMDataParallel PyTorch Modules 34 | import sagemaker 35 | from sagemaker import get_execution_role 36 | import boto3 37 | from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP 38 | import smdistributed.dataparallel.torch.distributed as dist 39 | 40 | # Network definition 41 | from model_def import Attention 42 | 43 | dist.init_process_group() 44 | cuda = False 45 | 46 | class TileDataset(data_utils.Dataset): 47 | """ 48 | Custom dataset class for tiles 49 | __getitem__ returns a batch of 16 tiles along with their isup_grade 50 | """ 51 | def __init__(self, img_path, dataframe, num_tiles, transform=None): 52 | """ 53 | img_zip: Where the images are stored 54 | dataframe: The train.csv dataframe 55 | num_tiles: How many tiles should the dataset return per sample 56 | transform: The function to apply to the image. Usually data augmentation. DO NOT DO NORMALIZATION here. 57 | """ 58 | self.img_path = img_path 59 | self.df = dataframe 60 | self.num_tiles = num_tiles 61 | self.img_list = list(self.df['image_id']) 62 | self.transform = transform 63 | 64 | def __getitem__(self, idx): 65 | img_id = self.img_list[idx] 66 | 67 | tiles = ['/'+img_id + '_' + str(i) + '.png' for i in range(0, self.num_tiles)] 68 | image_tiles = [] 69 | 70 | for tile in tiles: 71 | image = PIL.Image.open(self.img_path+tile) 72 | 73 | if self.transform is not None: 74 | image = self.transform(image) 75 | 76 | image = 1 - image 77 | image = transforms.Normalize([1.0-0.90949707, 1.0-0.8188697, 1.0-0.87795304], [0.1279171 , 0.24528177, 0.16098117])(image) 78 | image_tiles.append(image) 79 | 80 | image_tiles = torch.stack(image_tiles, dim=0) 81 | 82 | return torch.tensor(image_tiles), torch.tensor(self.df.iloc[idx]['isup_grade']) 83 | 84 | def __len__(self): 85 | return len(self.img_list) 86 | 87 | def train(model, device, train_loader, optimizer, epoch): 88 | model.train() 89 | train_loss = 0. 90 | train_error = 0. 91 | predictions = [] 92 | labels = [] 93 | for batch_idx, (data, label) in enumerate(train_loader): 94 | print('epoch = ', epoch) 95 | print('batch_idx = ', batch_idx) 96 | bag_label = label 97 | data = torch.squeeze(data) 98 | if cuda: 99 | data, bag_label = data.cuda(), bag_label.cuda() 100 | data, bag_label = Variable(data), Variable(bag_label) 101 | data, bag_label = data.to(device), bag_label.to(device) 102 | 103 | # Reset gradients 104 | optimizer.zero_grad() 105 | # Calculate loss 106 | loss, attention_weights = model.calculate_objective(data, bag_label) 107 | train_loss += loss.data[0] 108 | # Calculate error 109 | error, predicted_label = model.calculate_classification_error(data, bag_label) 110 | train_error += error 111 | 112 | # Keep track of predictions and labels to calculate accuracy after each epoch 113 | _, Y_hat, _ = model(data) 114 | predictions.append(int(Y_hat)) 115 | labels.append(int(bag_label)) 116 | # Backward pass 117 | loss.backward() 118 | # Update model weights 119 | optimizer.step() 120 | 121 | # Calculate loss and error for epoch 122 | train_loss /= len(train_loader) 123 | train_error /= len(train_loader) 124 | 125 | print('Train Set, Epoch: {}, Loss: {:.4f}, Error: {:.4f}, Accuracy: {:.2f}%'.format(epoch, train_loss.cpu().numpy()[0], train_error, accuracy_score(labels, predictions)*100)) 126 | 127 | 128 | def test(model, device, test_loader): 129 | model.eval() 130 | test_loss = 0. 131 | test_error = 0. 132 | with torch.no_grad(): 133 | for batch_idx, (data, label) in enumerate(test_loader): 134 | bag_label = label 135 | data = torch.squeeze(data) 136 | 137 | data, bag_label = Variable(data), Variable(bag_label) 138 | data, bag_label = data.to(device), bag_label.to(device) 139 | 140 | loss, attention_weights = model.calculate_objective(data, bag_label) 141 | test_loss += loss.data[0] 142 | error, predicted_label = model.calculate_classification_error(data, bag_label) 143 | test_error += error 144 | 145 | test_error /= len(test_loader) 146 | test_loss /= len(test_loader) 147 | 148 | print('\nTest Set, Loss: {:.4f}, Error: {:.4f}'.format(test_loss.cpu().numpy()[0], test_error)) 149 | 150 | 151 | def get_csv(directory, df, num): 152 | """ 153 | Given the image directory and panda_dataset.csv 154 | (tiles in image directory are a subset of those in panda_dataset.csv), 155 | returns a DataFrame consisting the first num/2 benign and num/2 malignant 156 | tiles available in the directory along with their isup_grade and gleason_score 157 | """ 158 | 159 | # Getting tiles that are in S3 160 | tiles_list = [] 161 | for image in os.listdir(directory): 162 | tiles_list.append(image.split('_')[0]) 163 | 164 | # Creating dataframe containing labels for each tile in S3 165 | tiles_df = pd.DataFrame(columns=['image_id', 'data_provider', 'isup_grade', 'gleason_score']) 166 | for i in range(len(tiles_list)): 167 | tiles_df = tiles_df.append(df.loc[df['image_id'] == tiles_list[i]]) 168 | 169 | # Drop duplicates first 170 | tiles_df = tiles_df.drop_duplicates() 171 | 172 | # Select the first 312 benign and 312 malignant slides 173 | benign = tiles_df[tiles_df.isup_grade == 0][:int(num/2)] 174 | malignant = tiles_df[tiles_df.isup_grade == 1][:int(num/2)] 175 | 176 | tiles_df = pd.concat([benign, malignant]) 177 | 178 | return tiles_df 179 | 180 | def main(): 181 | # Training settings 182 | parser = argparse.ArgumentParser(description='Histopathology MIL') 183 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 184 | help='input batch size for training (default: 64)') 185 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 186 | help='input batch size for testing (default: 1000)') 187 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 188 | help='number of epochs to train (default: 14)') 189 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 190 | help='learning rate (default: 1.0)') 191 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 192 | help='Learning rate step gamma (default: 0.7)') 193 | parser.add_argument('--seed', type=int, default=1, metavar='S', 194 | help='random seed (default: 1)') 195 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 196 | help='how many batches to wait before logging training status') 197 | parser.add_argument('--save-model', action='store_true', default=False, 198 | help='For Saving the current Model') 199 | parser.add_argument('--verbose', action='store_true', default=False, 200 | help='For displaying SMDataParallel-specific logs') 201 | parser.add_argument('--data-path', type=str, default='/tmp/data', help='Path for downloading ' 202 | 'the MNIST dataset') 203 | 204 | # Model checkpoint location 205 | parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) 206 | 207 | args = parser.parse_args() 208 | args.world_size = dist.get_world_size() 209 | args.rank = rank = dist.get_rank() 210 | args.local_rank = local_rank = dist.get_local_rank() 211 | args.lr = 1.0 212 | args.batch_size //= args.world_size // 8 213 | args.batch_size = max(args.batch_size, 1) 214 | data_path = args.data_path 215 | 216 | if args.verbose: 217 | print('Hello from rank', rank, 'of local_rank', 218 | local_rank, 'in world size of', args.world_size) 219 | 220 | if not torch.cuda.is_available(): 221 | raise Exception("Must run SMDataParallel on CUDA-capable devices.") 222 | 223 | torch.manual_seed(args.seed) 224 | 225 | device = torch.device("cuda") 226 | 227 | bucket = 'sagemaker-us-east-1-318322629142' 228 | 229 | train_dir = '/opt/ml/input/data/training' 230 | 231 | dataset_csv_key = 'panda_dataset.csv' 232 | dataset_csv_dir = 's3://{}/{}'.format(bucket, dataset_csv_key) 233 | 234 | # Load panda_dataset.csv 235 | df = pd.read_csv(dataset_csv_dir) 236 | 237 | # Replace isup_grade scores of 1,2,3,4, & 5 with 1 to indicate malignant 238 | df['isup_grade'] = df['isup_grade'].replace([1,2,3,4,5], 1) 239 | 240 | tiles_df = get_csv(train_dir, df, 624) 241 | train_df, test_df = train_test_split(tiles_df) 242 | 243 | # Save dataframes to s3 bucket (test_df is used in infer_pytorch.ipynb) 244 | train_df.to_csv('s3://{}/{}'.format(bucket, 'train_df')) 245 | test_df.to_csv('s3://{}/{}'.format(bucket, 'test_df')) 246 | 247 | transform_train = transforms.Compose([transforms.RandomHorizontalFlip(0.5), 248 | transforms.RandomVerticalFlip(0.5), 249 | transforms.ToTensor()]) 250 | 251 | train_set = TileDataset(train_dir, train_df, 16, transform=transform_train) 252 | 253 | train_loader = data_utils.DataLoader(train_set, 1, shuffle=True, num_workers=0) 254 | 255 | 256 | device = torch.device("cuda") 257 | model = DDP(Attention().to(device)) 258 | torch.cuda.set_device(local_rank) 259 | model.cuda(local_rank) 260 | optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0005) 261 | 262 | print('Start Training') 263 | for epoch in range(1, 10 + 1): 264 | train(model, device, train_loader, optimizer, epoch) 265 | 266 | if rank == 0: 267 | print("Saving the model...") 268 | torch.save(model.state_dict(), "/opt/ml/model/model.pt") 269 | 270 | 271 | if __name__ == '__main__': 272 | main() 273 | 274 | 275 | -------------------------------------------------------------------------------- /code/unit_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import model_def 4 | 5 | class RunTests(unittest.TestCase): 6 | def __init__(self, *args, **kwargs): 7 | super(RunTests, self).__init__() 8 | 9 | # Ensure shape of dataset is correct 10 | def test_equal(self, train_loader): 11 | data, _ = iter(train_loader).next() 12 | self.assertEqual(torch.Size((16, 3, 128, 128)), data.shape) 13 | 14 | # Testing if dataset works with the dataloader 15 | def test_single_process_dataloader(self, data): 16 | with self.subTest(split='train'): 17 | self._check_dataloader(data, shuffle=True, num_workers=0) 18 | with self.subTest(split='test'): 19 | self._check_dataloader(data, shuffle=False, num_workers=0) 20 | 21 | def _check_dataloader(self, data, shuffle, num_workers): 22 | loader = torch.data_utils.DataLoader(data, batch_size=1, shuffle=shuffle, num_workers=num_workers) 23 | for _ in loader: 24 | pass 25 | 26 | # Ensure that there aren't any dead sub-graphs (i.e. any learnable parameters that aren't used 27 | # in the forward pass, backward pass, or both) 28 | def test_all_parameters_updated(self): 29 | net = model_def.Attention() 30 | optim = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0005) 31 | 32 | loss, _ = net.calculate_objective(torch.randn(16, 3, 128, 128), torch.tensor([[1]])) 33 | 34 | loss.backward() 35 | optim.step() 36 | 37 | for param_name, param in net.named_parameters(): 38 | if param.requires_grad: 39 | with self.subTest(name=param_name): 40 | self.assertIsNotNone(param.grad) 41 | self.assertNotEqual(0., torch.sum(param.grad ** 2)) 42 | -------------------------------------------------------------------------------- /histopathology_mil_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "histopathology_mil_train.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "metadata": { 22 | "id": "91Z1sh2Tvp0n" 23 | }, 24 | "source": [ 25 | "!pip install sagemaker --upgrade" 26 | ], 27 | "execution_count": null, 28 | "outputs": [] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "metadata": { 33 | "id": "54NwFLHm1fhd" 34 | }, 35 | "source": [ 36 | "import sagemaker\n", 37 | "\n", 38 | "sagemaker_session = sagemaker.Session()\n", 39 | "role = sagemaker.get_execution_role()" 40 | ], 41 | "execution_count": null, 42 | "outputs": [] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "metadata": { 47 | "id": "eCmOiMC11lc0" 48 | }, 49 | "source": [ 50 | "from sagemaker.pytorch import PyTorch\n", 51 | "estimator = PyTorch(base_job_name='pytorch-smdataparallel-histopathology-mil',\n", 52 | " source_dir='code',\n", 53 | " entry_point='train.py',\n", 54 | " role=role,\n", 55 | " framework_version='1.8.1',\n", 56 | " py_version='py36',\n", 57 | " instance_count=2,\n", 58 | " instance_type= 'ml.p3.16xlarge',\n", 59 | " sagemaker_session=sagemaker_session,\n", 60 | " distribution={'smdistributed':{\n", 61 | " 'dataparallel':{\n", 62 | " 'enabled': True\n", 63 | " }\n", 64 | " }\n", 65 | " },\n", 66 | " debugger_hook_config=False,\n", 67 | " volume_size=40)" 68 | ], 69 | "execution_count": null, 70 | "outputs": [] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "metadata": { 75 | "id": "RWv47F7WfPKV" 76 | }, 77 | "source": [ 78 | "channels = {\n", 79 | " 'training': 's3://sagemaker-us-east-1-318322629142/tiles/',\n", 80 | "}" 81 | ], 82 | "execution_count": null, 83 | "outputs": [] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "metadata": { 88 | "id": "Yq_dA5ArvI_r" 89 | }, 90 | "source": [ 91 | "estimator.fit(inputs=channels)" 92 | ], 93 | "execution_count": null, 94 | "outputs": [] 95 | } 96 | ] 97 | } 98 | -------------------------------------------------------------------------------- /infer_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "kVG5nrRlFAkR" 7 | }, 8 | "source": [ 9 | "## SageMaker endpoint\n", 10 | "To deploy the model you previously trained, you need to create a Sagemaker Endpoint. This is a hosted prediction service that you can use to perform inference." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import io\n", 20 | "from PIL import Image\n", 21 | "\n", 22 | "import torch\n", 23 | "import torch.utils.data as data_utils\n", 24 | "from torchvision import datasets, transforms\n", 25 | "from torch.utils.data import DataLoader\n", 26 | "\n", 27 | "import boto3\n", 28 | "import pandas as pd\n", 29 | "from sklearn.metrics import accuracy_score" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "MbFVlkdCE6xi" 36 | }, 37 | "source": [ 38 | "### Finding the model\n", 39 | "This notebook uses a stored model if it exists. If you recently ran a training example that use the %store% magic, it will be restored in the next cell.\n", 40 | "\n", 41 | "Otherwise, you can pass the URI to the model file (a .tar.gz file) in the model_data variable.\n", 42 | "\n", 43 | "You can find your model files through the SageMaker console by choosing Training > Training jobs in the left navigation pane. Find your recent training job, choose it, and then look for the s3:// link in the Output pane. Uncomment the model_data line in the next cell that manually sets the model's URI." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 78, 49 | "metadata": { 50 | "id": "ArIXc850D18u" 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "no stored variable or alias model_data\n", 58 | "Using this model: s3://sagemaker-us-east-1-318322629142/pytorch-smdataparallel-histopathology-m-2021-05-18-18-59-25-802/output/model.tar.gz\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "# Retrieve a saved model from a previous notebook run's stored variable\n", 64 | "%store -r model_data\n", 65 | "\n", 66 | "# If no model was found, set it manually here.\n", 67 | "model_data = 's3://sagemaker-us-east-1-318322629142/pytorch-smdataparallel-histopathology-m-2021-05-18-18-59-25-802/output/model.tar.gz'\n", 68 | "\n", 69 | "print(\"Using this model: {}\".format(model_data))" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "id": "RDJ6ENagFNu-" 76 | }, 77 | "source": [ 78 | "### Create a model object\n", 79 | "You define the model object by using SageMaker SDK's PyTorchModel and pass in the model from the estimator and the entry_point. The endpoint's entry point for inference is defined by model_fn as seen in the following code block that prints out inference.py. The function loads the model and sets it to use a GPU, if available." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 79, 85 | "metadata": { 86 | "id": "ox9u5cCVEHCH" 87 | }, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "\u001b[37m# Licensed to the Apache Software Foundation (ASF) under one\u001b[39;49;00m\r\n", 94 | "\u001b[37m# or more contributor license agreements. See the NOTICE file\u001b[39;49;00m\r\n", 95 | "\u001b[37m# distributed with this work for additional information\u001b[39;49;00m\r\n", 96 | "\u001b[37m# regarding copyright ownership. The ASF licenses this file\u001b[39;49;00m\r\n", 97 | "\u001b[37m# to you under the Apache License, Version 2.0 (the\u001b[39;49;00m\r\n", 98 | "\u001b[37m# \"License\"); you may not use this file except in compliance\u001b[39;49;00m\r\n", 99 | "\u001b[37m# with the License. You may obtain a copy of the License at\u001b[39;49;00m\r\n", 100 | "\u001b[37m#\u001b[39;49;00m\r\n", 101 | "\u001b[37m# http://www.apache.org/licenses/LICENSE-2.0\u001b[39;49;00m\r\n", 102 | "\u001b[37m#\u001b[39;49;00m\r\n", 103 | "\u001b[37m# Unless required by applicable law or agreed to in writing,\u001b[39;49;00m\r\n", 104 | "\u001b[37m# software distributed under the License is distributed on an\u001b[39;49;00m\r\n", 105 | "\u001b[37m# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\u001b[39;49;00m\r\n", 106 | "\u001b[37m# KIND, either express or implied. See the License for the\u001b[39;49;00m\r\n", 107 | "\u001b[37m# specific language governing permissions and limitations\u001b[39;49;00m\r\n", 108 | "\u001b[37m# under the License.\u001b[39;49;00m\r\n", 109 | "\r\n", 110 | "\u001b[34mfrom\u001b[39;49;00m \u001b[04m\u001b[36m__future__\u001b[39;49;00m \u001b[34mimport\u001b[39;49;00m print_function\r\n", 111 | "\r\n", 112 | "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36mos\u001b[39;49;00m\r\n", 113 | "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36mtorch\u001b[39;49;00m\r\n", 114 | "\r\n", 115 | "\u001b[37m# Network definition\u001b[39;49;00m\r\n", 116 | "\u001b[34mfrom\u001b[39;49;00m \u001b[04m\u001b[36mmodel_def\u001b[39;49;00m \u001b[34mimport\u001b[39;49;00m Attention\r\n", 117 | "\r\n", 118 | "\u001b[34mdef\u001b[39;49;00m \u001b[32mmodel_fn\u001b[39;49;00m(model_dir):\r\n", 119 | " \u001b[36mprint\u001b[39;49;00m(\u001b[33m\"\u001b[39;49;00m\u001b[33mIn model_fn. Model directory is -\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\r\n", 120 | " \u001b[36mprint\u001b[39;49;00m(model_dir)\r\n", 121 | " device = torch.device(\u001b[33m\"\u001b[39;49;00m\u001b[33mcuda\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m \u001b[34mif\u001b[39;49;00m torch.cuda.is_available() \u001b[34melse\u001b[39;49;00m \u001b[33m\"\u001b[39;49;00m\u001b[33mcpu\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\r\n", 122 | " model = Attention()\r\n", 123 | " \u001b[34mwith\u001b[39;49;00m \u001b[36mopen\u001b[39;49;00m(os.path.join(model_dir, \u001b[33m'\u001b[39;49;00m\u001b[33mmodel.pth\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m), \u001b[33m'\u001b[39;49;00m\u001b[33mrb\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m) \u001b[34mas\u001b[39;49;00m f:\r\n", 124 | " \u001b[36mprint\u001b[39;49;00m(\u001b[33m\"\u001b[39;49;00m\u001b[33mLoading the histopathology mil model\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\r\n", 125 | " model.load_state_dict(torch.load(f, map_location=device))\r\n", 126 | " \u001b[34mreturn\u001b[39;49;00m model\r\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "!pygmentize code/inference.py" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 80, 137 | "metadata": { 138 | "id": "lrBb96KOEKT0" 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "import sagemaker\n", 143 | "role = sagemaker.get_execution_role()\n", 144 | "\n", 145 | "from sagemaker.pytorch import PyTorchModel\n", 146 | "model = PyTorchModel(model_data=model_data, source_dir='code',\n", 147 | " entry_point='inference.py', role=role, framework_version='1.6.0', py_version='py3')" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": { 153 | "id": "6u6HzY_lFTBv" 154 | }, 155 | "source": [ 156 | "#### Deploy the model on an endpoint\n", 157 | "You create a predictor by using the model.deploy function. You can optionally change both the instance count and instance type." 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 81, 163 | "metadata": { 164 | "id": "fmSumXLREaMV" 165 | }, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "------------------!" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "predictor = model.deploy(initial_instance_count=1, instance_type='ml.m5.24xlarge')" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": { 182 | "id": "pOn7-dTgFW58" 183 | }, 184 | "source": [ 185 | "### Test the model\n", 186 | "You can test the depolyed model using samples from the test set." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 83, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "def image_from_s3(bucket, key):\n", 196 | " bucket = s3_resource.Bucket(bucket)\n", 197 | " image = bucket.Object(key)\n", 198 | " img_data = image.get().get('Body').read()\n", 199 | "\n", 200 | " return Image.open(io.BytesIO(img_data))" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 86, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "class TileDataset(data_utils.Dataset):\n", 210 | "\n", 211 | " def __init__(self, img_path, folder_num, dataframe, num_tiles, transform=None):\n", 212 | " \"\"\"\n", 213 | " img_path: Where the images are stored\n", 214 | " dataframe: The train.csv dataframe\n", 215 | " num_tiles: How many tiles should the dataset return per sample\n", 216 | " transform: The function to apply to the image. Usually dataaugmentation. Do not do normalization here.\n", 217 | " \"\"\"\n", 218 | " self.img_path = img_path\n", 219 | " self.folder_num = folder_num\n", 220 | " self.df = dataframe\n", 221 | " self.num_tiles = num_tiles\n", 222 | " self.img_list = list(self.df['image_id'])\n", 223 | " self.transform = transform\n", 224 | "\n", 225 | " def __getitem__(self, idx):\n", 226 | " img_id = self.img_list[idx]\n", 227 | "\n", 228 | " tiles = ['test_'+str(self.folder_num)+'/'+img_id + '_' + str(i) + '.png' for i in range(0, self.num_tiles)]\n", 229 | " image_tiles = []\n", 230 | " \n", 231 | "\n", 232 | " for tile in tiles:\n", 233 | " image = image_from_s3(self.img_path, tile)\n", 234 | "\n", 235 | " if self.transform is not None:\n", 236 | " image = self.transform(image)\n", 237 | "\n", 238 | " image = 1 - image\n", 239 | " image = transforms.Normalize([1.0-0.90949707, 1.0-0.8188697, 1.0-0.87795304], [0.1279171 , 0.24528177, 0.16098117])(image)\n", 240 | " image_tiles.append(image)\n", 241 | "\n", 242 | " image_tiles = torch.stack(image_tiles, dim=0)\n", 243 | "\n", 244 | " return torch.tensor(image_tiles), torch.tensor(self.df.iloc[idx]['isup_grade'])\n", 245 | "\n", 246 | " def __len__(self):\n", 247 | " return len(self.img_list)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 87, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "def get_csv(bucket, folder_num, df):\n", 257 | " # Getting tiles that are in S3\n", 258 | " print('Collecting list of tiles')\n", 259 | " tiles_set = set()\n", 260 | " bucket = s3_resource.Bucket('sagemaker-us-east-1-318322629142')\n", 261 | " for key in bucket.objects.all():\n", 262 | " if 'test_'+str(folder_num) in key.key:\n", 263 | " tiles_set.add(key.key.split('/')[1].split('_')[0])\n", 264 | " tiles_list = list(tiles_set)\n", 265 | " \n", 266 | " print('Creating dataframe')\n", 267 | " # Creating dataframe containing labels for each tile in S3\n", 268 | " tiles_df = pd.DataFrame(columns=['image_id', 'data_provider', 'isup_grade', 'gleason_score'])\n", 269 | " for i in range(len(tiles_list)):\n", 270 | " tiles_df = tiles_df.append(df.loc[df['image_id'] == tiles_list[i]])\n", 271 | " \n", 272 | " tiles_df = tiles_df.drop_duplicates()\n", 273 | " return tiles_df" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 108, 279 | "metadata": { 280 | "id": "cRjZBmrCFdfk", 281 | "scrolled": true 282 | }, 283 | "outputs": [ 284 | { 285 | "name": "stdout", 286 | "output_type": "stream", 287 | "text": [ 288 | "Collecting list of tiles\n", 289 | "Creating dataframe\n", 290 | " image_id data_provider isup_grade gleason_score\n", 291 | "1732 2b730c057bde4c56e79f693e3d577138 radboud 1 4+5\n", 292 | "1727 2b4d629c0b0a02ddfb05cc41c0c8dc65 karolinska 1 4+4\n", 293 | "1707 2ac5f9c41e6b9a004fc0cecf6c3083be karolinska 0 3+3\n", 294 | "1774 2c8fd1d0ab8640342f6d10a0a54e5279 karolinska 0 0+0\n", 295 | "1254 1fc49bfab631583981f96f285ec0c94d karolinska 1 4+5\n", 296 | "... ... ... ... ...\n", 297 | "1709 2ad0f2857a4552a25127205fd04a5e9f radboud 0 3+3\n", 298 | "1725 2b340c9844077ddcdf641adac5f116e3 radboud 0 negative\n", 299 | "1680 2a1c3373688904fcabbdeb4a177972f8 radboud 0 3+3\n", 300 | "1249 1fb65315d7ded63d688194863a1b123e karolinska 1 5+5\n", 301 | "1257 1fe0cfea7347950a76bcbdafa0ad96ab karolinska 0 3+4\n", 302 | "\n", 303 | "[125 rows x 4 columns]\n", 304 | "Creating data loader\n" 305 | ] 306 | }, 307 | { 308 | "name": "stderr", 309 | "output_type": "stream", 310 | "text": [ 311 | "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/__main__.py:44: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n" 312 | ] 313 | }, 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "batch_idx = 0\n", 319 | "batch_idx = 1\n", 320 | "batch_idx = 2\n", 321 | "batch_idx = 3\n", 322 | "batch_idx = 4\n", 323 | "batch_idx = 5\n", 324 | "batch_idx = 6\n", 325 | "batch_idx = 7\n", 326 | "batch_idx = 8\n", 327 | "batch_idx = 9\n", 328 | "batch_idx = 10\n", 329 | "batch_idx = 11\n", 330 | "batch_idx = 12\n", 331 | "batch_idx = 13\n", 332 | "batch_idx = 14\n", 333 | "batch_idx = 15\n", 334 | "batch_idx = 16\n", 335 | "batch_idx = 17\n", 336 | "batch_idx = 18\n", 337 | "batch_idx = 19\n", 338 | "batch_idx = 20\n", 339 | "batch_idx = 21\n", 340 | "batch_idx = 22\n", 341 | "batch_idx = 23\n", 342 | "batch_idx = 24\n", 343 | "batch_idx = 25\n", 344 | "batch_idx = 26\n", 345 | "batch_idx = 27\n", 346 | "batch_idx = 28\n", 347 | "batch_idx = 29\n", 348 | "batch_idx = 30\n", 349 | "batch_idx = 31\n", 350 | "batch_idx = 32\n", 351 | "batch_idx = 33\n", 352 | "batch_idx = 34\n", 353 | "batch_idx = 35\n", 354 | "batch_idx = 36\n", 355 | "batch_idx = 37\n", 356 | "batch_idx = 38\n", 357 | "batch_idx = 39\n", 358 | "batch_idx = 40\n", 359 | "batch_idx = 41\n", 360 | "batch_idx = 42\n", 361 | "batch_idx = 43\n", 362 | "batch_idx = 44\n", 363 | "batch_idx = 45\n", 364 | "batch_idx = 46\n", 365 | "batch_idx = 47\n", 366 | "batch_idx = 48\n", 367 | "batch_idx = 49\n", 368 | "batch_idx = 50\n", 369 | "batch_idx = 51\n", 370 | "batch_idx = 52\n", 371 | "batch_idx = 53\n", 372 | "batch_idx = 54\n", 373 | "batch_idx = 55\n", 374 | "batch_idx = 56\n", 375 | "batch_idx = 57\n", 376 | "batch_idx = 58\n", 377 | "batch_idx = 59\n", 378 | "batch_idx = 60\n", 379 | "batch_idx = 61\n", 380 | "batch_idx = 62\n", 381 | "batch_idx = 63\n", 382 | "batch_idx = 64\n", 383 | "batch_idx = 65\n", 384 | "batch_idx = 66\n", 385 | "batch_idx = 67\n", 386 | "batch_idx = 68\n", 387 | "batch_idx = 69\n", 388 | "batch_idx = 70\n", 389 | "batch_idx = 71\n", 390 | "batch_idx = 72\n", 391 | "batch_idx = 73\n", 392 | "batch_idx = 74\n", 393 | "batch_idx = 75\n", 394 | "batch_idx = 76\n", 395 | "batch_idx = 77\n", 396 | "batch_idx = 78\n", 397 | "batch_idx = 79\n", 398 | "batch_idx = 80\n", 399 | "batch_idx = 81\n", 400 | "batch_idx = 82\n", 401 | "batch_idx = 83\n", 402 | "batch_idx = 84\n", 403 | "batch_idx = 85\n", 404 | "batch_idx = 86\n", 405 | "batch_idx = 87\n", 406 | "batch_idx = 88\n", 407 | "batch_idx = 89\n", 408 | "batch_idx = 90\n", 409 | "batch_idx = 91\n", 410 | "batch_idx = 92\n", 411 | "batch_idx = 93\n", 412 | "batch_idx = 94\n", 413 | "batch_idx = 95\n", 414 | "batch_idx = 96\n", 415 | "batch_idx = 97\n", 416 | "batch_idx = 98\n", 417 | "batch_idx = 99\n", 418 | "batch_idx = 100\n", 419 | "batch_idx = 101\n", 420 | "batch_idx = 102\n", 421 | "batch_idx = 103\n", 422 | "batch_idx = 104\n", 423 | "batch_idx = 105\n", 424 | "batch_idx = 106\n", 425 | "batch_idx = 107\n", 426 | "batch_idx = 108\n", 427 | "batch_idx = 109\n", 428 | "batch_idx = 110\n", 429 | "batch_idx = 111\n", 430 | "batch_idx = 112\n", 431 | "batch_idx = 113\n", 432 | "batch_idx = 114\n", 433 | "batch_idx = 115\n", 434 | "batch_idx = 116\n", 435 | "batch_idx = 117\n", 436 | "batch_idx = 118\n", 437 | "batch_idx = 119\n", 438 | "batch_idx = 120\n", 439 | "batch_idx = 121\n", 440 | "batch_idx = 122\n", 441 | "batch_idx = 123\n", 442 | "batch_idx = 124\n" 443 | ] 444 | } 445 | ], 446 | "source": [ 447 | "bucket = 'sagemaker-us-east-1-318322629142'\n", 448 | "\n", 449 | "dataset_csv_key = 'panda_dataset.csv'\n", 450 | "dataset_csv_dir = 's3://{}/{}'.format(bucket, dataset_csv_key)\n", 451 | "df = pd.read_csv(dataset_csv_dir)\n", 452 | "\n", 453 | "df['isup_grade'] = df['isup_grade'].replace([1,2], 0)\n", 454 | "df['isup_grade'] = df['isup_grade'].replace([3,4,5], 1)\n", 455 | "\n", 456 | "test_df = get_csv(bucket, 1, df)\n", 457 | "print(test_df)\n", 458 | "\n", 459 | "transform_train = transforms.Compose([transforms.RandomHorizontalFlip(0.5),\n", 460 | " transforms.RandomVerticalFlip(0.5),\n", 461 | " transforms.ToTensor()])\n", 462 | "\n", 463 | "\n", 464 | "print('Creating data loader')\n", 465 | "test_set = TileDataset(bucket, 1, test_df, 16, transform=transform_train)\n", 466 | "\n", 467 | "batch_size = 1\n", 468 | "test_loader = data_utils.DataLoader(test_set, batch_size, shuffle=False, num_workers=0)\n", 469 | "\n", 470 | "predictions = []\n", 471 | "true_labels = []\n", 472 | "for batch_idx, (data, label) in enumerate(test_loader):\n", 473 | " print('batch_idx = ', batch_idx)\n", 474 | " _, Y_hat, _ = predictor.predict(data)\n", 475 | " predictions.append(int(Y_hat))\n", 476 | " true_labels.append(int(label))\n", 477 | "\n" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 255, 483 | "metadata": {}, 484 | "outputs": [ 485 | { 486 | "name": "stdout", 487 | "output_type": "stream", 488 | "text": [ 489 | "0.672\n" 490 | ] 491 | } 492 | ], 493 | "source": [ 494 | "print(accuracy_score(true_labels, predictions))" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": { 500 | "id": "KqoA8ZlXFgqW" 501 | }, 502 | "source": [ 503 | "#### Cleanup\n", 504 | "If you don't intend on trying out inference or to do anything else with the endpoint, you should delete it." 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "metadata": { 511 | "id": "rY8U_J9HFiNy" 512 | }, 513 | "outputs": [], 514 | "source": [ 515 | "predictor.delete_endpoint()" 516 | ] 517 | } 518 | ], 519 | "metadata": { 520 | "colab": { 521 | "name": "infer_pytorch.ipynb", 522 | "provenance": [] 523 | }, 524 | "kernelspec": { 525 | "display_name": "conda_pytorch_p36", 526 | "language": "python", 527 | "name": "conda_pytorch_p36" 528 | }, 529 | "language_info": { 530 | "codemirror_mode": { 531 | "name": "ipython", 532 | "version": 3 533 | }, 534 | "file_extension": ".py", 535 | "mimetype": "text/x-python", 536 | "name": "python", 537 | "nbconvert_exporter": "python", 538 | "pygments_lexer": "ipython3", 539 | "version": "3.6.13" 540 | } 541 | }, 542 | "nbformat": 4, 543 | "nbformat_minor": 1 544 | } 545 | --------------------------------------------------------------------------------