├── .gitignore
├── README.md
├── docker
├── Dockerfile.gpu
├── Dockerfile.habana
├── Dockerfile.xlagpu
├── README.md
├── start_gpu_pytorch2
├── start_habana_pytorch2
└── start_xla_pytorch2
├── pytorch-compile-blogpost
└── torch-compile-under-the-hood.ipynb
├── pytorch-graph-optimization
├── benchmark_torch-compile_resnet.ipynb
├── graph_optimization_torch_compile.ipynb
└── inspecting_torch_compile.ipynb
└── pytorch-intro-torch-compile
├── 1-toy-benchmarks.ipynb
├── 2-torch-compile-intro.ipynb
├── 3-inspecting-compiler-stack.ipynb
└── 4-nn-example.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | torch_compile_debug/
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-examples
2 | A repository of PyTorch example
3 |
--------------------------------------------------------------------------------
/docker/Dockerfile.gpu:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.8.0-base-ubuntu22.04
2 |
3 | # Remove any third-party apt sources to avoid issues with expiring keys.
4 | RUN rm -f /etc/apt/sources.list.d/*.list
5 |
6 | # Install some basic utilities.
7 | RUN apt-get update && apt-get install -y \
8 | build-essential \
9 | graphviz \
10 | autoconf \
11 | automake \
12 | gdb \
13 | libffi-dev \
14 | zlib1g-dev \
15 | libssl-dev \
16 | libsndfile1 \
17 | curl \
18 | wget \
19 | vim \
20 | ca-certificates \
21 | git \
22 | bzip2 \
23 | libx11-6 \
24 | && rm -rf /var/lib/apt/lists/*
25 |
26 | # Download and install Miniconda.
27 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh
28 | RUN bash /tmp/miniconda.sh -b -p /opt/conda
29 | ENV PATH=/opt/conda/bin:$PATH
30 | RUN conda init
31 | RUN pip install \
32 | torch \
33 | torchvision \
34 | torchaudio \
35 | jupyterlab \
36 | ipykernel \
37 | matplotlib \
38 | ipywidgets \
39 | pydot \
40 | huggingface \
41 | transformers \
42 | datasets \
43 | diffusers["torch"] \
44 | accelerate \
45 | soundfile \
46 | librosa \
47 | torchdiffeq
48 | RUN pip install \
49 | xgboost \
50 | scikit-learn \
51 | tabulate \
52 | bokeh
53 |
54 | RUN jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
55 | #ENTRYPOINT ["jupyter", "lab", "--ip='*'", "--NotebookApp.token=''", "--NotebookApp.password=''","--allow-root"]
56 | ENTRYPOINT ["jupyter", "lab", "--ip='*'", "--allow-root"]
57 |
58 |
59 |
--------------------------------------------------------------------------------
/docker/Dockerfile.habana:
--------------------------------------------------------------------------------
1 | FROM vault.habana.ai/gaudi-docker/1.8.0/ubuntu20.04/habanalabs/pytorch-installer-1.13.1:latest
2 | RUN apt-get update
3 | RUN apt-get install build-essential -y
4 | RUN apt-get install autoconf automake gdb git libffi-dev zlib1g-dev libssl-dev libsndfile1 -y
5 | RUN pip install jupyterlab ipykernel matplotlib ipywidgets
6 | RUN pip install huggingface transformers datasets
7 | RUN pip install diffusers["torch"] accelerate soundfile librosa
8 | WORKDIR /pytorch-habana
9 | ENTRYPOINT ["jupyter", "lab", "--ip='*'", "--NotebookApp.token=''", "--NotebookApp.password=''","--allow-root"]
10 |
--------------------------------------------------------------------------------
/docker/Dockerfile.xlagpu:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.8.0-base-ubuntu22.04
2 |
3 | # Remove any third-party apt sources to avoid issues with expiring keys.
4 | RUN rm -f /etc/apt/sources.list.d/*.list
5 |
6 | # Install some basic utilities.
7 | RUN apt-get update && apt-get install -y \
8 | build-essential \
9 | graphviz \
10 | autoconf \
11 | automake \
12 | gdb \
13 | libffi-dev \
14 | zlib1g-dev \
15 | libssl-dev \
16 | libsndfile1 \
17 | curl \
18 | wget \
19 | vim \
20 | ca-certificates \
21 | git \
22 | bzip2 \
23 | libx11-6 \
24 | && rm -rf /var/lib/apt/lists/*
25 |
26 | # Download and install Miniconda.
27 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py38_23.1.0-1-Linux-x86_64.sh -O /tmp/miniconda.sh
28 | RUN bash /tmp/miniconda.sh -b -p /opt/conda
29 | ENV PATH=/opt/conda/bin:$PATH
30 |
31 | # Make RUN commands use the new environment:
32 | RUN pip install \
33 | cloud-tpu-client==0.10 \
34 | torch==2.0.0 \
35 | torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
36 | RUN pip install \
37 | jupyterlab \
38 | ipykernel \
39 | matplotlib \
40 | ipywidgets \
41 | pydot \
42 | huggingface \
43 | transformers \
44 | datasets \
45 | diffusers["torch"] \
46 | accelerate \
47 | soundfile \
48 | librosa \
49 | torchdiffeq
50 | RUN pip install \
51 | xgboost \
52 | scikit-learn \
53 | bokeh
54 |
55 | RUN jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
56 | ENTRYPOINT ["jupyter", "lab", "--ip='*'", "--NotebookApp.token=''", "--NotebookApp.password=''","--allow-root"]
57 |
58 |
59 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | ### Dockerfiles for environments used in examples
--------------------------------------------------------------------------------
/docker/start_gpu_pytorch2:
--------------------------------------------------------------------------------
1 | [ ! -z "$(docker ps -a -q)" ] && docker rm $(docker ps -a -q)
2 | callpath=$PWD
3 | cd $(dirname "$0")
4 | docker build -t pytorch2:gpu -f Dockerfile.gpu .
5 | cd $callpath
6 | [ ! -z "$(docker images -f dangling=true -q)" ] && docker rmi $(docker images -f dangling=true -q)
7 | docker run -it --rm \
8 | --network=host \
9 | --gpus all \
10 | -v $PWD/:/workspace \
11 | -v ~/.cache/:/root/.cache \
12 | --workdir /workspace \
13 | --name pytorch \
14 | pytorch2:gpu
15 |
--------------------------------------------------------------------------------
/docker/start_habana_pytorch2:
--------------------------------------------------------------------------------
1 | docker build -t pytorch-habana:latest -f docker/Dockerfile .
2 | docker rmi $(docker images -f dangling=true -q)
3 | docker run -it --network=host -v $PWD/:/pytorch-habana --workdir /pytorch-habana pytorch-habana:latest
4 |
--------------------------------------------------------------------------------
/docker/start_xla_pytorch2:
--------------------------------------------------------------------------------
1 | [ ! -z "$(docker ps -a -q)" ] && docker rm $(docker ps -a -q)
2 | docker build -t pytorch2:xla -f docker/Dockerfile.xlagpu .
3 | [ ! -z "$(docker images -f dangling=true -q)" ] && docker rmi $(docker images -f dangling=true -q)
4 | docker run -it --rm \
5 | --network=host \
6 | --gpus all \
7 | -v $PWD/:/pytorch-examples \
8 | -v ~/.cache/:/root/.cache \
9 | --workdir /pytorch-examples \
10 | pytorch2:xla
11 |
--------------------------------------------------------------------------------
/pytorch-compile-blogpost/torch-compile-under-the-hood.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "c0d1d55d-42e3-4445-b6c1-dfd95f7999a7",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "!rm -rf torch_compile_debug\n",
13 | "!rm *.svg"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "id": "79b09805-e61c-4be4-aaf0-2dc148b0b736",
20 | "metadata": {
21 | "tags": []
22 | },
23 | "outputs": [],
24 | "source": [
25 | "import torch\n",
26 | "import math\n",
27 | "import os\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "import torch._dynamo\n",
30 | "from torchvision import models\n",
31 | "from torch.fx.passes.graph_drawer import FxGraphDrawer\n",
32 | "from IPython.display import Markdown as md\n",
33 | "\n",
34 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\""
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "id": "a51c7031-33d1-44ee-8398-3137b5f7d085",
41 | "metadata": {
42 | "tags": []
43 | },
44 | "outputs": [],
45 | "source": [
46 | "def f(x):\n",
47 | " return torch.sin(x)**2 + torch.cos(x)**2\n",
48 | "\n",
49 | "md('''\n",
50 | "# $ y = f(x) = sin^2(x) + cos^2(x)$\n",
51 | "''')"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "id": "82568cbb-750b-4ed9-baf2-1cf2c6028c0d",
58 | "metadata": {
59 | "tags": []
60 | },
61 | "outputs": [],
62 | "source": [
63 | "md('''\n",
64 | "## Optimization problem:\n",
65 | "### find $w^*$ that $\\displaystyle \\min_{w} f(w)$\n",
66 | "\n",
67 | "## Gradient update\n",
68 | "### $ w_{i+1} = w_{i} - g(\\\\nabla f(w))$\n",
69 | "## For SGD\n",
70 | "### $ g(\\\\nabla f(w)) = \\\\alpha*\\\\nabla f(w)$\n",
71 | "## Which makes the update for SGD:\n",
72 | "### $ w_{i+1} = w_{i} - \\\\alpha*\\\\nabla f(w)$\n",
73 | "\n",
74 | "## Loss function\n",
75 | "### $loss(w): loss(model(w,batch_{inputs}), batch_{outputs})$ \n",
76 | "''')"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "4b4c1717-2fbe-43f6-86ad-1f12d397d33d",
83 | "metadata": {
84 | "tags": []
85 | },
86 | "outputs": [],
87 | "source": [
88 | "md('''\n",
89 | "# **Forward graph:** $f(x) = sin^2(x)+cos^2(x)$ \\n\n",
90 | "# **Backward graph:** $\\\\frac {df(x)}{d\\\\vec{w}} = f\\'(x) = 2sin(x)cos(x) + 2cos(x)(-sin(x))$\n",
91 | "''')"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "id": "0ba8ff94-c457-41fd-b7ff-91ac36872bee",
98 | "metadata": {
99 | "tags": []
100 | },
101 | "outputs": [],
102 | "source": [
103 | "torch.manual_seed(0)\n",
104 | "x = torch.rand(1000, requires_grad=True).to(device)\n",
105 | "torch.nn.functional.mse_loss(f(x),torch.ones_like(x)) < 1e-10"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "id": "f5f80ceb-7e1a-4140-a17c-b6827e6e7461",
112 | "metadata": {
113 | "tags": []
114 | },
115 | "outputs": [],
116 | "source": [
117 | "torch.manual_seed(0)\n",
118 | "x = torch.rand(1000, requires_grad=True).to(device)\n",
119 | "\n",
120 | "compiled_f = torch.compile(f)\n",
121 | "torch.nn.functional.mse_loss(compiled_f(x),torch.ones_like(x)) < 1e-10"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "id": "b3a9cf8b-37ef-4b00-9c1b-d21e9c0ed6af",
128 | "metadata": {
129 | "tags": []
130 | },
131 | "outputs": [],
132 | "source": [
133 | "def inspect_backend(gm, sample_inputs):\n",
134 | " code = gm.print_readable()\n",
135 | " with open(\"forward.svg\", \"wb\") as file:\n",
136 | " file.write(FxGraphDrawer(gm,'f').get_dot_graph().create_svg())\n",
137 | " return gm.forward\n",
138 | "\n",
139 | "torch._dynamo.reset()\n",
140 | "compiled_f = torch.compile(f, backend=inspect_backend)\n",
141 | "\n",
142 | "x = torch.rand(1000, requires_grad=True).to(device)\n",
143 | "out = compiled_f(x)\n",
144 | "\n",
145 | "md(f'''\n",
146 | "### Graph\n",
147 | "\n",
148 | "''')"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "id": "7ec2518d-1024-4aeb-a991-67de9dc51090",
155 | "metadata": {},
156 | "outputs": [],
157 | "source": []
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "id": "619928ed-ed3f-45a4-8ef0-783c70bafaed",
162 | "metadata": {},
163 | "source": [
164 | "# AOTAutograd and Aten IR"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "id": "b900c2c1-b83e-44ba-8eb2-c3d71c3bd53e",
171 | "metadata": {
172 | "tags": []
173 | },
174 | "outputs": [],
175 | "source": [
176 | "import torch._dynamo\n",
177 | "from torch.fx.passes.graph_drawer import FxGraphDrawer\n",
178 | "from functorch.compile import make_boxed_func\n",
179 | "from torch._functorch.aot_autograd import aot_module_simplified\n",
180 | "\n",
181 | "def f(x):\n",
182 | " return torch.sin(x)**2 + torch.cos(x)**2\n",
183 | "\n",
184 | "def inspect_backend(gm, sample_inputs): \n",
185 | " # Forward compiler capture\n",
186 | " def fw(gm, sample_inputs):\n",
187 | " gm.print_readable()\n",
188 | " g = FxGraphDrawer(gm, 'fn')\n",
189 | " with open(\"forward_aot.svg\", \"wb\") as file:\n",
190 | " file.write(g.get_dot_graph().create_svg())\n",
191 | " return make_boxed_func(gm.forward)\n",
192 | " \n",
193 | " # Backward compiler capture\n",
194 | " def bw(gm, sample_inputs):\n",
195 | " gm.print_readable()\n",
196 | " g = FxGraphDrawer(gm, 'fn')\n",
197 | " with open(\"backward_aot.svg\", \"wb\") as file:\n",
198 | " file.write(g.get_dot_graph().create_svg())\n",
199 | " return make_boxed_func(gm.forward)\n",
200 | " \n",
201 | " # Call AOTAutograd\n",
202 | " gm_forward = aot_module_simplified(gm,sample_inputs,\n",
203 | " fw_compiler=fw,\n",
204 | " bw_compiler=bw)\n",
205 | "\n",
206 | " return gm_forward\n",
207 | "\n",
208 | "torch.manual_seed(0)\n",
209 | "x = torch.rand(1000, requires_grad=True).to(device)\n",
210 | "y = torch.ones_like(x)\n",
211 | "\n",
212 | "torch._dynamo.reset()\n",
213 | "compiled_f = torch.compile(f, backend=inspect_backend)\n",
214 | "out = torch.nn.functional.mse_loss(compiled_f(x), y).backward()"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "id": "264ad8cc-6fc2-4e58-98c7-3dcff496ab59",
221 | "metadata": {
222 | "tags": []
223 | },
224 | "outputs": [],
225 | "source": [
226 | "md(f'''\n",
227 | "| | < Forward graph
Backward graph >||\n",
228 | "|---|---|---|\n",
229 | "''')"
230 | ]
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "id": "a5a4f76f-1461-43d8-a8d4-1d24a7f6d1a8",
235 | "metadata": {},
236 | "source": [
237 | "# Decomposition to Core Aten IR"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "id": "bf9f0343-65cc-4d7e-8c53-55292b46c87d",
244 | "metadata": {
245 | "tags": []
246 | },
247 | "outputs": [],
248 | "source": [
249 | "import torch._dynamo\n",
250 | "from torch.fx.passes.graph_drawer import FxGraphDrawer\n",
251 | "from functorch.compile import make_boxed_func\n",
252 | "from torch._functorch.aot_autograd import aot_module_simplified\n",
253 | "from torch._decomp import core_aten_decompositions\n",
254 | "\n",
255 | "def f_loss(x, y):\n",
256 | " f_x = torch.sin(x)**2 + torch.cos(x)**2\n",
257 | " return torch.nn.functional.mse_loss(f_x, y)\n",
258 | "\n",
259 | "# decompositions = core_aten_decompositions() # Use decomposition to Core Aten IR\n",
260 | "decompositions = {} # Don't use decomposition to Core Aten IR\n",
261 | "\n",
262 | "def inspect_backend(gm, sample_inputs): \n",
263 | " def fw(gm, sample_inputs):\n",
264 | " gm.print_readable()\n",
265 | " g = FxGraphDrawer(gm, 'fn')\n",
266 | " with open(\"forward_decomp.svg\", \"wb\") as file:\n",
267 | " file.write(g.get_dot_graph().create_svg())\n",
268 | " return make_boxed_func(gm.forward)\n",
269 | " \n",
270 | " def bw(gm, sample_inputs):\n",
271 | " gm.print_readable()\n",
272 | " g = FxGraphDrawer(gm, 'fn')\n",
273 | " with open(\"backward_decomp.svg\", \"wb\") as file:\n",
274 | " file.write(g.get_dot_graph().create_svg())\n",
275 | " return make_boxed_func(gm.forward)\n",
276 | "\n",
277 | " # Invoke AOTAutograd\n",
278 | " return aot_module_simplified(\n",
279 | " gm,\n",
280 | " sample_inputs,\n",
281 | " fw_compiler=fw,\n",
282 | " bw_compiler=bw,\n",
283 | " decompositions=decompositions\n",
284 | " )\n",
285 | "\n",
286 | "torch.manual_seed(0)\n",
287 | "x = torch.rand(1000, requires_grad=True).to(device)\n",
288 | "y = torch.ones_like(x)\n",
289 | "\n",
290 | "torch._dynamo.reset()\n",
291 | "compiled_f = torch.compile(f_loss, backend=inspect_backend)\n",
292 | "out = compiled_f(x,y).backward()\n",
293 | "\n",
294 | "\n",
295 | "md('''\n",
296 | "# $MSE = (\\\\frac{1}{n})(\\\\vec{y}-\\\\vec{x})^2$\n",
297 | "\n",
298 | "''')"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": null,
304 | "id": "2eb86a77-d050-4f80-96b0-c23ac12bd1ee",
305 | "metadata": {
306 | "tags": []
307 | },
308 | "outputs": [],
309 | "source": [
310 | "md(f'''\n",
311 | "| | < Forward graph
Backward graph >||\n",
312 | "|---|---|---|\n",
313 | "''')"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "id": "20a397ac-69a2-4cbe-af0d-9d3c50e00ece",
319 | "metadata": {},
320 | "source": [
321 | "# Decomposition to prim IR"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": null,
327 | "id": "19e71d40-6815-4167-862e-50253913e168",
328 | "metadata": {
329 | "tags": []
330 | },
331 | "outputs": [],
332 | "source": [
333 | "import torch._dynamo\n",
334 | "from torch.fx.passes.graph_drawer import FxGraphDrawer\n",
335 | "from functorch.compile import make_boxed_func\n",
336 | "from torch._functorch.aot_autograd import aot_module_simplified\n",
337 | "from torch._decomp import core_aten_decompositions\n",
338 | "\n",
339 | "def f_loss(x, y):\n",
340 | " f_x = torch.sin(x)**2 + torch.cos(x)**2\n",
341 | " return torch.nn.functional.mse_loss(f_x, y)\n",
342 | "\n",
343 | "decompositions = core_aten_decompositions()\n",
344 | "decompositions.update(\n",
345 | " torch._decomp.get_decompositions([\n",
346 | " torch.ops.aten.sin,\n",
347 | " torch.ops.aten.cos,\n",
348 | " torch.ops.aten.add,\n",
349 | " torch.ops.aten.sub,\n",
350 | " torch.ops.aten.mul,\n",
351 | " torch.ops.aten.sum,\n",
352 | " torch.ops.aten.mean,\n",
353 | " torch.ops.aten.pow.Tensor_Scalar,\n",
354 | " ])\n",
355 | ")\n",
356 | "\n",
357 | "def inspect_backend(gm, sample_inputs): \n",
358 | " def fw(gm, sample_inputs):\n",
359 | " gm.print_readable()\n",
360 | " g = FxGraphDrawer(gm, 'fn')\n",
361 | " with open(\"forward_decomp_prims.svg\", \"wb\") as f:\n",
362 | " f.write(g.get_dot_graph().create_svg())\n",
363 | " return make_boxed_func(gm.forward)\n",
364 | " \n",
365 | " def bw(gm, sample_inputs):\n",
366 | " gm.print_readable()\n",
367 | " g = FxGraphDrawer(gm, 'fn')\n",
368 | " with open(\"backward_decomp_prims.svg\", \"wb\") as f:\n",
369 | " f.write(g.get_dot_graph().create_svg())\n",
370 | " return make_boxed_func(gm.forward)\n",
371 | "\n",
372 | " # Invoke AOTAutograd\n",
373 | " return aot_module_simplified(\n",
374 | " gm,\n",
375 | " sample_inputs,\n",
376 | " fw_compiler=fw,\n",
377 | " bw_compiler=bw,\n",
378 | " decompositions=decompositions\n",
379 | " )\n",
380 | "\n",
381 | "torch.manual_seed(0)\n",
382 | "x = torch.rand(1000, requires_grad=True).to(device)\n",
383 | "y = torch.ones_like(x)\n",
384 | "\n",
385 | "torch._dynamo.reset()\n",
386 | "compiled_f = torch.compile(f_loss, backend=inspect_backend)\n",
387 | "out = compiled_f(x,y).backward()\n"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": null,
393 | "id": "b6b79bad-edc7-413d-94be-84ba286841b5",
394 | "metadata": {
395 | "tags": []
396 | },
397 | "outputs": [],
398 | "source": [
399 | "md(f'''\n",
400 | "| | < Forward graph
Backward graph >||\n",
401 | "|---|---|---|\n",
402 | "''')"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": null,
408 | "id": "91ff887d-46de-48d0-85a6-f1888c04e136",
409 | "metadata": {
410 | "tags": []
411 | },
412 | "outputs": [],
413 | "source": [
414 | "def f(x):\n",
415 | " return torch.sin(x)**2 + torch.cos(x)**2 \n",
416 | "\n",
417 | "torch._dynamo.reset()\n",
418 | "compiled_f = torch.compile(f, backend='inductor',\n",
419 | " options={'trace.enabled':True,\n",
420 | " 'trace.graph_diagram':True})\n",
421 | "\n",
422 | "\n",
423 | "# device = 'cpu'\n",
424 | "device = 'cuda'\n",
425 | "\n",
426 | "torch.manual_seed(0)\n",
427 | "x = torch.rand(1000, requires_grad=True).to(device)\n",
428 | "y = torch.ones_like(x)\n",
429 | "\n",
430 | "out = torch.nn.functional.mse_loss(compiled_f(x),y).backward()"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": null,
436 | "id": "a06362e9-7783-46bf-8498-bce83bc749ef",
437 | "metadata": {
438 | "tags": []
439 | },
440 | "outputs": [],
441 | "source": [
442 | "import glob\n",
443 | "fwd = glob.glob('torch_compile_debug/run_*/aot_torchinductor/*forward*/graph_diagram.svg')[-1]\n",
444 | "bwd = glob.glob('torch_compile_debug/run_*/aot_torchinductor/*backward*/graph_diagram.svg')[-1]\n",
445 | "\n",
446 | "md(f'''\n",
447 | "| | < Forward graph
Backward graph >||\n",
448 | "|---|---|---|\n",
449 | "''')"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": null,
455 | "id": "fb68e773-dce1-431e-8070-e2825d255bb1",
456 | "metadata": {},
457 | "outputs": [],
458 | "source": []
459 | }
460 | ],
461 | "metadata": {
462 | "kernelspec": {
463 | "display_name": "Python 3 (ipykernel)",
464 | "language": "python",
465 | "name": "python3"
466 | },
467 | "language_info": {
468 | "codemirror_mode": {
469 | "name": "ipython",
470 | "version": 3
471 | },
472 | "file_extension": ".py",
473 | "mimetype": "text/x-python",
474 | "name": "python",
475 | "nbconvert_exporter": "python",
476 | "pygments_lexer": "ipython3",
477 | "version": "3.10.9"
478 | }
479 | },
480 | "nbformat": 4,
481 | "nbformat_minor": 5
482 | }
483 |
--------------------------------------------------------------------------------
/pytorch-graph-optimization/benchmark_torch-compile_resnet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 19,
6 | "id": "b9fe79ed-0970-4b13-a6ac-4fc4f3c2e628",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "import torch.utils.benchmark as benchmark\n",
13 | "import torch\n",
14 | "from torchvision.models import resnet\n",
15 | "import torch._dynamo\n",
16 | "\n",
17 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\""
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 20,
23 | "id": "3bbdc444-058e-4439-9ee7-121fd962af1a",
24 | "metadata": {
25 | "tags": []
26 | },
27 | "outputs": [],
28 | "source": [
29 | "def run_batch_inference(model, batch=1):\n",
30 | " x = torch.randn(batch, 3, 224, 224).to(device)\n",
31 | " model(x)\n",
32 | "\n",
33 | "def run_batch_train(model, optimizer, batch=16):\n",
34 | " x = torch.randn(batch, 3, 224, 224).to(device)\n",
35 | " optimizer.zero_grad()\n",
36 | " out = model(x)\n",
37 | " out.sum().backward()\n",
38 | " optimizer.step()\n",
39 | " \n",
40 | "model = resnet.resnet18(weights=resnet.ResNet18_Weights.IMAGENET1K_V1).to(device)"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 22,
46 | "id": "7dfab6b2-57da-4bad-a8f1-af5df0c9e49f",
47 | "metadata": {
48 | "tags": []
49 | },
50 | "outputs": [
51 | {
52 | "name": "stderr",
53 | "output_type": "stream",
54 | "text": [
55 | "[2023-03-21 07:56:49,304] torch._inductor.debug: [WARNING] model__15_forward_25 debug trace: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_49_09_212386-pid_25659/aot_torchinductor/model__15_forward_25.10\n"
56 | ]
57 | },
58 | {
59 | "name": "stdout",
60 | "output_type": "stream",
61 | "text": [
62 | "Inference speedup: 1.96%\n"
63 | ]
64 | }
65 | ],
66 | "source": [
67 | "batch = 1\n",
68 | "torch._dynamo.reset()\n",
69 | "compiled_model = torch.compile(model, options={'triton.cudagraphs': False,\n",
70 | " 'trace.enabled':True})\n",
71 | "\n",
72 | "t_model = benchmark.Timer(\n",
73 | " stmt='run_batch_inference(model, batch)',\n",
74 | " setup='from __main__ import run_batch_inference',\n",
75 | " globals={'model': model, 'batch':batch})\n",
76 | "\n",
77 | "t_compiled_model = benchmark.Timer(\n",
78 | " stmt='run_batch_inference(model, batch)',\n",
79 | " setup='from __main__ import run_batch_inference',\n",
80 | " globals={'model': compiled_model, 'batch':batch})\n",
81 | "\n",
82 | "t_model_runs = t_model.timeit(100)\n",
83 | "t_compiled_model_runs = t_compiled_model.timeit(100)\n",
84 | "\n",
85 | "print(f\"Inference speedup: {100*(t_model_runs.mean - t_compiled_model_runs.mean) / t_model_runs.mean: .2f}%\")"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 4,
91 | "id": "bdf99896-0fe4-4b65-a1ed-13296a81c221",
92 | "metadata": {
93 | "tags": []
94 | },
95 | "outputs": [
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": [
100 | "\n",
101 | "run_batch_train(model, optimizer, batch)\n",
102 | "setup: from __main__ import run_batch_train\n",
103 | " 37.30 ms\n",
104 | " 1 measurement, 100 runs , 1 thread\n",
105 | "\n",
106 | "run_batch_train(model, optimizer, batch)\n",
107 | "setup: from __main__ import run_batch_train\n",
108 | " 34.57 ms\n",
109 | " 1 measurement, 100 runs , 1 thread\n",
110 | "Training speedup: 7.32%\n"
111 | ]
112 | }
113 | ],
114 | "source": [
115 | "batch = 32\n",
116 | "torch._dynamo.reset()\n",
117 | "compiled_model = torch.compile(model)\n",
118 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
119 | "\n",
120 | "t_model = benchmark.Timer(\n",
121 | " stmt='run_batch_train(model, optimizer, batch)',\n",
122 | " setup='from __main__ import run_batch_train',\n",
123 | " globals={'model': model,'optimizer':optimizer, 'batch':batch})\n",
124 | "\n",
125 | "t_compiled_model = benchmark.Timer(\n",
126 | " stmt='run_batch_train(model, optimizer, batch)',\n",
127 | " setup='from __main__ import run_batch_train',\n",
128 | " globals={'model': compiled_model, 'optimizer':optimizer, 'batch':batch})\n",
129 | "\n",
130 | "t_model_runs = t_model.timeit(100)\n",
131 | "t_compiled_model_runs = t_compiled_model.timeit(100)\n",
132 | "\n",
133 | "print(t_model_runs)\n",
134 | "print(t_compiled_model_runs)\n",
135 | "\n",
136 | "print(f\"Training speedup: {100*(t_model_runs.mean - t_compiled_model_runs.mean) / t_model_runs.mean: .2f}%\")"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "id": "bf10e32d-b29f-4df1-af18-6dadcccc9a98",
143 | "metadata": {},
144 | "outputs": [],
145 | "source": []
146 | }
147 | ],
148 | "metadata": {
149 | "kernelspec": {
150 | "display_name": "Python 3 (ipykernel)",
151 | "language": "python",
152 | "name": "python3"
153 | },
154 | "language_info": {
155 | "codemirror_mode": {
156 | "name": "ipython",
157 | "version": 3
158 | },
159 | "file_extension": ".py",
160 | "mimetype": "text/x-python",
161 | "name": "python",
162 | "nbconvert_exporter": "python",
163 | "pygments_lexer": "ipython3",
164 | "version": "3.10.9"
165 | }
166 | },
167 | "nbformat": 4,
168 | "nbformat_minor": 5
169 | }
170 |
--------------------------------------------------------------------------------
/pytorch-graph-optimization/graph_optimization_torch_compile.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "afb2eda5-99df-42b0-bbd3-f711f4d086b0",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "!rm -rf torch_compile_debug"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "id": "8f6b7210-882a-4c36-8060-4f324a85cccd",
19 | "metadata": {
20 | "tags": []
21 | },
22 | "outputs": [],
23 | "source": [
24 | "import torch\n",
25 | "import matplotlib.pyplot as plt\n",
26 | "from matplotlib import cm\n",
27 | "from matplotlib.ticker import LinearLocator\n",
28 | "import numpy as np\n",
29 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\""
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "id": "1cd9c0e4-bdc7-46d6-b0c4-d545ff5d2a29",
36 | "metadata": {
37 | "tags": []
38 | },
39 | "outputs": [],
40 | "source": [
41 | "def fn(x,y):\n",
42 | " # x*exp(−(x^2+y^2))+(x^2+y^2)/20\n",
43 | " return x*torch.exp(-x**2-y**2) + (x**2+y**2)/20 "
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 7,
49 | "id": "96b96737-a9e3-4c0a-ac13-502102cad039",
50 | "metadata": {
51 | "tags": []
52 | },
53 | "outputs": [
54 | {
55 | "data": {
56 | "image/png": "",
57 | "text/plain": [
58 | ""
59 | ]
60 | },
61 | "metadata": {},
62 | "output_type": "display_data"
63 | }
64 | ],
65 | "source": [
66 | "# Make data.\n",
67 | "x = torch.arange(-2, 2, 0.2,requires_grad=True)\n",
68 | "y = torch.arange(-2, 2, 0.2,requires_grad=True)\n",
69 | "x, y = torch.meshgrid(x, y, indexing='ij')\n",
70 | "r = fn(x,y)\n",
71 | "z = torch.sin(r)\n",
72 | "\n",
73 | "fig, ax = plt.subplots(subplot_kw={\"projection\": \"3d\"})\n",
74 | "\n",
75 | "# Plot the surface.\n",
76 | "surf = ax.plot_surface(x.detach().numpy(), y.detach().numpy(), z.detach().numpy(), cmap=cm.coolwarm,\n",
77 | " linewidth=0, antialiased=False)\n",
78 | "# Customize the z axis.\n",
79 | "ax.set_zlim(-0.5, 1.01)\n",
80 | "ax.zaxis.set_major_locator(LinearLocator(10))\n",
81 | "# A StrMethodFormatter is used automatically\n",
82 | "ax.zaxis.set_major_formatter('{x:.02f}')\n",
83 | "ax.view_init(10, 60)\n",
84 | "plt.show()"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 7,
90 | "id": "e5d0f05c-6990-4404-8013-8b323bfa7c6c",
91 | "metadata": {
92 | "tags": []
93 | },
94 | "outputs": [
95 | {
96 | "name": "stderr",
97 | "output_type": "stream",
98 | "text": [
99 | "[2023-03-21 07:04:09,637] torch._inductor.debug: [WARNING] model__1_forward_4 debug trace: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_03_48_329556-pid_24339/aot_torchinductor/model__1_forward_4.2\n",
100 | "[2023-03-21 07:04:09,677] torch._inductor.debug: [WARNING] model__1_backward_5 debug trace: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_03_48_329556-pid_24339/aot_torchinductor/model__1_backward_5.3\n"
101 | ]
102 | },
103 | {
104 | "name": "stdout",
105 | "output_type": "stream",
106 | "text": [
107 | "Writing FX graph to file: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_03_48_329556-pid_24339/aot_torchinductor/model__1_forward_4.2/graph_diagram.svg\n",
108 | "Writing FX graph to file: /pytorch-examples/pytorch-graph-optim/torch_compile_debug/run_2023_03_21_07_03_48_329556-pid_24339/aot_torchinductor/model__1_backward_5.3/graph_diagram.svg\n"
109 | ]
110 | }
111 | ],
112 | "source": [
113 | "fn_compiled = torch.compile(fn, backend=\"inductor\", \n",
114 | " options={'trace.graph_diagram':True,\n",
115 | " 'trace.enabled':True})\n",
116 | "\n",
117 | "out = fn_compiled(x.to(device), y.to(device)).sum().backward()"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "id": "d687a68c-afab-462e-b3cb-1d6f723603bb",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": []
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "id": "59d094a0-84fa-430d-b8e1-94c30e310283",
132 | "metadata": {
133 | "tags": []
134 | },
135 | "outputs": [
136 | {
137 | "name": "stdout",
138 | "output_type": "stream",
139 | "text": [
140 | "\n",
141 | "\n",
142 | "\n",
143 | "def forward(self, x, y):\n",
144 | " pow_1 = torch.pow(x, 2)\n",
145 | " neg = -pow_1; pow_1 = None\n",
146 | " pow_2 = torch.pow(y, 2)\n",
147 | " sub = neg - pow_2; neg = pow_2 = None\n",
148 | " exp = torch.exp(sub); sub = None\n",
149 | " mul = x * exp; exp = None\n",
150 | " pow_3 = torch.pow(x, 2); x = None\n",
151 | " pow_4 = torch.pow(y, 2); y = None\n",
152 | " add = pow_3 + pow_4; pow_3 = pow_4 = None\n",
153 | " truediv = add / 20; add = None\n",
154 | " add_1 = mul + truediv; mul = truediv = None\n",
155 | " return add_1\n",
156 | " \n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "from torch.fx import passes, symbolic_trace\n",
162 | "model = symbolic_trace(fn)\n",
163 | "\n",
164 | "g = passes.graph_drawer.FxGraphDrawer(model, 'fn')\n",
165 | "with open(\"unoptimized_graph1.svg\", \"wb\") as f:\n",
166 | " f.write(g.get_dot_graph().create_svg())\n",
167 | " \n",
168 | "print(model.code)"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 17,
174 | "id": "090bfe7c-8737-40eb-b870-9a87bbaa09b0",
175 | "metadata": {
176 | "tags": []
177 | },
178 | "outputs": [
179 | {
180 | "name": "stdout",
181 | "output_type": "stream",
182 | "text": [
183 | "\n",
184 | "\n",
185 | "\n",
186 | "def forward(self, x_1, y_1):\n",
187 | " pow_1 = torch.ops.aten.pow.Tensor_Scalar(x_1, 2)\n",
188 | " neg = torch.ops.aten.neg.default(pow_1); pow_1 = None\n",
189 | " pow_2 = torch.ops.aten.pow.Tensor_Scalar(y_1, 2)\n",
190 | " sub = torch.ops.aten.sub.Tensor(neg, pow_2); neg = pow_2 = None\n",
191 | " exp = torch.ops.aten.exp.default(sub); sub = None\n",
192 | " detach = torch.ops.aten.detach.default(exp)\n",
193 | " mul = torch.ops.aten.mul.Tensor(x_1, exp); exp = None\n",
194 | " pow_3 = torch.ops.aten.pow.Tensor_Scalar(x_1, 2); x_1 = None\n",
195 | " pow_4 = torch.ops.aten.pow.Tensor_Scalar(y_1, 2); y_1 = None\n",
196 | " add = torch.ops.aten.add.Tensor(pow_3, pow_4); pow_3 = pow_4 = None\n",
197 | " div = torch.ops.aten.div.Tensor(add, 20); add = None\n",
198 | " add_1 = torch.ops.aten.add.Tensor(mul, div); mul = div = None\n",
199 | " return add_1\n",
200 | " \n"
201 | ]
202 | }
203 | ],
204 | "source": [
205 | "from functorch import make_fx\n",
206 | "g = make_fx(fn)(x, y)\n",
207 | "print(g.code)"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "id": "6351c611-cf45-40e0-8e78-5b3707cb8a0e",
214 | "metadata": {},
215 | "outputs": [],
216 | "source": []
217 | }
218 | ],
219 | "metadata": {
220 | "kernelspec": {
221 | "display_name": "Python 3 (ipykernel)",
222 | "language": "python",
223 | "name": "python3"
224 | },
225 | "language_info": {
226 | "codemirror_mode": {
227 | "name": "ipython",
228 | "version": 3
229 | },
230 | "file_extension": ".py",
231 | "mimetype": "text/x-python",
232 | "name": "python",
233 | "nbconvert_exporter": "python",
234 | "pygments_lexer": "ipython3",
235 | "version": "3.10.9"
236 | }
237 | },
238 | "nbformat": 4,
239 | "nbformat_minor": 5
240 | }
241 |
--------------------------------------------------------------------------------
/pytorch-graph-optimization/inspecting_torch_compile.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "2c42e096-2d57-4e1d-976b-003cade9a05c",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "import torch\n",
13 | "import torch._dynamo\n",
14 | "from torch import nn"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 3,
20 | "id": "0a16cbc2-0baa-44f4-beb1-030dcb7c1433",
21 | "metadata": {
22 | "tags": []
23 | },
24 | "outputs": [],
25 | "source": [
26 | "class MLP(nn.Module):\n",
27 | " def __init__(self):\n",
28 | " super().__init__()\n",
29 | " self.fc1 = nn.Linear(32, 64)\n",
30 | "\n",
31 | " def forward(self, x):\n",
32 | " x = self.fc1(x)\n",
33 | " x = torch.nn.functional.gelu(x)\n",
34 | " return x\n",
35 | "\n",
36 | "model = MLP()\n",
37 | "\n",
38 | "batch_size = 8\n",
39 | "x = torch.randn(batch_size, 32)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "id": "43846e32-0973-4e62-ab70-563ec780bcf5",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "from torch.fx import passes, symbolic_trace\n",
50 | "model = symbolic_trace(fn)\n",
51 | "\n",
52 | "g = passes.graph_drawer.FxGraphDrawer(model, 'fn')\n",
53 | "with open(\"unoptimized_graph.svg\", \"wb\") as f:\n",
54 | " f.write(g.get_dot_graph().create_svg())"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 21,
60 | "id": "0a7036c0-cf13-44fb-ac96-79ad76fad291",
61 | "metadata": {
62 | "tags": []
63 | },
64 | "outputs": [
65 | {
66 | "name": "stdout",
67 | "output_type": "stream",
68 | "text": [
69 | "Writing FX graph to file: forward.svg\n",
70 | "Writing FX graph to file: backward.svg\n"
71 | ]
72 | },
73 | {
74 | "name": "stderr",
75 | "output_type": "stream",
76 | "text": [
77 | "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\n",
78 | " warnings.warn(\n",
79 | "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\n",
80 | " warnings.warn(\n"
81 | ]
82 | }
83 | ],
84 | "source": [
85 | "import torch._dynamo\n",
86 | "from torch._functorch.aot_autograd import aot_module_simplified\n",
87 | "from functorch.compile import compiled_function, draw_graph\n",
88 | "\n",
89 | "\n",
90 | "def toy_backend(gm, sample_inputs): \n",
91 | " def fw(gm, sample_inputs):\n",
92 | " draw_graph(gm, \"forward.svg\")\n",
93 | " return gm.forward\n",
94 | " \n",
95 | " def bw(gm, sample_inputs):\n",
96 | " draw_graph(gm, \"backward.svg\")\n",
97 | " return gm.forward\n",
98 | "\n",
99 | " # Invoke AOTAutograd\n",
100 | " return aot_module_simplified(\n",
101 | " gm,\n",
102 | " sample_inputs,\n",
103 | " fw_compiler=fw,\n",
104 | " bw_compiler=bw\n",
105 | " )\n",
106 | "\n",
107 | "def fn(x):\n",
108 | " return x**2\n",
109 | "\n",
110 | "model = fn\n",
111 | "x = torch.tensor(5., requires_grad=True)\n",
112 | "\n",
113 | "torch._dynamo.reset()\n",
114 | "cmodel = torch.compile(model, backend=toy_backend, dynamic=True)\n",
115 | "\n",
116 | "# triggers compilation of forward graph on the first run\n",
117 | "out = cmodel(x).sum().backward()"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 5,
123 | "id": "259b9c86-5f23-4f61-b178-15e71a6e2197",
124 | "metadata": {
125 | "tags": []
126 | },
127 | "outputs": [
128 | {
129 | "name": "stdout",
130 | "output_type": "stream",
131 | "text": [
132 | "\n",
133 | "\n",
134 | "\n",
135 | "def forward(self, x):\n",
136 | " param = self.param\n",
137 | " add = x + param; x = param = None\n",
138 | " linear = self.linear(add); add = None\n",
139 | " clamp = linear.clamp(min = 0.0, max = 1.0); linear = None\n",
140 | " return clamp\n",
141 | " \n"
142 | ]
143 | }
144 | ],
145 | "source": [
146 | "print(symtraced.code)"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "id": "8125d380-aaa2-4685-95a4-a035449cacfb",
153 | "metadata": {},
154 | "outputs": [],
155 | "source": []
156 | }
157 | ],
158 | "metadata": {
159 | "kernelspec": {
160 | "display_name": "Python 3 (ipykernel)",
161 | "language": "python",
162 | "name": "python3"
163 | },
164 | "language_info": {
165 | "codemirror_mode": {
166 | "name": "ipython",
167 | "version": 3
168 | },
169 | "file_extension": ".py",
170 | "mimetype": "text/x-python",
171 | "name": "python",
172 | "nbconvert_exporter": "python",
173 | "pygments_lexer": "ipython3",
174 | "version": "3.10.9"
175 | }
176 | },
177 | "nbformat": 4,
178 | "nbformat_minor": 5
179 | }
180 |
--------------------------------------------------------------------------------
/pytorch-intro-torch-compile/1-toy-benchmarks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "b33cbc31-dc42-4987-a9e2-513e31dd92de",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "import torch\n",
13 | "import time\n",
14 | "import os\n",
15 | "\n",
16 | "from torch import nn\n",
17 | "import torchvision.models as models\n",
18 | "from triton.testing import do_bench\n",
19 | "import torch._dynamo"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "id": "9d742cd2-8f4b-40d5-a51e-36dacb9a11a0",
26 | "metadata": {
27 | "tags": []
28 | },
29 | "outputs": [],
30 | "source": [
31 | "torch.set_float32_matmul_precision('high')\n",
32 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 3,
38 | "id": "98c588e4-db0c-4c73-8000-c51edad58d89",
39 | "metadata": {
40 | "tags": []
41 | },
42 | "outputs": [],
43 | "source": [
44 | "def run_benchmark(fn):\n",
45 | " exec_time, prctl20, prctl80 = do_bench(fn,warmup=100,rep=1000)\n",
46 | " print(f\"Exec time (median): {exec_time}\")\n",
47 | " print(f\"Exec time (20th percentile): {prctl20}\")\n",
48 | " print(f\"Exec time (80th percentile): {prctl80}\\n\")\n",
49 | " return exec_time"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "id": "cbe00915-2f6e-4a45-b33e-b309dd549dae",
55 | "metadata": {},
56 | "source": [
57 | "## 1. ResNet50 Speedup on NVIDIA A10G"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 4,
63 | "id": "306dae3d-ad53-442a-a62f-70c18d8567fe",
64 | "metadata": {
65 | "tags": []
66 | },
67 | "outputs": [],
68 | "source": [
69 | "def run_batch(model, optimizer):\n",
70 | " x = torch.randn(16, 3, 224, 224).to(device)\n",
71 | " optimizer.zero_grad()\n",
72 | " out = model(x)\n",
73 | " out.sum().backward()\n",
74 | " optimizer.step()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 5,
80 | "id": "6e3d508d-75db-406b-92e2-6bedffd1510d",
81 | "metadata": {
82 | "tags": []
83 | },
84 | "outputs": [],
85 | "source": [
86 | "model = models.resnet50().to(device)"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 6,
92 | "id": "9d483ad8-2882-4ce5-a491-cb464b3e3bb5",
93 | "metadata": {
94 | "tags": []
95 | },
96 | "outputs": [
97 | {
98 | "name": "stdout",
99 | "output_type": "stream",
100 | "text": [
101 | "Resnet50 Eager mode\n",
102 | "Exec time (median): 50.456573486328125\n",
103 | "Exec time (20th percentile): 50.4313850402832\n",
104 | "Exec time (80th percentile): 50.50531768798828\n",
105 | "\n",
106 | "Resnet50 Compiled defaults\n",
107 | "Exec time (median): 46.913536071777344\n",
108 | "Exec time (20th percentile): 46.89653778076172\n",
109 | "Exec time (80th percentile): 46.9372673034668\n",
110 | "\n",
111 | "speedup: 7.55%\n"
112 | ]
113 | },
114 | {
115 | "name": "stderr",
116 | "output_type": "stream",
117 | "text": [
118 | "Process ForkProcess-4:\n",
119 | "Process ForkProcess-1:\n",
120 | "Process ForkProcess-3:\n",
121 | "Process ForkProcess-2:\n"
122 | ]
123 | }
124 | ],
125 | "source": [
126 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
127 | "\n",
128 | "# Benchmark Eager\n",
129 | "print(\"Resnet50 Eager mode\")\n",
130 | "exec_time = run_benchmark(lambda: run_batch(model, optimizer))\n",
131 | "\n",
132 | "# Benchmark torch.compile defaults\n",
133 | "print(\"Resnet50 Compiled defaults\")\n",
134 | "opt_model = torch.compile(model)\n",
135 | "opt_exec_time = run_benchmark(lambda: run_batch(opt_model, optimizer))\n",
136 | "\n",
137 | "# Print speedups\n",
138 | "print(f\"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%\")"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "id": "fcd3add7-7c03-4559-9109-8443700febdf",
144 | "metadata": {},
145 | "source": [
146 | "## 2. Custom model Speedup on NVIDIA A10G"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 7,
152 | "id": "8ca833c0-d383-45ce-9f31-06b126539376",
153 | "metadata": {
154 | "tags": []
155 | },
156 | "outputs": [],
157 | "source": [
158 | "class MLP(nn.Module):\n",
159 | " def __init__(self):\n",
160 | " super().__init__()\n",
161 | " self.fc1 = nn.Linear(1024, 1024)\n",
162 | " self.fc2 = nn.Linear(1024, 1024)\n",
163 | " \n",
164 | " def forward(self, x):\n",
165 | " x = self.fc1(x).relu() ** 2\n",
166 | " return self.fc2(x).relu() ** 2"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 8,
172 | "id": "bc293b4a-6634-4e66-95f0-1675ab363114",
173 | "metadata": {
174 | "tags": []
175 | },
176 | "outputs": [],
177 | "source": [
178 | "model = MLP().to(device)\n",
179 | "x = torch.randn(1024, 1024).to(device)"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": 9,
185 | "id": "020901c4-9054-4652-acaf-0232a31d0011",
186 | "metadata": {
187 | "tags": []
188 | },
189 | "outputs": [
190 | {
191 | "name": "stdout",
192 | "output_type": "stream",
193 | "text": [
194 | "Exec time (median): 0.7147520184516907\n",
195 | "Exec time (20th percentile): 0.7127040028572083\n",
196 | "Exec time (80th percentile): 0.7157760262489319\n",
197 | "\n",
198 | "Exec time (median): 0.6010879874229431\n",
199 | "Exec time (20th percentile): 0.6000639796257019\n",
200 | "Exec time (80th percentile): 0.6021119952201843\n",
201 | "\n",
202 | "speedup: 18.91%\n"
203 | ]
204 | }
205 | ],
206 | "source": [
207 | "# Benchmark Eager\n",
208 | "exec_time = run_benchmark(lambda: model(x).sum().backward())\n",
209 | "\n",
210 | "torch._dynamo.reset()\n",
211 | "# Benchmark torch.compile defaults\n",
212 | "cmodel = torch.compile(model, backend='inductor')\n",
213 | "opt_exec_time = run_benchmark(lambda: cmodel(x).sum().backward())\n",
214 | "\n",
215 | "# Print speedups\n",
216 | "print(f\"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%\")"
217 | ]
218 | },
219 | {
220 | "cell_type": "markdown",
221 | "id": "49a00c44-c096-4ae8-afe8-7937db196933",
222 | "metadata": {},
223 | "source": [
224 | "## 3. HuggingFace model Speedup on NVIDIA A10G"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 10,
230 | "id": "d7a4eaae-8bc3-4a21-8147-a2d75e5fba3a",
231 | "metadata": {
232 | "tags": []
233 | },
234 | "outputs": [],
235 | "source": [
236 | "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC\n",
237 | "from datasets import load_dataset\n",
238 | "\n",
239 | "def run_inference(model, input_values):\n",
240 | " \n",
241 | " # retrieve logits\n",
242 | " logits = model(input_values).logits\n",
243 | " \n",
244 | " # take argmax and decode\n",
245 | " predicted_ids = torch.argmax(logits, dim=-1)\n",
246 | " transcription = processor.batch_decode(predicted_ids)"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 11,
252 | "id": "76d00a8f-4503-4fa7-90cd-263f693097da",
253 | "metadata": {
254 | "tags": []
255 | },
256 | "outputs": [
257 | {
258 | "data": {
259 | "application/vnd.jupyter.widget-view+json": {
260 | "model_id": "7baf2003e8ab49cca9c52c88ba8e14c5",
261 | "version_major": 2,
262 | "version_minor": 0
263 | },
264 | "text/plain": [
265 | "Downloading (…)rocessor_config.json: 0%| | 0.00/158 [00:00, ?B/s]"
266 | ]
267 | },
268 | "metadata": {},
269 | "output_type": "display_data"
270 | },
271 | {
272 | "data": {
273 | "application/vnd.jupyter.widget-view+json": {
274 | "model_id": "95ce1004d67747d885b5390635fff5ed",
275 | "version_major": 2,
276 | "version_minor": 0
277 | },
278 | "text/plain": [
279 | "Downloading (…)okenizer_config.json: 0%| | 0.00/162 [00:00, ?B/s]"
280 | ]
281 | },
282 | "metadata": {},
283 | "output_type": "display_data"
284 | },
285 | {
286 | "data": {
287 | "application/vnd.jupyter.widget-view+json": {
288 | "model_id": "4c46f57ef4424b63a165423425872e10",
289 | "version_major": 2,
290 | "version_minor": 0
291 | },
292 | "text/plain": [
293 | "Downloading (…)lve/main/config.json: 0%| | 0.00/1.61k [00:00, ?B/s]"
294 | ]
295 | },
296 | "metadata": {},
297 | "output_type": "display_data"
298 | },
299 | {
300 | "data": {
301 | "application/vnd.jupyter.widget-view+json": {
302 | "model_id": "5234e1aa6f504a91b1d66cc10a2c5346",
303 | "version_major": 2,
304 | "version_minor": 0
305 | },
306 | "text/plain": [
307 | "Downloading (…)olve/main/vocab.json: 0%| | 0.00/291 [00:00, ?B/s]"
308 | ]
309 | },
310 | "metadata": {},
311 | "output_type": "display_data"
312 | },
313 | {
314 | "data": {
315 | "application/vnd.jupyter.widget-view+json": {
316 | "model_id": "6c075c6c9c63424781a7006dee4b604f",
317 | "version_major": 2,
318 | "version_minor": 0
319 | },
320 | "text/plain": [
321 | "Downloading (…)cial_tokens_map.json: 0%| | 0.00/85.0 [00:00, ?B/s]"
322 | ]
323 | },
324 | "metadata": {},
325 | "output_type": "display_data"
326 | },
327 | {
328 | "data": {
329 | "application/vnd.jupyter.widget-view+json": {
330 | "model_id": "27eaa34b5cc5424faa543fb470ab76ca",
331 | "version_major": 2,
332 | "version_minor": 0
333 | },
334 | "text/plain": [
335 | "Downloading pytorch_model.bin: 0%| | 0.00/1.26G [00:00, ?B/s]"
336 | ]
337 | },
338 | "metadata": {},
339 | "output_type": "display_data"
340 | },
341 | {
342 | "name": "stderr",
343 | "output_type": "stream",
344 | "text": [
345 | "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']\n",
346 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
347 | ]
348 | },
349 | {
350 | "data": {
351 | "application/vnd.jupyter.widget-view+json": {
352 | "model_id": "88e2f484ed74461c892b44cf2ab8eac2",
353 | "version_major": 2,
354 | "version_minor": 0
355 | },
356 | "text/plain": [
357 | "Downloading builder script: 0%| | 0.00/5.16k [00:00, ?B/s]"
358 | ]
359 | },
360 | "metadata": {},
361 | "output_type": "display_data"
362 | },
363 | {
364 | "name": "stdout",
365 | "output_type": "stream",
366 | "text": [
367 | "Downloading and preparing dataset librispeech_asr_dummy/clean to /root/.cache/huggingface/datasets/patrickvonplaten___librispeech_asr_dummy/clean/2.1.0/f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc...\n"
368 | ]
369 | },
370 | {
371 | "data": {
372 | "application/vnd.jupyter.widget-view+json": {
373 | "model_id": "f3a8c6b33a3f40d7951adb0a22d345dc",
374 | "version_major": 2,
375 | "version_minor": 0
376 | },
377 | "text/plain": [
378 | "Downloading data files: 0%| | 0/1 [00:00, ?it/s]"
379 | ]
380 | },
381 | "metadata": {},
382 | "output_type": "display_data"
383 | },
384 | {
385 | "data": {
386 | "application/vnd.jupyter.widget-view+json": {
387 | "model_id": "a044e0c9a4a24f4e9b8182b3150cf21a",
388 | "version_major": 2,
389 | "version_minor": 0
390 | },
391 | "text/plain": [
392 | "Downloading data: 0%| | 0.00/9.08M [00:00, ?B/s]"
393 | ]
394 | },
395 | "metadata": {},
396 | "output_type": "display_data"
397 | },
398 | {
399 | "data": {
400 | "application/vnd.jupyter.widget-view+json": {
401 | "model_id": "23f4accd58444806abf78bc4d19ab800",
402 | "version_major": 2,
403 | "version_minor": 0
404 | },
405 | "text/plain": [
406 | "Extracting data files: 0%| | 0/1 [00:00, ?it/s]"
407 | ]
408 | },
409 | "metadata": {},
410 | "output_type": "display_data"
411 | },
412 | {
413 | "data": {
414 | "application/vnd.jupyter.widget-view+json": {
415 | "model_id": "",
416 | "version_major": 2,
417 | "version_minor": 0
418 | },
419 | "text/plain": [
420 | "Generating validation split: 0 examples [00:00, ? examples/s]"
421 | ]
422 | },
423 | "metadata": {},
424 | "output_type": "display_data"
425 | },
426 | {
427 | "name": "stdout",
428 | "output_type": "stream",
429 | "text": [
430 | "Dataset librispeech_asr_dummy downloaded and prepared to /root/.cache/huggingface/datasets/patrickvonplaten___librispeech_asr_dummy/clean/2.1.0/f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc. Subsequent calls will reuse this data.\n"
431 | ]
432 | },
433 | {
434 | "name": "stderr",
435 | "output_type": "stream",
436 | "text": [
437 | "It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
438 | ]
439 | }
440 | ],
441 | "source": [
442 | "# load model and processor\n",
443 | "processor = Wav2Vec2Processor.from_pretrained(\"facebook/wav2vec2-large-960h-lv60-self\")\n",
444 | "model = Wav2Vec2ForCTC.from_pretrained(\"facebook/wav2vec2-large-960h-lv60-self\").cuda()\n",
445 | "\n",
446 | "# load dummy dataset and read soundfiles\n",
447 | "ds = load_dataset(\"patrickvonplaten/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
448 | "\n",
449 | "# tokenize\n",
450 | "input_values = processor(ds[0][\"audio\"][\"array\"], return_tensors=\"pt\", padding=\"longest\").input_values.cuda()"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": 12,
456 | "id": "92d5578e-e793-48c2-b043-0a77d25d71c4",
457 | "metadata": {
458 | "tags": []
459 | },
460 | "outputs": [
461 | {
462 | "name": "stdout",
463 | "output_type": "stream",
464 | "text": [
465 | "Exec time (median): 30.511903762817383\n",
466 | "Exec time (20th percentile): 30.503211975097656\n",
467 | "Exec time (80th percentile): 30.539411544799805\n",
468 | "\n"
469 | ]
470 | },
471 | {
472 | "name": "stderr",
473 | "output_type": "stream",
474 | "text": [
475 | "AUTOTUNE bias_addmm(544x1024, 544x512, 512x1024)\n",
476 | " triton_mm_1 0.0276s 100.0%\n",
477 | " triton_mm_3 0.0276s 100.0%\n",
478 | " triton_mm_2 0.0287s 96.4%\n",
479 | " triton_mm_4 0.0287s 96.4%\n",
480 | " bias_addmm 0.0287s 96.4%\n",
481 | " triton_mm_0 0.0297s 93.1%\n",
482 | " triton_mm_8 0.0328s 84.4%\n",
483 | " triton_mm_10 0.0328s 84.4%\n",
484 | " triton_mm_11 0.0348s 79.4%\n",
485 | " triton_mm_5 0.0389s 71.1%\n",
486 | "AUTOTUNE bias_addmm(544x1024, 544x1024, 1024x1024)\n",
487 | " triton_mm_13 0.0481s 100.0%\n",
488 | " triton_mm_15 0.0481s 100.0%\n",
489 | " triton_mm_14 0.0492s 97.9%\n",
490 | " triton_mm_16 0.0492s 97.9%\n",
491 | " triton_mm_12 0.0532s 90.4%\n",
492 | " triton_mm_20 0.0532s 90.4%\n",
493 | " bias_addmm 0.0543s 88.7%\n",
494 | " triton_mm_22 0.0543s 88.7%\n",
495 | " triton_mm_23 0.0594s 81.0%\n",
496 | " triton_mm_18 0.0635s 75.8%\n",
497 | "AUTOTUNE bmm(16x544x64, 16x64x544)\n",
498 | " triton_bmm_56 0.0553s 100.0%\n",
499 | " triton_bmm_48 0.0553s 100.0%\n",
500 | " triton_bmm_58 0.0573s 96.4%\n",
501 | " triton_bmm_59 0.0584s 94.7%\n",
502 | " triton_bmm_52 0.0604s 91.5%\n",
503 | " triton_bmm_49 0.0604s 91.5%\n",
504 | " triton_bmm_50 0.0604s 91.5%\n",
505 | " triton_bmm_51 0.0614s 90.0%\n",
506 | " bmm 0.0614s 90.0%\n",
507 | " triton_bmm_55 0.0676s 81.8%\n",
508 | "AUTOTUNE bmm(16x544x544, 16x544x64)\n",
509 | " triton_bmm_60 0.0594s 100.0%\n",
510 | " triton_bmm_70 0.0604s 98.3%\n",
511 | " triton_bmm_62 0.0614s 96.7%\n",
512 | " triton_bmm_64 0.0614s 96.7%\n",
513 | " triton_bmm_68 0.0625s 95.1%\n",
514 | " triton_bmm_66 0.0625s 95.1%\n",
515 | " triton_bmm_65 0.0635s 93.5%\n",
516 | " bmm 0.0635s 93.5%\n",
517 | " triton_bmm_61 0.0645s 92.1%\n",
518 | " triton_bmm_63 0.0645s 92.1%\n",
519 | "AUTOTUNE bias_addmm(544x4096, 544x1024, 1024x4096)\n",
520 | " bias_addmm 0.1720s 100.0%\n",
521 | " triton_mm_88 0.1772s 97.1%\n",
522 | " triton_mm_86 0.1782s 96.6%\n",
523 | " triton_mm_84 0.1782s 96.6%\n",
524 | " triton_mm_94 0.1792s 96.0%\n",
525 | " triton_mm_87 0.1833s 93.9%\n",
526 | " triton_mm_85 0.1833s 93.9%\n",
527 | " triton_mm_91 0.1966s 87.5%\n",
528 | " triton_mm_92 0.2058s 83.6%\n",
529 | " addmm 0.2130s 80.8%\n",
530 | "AUTOTUNE bias_addmm(544x1024, 544x4096, 4096x1024)\n",
531 | " triton_mm_97 0.1629s 100.0%\n",
532 | " triton_mm_99 0.1638s 99.4%\n",
533 | " triton_mm_98 0.1679s 97.0%\n",
534 | " triton_mm_100 0.1679s 97.0%\n",
535 | " bias_addmm 0.1700s 95.8%\n",
536 | " triton_mm_104 0.1792s 90.9%\n",
537 | " addmm 0.1884s 86.4%\n",
538 | " triton_mm_106 0.1956s 83.3%\n",
539 | " triton_mm_96 0.1966s 82.8%\n",
540 | " triton_mm_102 0.2365s 68.9%\n",
541 | "AUTOTUNE bias_addmm(544x32, 544x1024, 1024x32)\n",
542 | " bias_addmm 0.0164s 100.0%\n",
543 | " triton_mm_2321 0.0184s 88.9%\n",
544 | " triton_mm_2325 0.0184s 88.9%\n",
545 | " triton_mm_2322 0.0184s 88.9%\n",
546 | " addmm 0.0215s 76.2%\n",
547 | " triton_mm_2324 0.0266s 61.5%\n",
548 | " triton_mm_2316 0.0389s 42.1%\n",
549 | " triton_mm_2326 0.0420s 39.0%\n",
550 | " triton_mm_2327 0.0451s 36.4%\n",
551 | " triton_mm_2317 0.0451s 36.4%\n"
552 | ]
553 | },
554 | {
555 | "name": "stdout",
556 | "output_type": "stream",
557 | "text": [
558 | "Exec time (median): 27.311792373657227\n",
559 | "Exec time (20th percentile): 27.307104110717773\n",
560 | "Exec time (80th percentile): 27.3222713470459\n",
561 | "\n",
562 | "speedup: 11.72%\n"
563 | ]
564 | }
565 | ],
566 | "source": [
567 | "exec_time = run_benchmark(lambda: run_inference(model, input_values))\n",
568 | "\n",
569 | "torch._dynamo.reset()\n",
570 | "model = torch.compile(model, mode=\"max-autotune\")\n",
571 | "opt_exec_time = run_benchmark(lambda: run_inference(model, input_values))\n",
572 | "\n",
573 | "# Print speedups\n",
574 | "print(f\"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%\")"
575 | ]
576 | },
577 | {
578 | "cell_type": "code",
579 | "execution_count": 13,
580 | "id": "6abb2e5a-966e-4e55-852b-a4c9f7eb03e9",
581 | "metadata": {},
582 | "outputs": [
583 | {
584 | "data": {
585 | "text/plain": [
586 | "['aot_ts_nvfuser',\n",
587 | " 'cudagraphs',\n",
588 | " 'inductor',\n",
589 | " 'ipex',\n",
590 | " 'nvprims_nvfuser',\n",
591 | " 'onnxrt',\n",
592 | " 'tvm']"
593 | ]
594 | },
595 | "execution_count": 13,
596 | "metadata": {},
597 | "output_type": "execute_result"
598 | }
599 | ],
600 | "source": [
601 | "torch._dynamo.list_backends()"
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "execution_count": null,
607 | "id": "b35a7fc2-571d-4557-9744-43fb2b13f3fa",
608 | "metadata": {},
609 | "outputs": [],
610 | "source": []
611 | }
612 | ],
613 | "metadata": {
614 | "kernelspec": {
615 | "display_name": "Python 3 (ipykernel)",
616 | "language": "python",
617 | "name": "python3"
618 | },
619 | "language_info": {
620 | "codemirror_mode": {
621 | "name": "ipython",
622 | "version": 3
623 | },
624 | "file_extension": ".py",
625 | "mimetype": "text/x-python",
626 | "name": "python",
627 | "nbconvert_exporter": "python",
628 | "pygments_lexer": "ipython3",
629 | "version": "3.10.9"
630 | }
631 | },
632 | "nbformat": 4,
633 | "nbformat_minor": 5
634 | }
635 |
--------------------------------------------------------------------------------
/pytorch-intro-torch-compile/2-torch-compile-intro.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 5,
6 | "id": "a2e14a76-d475-4ae7-b481-3bf850e900d1",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "import torch\n",
13 | "from torch import nn"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 6,
19 | "id": "c5309297-5989-4dca-b61e-1dffe44992cd",
20 | "metadata": {
21 | "tags": []
22 | },
23 | "outputs": [],
24 | "source": [
25 | "def fn(x, y):\n",
26 | " a = torch.sin(x)\n",
27 | " b = torch.cos(y)\n",
28 | " return a + b"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 7,
34 | "id": "c1457bc7-f536-4c3e-a3c0-3a9c62d622b7",
35 | "metadata": {
36 | "tags": []
37 | },
38 | "outputs": [
39 | {
40 | "name": "stdout",
41 | "output_type": "stream",
42 | "text": [
43 | "Writing FX graph to file: /pytorch-examples/torch_compile_debug/run_2023_03_09_21_59_43_229833-pid_8972/aot_torchinductor/model__2_forward_7.4/graph_diagram.svg\n"
44 | ]
45 | },
46 | {
47 | "name": "stderr",
48 | "output_type": "stream",
49 | "text": [
50 | "[2023-03-09 23:31:43,033] torch._inductor.debug: [WARNING] model__2_forward_7 debug trace: /pytorch-examples/torch_compile_debug/run_2023_03_09_21_59_43_229833-pid_8972/aot_torchinductor/model__2_forward_7.4\n"
51 | ]
52 | },
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "Writing FX graph to file: /pytorch-examples/torch_compile_debug/run_2023_03_09_21_59_43_229833-pid_8972/aot_torchinductor/model__2_backward_8.5/graph_diagram.svg\n"
58 | ]
59 | },
60 | {
61 | "name": "stderr",
62 | "output_type": "stream",
63 | "text": [
64 | "[2023-03-09 23:31:44,253] torch._inductor.debug: [WARNING] model__2_backward_8 debug trace: /pytorch-examples/torch_compile_debug/run_2023_03_09_21_59_43_229833-pid_8972/aot_torchinductor/model__2_backward_8.5\n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "new_fn = torch.compile(fn, backend=\"inductor\", \n",
70 | " options={'trace.graph_diagram':True,\n",
71 | " 'trace.enabled':True})\n",
72 | "\n",
73 | "input_tensor = torch.randn(10000, requires_grad=True)\n",
74 | "out = new_fn(input_tensor, input_tensor).sum().backward()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 8,
80 | "id": "d3dd98c3-9d6e-4031-b4c8-75b462057ebe",
81 | "metadata": {
82 | "tags": []
83 | },
84 | "outputs": [],
85 | "source": [
86 | "from torch.fx import passes, symbolic_trace\n",
87 | "model = symbolic_trace(fn)\n",
88 | "\n",
89 | "g = passes.graph_drawer.FxGraphDrawer(model, 'fn')\n",
90 | "with open(\"unoptimized_graph.svg\", \"wb\") as f:\n",
91 | " f.write(g.get_dot_graph().create_svg())"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "id": "05970749-293f-4060-b266-2c6a45a75a5a",
98 | "metadata": {},
99 | "outputs": [],
100 | "source": []
101 | }
102 | ],
103 | "metadata": {
104 | "kernelspec": {
105 | "display_name": "Python 3 (ipykernel)",
106 | "language": "python",
107 | "name": "python3"
108 | },
109 | "language_info": {
110 | "codemirror_mode": {
111 | "name": "ipython",
112 | "version": 3
113 | },
114 | "file_extension": ".py",
115 | "mimetype": "text/x-python",
116 | "name": "python",
117 | "nbconvert_exporter": "python",
118 | "pygments_lexer": "ipython3",
119 | "version": "3.10.9"
120 | }
121 | },
122 | "nbformat": 4,
123 | "nbformat_minor": 5
124 | }
125 |
--------------------------------------------------------------------------------
/pytorch-intro-torch-compile/3-inspecting-compiler-stack.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "acdb00b0-7032-4133-9cd4-37dddc40a33f",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "import torch\n",
13 | "import torch._dynamo\n",
14 | "from torch import nn"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "id": "92c8357d-f9a1-4a7f-be01-1aae1c8a76c5",
21 | "metadata": {
22 | "tags": []
23 | },
24 | "outputs": [],
25 | "source": [
26 | "class MLP(nn.Module):\n",
27 | " def __init__(self):\n",
28 | " super().__init__()\n",
29 | " self.fc1 = nn.Linear(32, 64)\n",
30 | "\n",
31 | " def forward(self, x):\n",
32 | " x = self.fc1(x)\n",
33 | " x = torch.nn.functional.gelu(x)\n",
34 | " return x\n",
35 | "\n",
36 | "model = MLP()\n",
37 | "\n",
38 | "batch_size = 8\n",
39 | "input = torch.randn(batch_size, 32)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "id": "e49cf111-80bc-4074-95a5-19f329d495d5",
46 | "metadata": {
47 | "tags": []
48 | },
49 | "outputs": [],
50 | "source": [
51 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "id": "296572f4-0eeb-4ae8-9977-d9d24e1bb7a1",
57 | "metadata": {},
58 | "source": [
59 | "### Invoke `torch.compile` produces a fx graph in Torch IR"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 4,
65 | "id": "a78f233e-7295-4e7b-9439-7903f3a5a3c5",
66 | "metadata": {
67 | "tags": []
68 | },
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "Dynamo produced a fx Graph in Torch IR:\n",
75 | "class GraphModule(torch.nn.Module):\n",
76 | " def forward(self, x : torch.Tensor):\n",
77 | " # File: /tmp/ipykernel_8670/1490842273.py:7, code: x = self.fc1(x)\n",
78 | " self_fc1 = self.self_fc1(x); x = None\n",
79 | " \n",
80 | " # File: /tmp/ipykernel_8670/1490842273.py:8, code: x = torch.nn.functional.gelu(x)\n",
81 | " gelu = torch._C._nn.gelu(self_fc1); self_fc1 = None\n",
82 | " return (gelu,)\n",
83 | " \n",
84 | "Notice that sample_inputs is a list of flattened FakeTensor:\n",
85 | "[FakeTensor(FakeTensor(..., device='meta', size=(s0, 32)), cpu)]\n"
86 | ]
87 | }
88 | ],
89 | "source": [
90 | "def toy_backend(gm, sample_inputs):\n",
91 | " print(\"Dynamo produced a fx Graph in Torch IR:\")\n",
92 | " gm.print_readable()\n",
93 | "\n",
94 | " print(\"Notice that sample_inputs is a list of flattened FakeTensor:\")\n",
95 | " print(sample_inputs)\n",
96 | " return gm.forward\n",
97 | "\n",
98 | "torch._dynamo.reset()\n",
99 | "cmodel = torch.compile(model, backend=toy_backend, dynamic=True)\n",
100 | "\n",
101 | "# triggers compilation of forward graph on the first run\n",
102 | "out = cmodel(input)"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "id": "cf9d8eba-34d4-42a1-9914-32f428403042",
108 | "metadata": {},
109 | "source": [
110 | "## Invoke AOTAutograd, produces forward + backward FX graph in Aten IR\n",
111 | "* Captures forward + backwards\n",
112 | "* Lowering from Torch IR to Aten/Prims IR"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "id": "cd68d6c5-f52a-471c-bff6-c5127cea5723",
118 | "metadata": {},
119 | "source": [
120 | "### Core Aten IR (https://pytorch.org/docs/master/ir.html#core-aten-ir)\n",
121 | "\n",
122 | "* A strict subset of aten operators (< 250) after decompositions\n",
123 | "* Purely functional (no inputs mutations)\n",
124 | "* Guaranteed metadata information, e.g. dtype and shape propagation"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 5,
130 | "id": "19a66032-61e8-4902-b4b2-056cd1a6eb35",
131 | "metadata": {
132 | "tags": []
133 | },
134 | "outputs": [
135 | {
136 | "name": "stdout",
137 | "output_type": "stream",
138 | "text": [
139 | "AOTAutograd produced a fx Graph in Aten IR:\n",
140 | "class GraphModule(torch.nn.Module):\n",
141 | " def forward(self, primals_1: f32[64, 32], primals_2: f32[64], primals_3: f32[s0, 32]):\n",
142 | " # File: /tmp/ipykernel_8670/1490842273.py:7, code: x = self.fc1(x)\n",
143 | " t: f32[32, 64] = torch.ops.aten.t.default(primals_1); primals_1 = None\n",
144 | " addmm: f32[s0, 64] = torch.ops.aten.addmm.default(primals_2, primals_3, t); primals_2 = t = None\n",
145 | " \n",
146 | " # File: /tmp/ipykernel_8670/1490842273.py:8, code: x = torch.nn.functional.gelu(x)\n",
147 | " gelu: f32[s0, 64] = torch.ops.aten.gelu.default(addmm)\n",
148 | " return [gelu, addmm, primals_3]\n",
149 | " \n"
150 | ]
151 | },
152 | {
153 | "name": "stderr",
154 | "output_type": "stream",
155 | "text": [
156 | "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\n",
157 | " warnings.warn(\n"
158 | ]
159 | }
160 | ],
161 | "source": [
162 | "import torch._dynamo\n",
163 | "from torch._functorch.aot_autograd import aot_module_simplified\n",
164 | "\n",
165 | "def toy_backend(gm, sample_inputs): \n",
166 | " def my_compiler(gm, sample_inputs):\n",
167 | " # \n",
168 | " print(\"AOTAutograd produced a fx Graph in Aten IR:\")\n",
169 | " gm.print_readable()\n",
170 | " return gm.forward\n",
171 | "\n",
172 | " # Invoke AOTAutograd\n",
173 | " return aot_module_simplified(\n",
174 | " gm,\n",
175 | " sample_inputs,\n",
176 | " fw_compiler=my_compiler\n",
177 | " )\n",
178 | "\n",
179 | "torch._dynamo.reset()\n",
180 | "cmodel = torch.compile(model, backend=toy_backend, dynamic=True)\n",
181 | "\n",
182 | "# triggers compilation of forward graph on the first run\n",
183 | "out = cmodel(input)"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 6,
189 | "id": "199f5ca0-0743-4f91-859d-c495a723280c",
190 | "metadata": {
191 | "tags": []
192 | },
193 | "outputs": [
194 | {
195 | "name": "stdout",
196 | "output_type": "stream",
197 | "text": [
198 | "Decomposed fx Graph in Aten IR:\n",
199 | "class GraphModule(torch.nn.Module):\n",
200 | " def forward(self, primals_1: f32[64, 32], primals_2: f32[64], primals_3: f32[s0, 32]):\n",
201 | " # File: /tmp/ipykernel_8670/1490842273.py:7, code: x = self.fc1(x)\n",
202 | " permute: f32[32, 64] = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None\n",
203 | " addmm: f32[s0, 64] = torch.ops.aten.addmm.default(primals_2, primals_3, permute); primals_2 = permute = None\n",
204 | " \n",
205 | " # File: /tmp/ipykernel_8670/1490842273.py:8, code: x = torch.nn.functional.gelu(x)\n",
206 | " mul: f32[s0, 64] = torch.ops.aten.mul.Tensor(addmm, 0.5)\n",
207 | " mul_1: f32[s0, 64] = torch.ops.aten.mul.Tensor(addmm, 0.7071067811865476)\n",
208 | " erf: f32[s0, 64] = torch.ops.aten.erf.default(mul_1); mul_1 = None\n",
209 | " add: f32[s0, 64] = torch.ops.aten.add.Tensor(erf, 1); erf = None\n",
210 | " mul_2: f32[s0, 64] = torch.ops.aten.mul.Tensor(mul, add); mul = add = None\n",
211 | " return [mul_2, addmm, primals_3]\n",
212 | " \n"
213 | ]
214 | },
215 | {
216 | "name": "stderr",
217 | "output_type": "stream",
218 | "text": [
219 | "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1251: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\n",
220 | " warnings.warn(\n"
221 | ]
222 | }
223 | ],
224 | "source": [
225 | "from torch._inductor.decomposition import decompositions as default_decompositions\n",
226 | "\n",
227 | "decompositions = default_decompositions.copy()\n",
228 | "\n",
229 | "def toy_backend(gm, sample_inputs):\n",
230 | " def my_compiler(gm, sample_inputs):\n",
231 | " # \n",
232 | " print(\"Decomposed fx Graph in Aten IR:\")\n",
233 | " gm.print_readable()\n",
234 | " return gm\n",
235 | "\n",
236 | " # Invoke AOTAutograd\n",
237 | " return aot_module_simplified(\n",
238 | " gm,\n",
239 | " sample_inputs,\n",
240 | " decompositions=decompositions,\n",
241 | " fw_compiler=my_compiler\n",
242 | " )\n",
243 | "\n",
244 | "torch._dynamo.reset()\n",
245 | "cmodel = torch.compile(model, backend=toy_backend, dynamic=True)\n",
246 | "\n",
247 | "# triggers compilation of forward graph on the first run\n",
248 | "out = cmodel(input)"
249 | ]
250 | },
251 | {
252 | "cell_type": "markdown",
253 | "id": "08c92c33-0803-4ce1-906c-e00faa1a620f",
254 | "metadata": {},
255 | "source": [
256 | "### Prims IR (https://pytorch.org/docs/master/ir.html#prims-ir)\n",
257 | "\n",
258 | "* Explicit type promotion and broadcasting\n",
259 | "* prims.convert_element_type\n",
260 | "* prims.broadcast_in_dim\n",
261 | "* For backends with powerful compiler that can reclaim the performance by fusion, e.g. nvFuser"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 7,
267 | "id": "236f9011-fca0-45d9-a832-4f07962d8ef4",
268 | "metadata": {
269 | "tags": []
270 | },
271 | "outputs": [
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "Further decomposed fx Graph in Prims IR:\n",
277 | "class (torch.nn.Module):\n",
278 | " def forward(self, arg0_1: f32[3], arg1_1: f16[3, 3]):\n",
279 | " # File: /tmp/ipykernel_8670/2178452752.py:7, code: return a + b\n",
280 | " _to_copy: f32[3, 3] = torch.ops.aten._to_copy.default(arg1_1, dtype = torch.float32); arg1_1 = None\n",
281 | " broadcast_in_dim: f32[3, 3] = torch.ops.prims.broadcast_in_dim.default(arg0_1, [3, 3], [1]); arg0_1 = None\n",
282 | " add: f32[3, 3] = torch.ops.prims.add.default(broadcast_in_dim, _to_copy); broadcast_in_dim = _to_copy = None\n",
283 | " return (add,)\n",
284 | " \n"
285 | ]
286 | }
287 | ],
288 | "source": [
289 | "prims_decomp = torch._decomp.get_decompositions([\n",
290 | " torch.ops.aten.add,\n",
291 | " torch.ops.aten.expand.default,\n",
292 | "])\n",
293 | "\n",
294 | "def fn(a, b):\n",
295 | " return a + b\n",
296 | "\n",
297 | "def toy_backend(gm, sample_inputs):\n",
298 | " def my_compiler(gm, sample_inputs):\n",
299 | " # \n",
300 | " print(\"Further decomposed fx Graph in Prims IR:\")\n",
301 | " gm.print_readable()\n",
302 | " return gm\n",
303 | "\n",
304 | " # Invoke AOTAutograd\n",
305 | " return aot_module_simplified(\n",
306 | " gm,\n",
307 | " sample_inputs,\n",
308 | " decompositions=prims_decomp,\n",
309 | " fw_compiler=my_compiler\n",
310 | " )\n",
311 | "\n",
312 | "torch._dynamo.reset()\n",
313 | "fn = torch.compile(backend=toy_backend)(fn)\n",
314 | "out = fn(torch.rand(3, dtype=torch.float), torch.rand(3, 3, dtype=torch.half))"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "id": "0c70a43d-7a90-409e-862c-7169c6fd088f",
321 | "metadata": {},
322 | "outputs": [],
323 | "source": []
324 | }
325 | ],
326 | "metadata": {
327 | "kernelspec": {
328 | "display_name": "Python 3 (ipykernel)",
329 | "language": "python",
330 | "name": "python3"
331 | },
332 | "language_info": {
333 | "codemirror_mode": {
334 | "name": "ipython",
335 | "version": 3
336 | },
337 | "file_extension": ".py",
338 | "mimetype": "text/x-python",
339 | "name": "python",
340 | "nbconvert_exporter": "python",
341 | "pygments_lexer": "ipython3",
342 | "version": "3.10.9"
343 | }
344 | },
345 | "nbformat": 4,
346 | "nbformat_minor": 5
347 | }
348 |
--------------------------------------------------------------------------------
/pytorch-intro-torch-compile/4-nn-example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "dd7ed31a-9899-4489-a7a1-bc2bb1d90ee7",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "import argparse\n",
13 | "import json\n",
14 | "import logging\n",
15 | "import os\n",
16 | "import sys\n",
17 | "\n",
18 | "import torch\n",
19 | "import torch.nn as nn\n",
20 | "import torch.nn.functional as F\n",
21 | "import torch.optim as optim\n",
22 | "import torch.utils.data\n",
23 | "import torchvision\n",
24 | "\n",
25 | "from torchvision import datasets, transforms\n",
26 | "\n",
27 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
28 | "torch.set_float32_matmul_precision('high') #Uses TF32 when available"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "id": "37effd54-1f4c-42b2-bd08-1d524908fdc3",
35 | "metadata": {
36 | "tags": []
37 | },
38 | "outputs": [],
39 | "source": [
40 | "def _get_model():\n",
41 | " class Net(nn.Module):\n",
42 | " def __init__(self):\n",
43 | " super(Net, self).__init__()\n",
44 | " self.conv1 = nn.Conv2d(3, 6, 5)\n",
45 | " self.pool = nn.MaxPool2d(2, 2)\n",
46 | " self.conv2 = nn.Conv2d(6, 16, 5)\n",
47 | " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
48 | " self.fc2 = nn.Linear(120, 84)\n",
49 | " self.fc3 = nn.Linear(84, 10)\n",
50 | "\n",
51 | " def forward(self, x):\n",
52 | " x = self.pool(F.relu(self.conv1(x)))\n",
53 | " x = self.pool(F.relu(self.conv2(x)))\n",
54 | " x = x.view(-1, 16 * 5 * 5)\n",
55 | " x = F.relu(self.fc1(x))\n",
56 | " x = F.relu(self.fc2(x))\n",
57 | " x = self.fc3(x)\n",
58 | " return x\n",
59 | " return Net()"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 3,
65 | "id": "55a1d996-acc3-45ac-8113-3344bc7ddc28",
66 | "metadata": {
67 | "tags": []
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# Define data augmentation\n",
72 | "def _get_transforms():\n",
73 | " transform = transforms.Compose([\n",
74 | " transforms.RandomCrop(32, padding=4),\n",
75 | " transforms.RandomHorizontalFlip(),\n",
76 | " transforms.ToTensor(),\n",
77 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
78 | " ])\n",
79 | " return transform"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "id": "f8cc60fa-faba-40d6-a2cc-7d5f29e682c6",
86 | "metadata": {
87 | "tags": []
88 | },
89 | "outputs": [],
90 | "source": [
91 | "def _get_dataloaders(batch_size):\n",
92 | " trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
93 | " download=True, transform=_get_transforms())\n",
94 | " testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
95 | " download=True, transform=_get_transforms())\n",
96 | " trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
97 | " shuffle=True)\n",
98 | " testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
99 | " shuffle=False)\n",
100 | " return trainloader, testloader"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 5,
106 | "id": "c7bda30f-df68-4a3b-925a-3153ba1f6809",
107 | "metadata": {
108 | "tags": []
109 | },
110 | "outputs": [],
111 | "source": [
112 | "def test(model, test_loader, device):\n",
113 | " test_loss = 0\n",
114 | " correct = 0\n",
115 | " with torch.no_grad():\n",
116 | " for data, target in test_loader:\n",
117 | " data, target = data.to(device), target.to(device)\n",
118 | " output = model(data)\n",
119 | " test_loss += F.nll_loss(output, target, reduction='mean').item() # sum up batch loss\n",
120 | " pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
121 | " correct += pred.eq(target.view_as(pred)).sum().item()\n",
122 | "\n",
123 | " test_loss /= len(test_loader.dataset)\n",
124 | " print(f\"Test set: Average loss: {test_loss}, Accuracy: {correct / len(test_loader.dataset)}\\n\")"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 6,
130 | "id": "96633f85-1875-4dc9-a2ee-9784083e63c9",
131 | "metadata": {
132 | "scrolled": true,
133 | "tags": []
134 | },
135 | "outputs": [],
136 | "source": [
137 | "import time\n",
138 | "def train(model, batch_size, epochs):\n",
139 | " torch.manual_seed(0)\n",
140 | " lr = 0.01\n",
141 | " momentum=0.9\n",
142 | " train_loader, test_loader = _get_dataloaders(batch_size)\n",
143 | "\n",
144 | " criterion = nn.CrossEntropyLoss()\n",
145 | " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)\n",
146 | "\n",
147 | " for epoch in range(1, epochs + 1):\n",
148 | " start_time = time.time()\n",
149 | " for batch_idx, (data, target) in enumerate(train_loader, 1):\n",
150 | " data, target = data.to(device), target.to(device)\n",
151 | "\n",
152 | " # zero the parameter gradients\n",
153 | " optimizer.zero_grad()\n",
154 | "\n",
155 | " # forward + backward + optimize\n",
156 | " output = model(data)\n",
157 | " loss = criterion(output, target)\n",
158 | " loss.backward()\n",
159 | " optimizer.step()\n",
160 | "\n",
161 | " print(f\"Train Epoch: {epoch} Epoch time: {time.time()-start_time:0.4f} Loss: {loss.item()}\")\n",
162 | " test(model, test_loader, device)"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 7,
168 | "id": "4c4c51e7-a0ac-4a7f-88a3-9505bcd8f248",
169 | "metadata": {
170 | "scrolled": true,
171 | "tags": []
172 | },
173 | "outputs": [
174 | {
175 | "name": "stdout",
176 | "output_type": "stream",
177 | "text": [
178 | "Train in eager mode on CIFAR-10\n",
179 | "Files already downloaded and verified\n",
180 | "Files already downloaded and verified\n",
181 | "Train Epoch: 1 Epoch time: 18.4492 Loss: 1.5229963064193726\n",
182 | "Test set: Average loss: -0.11479381905794143, Accuracy: 0.3574\n",
183 | "\n",
184 | "Train Epoch: 2 Epoch time: 17.1039 Loss: 1.255014419555664\n",
185 | "Test set: Average loss: -0.12258006700873375, Accuracy: 0.4207\n",
186 | "\n",
187 | "Train Epoch: 3 Epoch time: 17.0540 Loss: 1.2103943824768066\n",
188 | "Test set: Average loss: -0.14219269587993622, Accuracy: 0.4306\n",
189 | "\n",
190 | "Train Epoch: 4 Epoch time: 17.1195 Loss: 1.6141940355300903\n",
191 | "Test set: Average loss: -0.1223116991698742, Accuracy: 0.4078\n",
192 | "\n",
193 | "Train Epoch: 5 Epoch time: 17.1593 Loss: 1.7167662382125854\n",
194 | "Test set: Average loss: -0.1371232023358345, Accuracy: 0.427\n",
195 | "\n",
196 | "Train Epoch: 6 Epoch time: 17.0450 Loss: 1.6442832946777344\n",
197 | "Test set: Average loss: -0.11765576581060887, Accuracy: 0.4195\n",
198 | "\n",
199 | "Train Epoch: 7 Epoch time: 17.2778 Loss: 1.1586203575134277\n",
200 | "Test set: Average loss: -0.12454334303736686, Accuracy: 0.4175\n",
201 | "\n",
202 | "Train Epoch: 8 Epoch time: 17.1429 Loss: 2.3159003257751465\n",
203 | "Test set: Average loss: -0.14869595084786416, Accuracy: 0.4542\n",
204 | "\n",
205 | "Train Epoch: 9 Epoch time: 17.0725 Loss: 1.63433039188385\n",
206 | "Test set: Average loss: -0.13449764105677606, Accuracy: 0.4416\n",
207 | "\n",
208 | "Train Epoch: 10 Epoch time: 17.1152 Loss: 0.9652916193008423\n",
209 | "Test set: Average loss: -0.14638293540477754, Accuracy: 0.4566\n",
210 | "\n",
211 | "Train Epoch: 11 Epoch time: 17.0312 Loss: 1.4068830013275146\n",
212 | "Test set: Average loss: -0.13832889040112495, Accuracy: 0.4099\n",
213 | "\n",
214 | "Train Epoch: 12 Epoch time: 17.1650 Loss: 1.6365031003952026\n",
215 | "Test set: Average loss: -0.1226353962123394, Accuracy: 0.427\n",
216 | "\n",
217 | "Train Epoch: 13 Epoch time: 17.1685 Loss: 1.2202277183532715\n",
218 | "Test set: Average loss: -0.16288705285787583, Accuracy: 0.4464\n",
219 | "\n",
220 | "Train Epoch: 14 Epoch time: 17.1915 Loss: 1.5230098962783813\n",
221 | "Test set: Average loss: -0.13382687171697616, Accuracy: 0.4507\n",
222 | "\n",
223 | "Train Epoch: 15 Epoch time: 17.1395 Loss: 1.6405322551727295\n",
224 | "Test set: Average loss: -0.11657781254947186, Accuracy: 0.418\n",
225 | "\n",
226 | "Train Epoch: 16 Epoch time: 17.1124 Loss: 1.435512661933899\n",
227 | "Test set: Average loss: -0.16411116136312484, Accuracy: 0.4536\n",
228 | "\n",
229 | "Train Epoch: 17 Epoch time: 17.0854 Loss: 1.3877995014190674\n",
230 | "Test set: Average loss: -0.13008506304621698, Accuracy: 0.4505\n",
231 | "\n",
232 | "Train Epoch: 18 Epoch time: 17.0801 Loss: 1.3181973695755005\n",
233 | "Test set: Average loss: -0.12966735948324204, Accuracy: 0.4444\n",
234 | "\n",
235 | "Train Epoch: 19 Epoch time: 17.1122 Loss: 1.3785901069641113\n",
236 | "Test set: Average loss: -0.16746452817320823, Accuracy: 0.4708\n",
237 | "\n",
238 | "Train Epoch: 20 Epoch time: 17.0873 Loss: 2.601522922515869\n",
239 | "Test set: Average loss: -0.1661388152360916, Accuracy: 0.455\n",
240 | "\n",
241 | "Train Epoch: 21 Epoch time: 17.1412 Loss: 1.387449026107788\n",
242 | "Test set: Average loss: -0.13323016968667506, Accuracy: 0.4078\n",
243 | "\n",
244 | "Train Epoch: 22 Epoch time: 17.0955 Loss: 0.9796938300132751\n",
245 | "Test set: Average loss: -0.12620635756850243, Accuracy: 0.4348\n",
246 | "\n",
247 | "Train Epoch: 23 Epoch time: 17.1453 Loss: 1.7141715288162231\n",
248 | "Test set: Average loss: -0.14345559119582177, Accuracy: 0.4657\n",
249 | "\n",
250 | "Train Epoch: 24 Epoch time: 17.1091 Loss: 1.4967659711837769\n",
251 | "Test set: Average loss: -0.12276398456394673, Accuracy: 0.4001\n",
252 | "\n",
253 | "Train Epoch: 25 Epoch time: 17.0620 Loss: 1.9967564344406128\n",
254 | "Test set: Average loss: -0.15137818831205369, Accuracy: 0.4591\n",
255 | "\n",
256 | "Train Epoch: 26 Epoch time: 17.0726 Loss: 1.2827167510986328\n",
257 | "Test set: Average loss: -0.14104595474600792, Accuracy: 0.4339\n",
258 | "\n",
259 | "Train Epoch: 27 Epoch time: 17.1130 Loss: 1.6564006805419922\n",
260 | "Test set: Average loss: -0.15558898612856864, Accuracy: 0.457\n",
261 | "\n",
262 | "Train Epoch: 28 Epoch time: 17.0406 Loss: 1.5395925045013428\n",
263 | "Test set: Average loss: -0.15071935195326805, Accuracy: 0.4298\n",
264 | "\n",
265 | "Train Epoch: 29 Epoch time: 17.0579 Loss: 2.140176773071289\n",
266 | "Test set: Average loss: -0.15748710156083107, Accuracy: 0.4236\n",
267 | "\n",
268 | "Train Epoch: 30 Epoch time: 17.0842 Loss: 1.7239465713500977\n",
269 | "Test set: Average loss: -0.14595504192709924, Accuracy: 0.4115\n",
270 | "\n",
271 | "Train Epoch: 31 Epoch time: 17.1039 Loss: 1.7769050598144531\n",
272 | "Test set: Average loss: -0.1290455259978771, Accuracy: 0.3998\n",
273 | "\n",
274 | "Train Epoch: 32 Epoch time: 17.1451 Loss: 1.662497639656067\n",
275 | "Test set: Average loss: -0.15571664265990257, Accuracy: 0.4357\n",
276 | "\n",
277 | "Train Epoch: 33 Epoch time: 17.1315 Loss: 1.3852906227111816\n",
278 | "Test set: Average loss: -0.1613605972111225, Accuracy: 0.4172\n",
279 | "\n",
280 | "Train Epoch: 34 Epoch time: 17.0798 Loss: 1.7368806600570679\n",
281 | "Test set: Average loss: -0.12401523492336274, Accuracy: 0.3968\n",
282 | "\n",
283 | "Train Epoch: 35 Epoch time: 17.1884 Loss: 1.463071584701538\n",
284 | "Test set: Average loss: -0.13350237726569175, Accuracy: 0.4211\n",
285 | "\n",
286 | "Train Epoch: 36 Epoch time: 17.0499 Loss: 1.222204327583313\n",
287 | "Test set: Average loss: -0.16071371198892592, Accuracy: 0.4386\n",
288 | "\n",
289 | "Train Epoch: 37 Epoch time: 17.0795 Loss: 1.59268319606781\n",
290 | "Test set: Average loss: -0.1448448682665825, Accuracy: 0.4142\n",
291 | "\n",
292 | "Train Epoch: 38 Epoch time: 17.0917 Loss: 1.8278884887695312\n",
293 | "Test set: Average loss: -0.13484705570042133, Accuracy: 0.4107\n",
294 | "\n",
295 | "Train Epoch: 39 Epoch time: 17.1323 Loss: 1.2747076749801636\n",
296 | "Test set: Average loss: -0.14489886118769646, Accuracy: 0.3949\n",
297 | "\n",
298 | "Train Epoch: 40 Epoch time: 17.1330 Loss: 1.6077611446380615\n",
299 | "Test set: Average loss: -0.14588061140179634, Accuracy: 0.4289\n",
300 | "\n",
301 | "Train Epoch: 41 Epoch time: 17.2133 Loss: 1.6331448554992676\n",
302 | "Test set: Average loss: -0.09960648607611657, Accuracy: 0.3682\n",
303 | "\n",
304 | "Train Epoch: 42 Epoch time: 17.1127 Loss: 1.2302451133728027\n",
305 | "Test set: Average loss: -0.14067820476591586, Accuracy: 0.4283\n",
306 | "\n",
307 | "Train Epoch: 43 Epoch time: 17.1447 Loss: 1.6978567838668823\n",
308 | "Test set: Average loss: -0.16264084021151065, Accuracy: 0.4108\n",
309 | "\n",
310 | "Train Epoch: 44 Epoch time: 17.0487 Loss: 1.2845810651779175\n",
311 | "Test set: Average loss: -0.13672254542708398, Accuracy: 0.444\n",
312 | "\n",
313 | "Train Epoch: 45 Epoch time: 17.1562 Loss: 1.1537847518920898\n",
314 | "Test set: Average loss: -0.14150477679371834, Accuracy: 0.4273\n",
315 | "\n",
316 | "Train Epoch: 46 Epoch time: 17.1326 Loss: 1.6309105157852173\n",
317 | "Test set: Average loss: -0.13888773401975632, Accuracy: 0.3973\n",
318 | "\n",
319 | "Train Epoch: 47 Epoch time: 17.1038 Loss: 1.398644208908081\n",
320 | "Test set: Average loss: -0.15300521407723428, Accuracy: 0.4311\n",
321 | "\n",
322 | "Train Epoch: 48 Epoch time: 17.1457 Loss: 1.8404948711395264\n",
323 | "Test set: Average loss: -0.1579221700131893, Accuracy: 0.4233\n",
324 | "\n",
325 | "Train Epoch: 49 Epoch time: 17.1388 Loss: 1.3540527820587158\n",
326 | "Test set: Average loss: -0.1460476183593273, Accuracy: 0.4527\n",
327 | "\n",
328 | "Train Epoch: 50 Epoch time: 17.1742 Loss: 1.765850305557251\n",
329 | "Test set: Average loss: -0.12122882234454155, Accuracy: 0.4141\n",
330 | "\n",
331 | "CPU times: user 16min 39s, sys: 6.23 s, total: 16min 45s\n",
332 | "Wall time: 16min 47s\n"
333 | ]
334 | }
335 | ],
336 | "source": [
337 | "%%time\n",
338 | "print(\"Train in eager mode on CIFAR-10\")\n",
339 | "model = _get_model().to(device)\n",
340 | "\n",
341 | "train(model, batch_size=16, epochs=50)"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": 11,
347 | "id": "f5330b33-eb7b-4a18-939a-44564dcf4441",
348 | "metadata": {
349 | "tags": []
350 | },
351 | "outputs": [
352 | {
353 | "name": "stdout",
354 | "output_type": "stream",
355 | "text": [
356 | "Generating forward and backward graphs\n",
357 | "CPU times: user 1.01 s, sys: 12 ms, total: 1.02 s\n",
358 | "Wall time: 1.02 s\n"
359 | ]
360 | }
361 | ],
362 | "source": [
363 | "%%time\n",
364 | "import torch._inductor.config\n",
365 | "model = _get_model().to(device)\n",
366 | "model = torch.compile(model, backend=\"inductor\",\n",
367 | " mode=\"max-autotune\")\n",
368 | "\n",
369 | "randinput = torch.randn(16,3,32,32).to(device)\n",
370 | "randoutput = torch.randn(16,10).to(device)\n",
371 | "\n",
372 | "print('Generating forward and backward graphs')\n",
373 | "out = model(randinput)\n",
374 | "nn.CrossEntropyLoss()(out, randoutput).backward()"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 12,
380 | "id": "b5ef0dbd-2d0c-4f4c-bd93-b2d4c6765829",
381 | "metadata": {
382 | "scrolled": true,
383 | "tags": []
384 | },
385 | "outputs": [
386 | {
387 | "name": "stdout",
388 | "output_type": "stream",
389 | "text": [
390 | "Train in compiled mode on CIFAR-10\n",
391 | "Files already downloaded and verified\n",
392 | "Files already downloaded and verified\n",
393 | "Train Epoch: 1 Epoch time: 16.8979 Loss: 1.1281245946884155\n",
394 | "Test set: Average loss: -0.10141951135396958, Accuracy: 0.3753\n",
395 | "\n",
396 | "Train Epoch: 2 Epoch time: 16.8175 Loss: 1.5165073871612549\n",
397 | "Test set: Average loss: -0.11379602966904641, Accuracy: 0.4128\n",
398 | "\n",
399 | "Train Epoch: 3 Epoch time: 16.8435 Loss: 1.3197925090789795\n",
400 | "Test set: Average loss: -0.14007682649493217, Accuracy: 0.4195\n",
401 | "\n",
402 | "Train Epoch: 4 Epoch time: 17.1099 Loss: 1.7537651062011719\n",
403 | "Test set: Average loss: -0.12210166681408882, Accuracy: 0.444\n",
404 | "\n",
405 | "Train Epoch: 5 Epoch time: 16.9783 Loss: 1.5357666015625\n",
406 | "Test set: Average loss: -0.15047866759300232, Accuracy: 0.4326\n",
407 | "\n",
408 | "Train Epoch: 6 Epoch time: 16.7821 Loss: 1.6223543882369995\n",
409 | "Test set: Average loss: -0.143238749986887, Accuracy: 0.4263\n",
410 | "\n",
411 | "Train Epoch: 7 Epoch time: 16.8694 Loss: 1.7723138332366943\n",
412 | "Test set: Average loss: -0.12494028667807579, Accuracy: 0.4065\n",
413 | "\n",
414 | "Train Epoch: 8 Epoch time: 16.8882 Loss: 1.8077551126480103\n",
415 | "Test set: Average loss: -0.1637685509085655, Accuracy: 0.4561\n",
416 | "\n",
417 | "Train Epoch: 9 Epoch time: 16.8302 Loss: 1.4599609375\n",
418 | "Test set: Average loss: -0.1516111361026764, Accuracy: 0.4451\n",
419 | "\n",
420 | "Train Epoch: 10 Epoch time: 16.8638 Loss: 1.1293983459472656\n",
421 | "Test set: Average loss: -0.16525690048933028, Accuracy: 0.4697\n",
422 | "\n",
423 | "Train Epoch: 11 Epoch time: 16.8586 Loss: 0.9164922833442688\n",
424 | "Test set: Average loss: -0.15961521873474122, Accuracy: 0.4466\n",
425 | "\n",
426 | "Train Epoch: 12 Epoch time: 16.7826 Loss: 2.045623779296875\n",
427 | "Test set: Average loss: -0.10865869052410125, Accuracy: 0.4195\n",
428 | "\n",
429 | "Train Epoch: 13 Epoch time: 16.8195 Loss: 1.2999463081359863\n",
430 | "Test set: Average loss: -0.17005301181674004, Accuracy: 0.4713\n",
431 | "\n",
432 | "Train Epoch: 14 Epoch time: 16.8872 Loss: 1.5532177686691284\n",
433 | "Test set: Average loss: -0.17997907676696778, Accuracy: 0.4673\n",
434 | "\n",
435 | "Train Epoch: 15 Epoch time: 16.7638 Loss: 1.6674301624298096\n",
436 | "Test set: Average loss: -0.13775004784464837, Accuracy: 0.4458\n",
437 | "\n",
438 | "Train Epoch: 16 Epoch time: 16.8711 Loss: 1.7492315769195557\n",
439 | "Test set: Average loss: -0.17715689504146576, Accuracy: 0.4702\n",
440 | "\n",
441 | "Train Epoch: 17 Epoch time: 16.8545 Loss: 1.4167125225067139\n",
442 | "Test set: Average loss: -0.13974063433408737, Accuracy: 0.4377\n",
443 | "\n",
444 | "Train Epoch: 18 Epoch time: 16.8798 Loss: 1.9101390838623047\n",
445 | "Test set: Average loss: -0.1401785773575306, Accuracy: 0.4311\n",
446 | "\n",
447 | "Train Epoch: 19 Epoch time: 16.8081 Loss: 1.99967622756958\n",
448 | "Test set: Average loss: -0.1516777255654335, Accuracy: 0.4445\n",
449 | "\n",
450 | "Train Epoch: 20 Epoch time: 16.7739 Loss: 2.2190351486206055\n",
451 | "Test set: Average loss: -0.1643785246193409, Accuracy: 0.4545\n",
452 | "\n",
453 | "Train Epoch: 21 Epoch time: 17.1348 Loss: 1.7950257062911987\n",
454 | "Test set: Average loss: -0.17522918835878373, Accuracy: 0.4574\n",
455 | "\n",
456 | "Train Epoch: 22 Epoch time: 16.9389 Loss: 1.5520081520080566\n",
457 | "Test set: Average loss: -0.15325675513744355, Accuracy: 0.4569\n",
458 | "\n",
459 | "Train Epoch: 23 Epoch time: 16.8428 Loss: 1.663689374923706\n",
460 | "Test set: Average loss: -0.17498237799406052, Accuracy: 0.4514\n",
461 | "\n",
462 | "Train Epoch: 24 Epoch time: 16.8972 Loss: 2.0213215351104736\n",
463 | "Test set: Average loss: -0.12325149633288383, Accuracy: 0.4142\n",
464 | "\n",
465 | "Train Epoch: 25 Epoch time: 16.8106 Loss: 2.0409820079803467\n",
466 | "Test set: Average loss: -0.12809701528549194, Accuracy: 0.4128\n",
467 | "\n",
468 | "Train Epoch: 26 Epoch time: 16.7998 Loss: 1.821916103363037\n",
469 | "Test set: Average loss: -0.13387244796156883, Accuracy: 0.4204\n",
470 | "\n",
471 | "Train Epoch: 27 Epoch time: 16.8039 Loss: 1.3090702295303345\n",
472 | "Test set: Average loss: -0.15662841452360154, Accuracy: 0.4604\n",
473 | "\n",
474 | "Train Epoch: 28 Epoch time: 16.9014 Loss: 1.2767252922058105\n",
475 | "Test set: Average loss: -0.17367174760699272, Accuracy: 0.4426\n",
476 | "\n",
477 | "Train Epoch: 29 Epoch time: 16.7807 Loss: 1.9910852909088135\n",
478 | "Test set: Average loss: -0.138916358846426, Accuracy: 0.3752\n",
479 | "\n",
480 | "Train Epoch: 30 Epoch time: 16.7389 Loss: 1.2423659563064575\n",
481 | "Test set: Average loss: -0.14196027348041534, Accuracy: 0.4424\n",
482 | "\n",
483 | "Train Epoch: 31 Epoch time: 16.8019 Loss: 1.6152451038360596\n",
484 | "Test set: Average loss: -0.15072507199048996, Accuracy: 0.4381\n",
485 | "\n",
486 | "Train Epoch: 32 Epoch time: 16.8864 Loss: 1.6289972066879272\n",
487 | "Test set: Average loss: -0.11902655415534973, Accuracy: 0.4208\n",
488 | "\n",
489 | "Train Epoch: 33 Epoch time: 16.7390 Loss: 1.3837391138076782\n",
490 | "Test set: Average loss: -0.15387935359477997, Accuracy: 0.4218\n",
491 | "\n",
492 | "Train Epoch: 34 Epoch time: 16.7903 Loss: 1.516900658607483\n",
493 | "Test set: Average loss: -0.1460586778521538, Accuracy: 0.4429\n",
494 | "\n",
495 | "Train Epoch: 35 Epoch time: 16.7645 Loss: 1.8662484884262085\n",
496 | "Test set: Average loss: -0.1500231746673584, Accuracy: 0.4308\n",
497 | "\n",
498 | "Train Epoch: 36 Epoch time: 16.7263 Loss: 1.570525884628296\n",
499 | "Test set: Average loss: -0.17256923723816872, Accuracy: 0.4522\n",
500 | "\n",
501 | "Train Epoch: 37 Epoch time: 16.7440 Loss: 1.7509129047393799\n",
502 | "Test set: Average loss: -0.1331960899591446, Accuracy: 0.4025\n",
503 | "\n",
504 | "Train Epoch: 38 Epoch time: 16.6792 Loss: 1.6465048789978027\n",
505 | "Test set: Average loss: -0.17207880718111992, Accuracy: 0.463\n",
506 | "\n",
507 | "Train Epoch: 39 Epoch time: 16.7337 Loss: 1.4757962226867676\n",
508 | "Test set: Average loss: -0.1252164648413658, Accuracy: 0.435\n",
509 | "\n",
510 | "Train Epoch: 40 Epoch time: 16.7789 Loss: 2.02087140083313\n",
511 | "Test set: Average loss: -0.12817251023054124, Accuracy: 0.4488\n",
512 | "\n",
513 | "Train Epoch: 41 Epoch time: 16.7497 Loss: 1.4727895259857178\n",
514 | "Test set: Average loss: -0.1616751208484173, Accuracy: 0.4522\n",
515 | "\n",
516 | "Train Epoch: 42 Epoch time: 16.8264 Loss: 1.214515209197998\n",
517 | "Test set: Average loss: -0.09364713019430637, Accuracy: 0.3973\n",
518 | "\n",
519 | "Train Epoch: 43 Epoch time: 16.6972 Loss: 1.874068021774292\n",
520 | "Test set: Average loss: -0.17249812202453613, Accuracy: 0.4564\n",
521 | "\n",
522 | "Train Epoch: 44 Epoch time: 16.9201 Loss: 1.4941129684448242\n",
523 | "Test set: Average loss: -0.14445521401762962, Accuracy: 0.4388\n",
524 | "\n",
525 | "Train Epoch: 45 Epoch time: 16.7735 Loss: 1.3479551076889038\n",
526 | "Test set: Average loss: -0.1443574243813753, Accuracy: 0.4296\n",
527 | "\n",
528 | "Train Epoch: 46 Epoch time: 16.7243 Loss: 1.9566048383712769\n",
529 | "Test set: Average loss: -0.1257232101917267, Accuracy: 0.3808\n",
530 | "\n",
531 | "Train Epoch: 47 Epoch time: 16.7802 Loss: 1.851891279220581\n",
532 | "Test set: Average loss: -0.09098090907037258, Accuracy: 0.3395\n",
533 | "\n",
534 | "Train Epoch: 48 Epoch time: 16.7901 Loss: 1.5767955780029297\n",
535 | "Test set: Average loss: -0.1268459755897522, Accuracy: 0.4005\n",
536 | "\n",
537 | "Train Epoch: 49 Epoch time: 16.7760 Loss: 1.6413697004318237\n",
538 | "Test set: Average loss: -0.13277969150543212, Accuracy: 0.4223\n",
539 | "\n",
540 | "Train Epoch: 50 Epoch time: 16.6827 Loss: 1.7723995447158813\n",
541 | "Test set: Average loss: -0.1290963588565588, Accuracy: 0.3967\n",
542 | "\n",
543 | "CPU times: user 16min 20s, sys: 5.69 s, total: 16min 25s\n",
544 | "Wall time: 16min 27s\n"
545 | ]
546 | }
547 | ],
548 | "source": [
549 | "%%time\n",
550 | "print(\"Train in compiled mode on CIFAR-10\")\n",
551 | "model = _get_model().to(device)\n",
552 | "train(model, batch_size=16, epochs=50)"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "execution_count": null,
558 | "id": "63ed4c0a-e375-4bc4-b321-fa107a9341b7",
559 | "metadata": {},
560 | "outputs": [],
561 | "source": []
562 | }
563 | ],
564 | "metadata": {
565 | "kernelspec": {
566 | "display_name": "Python 3 (ipykernel)",
567 | "language": "python",
568 | "name": "python3"
569 | },
570 | "language_info": {
571 | "codemirror_mode": {
572 | "name": "ipython",
573 | "version": 3
574 | },
575 | "file_extension": ".py",
576 | "mimetype": "text/x-python",
577 | "name": "python",
578 | "nbconvert_exporter": "python",
579 | "pygments_lexer": "ipython3",
580 | "version": "3.10.9"
581 | }
582 | },
583 | "nbformat": 4,
584 | "nbformat_minor": 5
585 | }
586 |
--------------------------------------------------------------------------------