├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── image-classification
├── .gitignore
├── code
│ ├── .gitignore
│ ├── projectorTrans.py
│ ├── tokenizerTrans.py
│ └── vt-resnet-34.py
└── vt-resnet34-pytorch-sagemaker.ipynb
└── img
└── vt.png
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | this software and associated documentation files (the "Software"), to deal in
5 | the Software without restriction, including without limitation the rights to
6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software is furnished to do so.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15 |
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Image Classification using Visual Transformers
2 | - [Overview](#overview)
3 | - [Amazon SageMaker](#-amazon-sagemaker)
4 | - [How to run the code in Amazon SageMaker Studio?](#-how-to-run-the-code-in-amazon-sagemaker-studio)
5 | - [References](#-references)
6 |
7 | In standard image classification algorithms like ResNet, InceptionNet etc., images are represented as pixel arrays on which a series of convolution operations are performed. Although, great accuracy has been achieved with these algorithms, the convolution operation is computationally expensive. Therefore, in this notebook we will look at an alternative way to perform `Image Classification` using the ideas mentioned in the `Visual Transformers: Token-based Image Representation and Processing for Computer Vision` [research paper](https://arxiv.org/pdf/2006.03677.pdf).
8 |
9 |
10 |
11 | Diagram of a Visual Transformer (VT).
12 |
13 | For a given image, we first apply convolutional layers to extract low-level
14 | features. The output feature map is then fed to VT: First, apply a tokenizer, grouping pixels into a small number of visual
15 | tokens. Second, apply transformers to model relationships between tokens.
16 | Third, visual tokens are directly used for image classification or projected back to the feature map for semantic segmentation.
17 |
18 | **Note**
19 | - Dataset used is **Intel Image Classification** from [Kaggle](https://www.kaggle.com/puneet6060/intel-image-classification).
20 | - The notebook is only an example and not to be used for production deployments.
21 | - Use `Python3 (PyTorch 1.6 Python 3.6 CPU Optimized)` kernel and `ml.m5.large (2 vCPU + 8 GiB)` for the notebook, if you are using Amazon SageMaker Studio.
22 | - Notebook has ideas and some of the pseudo code from `Visual Transformers: Token-based Image Representation and Processing for Computer Vision` [research paper](https://arxiv.org/pdf/2006.03677.pdf) but does not reproduces the results mentioned in the paper.
23 |
24 | ## Amazon SageMaker
25 | ----
26 | Amazon SageMaker is the most comprehensive and full managed machine learning service. With SageMaker, data scientists and developers can quickly and easily build and train machine learning models, and then directly deploy them into a production-ready hosted environment. It provides an integrated Jupyter authoring notebook instance for easy access to your data sources for exploration and analysis, so you don't have to manage servers. It also provides common machine learning algorithms that are optimized to run efficiently against extremely large data in a distributed environment. With native support for bring-your-own-algorithms and frameworks, SageMaker offers flexible distributed training options that adjust to your specific workflows. Deploy a model into a secure and scalable environment by launching it with a few clicks from SageMaker Studio or the SageMaker console. We use Amazon SageMaker Studio for running the code, for more details see the [AWS documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/studio.html).
27 |
28 | ## How to run the code in Amazon SageMaker Studio?
29 | ----
30 | If you haven't used Amazon SageMaker Studio before, please follow the steps mentioned in [`Onboard to Amazon SageMaker Studio`](https://docs.aws.amazon.com/sagemaker/latest/dg/gs-studio-onboard.html).
31 |
32 | ### To log in from the SageMaker console
33 |
34 | - Onboard to Amazon SageMaker Studio. If you've already onboarded, skip to the next step.
35 | - Open the SageMaker console.
36 | - Choose Amazon SageMaker Studio.
37 | - The Amazon SageMaker Studio Control Panel opens.
38 | - In the Amazon SageMaker Studio Control Panel, you'll see a list of user names.
39 | - Next to your user name, choose Open Studio.
40 |
41 | ### Open a Studio notebook
42 | SageMaker Studio can only open notebooks listed in the Studio file browser. In this example we will `Clone a Git Repository in SageMaker Studio`.
43 |
44 | #### To clone the repo
45 |
46 | - In the left sidebar, choose the File Browser icon (
).
47 | - Choose the root folder or the folder you want to clone the repo into.
48 | - In the left sidebar, choose the Git icon (
).
49 | - Choose Clone a Repository.
50 | - Enter the URI for the repo https://github.com/aws-samples/amazon-sagemaker-visual-transformer.git.
51 | - Choose CLONE.
52 | - If the repo requires credentials, you are prompted to enter your username and password.
53 | - Wait for the download to finish. After the repo has been cloned, the File Browser opens to display the cloned repo.
54 | - Double click the repo to open it.
55 | - Choose the Git icon to view the Git user interface which now tracks the examples repo.
56 | - To track a different repo, open the repo in the file browser and then choose the Git icon.
57 |
58 | ### To open a notebook
59 |
60 | - In the left sidebar, choose the File Browser icon (
) to display the file browser.
61 | - Browse to a notebook file and double-click it to open the notebook in a new tab.
62 |
63 | ## References
64 | ----
65 | - Visual Transformers: Token-based Image Representation and Processing for
66 | Computer Vision (https://arxiv.org/pdf/2006.03677.pdf).
67 | - Kaggle notebook (https://www.kaggle.com/asollie/intel-image-multiclass-pytorch-94-test-acc).
68 | - Registry of Open Data from AWS (https://registry.opendata.aws/multimedia-commons/).
69 |
70 |
71 | ## Security
72 |
73 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
74 |
75 | ## License
76 |
77 | This library is licensed under the MIT-0 License. See the LICENSE file.
78 |
79 |
--------------------------------------------------------------------------------
/image-classification/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
--------------------------------------------------------------------------------
/image-classification/code/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
--------------------------------------------------------------------------------
/image-classification/code/projectorTrans.py:
--------------------------------------------------------------------------------
1 | # Projector module to fuse transformer output with the feature map.
2 |
3 | import torch
4 | import torchvision
5 | import torch.nn as nn
6 | import numpy as np
7 | import torch.nn.functional as F
8 |
9 | class Projector(nn.Module):
10 | def __init__(self, CT, C, head=16, groups=16):
11 | super(Projector , self).__init__()
12 | self.proj_value_conv = nn.Conv1d(CT, C, kernel_size=1)
13 | self.proj_key_conv = nn.Conv1d(CT, C, kernel_size=1)
14 | self.proj_query_conv = nn.Conv2d(C, CT, kernel_size=1,groups=groups)
15 | self.head = head
16 |
17 | def forward(self, feature, token):
18 | N, L, CT = token.shape
19 | token = token.view(N, CT, L)
20 | h = self.head
21 | proj_v = self.proj_value_conv(token).view(N, h, -1, L)
22 | proj_k = self.proj_key_conv(token).view(N, h, -1, L)
23 | proj_q = self.proj_query_conv(feature)
24 | N, C, H, W = proj_q.shape
25 | proj_q = proj_q.view(N, h, C // h, H * W).permute(0, 1, 3, 2)
26 | proj_coef = F.softmax(torch.Tensor.matmul(proj_q, proj_k) / np.sqrt(C / h), dim=3)
27 | proj = torch.Tensor.matmul(proj_v, proj_coef.permute(0, 1, 3, 2))
28 | _, _, H, W = feature.shape
29 | return feature + proj.view(N, -1, H, W), token
30 |
--------------------------------------------------------------------------------
/image-classification/code/tokenizerTrans.py:
--------------------------------------------------------------------------------
1 | # Tokenizer module to convert feature maps into visual tokens.
2 |
3 | import torch
4 | import torchvision
5 | import torch.nn as nn
6 | import numpy as np
7 | import torch.nn.functional as F
8 |
9 | class Tokenizer(nn.Module):
10 | def __init__(self, L, CT, C, head=16, groups=16, dynamic=False, input_channels=256):
11 | super(Tokenizer , self).__init__()
12 | # Code for adjusting the channel sizes in case C is not equal to CT
13 | self.feature = nn.Conv2d(input_channels, C, kernel_size=1)
14 | if not dynamic :
15 | # use static weights to compute token coefficients.
16 | self.conv_token_coef = nn.Conv2d(C, L, kernel_size=1)
17 | else:
18 | # use previous tokens to compute a query weight, which is
19 | # then used to compute token coefficients.
20 | self.conv_query = nn.Conv1d(CT, C, kernel_size=1)
21 | self.conv_key = nn.Conv2d(C, C, kernel_size=1, groups=groups)
22 | self.conv_value = nn.Conv2d(C, C,kernel_size=1, groups=groups)
23 | self.head = head
24 | self.dynamic = dynamic
25 | self.C = C
26 |
27 | def forward(self, feature, tokens=0):
28 | N, C, H, W = feature.shape
29 | if C != self.C:
30 | feature = self.feature(feature)
31 | # compute token coefficients
32 | #feature: N, C, H, W, token: N, CT, L
33 | if not self.dynamic :
34 | token_coef = self.conv_token_coef(feature)
35 | N, L, H, W = token_coef.shape
36 | token_coef = token_coef.view(N, 1, L, H * W)
37 | token_coef = token_coef.permute(0, 1, 3, 2) # N, 1, HW, L
38 | token_coef = token_coef / np.sqrt(feature.shape[1])
39 | else:
40 | L = tokens.shape[2]
41 | # Split input tokens
42 | T_a, T_b = tokens[:, :, :L // 2], tokens[:, :, L // 2:]
43 | query = self.conv_query(T_a)
44 | N, C, L_a = query.shape
45 | query = query.view(N, self.head, C // self.head, L_a)
46 | N, C, H, W = feature.shape
47 | key = self.conv_key(feature).view(N, self.head, C // self.head, H * W) # N, h, C//h, HW
48 | # Compute token coefficients.
49 | # N, h, HW, L_a
50 | token_coef = torch.Tensor.matmul(key.permute(0, 1, 3, 2), query)
51 | token_coef = token_coef / np.sqrt(C / self.head)
52 | N, C, H, W = feature.shape
53 | token_coef = F.softmax(token_coef , dim=2)
54 | value = self.conv_value(feature).view(N, self.head, C // self.head, H * W) # N, h, C//h, HW
55 | # extract tokens from the feature map
56 | # static tokens: N, C, L. dynamic tokens: N, C, L_a
57 | tokens = torch.Tensor.matmul(value, token_coef).view(N, C, -1)
58 | tokens = tokens.view(N, L, C)
59 | return feature, tokens
60 |
--------------------------------------------------------------------------------
/image-classification/code/vt-resnet-34.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | from pathlib import Path
6 | import sys
7 | import torch, torchvision
8 | import torch.distributed as dist
9 | import torch.nn as nn
10 | import torchvision.transforms as T
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | import torch.utils.data
14 | import torch.utils.data.distributed
15 | from torchvision import datasets, transforms
16 | import tokenizerTrans as tokenizer
17 | import projectorTrans as projector
18 | import numpy as np
19 | import pandas as pd
20 | from torch.optim import lr_scheduler
21 | from torchvision.datasets import ImageFolder
22 | from torch.utils.data import DataLoader
23 | from torchvision import models
24 | import time
25 | from PIL import Image
26 | import requests
27 |
28 | logger = logging.getLogger(__name__)
29 | logger.setLevel(logging.DEBUG)
30 | logger.addHandler(logging.StreamHandler(sys.stdout))
31 |
32 | DATASETS = ['train', 'val']
33 |
34 | mean_nums = [0.485, 0.456, 0.406]
35 | std_nums = [0.229, 0.224, 0.225]
36 |
37 | transforms = {'train': T.Compose([
38 | T.RandomResizedCrop(size=224),
39 | T.RandomRotation(degrees=15),
40 | T.ToTensor(),
41 | T.RandomHorizontalFlip(),
42 | T.Normalize(mean_nums, std_nums)
43 | ]), 'val': T.Compose([
44 | T.Resize(size=224),
45 | T.CenterCrop(size=224),
46 | T.ToTensor(),
47 | T.Normalize(mean_nums, std_nums)
48 | ]),}
49 |
50 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51 |
52 | # Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
53 | class VT(nn.Module):
54 | # Constructor
55 | def __init__(self, L, CT, C):
56 | super(VT, self).__init__()
57 | self.bn = nn.BatchNorm2d(256)
58 | self.tokenizer = tokenizer.Tokenizer(L=L,CT=CT, C=C)
59 | self.transformer = nn.Transformer(nhead=16, num_encoder_layers=5, num_decoder_layers = 0, dim_feedforward=2048, activation='relu', dropout=0.5)
60 | self.projector = projector.Projector(CT=CT, C=C)
61 |
62 | def forward(self, x):
63 | x = self.bn(x)
64 | x, token = self.tokenizer(x)
65 | token = self.transformer(token, token)
66 | out, token = self.projector(x,token)
67 | return out
68 |
69 | def _get_data_loader(batch_size, training_dir):
70 | logger.info("Get data loaders")
71 | dataset = {
72 | d: ImageFolder(f'{training_dir}/{d}', transforms[d]) for d in DATASETS
73 | }
74 | dataset_sizes = {d: len(dataset[d]) for d in DATASETS}
75 | logger.info(f'dataset sizes: {dataset_sizes}')
76 | data_loaders = {
77 | d: DataLoader(dataset[d], batch_size=batch_size, shuffle=True) for d in DATASETS
78 | }
79 | return data_loaders, dataset_sizes
80 |
81 | def create_model(n_classes):
82 | model = models.resnet34(pretrained=True)
83 | model.layer4 = VT(L=8, CT=512, C=512)
84 | n_features = model.fc.in_features
85 | model.fc = nn.Linear(n_features, n_classes)
86 | return model.to(device)
87 |
88 | def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):
89 | model = model.train() #Convert to train mode
90 | losses = []
91 | correct_predictions = 0
92 | for inputs, labels in data_loader:
93 | inputs = inputs.to(device) #Push array to gpu
94 | labels = labels.to(device)
95 | outputs = model(inputs) #get prob of output per class
96 | _, preds = torch.max(outputs, dim=1) # get max of pred
97 | loss = loss_fn(outputs, labels) # get loss
98 | correct_predictions += torch.sum(preds==labels)
99 | losses.append(loss.item())
100 | loss.backward()
101 | optimizer.step()
102 | optimizer.zero_grad()
103 | scheduler.step()
104 | return correct_predictions.double() / n_examples, np.mean(losses)
105 |
106 | def eval_model(model, data_loader, loss_fn, device, n_examples):
107 | model = model.eval() #Evaluation mode
108 | losses = []
109 | correct_predictions = 0
110 | with torch.no_grad():
111 | for inputs, labels in data_loader:
112 | inputs = inputs.to(device)
113 | labels = labels.to(device)
114 | outputs = model(inputs)
115 | _, preds = torch.max(outputs, dim=1)
116 | loss = loss_fn(outputs, labels)
117 | correct_predictions += torch.sum(preds==labels)
118 | losses.append(loss.item())
119 | return correct_predictions.double() / n_examples, np.mean(losses)
120 |
121 | def save_model(model, model_dir):
122 | path = os.path.join(model_dir, 'model.pth')
123 | # recommended way from http://pytorch.org/docs/master/notes/serialization.html
124 | torch.save(model.state_dict(), path)
125 | logger.info(f"Checkpoint: Saved the best model: {path} \n")
126 |
127 |
128 | def train(args):
129 | logger.info("Training using Visual Transformer ResNet34 model")
130 | logger.debug("\n Number of gpus available - {}".format(args.num_gpus))
131 | logger.debug(f"\n Device: {device}")
132 | train_loader, dataset_sizes = _get_data_loader(args.batch_size, args.data_dir)
133 | logger.info("Building Visual transformer (vt-resnet34) model from Resnet34 Pre-trained model. \n")
134 | model = create_model(args.num_classes)
135 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, nesterov=args.nesterov)
136 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
137 | loss_fn = nn.CrossEntropyLoss().to(device)
138 | best_accuracy = 0
139 | corresponding_loss = 0
140 | corresponding_epoch = 0
141 | start = time.time()
142 | for epoch in range(args.epochs):
143 | logger.info(f'\nEpoch {epoch + 1}/{args.epochs}')
144 | logger.info('-' * 10)
145 | train_acc, train_loss = train_epoch(model, train_loader['train'], loss_fn,
146 | optimizer, device, scheduler, dataset_sizes['train'])
147 | logger.info(f'Train_loss = {train_loss}; Train_accuracy = {train_acc};')
148 | val_acc, val_loss = eval_model(model, train_loader['val'], loss_fn, device, dataset_sizes['val'])
149 | logger.info(f'Valid_loss = {val_loss}; Valid_accuracy = {val_acc};')
150 | if val_acc >= best_accuracy:
151 | save_model(model, args.model_dir)
152 | best_accuracy = val_acc
153 | corresponding_loss = val_loss
154 | corresponding_epoch = epoch + 1
155 | end = time.time()
156 | logger.info(f'Best val accuracy: {best_accuracy}')
157 | logger.info(f'Corresponding loss: {corresponding_loss}')
158 | logger.info(f'Corresponding epoch: {corresponding_epoch}')
159 | logger.info(f'Runtime of the model is {round((end - start)/60, 2)} mins')
160 |
161 | def model_fn(model_dir):
162 | logger.info('model_fn')
163 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
164 | model = create_model(6)
165 | with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
166 | model.load_state_dict(torch.load(f))
167 | return model.to(device)
168 |
169 | def input_fn(request_body, content_type='application/json'):
170 | logger.info('Deserializing the input data.')
171 | if content_type == 'application/json':
172 | input_data = json.loads(request_body)
173 | url = input_data['url']
174 | logger.info(f'Image url: {url}')
175 | image_data = Image.open(requests.get(url, stream=True).raw)
176 | image_transform = T.Compose([
177 | T.Resize(size=256),
178 | T.CenterCrop(size=224),
179 | T.ToTensor(),
180 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
181 | ])
182 | return image_transform(image_data)
183 | else:
184 | logger.info('raising expception')
185 | raise Exception(f'Requested unsupported ContentType in content_type {content_type}')
186 |
187 | def predict_fn(input_data, model):
188 | logger.info('Generating prediction based on input parameters.')
189 | if torch.cuda.is_available():
190 | input_data = input_data.view(1, 3, 224, 224).cuda()
191 | else:
192 | input_data = input_data.view(1, 3, 224, 224)
193 | with torch.no_grad():
194 | model.eval()
195 | out = model(input_data)
196 | ps = torch.exp(out)
197 | return ps
198 |
199 | def output_fn(prediction_output, accept='application/json'):
200 | logger.info('Serializing the generated output.')
201 | classes = {0: 'buildings', 1: 'forest', 2: 'glacier', 3: 'mountain', 4: 'sea', 5: 'street'}
202 | topk, topclass = prediction_output.topk(3, dim=1)
203 | result = []
204 | for i in range(3):
205 | pred = {'prediction': classes[topclass.cpu().numpy()[0][i]], 'score': f'{topk.cpu().numpy()[0][i] * 100}%'}
206 | logger.info(f'Adding pediction: {pred}')
207 | result.append(pred)
208 | if accept == 'application/json':
209 | return json.dumps(result), accept
210 | raise Exception(f'Requested unsupported ContentType in Accept:{accept}')
211 |
212 |
213 | if __name__ == '__main__':
214 | parser = argparse.ArgumentParser()
215 |
216 | # Data and model checkpoints directories
217 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
218 | help='input batch size for training (default: 128)')
219 | parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
220 | help='input batch size for testing (default: 64)')
221 | parser.add_argument('--epochs', type=int, default=10, metavar='N',
222 | help='number of epochs to train (default: 10)')
223 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
224 | help='learning rate (default: 0.01)')
225 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
226 | help='SGD momentum (default: 0.9)')
227 | parser.add_argument('--step_size', type=int, default=7,
228 | help='step size for StepLR scheduler (default: 7)')
229 | parser.add_argument('--gamma', type=float, default=0.1,
230 | help='gamma for StepLR scheduler (default: 0.1)')
231 | parser.add_argument('--nesterov', type=bool, default=True,
232 | help='nesterov for SGD optimizer (default: True)')
233 | parser.add_argument('--seed', type=int, default=1, metavar='S',
234 | help='random seed (default: 1)')
235 | parser.add_argument('--log-interval', type=int, default=100, metavar='N',
236 | help='how many batches to wait before logging training status')
237 | parser.add_argument('--num_classes', type=int, default=None,
238 | help='number of classes')
239 |
240 | # Container environment
241 | parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
242 | parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
243 | parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
244 | parser.add_argument('--data-dir', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
245 | parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS'])
246 |
247 | train(parser.parse_args())
248 |
--------------------------------------------------------------------------------
/image-classification/vt-resnet34-pytorch-sagemaker.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Image Classification using Visual Transformers\n",
8 | "\n",
9 | "## Contents\n",
10 | "1. [Overview](#Overview)\n",
11 | "2. [Setup](#Setup)\n",
12 | "3. [Image pre-processing](#Image-Pre-processing)\n",
13 | "4. [Build and Train the Visual Transformer Model](#build-train-model)\n",
14 | "5. [Deploy VT-ResNet34 Model to SageMaker Endpoint](#Deploy-VT-ResNet34-Model-to-SageMaker-Endpoint)\n",
15 | "6. [References](#References)\n",
16 | "\n",
17 | "Note: \n",
18 | "- Dataset used is **Intel Image Classification** from [Kaggle](https://www.kaggle.com/puneet6060/intel-image-classification).\n",
19 | "- The notebook is only an example and not to be used for production deployments.\n",
20 | "- Use `Python3 (PyTorch 1.6 Python 3.6 CPU Optimized)` kernel and `ml.m5.large (2 vCPU + 8 GiB)` for the notebook."
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | "## Overview \n",
28 | "\n",
29 | "**_Important_** : Notebook has ideas and some of the pseudo code from `Visual Transformers: Token-based Image Representation and Processing for Computer Vision` [research paper](https://arxiv.org/pdf/2006.03677.pdf) but does not reproduce all the results mentioned in the paper. \n",
30 | "\n",
31 | "In standard image classification algorithms like ResNet, InceptionNet etc., images are represented as pixel arrays on which a series of convolution operations are performed. Although, great accuracy has been achieved with these algorithms, the convolution operation is computationally expensive. Therefore, in this notebook we will look at an alternative way to perform `Image Classification` using the ideas mentioned in the research paper. \n",
32 | "\n",
33 | "
\n",
34 | "\n",
35 | "Diagram of a Visual Transformer (VT). For a given image, we first apply convolutional layers to extract low-level\n",
36 | "features. The output feature map is then fed to Visual Transformer (VT): First, apply a tokenizer, grouping pixels into a small number of visual tokens. Second, apply transformers to model relation**s**hips between tokens.\n",
37 | "Third, visual tokens are directly used for image classification or projected back to the feature map for semantic segmentation.\n",
38 | " "
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {},
44 | "source": [
45 | "## Setup \n",
46 | "To start, let's import some Python libraries initialize a SageMaker session, S3 bucket & prefix, and IAM Role."
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "!pip install seaborn"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "# import python librabries and framework\n",
65 | "from pathlib import Path\n",
66 | "import numpy as np\n",
67 | "import cv2\n",
68 | "import PIL.Image as Image\n",
69 | "import seaborn as sns\n",
70 | "from pylab import rcParams\n",
71 | "import matplotlib.pyplot as plt\n",
72 | "from matplotlib import rc\n",
73 | "from matplotlib.ticker import MaxNLocator\n",
74 | "from glob import glob\n",
75 | "import shutil\n",
76 | "\n",
77 | "import torch, torchvision\n",
78 | "from torch import nn, optim\n",
79 | "import torchvision.transforms as T\n",
80 | "from torchvision.datasets import ImageFolder\n",
81 | "from torch.utils.data import DataLoader\n",
82 | "\n",
83 | "import sagemaker\n",
84 | "\n",
85 | "%matplotlib inline\n",
86 | "\n",
87 | "sns.set(style='whitegrid', palette='muted', font_scale=1.2)\n",
88 | "\n",
89 | "COLORS_PALETTE=[\"#01BEFE\",\"#FFDD00\",\"#FF7D00\",\"#FF006D\",\"#ADFF02\",\"#8F00FF\"]\n",
90 | "\n",
91 | "sns.set_palette(sns.color_palette(COLORS_PALETTE))\n",
92 | "\n",
93 | "rcParams['figure.figsize'] = 15, 10\n",
94 | "\n",
95 | "RANDOM_SEED = 42\n",
96 | "np.random.seed(RANDOM_SEED)\n",
97 | "torch.manual_seed(RANDOM_SEED)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "sagemaker_session = sagemaker.Session()\n",
107 | "bucket = sagemaker_session.default_bucket()\n",
108 | "prefix = \"sagemaker/pytorch-vt-resnet34\"\n",
109 | "role = sagemaker.get_execution_role()"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "## Image Pre-processing \n",
117 | "The dataset used in the notebook is `Intel Image Classification` downloaded from kaggle.com. \n",
118 | "It contains around 25k images of size 150x150 distributed under 6 categories.\n",
119 | "```\n",
120 | "{'buildings' -> 0,\n",
121 | "'forest' -> 1,\n",
122 | "'glacier' -> 2,\n",
123 | "'mountain' -> 3,\n",
124 | "'sea' -> 4,\n",
125 | "'street' -> 5 }\n",
126 | "```\n",
127 | "The `train, test and prediction` data is separated in each zip files. There are around 14k images in Train, 3k in Test and 7k in Prediction.\n",
128 | "You can download the dataset from [here](https://www.kaggle.com/puneet6060/intel-image-classification/download), rename the zip file to `data1`, upload it in the Jupyter Lab inside the `image-classification` folder and then follow the steps below. \n"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "# extracting files may take 5-10mins.\n",
138 | "from zipfile import ZipFile\n",
139 | "with ZipFile('data1.zip', 'r') as zipObj:\n",
140 | " # Extract all the contents of data1.zip file in different data1 directory\n",
141 | " zipObj.extractall('data1')"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "# Store location for train, test and prediction dataset. \n",
151 | "train_set = './data1/seg_train/seg_train'\n",
152 | "test_set = './data1/seg_test/seg_test'\n",
153 | "pred_set = './data1/seg_pred/seg_pred'"
154 | ]
155 | },
156 | {
157 | "cell_type": "markdown",
158 | "metadata": {},
159 | "source": [
160 | "**We will get each label folder and we can see that we have six folders.** Each of these folders correspond to the classes in the dataset as shown below:\n",
161 | "\n",
162 | "* buildings = 0 \n",
163 | "* forest = 1\n",
164 | "* glacier = 2\n",
165 | "* mountain = 3\n",
166 | "* sea = 4\n",
167 | "* street = 5 "
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "class_names = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']\n",
177 | "class_indices = [0,1,2,3,4,5]"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": [
186 | "# display the count of folders for each class, since we have 6 classes, lets verify it. \n",
187 | "train_folders = sorted(glob(train_set + '/*'))\n",
188 | "len(train_folders)"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {},
194 | "source": [
195 | "### Defining the helper functions to load and view images."
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "def load_image(img_path, resize=True):\n",
205 | " img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)\n",
206 | " \n",
207 | " if resize:\n",
208 | " img = cv2.resize(img, (64,64), interpolation = cv2.INTER_AREA)\n",
209 | " \n",
210 | " return img\n",
211 | "\n",
212 | "def show_image(img_path):\n",
213 | " img = load_image(img_path)\n",
214 | " plt.imshow(img)\n",
215 | " plt.axis('off')\n",
216 | " \n",
217 | "def show_sign_grid(image_paths):\n",
218 | " images = [load_image(img) for img in image_paths]\n",
219 | " images = torch.as_tensor(images)\n",
220 | " images = images.permute(0,3,1,2)\n",
221 | " grid_img = torchvision.utils.make_grid(images, nrow=11)\n",
222 | " plt.figure(figsize=(24,12))\n",
223 | " plt.imshow(grid_img.permute(1,2,0))\n",
224 | " plt.axis('off')"
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "metadata": {},
230 | "source": [
231 | "**Display sample images for all 6 classes in the dataset.**"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "metadata": {},
238 | "outputs": [],
239 | "source": [
240 | "sample_images = [np.random.choice(glob(f'{tf}/*jpg')) for tf in train_folders]\n",
241 | "show_sign_grid(sample_images)"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "**We will copy all the images to new directory to re-organize the structure of the folder, the purpose is to make it easier for `torchvision dataset` helpers to utilize the images.**\n",
249 | "The new directory structure will look like this: \n",
250 | "```\n",
251 | "|- data\n",
252 | "|----train\n",
253 | "|----val\n",
254 | "```"
255 | ]
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "metadata": {},
260 | "source": [
261 | "**We are going to reserve 80% for train and 20% for validation for each class, then copy them to the `data` folder.**"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": null,
267 | "metadata": {},
268 | "outputs": [],
269 | "source": [
270 | "# this step may take 2mins to execute.\n",
271 | "!rm -rf data\n",
272 | "\n",
273 | "DATA_DIR = Path('data')\n",
274 | "\n",
275 | "DATASETS = ['train', 'val']\n",
276 | "\n",
277 | "for ds in DATASETS:\n",
278 | " for cls in class_names:\n",
279 | " (DATA_DIR / ds / cls).mkdir(parents=True, exist_ok=True)"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": null,
285 | "metadata": {},
286 | "outputs": [],
287 | "source": [
288 | "# counting the images in each class. This may take 1-2mins to execute.\n",
289 | "for i, cls_index in enumerate(class_indices):\n",
290 | " image_paths = np.array(glob(f'{train_folders[cls_index]}/*jpg'))\n",
291 | " class_name = class_names[i]\n",
292 | " print(f'{class_name}: {len(image_paths)}')\n",
293 | " np.random.shuffle(image_paths)\n",
294 | " \n",
295 | " ds_split = np.split(\n",
296 | " image_paths,\n",
297 | " indices_or_sections=[int(.8*len(image_paths)), int(.9*len(image_paths))]\n",
298 | " )\n",
299 | " \n",
300 | " dataset_data = zip(DATASETS, ds_split)\n",
301 | " for ds, images in dataset_data:\n",
302 | " for img_path in images:\n",
303 | " shutil.copy(img_path, f'{DATA_DIR}/{ds}/{class_name}/')"
304 | ]
305 | },
306 | {
307 | "cell_type": "markdown",
308 | "metadata": {},
309 | "source": [
310 | "**Distribution of classes are good, the total per class ratio is not so high.**\n",
311 | "\n",
312 | "_We add some transformations to artifically increase the size of dataset, particularily random resizing, rotation and horizontal flips, then we normalize the tensors using present values for each channel._"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "metadata": {},
319 | "outputs": [],
320 | "source": [
321 | "mean_nums = [0.485, 0.456, 0.406]\n",
322 | "std_nums = [0.229, 0.224, 0.225]\n",
323 | "\n",
324 | "transforms = {'train': T.Compose([\n",
325 | " T.RandomResizedCrop(size=224),\n",
326 | " T.RandomRotation(degrees=15),\n",
327 | " T.RandomHorizontalFlip(),\n",
328 | " T.ToTensor(),\n",
329 | " T.Normalize(mean_nums, std_nums)\n",
330 | "]), 'val': T.Compose([\n",
331 | " T.Resize(size=224),\n",
332 | " T.CenterCrop(size=224),\n",
333 | " T.ToTensor(),\n",
334 | " T.Normalize(mean_nums, std_nums)\n",
335 | "]),}"
336 | ]
337 | },
338 | {
339 | "cell_type": "markdown",
340 | "metadata": {},
341 | "source": [
342 | "Lets create the Pytorch Dataloader from the `data` directory."
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": null,
348 | "metadata": {},
349 | "outputs": [],
350 | "source": [
351 | "image_datasets = {\n",
352 | " d: ImageFolder(f'{DATA_DIR}/{d}', transforms[d]) for d in DATASETS\n",
353 | "}\n",
354 | "\n",
355 | "data_loaders = {\n",
356 | " d: DataLoader(image_datasets[d], batch_size=16, shuffle=True, num_workers=4) for d in DATASETS\n",
357 | "}"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": null,
363 | "metadata": {},
364 | "outputs": [],
365 | "source": [
366 | "# counting the images in datasets created above.\n",
367 | "dataset_sizes = {d: len(image_datasets[d]) for d in DATASETS}\n",
368 | "class_names = image_datasets['train'].classes\n",
369 | "dataset_sizes"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": null,
375 | "metadata": {},
376 | "outputs": [],
377 | "source": [
378 | "def imshow(inp, title=None):\n",
379 | " inp = inp.numpy().transpose((1,2,0))\n",
380 | " mean = np.array([mean_nums])\n",
381 | " std = np.array([std_nums])\n",
382 | " inp = std * inp + mean\n",
383 | " inp = np.clip(inp,0,1)\n",
384 | " plt.imshow(inp)\n",
385 | " if title is not None:\n",
386 | " plt.title(title)\n",
387 | " plt.axis('off')\n",
388 | " \n",
389 | "inputs, classes = next(iter(data_loaders['train']))\n",
390 | "out = torchvision.utils.make_grid(inputs)\n",
391 | "\n",
392 | "imshow(out, title=[class_names[x] for x in classes])"
393 | ]
394 | },
395 | {
396 | "cell_type": "markdown",
397 | "metadata": {},
398 | "source": [
399 | "**Lets have a look at some sample images with all the transformations applied to the images.**"
400 | ]
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "metadata": {},
405 | "source": [
406 | "### Uploading the data to S3\n",
407 | "We are going to use the `sagemaker.Session.upload_data` function to upload our datasets to an S3 location. The return value `input_path` identifies the S3 path -- we will use later when we start the training job. This might take few minutes."
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": null,
413 | "metadata": {},
414 | "outputs": [],
415 | "source": [
416 | "# upload data to S3, this might take few minutes.\n",
417 | "input_path = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)\n",
418 | "print('input specification (in this case, just an S3 path): {}'.format(input_path))"
419 | ]
420 | },
421 | {
422 | "cell_type": "markdown",
423 | "metadata": {},
424 | "source": [
425 | "## Build and Train the model \n",
426 | "We will use pretrained resnet34 model and replace the last layer with the custom visual transformer, to classify the images.\n",
427 | "\n",
428 | "1. Create a Visual Transformer class (replacing the last layer with a transformer layer).\n",
429 | "2. Import ResNet34 pretrained model.\n",
430 | "3. Convert it into training mode.\n",
431 | "4. Train the model on new data.\n",
432 | "5. Evaluate model performance on `validation loss` , `validation accuracy` and `execution time`. \n",
433 | "\n",
434 | "All the above steps are performed in `Training script`. \n",
435 | "\n",
436 | "### Training script\n",
437 | "The `vt-resnet-34.py` script provides all the code we need for training and hosting a SageMaker model (`model_fn` function to load a model).\n",
438 | "The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables, such as:\n",
439 | "\n",
440 | "* `SM_MODEL_DIR`: A string representing the path to the directory to write model artifacts to.\n",
441 | " These artifacts are uploaded to S3 for model hosting.\n",
442 | "* `SM_NUM_GPUS`: The number of gpus available in the current container.\n",
443 | "* `SM_CURRENT_HOST`: The name of the current container on the container network.\n",
444 | "* `SM_HOSTS`: JSON encoded list containing all the hosts .\n",
445 | "\n",
446 | "Supposing one input channel, 'training', was used in the call to the PyTorch estimator's `fit()` method, the following will be set, with the format `SM_CHANNEL_[channel_name]`:\n",
447 | "\n",
448 | "* `SM_CHANNEL_TRAINING`: A string representing the path to the directory containing data in the 'training' channel.\n",
449 | "\n",
450 | "For more information about training environment variables, please visit [SageMaker Containers](https://github.com/aws/sagemaker-containers).\n",
451 | "\n",
452 | "A typical training script loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model to `model_dir` so that it can be hosted later. Hyperparameters are passed to your script as arguments and can be retrieved with an `argparse.ArgumentParser` instance.\n",
453 | "\n",
454 | "Because the SageMaker imports the training script, you should put your training code in a main guard (``if __name__=='__main__':``) if you are using the same script to host your model as we do in this example, so that SageMaker does not inadvertently run your training code at the wrong point in execution.\n",
455 | "\n",
456 | "For example, the script run by this notebook:"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": null,
462 | "metadata": {},
463 | "outputs": [],
464 | "source": [
465 | "!pygmentize code/vt-resnet-34.py"
466 | ]
467 | },
468 | {
469 | "cell_type": "markdown",
470 | "metadata": {},
471 | "source": [
472 | "## Train on Amazon SageMaker\n",
473 | "\n",
474 | "We use Amazon SageMaker to train and deploy a model using our custom PyTorch code. The Amazon SageMaker Python SDK makes it easier to run a PyTorch script in Amazon SageMaker using its PyTorch estimator. After that, we can use the SageMaker Python SDK to deploy the trained model and run predictions. For more information on how to use this SDK with PyTorch, see [the SageMaker Python SDK documentation](https://sagemaker.readthedocs.io/en/stable/using_pytorch.html).\n",
475 | "\n",
476 | "To start, we use the `PyTorch` estimator class to train our model. When creating our estimator, we make sure to specify a few things:\n",
477 | "\n",
478 | "* `entry_point`: the name of our PyTorch script. It contains our training script, which loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model. It also contains code to load and run the model during inference.\n",
479 | "* `source_dir`: the location of our training scripts and requirements.txt file. \"requirements.txt\" lists packages you want to use with your script.\n",
480 | "* `framework_version`: the PyTorch version we want to use\n",
481 | "\n",
482 | "The PyTorch estimator supports single-machine, distributed PyTorch training. To use this, we just set instance_count equal to one. Our training script supports distributed training for only GPU instances.\n",
483 | "\n",
484 | "After creating the estimator, we then call fit(), which launches a training job. We use the Amazon S3 URIs where we uploaded the training data earlier."
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": null,
490 | "metadata": {},
491 | "outputs": [],
492 | "source": [
493 | "from datetime import datetime\n",
494 | "from sagemaker.pytorch import PyTorch\n",
495 | "\n",
496 | "now = datetime.now()\n",
497 | "timestr = now.strftime(\"%m-%d-%Y-%H-%M-%S\")\n",
498 | "vt_training_job_name = \"vt-training-{}\".format(timestr)\n",
499 | "print(vt_training_job_name)\n",
500 | "\n",
501 | "estimator = PyTorch(\n",
502 | " entry_point=\"vt-resnet-34.py\",\n",
503 | " source_dir=\"code\",\n",
504 | " role=role,\n",
505 | " framework_version=\"1.6.0\",\n",
506 | " py_version=\"py3\",\n",
507 | " instance_count=1, # this script only supports single instance multi-gpu distributed data training.\n",
508 | " instance_type=\"ml.p3.16xlarge\", # this instance has 8 GPUs, you can change it, if you want to train on bigger/smaller instance. \n",
509 | " use_spot_instances=False, # you can set it to True if you want to use Spot instance for training which might take some additional time, but are more cost effective.\n",
510 | "# max_run=3600, # uncomment it, if use_spot_instances = True\n",
511 | "# max_wait=3600, # uncomment it, if use_spot_instances = True\n",
512 | " debugger_hook_config=False,\n",
513 | " hyperparameters={\n",
514 | " \"epochs\": 5,\n",
515 | " \"num_classes\": 6,\n",
516 | " \"batch-size\": 256,\n",
517 | " },\n",
518 | " metric_definitions=[\n",
519 | " {'Name': 'validation:loss', 'Regex': 'Valid_loss = ([0-9\\\\.]+);'},\n",
520 | " {'Name': 'validation:accuracy', 'Regex': 'Valid_accuracy = ([0-9\\\\.]+);'},\n",
521 | " {'Name': 'train:accuracy', 'Regex': 'Train_accuracy = ([0-9\\\\.]+);'},\n",
522 | " {'Name': 'train:loss', 'Regex': 'Train_loss = ([0-9\\\\.]+);'},\n",
523 | " ]\n",
524 | ")\n",
525 | "estimator.fit({\"training\": input_path}, wait=True, job_name=vt_training_job_name)"
526 | ]
527 | },
528 | {
529 | "cell_type": "code",
530 | "execution_count": null,
531 | "metadata": {},
532 | "outputs": [],
533 | "source": [
534 | "vt_training_job_name = estimator.latest_training_job.name\n",
535 | "print(\"Visual Transformer training job name: \", vt_training_job_name)"
536 | ]
537 | },
538 | {
539 | "cell_type": "markdown",
540 | "metadata": {},
541 | "source": [
542 | "## Deploy VT-ResNet34 Model to SageMaker Endpoint "
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": null,
548 | "metadata": {},
549 | "outputs": [],
550 | "source": [
551 | "from sagemaker import get_execution_role\n",
552 | "ENDPOINT_NAME='pytorch-inference-{}'.format(timestr)\n",
553 | "predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p3.2xlarge', endpoint_name=ENDPOINT_NAME)"
554 | ]
555 | },
556 | {
557 | "cell_type": "code",
558 | "execution_count": null,
559 | "metadata": {},
560 | "outputs": [],
561 | "source": [
562 | "import json\n",
563 | "import requests\n",
564 | "from IPython.display import Image \n",
565 | "import json\n",
566 | "import boto3\n",
567 | "import numpy as np\n",
568 | "\n",
569 | "runtime= boto3.client('runtime.sagemaker')\n",
570 | "client = boto3.client('sagemaker')\n",
571 | "\n",
572 | "endpoint_desc = client.describe_endpoint(EndpointName=ENDPOINT_NAME)\n",
573 | "print(endpoint_desc)\n",
574 | "print('---'*60)"
575 | ]
576 | },
577 | {
578 | "cell_type": "markdown",
579 | "metadata": {},
580 | "source": [
581 | "## Making Predictions with VT-ResNet34 Model using SageMaker Endpoint\n",
582 | "\n",
583 | "Inference dataset is taken from `Registry of Open Data from AWS` (https://registry.opendata.aws/multimedia-commons/)."
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": null,
589 | "metadata": {},
590 | "outputs": [],
591 | "source": [
592 | "payload = '[{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/019/1390196df443f2cf614f2255ae75fcf8.jpg\"},\\\n",
593 | "{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/015/1390157d4caaf290962de5c5fb4c42.jpg\"},\\\n",
594 | "{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/020/1390207be327f4c4df1259c7266473.jpg\"},\\\n",
595 | "{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/021/139021f9aed9896831bf88f349fcec.jpg\"},\\\n",
596 | "{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/028/139028d865bafa3de66568eeb499f4a6.jpg\"},\\\n",
597 | "{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/030/13903090f3c8c7a708ca69c8d5d68b2.jpg\"},\\\n",
598 | "{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/002/010/00201099c5bf0d794c9a951b74390.jpg\"},\\\n",
599 | "{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/136/139136bb43e41df8949f873fb44af.jpg\"},\\\n",
600 | "{\"url\":\"https://multimedia-commons.s3-us-west-2.amazonaws.com/data/images/139/145/1391457e4a2e25557cbf956aaee4345.jpg\"}]'\n",
601 | "\n",
602 | "payload = json.loads(payload)\n",
603 | "for item in payload:\n",
604 | " item = json.dumps(item)\n",
605 | " response = runtime.invoke_endpoint(EndpointName=ENDPOINT_NAME, \n",
606 | " ContentType='application/json', \n",
607 | " Body=item)\n",
608 | " result = response['Body'].read()\n",
609 | " result = json.loads(result)\n",
610 | " print('predicted:', result[0]['prediction'])\n",
611 | "\n",
612 | " from PIL import Image\n",
613 | " import requests\n",
614 | "\n",
615 | " input_data = json.loads(item)\n",
616 | " url = input_data['url']\n",
617 | " im = Image.open(requests.get(url, stream=True).raw)\n",
618 | " newsize = (250, 250) \n",
619 | " im1 = im.resize(newsize) \n",
620 | "\n",
621 | " from IPython.display import Image\n",
622 | " display(im1)"
623 | ]
624 | },
625 | {
626 | "cell_type": "markdown",
627 | "metadata": {},
628 | "source": [
629 | "# Cleanup\n",
630 | "\n",
631 | "Lastly, please remember to delete the Amazon SageMaker endpoint to avoid charges. Uncomment following statement `predictor.delete_endpoint()` to do so. "
632 | ]
633 | },
634 | {
635 | "cell_type": "code",
636 | "execution_count": null,
637 | "metadata": {},
638 | "outputs": [],
639 | "source": [
640 | "predictor.delete_endpoint()"
641 | ]
642 | },
643 | {
644 | "cell_type": "markdown",
645 | "metadata": {},
646 | "source": [
647 | "## References \n",
648 | "- [1] Visual Transformers: Token-based Image Representation and Processing for Computer Vision (https://arxiv.org/pdf/2006.03677.pdf)\n",
649 | "- [2] Kaggle notebook (https://www.kaggle.com/asollie/intel-image-multiclass-pytorch-94-test-acc)\n",
650 | "- [3] Registry of Open Data from AWS (https://registry.opendata.aws/multimedia-commons/)"
651 | ]
652 | }
653 | ],
654 | "metadata": {
655 | "instance_type": "ml.g4dn.xlarge",
656 | "kernelspec": {
657 | "display_name": "Python 3 (PyTorch 1.6 Python 3.6 GPU Optimized)",
658 | "language": "python",
659 | "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/pytorch-1.6-gpu-py36-cu110-ubuntu18.04-v3"
660 | },
661 | "language_info": {
662 | "codemirror_mode": {
663 | "name": "ipython",
664 | "version": 3
665 | },
666 | "file_extension": ".py",
667 | "mimetype": "text/x-python",
668 | "name": "python",
669 | "nbconvert_exporter": "python",
670 | "pygments_lexer": "ipython3",
671 | "version": "3.6.10"
672 | }
673 | },
674 | "nbformat": 4,
675 | "nbformat_minor": 4
676 | }
--------------------------------------------------------------------------------
/img/vt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-visual-transformer/2917e7f1abb4eb4acb0fd8e70947d5443d2d761c/img/vt.png
--------------------------------------------------------------------------------