├── .gitignore ├── README.md ├── docker ├── Dockerfile.gpu ├── Dockerfile.habana ├── Dockerfile.xlagpu ├── README.md ├── start_gpu_pytorch2 ├── start_habana_pytorch2 └── start_xla_pytorch2 ├── pytorch-compile-blogpost └── torch-compile-under-the-hood.ipynb ├── pytorch-graph-optimization ├── benchmark_torch-compile_resnet.ipynb ├── graph_optimization_torch_compile.ipynb └── inspecting_torch_compile.ipynb └── pytorch-intro-torch-compile ├── 1-toy-benchmarks.ipynb ├── 2-torch-compile-intro.ipynb ├── 3-inspecting-compiler-stack.ipynb └── 4-nn-example.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | torch_compile_debug/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-examples 2 | A repository of PyTorch example 3 | -------------------------------------------------------------------------------- /docker/Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-base-ubuntu22.04 2 | 3 | # Remove any third-party apt sources to avoid issues with expiring keys. 4 | RUN rm -f /etc/apt/sources.list.d/*.list 5 | 6 | # Install some basic utilities. 7 | RUN apt-get update && apt-get install -y \ 8 | build-essential \ 9 | graphviz \ 10 | autoconf \ 11 | automake \ 12 | gdb \ 13 | libffi-dev \ 14 | zlib1g-dev \ 15 | libssl-dev \ 16 | libsndfile1 \ 17 | curl \ 18 | wget \ 19 | vim \ 20 | ca-certificates \ 21 | git \ 22 | bzip2 \ 23 | libx11-6 \ 24 | && rm -rf /var/lib/apt/lists/* 25 | 26 | # Download and install Miniconda. 27 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh 28 | RUN bash /tmp/miniconda.sh -b -p /opt/conda 29 | ENV PATH=/opt/conda/bin:$PATH 30 | RUN conda init 31 | RUN pip install \ 32 | torch \ 33 | torchvision \ 34 | torchaudio \ 35 | jupyterlab \ 36 | ipykernel \ 37 | matplotlib \ 38 | ipywidgets \ 39 | pydot \ 40 | huggingface \ 41 | transformers \ 42 | datasets \ 43 | diffusers["torch"] \ 44 | accelerate \ 45 | soundfile \ 46 | librosa \ 47 | torchdiffeq 48 | RUN pip install \ 49 | xgboost \ 50 | scikit-learn \ 51 | tabulate \ 52 | bokeh 53 | 54 | RUN jupyter labextension disable "@jupyterlab/apputils-extension:announcements" 55 | #ENTRYPOINT ["jupyter", "lab", "--ip='*'", "--NotebookApp.token=''", "--NotebookApp.password=''","--allow-root"] 56 | ENTRYPOINT ["jupyter", "lab", "--ip='*'", "--allow-root"] 57 | 58 | 59 | -------------------------------------------------------------------------------- /docker/Dockerfile.habana: -------------------------------------------------------------------------------- 1 | FROM vault.habana.ai/gaudi-docker/1.8.0/ubuntu20.04/habanalabs/pytorch-installer-1.13.1:latest 2 | RUN apt-get update 3 | RUN apt-get install build-essential -y 4 | RUN apt-get install autoconf automake gdb git libffi-dev zlib1g-dev libssl-dev libsndfile1 -y 5 | RUN pip install jupyterlab ipykernel matplotlib ipywidgets 6 | RUN pip install huggingface transformers datasets 7 | RUN pip install diffusers["torch"] accelerate soundfile librosa 8 | WORKDIR /pytorch-habana 9 | ENTRYPOINT ["jupyter", "lab", "--ip='*'", "--NotebookApp.token=''", "--NotebookApp.password=''","--allow-root"] 10 | -------------------------------------------------------------------------------- /docker/Dockerfile.xlagpu: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-base-ubuntu22.04 2 | 3 | # Remove any third-party apt sources to avoid issues with expiring keys. 4 | RUN rm -f /etc/apt/sources.list.d/*.list 5 | 6 | # Install some basic utilities. 7 | RUN apt-get update && apt-get install -y \ 8 | build-essential \ 9 | graphviz \ 10 | autoconf \ 11 | automake \ 12 | gdb \ 13 | libffi-dev \ 14 | zlib1g-dev \ 15 | libssl-dev \ 16 | libsndfile1 \ 17 | curl \ 18 | wget \ 19 | vim \ 20 | ca-certificates \ 21 | git \ 22 | bzip2 \ 23 | libx11-6 \ 24 | && rm -rf /var/lib/apt/lists/* 25 | 26 | # Download and install Miniconda. 27 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py38_23.1.0-1-Linux-x86_64.sh -O /tmp/miniconda.sh 28 | RUN bash /tmp/miniconda.sh -b -p /opt/conda 29 | ENV PATH=/opt/conda/bin:$PATH 30 | 31 | # Make RUN commands use the new environment: 32 | RUN pip install \ 33 | cloud-tpu-client==0.10 \ 34 | torch==2.0.0 \ 35 | torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-2.0-cp38-cp38-linux_x86_64.whl 36 | RUN pip install \ 37 | jupyterlab \ 38 | ipykernel \ 39 | matplotlib \ 40 | ipywidgets \ 41 | pydot \ 42 | huggingface \ 43 | transformers \ 44 | datasets \ 45 | diffusers["torch"] \ 46 | accelerate \ 47 | soundfile \ 48 | librosa \ 49 | torchdiffeq 50 | RUN pip install \ 51 | xgboost \ 52 | scikit-learn \ 53 | bokeh 54 | 55 | RUN jupyter labextension disable "@jupyterlab/apputils-extension:announcements" 56 | ENTRYPOINT ["jupyter", "lab", "--ip='*'", "--NotebookApp.token=''", "--NotebookApp.password=''","--allow-root"] 57 | 58 | 59 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ### Dockerfiles for environments used in examples -------------------------------------------------------------------------------- /docker/start_gpu_pytorch2: -------------------------------------------------------------------------------- 1 | [ ! -z "$(docker ps -a -q)" ] && docker rm $(docker ps -a -q) 2 | callpath=$PWD 3 | cd $(dirname "$0") 4 | docker build -t pytorch2:gpu -f Dockerfile.gpu . 5 | cd $callpath 6 | [ ! -z "$(docker images -f dangling=true -q)" ] && docker rmi $(docker images -f dangling=true -q) 7 | docker run -it --rm \ 8 | --network=host \ 9 | --gpus all \ 10 | -v $PWD/:/workspace \ 11 | -v ~/.cache/:/root/.cache \ 12 | --workdir /workspace \ 13 | --name pytorch \ 14 | pytorch2:gpu 15 | -------------------------------------------------------------------------------- /docker/start_habana_pytorch2: -------------------------------------------------------------------------------- 1 | docker build -t pytorch-habana:latest -f docker/Dockerfile . 2 | docker rmi $(docker images -f dangling=true -q) 3 | docker run -it --network=host -v $PWD/:/pytorch-habana --workdir /pytorch-habana pytorch-habana:latest 4 | -------------------------------------------------------------------------------- /docker/start_xla_pytorch2: -------------------------------------------------------------------------------- 1 | [ ! -z "$(docker ps -a -q)" ] && docker rm $(docker ps -a -q) 2 | docker build -t pytorch2:xla -f docker/Dockerfile.xlagpu . 3 | [ ! -z "$(docker images -f dangling=true -q)" ] && docker rmi $(docker images -f dangling=true -q) 4 | docker run -it --rm \ 5 | --network=host \ 6 | --gpus all \ 7 | -v $PWD/:/pytorch-examples \ 8 | -v ~/.cache/:/root/.cache \ 9 | --workdir /pytorch-examples \ 10 | pytorch2:xla 11 | -------------------------------------------------------------------------------- /pytorch-compile-blogpost/torch-compile-under-the-hood.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c0d1d55d-42e3-4445-b6c1-dfd95f7999a7", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!rm -rf torch_compile_debug\n", 13 | "!rm *.svg" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "79b09805-e61c-4be4-aaf0-2dc148b0b736", 20 | "metadata": { 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import torch\n", 26 | "import math\n", 27 | "import os\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import torch._dynamo\n", 30 | "from torchvision import models\n", 31 | "from torch.fx.passes.graph_drawer import FxGraphDrawer\n", 32 | "from IPython.display import Markdown as md\n", 33 | "\n", 34 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\"" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "a51c7031-33d1-44ee-8398-3137b5f7d085", 41 | "metadata": { 42 | "tags": [] 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "def f(x):\n", 47 | " return torch.sin(x)**2 + torch.cos(x)**2\n", 48 | "\n", 49 | "md('''\n", 50 | "# $ y = f(x) = sin^2(x) + cos^2(x)$\n", 51 | "''')" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "82568cbb-750b-4ed9-baf2-1cf2c6028c0d", 58 | "metadata": { 59 | "tags": [] 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "md('''\n", 64 | "## Optimization problem:\n", 65 | "### find $w^*$ that $\\displaystyle \\min_{w} f(w)$\n", 66 | "\n", 67 | "## Gradient update\n", 68 | "### $ w_{i+1} = w_{i} - g(\\\\nabla f(w))$\n", 69 | "## For SGD\n", 70 | "### $ g(\\\\nabla f(w)) = \\\\alpha*\\\\nabla f(w)$\n", 71 | "## Which makes the update for SGD:\n", 72 | "### $ w_{i+1} = w_{i} - \\\\alpha*\\\\nabla f(w)$\n", 73 | "\n", 74 | "## Loss function\n", 75 | "### $loss(w): loss(model(w,batch_{inputs}), batch_{outputs})$ \n", 76 | "''')" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "4b4c1717-2fbe-43f6-86ad-1f12d397d33d", 83 | "metadata": { 84 | "tags": [] 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "md('''\n", 89 | "# **Forward graph:** $f(x) = sin^2(x)+cos^2(x)$ \\n\n", 90 | "# **Backward graph:** $\\\\frac {df(x)}{d\\\\vec{w}} = f\\'(x) = 2sin(x)cos(x) + 2cos(x)(-sin(x))$\n", 91 | "''')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "0ba8ff94-c457-41fd-b7ff-91ac36872bee", 98 | "metadata": { 99 | "tags": [] 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "torch.manual_seed(0)\n", 104 | "x = torch.rand(1000, requires_grad=True).to(device)\n", 105 | "torch.nn.functional.mse_loss(f(x),torch.ones_like(x)) < 1e-10" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "f5f80ceb-7e1a-4140-a17c-b6827e6e7461", 112 | "metadata": { 113 | "tags": [] 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "torch.manual_seed(0)\n", 118 | "x = torch.rand(1000, requires_grad=True).to(device)\n", 119 | "\n", 120 | "compiled_f = torch.compile(f)\n", 121 | "torch.nn.functional.mse_loss(compiled_f(x),torch.ones_like(x)) < 1e-10" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "b3a9cf8b-37ef-4b00-9c1b-d21e9c0ed6af", 128 | "metadata": { 129 | "tags": [] 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "def inspect_backend(gm, sample_inputs):\n", 134 | " code = gm.print_readable()\n", 135 | " with open(\"forward.svg\", \"wb\") as file:\n", 136 | " file.write(FxGraphDrawer(gm,'f').get_dot_graph().create_svg())\n", 137 | " return gm.forward\n", 138 | "\n", 139 | "torch._dynamo.reset()\n", 140 | "compiled_f = torch.compile(f, backend=inspect_backend)\n", 141 | "\n", 142 | "x = torch.rand(1000, requires_grad=True).to(device)\n", 143 | "out = compiled_f(x)\n", 144 | "\n", 145 | "md(f'''\n", 146 | "### Graph\n", 147 | "![]({'forward.svg'})\n", 148 | "''')" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "7ec2518d-1024-4aeb-a991-67de9dc51090", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "619928ed-ed3f-45a4-8ef0-783c70bafaed", 162 | "metadata": {}, 163 | "source": [ 164 | "# AOTAutograd and Aten IR" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "b900c2c1-b83e-44ba-8eb2-c3d71c3bd53e", 171 | "metadata": { 172 | "tags": [] 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "import torch._dynamo\n", 177 | "from torch.fx.passes.graph_drawer import FxGraphDrawer\n", 178 | "from functorch.compile import make_boxed_func\n", 179 | "from torch._functorch.aot_autograd import aot_module_simplified\n", 180 | "\n", 181 | "def f(x):\n", 182 | " return torch.sin(x)**2 + torch.cos(x)**2\n", 183 | "\n", 184 | "def inspect_backend(gm, sample_inputs): \n", 185 | " # Forward compiler capture\n", 186 | " def fw(gm, sample_inputs):\n", 187 | " gm.print_readable()\n", 188 | " g = FxGraphDrawer(gm, 'fn')\n", 189 | " with open(\"forward_aot.svg\", \"wb\") as file:\n", 190 | " file.write(g.get_dot_graph().create_svg())\n", 191 | " return make_boxed_func(gm.forward)\n", 192 | " \n", 193 | " # Backward compiler capture\n", 194 | " def bw(gm, sample_inputs):\n", 195 | " gm.print_readable()\n", 196 | " g = FxGraphDrawer(gm, 'fn')\n", 197 | " with open(\"backward_aot.svg\", \"wb\") as file:\n", 198 | " file.write(g.get_dot_graph().create_svg())\n", 199 | " return make_boxed_func(gm.forward)\n", 200 | " \n", 201 | " # Call AOTAutograd\n", 202 | " gm_forward = aot_module_simplified(gm,sample_inputs,\n", 203 | " fw_compiler=fw,\n", 204 | " bw_compiler=bw)\n", 205 | "\n", 206 | " return gm_forward\n", 207 | "\n", 208 | "torch.manual_seed(0)\n", 209 | "x = torch.rand(1000, requires_grad=True).to(device)\n", 210 | "y = torch.ones_like(x)\n", 211 | "\n", 212 | "torch._dynamo.reset()\n", 213 | "compiled_f = torch.compile(f, backend=inspect_backend)\n", 214 | "out = torch.nn.functional.mse_loss(compiled_f(x), y).backward()" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "264ad8cc-6fc2-4e58-98c7-3dcff496ab59", 221 | "metadata": { 222 | "tags": [] 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "md(f'''\n", 227 | "|![]({'forward_aot.svg'}) | < Forward graph


