├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── sm_augmentation.ipynb ├── src ├── requirements.txt └── sm_augmentation_train-script.py ├── util_debugger.py └── util_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | <<<<<<< HEAD 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | this software and associated documentation files (the "Software"), to deal in 6 | the Software without restriction, including without limitation the rights to 7 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | the Software, and to permit persons to whom the Software is furnished to do so. 9 | 10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 11 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 12 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 13 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 14 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 15 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 16 | 17 | ======= 18 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 19 | software and associated documentation files (the "Software"), to deal in the Software 20 | without restriction, including without limitation the rights to use, copy, modify, 21 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 22 | permit persons to whom the Software is furnished to do so. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 25 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 26 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 27 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 28 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 29 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 30 | >>>>>>> r1remote/main 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sagemaker-cv-preprocessing-training-performance 2 | 3 | This repository contains [Amazon SageMaker](https://aws.amazon.com/sagemaker/) training implementation with data pre-processing (decoding + augmentations) on both GPUs and CPUs for computer vision — allowing you to compare and reduce training time by addressing CPU bottlenecks caused by increasing data pre-processing load. This is achieved by GPU-accelerated JPEG image decoding and offloading of augmentation to GPUs using [NVIDIA DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/). Performance bottlenecks and ystem utilizations metrics are compared using [Amazon Sagemaker Debugger](https://docs.aws.amazon.com/sagemaker/latest/dg/train-debugger.html). 4 | 5 | ## Module Description: 6 | 7 | - `util_train.py`: Launch [Amazon Sagemaker PyTorch traininng](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html) jobs with your custom training script. 8 | - `src/sm_augmentation_train-script.py`: Custom training script to train models of different complexities (`RESNET-18`, `RESNET-50`, `RESNET-152`) with data pre-processing implementation for: 9 | - JPEG decoding and augmentation on CPUs using PyTorch Dataloader 10 | - JPEG decoding and augmentation on CPUs & GPUs using NVIDIA DALI 11 | - `util_debugger.py`: Extract system utilization metrics with [SageMaker Debugger](https://sagemaker.readthedocs.io/en/stable/amazon_sagemaker_debugger.html). 12 | 13 | ## Run SageMaker training job with decoding and augmentation on GPU: 14 | - Parameters such as training data path, S3 bucket, epochs and other training hyperparameters can be adapted at `util_train.py`. 15 | - The custom custom training script used is `src/sm_augmentation_train-script.py`. 16 | ``` 17 | from util_debugger import get_sys_metric 18 | from util_train import aug_exp_train 19 | aug_exp_train(model_arch = 'RESNET50', 20 | batch_size = '32', 21 | aug_operator = 'dali-gpu', 22 | instance_type='ml.p3.2xlarge', 23 | curr_sm_role = 'to-be-added') 24 | ``` 25 | - Note that this implementation at the moment is optimized for single-GPU training to address multi-core CPU bottlenecks. The [DALI Decoder operation](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/supported_ops.html#nvidia.dali.fn.decoders.image) can be updated with improved usage of `device_memory_padding` and` host_memory_padding` for multi-GPU larger instances. 26 | 27 | ## Experiment to compare bottlenecks: 28 | 29 | - Create an Amazon S3 bucket called `sm-aug-test` and upload the [Imagenette dataset](https://github.com/fastai/imagenette) ([download link](https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz)). 30 | - Update your SageMaker execution role in the notebook `sm_augmentation_train-script.py` and run the notebook to compare seconds/epoch and system utilization for training jobs by toggling the following parameters: 31 | - `instance_type` (default: `ml.p3.2xlarge`) 32 | - `model_arch` (default: `RESNET18`) 33 | - `batch_size` (default: `32`) 34 | - `aug_load_factor` (default: `12`) 35 | - `AUGMENTATION_APPROACHES` (default: `['pytorch-cpu', 'dali-gpu']`) 36 | - Comparison results using the above default parameter setup: 37 | - Seconds/ Epoch improvement of `72.59%` in Amazon SageMaker training job by offloading JPEG decoding and heavy augmentation to GPU — addressing data pre-processing bottleneck to improve performance-cost ratio. 38 | - Using the above strategy, training time improvement is higher for lighter models like `RESNET-18` (which causes more CPU bottlenecks) over heavier model such as `RESNET-152` as the `aug_load_factor` is increased while keeping lower batch size of `32`. 39 | - System utilization Histograms and CPU bottleneck Heatmaps are generated with SageMaker Debugger in the notebook. Profiler Report and other interactive visuals available on SageMaker Studio. 40 | - Further detailed results (based on different augmentation loads, batch sizes, and model complexities for training on 8-CPUs and 1-GPU) are available on request. 41 | 42 | ## Security 43 | 44 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 45 | 46 | ## License 47 | 48 | This library is licensed under the MIT-0 License. See the LICENSE file. 49 | -------------------------------------------------------------------------------- /sm_augmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "fabulous-number", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from util_debugger import get_sys_metric\n", 11 | "from util_train import aug_exp_train\n", 12 | "import pprint" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "shared-wheat", 18 | "metadata": {}, 19 | "source": [ 20 | "### Setting up required parameters for the experiment to compare training time for augmentation on CPU and GPU" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "selective-count", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "curr_sm_role = # Please update your SageMaker Execution Role\n", 31 | " \n", 32 | "# 'pytorch-cpu': JPEG decoding and augmentation on CPUs using PyTorch Dataloader\n", 33 | "# 'dali-cpu': JPEG decoding and augmentation on CPUs using NVIDIA DALI\n", 34 | "# 'dali-gpu': JPEG decoding and augmentation on GPUs using NVIDIA DALI\n", 35 | "AUGMENTATION_APPROACHES = ['pytorch-cpu', 'dali-gpu']\n", 36 | "\n", 37 | "instance_type = 'ml.p3.2xlarge'\n", 38 | "# Required for plotting system utilization\n", 39 | "num_cpu = 8\n", 40 | "num_gpu = 1\n", 41 | "\n", 42 | "# Training Script supports: 'RESNET50', 'RESNET18', and 'RESNET152'\n", 43 | "model_arch = 'RESNET18'\n", 44 | "\n", 45 | "batch_size = 32\n", 46 | "\n", 47 | "# Factor by which to repeat augmentation operations for increasing data pre-processing load \n", 48 | "aug_load_factor = 12\n", 49 | "\n", 50 | "# You can change other parameters such as training data, S3 bucket, Epoch and training hyperparameters at util_train.script" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "intensive-party", 56 | "metadata": {}, 57 | "source": [ 58 | "### Launching training jobs and fetching system utilization for data pre-processing on CPUs vs on GPUs" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "reserved-president", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "exp_data = {}\n", 69 | "trial = 0\n", 70 | "pp = pprint.PrettyPrinter()\n", 71 | "\n", 72 | "for aug_operator in AUGMENTATION_APPROACHES:\n", 73 | " \n", 74 | " trial = trial + 1\n", 75 | " trial_data = dict.fromkeys(['train_job_id', 'model_arch', 'instance_type', 'batch_size', 'aug_load_factor', 'aug_operator', 'sys_util_df'])\n", 76 | " \n", 77 | " # Launch Amazon Sagemaker PyTorch traininng jobs with your custom training script.\n", 78 | " train_job_id, train_estimator = aug_exp_train(model_arch, \n", 79 | " batch_size, \n", 80 | " aug_operator, \n", 81 | " aug_load_factor, \n", 82 | " instance_type, \n", 83 | " curr_sm_role)\n", 84 | " \n", 85 | " # Extract system utilization metrics with SageMaker Debugger\n", 86 | " heatmap, metric_hist, sys_util_df = get_sys_metric(train_estimator, \n", 87 | " num_cpu,\n", 88 | " num_gpu)\n", 89 | " \n", 90 | " # Print parameter and result summary for the current training job run\n", 91 | " trial_data['train_job_id'] = train_job_id\n", 92 | " trial_data['model_arch'] = model_arch\n", 93 | " trial_data['instance_type'] = instance_type\n", 94 | " trial_data['batch_size'] = batch_size\n", 95 | " trial_data['aug_load_factor'] = aug_load_factor\n", 96 | " trial_data['aug_operator'] = aug_operator\n", 97 | " trial_data['sys_util_df'] = sys_util_df\n", 98 | " \n", 99 | " pp.pprint(trial_data) \n", 100 | " exp_data.update({'trial-'+str(trial): trial_data})\n", 101 | " \n", 102 | "pp.pprint(exp_data) " 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "disabled-aaron", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [] 112 | } 113 | ], 114 | "metadata": { 115 | "kernelspec": { 116 | "display_name": ".venv_dev", 117 | "language": "python", 118 | "name": ".venv_dev" 119 | }, 120 | "language_info": { 121 | "codemirror_mode": { 122 | "name": "ipython", 123 | "version": 3 124 | }, 125 | "file_extension": ".py", 126 | "mimetype": "text/x-python", 127 | "name": "python", 128 | "nbconvert_exporter": "python", 129 | "pygments_lexer": "ipython3", 130 | "version": "3.8.2" 131 | } 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 5 135 | } 136 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | nvidia-pyindex 2 | --extra-index-url https://developer.download.nvidia.com/compute/redist 3 | nvidia-dali-cuda110 -------------------------------------------------------------------------------- /src/sm_augmentation_train-script.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import argparse 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | import copy 8 | import time 9 | from PIL import Image 10 | 11 | import torch 12 | import torch.distributed as dist 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | from torchvision import models, datasets, transforms 18 | 19 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy 20 | from nvidia.dali.pipeline import pipeline_def 21 | import nvidia.dali.types as types 22 | import nvidia.dali.fn as fn 23 | 24 | logger = logging.getLogger(__name__) 25 | logger.setLevel(logging.INFO) 26 | logger.addHandler(logging.StreamHandler(sys.stdout)) 27 | 28 | Image.MAX_IMAGE_PIXELS = None 29 | 30 | """ 31 | Method to augment and load data on CPU with PyTorch Dataloaders 32 | """ 33 | 34 | 35 | def augmentation_pytorch(train_dir, batch_size, workers, is_distributed, use_cuda, aug_load_factor): 36 | print ("Image augmentation using PyTorch Dataloaders on CPUs") 37 | aug_ops = [ 38 | transforms.RandomHorizontalFlip(), 39 | transforms.RandomVerticalFlip(), 40 | transforms.RandomRotation(5) 41 | ] 42 | crop_norm_ops = [ 43 | transforms.RandomResizedCrop(224), 44 | transforms.ToTensor(), 45 | transforms.Normalize([0.485, 0.456, 0.406], 46 | [0.229, 0.224, 0.225]) 47 | ] 48 | 49 | train_aug_ops = [] 50 | # Repeating Augmentation to influence bottleneck 51 | for iteration in range(aug_load_factor): 52 | train_aug_ops = train_aug_ops + aug_ops 53 | 54 | data_transforms = { 55 | 'train': transforms.Compose(train_aug_ops + crop_norm_ops), 56 | 'val': transforms.Compose(crop_norm_ops), 57 | } 58 | 59 | image_datasets = {x: datasets.ImageFolder(os.path.join(train_dir, x), 60 | data_transforms[x]) 61 | for x in ['train', 'val']} 62 | train_sampler = torch.utils.data.distributed.DistributedSampler(image_datasets) if is_distributed else None 63 | dataloaders = {x: torch.utils.data.DataLoader(dataset=image_datasets[x], 64 | batch_size=batch_size, 65 | shuffle=train_sampler, 66 | num_workers=workers, 67 | pin_memory=True if use_cuda else False) 68 | for x in ['train', 'val']} 69 | 70 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 71 | return dataloaders, dataset_sizes 72 | 73 | 74 | """ 75 | Method to augment and load data on CPU or GPU with NVIDIA DALI 76 | """ 77 | 78 | 79 | @pipeline_def 80 | def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu, is_training, aug_load_factor): 81 | images, labels = fn.readers.file(file_root=data_dir, 82 | shard_id=shard_id, 83 | num_shards=num_shards, 84 | random_shuffle=is_training, 85 | pad_last_batch=True, 86 | name="Reader") 87 | """ 88 | For jpeg images, “mixed” backend uses the nvJPEG library. If hardware is available, operator will use dedicated 89 | hardware decoder. For jpeg images, “cpu” backend uses libjpeg-turbo. Other image formats are decoded with OpenCV 90 | or other specific libraries, such as libtiff. 91 | """ 92 | dali_device = 'cpu' if dali_cpu else 'gpu' 93 | decoder_device = 'cpu' if dali_cpu else 'mixed' 94 | 95 | images = fn.decoders.image(images, 96 | device=decoder_device, 97 | output_type=types.RGB, 98 | memory_stats=True) 99 | if is_training: 100 | # Repeating Augmentation to influence bottleneck 101 | for x in range(aug_load_factor): 102 | images = fn.flip(images, device=dali_device, horizontal=1, vertical=1) 103 | images = fn.rotate(images, angle=5, device=dali_device) 104 | 105 | images = fn.random_resized_crop(images, size=size, device=dali_device) 106 | images = fn.crop_mirror_normalize(images, 107 | dtype=types.FLOAT, 108 | output_layout="CHW", 109 | crop=(crop, crop), 110 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 111 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) 112 | images = images.gpu() 113 | labels = labels.gpu() 114 | 115 | return images, labels 116 | 117 | 118 | def augmentation_dali(train_dir, batch_size, workers, host_rank, world_size, seed, aug_load_factor, dali_cpu): 119 | if dali_cpu: 120 | print ("Image augmentation using DALI pipelines on CPUs") 121 | else: 122 | print ("Image augmentation using DALI pipelines on GPUs") 123 | 124 | """ 125 | Augmentation on GPU with DALI is not implemented at the moment for distributed training. Refer to: 126 | https://github.com/NVIDIA/DALI/blob/c4e86b55dccba083ae944cf00a478678b7e906cc/docs/examples/use_cases/pytorch/resnet50/main.py 127 | """ 128 | 129 | dataloaders = {} 130 | dataset_sizes = {} 131 | 132 | train_path = train_dir + '/train/' 133 | dataset_sizes['train'] = sum([len(files) for r, d, files in os.walk(train_path)]) 134 | train_pipe = create_dali_pipeline(batch_size=batch_size, 135 | num_threads=workers, 136 | device_id=host_rank, 137 | seed=seed, 138 | data_dir=train_path, 139 | crop=224, 140 | size=256, 141 | dali_cpu=dali_cpu, 142 | shard_id=host_rank, 143 | num_shards=world_size, 144 | is_training=True, 145 | aug_load_factor=aug_load_factor) 146 | train_pipe.build() 147 | dataloaders['train'] = DALIClassificationIterator(train_pipe, 148 | reader_name="Reader", 149 | last_batch_policy=LastBatchPolicy.PARTIAL) 150 | 151 | val_path = train_dir + '/val/' 152 | dataset_sizes['val'] = sum([len(files) for r, d, files in os.walk(val_path)]) 153 | val_pipe = create_dali_pipeline(batch_size=batch_size, 154 | num_threads=workers, 155 | device_id=host_rank, 156 | seed=seed, 157 | data_dir=val_path, 158 | crop=224, 159 | size=256, 160 | dali_cpu=dali_cpu, 161 | shard_id=host_rank, 162 | num_shards=world_size, 163 | is_training=False, 164 | aug_load_factor=aug_load_factor) 165 | val_pipe.build() 166 | dataloaders['val'] = DALIClassificationIterator(val_pipe, 167 | reader_name="Reader", 168 | last_batch_policy=LastBatchPolicy.PARTIAL) 169 | return dataloaders, dataset_sizes 170 | 171 | 172 | """ 173 | Method to train models for number of epochs 174 | """ 175 | 176 | 177 | def run_training_epochs(model_ft, num_epochs, criterion, optimizer_ft, dataloaders, dataset_sizes, device, USE_PYTORCH): 178 | best_model_wts = copy.deepcopy(model_ft.state_dict()) 179 | best_acc = 0.0 180 | 181 | total_epoch_time = 0 182 | for epoch in range(num_epochs): 183 | print('Running Epoch {}/{}'.format(epoch + 1, num_epochs)) 184 | 185 | epoch_start_time = time.time() 186 | 187 | # Each epoch has a training and validation phase 188 | for phase in ['train', 'val']: 189 | 190 | if phase == 'train': 191 | model_ft.train() 192 | else: 193 | model_ft.eval() 194 | 195 | running_loss = 0.0 196 | running_corrects = 0 197 | 198 | # Data iteration if using DALI Pipelines for loading the augmented data 199 | if not USE_PYTORCH: 200 | 201 | for i, data in enumerate(dataloaders[phase]): 202 | inputs = data[0]["data"] 203 | labels = data[0]["label"].squeeze(-1).long() 204 | 205 | optimizer_ft.zero_grad() 206 | with torch.set_grad_enabled(phase == 'train'): 207 | outputs = model_ft(inputs) 208 | _, preds = torch.max(outputs, 1) 209 | loss = criterion(outputs, labels) 210 | if phase == 'train': 211 | loss.backward() 212 | optimizer_ft.step() 213 | running_loss += loss.item() * inputs.size(0) 214 | running_corrects += torch.sum(preds == labels.data) 215 | 216 | # Data iteration if using PyTorch Dataloader for loading the augmented data 217 | else: 218 | 219 | for inputs, labels in dataloaders[phase]: 220 | inputs = inputs.to(device) 221 | labels = labels.to(device) 222 | 223 | optimizer_ft.zero_grad() 224 | with torch.set_grad_enabled(phase == 'train'): 225 | outputs = model_ft(inputs) 226 | _, preds = torch.max(outputs, 1) 227 | loss = criterion(outputs, labels) 228 | if phase == 'train': 229 | loss.backward() 230 | optimizer_ft.step() 231 | running_loss += loss.item() * inputs.size(0) 232 | running_corrects += torch.sum(preds == labels.data) 233 | 234 | epoch_loss = running_loss / dataset_sizes[phase] 235 | epoch_acc = running_corrects / dataset_sizes[phase] 236 | print('{}-loss: {:.4f} {}-acc: {:.4f}'.format( 237 | phase, epoch_loss, phase, epoch_acc)) 238 | 239 | if phase == 'val' and epoch_acc > best_acc: 240 | best_model_wts = copy.deepcopy(model_ft.state_dict()) 241 | 242 | epoch_time_elapsed = time.time() - epoch_start_time 243 | print('Epoch completed in {:.2f}s'.format(epoch_time_elapsed)) 244 | total_epoch_time = total_epoch_time + epoch_time_elapsed 245 | 246 | # Calculating Seconds/ Epoch: Metric used for comparing performance for the experiemnts 247 | print('-' * 25) 248 | print('Seconds per Epoch: {:.2f}'.format(total_epoch_time / num_epochs)) 249 | 250 | model_ft.load_state_dict(best_model_wts) 251 | return model_ft, best_acc 252 | 253 | 254 | def training(args): 255 | num_gpus = args.num_gpus 256 | hosts = args.hosts 257 | current_host = args.current_host 258 | backend = args.backend 259 | seed = args.seed 260 | 261 | is_distributed = len(hosts) > 1 and backend is not None 262 | logger.debug("Distributed training - {}".format(is_distributed)) 263 | use_cuda = num_gpus > 0 264 | logger.debug("Number of gpus available - {}".format(num_gpus)) 265 | device = torch.device("cuda" if use_cuda else "cpu") 266 | 267 | world_size = len(hosts) 268 | os.environ['WORLD_SIZE'] = str(world_size) 269 | host_rank = hosts.index(current_host) 270 | 271 | if is_distributed: 272 | # Initialize the distributed environment. 273 | dist.init_process_group(backend=backend, rank=host_rank, world_size=world_size) 274 | logger.info('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format( 275 | backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format( 276 | dist.get_rank(), num_gpus)) 277 | # set the seed for generating random numbers 278 | torch.manual_seed(seed) 279 | 280 | if use_cuda: 281 | torch.cuda.manual_seed(seed) 282 | 283 | # Loading training and validation data 284 | batch_size = args.batch_size 285 | train_dir = args.train_dir 286 | 287 | # Set to the available #CPUs here — Hits the file system concurrency with large #workers for large #CPU instances 288 | workers = os.cpu_count() if use_cuda else 0 289 | 290 | # By factor to repeat augmentation to influence bottleneck 291 | aug_load_factor = args.aug_load_factor 292 | 293 | # Deciding on the augmentation approach to use 294 | USE_PYTORCH = False 295 | USE_DALI_CPU = False 296 | if args.aug == 'pytorch-cpu': 297 | USE_PYTORCH = True 298 | if args.aug == 'dali-cpu': 299 | USE_DALI_CPU = True 300 | if USE_PYTORCH == True: 301 | dataloaders, dataset_sizes = augmentation_pytorch(train_dir, 302 | batch_size, 303 | workers, 304 | is_distributed, 305 | use_cuda, 306 | aug_load_factor) 307 | else: 308 | dataloaders, dataset_sizes = augmentation_dali(train_dir, 309 | batch_size, 310 | workers, 311 | host_rank, 312 | world_size, 313 | seed, 314 | aug_load_factor, 315 | dali_cpu=USE_DALI_CPU) 316 | 317 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 318 | 319 | # Deciding on the model to use 320 | if args.model_type == 'RESNET18': 321 | model_ft = models.resnet18(pretrained=False) 322 | elif args.model_type == 'RESNET50': 323 | model_ft = models.resnet50(pretrained=False) 324 | elif args.model_type == 'RESNET152': 325 | model_ft = models.resnet152(pretrained=False) 326 | else: 327 | sys.exit('Requested Model not found') 328 | 329 | model_ft = model_ft.to(device) 330 | 331 | if is_distributed and use_cuda: 332 | model_ft = torch.nn.parallel.DistributedDataParallel(model_ft) 333 | else: 334 | model_ft = torch.nn.DataParallel(model_ft) 335 | 336 | num_epochs = args.epochs 337 | criterion = nn.CrossEntropyLoss() 338 | optimizer_ft = optim.SGD(model_ft.parameters(), args.lr, args.momentum) 339 | 340 | # Running Model Training 341 | since = time.time() 342 | 343 | # Not using the trained model or accuracy score for this experiment 344 | model_ft, best_acc = run_training_epochs(model_ft, 345 | num_epochs, 346 | criterion, 347 | optimizer_ft, 348 | dataloaders, 349 | dataset_sizes, 350 | device, 351 | USE_PYTORCH) 352 | time_elapsed = time.time() - since 353 | 354 | print('-' * 25) 355 | print ("Model — ", args.model_type) 356 | print ("Augmentation Approach — ", args.aug) 357 | print ("Batch Size — ", batch_size) 358 | print ("Augmentation Load factor — ", aug_load_factor) 359 | print('-' * 25) 360 | 361 | 362 | if __name__ == '__main__': 363 | parser = argparse.ArgumentParser() 364 | 365 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 366 | help='input batch size for training (default: 32)') 367 | parser.add_argument('--epochs', type=int, default=2, metavar='N', 368 | help='number of epochs to train (default: 2)') 369 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 370 | help='learning rate (default: 0.001)') 371 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 372 | help='SGD momentum (default: 0.5)') 373 | parser.add_argument('--seed', type=int, default=42, metavar='S', 374 | help='random seed (default: 42)') 375 | parser.add_argument('--model-type', type=str, default=None, 376 | help='Model architecture to train') 377 | parser.add_argument('--aug', type=str, default='dali-gpu', 378 | help='Augmentation approach to use: pytorch-cpu, dali-cpu, or dali-gpu (default: dali-gpu)') 379 | parser.add_argument('--aug-load-factor', type=int, default=1, 380 | help='Factor by which augmentation should be repeated to create bottleneck (default: 1)') 381 | parser.add_argument('--backend', type=str, default=None, 382 | help='backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)') 383 | 384 | parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) 385 | parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) 386 | parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) 387 | parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS']) 388 | parser.add_argument('--train_dir', type=str, default=os.environ['SM_CHANNEL_TRAIN']) 389 | parser.add_argument('--val_dir', type=str, default=os.environ['SM_CHANNEL_VAL']) 390 | 391 | training(parser.parse_args()) 392 | -------------------------------------------------------------------------------- /util_debugger.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import json 3 | import pandas as pd 4 | from smdebug.profiler.system_metrics_reader import S3SystemMetricsReader 5 | from smdebug.profiler.algorithm_metrics_reader import S3AlgorithmMetricsReader 6 | from smdebug.profiler.analysis.notebook_utils.metrics_histogram import MetricsHistogram 7 | from smdebug.profiler.analysis.notebook_utils.heatmap import Heatmap 8 | 9 | """ 10 | Method to get system metrics for the training job 11 | """ 12 | 13 | 14 | def get_sys_metric(train_estimator, num_cpu, num_gpu): 15 | path = train_estimator.latest_job_profiler_artifacts_path() 16 | system_metrics_reader = S3SystemMetricsReader(path) 17 | framework_metrics_reader = S3AlgorithmMetricsReader(path) 18 | 19 | """ 20 | Metric histograms and Heatmaps of system usage 21 | """ 22 | 23 | dim_to_plot = ["CPU", "GPU"] 24 | events_to_plot = [] 25 | 26 | for x in range(num_cpu): 27 | events_to_plot.append("cpu"+str(x)) 28 | for x in range(num_gpu): 29 | events_to_plot.append("gpu"+str(x)) 30 | 31 | system_metrics_reader.refresh_event_file_list() 32 | framework_metrics_reader.refresh_event_file_list() 33 | 34 | resultant_heatmap = Heatmap( 35 | system_metrics_reader, 36 | framework_metrics_reader, 37 | select_dimensions=dim_to_plot, 38 | select_events=events_to_plot, 39 | plot_height=400 40 | ) 41 | 42 | system_metrics_reader.refresh_event_file_list() 43 | resultant_metric_hist = MetricsHistogram(system_metrics_reader).plot( 44 | select_dimensions=dim_to_plot, 45 | select_events=events_to_plot 46 | ) 47 | 48 | """ 49 | Fetching system statistics from profiler report 50 | """ 51 | 52 | profiler_report_name = [ 53 | rule["RuleConfigurationName"] 54 | for rule in train_estimator.latest_training_job.rule_job_summary() 55 | if "Profiler" in rule["RuleConfigurationName"]][0] 56 | profiler_report_name 57 | 58 | rule_output_path = train_estimator.output_path + '/' + train_estimator.latest_training_job.job_name + "/rule-output" 59 | s3_sys_usage_json_path = rule_output_path + "/" + profiler_report_name + "/profiler-output/profiler-reports/OverallSystemUsage.json" 60 | print ("Fetching data from: ", s3_sys_usage_json_path) 61 | 62 | path_without_s3_pre = s3_sys_usage_json_path.split("//",1)[1] 63 | s3_bucket = path_without_s3_pre.split("/",1)[0] 64 | s3_prefix = path_without_s3_pre.split("/",1)[1] 65 | 66 | s3 = boto3.resource('s3') 67 | content_object = s3.Object(s3_bucket, s3_prefix) 68 | json_content = content_object.get()['Body'].read().decode('utf-8') 69 | sys_usage = json.loads(json_content) 70 | 71 | cpu_usage = sys_usage["Details"]["CPU"]["algo-1"] 72 | gpu_usage = sys_usage["Details"]["GPU"]["algo-1"] 73 | sys_util_data = [['CPU', num_cpu, cpu_usage["p50"], cpu_usage["p95"], cpu_usage["p99"]], 74 | ['GPU', num_gpu, gpu_usage["p50"], gpu_usage["p95"], gpu_usage["p99"]], 75 | ] 76 | sys_util_df = pd.DataFrame(sys_util_data, columns = ['Metric', '#', 'p50', 'p95', 'p99']) 77 | 78 | return resultant_heatmap, resultant_metric_hist, sys_util_df 79 | 80 | -------------------------------------------------------------------------------- /util_train.py: -------------------------------------------------------------------------------- 1 | import sagemaker 2 | from sagemaker.pytorch import PyTorch 3 | from datetime import datetime 4 | 5 | 6 | def aug_exp_train(model_arch, batch_size, aug_operator, aug_load_factor, instance_type, sm_role): 7 | 8 | # Amazon S3 bucket and prefix for fetching and storing data: 9 | BUCKET = 'sm-aug-test' 10 | 11 | # Full size download of https://github.com/fastai/imagenette 12 | # 1.3GB — 13,395 (9469 train, 3925 val images) from 10 classes 13 | train_data_s3 = 's3://{}/{}'.format(BUCKET, 'imagenette2') 14 | src_code_s3 = 's3://{}/{}'.format(BUCKET, 'training_jobs') 15 | training_job_output_s3 = 's3://{}/{}'.format(BUCKET, 'training_jobs_output') 16 | 17 | # Encapsulate training on SageMaker with PyTorch: 18 | train_estimator = PyTorch(entry_point='sm_augmentation_train-script.py', 19 | source_dir='./src', 20 | role=sm_role, 21 | framework_version='1.8.1', 22 | py_version='py3', 23 | debugger_hook_config=False, 24 | 25 | instance_count=1, 26 | instance_type=instance_type, 27 | 28 | output_path=training_job_output_s3, 29 | code_location=src_code_s3, 30 | 31 | hyperparameters={'epochs': 2, 32 | 'backend': 'nccl', 33 | 'model-type': model_arch, 34 | 'lr': 0.001, 35 | 'batch-size': batch_size, 36 | 'aug': aug_operator, 37 | 'aug-load-factor': aug_load_factor 38 | } 39 | ) 40 | 41 | # Setting up File-system to import data from S3 42 | data_channels = {'train': sagemaker.inputs.TrainingInput( 43 | s3_data_type='S3Prefix', 44 | s3_data=train_data_s3, 45 | content_type='image/jpeg', 46 | input_mode='File'), 47 | 'val': sagemaker.inputs.TrainingInput( 48 | s3_data_type='S3Prefix', 49 | s3_data=train_data_s3, 50 | content_type='image/jpeg', 51 | input_mode='File') 52 | } 53 | 54 | # Launching SageMaker training job 55 | train_job_id = 'sm-aug-' + str(datetime.now().strftime("%H-%M-%S")) 56 | train_estimator.fit(inputs=data_channels, job_name=train_job_id) 57 | 58 | return train_job_id, train_estimator 59 | --------------------------------------------------------------------------------