├── Dockerfile ├── LICENSE ├── README.md ├── Singularity └── workspace ├── jupyter_notebook ├── 01_introduction.ipynb ├── 02_pytorch_mnist.ipynb ├── 03_data_transfer.ipynb ├── 04_tensor_core_util.ipynb ├── 05_summary.ipynb ├── images │ ├── Baseline.jpg │ ├── Built-in_NVTX.jpg │ ├── NVTX_annotations.jpg │ ├── NumberOfWorkers.jpg │ ├── Optimization2.jpg │ ├── Optimization3.jpg │ ├── Optimization_workflow.jpg │ ├── PageableVsPinned.jpg │ ├── ShowInEventsView.jpg │ ├── StarvationDuringDataLoading.jpg │ ├── TensorCoreUsage.jpg │ ├── Training_program.jpg │ ├── amp.jpg │ ├── architecture_tensor_cores.jpg │ ├── browse_port.jpg │ ├── cudaProfilerApi.jpg │ ├── cudaprofilestart.png │ ├── fp16_opt22.jpg │ ├── fp16_opt33.jpg │ ├── gpu_kernel_view.jpg │ ├── memory_view_menu.jpg │ ├── memory_view_opt11.jpg │ ├── memory_view_opt22.jpg │ ├── mnist.jpg │ ├── module_view.jpg │ ├── module_view_opt1.jpg │ ├── nsight_flow.png │ ├── nsys_CUDA_api.png │ ├── nsys_GPU_view.png │ ├── nsys_blocked_state.png │ ├── nsys_blocked_state_1.png │ ├── nsys_cpu_api_thread.png │ ├── nsys_cpu_summary.png │ ├── nsys_search.png │ ├── nsys_timeline.png │ ├── nvtx_annotation.jpg │ ├── operator_view.jpg │ ├── operator_view_code_trace.jpg │ ├── operator_view_tbl.jpg │ ├── performance_recommendation.jpg │ ├── performance_recommendation_opt1.jpg │ ├── performance_recommendation_opt2.jpg │ ├── port_forwarding.jpg │ ├── profile_summary.jpg │ ├── profile_summary_opt1.jpg │ ├── profile_summary_opt11.jpg │ ├── profile_summary_opt11_memcpy.jpg │ ├── profile_summary_opt2.jpg │ ├── profile_summary_opt22.jpg │ ├── profile_summary_opt22_tensor_core.jpg │ ├── profile_summary_opt33.jpg │ ├── report_02_1.jpg │ ├── report_02_2.jpg │ ├── report_02_3.jpg │ ├── report_02_4.jpg │ ├── report_activate_tensor.jpg │ ├── report_api_call.jpg │ ├── report_baseline.jpg │ ├── report_final_time.jpg │ ├── report_firstoptimization.jpg │ ├── report_gpu_starvation.jpg │ ├── report_kernel_coverage.jpg │ ├── report_nvtx.png │ ├── report_pinned_memory.jpg │ ├── report_show_current_timeline.jpg │ ├── report_show_in_events_view.jpg │ ├── side_menu.jpg │ ├── side_menu_list.jpg │ ├── step_time_breakdown.jpg │ ├── step_time_breakdown_opt1.jpg │ ├── step_time_breakdown_opt11.jpg │ ├── step_time_breakdown_opt2.jpg │ ├── step_time_breakdown_opt33.jpg │ ├── tb_kernel_view.png │ ├── tb_main.jpg │ ├── tb_main.png │ ├── tb_overview.png │ ├── tb_train.png │ ├── tensor_core_util_opt22.jpg │ ├── tensor_cores.jpg │ ├── trace_view.jpg │ ├── trace_view1.jpg │ ├── trace_view_opt1.jpg │ ├── trace_view_opt11.jpg │ ├── trace_view_opt2.jpg │ ├── trace_view_opt33.jpg │ └── train.png ├── tb01_introduction.ipynb ├── tb02_pytorch_mnist.ipynb ├── tb03_data_transfer.ipynb ├── tb04_tensor_core_util.ipynb └── tb05_summary.ipynb ├── log └── mnist_0 │ ├── dgx01_2422087.1661161405176.pt.trace.json │ └── dgx01_2422087.1661161405859.pt.trace.json ├── reports ├── baseline.nsys-rep ├── baseline_nvtx.nsys-rep ├── firstOptimization.nsys-rep ├── secondOptimization.nsys-rep └── thirdOptimization.nsys-rep ├── source_code ├── LICENSE.txt ├── Nsight_Systems_User_Guide_2022.2.1.31-5fe97ab.pdf ├── data │ └── MNIST │ │ ├── LICENSE.txt │ │ ├── processed │ │ ├── test.pt │ │ └── training.pt │ │ └── raw │ │ ├── t10k-images-idx3-ubyte │ │ ├── t10k-images-idx3-ubyte.gz │ │ ├── t10k-labels-idx1-ubyte │ │ ├── t10k-labels-idx1-ubyte.gz │ │ ├── train-images-idx3-ubyte │ │ ├── train-images-idx3-ubyte.gz │ │ ├── train-labels-idx1-ubyte │ │ └── train-labels-idx1-ubyte.gz ├── main_baseline.py ├── main_baseline_nvtx.py ├── main_opt1.py ├── main_opt2.py ├── main_opt3.py ├── tb_main_baseline_profiler.py ├── tb_main_opt1.py ├── tb_main_opt2.py └── tb_main_opt3.py └── start_here.ipynb /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.04-py3 2 | 3 | RUN apt-get -y update 4 | RUN pip3 install jupyterlab 5 | RUN pip3 install ipywidgets 6 | RUN pip3 install torch_tb_profiler 7 | 8 | ##### 9 | # Read https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772 10 | RUN apt-get update -y && \ 11 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 12 | apt-transport-https \ 13 | ca-certificates \ 14 | gnupg \ 15 | wget && \ 16 | rm -rf /var/lib/apt/lists/* 17 | RUN wget -qO - https://developer.download.nvidia.com/devtools/repos/ubuntu2004/amd64/nvidia.pub | apt-key add - && \ 18 | echo "deb https://developer.download.nvidia.com/devtools/repos/ubuntu2004/amd64/ /" >> /etc/apt/sources.list.d/nsight.list && \ 19 | apt-get update -y && \ 20 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 21 | nsight-systems-2022.1.1 && \ 22 | rm -rf /var/lib/apt/lists/* 23 | 24 | 25 | ################################################# 26 | ENV LD_LIBRARY_PATH="/usr/local/lib:/usr/local/lib/python3.8/dist-packages:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" 27 | ENV PATH="/opt/nvidia/nsight-systems/2022.2.1/bin:/usr/local/bin:/bin:/usr/local/cuda/bin:/usr/bin${PATH:+:${PATH}}" 28 | 29 | # TO COPY the data 30 | #COPY workspace/ /workspace 31 | 32 | 33 | #CMD jupyter-lab --no-browser --allow-root --ip=0.0.0.0 --port=8888 --NotebookApp.token="" --notebook-dir=/workspace 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimizing a Deep Neural Network (DNN) training program 2 | 3 | This folder contains contents for AI training program profiling. 4 | 5 | - NVIDIA Nsight Systems 6 | - PyTorch Profiler with TensorBoard Plugin 7 | - TensorBoard Visualization 8 | - Optimization Techniques 9 | 10 | ## Prerequisites 11 | To run this tutorial you will need a machine with NVIDIA GPU. 12 | 13 | - Install the latest [Docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) or [Singularity](https://sylabs.io/docs/). 14 | - To be able to see the profiler output, please download NVIDIA Nsight Systems' latest version from [here](https://developer.nvidia.com/nsight-systems). 15 | - Linux ubuntu OS 16 | 17 | 18 | ## Running on containers 19 | To start with, you will have to build a Docker or Singularity container. 20 | 21 | ### Docker Container 22 | To build a docker container, run: 23 | `sudo docker build --network=host -t : .` 24 | 25 | For instance: 26 | `sudo docker build -t pytorch:1.0 .` 27 | 28 | The code labs have been written using Jupyter notebooks and a Dockerfile has been built to simplify deployment. In order to serve the docker instance for a student, it is necessary to expose port 8888 from the container, for instance, the following command would expose port 8888 inside the container as port 8888 on the lab machine: 29 | 30 | `sudo docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -it --rm --network=host -v ~/ai_profiler/workspace:/workspace pytorch:1.0 jupyter-lab --no-browser --allow-root --ip=0.0.0.0 --port=8888 --NotebookApp.token="" --notebook-dir=/workspace` 31 | 32 | The `--gpus` flag is used to enable `all` NVIDIA GPUs during container runtime. The `--rm` flag is used to clean an temporary images created during the running of the container. The `-it` flag enables killing the jupyter server with ctrl-c. 33 | 34 | The `--ipc=host --ulimit memlock=-1 --ulimit stack=67108864` enable sufficient memory allocation to run pytorch within the docker environment. 35 | 36 | The `jupyter-lab --no-browser --allow-root --ip=0.0.0.0 --port=8888 --NotebookApp.token="" --notebook-dir=/workspace` command launch the jupyter notebook inside the container. The flag `-v` allows the mapping of working directory on your local machine `~/ai_profiler/profiler/workspace:/workspace` to `worspace` directory inside the container. 37 | 38 | Then, open the jupyter notebook in browser: http://localhost:8888 39 | Start working on the lab by clicking on the `start_here.ipynb` notebook. 40 | 41 | ### Singularity Container 42 | 43 | To build the singularity container, run: 44 | `sudo singularity build --fakeroot .simg Singularity` 45 | 46 | Fore example: 47 | `singularity build --fakeroot pytorch.simg Singularity` 48 | 49 | Then, run the container: 50 | `singularity run --nv --bind ~/ai_profiler/workspace:/workspace pytorch.simg jupyter-lab --no-browser --allow-root --ip=0.0.0.0 --port=8888 --NotebookApp.token="" --notebook-dir=/workspace` 51 | 52 | The `--nv` flag is used to enable `all` NVIDIA GPUs during container runtime. The `--bind` allows the mapping of working directory on your local machine `~/ai_profiler/profiler/workspace:/workspace` to `worspace` directory inside the container. 53 | 54 | Then, open the jupyter notebook in browser: http://localhost:8888 55 | Start working on the lab by clicking on the `Start_Here.ipynb` notebook. 56 | 57 | 58 | ## Running on Local Machine 59 | 60 | - Install PyTorch [here](https://pytorch.org/get-started/locally/) 61 | - Install essentials: 62 | ``` 63 | pip3 install jupyterlab 64 | pip3 install ipywidgets 65 | pip3 install torch_tb_profiler 66 | ``` 67 | - Install NVIDIA Nsight Systems version 2022.1.1 from [here](https://developer.nvidia.com/nsight-systems) and set path. Please run `nsys --version` from the terminal to ensure you are using the version 2022.1.1 or above 68 | 69 | 70 | 71 | # Tutorial Duration 72 | 73 | The total bootcamp material would take 2 hours. 74 | -------------------------------------------------------------------------------- /Singularity: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: nvcr.io/nvidia/pytorch:22.04-py3 3 | 4 | %runscript 5 | 6 | "$@" 7 | 8 | %post 9 | 10 | apt-get -y update 11 | pip3 install jupyterlab 12 | pip3 install ipywidgets 13 | pip3 install torch_tb_profiler 14 | 15 | # Read https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772 16 | apt-get update -y && \ 17 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 18 | apt-transport-https \ 19 | ca-certificates \ 20 | gnupg \ 21 | wget && \ 22 | rm -rf /var/lib/apt/lists/* 23 | wget -qO - https://developer.download.nvidia.com/devtools/repos/ubuntu2004/amd64/nvidia.pub | apt-key add - && \ 24 | echo "deb https://developer.download.nvidia.com/devtools/repos/ubuntu2004/amd64/ /" >> /etc/apt/sources.list.d/nsight.list && \ 25 | apt-get update -y && \ 26 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 27 | nsight-systems-2022.1.1 && \ 28 | rm -rf /var/lib/apt/lists/* 29 | 30 | %files 31 | 32 | #../English/* /workspace/ 33 | 34 | %environment 35 | export XDG_RUNTIME_DIR= 36 | export LD_LIBRARY_PATH="/usr/local/lib:/usr/local/lib/python3.8/dist-packages:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" 37 | export PATH="/opt/nvidia/nsight-systems/2022.2.1/bin:/usr/local/bin:/bin:/usr/local/cuda/bin:/usr/bin${PATH:+:${PATH}}" 38 | 39 | %labels 40 | 41 | AUTHOR Tosin 42 | -------------------------------------------------------------------------------- /workspace/jupyter_notebook/01_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Home Page