Backward graph >|![]({'backward_aot.svg'})|\n", 228 | "|---|---|---|\n", 229 | "''')" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "id": "a5a4f76f-1461-43d8-a8d4-1d24a7f6d1a8", 235 | "metadata": {}, 236 | "source": [ 237 | "# Decomposition to Core Aten IR" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "bf9f0343-65cc-4d7e-8c53-55292b46c87d", 244 | "metadata": { 245 | "tags": [] 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "import torch._dynamo\n", 250 | "from torch.fx.passes.graph_drawer import FxGraphDrawer\n", 251 | "from functorch.compile import make_boxed_func\n", 252 | "from torch._functorch.aot_autograd import aot_module_simplified\n", 253 | "from torch._decomp import core_aten_decompositions\n", 254 | "\n", 255 | "def f_loss(x, y):\n", 256 | " f_x = torch.sin(x)**2 + torch.cos(x)**2\n", 257 | " return torch.nn.functional.mse_loss(f_x, y)\n", 258 | "\n", 259 | "# decompositions = core_aten_decompositions() # Use decomposition to Core Aten IR\n", 260 | "decompositions = {} # Don't use decomposition to Core Aten IR\n", 261 | "\n", 262 | "def inspect_backend(gm, sample_inputs): \n", 263 | " def fw(gm, sample_inputs):\n", 264 | " gm.print_readable()\n", 265 | " g = FxGraphDrawer(gm, 'fn')\n", 266 | " with open(\"forward_decomp.svg\", \"wb\") as file:\n", 267 | " file.write(g.get_dot_graph().create_svg())\n", 268 | " return make_boxed_func(gm.forward)\n", 269 | " \n", 270 | " def bw(gm, sample_inputs):\n", 271 | " gm.print_readable()\n", 272 | " g = FxGraphDrawer(gm, 'fn')\n", 273 | " with open(\"backward_decomp.svg\", \"wb\") as file:\n", 274 | " file.write(g.get_dot_graph().create_svg())\n", 275 | " return make_boxed_func(gm.forward)\n", 276 | "\n", 277 | " # Invoke AOTAutograd\n", 278 | " return aot_module_simplified(\n", 279 | " gm,\n", 280 | " sample_inputs,\n", 281 | " fw_compiler=fw,\n", 282 | " bw_compiler=bw,\n", 283 | " decompositions=decompositions\n", 284 | " )\n", 285 | "\n", 286 | "torch.manual_seed(0)\n", 287 | "x = torch.rand(1000, requires_grad=True).to(device)\n", 288 | "y = torch.ones_like(x)\n", 289 | "\n", 290 | "torch._dynamo.reset()\n", 291 | "compiled_f = torch.compile(f_loss, backend=inspect_backend)\n", 292 | "out = compiled_f(x,y).backward()\n", 293 | "\n", 294 | "\n", 295 | "md('''\n", 296 | "# $MSE = (\\\\frac{1}{n})(\\\\vec{y}-\\\\vec{x})^2$\n", 297 | "\n", 298 | "''')" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "id": "2eb86a77-d050-4f80-96b0-c23ac12bd1ee", 305 | "metadata": { 306 | "tags": [] 307 | }, 308 | "outputs": [], 309 | "source": [ 310 | "md(f'''\n", 311 | "|![]({'forward_decomp.svg'}) | < Forward graph


Backward graph >|![]({'backward_decomp.svg'})|\n", 312 | "|---|---|---|\n", 313 | "''')" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "id": "20a397ac-69a2-4cbe-af0d-9d3c50e00ece", 319 | "metadata": {}, 320 | "source": [ 321 | "# Decomposition to prim IR" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "id": "19e71d40-6815-4167-862e-50253913e168", 328 | "metadata": { 329 | "tags": [] 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "import torch._dynamo\n", 334 | "from torch.fx.passes.graph_drawer import FxGraphDrawer\n", 335 | "from functorch.compile import make_boxed_func\n", 336 | "from torch._functorch.aot_autograd import aot_module_simplified\n", 337 | "from torch._decomp import core_aten_decompositions\n", 338 | "\n", 339 | "def f_loss(x, y):\n", 340 | " f_x = torch.sin(x)**2 + torch.cos(x)**2\n", 341 | " return torch.nn.functional.mse_loss(f_x, y)\n", 342 | "\n", 343 | "decompositions = core_aten_decompositions()\n", 344 | "decompositions.update(\n", 345 | " torch._decomp.get_decompositions([\n", 346 | " torch.ops.aten.sin,\n", 347 | " torch.ops.aten.cos,\n", 348 | " torch.ops.aten.add,\n", 349 | " torch.ops.aten.sub,\n", 350 | " torch.ops.aten.mul,\n", 351 | " torch.ops.aten.sum,\n", 352 | " torch.ops.aten.mean,\n", 353 | " torch.ops.aten.pow.Tensor_Scalar,\n", 354 | " ])\n", 355 | ")\n", 356 | "\n", 357 | "def inspect_backend(gm, sample_inputs): \n", 358 | " def fw(gm, sample_inputs):\n", 359 | " gm.print_readable()\n", 360 | " g = FxGraphDrawer(gm, 'fn')\n", 361 | " with open(\"forward_decomp_prims.svg\", \"wb\") as f:\n", 362 | " f.write(g.get_dot_graph().create_svg())\n", 363 | " return make_boxed_func(gm.forward)\n", 364 | " \n", 365 | " def bw(gm, sample_inputs):\n", 366 | " gm.print_readable()\n", 367 | " g = FxGraphDrawer(gm, 'fn')\n", 368 | " with open(\"backward_decomp_prims.svg\", \"wb\") as f:\n", 369 | " f.write(g.get_dot_graph().create_svg())\n", 370 | " return make_boxed_func(gm.forward)\n", 371 | "\n", 372 | " # Invoke AOTAutograd\n", 373 | " return aot_module_simplified(\n", 374 | " gm,\n", 375 | " sample_inputs,\n", 376 | " fw_compiler=fw,\n", 377 | " bw_compiler=bw,\n", 378 | " decompositions=decompositions\n", 379 | " )\n", 380 | "\n", 381 | "torch.manual_seed(0)\n", 382 | "x = torch.rand(1000, requires_grad=True).to(device)\n", 383 | "y = torch.ones_like(x)\n", 384 | "\n", 385 | "torch._dynamo.reset()\n", 386 | "compiled_f = torch.compile(f_loss, backend=inspect_backend)\n", 387 | "out = compiled_f(x,y).backward()\n" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "id": "b6b79bad-edc7-413d-94be-84ba286841b5", 394 | "metadata": { 395 | "tags": [] 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "md(f'''\n", 400 | "|![]({'forward_decomp_prims.svg'}) | < Forward graph