\n", 8 | "\n", 9 | " \n", 10 | "
\n", 11 | " \n", 12 | " 1\n", 13 | " 2\n", 14 | " 3\n", 15 | " 4\n", 16 | " 5\n", 17 | " \n", 18 | " Next Notebook\n", 19 | "
\n" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Part 1: Using NVIDIA® Nsight™ Systems\n", 27 | "---" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "This lab gives an overview of the NVIDIA developer tools and steps to profile an application. The focus of this lab is to familiarize you with commonly used features of NVIDIA Nsight Systems graphic user interface (GUI).\n", 35 | "\n", 36 | "\n", 37 | "## What is profiling\n", 38 | "\n", 39 | "Profiling is the first step in optimizing and tuning your application. Profiling an application helps us understand where most of the execution time is spent, providing an understanding of its performance characteristics and identifying parts of the code that present opportunities for improvement. Finding hotspots and bottlenecks in your application can help you decide where to focus your optimization efforts." 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## NVIDIA Developer Tools\n", 47 | "\n", 48 | "NVIDIA developer tools (Nsight Systems, Nsight Compute, Nsight Graphics) are a collection of applications, spanning desktop and mobile targets, which enable developers to build, debug, profile, and develop class- leading and cutting-edge software that utilizes the latest visual computing hardware from NVIDIA.\n", 49 | "\n", 50 | "Your profiling workflow will change to reflect the individual NVIDIA developer tool selected. Start with `Nsight Systems` to get a system-level overview of the workload, eliminate any system-level bottlenecks, such as unnecessary thread synchronization or data movement, and improve the system-level parallelism of your algorithms to scale efficiently across any number or size of central processing units (CPUs) and GPUs. Once you have done that, proceed to `Nsight Compute` or `Nsight Graphics` to optimize the most significant NVIDIA CUDA® kernels or graphics workloads, respectively. Periodically return to Nsight Systems to ensure that you remain focused on the largest bottleneck, otherwise the bottleneck may have shifted and your kernel level optimizations may not achieve as high an improvement as expected.\n", 51 | "\n", 52 | "- Nsight Systems analyze application algorithm system-wide\n", 53 | "- Nsight Compute debug and optimize CUDA kernels\n", 54 | "- Nsight Graphics debug and optimize graphic workloads" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "\n", 69 | "\n", 70 | "## Nsight Systems\n", 71 | "\n", 72 | "NVIDIA [Nsight Systems](https://developer.nvidia.com/nsight-systems) offers system-wide performance analysis in order to visualize application’s algorithms, help identify optimization opportunities, and improve the performance of applications running on a system consisting of multiple CPUs and GPUs.\n", 73 | "\n", 74 | "The typical optimization workflow using NVIDIA Nsight Systems looks like the following:\n", 75 | "\n", 76 | "\n", 77 | "\n", 78 | "It is an iterative process with 3 main steps:\n", 79 | "1. Profile the application\n", 80 | "2. Inspect and analyze the profile to identify any bottlenecks\n", 81 | "3. Optimize the application to address the bottlenecks\n", 82 | "\n" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Brief Highlight of NVIDIA Nsight Systems (GUI) Timeline View \n", 90 | "\n", 91 | "This is a profile of the DeepStream reference application . The main features of the timeline view consist of the following:\n" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "- **CPU view, Application Programming Interfaces (APIs) Traced, Thread Utilization, Core, Thread State, and CPU Sampling Points**\n", 99 | "\n", 100 | "At the top of the timeline view, you can see the CPU view which shows how the application is utilizing the CPU cores on the system, the processes, and the operating system (OS) thread running through the application. On each OS thread, you can view the `trace` of all the `APIs` made on that thread. NVIDIA Nsight Systems has the ability to trace 20 different APIs such as `CUDA`, `cuDNN`, `cuBLAS`, `NVTX`, `OS runtimes libraries` such as calls to `pThread libraries`, `file I/O`, etc. For each OS thread, you can view the thread state changes and its migrations across CPU cores.\n", 101 | "\n", 102 | "" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "- **CPU Sampling Summary**\n", 110 | "\n", 111 | "At the bottom of the timeline view, you can see a statistical summary of the CPU sampling data which helps to quickly identify hot functions(functions that consume more time) on the CPU. \n", 112 | "\n", 113 | "" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "- **Blocked State Backtrace**\n", 121 | "\n", 122 | "For long running calls into OS runtime libraries, Nsight Systems captures `Backtraces` which helps identify problematic parts of the code that are causing the threads to block.\n", 123 | "\n", 124 | "" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "\n", 132 | "- **GPU View**\n", 133 | "\n", 134 | "Near the bottom part of the timeline views, you can see how the application is utilizing various GPUs on the system, the `kernels`, and the `memory` operations that transpired on the GPU. When you hover your mouse over any of these events, a pop-up box will appear to display detailed operations about the event.\n", 135 | "\n", 136 | "" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "- **Search within a Timeline Row**\n", 144 | "\n", 145 | "To the right of the `Events View`, you can `search` within the timeline row.\n", 146 | "\n", 147 | "" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "- **CUDA API Backtrace**\n", 155 | "\n", 156 | "For long running CUDA APIs calls, you can capture `Backtraces` which can help identify problematic parts of the code.\n", 157 | "\n", 158 | "" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "The next notebook explains how to start the profiling of a simple DNN training program. Please click on the `Next Notebook` link at the bottom to get started." 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "## Acknowledgments\n", 173 | "\n", 174 | "Images used in this notebook were extracted from `Sneha Kottapalli's`, April 16, 2021 presentation on `NVIDIA NSIGHT SYSTEMS` at the GTC.\n", 175 | "\n", 176 | "\n", 177 | "## Links and Resources\n", 178 | "\n", 179 | "\n", 180 | "[NVIDIA Nsight Systems](https://docs.nvidia.com/nsight-systems/)\n", 181 | "\n", 182 | "\n", 183 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 184 | "\n", 185 | "\n", 186 | "You can also get resources from [Open Hackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 187 | "\n", 188 | "\n", 189 | "--- \n", 190 | "\n", 191 | "## Licensing \n", 192 | "\n", 193 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply." 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "
\n", 201 | " \n", 202 | " 1\n", 203 | " 2\n", 204 | " 3\n", 205 | " 4\n", 206 | " 5\n", 207 | " \n", 208 | " Next Notebook\n", 209 | "
\n", 210 | "\n", 211 | "
\n", 212 | "

Home Page

" 213 | ] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": "Python 3", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 3 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython3", 232 | "version": "3.8.8" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 4 237 | } 238 | -------------------------------------------------------------------------------- /workspace/jupyter_notebook/03_data_transfer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Home Page

\n", 8 | "\n", 9 | " \n", 10 | "
\n", 11 | " \n", 12 | " 1\n", 13 | " 2\n", 14 | " 3\n", 15 | " 4\n", 16 | " 5\n", 17 | " \n", 18 | " Next Notebook\n", 19 | "
" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Part 1: Data Transfers Between Host (CPU) and Device (GPU)\n", 27 | "---" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "The objective of this notebook is to optimize data transfer between Host and Device. \n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Analyze the Report\n", 42 | "\n", 43 | "Let's analyze the data transfers between host and graphics processing unit (GPU) in the report `firstOptimization.nsys-rep` from the first optimization step. Open the report in the NVIDIA® Nsight™ Systems graphical user interface (GUI). Expand the `NVIDIA CUDA® device row` by clicking on the tiny triangle in front of it. Select the `Memory` row and right-click to choose `Show in Events View` option as shown below.\n", 44 | "\n", 45 | "\n" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "This populates the `Events View` window with the memory operations listed in chronological order. Click on the `Duration` column header to sort the table in the Events View by duration so that the longest memory operation shows up first. Right-click on the first entry in the table and select \"Show Current on Timeline\" as illustrated below.\n", 53 | "\n", 54 | "\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "This zooms into the event on the timeline and the teal highlights help you find the CUDA API call, `cudaMemcpyAsync`, that initiated the memory operation on the GPU (see the image below). Note: You may have to zoom out and/or scroll up to find the CUDA API call on the CPU thread.\n", 62 | "\n", 63 | "" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "\n", 71 | "You notice the following from the timeline:\n", 72 | "- All Host-to-Device (HtoD) memcopies are using pageable memory which is:\n", 73 | " - slower and, \n", 74 | " - causes the `cudaMemcpyAsync` API call on the CPU thread to block until the operation completes on the GPU.\n", 75 | "- The longest memcpy operation takes ~385 microseconds to complete on the GPU.\n", 76 | "- The CUDA API call (`cudaMemcpyAsync`) corresponding to the longest memcpy operation is almost 0.5ms long." 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## Optimize the Application to Use Pinned Memory\n", 84 | "\n", 85 | "Host (CPU) memory allocations are pageable by default. The GPU cannot access data directly from pageable host memory. When a data transfer is invoked from pageable host memory to device memory, the CUDA driver must first allocate a temporary page-locked (or “pinned”) host array, copy the host data to the pinned array, and then transfer the data from the pinned array to the device memory. The pinned memory is used as a staging area for transfers from the host to the device. By directly allocating our host data to pinned memory, we can avoid this extra step and its overhead. See the blog [post](https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/) for more details.\n", 86 | "\n", 87 | "\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "The settings used for the data loader [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) in our application relies on the default value of `pin_memory: False`. Execute the cell below to see the code change made `(in green color)` to use pinned memory." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 1, 100 | "metadata": { 101 | "scrolled": true, 102 | "tags": [] 103 | }, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "\u001b[1m--- ../source_code/main_opt1.py\t2022-08-18 03:51:54.178624632 +0900\u001b[0m\n", 110 | "\u001b[1m+++ ../source_code/main_opt2.py\t2022-08-19 23:25:01.750840799 +0900\u001b[0m\n", 111 | "\u001b[36m@@ -160,8 +160,9 @@\u001b[0m\n", 112 | " test_kwargs = {'batch_size': args.test_batch_size}\n", 113 | " if use_cuda:\n", 114 | " #multiprocessing.cpu_count()\n", 115 | " cuda_kwargs = {'num_workers': 2,\n", 116 | "\u001b[32m+ 'pin_memory': True,\u001b[0m\n", 117 | " 'shuffle': True}\n", 118 | " train_kwargs.update(cuda_kwargs)\n", 119 | " test_kwargs.update(cuda_kwargs)\n", 120 | " \n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "!diff -U4 --color=always ../source_code/main_opt1.py ../source_code/main_opt2.py" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "## Profile Again to Verify Optimization\n", 133 | "Profile again by executing the cell given below to verify if the code change addresses the problem with host-to-device memory transfers after setting `pin_memory: True`." 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "scrolled": true, 141 | "tags": [] 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "!nsys profile --trace cuda,osrt,nvtx \\\n", 146 | "--capture-range cudaProfilerApi \\\n", 147 | "--gpu-metrics-device=all \\\n", 148 | "--output ../reports/secondOptimization \\\n", 149 | "--force-overwrite true \\\n", 150 | "python3 ../source_code/main_opt2.py" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "Open the report (secondOptimization.nsys-rep) in the GUI. Similar to how we navigated the timeline previously, expand the `CUDA device` row and select the `Memory` row and right-click to choose `Show in Events View`. Sort the table in the `Events View` by duration to list the longest memory operation first. Right-click on the topmost event to select `Show current on timeline`. You should see the view as shown below.\n", 158 | "\n", 159 | "\n" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "In the profile collected after optimization, we observed that:\n", 167 | "- All HtoD memcopies now use pinned memory,\n", 168 | "- The longest memcpy is now only `183µs` compared to`~385µs` before optimization, and\n", 169 | "- The `cudaMemcpyAsync` API call corresponding to the longest memcpy is now reduced from `490µs` to `~36µs`.\n", 170 | "\n", 171 | "Now that we have addressed a bottleneck with memory transfers, let's identify the next performance bottleneck by clicking on the `Next Notebook` link below." 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "## Links and Resources\n", 179 | "\n", 180 | "\n", 181 | "[NVIDIA Nsight Systems](https://docs.nvidia.com/nsight-systems/)\n", 182 | "\n", 183 | "\n", 184 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 185 | "\n", 186 | "\n", 187 | "You can also get resources from [Open Hackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 188 | "\n", 189 | "\n", 190 | "--- \n", 191 | "\n", 192 | "## Licensing \n", 193 | "\n", 194 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply." 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "
\n", 202 | " \n", 203 | " 1\n", 204 | " 2\n", 205 | " 3\n", 206 | " 4\n", 207 | " 5\n", 208 | " \n", 209 | " Next Notebook\n", 210 | "
\n", 211 | "\n", 212 | "
\n", 213 | "

Home Page

" 214 | ] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "Python 3", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.8.8" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 4 238 | } 239 | -------------------------------------------------------------------------------- /workspace/jupyter_notebook/04_tensor_core_util.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Home Page

\n", 8 | "\n", 9 | " \n", 10 | "
\n", 11 | " \n", 12 | " 1\n", 13 | " 2\n", 14 | " 3\n", 15 | " 4\n", 16 | " 5\n", 17 | " \n", 18 | " Next Notebook\n", 19 | "
" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Part 1: Tensor Cores \n", 27 | "---" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "The goal of this notebook is to show how to enable mixed precision (FP32/FP16) on the Tensor Core to further optimize our application.\n", 35 | "\n", 36 | "## Tensor Core Usage\n", 37 | "\n", 38 | "Tensor cores are specialized processing units designed to accelerate the process of tensor/matrix multiplication. Tensor Cores enable mixed-precision computing, dynamically adapting calculations to accelerate throughput while preserving accuracy. Our application runs on the `NVIDIA® DGX™ A100 Ampere architecture` GPU. You can also run the application on other GPU architectures, for example `NVIDIA Turing™ architecture` which has Tensor Core precision.\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "
" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "The screenshot below shows a table of NVIDIA GPU architectures and supported Tensor Core precisions." 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "
\n", 60 | "
Source: NVIDIA website
" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "## Analyze the Profile Report\n", 68 | "To verify if the application uses Tensor Cores, we will use a new feature in NVIDIA Nsight™ Systems: **GPU performance metrics sampling**. Notice in the previous notebook, to profile the application after the second optimization we used the Nsight Systems `--gpu-metrics-device=all` CLI option. This enables the collection of the new feature and is intended to measure the utilization of different GPU subsystems. Hardware counters within the GPU are periodically read and used to generate performance metrics." 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "Let's analyze the application's Tensor Cores usage by examining the report `(secondOptimization.nsys-rep)` in the Nsight Systems GUI. Scroll down to the bottom of the timeline until you see the rows for GPU metrics. Expand the `SM instructions` timeline row to see the `Tensor Active` which represents the ratio of `cycles the SM tensor pipes or FP16x2 pipes were active issuing tensor instructions` to `the number of cycles in the sample period` as a percentage." 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "As shown in the screenshot above, the percentage graph is an `average of 5.7%` and `maximum of 45%`, so the application already uses the Tensor Cores on the A100 GPU. But, this is not the case for other architectures. For example, after examining the secondOptimization.nsys-rep from NVIDIA Turing™ GPU architecture, the percentage graph is zero at `Tensor Active/FP16 Activate`. Therefore, Tensor Core utilization has to be explicitly enabled using, for example, `automatic mixed precision (AMP)`." 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "\n", 97 | "\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "## Automatic Mixed Precision (AMP)\n", 105 | "\n", 106 | "Mixed Precision is the combined use of different numerical format `(single and half-precision computation)` in the training of a deep neural network.\n", 107 | "- single precision: FP32\n", 108 | "- half precision: FP16\n", 109 | "\n", 110 | "The use of mixed precision is possible in NVIDIA GPU architectures such as `Ampere`, `Volta™` , and `Turing`. The benefits include:\n", 111 | "\n", 112 | "- speed up of math-intensive operations using tensor cores,\n", 113 | "- require less memory bandwidth, thus data transfer operations are speedup, and\n", 114 | "- require less memory thus, the training and deployment of larger neural networks are possible.\n", 115 | "\n", 116 | "AMP automates the process of training using mixed precision through deep neural network (DNN) frameworks. PyTorch has the [Automatic Mixed Precision (AMP)](https://pytorch.org/docs/stable/amp.html) package which provides a simple way for users to convert existing FP32 training scripts to mixed FP32 and FP16 precision. This unlocks faster computation with Tensor Cores on NVIDIA GPUs. In the screenshot below you will see the code changes made `(in green color frame)` to use the AMP package in PyTorch." 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "Profile again and verify the code change addresses Tensor Core usage on the Turing GPU. " 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": { 137 | "scrolled": true, 138 | "tags": [] 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "!nsys profile --trace cuda,osrt,nvtx \\\n", 143 | "--capture-range cudaProfilerApi \\\n", 144 | "--gpu-metrics-device=all \\\n", 145 | "--output ../reports/thirdOptimization_env \\\n", 146 | "--force-overwrite true \\\n", 147 | "python3 ../source_code/main_opt3.py" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "Open the report (thirdOptimization.nsys-rep) in the GUI. Scroll down to view the `Tensor Active / FP16 Active` timeline row.\n", 155 | "\n", 156 | "\n", 157 | "\n", 158 | "Now, we can see the Tensor Cores usage on the Turing GPU. Note that the main contribution of AMP is that it reduces the kernel time using Tensor Cores thereby achieving a speedup.\n", 159 | "\n", 160 | "## Compare the Performance Before and After the Optimizations\n", 161 | "Now that we have addressed three different performance problems, we will time the application [main_opt3.py](../source_code/main_opt3.py)." 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": { 168 | "scrolled": true, 169 | "tags": [] 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "!cd ../source_code && time python3 main_opt3.py" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "**Expected output on A100 GPUs**:\n", 181 | "\n", 182 | "```python\n", 183 | "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.308961\n", 184 | "\n", 185 | "Test set: Average loss: 0.1024, Accuracy: 9683/10000 (97%)\n", 186 | "\n", 187 | "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.154755\n", 188 | "\n", 189 | "Test set: Average loss: 0.0608, Accuracy: 9814/10000 (98%)\n", 190 | "\n", 191 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.110753\n", 192 | "\n", 193 | "Test set: Average loss: 0.0535, Accuracy: 9827/10000 (98%)\n", 194 | "\n", 195 | "----------------------------------------------------------\n", 196 | "\n", 197 | "Train Epoch: 9 [0/60000 (0%)]\tLoss: 0.059646\n", 198 | "\n", 199 | "Test set: Average loss: 0.0370, Accuracy: 9865/10000 (99%)\n", 200 | "\n", 201 | "Train Epoch: 10 [0/60000 (0%)]\tLoss: 0.055018\n", 202 | "\n", 203 | "Test set: Average loss: 0.0365, Accuracy: 9865/10000 (99%)\n", 204 | "\n", 205 | "\n", 206 | "real\t1m24.619s\n", 207 | "user\t2m35.069s\n", 208 | "sys\t 0m7.676s\n", 209 | "\n", 210 | "```\n" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "Comparing the time taken to run our baseline code [main_baseline.py](../source_code/main_baseline_nvtx.py) from [notebook 2](02_pytorch_mnist.ipynb) with the code after applying the three recent optimizations so far [main_opt3.py](../source_code/main_opt3.py), we see that the overall time taken has reduced as shown in the table below.\n", 218 | "\n", 219 | "\n", 220 | "|Training code| Time|speedup|\n", 221 | "|--|--|--|\n", 222 | "|basline| 113s|-|\n", 223 | "|optimized|~85|1.3x|\n" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "## Links and Resources\n", 231 | "\n", 232 | "\n", 233 | "[NVIDIA Nsight Systems](https://docs.nvidia.com/nsight-systems/)\n", 234 | "\n", 235 | "\n", 236 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 237 | "\n", 238 | "\n", 239 | "You can also get resources from [Open Hackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 240 | "\n", 241 | "\n", 242 | "--- \n", 243 | "\n", 244 | "## Licensing \n", 245 | "\n", 246 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply." 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "
\n", 254 | " \n", 255 | " 1\n", 256 | " 2\n", 257 | " 3\n", 258 | " 4\n", 259 | " 5\n", 260 | " \n", 261 | " Next Notebook\n", 262 | "
\n", 263 | "\n", 264 | "
\n", 265 | "

Home Page

" 266 | ] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "Python 3", 272 | "language": "python", 273 | "name": "python3" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.8.8" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 4 290 | } 291 | -------------------------------------------------------------------------------- /workspace/jupyter_notebook/05_summary.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Home Page

\n", 8 | "\n", 9 | " \n", 10 | "
\n", 11 | " \n", 12 | " 1\n", 13 | " 2\n", 14 | " 3\n", 15 | " 4\n", 16 | " 5\n", 17 | " \n", 18 | "
" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "Part 1 SUMMARY\n", 26 | "---\n", 27 | "\n", 28 | "In the previous notebooks, you learned how to:\n", 29 | "\n", 30 | "- Profile a sample application using NVIDIA® Nsight™ Systems CLI commands. \n", 31 | "- Apply NVIDIA Tools Extension SDK (NVTX) annotations within the application code.\n", 32 | "- Interpret the profile report on the timeline provided by NVIDIA Nsight Systems.\n", 33 | "- Identify performance problems, apply optimization strategies, and confirm the performance improvement gained.\n", 34 | "- Activate the use of automatic mixed precision (AMP) and track Tensor Core usage\n", 35 | "\n", 36 | "Going forward, we hope you apply the technique learned to optimize applications.\n", 37 | "\n", 38 | "Run the cell below to learn more about Nsight Systems CLI commands from the NVIDIA Nsight Systems 2022.2.1 user guide. " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 1, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "text/html": [ 49 | "\n", 50 | " \n", 57 | " " 58 | ], 59 | "text/plain": [ 60 | "" 61 | ] 62 | }, 63 | "execution_count": 1, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "from IPython.display import IFrame\n", 70 | "IFrame(\"../source_code/Nsight_Systems_User_Guide_2022.2.1.31-5fe97ab.pdf\", 960,900)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "---\n", 78 | "Please click on the `Start Part 2` link below to get started with the Pytorch profiler with the Tensorboard section. Please, note that this part is not required for evaluation, it is only for knowledge purpose. \n", 79 | "\n", 80 | "##

[Start Part 2](tb01_introduction.ipynb) (optional)
\n", 81 | "\n", 82 | "---" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "\n", 90 | "## Acknowledgement\n", 91 | "\n", 92 | "**This material is adapted from:**\n", 93 | "- GTC 2021 Hands-on labs, Nsight Developer Tools Training Contents (https://github.com/NVIDIA/nsight-training) licensed under BSD 3-Clause \n", 94 | "- HPC nways on Nsight Systems (https://github.com/openhackathons-org/gpubootcamp/tree/master/hpc) licensed under the Creative Commons Attribution 4.0 International (CC BY 4.0)\n", 95 | "\n", 96 | "## Links and Resources\n", 97 | "\n", 98 | "If you are interested in learning more about the tool, here are some resources:\n", 99 | "- Download Nsight Systems for free from http://developer.nvidia.com/nsight-systems\n", 100 | "- Documentation is at https://docs.nvidia.com/nsight-systems/index.html\n", 101 | "- Blog posts: https://developer.nvidia.com/blog/tag/nsight-systems/\n", 102 | "\n", 103 | "\n", 104 | "[NVIDIA Nsight Systems](https://docs.nvidia.com/nsight-systems/)\n", 105 | "\n", 106 | "\n", 107 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 108 | "\n", 109 | "\n", 110 | "You can also get resources from [Open Hackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 111 | "\n", 112 | "\n", 113 | "--- \n", 114 | "\n", 115 | "## Licensing \n", 116 | "\n", 117 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply." 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "
\n", 125 | " \n", 126 | " 1\n", 127 | " 2\n", 128 | " 3\n", 129 | " 4\n", 130 | " 5\n", 131 | " \n", 132 | "
\n", 133 | "\n", 134 | "
\n", 135 | "

Home Page

" 136 | ] 137 | } 138 | ], 139 | "metadata": { 140 | "kernelspec": { 141 | "display_name": "Python 3", 142 | "language": "python", 143 | "name": "python3" 144 | }, 145 | "language_info": { 146 | "codemirror_mode": { 147 | "name": "ipython", 148 | "version": 3 149 | }, 150 | "file_extension": ".py", 151 | "mimetype": "text/x-python", 152 | "name": "python", 153 | "nbconvert_exporter": "python", 154 | "pygments_lexer": "ipython3", 155 | "version": "3.8.8" 156 | } 157 | }, 158 | "nbformat": 4, 159 | "nbformat_minor": 4 160 | } 161 | -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/Baseline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/Baseline.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/Built-in_NVTX.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/Built-in_NVTX.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/NVTX_annotations.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/NVTX_annotations.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/NumberOfWorkers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/NumberOfWorkers.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/Optimization2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/Optimization2.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/Optimization3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/Optimization3.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/Optimization_workflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/Optimization_workflow.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/PageableVsPinned.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/PageableVsPinned.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/ShowInEventsView.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/ShowInEventsView.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/StarvationDuringDataLoading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/StarvationDuringDataLoading.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/TensorCoreUsage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/TensorCoreUsage.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/Training_program.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/Training_program.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/amp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/amp.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/architecture_tensor_cores.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/architecture_tensor_cores.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/browse_port.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/browse_port.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/cudaProfilerApi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/cudaProfilerApi.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/cudaprofilestart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/cudaprofilestart.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/fp16_opt22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/fp16_opt22.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/fp16_opt33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/fp16_opt33.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/gpu_kernel_view.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/gpu_kernel_view.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/memory_view_menu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/memory_view_menu.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/memory_view_opt11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/memory_view_opt11.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/memory_view_opt22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/memory_view_opt22.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/mnist.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/module_view.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/module_view.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/module_view_opt1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/module_view_opt1.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsight_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsight_flow.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsys_CUDA_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsys_CUDA_api.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsys_GPU_view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsys_GPU_view.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsys_blocked_state.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsys_blocked_state.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsys_blocked_state_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsys_blocked_state_1.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsys_cpu_api_thread.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsys_cpu_api_thread.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsys_cpu_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsys_cpu_summary.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsys_search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsys_search.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nsys_timeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nsys_timeline.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/nvtx_annotation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/nvtx_annotation.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/operator_view.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/operator_view.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/operator_view_code_trace.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/operator_view_code_trace.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/operator_view_tbl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/operator_view_tbl.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/performance_recommendation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/performance_recommendation.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/performance_recommendation_opt1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/performance_recommendation_opt1.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/performance_recommendation_opt2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/performance_recommendation_opt2.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/port_forwarding.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/port_forwarding.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/profile_summary.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/profile_summary.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/profile_summary_opt1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/profile_summary_opt1.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/profile_summary_opt11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/profile_summary_opt11.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/profile_summary_opt11_memcpy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/profile_summary_opt11_memcpy.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/profile_summary_opt2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/profile_summary_opt2.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/profile_summary_opt22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/profile_summary_opt22.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/profile_summary_opt22_tensor_core.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/profile_summary_opt22_tensor_core.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/profile_summary_opt33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/profile_summary_opt33.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_02_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_02_1.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_02_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_02_2.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_02_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_02_3.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_02_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_02_4.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_activate_tensor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_activate_tensor.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_api_call.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_api_call.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_baseline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_baseline.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_final_time.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_final_time.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_firstoptimization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_firstoptimization.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_gpu_starvation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_gpu_starvation.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_kernel_coverage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_kernel_coverage.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_nvtx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_nvtx.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_pinned_memory.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_pinned_memory.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_show_current_timeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_show_current_timeline.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/report_show_in_events_view.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/report_show_in_events_view.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/side_menu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/side_menu.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/side_menu_list.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/side_menu_list.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/step_time_breakdown.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/step_time_breakdown.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/step_time_breakdown_opt1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/step_time_breakdown_opt1.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/step_time_breakdown_opt11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/step_time_breakdown_opt11.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/step_time_breakdown_opt2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/step_time_breakdown_opt2.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/step_time_breakdown_opt33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/step_time_breakdown_opt33.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/tb_kernel_view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/tb_kernel_view.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/tb_main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/tb_main.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/tb_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/tb_main.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/tb_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/tb_overview.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/tb_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/tb_train.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/tensor_core_util_opt22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/tensor_core_util_opt22.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/tensor_cores.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/tensor_cores.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/trace_view.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/trace_view.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/trace_view1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/trace_view1.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/trace_view_opt1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/trace_view_opt1.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/trace_view_opt11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/trace_view_opt11.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/trace_view_opt2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/trace_view_opt2.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/trace_view_opt33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/trace_view_opt33.jpg -------------------------------------------------------------------------------- /workspace/jupyter_notebook/images/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/jupyter_notebook/images/train.png -------------------------------------------------------------------------------- /workspace/jupyter_notebook/tb01_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Home Page

\n", 8 | "\n", 9 | " \n", 10 | "
\n", 11 | " \n", 12 | " 1\n", 13 | " 2\n", 14 | " 3\n", 15 | " 4\n", 16 | " 5\n", 17 | " \n", 18 | " Next Notebook\n", 19 | "
\n" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Part 2: Using PyTorch Profiler with TensorBoard plugin\n", 27 | "---" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "This section is an overview of the PyTorch Profiler. The goal is to familiarize you with commonly used features of the TensorBoard visualization toolkit. To get started, we will recap what profiling is from the Part 1 section of this lab.\n", 35 | "\n", 36 | "## What is profiling?\n", 37 | "\n", 38 | "Profiling is the first step in optimizing and tuning your application. Profiling an application helps to understand where most of the execution time is spent, providing an understanding of performance characteristics and identifying parts of the code that present opportunities for improvement. Finding hotspots and bottlenecks in your application can help you decide where to focus your optimization efforts." 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "## Pytorch Profiler\n", 46 | "\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "The PyTorch Profiler tool enables the profiling of deep neural networks (DNN) training programs through the collection of performance metrics that include execution time, memory costs, stack traces, device Kernel, and more. This process is done through the context manager application processing interface (API). The PyTorch Profiler API assists in identifying the most expensive operators such as hot spots within the application. The Profiler also supports multithreaded models.\n", 54 | "\n", 55 | "To be able to profile the application using PyTorch Profiler, you must import essential libraries as shown below: \n", 56 | "\n", 57 | "```python\n", 58 | "from torch.profiler import profile, record_function, ProfilerActivity\n", 59 | "from torch.profiler import schedule\n", 60 | "```\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "You can achieve the following tasks with the PyTorch profiler:\n", 68 | "\n", 69 | "- **Analyze Execution Time**\n", 70 | "\n", 71 | "The context manager activates the PyTorch Profiler to measure the execution time on the Host `(CPU)` and the device Kernel `(CUDA)` activities using `ProfilerActivity`. It also records the shapes of the operator inputs and reports the amount of memory consumed by the model’s Tensors. Below are the lines of code needed to achieve this.\n", 72 | "\n", 73 | "\n", 74 | "```python\n", 75 | "\n", 76 | "with profile(activities=[ ProfilerActivity.CUDA, ProfilerActivity.CPU], \n", 77 | " record_shapes=True) as p:\n", 78 | " with record_function(\"running_model\"):\n", 79 | " model(data)\n", 80 | "```\n", 81 | "\n", 82 | "\n", 83 | "- **Analyze Memory Consumption and Examine Stack Traces**\n", 84 | "\n", 85 | "\n", 86 | "The PyTorch Profiler reveals the amount of memory consumed by the model’s tensors or released during the execution of the model’s operators. In addition, it gives an analysis of the stack traces. To capture these, you must include `profile_memory=True` and `with_stack=True` in the `profile` as shown below:\n", 87 | "\n", 88 | "```python\n", 89 | "with profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], \n", 90 | " record_shapes=True, profile_memory=True, with_stack=True) as p:\n", 91 | " model(data)\n", 92 | "```\n" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "- **Analyze Long-running Jobs**\n", 100 | "\n", 101 | "Long-running jobs are referred to as model training that runs within loops (step or batch iterations). In this case, the profiling process can take a long time and result in a heavy trace file being generated. Therefore, PyTorch Profiler offers `torch.profiler.schedule` API to schedule the number of steps to trace execution and return profiling results.\n", 102 | "\n", 103 | "```python\n", 104 | "schedule = schedule(skip_first=5, wait=1, warmup=1, active=3)\n", 105 | "```\n", 106 | "In the above statement, `skip_first=5` implies the profiler should ignore the first 5 steps, `wait=1` means the profiler is to be idle for 1 step, `warmup=1` depicts starts tracing using 1 step as for warming up but discard the result due to overheads incurred at the beginning of profiling trace, `active=3` tells the profiler to start the trace and collect performance metrics. We can combine the schedule with the profiling statement as given below:\n", 107 | "\n", 108 | "```python\n", 109 | "with profile( activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], \n", 110 | " schedule=torch.profiler.schedule(wait=1, warmup=1,active=2), \n", 111 | " on_trace_ready=trace_handler) as p:\n", 112 | " for step, data in enumerate(train_loader):\n", 113 | " train(data)\n", 114 | " p.step() \n", 115 | "```\n", 116 | "For more information please visit the [PyTorch Profiler page](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)." 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## TensorBoard" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "The TensorBoard is a TensorFlow visualization toolkit that provides the tooling needed for tracking and visualizing performance metrics in a machine learning workflow. Through the TensorBoard plugin, the PyTorch Profiler trace can be visualized and analyzed on the TensorBoard. TensorBoard visualization can be viewed through `http://localhost:6006/#pytorch_profiler` URL in the Google Chrome or Microsoft Edge browser.\n", 131 | "\n", 132 | "\n", 133 | "\n", 134 | "Further details will be provided in the next notebook.\n" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "### Brief Highlight of TensorBoard visualization features\n", 142 | "\n", 143 | "The main features include:\n", 144 | "\n", 145 | "- Overview\n", 146 | "- Operator\n", 147 | "- GPU Kernel\n", 148 | "- Trace\n", 149 | "- Module\n", 150 | "\n", 151 | "\n" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "- **Overview**\n", 159 | "\n", 160 | "The overview view gives a summary of training model performance metrics that include `Configuration`, `GPU Summary`, `Execution Summary`, `Step Time Breakdown`, and `Performance Recommendation`. In the configuration panel, you can see the `Number of Worker(s)` and the `Device Type` that indicates a device such as a GPU. GPU features like GPU index and name, compute capability, utilization, estimated shared memory, achieved occupancy, and kernel time using Tensor Cores are found on the `GPU Summary` panel. The `Execution Summary` displays a statistical summary of time duration and percentage of use for important processes like Kernel, Memcpy, Memset, Runtime, CPU execution, and others. Meanwhile, the `Step Time Breakdown` represents the processes in the `Execution Summary` panel as a step time in microseconds (μs) stacked bar graph. In the `Performance Recommendation` panel, you will find suggestions that help identify bottlenecks as well as steps to solving and improving the performance of the profiled code." 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "\n", 168 | "
view source here
" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "- **Operator**\n", 176 | "\n", 177 | "The `Operator` view provides a graphical representation of the performance of PyTorch operators executed on the CPU and GPU. The percentage statistics on the host and device are shown below in μs. At the bottom of the “Operator” view, you can see in a table the statistics displayed in the graphical form above. At the right-end column of the table, you will see the `View CallStack` that displays the call frames link `View call frames`. This provides links to line numbers that trigger the operations of PyTorch operators within your code. Clicking on the stack calls line numbers launches VSCode (if installed). The section of the code is displayed otherwise, you have to manually open the code and search for the line number within your preferred IDE (Integrated Development Environment)." 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "- **GPU Kernel**\n", 206 | "\n", 207 | "In the `Views` selection dropdown from the `NORMAL` panel on the right side of the visualization, select `GPU Kernel` to display the “Kernel View.” This displays the percentage statistics of all kernels launched in your training model. More importantly, it shows the percentage statistics of Tensor Cores used and not used. At the bottom of the graphical chart is the table that lists kernels running within the application." 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "\n", 215 | "
View source here
\n" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "\n", 223 | "- **Trace**\n", 224 | "\n", 225 | "The `Trace` view shows a timeline of profiled operators including the code line number for the threads. `GPU Utilization and SM Efficiency estimation` can also be seen on the `Trace` view timeline. You can zoom in/out the timeline using the arrow toolbar at the right-side of the timeline view." 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": {}, 231 | "source": [ 232 | "" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "- **Module**\n", 240 | "\n", 241 | "The `Module` view shows your application summary, displaying the name of the module/class running within your application model, the number of `occurrences` (the number of times it was called) and `operators`, and the total amount of time spent on the CPU and GPU.\n", 242 | "\n", 243 | "" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "We are now ready to proceed to start profiling a simple DNN training program. In the rest of the notebooks, the expression `DNN training program` will be used interchangeably with the word 'application'. Please, click on the `Next Notebook` link at the bottom to get started." 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "## Links and Resources\n", 258 | "\n", 259 | "\n", 260 | "[NVIDIA Nsight™ Systems](https://docs.nvidia.com/nsight-systems/)\n", 261 | "\n", 262 | "\n", 263 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 264 | "\n", 265 | "\n", 266 | "You can also get resources from [Open Hackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 267 | "\n", 268 | "## References\n", 269 | "\n", 270 | "- https://pytorch.org/tutorials/beginner/profiler.html\n", 271 | "- https://pytorch.org/docs/stable/profiler.html\n", 272 | "- https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html\n", 273 | "- https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html\n", 274 | "- https://www.tensorflow.org/tensorboard/get_started\n", 275 | "---\n", 276 | "## Licensing \n", 277 | "\n", 278 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply." 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "
\n", 286 | " \n", 287 | " 1\n", 288 | " 2\n", 289 | " 3\n", 290 | " 4\n", 291 | " 5\n", 292 | " \n", 293 | " Next Notebook\n", 294 | "
\n", 295 | "\n", 296 | "
\n", 297 | "

Home Page

" 298 | ] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": "Python 3", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.8.8" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 4 322 | } 323 | -------------------------------------------------------------------------------- /workspace/jupyter_notebook/tb03_data_transfer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Home Page

\n", 8 | "\n", 9 | " \n", 10 | "
\n", 11 | " \n", 12 | " 1\n", 13 | " 2\n", 14 | " 3\n", 15 | " 4\n", 16 | " 5\n", 17 | " \n", 18 | " Next Notebook\n", 19 | "
" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Part 2: Memory Operations\n", 27 | "---" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "The goal of this notebook is to optimize memory operations that reflect data transfers between Host and Device. " 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Analyze the Previous Profile Trace \n", 42 | "Let's analyze the data transfers from the host and device from the first optimization step in the previous notebook. In the screenshot below, you can see the time taken by `Memcpy`. The first goal is to reduce the time, but before that can happen we need to investigate what is happening on the device's memory." 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "To see the memory view, you must ensure that `profile_memory=True` is set in the PyTorch Profiler. `The Memory View` records and displays all memory allocations and reserves from the GPU & CPU in a curve graph. The view also shows the memory events and statistics in tabular form. The first step is to look by the left-side of the `TensorBoard` under `NORMAL`, you will see the “Views” dropdown selection box. Please select `Memory` as shown below." 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "Our interest is GPU memory usage therefore, select GPU under `Device` at the top left corner of the `Memory View` panel. In the memory curve graph screenshot below, we see almost zero memory usage on the time axis from 8ms to 35ms. This implies that a small amount of GPU memory is being utilized therefore `memcpy` will take longer and affect data transfer. The remedy would be to enable `Pinned Memory` to utilize more GPU memory and foster faster data transfer." 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Enable Pinned Memory\n", 85 | "\n", 86 | "Host (CPU) memory allocations are pageable by default. The GPU cannot access data directly from pageable host memory. When a data transfer is invoked from pageable host memory to device memory, the NVIDIA® CUDA® driver must first allocate a temporary page-locked (or “pinned”) host array, copy the host data to the pinned array, and then transfer the data from the pinned array to the device memory. The pinned memory is used as a staging area for transfers from the host to the device. By directly allocating our host data to pinned memory, we can avoid this extra step and its overhead. See the following blog [post](https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/) for more details.\n", 87 | "\n", 88 | " \n", 89 | "Source: https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc

\n", 90 | "\n", 91 | "We add `pin_memory: True` to the settings used for the data loader in our program. Run the cell below to see the code change (shown in green) made to use pinned memory." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 1, 97 | "metadata": { 98 | "scrolled": true, 99 | "tags": [] 100 | }, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "\u001b[1m--- ../source_code/main_opt1.py\t2022-08-18 03:51:54.178624632 +0900\u001b[0m\n", 107 | "\u001b[1m+++ ../source_code/main_opt2.py\t2022-08-19 23:25:01.750840799 +0900\u001b[0m\n", 108 | "\u001b[36m@@ -160,8 +160,9 @@\u001b[0m\n", 109 | " test_kwargs = {'batch_size': args.test_batch_size}\n", 110 | " if use_cuda:\n", 111 | " #multiprocessing.cpu_count()\n", 112 | " cuda_kwargs = {'num_workers': 2,\n", 113 | "\u001b[32m+ 'pin_memory': True,\u001b[0m\n", 114 | " 'shuffle': True}\n", 115 | " train_kwargs.update(cuda_kwargs)\n", 116 | " test_kwargs.update(cuda_kwargs)\n", 117 | " \n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "!diff -U4 --color=always ../source_code/tb_main_opt1.py ../source_code/tb_main_opt2.py" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## Profile Again to Verify Optimization\n", 130 | "Profile again by executing the cell given below to verify if our code change addresses the problem with host-to-device memory transfers after setting `pin_memory: True`." 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": { 137 | "scrolled": true, 138 | "tags": [] 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "!python3 ../source_code/tb_main_opt2.py" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "When you are done with the profiling, run the cell below to visualize the profile trace in the `TensorBoard`." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "!tensorboard --logdir=../log" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "If you are working on a remote machine, remember to do `port-forward` as described in the previous notebook before opening the browser at `localhost:6006/` ." 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "\n" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "The following are noticeable changes on the TensorBoard: \n", 187 | "- In the `Memory View` we can see the memory usage has increased as compared to the way it was before.\n", 188 | "- On the `Execution Summary` panel, the `memcpy` has dropped from `346µs` to `173µs`\n", 189 | "- Also, the time taken by the `DataLoader` has reduced to `4,4638µs (~29%)` as compared to `7,770µs (48.3%)` from the previous notebook.\n", 190 | "\n", 191 | "Now that we have addressed a bottleneck with memory transfers, let's identify the next performance bottleneck by clicking on the `Next Notebook` link below." 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | " ## Links and Resources\n", 199 | "\n", 200 | "[NVIDIA Nsight™ Systems](https://docs.nvidia.com/nsight-systems/)\n", 201 | "\n", 202 | "\n", 203 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 204 | "\n", 205 | "You can also get resources from [openhackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 206 | " \n", 207 | " ---\n", 208 | " ## Licensing\n", 209 | " \n", 210 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply." 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "
\n", 218 | " \n", 219 | " 1\n", 220 | " 2\n", 221 | " 3\n", 222 | " 4\n", 223 | " 5\n", 224 | " \n", 225 | " Next Notebook\n", 226 | "
\n", 227 | "\n", 228 | "
\n", 229 | "

Home Page

" 230 | ] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.8.8" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 4 254 | } 255 | -------------------------------------------------------------------------------- /workspace/jupyter_notebook/tb04_tensor_core_util.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Home Page

\n", 8 | "\n", 9 | " \n", 10 | "
\n", 11 | " \n", 12 | " 1\n", 13 | " 2\n", 14 | " 3\n", 15 | " 4\n", 16 | " 5\n", 17 | " \n", 18 | " Next Notebook\n", 19 | "
" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Part 2: Tensor Cores \n", 27 | "---" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "The goal of this notebook is to show how to enable mixed precision on the Tensor Core to further optimize our application.\n", 35 | "\n", 36 | "## Tensor Core Usage\n", 37 | "\n", 38 | "Tensor cores are specialized processing units designed to accelerate the process of tensor/matrix multiplication. Tensor Cores enable mixed-precision computing, dynamically adapting calculations to accelerate throughput while preserving accuracy. Our application runs on the `NVIDIA® DGX™ A100` Tensor Core GPU. You can also run the application on other GPU architectures, for example, the `Turing™ GPU architecture` which has Tensor Core precision." 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "
" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "The screenshot below shows a table of NVIDIA GPU architectures and supported Tensor Core precisions." 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "
\n", 60 | "
Source here
" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "## Analyze the TensorBoard Visualization\n", 68 | "\n", 69 | "To verify Tensor Core usage, check the `GPU Summary` frame. You can also verify through the `Kernel View` by selecting `GPU Kernel` from the Views dropdown menu. As shown below, the `Kernel Time using Tensor Cores` is `30.7%` and likewise the `Tensor Core Utilization` in the `Kernel View`. We are able to see this because Tensor Core utilization is automatic with Ampere architecture-based GPUs such as those within the DGX A100. However, this may not be the case if you are running the lab on other GPU architectures. Our aim is to introduce `Automatic Mixed Precision (AMP)` Tensor Core operations." 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "\n", 77 | "" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Automatic Mixed Precision (AMP)\n", 85 | "\n", 86 | "Mixed Precision is the combined use of different numerical formats `(single and half-precision computation)` in the training of a deep neural network.\n", 87 | "- Single precision: FP32\n", 88 | "- Half precision: FP16\n", 89 | "\n", 90 | "The use of mixed precision is possible in NVIDIA GPU architectures such as `Ampere`, `Volta™`, and `Turing`. The benefits include:\n", 91 | "\n", 92 | "- Speed up math-intensive operations using tensor cores.\n", 93 | "- Requires less memory bandwidth, thus achieving a speedup of data transfer operations. \n", 94 | "- Requires less memory enabling the training and deployment of larger neural networks.\n", 95 | "\n", 96 | "\n", 97 | "Automatic Mixed Precision (AMP) automates the process of training using mixed precision through DNN frameworks. PyTorch has [AMP](https://pytorch.org/docs/stable/amp.html) package which provides a simple way for users to convert existing FP32 training scripts to mixed FP32 and FP16 precision. This unlocks faster computation with Tensor Cores on NVIDIA GPUs." 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "Before we introduce `AMP` into our application code, our first step is to search for `fp16` operations running within the Tensor Core in the `Kernel View`. The result below shows it is absent.\n", 105 | "\n", 106 | "" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "In the screenshot below you will see the code changes (in green color frame) made to use the AMP package in PyTorch.\n", 114 | "\n", 115 | "" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "Profile again and verify the code change enables `fp16` operations within the Tensor Core operations." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": { 129 | "scrolled": true, 130 | "tags": [] 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "!python3 ../source_code/tb_main_opt2.py" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "Next, run the cell below to visualize the profile in the TensorBoard. If you are working on a remote machine, remember to do `port-forward` before opening the browser at `localhost:6006/`." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "!tensorboard --logdir=../log" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "\n" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "Now, you can see `fp16` ops under the `Name` column and the Yes that validate operations running within the Tensor Core under the `Tensor Cores Used` column. The impact of this on our application model is that it reduces the amount of time spent by the Tensor Core for computation because `AMP` enabled the speedup of math-intensive operations. This is verified in the `GPU Summary` frame shown below." 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "The following changes were found after `AMP` was activated:\n", 193 | "\n", 194 | "- Increase in GPU usage from `37%` to `52%`,\n", 195 | "- Time spent by DataLoader was further reduced from `4,638µs`(~27.9%) to `948µs` (10.7%)." 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "## Compare the Performance Before and After the Optimizations\n", 203 | "Now that three different performance problems have been addressed, let us time the application [tb_main_opt3.py](../source_code/tb_main_opt3.py)." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": { 210 | "scrolled": true, 211 | "tags": [] 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "!cd ../source_code && time python3 main_opt3.py" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "**Expected output on DGX A100**:\n", 223 | "\n", 224 | "```python\n", 225 | "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.308961\n", 226 | "\n", 227 | "Test set: Average loss: 0.1024, Accuracy: 9683/10000 (97%)\n", 228 | "\n", 229 | "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.154755\n", 230 | "\n", 231 | "Test set: Average loss: 0.0608, Accuracy: 9814/10000 (98%)\n", 232 | "\n", 233 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.110753\n", 234 | "\n", 235 | "Test set: Average loss: 0.0535, Accuracy: 9827/10000 (98%)\n", 236 | "\n", 237 | "----------------------------------------------------------\n", 238 | "\n", 239 | "Train Epoch: 9 [0/60000 (0%)]\tLoss: 0.059646\n", 240 | "\n", 241 | "Test set: Average loss: 0.0370, Accuracy: 9865/10000 (99%)\n", 242 | "\n", 243 | "Train Epoch: 10 [0/60000 (0%)]\tLoss: 0.055018\n", 244 | "\n", 245 | "Test set: Average loss: 0.0365, Accuracy: 9865/10000 (99%)\n", 246 | "\n", 247 | "\n", 248 | "real\t1m24.619s\n", 249 | "user\t2m35.069s\n", 250 | "sys\t 0m7.676s\n", 251 | "\n", 252 | "```\n" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "Comparing the time taken to run our baseline code [main_baseline.py](../source_code/tb_main_baseline_nvtx.py) from [notebook 2](tb02_pytorch_mnist.ipynb) with the code after applying the three optimizations [main_opt3.py](../source_code/main_opt3.py), we see that the overall time taken was reduced as shown in the table below.\n", 260 | "\n", 261 | "\n", 262 | "|Training code| Time|speedup\n", 263 | "|--|--|--|\n", 264 | "|basline| 113s|-|\n", 265 | "|optimized|~85s|1.3x|\n" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | " ## Links and Resources\n", 273 | "\n", 274 | "[NVIDIA Nsight™ Systems](https://docs.nvidia.com/nsight-systems/)\n", 275 | "\n", 276 | "\n", 277 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 278 | "\n", 279 | "You can also get resources from [openhackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 280 | " \n", 281 | " ---\n", 282 | " ## Licensing\n", 283 | " \n", 284 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply." 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "
\n", 292 | " \n", 293 | " 1\n", 294 | " 2\n", 295 | " 3\n", 296 | " 4\n", 297 | " 5\n", 298 | " \n", 299 | " Next Notebook\n", 300 | "
\n", 301 | "\n", 302 | "
\n", 303 | "

Home Page

" 304 | ] 305 | } 306 | ], 307 | "metadata": { 308 | "kernelspec": { 309 | "display_name": "Python 3", 310 | "language": "python", 311 | "name": "python3" 312 | }, 313 | "language_info": { 314 | "codemirror_mode": { 315 | "name": "ipython", 316 | "version": 3 317 | }, 318 | "file_extension": ".py", 319 | "mimetype": "text/x-python", 320 | "name": "python", 321 | "nbconvert_exporter": "python", 322 | "pygments_lexer": "ipython3", 323 | "version": "3.8.8" 324 | } 325 | }, 326 | "nbformat": 4, 327 | "nbformat_minor": 4 328 | } 329 | -------------------------------------------------------------------------------- /workspace/jupyter_notebook/tb05_summary.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Home Page

\n", 8 | "\n", 9 | " \n", 10 | "
\n", 11 | " \n", 12 | " 1\n", 13 | " 2\n", 14 | " 3\n", 15 | " 4\n", 16 | " 5\n", 17 | " \n", 18 | "
" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "Part 2 SUMMARY\n", 26 | "---\n", 27 | "In this section of the labs, you learned how to:\n", 28 | "\n", 29 | "- Run the sample application.\n", 30 | "- Apply Pytorch Profiler within our application code. Visualize profile trace log on the TensorBoard.\n", 31 | "- Interpret the timeline, Overview, Memory View, GPU View, Performance recommendation, and GPU utilization.\n", 32 | "- Identify performance problems through the Performance recommendation suggestion. Confirm the performance improvement gained from the optimizations.\n", 33 | "- Enable pinned memory.\n", 34 | "- Activate the use of Automatic Mixed Precision (AMP) and track Tensor Core usage.\n", 35 | "\n", 36 | "We hope that you enjoyed the lab and learned useful techniques and strategies on how to use PyTorch Profiler with TensorBoard to optimize applications to efficiently utilize the system resources.\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "
\n", 44 | "
\n", 45 | "
\n" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "We recommend answering a few multiple-choice questions to test your understanding after completing Part 1 of this lab. " 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## Acknowledgement\n", 60 | "\n", 61 | "**This material is adapted from:**\n", 62 | "\n", 63 | "- NVIDIA GPU Technology Conference (GTC) 2021 Hands-on labs, NVIDIA Nsight™ Developer Tools Training content (https://github.com/NVIDIA/nsight-training) licensed under BSD 3-Clause.\n", 64 | "- HPC N-Ways to GPU Programming on Nsight Systems (https://github.com/openhackathons-org/gpubootcamp/tree/master/hpc) licensed under the Creative Commons Attribution 4.0 International (CC BY 4.0)\n" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "## Links and Resources\n", 72 | "\n", 73 | "[NVIDIA Nsight™ Systems](https://docs.nvidia.com/nsight-systems/)\n", 74 | "\n", 75 | "\n", 76 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 77 | "\n", 78 | "You can also get resources from [openhackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 79 | " \n", 80 | " ---\n", 81 | " ## Licensing\n", 82 | " \n", 83 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply." 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "
\n", 91 | " \n", 92 | " 1\n", 93 | " 2\n", 94 | " 3\n", 95 | " 4\n", 96 | " 5\n", 97 | " \n", 98 | "
\n", 99 | "\n", 100 | "
\n", 101 | "

Home Page

" 102 | ] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python 3", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.8.8" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 4 126 | } 127 | -------------------------------------------------------------------------------- /workspace/reports/baseline.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/reports/baseline.nsys-rep -------------------------------------------------------------------------------- /workspace/reports/baseline_nvtx.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/reports/baseline_nvtx.nsys-rep -------------------------------------------------------------------------------- /workspace/reports/firstOptimization.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/reports/firstOptimization.nsys-rep -------------------------------------------------------------------------------- /workspace/reports/secondOptimization.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/reports/secondOptimization.nsys-rep -------------------------------------------------------------------------------- /workspace/reports/thirdOptimization.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/reports/thirdOptimization.nsys-rep -------------------------------------------------------------------------------- /workspace/source_code/LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, NVIDIA CORPORATION 4 | Copyright (c) 2017, 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /workspace/source_code/Nsight_Systems_User_Guide_2022.2.1.31-5fe97ab.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/Nsight_Systems_User_Guide_2022.2.1.31-5fe97ab.pdf -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Yann LeCun (Courant Institute, NYU) and Corinna Cortes (Google Labs, New York) 2 | hold the copyright of MNIST dataset, which is a derivative work from original 3 | NIST datasets. MNIST dataset is made available under the terms of the Creative 4 | Commons Attribution-Share Alike 3.0 license. 5 | -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/processed/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/processed/test.pt -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/processed/training.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/processed/training.pt -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/raw/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/raw/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/raw/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/raw/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/raw/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/raw/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/raw/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/raw/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/raw/train-images-idx3-ubyte -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/raw/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/raw/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/raw/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/raw/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /workspace/source_code/data/MNIST/raw/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openhackathons-org/AI-Profiler/709cacd4f1292ed7b40fe293c20edc735c60d7d5/workspace/source_code/data/MNIST/raw/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /workspace/source_code/main_baseline.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | import torch.optim as optim 39 | from torchvision import datasets, transforms 40 | from torch.optim.lr_scheduler import StepLR 41 | 42 | 43 | class Net(nn.Module): 44 | def __init__(self): 45 | super(Net, self).__init__() 46 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 47 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 48 | self.dropout1 = nn.Dropout(0.25) 49 | self.dropout2 = nn.Dropout(0.5) 50 | self.fc1 = nn.Linear(9216, 128) 51 | self.fc2 = nn.Linear(128, 10) 52 | 53 | def forward(self, x): 54 | x = self.conv1(x) 55 | x = F.relu(x) 56 | x = self.conv2(x) 57 | x = F.relu(x) 58 | x = F.max_pool2d(x, 2) 59 | x = self.dropout1(x) 60 | x = torch.flatten(x, 1) 61 | x = self.fc1(x) 62 | x = F.relu(x) 63 | x = self.dropout2(x) 64 | x = self.fc2(x) 65 | output = F.log_softmax(x, dim=1) 66 | return output 67 | 68 | 69 | def train(args, model, device, train_loader, optimizer, epoch): 70 | model.train() 71 | for batch_idx, (data, target) in enumerate(train_loader): 72 | data, target = data.to(device), target.to(device) 73 | optimizer.zero_grad() 74 | output = model(data) 75 | loss = F.nll_loss(output, target) 76 | loss.backward() 77 | optimizer.step() 78 | if batch_idx % args.log_interval == 0: 79 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 80 | epoch, batch_idx * len(data), len(train_loader.dataset), 81 | 100. * batch_idx / len(train_loader), loss.item())) 82 | if args.dry_run: 83 | break 84 | 85 | 86 | def test(model, device, test_loader): 87 | model.eval() 88 | test_loss = 0 89 | correct = 0 90 | with torch.no_grad(): 91 | for data, target in test_loader: 92 | data, target = data.to(device), target.to(device) 93 | output = model(data) 94 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 95 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 96 | correct += pred.eq(target.view_as(pred)).sum().item() 97 | 98 | test_loss /= len(test_loader.dataset) 99 | 100 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 101 | test_loss, correct, len(test_loader.dataset), 102 | 100. * correct / len(test_loader.dataset))) 103 | 104 | 105 | def main(): 106 | # Training settings 107 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 108 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 109 | help='input batch size for training (default: 64)') 110 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 111 | help='input batch size for testing (default: 1000)') 112 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 113 | help='number of epochs to train (default: 14)') 114 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 115 | help='learning rate (default: 1.0)') 116 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 117 | help='Learning rate step gamma (default: 0.7)') 118 | parser.add_argument('--no-cuda', action='store_true', default=False, 119 | help='disables CUDA training') 120 | parser.add_argument('--dry-run', action='store_true', default=False, 121 | help='quickly check a single pass') 122 | parser.add_argument('--seed', type=int, default=1, metavar='S', 123 | help='random seed (default: 1)') 124 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 125 | help='how many batches to wait before logging training status') 126 | parser.add_argument('--save-model', action='store_true', default=False, 127 | help='For Saving the current Model') 128 | args = parser.parse_args() 129 | use_cuda = not args.no_cuda and torch.cuda.is_available() 130 | 131 | torch.manual_seed(args.seed) 132 | 133 | device = torch.device("cuda" if use_cuda else "cpu") 134 | 135 | train_kwargs = {'batch_size': args.batch_size} 136 | test_kwargs = {'batch_size': args.test_batch_size} 137 | if use_cuda: 138 | cuda_kwargs = {'num_workers': 1, 139 | 'shuffle': True} 140 | train_kwargs.update(cuda_kwargs) 141 | test_kwargs.update(cuda_kwargs) 142 | 143 | transform=transforms.Compose([ 144 | transforms.ToTensor(), 145 | transforms.Normalize((0.1307,), (0.3081,)) 146 | ]) 147 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 148 | dataDir = os.path.join(scriptPath, 'data') 149 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 150 | transform=transform) 151 | dataset2 = datasets.MNIST(dataDir, train=False, 152 | transform=transform) 153 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 154 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 155 | 156 | model = Net().to(device) 157 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 158 | 159 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 160 | for epoch in range(1, args.epochs + 1): 161 | # Start profiling from 2nd epoch 162 | if epoch == 2: 163 | torch.cuda.cudart().cudaProfilerStart() 164 | train(args, model, device, train_loader, optimizer, epoch) 165 | test(model, device, test_loader) 166 | scheduler.step() 167 | # Stop profiling at the end of 2nd epoch 168 | if epoch == 2: 169 | torch.cuda.cudart().cudaProfilerStop() 170 | 171 | if args.save_model: 172 | torch.save(model.state_dict(), "mnist_cnn.pt") 173 | 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /workspace/source_code/main_baseline_nvtx.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | import torch.optim as optim 39 | from torchvision import datasets, transforms 40 | from torch.optim.lr_scheduler import StepLR 41 | from torch.cuda import nvtx 42 | 43 | class Net(nn.Module): 44 | def __init__(self): 45 | super(Net, self).__init__() 46 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 47 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 48 | self.dropout1 = nn.Dropout(0.25) 49 | self.dropout2 = nn.Dropout(0.5) 50 | self.fc1 = nn.Linear(9216, 128) 51 | self.fc2 = nn.Linear(128, 10) 52 | 53 | def forward(self, x): 54 | x = self.conv1(x) 55 | x = F.relu(x) 56 | x = self.conv2(x) 57 | x = F.relu(x) 58 | x = F.max_pool2d(x, 2) 59 | x = self.dropout1(x) 60 | x = torch.flatten(x, 1) 61 | x = self.fc1(x) 62 | x = F.relu(x) 63 | x = self.dropout2(x) 64 | x = self.fc2(x) 65 | output = F.log_softmax(x, dim=1) 66 | return output 67 | 68 | 69 | def train(args, model, device, train_loader, optimizer, epoch): 70 | model.train() 71 | with torch.autograd.profiler.emit_nvtx(): 72 | nvtx.range_push("Data loading"); 73 | for batch_idx, (data, target) in enumerate(train_loader): 74 | nvtx.range_pop();# Data loading 75 | nvtx.range_push("Batch " + str(batch_idx)) 76 | 77 | nvtx.range_push("Copy to device") 78 | data, target = data.to(device), target.to(device) 79 | nvtx.range_pop() # Copy to device 80 | 81 | nvtx.range_push("Forward pass") 82 | optimizer.zero_grad() 83 | output = model(data) 84 | loss = F.nll_loss(output, target) 85 | nvtx.range_pop() # Forward pass 86 | 87 | nvtx.range_push("Backward pass") 88 | loss.backward() 89 | optimizer.step() 90 | nvtx.range_pop() # Backward pass 91 | 92 | nvtx.range_pop() # Batch 93 | if batch_idx % args.log_interval == 0: 94 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 95 | epoch, batch_idx * len(data), len(train_loader.dataset), 96 | 100. * batch_idx / len(train_loader), loss.item())) 97 | if args.dry_run: 98 | break 99 | nvtx.range_push("Data loading"); 100 | nvtx.range_pop(); # Data loading 101 | 102 | 103 | def test(model, device, test_loader): 104 | model.eval() 105 | test_loss = 0 106 | correct = 0 107 | with torch.no_grad(): 108 | for data, target in test_loader: 109 | nvtx.range_push("Copy to device") 110 | data, target = data.to(device), target.to(device) 111 | nvtx.range_pop(); # Copy to device 112 | 113 | nvtx.range_push("Test forward pass") 114 | output = model(data) 115 | nvtx.range_pop() # Test forward pass 116 | 117 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 118 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 119 | correct += pred.eq(target.view_as(pred)).sum().item() 120 | 121 | test_loss /= len(test_loader.dataset) 122 | 123 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 124 | test_loss, correct, len(test_loader.dataset), 125 | 100. * correct / len(test_loader.dataset))) 126 | 127 | 128 | def main(): 129 | # Training settings 130 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 131 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 132 | help='input batch size for training (default: 64)') 133 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 134 | help='input batch size for testing (default: 1000)') 135 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 136 | help='number of epochs to train (default: 14)') 137 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 138 | help='learning rate (default: 1.0)') 139 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 140 | help='Learning rate step gamma (default: 0.7)') 141 | parser.add_argument('--no-cuda', action='store_true', default=False, 142 | help='disables CUDA training') 143 | parser.add_argument('--dry-run', action='store_true', default=False, 144 | help='quickly check a single pass') 145 | parser.add_argument('--seed', type=int, default=1, metavar='S', 146 | help='random seed (default: 1)') 147 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 148 | help='how many batches to wait before logging training status') 149 | parser.add_argument('--save-model', action='store_true', default=False, 150 | help='For Saving the current Model') 151 | args = parser.parse_args() 152 | use_cuda = not args.no_cuda and torch.cuda.is_available() 153 | 154 | torch.manual_seed(args.seed) 155 | 156 | device = torch.device("cuda" if use_cuda else "cpu") 157 | 158 | train_kwargs = {'batch_size': args.batch_size} 159 | test_kwargs = {'batch_size': args.test_batch_size} 160 | if use_cuda: 161 | cuda_kwargs = {'num_workers': 1, 162 | 'shuffle': True} 163 | train_kwargs.update(cuda_kwargs) 164 | test_kwargs.update(cuda_kwargs) 165 | 166 | transform=transforms.Compose([ 167 | transforms.ToTensor(), 168 | transforms.Normalize((0.1307,), (0.3081,)) 169 | ]) 170 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 171 | dataDir = os.path.join(scriptPath, 'data') 172 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 173 | transform=transform) 174 | dataset2 = datasets.MNIST(dataDir, train=False, 175 | transform=transform) 176 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 177 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 178 | 179 | model = Net().to(device) 180 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 181 | 182 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 183 | for epoch in range(1, args.epochs + 1): 184 | # Start profiling from 2nd epoch 185 | if epoch == 2: 186 | torch.cuda.cudart().cudaProfilerStart() 187 | 188 | nvtx.range_push("Epoch " + str(epoch)) 189 | nvtx.range_push("Train") 190 | train(args, model, device, train_loader, optimizer, epoch) 191 | nvtx.range_pop() # Train 192 | 193 | nvtx.range_push("Test") 194 | test(model, device, test_loader) 195 | nvtx.range_pop() # Test 196 | 197 | scheduler.step() 198 | nvtx.range_pop() # Epoch 199 | # Stop profiling at the end of 2nd epoch 200 | if epoch == 2: 201 | torch.cuda.cudart().cudaProfilerStop() 202 | 203 | if args.save_model: 204 | torch.save(model.state_dict(), "mnist_cnn.pt") 205 | 206 | 207 | if __name__ == '__main__': 208 | main() 209 | -------------------------------------------------------------------------------- /workspace/source_code/main_opt1.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import multiprocessing 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.functional as F 39 | import torch.optim as optim 40 | from torchvision import datasets, transforms 41 | from torch.optim.lr_scheduler import StepLR 42 | from torch.cuda import nvtx 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 48 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 49 | self.dropout1 = nn.Dropout(0.25) 50 | self.dropout2 = nn.Dropout(0.5) 51 | self.fc1 = nn.Linear(9216, 128) 52 | self.fc2 = nn.Linear(128, 10) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = F.relu(x) 57 | x = self.conv2(x) 58 | x = F.relu(x) 59 | x = F.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = F.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | output = F.log_softmax(x, dim=1) 67 | return output 68 | 69 | 70 | def train(args, model, device, train_loader, optimizer, epoch): 71 | model.train() 72 | with torch.autograd.profiler.emit_nvtx(): 73 | nvtx.range_push("Data loading"); 74 | for batch_idx, (data, target) in enumerate(train_loader): 75 | nvtx.range_pop();# Data loading 76 | nvtx.range_push("Batch " + str(batch_idx)) 77 | 78 | nvtx.range_push("Copy to device") 79 | data, target = data.to(device), target.to(device) 80 | nvtx.range_pop() # Copy to device 81 | 82 | nvtx.range_push("Forward pass") 83 | optimizer.zero_grad() 84 | output = model(data) 85 | loss = F.nll_loss(output, target) 86 | nvtx.range_pop() # Forward pass 87 | 88 | nvtx.range_push("Backward pass") 89 | loss.backward() 90 | optimizer.step() 91 | nvtx.range_pop() # Backward pass 92 | 93 | nvtx.range_pop() # Batch 94 | if batch_idx % args.log_interval == 0: 95 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 96 | epoch, batch_idx * len(data), len(train_loader.dataset), 97 | 100. * batch_idx / len(train_loader), loss.item())) 98 | if args.dry_run: 99 | break 100 | nvtx.range_push("Data loading"); 101 | nvtx.range_pop(); # Data loading 102 | 103 | 104 | def test(model, device, test_loader): 105 | model.eval() 106 | test_loss = 0 107 | correct = 0 108 | with torch.no_grad(): 109 | for data, target in test_loader: 110 | nvtx.range_push("Copy to device") 111 | data, target = data.to(device), target.to(device) 112 | nvtx.range_pop(); # Copy to device 113 | 114 | nvtx.range_push("Test forward pass") 115 | output = model(data) 116 | nvtx.range_pop() # Test forward pass 117 | 118 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 119 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 120 | correct += pred.eq(target.view_as(pred)).sum().item() 121 | 122 | test_loss /= len(test_loader.dataset) 123 | 124 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 125 | test_loss, correct, len(test_loader.dataset), 126 | 100. * correct / len(test_loader.dataset))) 127 | 128 | 129 | def main(): 130 | # Training settings 131 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 132 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 133 | help='input batch size for training (default: 64)') 134 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 135 | help='input batch size for testing (default: 1000)') 136 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 137 | help='number of epochs to train (default: 14)') 138 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 139 | help='learning rate (default: 1.0)') 140 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 141 | help='Learning rate step gamma (default: 0.7)') 142 | parser.add_argument('--no-cuda', action='store_true', default=False, 143 | help='disables CUDA training') 144 | parser.add_argument('--dry-run', action='store_true', default=False, 145 | help='quickly check a single pass') 146 | parser.add_argument('--seed', type=int, default=1, metavar='S', 147 | help='random seed (default: 1)') 148 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 149 | help='how many batches to wait before logging training status') 150 | parser.add_argument('--save-model', action='store_true', default=False, 151 | help='For Saving the current Model') 152 | args = parser.parse_args() 153 | use_cuda = not args.no_cuda and torch.cuda.is_available() 154 | 155 | torch.manual_seed(args.seed) 156 | 157 | device = torch.device("cuda" if use_cuda else "cpu") 158 | 159 | train_kwargs = {'batch_size': args.batch_size} 160 | test_kwargs = {'batch_size': args.test_batch_size} 161 | if use_cuda: 162 | #multiprocessing.cpu_count() 163 | cuda_kwargs = {'num_workers': 2, 164 | 'shuffle': True} 165 | train_kwargs.update(cuda_kwargs) 166 | test_kwargs.update(cuda_kwargs) 167 | 168 | transform=transforms.Compose([ 169 | transforms.ToTensor(), 170 | transforms.Normalize((0.1307,), (0.3081,)) 171 | ]) 172 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 173 | dataDir = os.path.join(scriptPath, 'data') 174 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 175 | transform=transform) 176 | dataset2 = datasets.MNIST(dataDir, train=False, 177 | transform=transform) 178 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 179 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 180 | 181 | model = Net().to(device) 182 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 183 | 184 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 185 | for epoch in range(1, args.epochs + 1): 186 | # Start profiling from 2nd epoch 187 | if epoch == 2: 188 | torch.cuda.cudart().cudaProfilerStart() 189 | 190 | nvtx.range_push("Epoch " + str(epoch)) 191 | nvtx.range_push("Train") 192 | train(args, model, device, train_loader, optimizer, epoch) 193 | nvtx.range_pop() # Train 194 | 195 | nvtx.range_push("Test") 196 | test(model, device, test_loader) 197 | nvtx.range_pop() # Test 198 | 199 | scheduler.step() 200 | nvtx.range_pop() # Epoch 201 | # Stop profiling at the end of 2nd epoch 202 | if epoch == 2: 203 | torch.cuda.cudart().cudaProfilerStop() 204 | 205 | if args.save_model: 206 | torch.save(model.state_dict(), "mnist_cnn.pt") 207 | 208 | 209 | if __name__ == '__main__': 210 | main() 211 | -------------------------------------------------------------------------------- /workspace/source_code/main_opt2.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import multiprocessing 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.functional as F 39 | import torch.optim as optim 40 | from torchvision import datasets, transforms 41 | from torch.optim.lr_scheduler import StepLR 42 | from torch.cuda import nvtx 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 48 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 49 | self.dropout1 = nn.Dropout(0.25) 50 | self.dropout2 = nn.Dropout(0.5) 51 | self.fc1 = nn.Linear(9216, 128) 52 | self.fc2 = nn.Linear(128, 10) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = F.relu(x) 57 | x = self.conv2(x) 58 | x = F.relu(x) 59 | x = F.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = F.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | output = F.log_softmax(x, dim=1) 67 | return output 68 | 69 | 70 | def train(args, model, device, train_loader, optimizer, epoch): 71 | model.train() 72 | with torch.autograd.profiler.emit_nvtx(): 73 | nvtx.range_push("Data loading"); 74 | for batch_idx, (data, target) in enumerate(train_loader): 75 | nvtx.range_pop();# Data loading 76 | nvtx.range_push("Batch " + str(batch_idx)) 77 | 78 | nvtx.range_push("Copy to device") 79 | data, target = data.to(device), target.to(device) 80 | nvtx.range_pop() # Copy to device 81 | 82 | nvtx.range_push("Forward pass") 83 | optimizer.zero_grad() 84 | output = model(data) 85 | loss = F.nll_loss(output, target) 86 | nvtx.range_pop() # Forward pass 87 | 88 | nvtx.range_push("Backward pass") 89 | loss.backward() 90 | optimizer.step() 91 | nvtx.range_pop() # Backward pass 92 | 93 | nvtx.range_pop() # Batch 94 | if batch_idx % args.log_interval == 0: 95 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 96 | epoch, batch_idx * len(data), len(train_loader.dataset), 97 | 100. * batch_idx / len(train_loader), loss.item())) 98 | if args.dry_run: 99 | break 100 | nvtx.range_push("Data loading"); 101 | nvtx.range_pop(); # Data loading 102 | 103 | 104 | def test(model, device, test_loader): 105 | model.eval() 106 | test_loss = 0 107 | correct = 0 108 | with torch.no_grad(): 109 | for data, target in test_loader: 110 | nvtx.range_push("Copy to device") 111 | data, target = data.to(device), target.to(device) 112 | nvtx.range_pop(); # Copy to device 113 | 114 | nvtx.range_push("Test forward pass") 115 | output = model(data) 116 | nvtx.range_pop() # Test forward pass 117 | 118 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 119 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 120 | correct += pred.eq(target.view_as(pred)).sum().item() 121 | 122 | test_loss /= len(test_loader.dataset) 123 | 124 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 125 | test_loss, correct, len(test_loader.dataset), 126 | 100. * correct / len(test_loader.dataset))) 127 | 128 | 129 | def main(): 130 | # Training settings 131 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 132 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 133 | help='input batch size for training (default: 64)') 134 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 135 | help='input batch size for testing (default: 1000)') 136 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 137 | help='number of epochs to train (default: 14)') 138 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 139 | help='learning rate (default: 1.0)') 140 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 141 | help='Learning rate step gamma (default: 0.7)') 142 | parser.add_argument('--no-cuda', action='store_true', default=False, 143 | help='disables CUDA training') 144 | parser.add_argument('--dry-run', action='store_true', default=False, 145 | help='quickly check a single pass') 146 | parser.add_argument('--seed', type=int, default=1, metavar='S', 147 | help='random seed (default: 1)') 148 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 149 | help='how many batches to wait before logging training status') 150 | parser.add_argument('--save-model', action='store_true', default=False, 151 | help='For Saving the current Model') 152 | args = parser.parse_args() 153 | use_cuda = not args.no_cuda and torch.cuda.is_available() 154 | 155 | torch.manual_seed(args.seed) 156 | 157 | device = torch.device("cuda" if use_cuda else "cpu") 158 | 159 | train_kwargs = {'batch_size': args.batch_size} 160 | test_kwargs = {'batch_size': args.test_batch_size} 161 | if use_cuda: 162 | #multiprocessing.cpu_count() 163 | cuda_kwargs = {'num_workers': 2, 164 | 'pin_memory': True, 165 | 'shuffle': True} 166 | train_kwargs.update(cuda_kwargs) 167 | test_kwargs.update(cuda_kwargs) 168 | 169 | transform=transforms.Compose([ 170 | transforms.ToTensor(), 171 | transforms.Normalize((0.1307,), (0.3081,)) 172 | ]) 173 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 174 | dataDir = os.path.join(scriptPath, 'data') 175 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 176 | transform=transform) 177 | dataset2 = datasets.MNIST(dataDir, train=False, 178 | transform=transform) 179 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 180 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 181 | 182 | model = Net().to(device) 183 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 184 | 185 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 186 | for epoch in range(1, args.epochs + 1): 187 | # Start profiling from 2nd epoch 188 | if epoch == 2: 189 | torch.cuda.cudart().cudaProfilerStart() 190 | 191 | nvtx.range_push("Epoch " + str(epoch)) 192 | nvtx.range_push("Train") 193 | train(args, model, device, train_loader, optimizer, epoch) 194 | nvtx.range_pop() # Train 195 | 196 | nvtx.range_push("Test") 197 | test(model, device, test_loader) 198 | nvtx.range_pop() # Test 199 | 200 | scheduler.step() 201 | nvtx.range_pop() # Epoch 202 | # Stop profiling at the end of 2nd epoch 203 | if epoch == 2: 204 | torch.cuda.cudart().cudaProfilerStop() 205 | 206 | if args.save_model: 207 | torch.save(model.state_dict(), "mnist_cnn.pt") 208 | 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /workspace/source_code/main_opt3.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import multiprocessing 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.functional as F 39 | import torch.optim as optim 40 | from torchvision import datasets, transforms 41 | from torch.optim.lr_scheduler import StepLR 42 | from torch.cuda import nvtx 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 48 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 49 | self.dropout1 = nn.Dropout(0.25) 50 | self.dropout2 = nn.Dropout(0.5) 51 | self.fc1 = nn.Linear(9216, 128) 52 | self.fc2 = nn.Linear(128, 10) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = F.relu(x) 57 | x = self.conv2(x) 58 | x = F.relu(x) 59 | x = F.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = F.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | output = F.log_softmax(x, dim=1) 67 | return output 68 | 69 | 70 | def train(args, model, device, train_loader, optimizer, epoch): 71 | model.train() 72 | with torch.autograd.profiler.emit_nvtx(): 73 | nvtx.range_push("Data loading"); 74 | for batch_idx, (data, target) in enumerate(train_loader): 75 | nvtx.range_pop();# Data loading 76 | nvtx.range_push("Batch " + str(batch_idx)) 77 | 78 | nvtx.range_push("Copy to device") 79 | data, target = data.to(device), target.to(device) 80 | nvtx.range_pop() # Copy to device 81 | 82 | nvtx.range_push("Forward pass") 83 | optimizer.zero_grad() 84 | 85 | # Enables autocasting for the forward pass 86 | with torch.cuda.amp.autocast(enabled=True): 87 | output = model(data) 88 | loss = F.nll_loss(output, target) 89 | nvtx.range_pop() # Forward pass 90 | 91 | nvtx.range_push("Backward pass") 92 | loss.backward() 93 | optimizer.step() 94 | nvtx.range_pop() # Backward pass 95 | 96 | nvtx.range_pop() # Batch 97 | if batch_idx % args.log_interval == 0: 98 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 99 | epoch, batch_idx * len(data), len(train_loader.dataset), 100 | 100. * batch_idx / len(train_loader), loss.item())) 101 | if args.dry_run: 102 | break 103 | nvtx.range_push("Data loading"); 104 | nvtx.range_pop(); # Data loading 105 | 106 | 107 | def test(model, device, test_loader): 108 | model.eval() 109 | test_loss = 0 110 | correct = 0 111 | with torch.no_grad(): 112 | for data, target in test_loader: 113 | nvtx.range_push("Copy to device") 114 | data, target = data.to(device), target.to(device) 115 | nvtx.range_pop(); # Copy to device 116 | 117 | nvtx.range_push("Test forward pass") 118 | output = model(data) 119 | nvtx.range_pop() # Test forward pass 120 | 121 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 122 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 123 | correct += pred.eq(target.view_as(pred)).sum().item() 124 | 125 | test_loss /= len(test_loader.dataset) 126 | 127 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 128 | test_loss, correct, len(test_loader.dataset), 129 | 100. * correct / len(test_loader.dataset))) 130 | 131 | 132 | def main(): 133 | # Training settings 134 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 135 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 136 | help='input batch size for training (default: 64)') 137 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 138 | help='input batch size for testing (default: 1000)') 139 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 140 | help='number of epochs to train (default: 14)') 141 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 142 | help='learning rate (default: 1.0)') 143 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 144 | help='Learning rate step gamma (default: 0.7)') 145 | parser.add_argument('--no-cuda', action='store_true', default=False, 146 | help='disables CUDA training') 147 | parser.add_argument('--dry-run', action='store_true', default=False, 148 | help='quickly check a single pass') 149 | parser.add_argument('--seed', type=int, default=1, metavar='S', 150 | help='random seed (default: 1)') 151 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 152 | help='how many batches to wait before logging training status') 153 | parser.add_argument('--save-model', action='store_true', default=False, 154 | help='For Saving the current Model') 155 | args = parser.parse_args() 156 | use_cuda = not args.no_cuda and torch.cuda.is_available() 157 | 158 | torch.manual_seed(args.seed) 159 | 160 | device = torch.device("cuda" if use_cuda else "cpu") 161 | 162 | train_kwargs = {'batch_size': args.batch_size} 163 | test_kwargs = {'batch_size': args.test_batch_size} 164 | if use_cuda: 165 | #multiprocessing.cpu_count() 166 | cuda_kwargs = {'num_workers': 2, 167 | 'pin_memory': True, 168 | 'shuffle': True} 169 | train_kwargs.update(cuda_kwargs) 170 | test_kwargs.update(cuda_kwargs) 171 | 172 | transform=transforms.Compose([ 173 | transforms.ToTensor(), 174 | transforms.Normalize((0.1307,), (0.3081,)) 175 | ]) 176 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 177 | dataDir = os.path.join(scriptPath, 'data') 178 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 179 | transform=transform) 180 | dataset2 = datasets.MNIST(dataDir, train=False, 181 | transform=transform) 182 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 183 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 184 | 185 | model = Net().to(device) 186 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 187 | 188 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 189 | for epoch in range(1, args.epochs + 1): 190 | # Start profiling from 2nd epoch 191 | if epoch == 2: 192 | torch.cuda.cudart().cudaProfilerStart() 193 | 194 | nvtx.range_push("Epoch " + str(epoch)) 195 | nvtx.range_push("Train") 196 | train(args, model, device, train_loader, optimizer, epoch) 197 | nvtx.range_pop() # Train 198 | 199 | nvtx.range_push("Test") 200 | test(model, device, test_loader) 201 | nvtx.range_pop() # Test 202 | 203 | scheduler.step() 204 | nvtx.range_pop() # Epoch 205 | # Stop profiling at the end of 2nd epoch 206 | if epoch == 2: 207 | torch.cuda.cudart().cudaProfilerStop() 208 | 209 | if args.save_model: 210 | torch.save(model.state_dict(), "mnist_cnn.pt") 211 | 212 | 213 | if __name__ == '__main__': 214 | main() 215 | -------------------------------------------------------------------------------- /workspace/source_code/tb_main_baseline_profiler.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | import torch.optim as optim 39 | from torchvision import datasets, transforms 40 | from torch.optim.lr_scheduler import StepLR 41 | import torch.profiler 42 | 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 48 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 49 | self.dropout1 = nn.Dropout(0.25) 50 | self.dropout2 = nn.Dropout(0.5) 51 | self.fc1 = nn.Linear(9216, 128) 52 | self.fc2 = nn.Linear(128, 10) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = F.relu(x) 57 | x = self.conv2(x) 58 | x = F.relu(x) 59 | x = F.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = F.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | output = F.log_softmax(x, dim=1) 67 | return output 68 | 69 | 70 | def train(args, model, device, train_loader, optimizer, epoch,prof): 71 | model.train() 72 | for batch_idx, (data, target) in enumerate(train_loader): 73 | data, target = data.to(device), target.to(device) 74 | optimizer.zero_grad() 75 | output = model(data) 76 | 77 | loss = F.nll_loss(output, target) 78 | loss.backward() 79 | optimizer.step() 80 | prof.step() 81 | if batch_idx % args.log_interval == 0: 82 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 83 | epoch, batch_idx * len(data), len(train_loader.dataset), 84 | 100. * batch_idx / len(train_loader), loss.item())) 85 | if args.dry_run: 86 | break 87 | 88 | 89 | def test(model, device, test_loader): 90 | model.eval() 91 | test_loss = 0 92 | correct = 0 93 | with torch.no_grad(): 94 | for data, target in test_loader: 95 | data, target = data.to(device), target.to(device) 96 | output = model(data) 97 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 98 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 99 | correct += pred.eq(target.view_as(pred)).sum().item() 100 | 101 | test_loss /= len(test_loader.dataset) 102 | 103 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 104 | test_loss, correct, len(test_loader.dataset), 105 | 100. * correct / len(test_loader.dataset))) 106 | 107 | 108 | def main(): 109 | # Training settings 110 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 111 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 112 | help='input batch size for training (default: 64)') 113 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 114 | help='input batch size for testing (default: 1000)') 115 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 116 | help='number of epochs to train (default: 14)') 117 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 118 | help='learning rate (default: 1.0)') 119 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 120 | help='Learning rate step gamma (default: 0.7)') 121 | parser.add_argument('--no-cuda', action='store_true', default=False, 122 | help='disables CUDA training') 123 | parser.add_argument('--dry-run', action='store_true', default=False, 124 | help='quickly check a single pass') 125 | parser.add_argument('--seed', type=int, default=1, metavar='S', 126 | help='random seed (default: 1)') 127 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 128 | help='how many batches to wait before logging training status') 129 | parser.add_argument('--save-model', action='store_true', default=False, 130 | help='For Saving the current Model') 131 | args = parser.parse_args() 132 | use_cuda = not args.no_cuda and torch.cuda.is_available() 133 | 134 | torch.manual_seed(args.seed) 135 | 136 | device = torch.device("cuda" if use_cuda else "cpu") 137 | 138 | train_kwargs = {'batch_size': args.batch_size} 139 | test_kwargs = {'batch_size': args.test_batch_size} 140 | if use_cuda: 141 | cuda_kwargs = {'num_workers': 1, 142 | 'shuffle': True} 143 | train_kwargs.update(cuda_kwargs) 144 | test_kwargs.update(cuda_kwargs) 145 | 146 | transform=transforms.Compose([ 147 | transforms.ToTensor(), 148 | transforms.Normalize((0.1307,), (0.3081,)) 149 | ]) 150 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 151 | dataDir = os.path.join(scriptPath, 'data') 152 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 153 | transform=transform) 154 | dataset2 = datasets.MNIST(dataDir, train=False, 155 | transform=transform) 156 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 157 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 158 | 159 | model = Net().to(device) 160 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 161 | 162 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 163 | 164 | prof = torch.profiler.profile( 165 | schedule=torch.profiler.schedule( wait=1, warmup=1, active=3, repeat=2), 166 | 167 | on_trace_ready=torch.profiler.tensorboard_trace_handler('../log/mnist'), 168 | record_shapes=True,with_stack=True) 169 | prof.start() 170 | 171 | for epoch in range(1, args.epochs + 1): 172 | train(args, model, device, train_loader, optimizer, epoch, prof) 173 | test(model, device, test_loader) 174 | 175 | scheduler.step() 176 | 177 | prof.stop() 178 | if args.save_model: 179 | torch.save(model.state_dict(), "mnist_cnn.pt") 180 | 181 | 182 | if __name__ == '__main__': 183 | main() 184 | -------------------------------------------------------------------------------- /workspace/source_code/tb_main_opt1.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | import torch.optim as optim 39 | from torchvision import datasets, transforms 40 | from torch.optim.lr_scheduler import StepLR 41 | import torch.profiler 42 | 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 48 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 49 | self.dropout1 = nn.Dropout(0.25) 50 | self.dropout2 = nn.Dropout(0.5) 51 | self.fc1 = nn.Linear(9216, 128) 52 | self.fc2 = nn.Linear(128, 10) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = F.relu(x) 57 | x = self.conv2(x) 58 | x = F.relu(x) 59 | x = F.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = F.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | output = F.log_softmax(x, dim=1) 67 | return output 68 | 69 | 70 | def train(args, model, device, train_loader, optimizer, epoch,prof): 71 | model.train() 72 | for batch_idx, (data, target) in enumerate(train_loader): 73 | data, target = data.to(device), target.to(device) 74 | optimizer.zero_grad() 75 | output = model(data) 76 | 77 | loss = F.nll_loss(output, target) 78 | loss.backward() 79 | optimizer.step() 80 | prof.step() 81 | if batch_idx % args.log_interval == 0: 82 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 83 | epoch, batch_idx * len(data), len(train_loader.dataset), 84 | 100. * batch_idx / len(train_loader), loss.item())) 85 | if args.dry_run: 86 | break 87 | 88 | 89 | def test(model, device, test_loader): 90 | model.eval() 91 | test_loss = 0 92 | correct = 0 93 | with torch.no_grad(): 94 | for data, target in test_loader: 95 | data, target = data.to(device), target.to(device) 96 | output = model(data) 97 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 98 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 99 | correct += pred.eq(target.view_as(pred)).sum().item() 100 | 101 | test_loss /= len(test_loader.dataset) 102 | 103 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 104 | test_loss, correct, len(test_loader.dataset), 105 | 100. * correct / len(test_loader.dataset))) 106 | 107 | 108 | def main(): 109 | # Training settings 110 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 111 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 112 | help='input batch size for training (default: 64)') 113 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 114 | help='input batch size for testing (default: 1000)') 115 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 116 | help='number of epochs to train (default: 14)') 117 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 118 | help='learning rate (default: 1.0)') 119 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 120 | help='Learning rate step gamma (default: 0.7)') 121 | parser.add_argument('--no-cuda', action='store_true', default=False, 122 | help='disables CUDA training') 123 | parser.add_argument('--dry-run', action='store_true', default=False, 124 | help='quickly check a single pass') 125 | parser.add_argument('--seed', type=int, default=1, metavar='S', 126 | help='random seed (default: 1)') 127 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 128 | help='how many batches to wait before logging training status') 129 | parser.add_argument('--save-model', action='store_true', default=False, 130 | help='For Saving the current Model') 131 | args = parser.parse_args() 132 | use_cuda = not args.no_cuda and torch.cuda.is_available() 133 | 134 | torch.manual_seed(args.seed) 135 | 136 | device = torch.device("cuda" if use_cuda else "cpu") 137 | 138 | train_kwargs = {'batch_size': args.batch_size} 139 | test_kwargs = {'batch_size': args.test_batch_size} 140 | if use_cuda: 141 | cuda_kwargs = {'num_workers': 2, 142 | 'shuffle': True} 143 | train_kwargs.update(cuda_kwargs) 144 | test_kwargs.update(cuda_kwargs) 145 | 146 | transform=transforms.Compose([ 147 | transforms.ToTensor(), 148 | transforms.Normalize((0.1307,), (0.3081,)) 149 | ]) 150 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 151 | dataDir = os.path.join(scriptPath, 'data') 152 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 153 | transform=transform) 154 | dataset2 = datasets.MNIST(dataDir, train=False, 155 | transform=transform) 156 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 157 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 158 | 159 | model = Net().to(device) 160 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 161 | 162 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 163 | prof = torch.profiler.profile( 164 | schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), 165 | on_trace_ready=torch.profiler.tensorboard_trace_handler('../log/mnist_1'), 166 | record_shapes=True, 167 | profile_memory=True, 168 | with_stack=True) 169 | prof.start() 170 | 171 | for epoch in range(1, args.epochs + 1): 172 | 173 | train(args, model, device, train_loader, optimizer, epoch, prof) 174 | 175 | test(model, device, test_loader) 176 | scheduler.step() 177 | 178 | prof.stop() 179 | if args.save_model: 180 | torch.save(model.state_dict(), "mnist_cnn.pt") 181 | 182 | 183 | if __name__ == '__main__': 184 | main() 185 | -------------------------------------------------------------------------------- /workspace/source_code/tb_main_opt2.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | import torch.optim as optim 39 | from torchvision import datasets, transforms 40 | from torch.optim.lr_scheduler import StepLR 41 | import torch.profiler 42 | 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 48 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 49 | self.dropout1 = nn.Dropout(0.25) 50 | self.dropout2 = nn.Dropout(0.5) 51 | self.fc1 = nn.Linear(9216, 128) 52 | self.fc2 = nn.Linear(128, 10) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = F.relu(x) 57 | x = self.conv2(x) 58 | x = F.relu(x) 59 | x = F.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = F.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | output = F.log_softmax(x, dim=1) 67 | return output 68 | 69 | 70 | def train(args, model, device, train_loader, optimizer, epoch,prof): 71 | model.train() 72 | for batch_idx, (data, target) in enumerate(train_loader): 73 | data, target = data.to(device), target.to(device) 74 | optimizer.zero_grad() 75 | output = model(data) 76 | 77 | loss = F.nll_loss(output, target) 78 | loss.backward() 79 | optimizer.step() 80 | prof.step() 81 | if batch_idx % args.log_interval == 0: 82 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 83 | epoch, batch_idx * len(data), len(train_loader.dataset), 84 | 100. * batch_idx / len(train_loader), loss.item())) 85 | if args.dry_run: 86 | break 87 | 88 | 89 | def test(model, device, test_loader): 90 | model.eval() 91 | test_loss = 0 92 | correct = 0 93 | with torch.no_grad(): 94 | for data, target in test_loader: 95 | data, target = data.to(device), target.to(device) 96 | output = model(data) 97 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 98 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 99 | correct += pred.eq(target.view_as(pred)).sum().item() 100 | 101 | test_loss /= len(test_loader.dataset) 102 | 103 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 104 | test_loss, correct, len(test_loader.dataset), 105 | 100. * correct / len(test_loader.dataset))) 106 | 107 | 108 | def main(): 109 | # Training settings 110 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 111 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 112 | help='input batch size for training (default: 64)') 113 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 114 | help='input batch size for testing (default: 1000)') 115 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 116 | help='number of epochs to train (default: 14)') 117 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 118 | help='learning rate (default: 1.0)') 119 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 120 | help='Learning rate step gamma (default: 0.7)') 121 | parser.add_argument('--no-cuda', action='store_true', default=False, 122 | help='disables CUDA training') 123 | parser.add_argument('--dry-run', action='store_true', default=False, 124 | help='quickly check a single pass') 125 | parser.add_argument('--seed', type=int, default=1, metavar='S', 126 | help='random seed (default: 1)') 127 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 128 | help='how many batches to wait before logging training status') 129 | parser.add_argument('--save-model', action='store_true', default=False, 130 | help='For Saving the current Model') 131 | args = parser.parse_args() 132 | use_cuda = not args.no_cuda and torch.cuda.is_available() 133 | 134 | torch.manual_seed(args.seed) 135 | 136 | device = torch.device("cuda" if use_cuda else "cpu") 137 | 138 | train_kwargs = {'batch_size': args.batch_size} 139 | test_kwargs = {'batch_size': args.test_batch_size} 140 | if use_cuda: 141 | cuda_kwargs = {'num_workers': 2, 142 | 'pin_memory': True, 143 | 'shuffle': True} 144 | train_kwargs.update(cuda_kwargs) 145 | test_kwargs.update(cuda_kwargs) 146 | 147 | transform=transforms.Compose([ 148 | transforms.ToTensor(), 149 | transforms.Normalize((0.1307,), (0.3081,)) 150 | ]) 151 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 152 | dataDir = os.path.join(scriptPath, 'data') 153 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 154 | transform=transform) 155 | dataset2 = datasets.MNIST(dataDir, train=False, 156 | transform=transform) 157 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 158 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 159 | 160 | model = Net().to(device) 161 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 162 | 163 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 164 | prof = torch.profiler.profile( 165 | schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), 166 | on_trace_ready=torch.profiler.tensorboard_trace_handler('../log/mnist_2'), 167 | record_shapes=True, 168 | profile_memory=True, 169 | with_stack=True) 170 | prof.start() 171 | 172 | for epoch in range(1, args.epochs + 1): 173 | 174 | train(args, model, device, train_loader, optimizer, epoch, prof) 175 | 176 | test(model, device, test_loader) 177 | scheduler.step() 178 | 179 | prof.stop() 180 | if args.save_model: 181 | torch.save(model.state_dict(), "mnist_cnn.pt") 182 | 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /workspace/source_code/tb_main_opt3.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION 4 | # Copyright (c) 2017, 5 | # All rights reserved. 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # * Redistributions of source code must retain the above copyright notice, this 11 | # list of conditions and the following disclaimer. 12 | # 13 | # * Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # * Neither the name of the copyright holder nor the names of its 18 | # contributors may be used to endorse or promote products derived from 19 | # this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | from __future__ import print_function 33 | import argparse 34 | import os 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | import torch.optim as optim 39 | from torchvision import datasets, transforms 40 | from torch.optim.lr_scheduler import StepLR 41 | import torch.profiler 42 | 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 48 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 49 | self.dropout1 = nn.Dropout(0.25) 50 | self.dropout2 = nn.Dropout(0.5) 51 | self.fc1 = nn.Linear(9216, 128) 52 | self.fc2 = nn.Linear(128, 10) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = F.relu(x) 57 | x = self.conv2(x) 58 | x = F.relu(x) 59 | x = F.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = F.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | output = F.log_softmax(x, dim=1) 67 | return output 68 | 69 | 70 | def train(args, model, device, train_loader, optimizer, epoch,prof): 71 | model.train() 72 | for batch_idx, (data, target) in enumerate(train_loader): 73 | data, target = data.to(device), target.to(device) 74 | optimizer.zero_grad() 75 | 76 | # Enables autocasting for the forward pass 77 | with torch.cuda.amp.autocast(enabled=True): 78 | output = model(data) 79 | loss = F.nll_loss(output, target) 80 | 81 | loss.backward() 82 | optimizer.step() 83 | prof.step() 84 | if batch_idx % args.log_interval == 0: 85 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 86 | epoch, batch_idx * len(data), len(train_loader.dataset), 87 | 100. * batch_idx / len(train_loader), loss.item())) 88 | if args.dry_run: 89 | break 90 | 91 | 92 | def test(model, device, test_loader): 93 | model.eval() 94 | test_loss = 0 95 | correct = 0 96 | with torch.no_grad(): 97 | for data, target in test_loader: 98 | data, target = data.to(device), target.to(device) 99 | output = model(data) 100 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 101 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 102 | correct += pred.eq(target.view_as(pred)).sum().item() 103 | 104 | test_loss /= len(test_loader.dataset) 105 | 106 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 107 | test_loss, correct, len(test_loader.dataset), 108 | 100. * correct / len(test_loader.dataset))) 109 | 110 | 111 | def main(): 112 | # Training settings 113 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 114 | parser.add_argument('--batch-size', type=int, default=1024, metavar='N', 115 | help='input batch size for training (default: 64)') 116 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 117 | help='input batch size for testing (default: 1000)') 118 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 119 | help='number of epochs to train (default: 14)') 120 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 121 | help='learning rate (default: 1.0)') 122 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 123 | help='Learning rate step gamma (default: 0.7)') 124 | parser.add_argument('--no-cuda', action='store_true', default=False, 125 | help='disables CUDA training') 126 | parser.add_argument('--dry-run', action='store_true', default=False, 127 | help='quickly check a single pass') 128 | parser.add_argument('--seed', type=int, default=1, metavar='S', 129 | help='random seed (default: 1)') 130 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 131 | help='how many batches to wait before logging training status') 132 | parser.add_argument('--save-model', action='store_true', default=False, 133 | help='For Saving the current Model') 134 | args = parser.parse_args() 135 | use_cuda = not args.no_cuda and torch.cuda.is_available() 136 | 137 | torch.manual_seed(args.seed) 138 | 139 | device = torch.device("cuda" if use_cuda else "cpu") 140 | 141 | train_kwargs = {'batch_size': args.batch_size} 142 | test_kwargs = {'batch_size': args.test_batch_size} 143 | if use_cuda: 144 | cuda_kwargs = {'num_workers': 2, 145 | 'pin_memory': True, 146 | 'shuffle': True} 147 | train_kwargs.update(cuda_kwargs) 148 | test_kwargs.update(cuda_kwargs) 149 | 150 | transform=transforms.Compose([ 151 | transforms.ToTensor(), 152 | transforms.Normalize((0.1307,), (0.3081,)) 153 | ]) 154 | scriptPath = os.path.dirname(os.path.realpath(__file__)) 155 | dataDir = os.path.join(scriptPath, 'data') 156 | dataset1 = datasets.MNIST(dataDir, train=True, download=True, 157 | transform=transform) 158 | dataset2 = datasets.MNIST(dataDir, train=False, 159 | transform=transform) 160 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 161 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 162 | 163 | model = Net().to(device) 164 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 165 | 166 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 167 | prof = torch.profiler.profile( 168 | schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), 169 | on_trace_ready=torch.profiler.tensorboard_trace_handler('../log/mnist_3'), 170 | record_shapes=True, 171 | profile_memory = True, 172 | with_stack=True) 173 | prof.start() 174 | 175 | for epoch in range(1, args.epochs + 1): 176 | train(args, model, device, train_loader, optimizer, epoch, prof) 177 | test(model, device, test_loader) 178 | 179 | scheduler.step() 180 | 181 | prof.stop() 182 | if args.save_model: 183 | torch.save(model.state_dict(), "mnist_cnn.pt") 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /workspace/start_here.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Optimizing a Deep Neural Network (DNN) Training Program\n", 8 | "---" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "This lab is presented in two parts: part 1 and 2. Part 1 discusses profiling using NVIDIA® Nsight™ Systems, while part 2 explains the application of PyTorch Profiler with TensorBoard plugin. Both parts focus on steps to optimize a deep neural network (DNN) training program that detects handwritten digits using a PyTorch Modified National Institute of Standards and Technology (MNIST) dataset. The techniques and strategies discussed in this lab will translate to optimizing any application that uses NVIDIA's graphic processing units (GPUs).\n", 16 | "\n", 17 | "In this lab, you will learn how to do the following:\n", 18 | "- Run the sample application,\n", 19 | "- Use NVIDIA Nsight Systems to profile the application,\n", 20 | "- Use PyTorch Profiler to profile the application and visualize it on TensorBoard,\n", 21 | "- Interpret the timeline provided by NVIDIA Nsight Systems and understand the application's use of the system resources,\n", 22 | "- Use TensorBoard execution summary, step time breakdown, and performance recommendation to understand the application's use of the system resources,\n", 23 | "- Identify performance problems in the application and apply optimization strategies, and Confirm the performance improvement gained from the optimizations.\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "Below is the agenda to get us started with the optimizing process of a simple DNN training program.\n", 31 | "\n", 32 | "## Table of Content\n", 33 | "1. Part 1 (Profiling With NVIDIA Nsight Systems) [**required for evaluation**]\n", 34 | " 1. [Start the NVIDIA Nsight Systems lab](jupyter_notebook/01_introduction.ipynb)\n", 35 | " 1. [PyTorch MNIST and Optimization Workflow](jupyter_notebook/02_pytorch_mnist.ipynb)\n", 36 | " 1. [Data Transfers between Host and GPU](jupyter_notebook/03_data_transfer.ipynb)\n", 37 | " 1. [Tensor Core](jupyter_notebook/04_tensor_core_util.ipynb)\n", 38 | " 1. [Summary](jupyter_notebook/05_summary.ipynb)\n", 39 | "1. Part 2 (PyTorch Profiler with TensorBoard)[**optional**]\n", 40 | " 1. [Start PyTorch Profiler with TensorBoard plugin](jupyter_notebook/tb01_introduction.ipynb)\n", 41 | " 1. [PyTorch MNIST Optimization from TensorBoard Visualization](jupyter_notebook/tb02_pytorch_mnist.ipynb)\n", 42 | " 1. [Memory Operation ](jupyter_notebook/tb03_data_transfer.ipynb)\n", 43 | " 1. [Tensor Core](jupyter_notebook/tb04_tensor_core_util.ipynb)\n", 44 | " 1. [Summary](jupyter_notebook/tb05_summary.ipynb)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "Execute the `nvidia-smi` command in the cell below to display information about the NVIDIA CUDA® driver and GPUs running on the server by clicking on it with your mouse and pressing `Ctrl+Enter`, or pressing the play button in the toolbar above. You should see some output returned below the grey cell if the command is executed properly." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "!nvidia-smi" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "### Tutorial Duration\n", 68 | "The lab material will be presented in a 2-hour session.\n", 69 | "\n", 70 | "### Content Level\n", 71 | "Beginner, Intermediate\n", 72 | "\n", 73 | "### Target Audience and Prerequisites\n", 74 | "The target audience for this lab are prospective mentors who plan to mentor at artificial intelligence (AI)-based hackathons . Python programming background knowledge is expected." 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## Links and Resources\n", 82 | "\n", 83 | "\n", 84 | "[NVIDIA Nsight Systems](https://docs.nvidia.com/nsight-systems/)\n", 85 | "\n", 86 | "\n", 87 | "**NOTE**: To be able to see the profiler output, please download the latest version of NVIDIA Nsight Systems from [here](https://developer.nvidia.com/nsight-systems).\n", 88 | "\n", 89 | "\n", 90 | "You can also get resources from [Open Hackathons technical resource page](https://www.openhackathons.org/s/technical-resources)\n", 91 | "\n", 92 | "\n", 93 | "--- \n", 94 | "\n", 95 | "## Licensing \n", 96 | "\n", 97 | "Copyright © 2022 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials may include references to hardware and software developed by other entities; all applicable licensing and copyrights apply.\n" 98 | ] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "Python 3", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.8.8" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 4 122 | } 123 | --------------------------------------------------------------------------------