Backward graph >|![]({'backward_decomp_prims.svg'})|\n", 401 | "|---|---|---|\n", 402 | "''')" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "91ff887d-46de-48d0-85a6-f1888c04e136", 409 | "metadata": { 410 | "tags": [] 411 | }, 412 | "outputs": [], 413 | "source": [ 414 | "def f(x):\n", 415 | " return torch.sin(x)**2 + torch.cos(x)**2 \n", 416 | "\n", 417 | "torch._dynamo.reset()\n", 418 | "compiled_f = torch.compile(f, backend='inductor',\n", 419 | " options={'trace.enabled':True,\n", 420 | " 'trace.graph_diagram':True})\n", 421 | "\n", 422 | "\n", 423 | "# device = 'cpu'\n", 424 | "device = 'cuda'\n", 425 | "\n", 426 | "torch.manual_seed(0)\n", 427 | "x = torch.rand(1000, requires_grad=True).to(device)\n", 428 | "y = torch.ones_like(x)\n", 429 | "\n", 430 | "out = torch.nn.functional.mse_loss(compiled_f(x),y).backward()" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "id": "a06362e9-7783-46bf-8498-bce83bc749ef", 437 | "metadata": { 438 | "tags": [] 439 | }, 440 | "outputs": [], 441 | "source": [ 442 | "import glob\n", 443 | "fwd = glob.glob('torch_compile_debug/run_*/aot_torchinductor/*forward*/graph_diagram.svg')[-1]\n", 444 | "bwd = glob.glob('torch_compile_debug/run_*/aot_torchinductor/*backward*/graph_diagram.svg')[-1]\n", 445 | "\n", 446 | "md(f'''\n", 447 | "|![]({fwd}) | < Forward graph


Backward graph >|![]({bwd})|\n", 448 | "|---|---|---|\n", 449 | "''')" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "id": "fb68e773-dce1-431e-8070-e2825d255bb1", 456 | "metadata": {}, 457 | "outputs": [], 458 | "source": [] 459 | } 460 | ], 461 | "metadata": { 462 | "kernelspec": { 463 | "display_name": "Python 3 (ipykernel)", 464 | "language": "python", 465 | "name": "python3" 466 | }, 467 | "language_info": { 468 | "codemirror_mode": { 469 | "name": "ipython", 470 | "version": 3 471 | }, 472 | "file_extension": ".py", 473 | "mimetype": "text/x-python", 474 | "name": "python", 475 | "nbconvert_exporter": "python", 476 | "pygments_lexer": "ipython3", 477 | "version": "3.10.9" 478 | } 479 | }, 480 | "nbformat": 4, 481 | "nbformat_minor": 5 482 | } 483 | -------------------------------------------------------------------------------- /pytorch-graph-optimization/benchmark_torch-compile_resnet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 19, 6 | "id": "b9fe79ed-0970-4b13-a6ac-4fc4f3c2e628", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import torch.utils.benchmark as benchmark\n", 13 | "import torch\n", 14 | "from torchvision.models import resnet\n", 15 | "import torch._dynamo\n", 16 | "\n", 17 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\"" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 20, 23 | "id": "3bbdc444-058e-4439-9ee7-121fd962af1a", 24 | "metadata": { 25 | "tags": [] 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "def run_batch_inference(model, batch=1):\n", 30 | " x = torch.randn(batch, 3, 224, 224).to(device)\n", 31 | " model(x)\n", 32 | "\n", 33 | "def run_batch_train(model, optimizer, batch=16):\n", 34 | " x = torch.randn(batch, 3, 224, 224).to(device)\n", 35 | " optimizer.zero_grad()\n", 36 | " out = model(x)\n", 37 | " out.sum().backward()\n", 38 | " optimizer.step()\n", 39 | " \n", 40 | "model = resnet.resnet18(weights=resnet.ResNet18_Weights.IMAGENET1K_V1).to(device)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 22, 46 | "id": "7dfab6b2-57da-4bad-a8f1-af5df0c9e49f", 47 | "metadata": { 48 | "tags": [] 49 | }, 50 | "outputs": [ 51 | { 52 | "name": "stderr", 53 | "output_type": "stream", 54 | "text": [ 55 | "[2023-03-21 07:56:49,304] torch._inductor.debug: [WARNING] model__15_forward_25 debug trace: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_49_09_212386-pid_25659/aot_torchinductor/model__15_forward_25.10\n" 56 | ] 57 | }, 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Inference speedup: 1.96%\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "batch = 1\n", 68 | "torch._dynamo.reset()\n", 69 | "compiled_model = torch.compile(model, options={'triton.cudagraphs': False,\n", 70 | " 'trace.enabled':True})\n", 71 | "\n", 72 | "t_model = benchmark.Timer(\n", 73 | " stmt='run_batch_inference(model, batch)',\n", 74 | " setup='from __main__ import run_batch_inference',\n", 75 | " globals={'model': model, 'batch':batch})\n", 76 | "\n", 77 | "t_compiled_model = benchmark.Timer(\n", 78 | " stmt='run_batch_inference(model, batch)',\n", 79 | " setup='from __main__ import run_batch_inference',\n", 80 | " globals={'model': compiled_model, 'batch':batch})\n", 81 | "\n", 82 | "t_model_runs = t_model.timeit(100)\n", 83 | "t_compiled_model_runs = t_compiled_model.timeit(100)\n", 84 | "\n", 85 | "print(f\"Inference speedup: {100*(t_model_runs.mean - t_compiled_model_runs.mean) / t_model_runs.mean: .2f}%\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "id": "bdf99896-0fe4-4b65-a1ed-13296a81c221", 92 | "metadata": { 93 | "tags": [] 94 | }, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "\n", 101 | "run_batch_train(model, optimizer, batch)\n", 102 | "setup: from __main__ import run_batch_train\n", 103 | " 37.30 ms\n", 104 | " 1 measurement, 100 runs , 1 thread\n", 105 | "\n", 106 | "run_batch_train(model, optimizer, batch)\n", 107 | "setup: from __main__ import run_batch_train\n", 108 | " 34.57 ms\n", 109 | " 1 measurement, 100 runs , 1 thread\n", 110 | "Training speedup: 7.32%\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "batch = 32\n", 116 | "torch._dynamo.reset()\n", 117 | "compiled_model = torch.compile(model)\n", 118 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", 119 | "\n", 120 | "t_model = benchmark.Timer(\n", 121 | " stmt='run_batch_train(model, optimizer, batch)',\n", 122 | " setup='from __main__ import run_batch_train',\n", 123 | " globals={'model': model,'optimizer':optimizer, 'batch':batch})\n", 124 | "\n", 125 | "t_compiled_model = benchmark.Timer(\n", 126 | " stmt='run_batch_train(model, optimizer, batch)',\n", 127 | " setup='from __main__ import run_batch_train',\n", 128 | " globals={'model': compiled_model, 'optimizer':optimizer, 'batch':batch})\n", 129 | "\n", 130 | "t_model_runs = t_model.timeit(100)\n", 131 | "t_compiled_model_runs = t_compiled_model.timeit(100)\n", 132 | "\n", 133 | "print(t_model_runs)\n", 134 | "print(t_compiled_model_runs)\n", 135 | "\n", 136 | "print(f\"Training speedup: {100*(t_model_runs.mean - t_compiled_model_runs.mean) / t_model_runs.mean: .2f}%\")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "bf10e32d-b29f-4df1-af18-6dadcccc9a98", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3 (ipykernel)", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.10.9" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 5 169 | } 170 | -------------------------------------------------------------------------------- /pytorch-graph-optimization/graph_optimization_torch_compile.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "afb2eda5-99df-42b0-bbd3-f711f4d086b0", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!rm -rf torch_compile_debug" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "8f6b7210-882a-4c36-8060-4f324a85cccd", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import torch\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "from matplotlib import cm\n", 27 | "from matplotlib.ticker import LinearLocator\n", 28 | "import numpy as np\n", 29 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\"" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "1cd9c0e4-bdc7-46d6-b0c4-d545ff5d2a29", 36 | "metadata": { 37 | "tags": [] 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "def fn(x,y):\n", 42 | " # x*exp(−(x^2+y^2))+(x^2+y^2)/20\n", 43 | " return x*torch.exp(-x**2-y**2) + (x**2+y**2)/20 " 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 7, 49 | "id": "96b96737-a9e3-4c0a-ac13-502102cad039", 50 | "metadata": { 51 | "tags": [] 52 | }, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "image/png": "", 57 | "text/plain": [ 58 | "
" 59 | ] 60 | }, 61 | "metadata": {}, 62 | "output_type": "display_data" 63 | } 64 | ], 65 | "source": [ 66 | "# Make data.\n", 67 | "x = torch.arange(-2, 2, 0.2,requires_grad=True)\n", 68 | "y = torch.arange(-2, 2, 0.2,requires_grad=True)\n", 69 | "x, y = torch.meshgrid(x, y, indexing='ij')\n", 70 | "r = fn(x,y)\n", 71 | "z = torch.sin(r)\n", 72 | "\n", 73 | "fig, ax = plt.subplots(subplot_kw={\"projection\": \"3d\"})\n", 74 | "\n", 75 | "# Plot the surface.\n", 76 | "surf = ax.plot_surface(x.detach().numpy(), y.detach().numpy(), z.detach().numpy(), cmap=cm.coolwarm,\n", 77 | " linewidth=0, antialiased=False)\n", 78 | "# Customize the z axis.\n", 79 | "ax.set_zlim(-0.5, 1.01)\n", 80 | "ax.zaxis.set_major_locator(LinearLocator(10))\n", 81 | "# A StrMethodFormatter is used automatically\n", 82 | "ax.zaxis.set_major_formatter('{x:.02f}')\n", 83 | "ax.view_init(10, 60)\n", 84 | "plt.show()" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 7, 90 | "id": "e5d0f05c-6990-4404-8013-8b323bfa7c6c", 91 | "metadata": { 92 | "tags": [] 93 | }, 94 | "outputs": [ 95 | { 96 | "name": "stderr", 97 | "output_type": "stream", 98 | "text": [ 99 | "[2023-03-21 07:04:09,637] torch._inductor.debug: [WARNING] model__1_forward_4 debug trace: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_03_48_329556-pid_24339/aot_torchinductor/model__1_forward_4.2\n", 100 | "[2023-03-21 07:04:09,677] torch._inductor.debug: [WARNING] model__1_backward_5 debug trace: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_03_48_329556-pid_24339/aot_torchinductor/model__1_backward_5.3\n" 101 | ] 102 | }, 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "Writing FX graph to file: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_03_48_329556-pid_24339/aot_torchinductor/model__1_forward_4.2/graph_diagram.svg\n", 108 | "Writing FX graph to file: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_03_48_329556-pid_24339/aot_torchinductor/model__1_backward_5.3/graph_diagram.svg\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "fn_compiled = torch.compile(fn, backend=\"inductor\", \n", 114 | " options={'trace.graph_diagram':True,\n", 115 | " 'trace.enabled':True})\n", 116 | "\n", 117 | "out = fn_compiled(x.to(device), y.to(device)).sum().backward()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "d687a68c-afab-462e-b3cb-1d6f723603bb", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "59d094a0-84fa-430d-b8e1-94c30e310283", 132 | "metadata": { 133 | "tags": [] 134 | }, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "\n", 141 | "\n", 142 | "\n", 143 | "def forward(self, x, y):\n", 144 | " pow_1 = torch.pow(x, 2)\n", 145 | " neg = -pow_1; pow_1 = None\n", 146 | " pow_2 = torch.pow(y, 2)\n", 147 | " sub = neg - pow_2; neg = pow_2 = None\n", 148 | " exp = torch.exp(sub); sub = None\n", 149 | " mul = x * exp; exp = None\n", 150 | " pow_3 = torch.pow(x, 2); x = None\n", 151 | " pow_4 = torch.pow(y, 2); y = None\n", 152 | " add = pow_3 + pow_4; pow_3 = pow_4 = None\n", 153 | " truediv = add / 20; add = None\n", 154 | " add_1 = mul + truediv; mul = truediv = None\n", 155 | " return add_1\n", 156 | " \n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "from torch.fx import passes, symbolic_trace\n", 162 | "model = symbolic_trace(fn)\n", 163 | "\n", 164 | "g = passes.graph_drawer.FxGraphDrawer(model, 'fn')\n", 165 | "with open(\"unoptimized_graph1.svg\", \"wb\") as f:\n", 166 | " f.write(g.get_dot_graph().create_svg())\n", 167 | " \n", 168 | "print(model.code)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 17, 174 | "id": "090bfe7c-8737-40eb-b870-9a87bbaa09b0", 175 | "metadata": { 176 | "tags": [] 177 | }, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "\n", 184 | "\n", 185 | "\n", 186 | "def forward(self, x_1, y_1):\n", 187 | " pow_1 = torch.ops.aten.pow.Tensor_Scalar(x_1, 2)\n", 188 | " neg = torch.ops.aten.neg.default(pow_1); pow_1 = None\n", 189 | " pow_2 = torch.ops.aten.pow.Tensor_Scalar(y_1, 2)\n", 190 | " sub = torch.ops.aten.sub.Tensor(neg, pow_2); neg = pow_2 = None\n", 191 | " exp = torch.ops.aten.exp.default(sub); sub = None\n", 192 | " detach = torch.ops.aten.detach.default(exp)\n", 193 | " mul = torch.ops.aten.mul.Tensor(x_1, exp); exp = None\n", 194 | " pow_3 = torch.ops.aten.pow.Tensor_Scalar(x_1, 2); x_1 = None\n", 195 | " pow_4 = torch.ops.aten.pow.Tensor_Scalar(y_1, 2); y_1 = None\n", 196 | " add = torch.ops.aten.add.Tensor(pow_3, pow_4); pow_3 = pow_4 = None\n", 197 | " div = torch.ops.aten.div.Tensor(add, 20); add = None\n", 198 | " add_1 = torch.ops.aten.add.Tensor(mul, div); mul = div = None\n", 199 | " return add_1\n", 200 | " \n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "from functorch import make_fx\n", 206 | "g = make_fx(fn)(x, y)\n", 207 | "print(g.code)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "6351c611-cf45-40e0-8e78-5b3707cb8a0e", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "Python 3 (ipykernel)", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.10.9" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 5 240 | } 241 | -------------------------------------------------------------------------------- /pytorch-graph-optimization/inspecting_torch_compile.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "2c42e096-2d57-4e1d-976b-003cade9a05c", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import torch\n", 13 | "import torch._dynamo\n", 14 | "from torch import nn" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 3, 20 | "id": "0a16cbc2-0baa-44f4-beb1-030dcb7c1433", 21 | "metadata": { 22 | "tags": [] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "class MLP(nn.Module):\n", 27 | " def __init__(self):\n", 28 | " super().__init__()\n", 29 | " self.fc1 = nn.Linear(32, 64)\n", 30 | "\n", 31 | " def forward(self, x):\n", 32 | " x = self.fc1(x)\n", 33 | " x = torch.nn.functional.gelu(x)\n", 34 | " return x\n", 35 | "\n", 36 | "model = MLP()\n", 37 | "\n", 38 | "batch_size = 8\n", 39 | "x = torch.randn(batch_size, 32)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "43846e32-0973-4e62-ab70-563ec780bcf5", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "from torch.fx import passes, symbolic_trace\n", 50 | "model = symbolic_trace(fn)\n", 51 | "\n", 52 | "g = passes.graph_drawer.FxGraphDrawer(model, 'fn')\n", 53 | "with open(\"unoptimized_graph.svg\", \"wb\") as f:\n", 54 | " f.write(g.get_dot_graph().create_svg())" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 21, 60 | "id": "0a7036c0-cf13-44fb-ac96-79ad76fad291", 61 | "metadata": { 62 | "tags": [] 63 | }, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Writing FX graph to file: forward.svg\n", 70 | "Writing FX graph to file: backward.svg\n" 71 | ] 72 | }, 73 | { 74 | "name": "stderr", 75 | "output_type": "stream", 76 | "text": [ 77 | "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\n", 78 | " warnings.warn(\n", 79 | "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\n", 80 | " warnings.warn(\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "import torch._dynamo\n", 86 | "from torch._functorch.aot_autograd import aot_module_simplified\n", 87 | "from functorch.compile import compiled_function, draw_graph\n", 88 | "\n", 89 | "\n", 90 | "def toy_backend(gm, sample_inputs): \n", 91 | " def fw(gm, sample_inputs):\n", 92 | " draw_graph(gm, \"forward.svg\")\n", 93 | " return gm.forward\n", 94 | " \n", 95 | " def bw(gm, sample_inputs):\n", 96 | " draw_graph(gm, \"backward.svg\")\n", 97 | " return gm.forward\n", 98 | "\n", 99 | " # Invoke AOTAutograd\n", 100 | " return aot_module_simplified(\n", 101 | " gm,\n", 102 | " sample_inputs,\n", 103 | " fw_compiler=fw,\n", 104 | " bw_compiler=bw\n", 105 | " )\n", 106 | "\n", 107 | "def fn(x):\n", 108 | " return x**2\n", 109 | "\n", 110 | "model = fn\n", 111 | "x = torch.tensor(5., requires_grad=True)\n", 112 | "\n", 113 | "torch._dynamo.reset()\n", 114 | "cmodel = torch.compile(model, backend=toy_backend, dynamic=True)\n", 115 | "\n", 116 | "# triggers compilation of forward graph on the first run\n", 117 | "out = cmodel(x).sum().backward()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "id": "259b9c86-5f23-4f61-b178-15e71a6e2197", 124 | "metadata": { 125 | "tags": [] 126 | }, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "\n", 133 | "\n", 134 | "\n", 135 | "def forward(self, x):\n", 136 | " param = self.param\n", 137 | " add = x + param; x = param = None\n", 138 | " linear = self.linear(add); add = None\n", 139 | " clamp = linear.clamp(min = 0.0, max = 1.0); linear = None\n", 140 | " return clamp\n", 141 | " \n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "print(symtraced.code)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "8125d380-aaa2-4685-95a4-a035449cacfb", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [] 156 | } 157 | ], 158 | "metadata": { 159 | "kernelspec": { 160 | "display_name": "Python 3 (ipykernel)", 161 | "language": "python", 162 | "name": "python3" 163 | }, 164 | "language_info": { 165 | "codemirror_mode": { 166 | "name": "ipython", 167 | "version": 3 168 | }, 169 | "file_extension": ".py", 170 | "mimetype": "text/x-python", 171 | "name": "python", 172 | "nbconvert_exporter": "python", 173 | "pygments_lexer": "ipython3", 174 | "version": "3.10.9" 175 | } 176 | }, 177 | "nbformat": 4, 178 | "nbformat_minor": 5 179 | } 180 | -------------------------------------------------------------------------------- /pytorch-intro-torch-compile/1-toy-benchmarks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "b33cbc31-dc42-4987-a9e2-513e31dd92de", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import torch\n", 13 | "import time\n", 14 | "import os\n", 15 | "\n", 16 | "from torch import nn\n", 17 | "import torchvision.models as models\n", 18 | "from triton.testing import do_bench\n", 19 | "import torch._dynamo" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "9d742cd2-8f4b-40d5-a51e-36dacb9a11a0", 26 | "metadata": { 27 | "tags": [] 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "torch.set_float32_matmul_precision('high')\n", 32 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "98c588e4-db0c-4c73-8000-c51edad58d89", 39 | "metadata": { 40 | "tags": [] 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "def run_benchmark(fn):\n", 45 | " exec_time, prctl20, prctl80 = do_bench(fn,warmup=100,rep=1000)\n", 46 | " print(f\"Exec time (median): {exec_time}\")\n", 47 | " print(f\"Exec time (20th percentile): {prctl20}\")\n", 48 | " print(f\"Exec time (80th percentile): {prctl80}\\n\")\n", 49 | " return exec_time" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "cbe00915-2f6e-4a45-b33e-b309dd549dae", 55 | "metadata": {}, 56 | "source": [ 57 | "## 1. ResNet50 Speedup on NVIDIA A10G" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "id": "306dae3d-ad53-442a-a62f-70c18d8567fe", 64 | "metadata": { 65 | "tags": [] 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "def run_batch(model, optimizer):\n", 70 | " x = torch.randn(16, 3, 224, 224).to(device)\n", 71 | " optimizer.zero_grad()\n", 72 | " out = model(x)\n", 73 | " out.sum().backward()\n", 74 | " optimizer.step()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "id": "6e3d508d-75db-406b-92e2-6bedffd1510d", 81 | "metadata": { 82 | "tags": [] 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "model = models.resnet50().to(device)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 6, 92 | "id": "9d483ad8-2882-4ce5-a491-cb464b3e3bb5", 93 | "metadata": { 94 | "tags": [] 95 | }, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "Resnet50 Eager mode\n", 102 | "Exec time (median): 50.456573486328125\n", 103 | "Exec time (20th percentile): 50.4313850402832\n", 104 | "Exec time (80th percentile): 50.50531768798828\n", 105 | "\n", 106 | "Resnet50 Compiled defaults\n", 107 | "Exec time (median): 46.913536071777344\n", 108 | "Exec time (20th percentile): 46.89653778076172\n", 109 | "Exec time (80th percentile): 46.9372673034668\n", 110 | "\n", 111 | "speedup: 7.55%\n" 112 | ] 113 | }, 114 | { 115 | "name": "stderr", 116 | "output_type": "stream", 117 | "text": [ 118 | "Process ForkProcess-4:\n", 119 | "Process ForkProcess-1:\n", 120 | "Process ForkProcess-3:\n", 121 | "Process ForkProcess-2:\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", 127 | "\n", 128 | "# Benchmark Eager\n", 129 | "print(\"Resnet50 Eager mode\")\n", 130 | "exec_time = run_benchmark(lambda: run_batch(model, optimizer))\n", 131 | "\n", 132 | "# Benchmark torch.compile defaults\n", 133 | "print(\"Resnet50 Compiled defaults\")\n", 134 | "opt_model = torch.compile(model)\n", 135 | "opt_exec_time = run_benchmark(lambda: run_batch(opt_model, optimizer))\n", 136 | "\n", 137 | "# Print speedups\n", 138 | "print(f\"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%\")" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "fcd3add7-7c03-4559-9109-8443700febdf", 144 | "metadata": {}, 145 | "source": [ 146 | "## 2. Custom model Speedup on NVIDIA A10G" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 7, 152 | "id": "8ca833c0-d383-45ce-9f31-06b126539376", 153 | "metadata": { 154 | "tags": [] 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "class MLP(nn.Module):\n", 159 | " def __init__(self):\n", 160 | " super().__init__()\n", 161 | " self.fc1 = nn.Linear(1024, 1024)\n", 162 | " self.fc2 = nn.Linear(1024, 1024)\n", 163 | " \n", 164 | " def forward(self, x):\n", 165 | " x = self.fc1(x).relu() ** 2\n", 166 | " return self.fc2(x).relu() ** 2" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 8, 172 | "id": "bc293b4a-6634-4e66-95f0-1675ab363114", 173 | "metadata": { 174 | "tags": [] 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "model = MLP().to(device)\n", 179 | "x = torch.randn(1024, 1024).to(device)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 9, 185 | "id": "020901c4-9054-4652-acaf-0232a31d0011", 186 | "metadata": { 187 | "tags": [] 188 | }, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "Exec time (median): 0.7147520184516907\n", 195 | "Exec time (20th percentile): 0.7127040028572083\n", 196 | "Exec time (80th percentile): 0.7157760262489319\n", 197 | "\n", 198 | "Exec time (median): 0.6010879874229431\n", 199 | "Exec time (20th percentile): 0.6000639796257019\n", 200 | "Exec time (80th percentile): 0.6021119952201843\n", 201 | "\n", 202 | "speedup: 18.91%\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "# Benchmark Eager\n", 208 | "exec_time = run_benchmark(lambda: model(x).sum().backward())\n", 209 | "\n", 210 | "torch._dynamo.reset()\n", 211 | "# Benchmark torch.compile defaults\n", 212 | "cmodel = torch.compile(model, backend='inductor')\n", 213 | "opt_exec_time = run_benchmark(lambda: cmodel(x).sum().backward())\n", 214 | "\n", 215 | "# Print speedups\n", 216 | "print(f\"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%\")" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "49a00c44-c096-4ae8-afe8-7937db196933", 222 | "metadata": {}, 223 | "source": [ 224 | "## 3. HuggingFace model Speedup on NVIDIA A10G" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 10, 230 | "id": "d7a4eaae-8bc3-4a21-8147-a2d75e5fba3a", 231 | "metadata": { 232 | "tags": [] 233 | }, 234 | "outputs": [], 235 | "source": [ 236 | "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC\n", 237 | "from datasets import load_dataset\n", 238 | "\n", 239 | "def run_inference(model, input_values):\n", 240 | " \n", 241 | " # retrieve logits\n", 242 | " logits = model(input_values).logits\n", 243 | " \n", 244 | " # take argmax and decode\n", 245 | " predicted_ids = torch.argmax(logits, dim=-1)\n", 246 | " transcription = processor.batch_decode(predicted_ids)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 11, 252 | "id": "76d00a8f-4503-4fa7-90cd-263f693097da", 253 | "metadata": { 254 | "tags": [] 255 | }, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "application/vnd.jupyter.widget-view+json": { 260 | "model_id": "7baf2003e8ab49cca9c52c88ba8e14c5", 261 | "version_major": 2, 262 | "version_minor": 0 263 | }, 264 | "text/plain": [ 265 | "Downloading (…)rocessor_config.json: 0%| | 0.00/158 [00:00\n", 168 | " print(\"AOTAutograd produced a fx Graph in Aten IR:\")\n", 169 | " gm.print_readable()\n", 170 | " return gm.forward\n", 171 | "\n", 172 | " # Invoke AOTAutograd\n", 173 | " return aot_module_simplified(\n", 174 | " gm,\n", 175 | " sample_inputs,\n", 176 | " fw_compiler=my_compiler\n", 177 | " )\n", 178 | "\n", 179 | "torch._dynamo.reset()\n", 180 | "cmodel = torch.compile(model, backend=toy_backend, dynamic=True)\n", 181 | "\n", 182 | "# triggers compilation of forward graph on the first run\n", 183 | "out = cmodel(input)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 6, 189 | "id": "199f5ca0-0743-4f91-859d-c495a723280c", 190 | "metadata": { 191 | "tags": [] 192 | }, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "Decomposed fx Graph in Aten IR:\n", 199 | "class GraphModule(torch.nn.Module):\n", 200 | " def forward(self, primals_1: f32[64, 32], primals_2: f32[64], primals_3: f32[s0, 32]):\n", 201 | " # File: /tmp/ipykernel_8670/1490842273.py:7, code: x = self.fc1(x)\n", 202 | " permute: f32[32, 64] = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None\n", 203 | " addmm: f32[s0, 64] = torch.ops.aten.addmm.default(primals_2, primals_3, permute); primals_2 = permute = None\n", 204 | " \n", 205 | " # File: /tmp/ipykernel_8670/1490842273.py:8, code: x = torch.nn.functional.gelu(x)\n", 206 | " mul: f32[s0, 64] = torch.ops.aten.mul.Tensor(addmm, 0.5)\n", 207 | " mul_1: f32[s0, 64] = torch.ops.aten.mul.Tensor(addmm, 0.7071067811865476)\n", 208 | " erf: f32[s0, 64] = torch.ops.aten.erf.default(mul_1); mul_1 = None\n", 209 | " add: f32[s0, 64] = torch.ops.aten.add.Tensor(erf, 1); erf = None\n", 210 | " mul_2: f32[s0, 64] = torch.ops.aten.mul.Tensor(mul, add); mul = add = None\n", 211 | " return [mul_2, addmm, primals_3]\n", 212 | " \n" 213 | ] 214 | }, 215 | { 216 | "name": "stderr", 217 | "output_type": "stream", 218 | "text": [ 219 | "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\n", 220 | " warnings.warn(\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "from torch._inductor.decomposition import decompositions as default_decompositions\n", 226 | "\n", 227 | "decompositions = default_decompositions.copy()\n", 228 | "\n", 229 | "def toy_backend(gm, sample_inputs):\n", 230 | " def my_compiler(gm, sample_inputs):\n", 231 | " # \n", 232 | " print(\"Decomposed fx Graph in Aten IR:\")\n", 233 | " gm.print_readable()\n", 234 | " return gm\n", 235 | "\n", 236 | " # Invoke AOTAutograd\n", 237 | " return aot_module_simplified(\n", 238 | " gm,\n", 239 | " sample_inputs,\n", 240 | " decompositions=decompositions,\n", 241 | " fw_compiler=my_compiler\n", 242 | " )\n", 243 | "\n", 244 | "torch._dynamo.reset()\n", 245 | "cmodel = torch.compile(model, backend=toy_backend, dynamic=True)\n", 246 | "\n", 247 | "# triggers compilation of forward graph on the first run\n", 248 | "out = cmodel(input)" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "id": "08c92c33-0803-4ce1-906c-e00faa1a620f", 254 | "metadata": {}, 255 | "source": [ 256 | "### Prims IR (https://pytorch.org/docs/master/ir.html#prims-ir)\n", 257 | "\n", 258 | "* Explicit type promotion and broadcasting\n", 259 | "* prims.convert_element_type\n", 260 | "* prims.broadcast_in_dim\n", 261 | "* For backends with powerful compiler that can reclaim the performance by fusion, e.g. nvFuser" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 7, 267 | "id": "236f9011-fca0-45d9-a832-4f07962d8ef4", 268 | "metadata": { 269 | "tags": [] 270 | }, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "Further decomposed fx Graph in Prims IR:\n", 277 | "class (torch.nn.Module):\n", 278 | " def forward(self, arg0_1: f32[3], arg1_1: f16[3, 3]):\n", 279 | " # File: /tmp/ipykernel_8670/2178452752.py:7, code: return a + b\n", 280 | " _to_copy: f32[3, 3] = torch.ops.aten._to_copy.default(arg1_1, dtype = torch.float32); arg1_1 = None\n", 281 | " broadcast_in_dim: f32[3, 3] = torch.ops.prims.broadcast_in_dim.default(arg0_1, [3, 3], [1]); arg0_1 = None\n", 282 | " add: f32[3, 3] = torch.ops.prims.add.default(broadcast_in_dim, _to_copy); broadcast_in_dim = _to_copy = None\n", 283 | " return (add,)\n", 284 | " \n" 285 | ] 286 | } 287 | ], 288 | "source": [ 289 | "prims_decomp = torch._decomp.get_decompositions([\n", 290 | " torch.ops.aten.add,\n", 291 | " torch.ops.aten.expand.default,\n", 292 | "])\n", 293 | "\n", 294 | "def fn(a, b):\n", 295 | " return a + b\n", 296 | "\n", 297 | "def toy_backend(gm, sample_inputs):\n", 298 | " def my_compiler(gm, sample_inputs):\n", 299 | " # \n", 300 | " print(\"Further decomposed fx Graph in Prims IR:\")\n", 301 | " gm.print_readable()\n", 302 | " return gm\n", 303 | "\n", 304 | " # Invoke AOTAutograd\n", 305 | " return aot_module_simplified(\n", 306 | " gm,\n", 307 | " sample_inputs,\n", 308 | " decompositions=prims_decomp,\n", 309 | " fw_compiler=my_compiler\n", 310 | " )\n", 311 | "\n", 312 | "torch._dynamo.reset()\n", 313 | "fn = torch.compile(backend=toy_backend)(fn)\n", 314 | "out = fn(torch.rand(3, dtype=torch.float), torch.rand(3, 3, dtype=torch.half))" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "id": "0c70a43d-7a90-409e-862c-7169c6fd088f", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [] 324 | } 325 | ], 326 | "metadata": { 327 | "kernelspec": { 328 | "display_name": "Python 3 (ipykernel)", 329 | "language": "python", 330 | "name": "python3" 331 | }, 332 | "language_info": { 333 | "codemirror_mode": { 334 | "name": "ipython", 335 | "version": 3 336 | }, 337 | "file_extension": ".py", 338 | "mimetype": "text/x-python", 339 | "name": "python", 340 | "nbconvert_exporter": "python", 341 | "pygments_lexer": "ipython3", 342 | "version": "3.10.9" 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 5 347 | } 348 | -------------------------------------------------------------------------------- /pytorch-intro-torch-compile/4-nn-example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "dd7ed31a-9899-4489-a7a1-bc2bb1d90ee7", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import argparse\n", 13 | "import json\n", 14 | "import logging\n", 15 | "import os\n", 16 | "import sys\n", 17 | "\n", 18 | "import torch\n", 19 | "import torch.nn as nn\n", 20 | "import torch.nn.functional as F\n", 21 | "import torch.optim as optim\n", 22 | "import torch.utils.data\n", 23 | "import torchvision\n", 24 | "\n", 25 | "from torchvision import datasets, transforms\n", 26 | "\n", 27 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 28 | "torch.set_float32_matmul_precision('high') #Uses TF32 when available" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "37effd54-1f4c-42b2-bd08-1d524908fdc3", 35 | "metadata": { 36 | "tags": [] 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "def _get_model():\n", 41 | " class Net(nn.Module):\n", 42 | " def __init__(self):\n", 43 | " super(Net, self).__init__()\n", 44 | " self.conv1 = nn.Conv2d(3, 6, 5)\n", 45 | " self.pool = nn.MaxPool2d(2, 2)\n", 46 | " self.conv2 = nn.Conv2d(6, 16, 5)\n", 47 | " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", 48 | " self.fc2 = nn.Linear(120, 84)\n", 49 | " self.fc3 = nn.Linear(84, 10)\n", 50 | "\n", 51 | " def forward(self, x):\n", 52 | " x = self.pool(F.relu(self.conv1(x)))\n", 53 | " x = self.pool(F.relu(self.conv2(x)))\n", 54 | " x = x.view(-1, 16 * 5 * 5)\n", 55 | " x = F.relu(self.fc1(x))\n", 56 | " x = F.relu(self.fc2(x))\n", 57 | " x = self.fc3(x)\n", 58 | " return x\n", 59 | " return Net()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "id": "55a1d996-acc3-45ac-8113-3344bc7ddc28", 66 | "metadata": { 67 | "tags": [] 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# Define data augmentation\n", 72 | "def _get_transforms():\n", 73 | " transform = transforms.Compose([\n", 74 | " transforms.RandomCrop(32, padding=4),\n", 75 | " transforms.RandomHorizontalFlip(),\n", 76 | " transforms.ToTensor(),\n", 77 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", 78 | " ])\n", 79 | " return transform" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "id": "f8cc60fa-faba-40d6-a2cc-7d5f29e682c6", 86 | "metadata": { 87 | "tags": [] 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "def _get_dataloaders(batch_size):\n", 92 | " trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", 93 | " download=True, transform=_get_transforms())\n", 94 | " testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n", 95 | " download=True, transform=_get_transforms())\n", 96 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n", 97 | " shuffle=True)\n", 98 | " testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n", 99 | " shuffle=False)\n", 100 | " return trainloader, testloader" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 5, 106 | "id": "c7bda30f-df68-4a3b-925a-3153ba1f6809", 107 | "metadata": { 108 | "tags": [] 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "def test(model, test_loader, device):\n", 113 | " test_loss = 0\n", 114 | " correct = 0\n", 115 | " with torch.no_grad():\n", 116 | " for data, target in test_loader:\n", 117 | " data, target = data.to(device), target.to(device)\n", 118 | " output = model(data)\n", 119 | " test_loss += F.nll_loss(output, target, reduction='mean').item() # sum up batch loss\n", 120 | " pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 121 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 122 | "\n", 123 | " test_loss /= len(test_loader.dataset)\n", 124 | " print(f\"Test set: Average loss: {test_loss}, Accuracy: {correct / len(test_loader.dataset)}\\n\")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 6, 130 | "id": "96633f85-1875-4dc9-a2ee-9784083e63c9", 131 | "metadata": { 132 | "scrolled": true, 133 | "tags": [] 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "import time\n", 138 | "def train(model, batch_size, epochs):\n", 139 | " torch.manual_seed(0)\n", 140 | " lr = 0.01\n", 141 | " momentum=0.9\n", 142 | " train_loader, test_loader = _get_dataloaders(batch_size)\n", 143 | "\n", 144 | " criterion = nn.CrossEntropyLoss()\n", 145 | " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)\n", 146 | "\n", 147 | " for epoch in range(1, epochs + 1):\n", 148 | " start_time = time.time()\n", 149 | " for batch_idx, (data, target) in enumerate(train_loader, 1):\n", 150 | " data, target = data.to(device), target.to(device)\n", 151 | "\n", 152 | " # zero the parameter gradients\n", 153 | " optimizer.zero_grad()\n", 154 | "\n", 155 | " # forward + backward + optimize\n", 156 | " output = model(data)\n", 157 | " loss = criterion(output, target)\n", 158 | " loss.backward()\n", 159 | " optimizer.step()\n", 160 | "\n", 161 | " print(f\"Train Epoch: {epoch} Epoch time: {time.time()-start_time:0.4f} Loss: {loss.item()}\")\n", 162 | " test(model, test_loader, device)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 7, 168 | "id": "4c4c51e7-a0ac-4a7f-88a3-9505bcd8f248", 169 | "metadata": { 170 | "scrolled": true, 171 | "tags": [] 172 | }, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "Train in eager mode on CIFAR-10\n", 179 | "Files already downloaded and verified\n", 180 | "Files already downloaded and verified\n", 181 | "Train Epoch: 1 Epoch time: 18.4492 Loss: 1.5229963064193726\n", 182 | "Test set: Average loss: -0.11479381905794143, Accuracy: 0.3574\n", 183 | "\n", 184 | "Train Epoch: 2 Epoch time: 17.1039 Loss: 1.255014419555664\n", 185 | "Test set: Average loss: -0.12258006700873375, Accuracy: 0.4207\n", 186 | "\n", 187 | "Train Epoch: 3 Epoch time: 17.0540 Loss: 1.2103943824768066\n", 188 | "Test set: Average loss: -0.14219269587993622, Accuracy: 0.4306\n", 189 | "\n", 190 | "Train Epoch: 4 Epoch time: 17.1195 Loss: 1.6141940355300903\n", 191 | "Test set: Average loss: -0.1223116991698742, Accuracy: 0.4078\n", 192 | "\n", 193 | "Train Epoch: 5 Epoch time: 17.1593 Loss: 1.7167662382125854\n", 194 | "Test set: Average loss: -0.1371232023358345, Accuracy: 0.427\n", 195 | "\n", 196 | "Train Epoch: 6 Epoch time: 17.0450 Loss: 1.6442832946777344\n", 197 | "Test set: Average loss: -0.11765576581060887, Accuracy: 0.4195\n", 198 | "\n", 199 | "Train Epoch: 7 Epoch time: 17.2778 Loss: 1.1586203575134277\n", 200 | "Test set: Average loss: -0.12454334303736686, Accuracy: 0.4175\n", 201 | "\n", 202 | "Train Epoch: 8 Epoch time: 17.1429 Loss: 2.3159003257751465\n", 203 | "Test set: Average loss: -0.14869595084786416, Accuracy: 0.4542\n", 204 | "\n", 205 | "Train Epoch: 9 Epoch time: 17.0725 Loss: 1.63433039188385\n", 206 | "Test set: Average loss: -0.13449764105677606, Accuracy: 0.4416\n", 207 | "\n", 208 | "Train Epoch: 10 Epoch time: 17.1152 Loss: 0.9652916193008423\n", 209 | "Test set: Average loss: -0.14638293540477754, Accuracy: 0.4566\n", 210 | "\n", 211 | "Train Epoch: 11 Epoch time: 17.0312 Loss: 1.4068830013275146\n", 212 | "Test set: Average loss: -0.13832889040112495, Accuracy: 0.4099\n", 213 | "\n", 214 | "Train Epoch: 12 Epoch time: 17.1650 Loss: 1.6365031003952026\n", 215 | "Test set: Average loss: -0.1226353962123394, Accuracy: 0.427\n", 216 | "\n", 217 | "Train Epoch: 13 Epoch time: 17.1685 Loss: 1.2202277183532715\n", 218 | "Test set: Average loss: -0.16288705285787583, Accuracy: 0.4464\n", 219 | "\n", 220 | "Train Epoch: 14 Epoch time: 17.1915 Loss: 1.5230098962783813\n", 221 | "Test set: Average loss: -0.13382687171697616, Accuracy: 0.4507\n", 222 | "\n", 223 | "Train Epoch: 15 Epoch time: 17.1395 Loss: 1.6405322551727295\n", 224 | "Test set: Average loss: -0.11657781254947186, Accuracy: 0.418\n", 225 | "\n", 226 | "Train Epoch: 16 Epoch time: 17.1124 Loss: 1.435512661933899\n", 227 | "Test set: Average loss: -0.16411116136312484, Accuracy: 0.4536\n", 228 | "\n", 229 | "Train Epoch: 17 Epoch time: 17.0854 Loss: 1.3877995014190674\n", 230 | "Test set: Average loss: -0.13008506304621698, Accuracy: 0.4505\n", 231 | "\n", 232 | "Train Epoch: 18 Epoch time: 17.0801 Loss: 1.3181973695755005\n", 233 | "Test set: Average loss: -0.12966735948324204, Accuracy: 0.4444\n", 234 | "\n", 235 | "Train Epoch: 19 Epoch time: 17.1122 Loss: 1.3785901069641113\n", 236 | "Test set: Average loss: -0.16746452817320823, Accuracy: 0.4708\n", 237 | "\n", 238 | "Train Epoch: 20 Epoch time: 17.0873 Loss: 2.601522922515869\n", 239 | "Test set: Average loss: -0.1661388152360916, Accuracy: 0.455\n", 240 | "\n", 241 | "Train Epoch: 21 Epoch time: 17.1412 Loss: 1.387449026107788\n", 242 | "Test set: Average loss: -0.13323016968667506, Accuracy: 0.4078\n", 243 | "\n", 244 | "Train Epoch: 22 Epoch time: 17.0955 Loss: 0.9796938300132751\n", 245 | "Test set: Average loss: -0.12620635756850243, Accuracy: 0.4348\n", 246 | "\n", 247 | "Train Epoch: 23 Epoch time: 17.1453 Loss: 1.7141715288162231\n", 248 | "Test set: Average loss: -0.14345559119582177, Accuracy: 0.4657\n", 249 | "\n", 250 | "Train Epoch: 24 Epoch time: 17.1091 Loss: 1.4967659711837769\n", 251 | "Test set: Average loss: -0.12276398456394673, Accuracy: 0.4001\n", 252 | "\n", 253 | "Train Epoch: 25 Epoch time: 17.0620 Loss: 1.9967564344406128\n", 254 | "Test set: Average loss: -0.15137818831205369, Accuracy: 0.4591\n", 255 | "\n", 256 | "Train Epoch: 26 Epoch time: 17.0726 Loss: 1.2827167510986328\n", 257 | "Test set: Average loss: -0.14104595474600792, Accuracy: 0.4339\n", 258 | "\n", 259 | "Train Epoch: 27 Epoch time: 17.1130 Loss: 1.6564006805419922\n", 260 | "Test set: Average loss: -0.15558898612856864, Accuracy: 0.457\n", 261 | "\n", 262 | "Train Epoch: 28 Epoch time: 17.0406 Loss: 1.5395925045013428\n", 263 | "Test set: Average loss: -0.15071935195326805, Accuracy: 0.4298\n", 264 | "\n", 265 | "Train Epoch: 29 Epoch time: 17.0579 Loss: 2.140176773071289\n", 266 | "Test set: Average loss: -0.15748710156083107, Accuracy: 0.4236\n", 267 | "\n", 268 | "Train Epoch: 30 Epoch time: 17.0842 Loss: 1.7239465713500977\n", 269 | "Test set: Average loss: -0.14595504192709924, Accuracy: 0.4115\n", 270 | "\n", 271 | "Train Epoch: 31 Epoch time: 17.1039 Loss: 1.7769050598144531\n", 272 | "Test set: Average loss: -0.1290455259978771, Accuracy: 0.3998\n", 273 | "\n", 274 | "Train Epoch: 32 Epoch time: 17.1451 Loss: 1.662497639656067\n", 275 | "Test set: Average loss: -0.15571664265990257, Accuracy: 0.4357\n", 276 | "\n", 277 | "Train Epoch: 33 Epoch time: 17.1315 Loss: 1.3852906227111816\n", 278 | "Test set: Average loss: -0.1613605972111225, Accuracy: 0.4172\n", 279 | "\n", 280 | "Train Epoch: 34 Epoch time: 17.0798 Loss: 1.7368806600570679\n", 281 | "Test set: Average loss: -0.12401523492336274, Accuracy: 0.3968\n", 282 | "\n", 283 | "Train Epoch: 35 Epoch time: 17.1884 Loss: 1.463071584701538\n", 284 | "Test set: Average loss: -0.13350237726569175, Accuracy: 0.4211\n", 285 | "\n", 286 | "Train Epoch: 36 Epoch time: 17.0499 Loss: 1.222204327583313\n", 287 | "Test set: Average loss: -0.16071371198892592, Accuracy: 0.4386\n", 288 | "\n", 289 | "Train Epoch: 37 Epoch time: 17.0795 Loss: 1.59268319606781\n", 290 | "Test set: Average loss: -0.1448448682665825, Accuracy: 0.4142\n", 291 | "\n", 292 | "Train Epoch: 38 Epoch time: 17.0917 Loss: 1.8278884887695312\n", 293 | "Test set: Average loss: -0.13484705570042133, Accuracy: 0.4107\n", 294 | "\n", 295 | "Train Epoch: 39 Epoch time: 17.1323 Loss: 1.2747076749801636\n", 296 | "Test set: Average loss: -0.14489886118769646, Accuracy: 0.3949\n", 297 | "\n", 298 | "Train Epoch: 40 Epoch time: 17.1330 Loss: 1.6077611446380615\n", 299 | "Test set: Average loss: -0.14588061140179634, Accuracy: 0.4289\n", 300 | "\n", 301 | "Train Epoch: 41 Epoch time: 17.2133 Loss: 1.6331448554992676\n", 302 | "Test set: Average loss: -0.09960648607611657, Accuracy: 0.3682\n", 303 | "\n", 304 | "Train Epoch: 42 Epoch time: 17.1127 Loss: 1.2302451133728027\n", 305 | "Test set: Average loss: -0.14067820476591586, Accuracy: 0.4283\n", 306 | "\n", 307 | "Train Epoch: 43 Epoch time: 17.1447 Loss: 1.6978567838668823\n", 308 | "Test set: Average loss: -0.16264084021151065, Accuracy: 0.4108\n", 309 | "\n", 310 | "Train Epoch: 44 Epoch time: 17.0487 Loss: 1.2845810651779175\n", 311 | "Test set: Average loss: -0.13672254542708398, Accuracy: 0.444\n", 312 | "\n", 313 | "Train Epoch: 45 Epoch time: 17.1562 Loss: 1.1537847518920898\n", 314 | "Test set: Average loss: -0.14150477679371834, Accuracy: 0.4273\n", 315 | "\n", 316 | "Train Epoch: 46 Epoch time: 17.1326 Loss: 1.6309105157852173\n", 317 | "Test set: Average loss: -0.13888773401975632, Accuracy: 0.3973\n", 318 | "\n", 319 | "Train Epoch: 47 Epoch time: 17.1038 Loss: 1.398644208908081\n", 320 | "Test set: Average loss: -0.15300521407723428, Accuracy: 0.4311\n", 321 | "\n", 322 | "Train Epoch: 48 Epoch time: 17.1457 Loss: 1.8404948711395264\n", 323 | "Test set: Average loss: -0.1579221700131893, Accuracy: 0.4233\n", 324 | "\n", 325 | "Train Epoch: 49 Epoch time: 17.1388 Loss: 1.3540527820587158\n", 326 | "Test set: Average loss: -0.1460476183593273, Accuracy: 0.4527\n", 327 | "\n", 328 | "Train Epoch: 50 Epoch time: 17.1742 Loss: 1.765850305557251\n", 329 | "Test set: Average loss: -0.12122882234454155, Accuracy: 0.4141\n", 330 | "\n", 331 | "CPU times: user 16min 39s, sys: 6.23 s, total: 16min 45s\n", 332 | "Wall time: 16min 47s\n" 333 | ] 334 | } 335 | ], 336 | "source": [ 337 | "%%time\n", 338 | "print(\"Train in eager mode on CIFAR-10\")\n", 339 | "model = _get_model().to(device)\n", 340 | "\n", 341 | "train(model, batch_size=16, epochs=50)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 11, 347 | "id": "f5330b33-eb7b-4a18-939a-44564dcf4441", 348 | "metadata": { 349 | "tags": [] 350 | }, 351 | "outputs": [ 352 | { 353 | "name": "stdout", 354 | "output_type": "stream", 355 | "text": [ 356 | "Generating forward and backward graphs\n", 357 | "CPU times: user 1.01 s, sys: 12 ms, total: 1.02 s\n", 358 | "Wall time: 1.02 s\n" 359 | ] 360 | } 361 | ], 362 | "source": [ 363 | "%%time\n", 364 | "import torch._inductor.config\n", 365 | "model = _get_model().to(device)\n", 366 | "model = torch.compile(model, backend=\"inductor\",\n", 367 | " mode=\"max-autotune\")\n", 368 | "\n", 369 | "randinput = torch.randn(16,3,32,32).to(device)\n", 370 | "randoutput = torch.randn(16,10).to(device)\n", 371 | "\n", 372 | "print('Generating forward and backward graphs')\n", 373 | "out = model(randinput)\n", 374 | "nn.CrossEntropyLoss()(out, randoutput).backward()" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 12, 380 | "id": "b5ef0dbd-2d0c-4f4c-bd93-b2d4c6765829", 381 | "metadata": { 382 | "scrolled": true, 383 | "tags": [] 384 | }, 385 | "outputs": [ 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "Train in compiled mode on CIFAR-10\n", 391 | "Files already downloaded and verified\n", 392 | "Files already downloaded and verified\n", 393 | "Train Epoch: 1 Epoch time: 16.8979 Loss: 1.1281245946884155\n", 394 | "Test set: Average loss: -0.10141951135396958, Accuracy: 0.3753\n", 395 | "\n", 396 | "Train Epoch: 2 Epoch time: 16.8175 Loss: 1.5165073871612549\n", 397 | "Test set: Average loss: -0.11379602966904641, Accuracy: 0.4128\n", 398 | "\n", 399 | "Train Epoch: 3 Epoch time: 16.8435 Loss: 1.3197925090789795\n", 400 | "Test set: Average loss: -0.14007682649493217, Accuracy: 0.4195\n", 401 | "\n", 402 | "Train Epoch: 4 Epoch time: 17.1099 Loss: 1.7537651062011719\n", 403 | "Test set: Average loss: -0.12210166681408882, Accuracy: 0.444\n", 404 | "\n", 405 | "Train Epoch: 5 Epoch time: 16.9783 Loss: 1.5357666015625\n", 406 | "Test set: Average loss: -0.15047866759300232, Accuracy: 0.4326\n", 407 | "\n", 408 | "Train Epoch: 6 Epoch time: 16.7821 Loss: 1.6223543882369995\n", 409 | "Test set: Average loss: -0.143238749986887, Accuracy: 0.4263\n", 410 | "\n", 411 | "Train Epoch: 7 Epoch time: 16.8694 Loss: 1.7723138332366943\n", 412 | "Test set: Average loss: -0.12494028667807579, Accuracy: 0.4065\n", 413 | "\n", 414 | "Train Epoch: 8 Epoch time: 16.8882 Loss: 1.8077551126480103\n", 415 | "Test set: Average loss: -0.1637685509085655, Accuracy: 0.4561\n", 416 | "\n", 417 | "Train Epoch: 9 Epoch time: 16.8302 Loss: 1.4599609375\n", 418 | "Test set: Average loss: -0.1516111361026764, Accuracy: 0.4451\n", 419 | "\n", 420 | "Train Epoch: 10 Epoch time: 16.8638 Loss: 1.1293983459472656\n", 421 | "Test set: Average loss: -0.16525690048933028, Accuracy: 0.4697\n", 422 | "\n", 423 | "Train Epoch: 11 Epoch time: 16.8586 Loss: 0.9164922833442688\n", 424 | "Test set: Average loss: -0.15961521873474122, Accuracy: 0.4466\n", 425 | "\n", 426 | "Train Epoch: 12 Epoch time: 16.7826 Loss: 2.045623779296875\n", 427 | "Test set: Average loss: -0.10865869052410125, Accuracy: 0.4195\n", 428 | "\n", 429 | "Train Epoch: 13 Epoch time: 16.8195 Loss: 1.2999463081359863\n", 430 | "Test set: Average loss: -0.17005301181674004, Accuracy: 0.4713\n", 431 | "\n", 432 | "Train Epoch: 14 Epoch time: 16.8872 Loss: 1.5532177686691284\n", 433 | "Test set: Average loss: -0.17997907676696778, Accuracy: 0.4673\n", 434 | "\n", 435 | "Train Epoch: 15 Epoch time: 16.7638 Loss: 1.6674301624298096\n", 436 | "Test set: Average loss: -0.13775004784464837, Accuracy: 0.4458\n", 437 | "\n", 438 | "Train Epoch: 16 Epoch time: 16.8711 Loss: 1.7492315769195557\n", 439 | "Test set: Average loss: -0.17715689504146576, Accuracy: 0.4702\n", 440 | "\n", 441 | "Train Epoch: 17 Epoch time: 16.8545 Loss: 1.4167125225067139\n", 442 | "Test set: Average loss: -0.13974063433408737, Accuracy: 0.4377\n", 443 | "\n", 444 | "Train Epoch: 18 Epoch time: 16.8798 Loss: 1.9101390838623047\n", 445 | "Test set: Average loss: -0.1401785773575306, Accuracy: 0.4311\n", 446 | "\n", 447 | "Train Epoch: 19 Epoch time: 16.8081 Loss: 1.99967622756958\n", 448 | "Test set: Average loss: -0.1516777255654335, Accuracy: 0.4445\n", 449 | "\n", 450 | "Train Epoch: 20 Epoch time: 16.7739 Loss: 2.2190351486206055\n", 451 | "Test set: Average loss: -0.1643785246193409, Accuracy: 0.4545\n", 452 | "\n", 453 | "Train Epoch: 21 Epoch time: 17.1348 Loss: 1.7950257062911987\n", 454 | "Test set: Average loss: -0.17522918835878373, Accuracy: 0.4574\n", 455 | "\n", 456 | "Train Epoch: 22 Epoch time: 16.9389 Loss: 1.5520081520080566\n", 457 | "Test set: Average loss: -0.15325675513744355, Accuracy: 0.4569\n", 458 | "\n", 459 | "Train Epoch: 23 Epoch time: 16.8428 Loss: 1.663689374923706\n", 460 | "Test set: Average loss: -0.17498237799406052, Accuracy: 0.4514\n", 461 | "\n", 462 | "Train Epoch: 24 Epoch time: 16.8972 Loss: 2.0213215351104736\n", 463 | "Test set: Average loss: -0.12325149633288383, Accuracy: 0.4142\n", 464 | "\n", 465 | "Train Epoch: 25 Epoch time: 16.8106 Loss: 2.0409820079803467\n", 466 | "Test set: Average loss: -0.12809701528549194, Accuracy: 0.4128\n", 467 | "\n", 468 | "Train Epoch: 26 Epoch time: 16.7998 Loss: 1.821916103363037\n", 469 | "Test set: Average loss: -0.13387244796156883, Accuracy: 0.4204\n", 470 | "\n", 471 | "Train Epoch: 27 Epoch time: 16.8039 Loss: 1.3090702295303345\n", 472 | "Test set: Average loss: -0.15662841452360154, Accuracy: 0.4604\n", 473 | "\n", 474 | "Train Epoch: 28 Epoch time: 16.9014 Loss: 1.2767252922058105\n", 475 | "Test set: Average loss: -0.17367174760699272, Accuracy: 0.4426\n", 476 | "\n", 477 | "Train Epoch: 29 Epoch time: 16.7807 Loss: 1.9910852909088135\n", 478 | "Test set: Average loss: -0.138916358846426, Accuracy: 0.3752\n", 479 | "\n", 480 | "Train Epoch: 30 Epoch time: 16.7389 Loss: 1.2423659563064575\n", 481 | "Test set: Average loss: -0.14196027348041534, Accuracy: 0.4424\n", 482 | "\n", 483 | "Train Epoch: 31 Epoch time: 16.8019 Loss: 1.6152451038360596\n", 484 | "Test set: Average loss: -0.15072507199048996, Accuracy: 0.4381\n", 485 | "\n", 486 | "Train Epoch: 32 Epoch time: 16.8864 Loss: 1.6289972066879272\n", 487 | "Test set: Average loss: -0.11902655415534973, Accuracy: 0.4208\n", 488 | "\n", 489 | "Train Epoch: 33 Epoch time: 16.7390 Loss: 1.3837391138076782\n", 490 | "Test set: Average loss: -0.15387935359477997, Accuracy: 0.4218\n", 491 | "\n", 492 | "Train Epoch: 34 Epoch time: 16.7903 Loss: 1.516900658607483\n", 493 | "Test set: Average loss: -0.1460586778521538, Accuracy: 0.4429\n", 494 | "\n", 495 | "Train Epoch: 35 Epoch time: 16.7645 Loss: 1.8662484884262085\n", 496 | "Test set: Average loss: -0.1500231746673584, Accuracy: 0.4308\n", 497 | "\n", 498 | "Train Epoch: 36 Epoch time: 16.7263 Loss: 1.570525884628296\n", 499 | "Test set: Average loss: -0.17256923723816872, Accuracy: 0.4522\n", 500 | "\n", 501 | "Train Epoch: 37 Epoch time: 16.7440 Loss: 1.7509129047393799\n", 502 | "Test set: Average loss: -0.1331960899591446, Accuracy: 0.4025\n", 503 | "\n", 504 | "Train Epoch: 38 Epoch time: 16.6792 Loss: 1.6465048789978027\n", 505 | "Test set: Average loss: -0.17207880718111992, Accuracy: 0.463\n", 506 | "\n", 507 | "Train Epoch: 39 Epoch time: 16.7337 Loss: 1.4757962226867676\n", 508 | "Test set: Average loss: -0.1252164648413658, Accuracy: 0.435\n", 509 | "\n", 510 | "Train Epoch: 40 Epoch time: 16.7789 Loss: 2.02087140083313\n", 511 | "Test set: Average loss: -0.12817251023054124, Accuracy: 0.4488\n", 512 | "\n", 513 | "Train Epoch: 41 Epoch time: 16.7497 Loss: 1.4727895259857178\n", 514 | "Test set: Average loss: -0.1616751208484173, Accuracy: 0.4522\n", 515 | "\n", 516 | "Train Epoch: 42 Epoch time: 16.8264 Loss: 1.214515209197998\n", 517 | "Test set: Average loss: -0.09364713019430637, Accuracy: 0.3973\n", 518 | "\n", 519 | "Train Epoch: 43 Epoch time: 16.6972 Loss: 1.874068021774292\n", 520 | "Test set: Average loss: -0.17249812202453613, Accuracy: 0.4564\n", 521 | "\n", 522 | "Train Epoch: 44 Epoch time: 16.9201 Loss: 1.4941129684448242\n", 523 | "Test set: Average loss: -0.14445521401762962, Accuracy: 0.4388\n", 524 | "\n", 525 | "Train Epoch: 45 Epoch time: 16.7735 Loss: 1.3479551076889038\n", 526 | "Test set: Average loss: -0.1443574243813753, Accuracy: 0.4296\n", 527 | "\n", 528 | "Train Epoch: 46 Epoch time: 16.7243 Loss: 1.9566048383712769\n", 529 | "Test set: Average loss: -0.1257232101917267, Accuracy: 0.3808\n", 530 | "\n", 531 | "Train Epoch: 47 Epoch time: 16.7802 Loss: 1.851891279220581\n", 532 | "Test set: Average loss: -0.09098090907037258, Accuracy: 0.3395\n", 533 | "\n", 534 | "Train Epoch: 48 Epoch time: 16.7901 Loss: 1.5767955780029297\n", 535 | "Test set: Average loss: -0.1268459755897522, Accuracy: 0.4005\n", 536 | "\n", 537 | "Train Epoch: 49 Epoch time: 16.7760 Loss: 1.6413697004318237\n", 538 | "Test set: Average loss: -0.13277969150543212, Accuracy: 0.4223\n", 539 | "\n", 540 | "Train Epoch: 50 Epoch time: 16.6827 Loss: 1.7723995447158813\n", 541 | "Test set: Average loss: -0.1290963588565588, Accuracy: 0.3967\n", 542 | "\n", 543 | "CPU times: user 16min 20s, sys: 5.69 s, total: 16min 25s\n", 544 | "Wall time: 16min 27s\n" 545 | ] 546 | } 547 | ], 548 | "source": [ 549 | "%%time\n", 550 | "print(\"Train in compiled mode on CIFAR-10\")\n", 551 | "model = _get_model().to(device)\n", 552 | "train(model, batch_size=16, epochs=50)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "id": "63ed4c0a-e375-4bc4-b321-fa107a9341b7", 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [] 562 | } 563 | ], 564 | "metadata": { 565 | "kernelspec": { 566 | "display_name": "Python 3 (ipykernel)", 567 | "language": "python", 568 | "name": "python3" 569 | }, 570 | "language_info": { 571 | "codemirror_mode": { 572 | "name": "ipython", 573 | "version": 3 574 | }, 575 | "file_extension": ".py", 576 | "mimetype": "text/x-python", 577 | "name": "python", 578 | "nbconvert_exporter": "python", 579 | "pygments_lexer": "ipython3", 580 | "version": "3.10.9" 581 | } 582 | }, 583 | "nbformat": 4, 584 | "nbformat_minor": 5 585 | } 586 | --------------------------------------------------------------------------------