├── code_snippets
├── torch_sle_loss.py
├── torch_autodiff_grad_hess.py
├── torch_autodiff_output.ipynb
├── torch_autodiff_housing_output.ipynb
└── jax_autodiff_housing_output.ipynb
├── README.md
├── LICENSE
└── benchmark.ipynb
/code_snippets/torch_sle_loss.py:
--------------------------------------------------------------------------------
1 | def torch_sle_loss(y_true: torch.Tensor, y_pred: torch.Tensor):
2 | """Calculate the Squared Log Error loss."""
3 | return 1/2 * (torch.log1p(y_pred) - torch.log1p(y_true))**2
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # JAX vs PyTorch: Automatic Differentiation for XGBoost
2 | Code for the [corresponding Medium article](https://medium.com/@1danielr/jax-vs-pytorch-automatic-differentiation-for-xgboost-10222e1404ec).
3 |
4 | ## Benchmark
5 | Please see `benchmark.ipynb` for a working copy of the entire codebase
6 | together with the code used to generate the benchmark.
7 |
--------------------------------------------------------------------------------
/code_snippets/torch_autodiff_grad_hess.py:
--------------------------------------------------------------------------------
1 | def torch_autodiff_grad_hess(
2 | loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
3 | y_true: np.ndarray, y_pred: np.ndarray
4 | ):
5 | """Perform automatic differentiation to get the
6 | Gradient and the Hessian of `loss_function`."""
7 | y_true = torch.tensor(y_true, dtype=torch.float, requires_grad=False)
8 | y_pred = torch.tensor(y_pred, dtype=torch.float, requires_grad=True)
9 | loss_function_sum = lambda y_pred: loss_function(y_true, y_pred).sum()
10 |
11 | loss_function_sum(y_pred).backward()
12 | grad = y_pred.grad
13 |
14 | hess_matrix = torch.autograd.functional.hessian(loss_function_sum, y_pred, vectorize=True)
15 | hess = torch.diagonal(hess_matrix)
16 |
17 | return grad, hess
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Daniel Reedstone
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/code_snippets/torch_autodiff_output.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "b698ec43-1036-487e-9f34-6b7ba11a4682",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "data": {
11 | "text/plain": [
12 | "(tensor([ 0.3466, 0.0000, -0.1014]), tensor([0.0767, 0.1111, 0.0878]))"
13 | ]
14 | },
15 | "execution_count": 1,
16 | "metadata": {},
17 | "output_type": "execute_result"
18 | }
19 | ],
20 | "source": [
21 | "y_true = np.array([0, 2, 5])\n",
22 | "y_pred = np.array([1, 2, 3])\n",
23 | "torch_autodiff_grad_hess(torch_sle_loss, y_true, y_pred)"
24 | ]
25 | }
26 | ],
27 | "metadata": {
28 | "kernelspec": {
29 | "display_name": "Python 3 (ipykernel)",
30 | "language": "python",
31 | "name": "python3"
32 | },
33 | "language_info": {
34 | "codemirror_mode": {
35 | "name": "ipython",
36 | "version": 3
37 | },
38 | "file_extension": ".py",
39 | "mimetype": "text/x-python",
40 | "name": "python",
41 | "nbconvert_exporter": "python",
42 | "pygments_lexer": "ipython3",
43 | "version": "3.9.12"
44 | }
45 | },
46 | "nbformat": 4,
47 | "nbformat_minor": 5
48 | }
49 |
--------------------------------------------------------------------------------
/code_snippets/torch_autodiff_housing_output.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 24,
6 | "id": "87710599-4884-481b-b687-e90fab4156dc",
7 | "metadata": {
8 | "colab": {
9 | "base_uri": "https://localhost:8080/"
10 | },
11 | "id": "IztnR9BhvGLQ",
12 | "outputId": "0d75c392-c958-4494-e3f9-461c57c610f3"
13 | },
14 | "outputs": [
15 | {
16 | "name": "stdout",
17 | "output_type": "stream",
18 | "text": [
19 | "Train Data: 4128 examples, 8 features\n"
20 | ]
21 | }
22 | ],
23 | "source": [
24 | "X, y = sklearn.datasets.fetch_california_housing(return_X_y=True)\n",
25 | "X_train, X_test, y_train, y_test = train_test_split(\n",
26 | " X, y, train_size=0.2, random_state=0)\n",
27 | "print(f\"Train Data: {X_train.shape[0]} examples, {X_train.shape[1]} features\")"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 26,
33 | "id": "7nrq9G_tA7D6",
34 | "metadata": {
35 | "id": "7nrq9G_tA7D6"
36 | },
37 | "outputs": [
38 | {
39 | "name": "stdout",
40 | "output_type": "stream",
41 | "text": [
42 | "14.6 s ± 286 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
43 | ]
44 | }
45 | ],
46 | "source": [
47 | "%%timeit\n",
48 | "torch_objective = partial(torch_autodiff_grad_hess, torch_sle_loss)\n",
49 | "reg = XGBRegressor(objective=torch_objective, n_estimators=100)\n",
50 | "reg.fit(X_train, y_train)"
51 | ]
52 | }
53 | ],
54 | "metadata": {
55 | "kernelspec": {
56 | "display_name": "Python 3 (ipykernel)",
57 | "language": "python",
58 | "name": "python3"
59 | },
60 | "language_info": {
61 | "codemirror_mode": {
62 | "name": "ipython",
63 | "version": 3
64 | },
65 | "file_extension": ".py",
66 | "mimetype": "text/x-python",
67 | "name": "python",
68 | "nbconvert_exporter": "python",
69 | "pygments_lexer": "ipython3",
70 | "version": "3.9.12"
71 | }
72 | },
73 | "nbformat": 4,
74 | "nbformat_minor": 5
75 | }
--------------------------------------------------------------------------------
/code_snippets/jax_autodiff_housing_output.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 24,
6 | "id": "IztnR9BhvGLQ",
7 | "metadata": {
8 | "colab": {
9 | "base_uri": "https://localhost:8080/"
10 | },
11 | "id": "IztnR9BhvGLQ",
12 | "outputId": "0d75c392-c958-4494-e3f9-461c57c610f3"
13 | },
14 | "outputs": [
15 | {
16 | "name": "stdout",
17 | "output_type": "stream",
18 | "text": [
19 | "Train Data: 4128 examples, 8 features\n"
20 | ]
21 | }
22 | ],
23 | "source": [
24 | "X, y = sklearn.datasets.fetch_california_housing(return_X_y=True)\n",
25 | "X_train, X_test, y_train, y_test = train_test_split(\n",
26 | " X, y, train_size=0.2, random_state=0)\n",
27 | "print(f\"Train Data: {X_train.shape[0]} examples, {X_train.shape[1]} features\")"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 25,
33 | "id": "5LzZ3S0ZA2YX",
34 | "metadata": {
35 | "colab": {
36 | "base_uri": "https://localhost:8080/"
37 | },
38 | "id": "5LzZ3S0ZA2YX",
39 | "outputId": "d4b04c11-e9e4-431b-917c-467807471453"
40 | },
41 | "outputs": [
42 | {
43 | "name": "stdout",
44 | "output_type": "stream",
45 | "text": [
46 | "6.89 s ± 119 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "%%timeit\n",
52 | "jax_objective = jax.jit(partial(jax_autodiff_grad_hess, jax_sle_loss))\n",
53 | "reg = XGBRegressor(objective=jax_objective, n_estimators=100)\n",
54 | "reg.fit(X_train, y_train)"
55 | ]
56 | }
57 | ],
58 | "metadata": {
59 | "kernelspec": {
60 | "display_name": "Python 3 (ipykernel)",
61 | "language": "python",
62 | "name": "python3"
63 | },
64 | "language_info": {
65 | "codemirror_mode": {
66 | "name": "ipython",
67 | "version": 3
68 | },
69 | "file_extension": ".py",
70 | "mimetype": "text/x-python",
71 | "name": "python",
72 | "nbconvert_exporter": "python",
73 | "pygments_lexer": "ipython3",
74 | "version": "3.9.12"
75 | }
76 | },
77 | "nbformat": 4,
78 | "nbformat_minor": 5
79 | }
80 |
--------------------------------------------------------------------------------
/benchmark.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "57e52aee-3006-4140-8f6e-668dde48e5b5",
6 | "metadata": {
7 | "id": "0e09e47d-4b93-489d-89af-807d56d55519"
8 | },
9 | "source": [
10 | "### Imports"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "04ffda5f-69b6-4f46-83c7-a06fd02e7dca",
17 | "metadata": {
18 | "id": "5ec61494-67ed-4bec-a39a-d1991d6650c3"
19 | },
20 | "outputs": [],
21 | "source": [
22 | "import numpy as np\n",
23 | "import pandas as pd\n",
24 | "\n",
25 | "import sklearn.datasets\n",
26 | "from sklearn.model_selection import train_test_split\n",
27 | "from sklearn.utils import resample\n",
28 | "\n",
29 | "from xgboost.sklearn import XGBRegressor\n",
30 | "\n",
31 | "import torch\n",
32 | "\n",
33 | "import jax\n",
34 | "import jax.numpy as jnp\n",
35 | "\n",
36 | "from tqdm.notebook import tqdm\n",
37 | "\n",
38 | "import plotly.express as px\n",
39 | "\n",
40 | "import timeit\n",
41 | "from functools import partial\n",
42 | "from typing import Callable"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "id": "60b50c07-69e9-4e10-9aff-b85932a18480",
48 | "metadata": {},
49 | "source": [
50 | "## Automatic Differentiation"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "id": "f9c1bc1e-239a-46dd-9b84-2d47852772e9",
56 | "metadata": {},
57 | "source": [
58 | "### PyTorch"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 2,
64 | "id": "78f7dddc-899c-439b-bae0-4c294c933222",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "def torch_sle_loss(y_true: torch.Tensor, y_pred: torch.Tensor):\n",
69 | " \"\"\"Calculate the Squared Log Error loss.\"\"\"\n",
70 | " return 1/2 * (torch.log1p(y_pred) - torch.log1p(y_true))**2\n",
71 | "\n",
72 | "def torch_autodiff_grad_hess(\n",
73 | " loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n",
74 | " y_true: np.ndarray, y_pred: np.ndarray\n",
75 | "):\n",
76 | " \"\"\"Perform automatic differentiation to get the\n",
77 | " Gradient and the Hessian of `loss_function`.\"\"\"\n",
78 | " y_true = torch.tensor(y_true, dtype=torch.float, requires_grad=False)\n",
79 | " y_pred = torch.tensor(y_pred, dtype=torch.float, requires_grad=True)\n",
80 | " loss_function_sum = lambda y_pred: loss_function(y_true, y_pred).sum()\n",
81 | "\n",
82 | " loss_function_sum(y_pred).backward()\n",
83 | " grad = y_pred.grad\n",
84 | "\n",
85 | " hess_matrix = torch.autograd.functional.hessian(loss_function_sum, y_pred, vectorize=True)\n",
86 | " hess = torch.diagonal(hess_matrix)\n",
87 | "\n",
88 | " return grad, hess"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "112b872a-37bd-4fdc-a384-fdfceb969eea",
94 | "metadata": {},
95 | "source": [
96 | "### JAX"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 3,
102 | "id": "781a05ba-0bac-44f6-a0eb-39d530952737",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "def jax_sle_loss(y_true: np.ndarray, y_pred: np.ndarray):\n",
107 | " \"\"\"Calculate the Squared Log Error loss.\"\"\"\n",
108 | " return (1/2 * (jnp.log1p(y_pred) - jnp.log1p(y_true))**2)\n",
109 | "\n",
110 | "def hvp(f, inputs, vectors):\n",
111 | " \"\"\"Hessian-vector product.\"\"\"\n",
112 | " return jax.jvp(jax.grad(f), inputs, vectors)[1]\n",
113 | "\n",
114 | "def jax_autodiff_grad_hess(\n",
115 | " loss_function: Callable[[np.ndarray, np.ndarray], np.ndarray],\n",
116 | " y_true: np.ndarray, y_pred: np.ndarray\n",
117 | "):\n",
118 | " \"\"\"Perform automatic differentiation to get the\n",
119 | " Gradient and the Hessian of `loss_function`.\"\"\"\n",
120 | " loss_function_sum = lambda y_pred: loss_function(y_true, y_pred).sum()\n",
121 | "\n",
122 | " grad_fn = jax.grad(loss_function_sum)\n",
123 | " grad = grad_fn(y_pred)\n",
124 | "\n",
125 | " hess = hvp(loss_function_sum, (y_pred,), (jnp.ones_like(y_pred), ))\n",
126 | "\n",
127 | " return grad, hess"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "id": "955fbb3f-4f52-40ab-9abf-9a68f05529ac",
133 | "metadata": {},
134 | "source": [
135 | "# Benchmark"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "id": "cb12f0e5-07bd-40ff-bb57-3902aea0b258",
141 | "metadata": {},
142 | "source": [
143 | "### Basic Benchmark"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 4,
149 | "id": "38a094ab-aa20-4745-8d24-b57b03ab4282",
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "name": "stdout",
154 | "output_type": "stream",
155 | "text": [
156 | "Train Data: 1032 examples, 8 features\n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "X, y = sklearn.datasets.fetch_california_housing(return_X_y=True)\n",
162 | "X_train, X_test, y_train, y_test = train_test_split(\n",
163 | " X, y, train_size=0.05, random_state=0)\n",
164 | "print(f\"Train Data: {X_train.shape[0]} examples, {X_train.shape[1]} features\")"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 5,
170 | "id": "631498c2-b23f-475c-af37-921103e10e29",
171 | "metadata": {},
172 | "outputs": [
173 | {
174 | "name": "stderr",
175 | "output_type": "stream",
176 | "text": [
177 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
178 | ]
179 | },
180 | {
181 | "name": "stdout",
182 | "output_type": "stream",
183 | "text": [
184 | "1.79 s ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
185 | ]
186 | }
187 | ],
188 | "source": [
189 | "%%timeit\n",
190 | "jax_objective = jax.jit(partial(jax_autodiff_grad_hess, jax_sle_loss))\n",
191 | "reg = XGBRegressor(objective=jax_objective, n_estimators=100)\n",
192 | "reg.fit(X_train, y_train)"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 6,
198 | "id": "55eccdf6-4b75-4f1d-b45c-9e2b8a33394b",
199 | "metadata": {},
200 | "outputs": [
201 | {
202 | "name": "stdout",
203 | "output_type": "stream",
204 | "text": [
205 | "1.33 s ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
206 | ]
207 | }
208 | ],
209 | "source": [
210 | "%%timeit\n",
211 | "torch_objective = partial(torch_autodiff_grad_hess, torch_sle_loss)\n",
212 | "reg = XGBRegressor(objective=torch_objective, n_estimators=100)\n",
213 | "reg.fit(X_train, y_train)"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "id": "5cf179d8-af2a-414f-9bcc-56bfb926d9e8",
219 | "metadata": {},
220 | "source": [
221 | "## Full Benchmark"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 7,
227 | "id": "8dfa24b4-11ca-4fce-977e-bad2efc02c15",
228 | "metadata": {},
229 | "outputs": [],
230 | "source": [
231 | "torch_objective = partial(torch_autodiff_grad_hess, torch_sle_loss)\n",
232 | "jax_objective = jax.jit(partial(jax_autodiff_grad_hess, jax_sle_loss))"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": 8,
238 | "id": "f58f411f-1316-4663-ab6d-02d641f519c0",
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "def run_benchmark():\n",
243 | " VERBOSE = 0\n",
244 | " TIMEIT_NUMBER = 5\n",
245 | " N_ESTIMATORS = 10\n",
246 | " N_SAMPLES_OPTIONS = [100, 1000, 2000, 5000, 10000, 20000]\n",
247 | "\n",
248 | " def run_torch():\n",
249 | " print(f\"Running torch with {len(X_train)} training examples.\") if VERBOSE else None\n",
250 | " reg = XGBRegressor(objective=torch_objective, n_estimators=N_ESTIMATORS)\n",
251 | " reg.fit(X_train, y_train)\n",
252 | " \n",
253 | " def run_jax():\n",
254 | " print(f\"Running jax with {len(X_train)} training examples.\") if VERBOSE else None\n",
255 | " reg = XGBRegressor(objective=jax_objective, n_estimators=N_ESTIMATORS)\n",
256 | " reg.fit(X_train, y_train)\n",
257 | "\n",
258 | " def run_manual():\n",
259 | " reg = XGBRegressor(objective='reg:squaredlogerror', n_estimators=N_ESTIMATORS)\n",
260 | " reg.fit(X_train, y_train)\n",
261 | "\n",
262 | " entries = []\n",
263 | " for n_samples in tqdm(N_SAMPLES_OPTIONS):\n",
264 | " entry = {'n_samples': n_samples}\n",
265 | " X_train, y_train = resample(X, y, n_samples=n_samples)\n",
266 | " \n",
267 | " entry['torch'] = timeit.timeit(run_torch, number=TIMEIT_NUMBER) / TIMEIT_NUMBER\n",
268 | " entry['jax'] = timeit.timeit(run_jax, number=TIMEIT_NUMBER) / TIMEIT_NUMBER\n",
269 | " entry['manual'] = timeit.timeit(run_manual, number=TIMEIT_NUMBER) / TIMEIT_NUMBER\n",
270 | "\n",
271 | " entries.append(entry)\n",
272 | " entries = pd.DataFrame(entries)\n",
273 | " return entries"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 9,
279 | "id": "ec9b7475-c2fc-46ed-b30e-d7fe1f6f468a",
280 | "metadata": {},
281 | "outputs": [
282 | {
283 | "data": {
284 | "application/vnd.jupyter.widget-view+json": {
285 | "model_id": "98b8541d9d0143529dc155ae5007a82c",
286 | "version_major": 2,
287 | "version_minor": 0
288 | },
289 | "text/plain": [
290 | " 0%| | 0/6 [00:00, ?it/s]"
291 | ]
292 | },
293 | "metadata": {},
294 | "output_type": "display_data"
295 | }
296 | ],
297 | "source": [
298 | "benchmark_results_df = run_benchmark()"
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "id": "fd7a9d34-d2b1-4193-a3a6-b0902605a089",
304 | "metadata": {
305 | "tags": []
306 | },
307 | "source": [
308 | "### Plotting"
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": 16,
314 | "id": "2dc3c0fc-5e08-416e-ad73-2b5563b87b59",
315 | "metadata": {},
316 | "outputs": [],
317 | "source": [
318 | "plot_df = benchmark_results_df.melt(id_vars=['n_samples'], value_vars=['torch', 'jax', 'manual'],\n",
319 | " value_name='time', var_name='implementation')\n",
320 | "plot_df = plot_df[(plot_df['implementation'] != 'manual') & (plot_df['n_samples'] >= 2000)]\n",
321 | "plot_df = plot_df.replace({'implementation': {'torch': 'PyTorch', 'jax': 'JAX', 'manual': 'Manual'}})"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": 17,
327 | "id": "uriY3ZKooldE",
328 | "metadata": {
329 | "colab": {
330 | "base_uri": "https://localhost:8080/",
331 | "height": 542
332 | },
333 | "id": "uriY3ZKooldE",
334 | "outputId": "946d2878-ae00-4dc2-b843-731d74bfa586"
335 | },
336 | "outputs": [
337 | {
338 | "data": {
339 | "application/vnd.plotly.v1+json": {
340 | "config": {
341 | "plotlyServerURL": "https://plot.ly"
342 | },
343 | "data": [
344 | {
345 | "alignmentgroup": "True",
346 | "hovertemplate": "Implementation=PyTorch
Number of Data Points=%{x}
Time (s)=%{text}",
347 | "legendgroup": "PyTorch",
348 | "marker": {
349 | "color": "#636efa",
350 | "pattern": {
351 | "shape": ""
352 | }
353 | },
354 | "name": "PyTorch",
355 | "offsetgroup": "PyTorch",
356 | "orientation": "v",
357 | "showlegend": true,
358 | "text": [
359 | 0.35879376000000035,
360 | 1.9320429399999994,
361 | 7.311690239999999,
362 | 28.3964026
363 | ],
364 | "textposition": "outside",
365 | "texttemplate": "%{y:.2}",
366 | "type": "bar",
367 | "x": [
368 | 2000,
369 | 5000,
370 | 10000,
371 | 20000
372 | ],
373 | "xaxis": "x",
374 | "y": [
375 | 0.35879376000000035,
376 | 1.9320429399999994,
377 | 7.311690239999999,
378 | 28.3964026
379 | ],
380 | "yaxis": "y"
381 | },
382 | {
383 | "alignmentgroup": "True",
384 | "hovertemplate": "Implementation=JAX
Number of Data Points=%{x}
Time (s)=%{text}",
385 | "legendgroup": "JAX",
386 | "marker": {
387 | "color": "#EF553B",
388 | "pattern": {
389 | "shape": ""
390 | }
391 | },
392 | "name": "JAX",
393 | "offsetgroup": "JAX",
394 | "orientation": "v",
395 | "showlegend": true,
396 | "text": [
397 | 0.3323473000000007,
398 | 0.8418071999999995,
399 | 1.4339120000000007,
400 | 2.587041499999998
401 | ],
402 | "textposition": "outside",
403 | "texttemplate": "%{y:.2}",
404 | "type": "bar",
405 | "x": [
406 | 2000,
407 | 5000,
408 | 10000,
409 | 20000
410 | ],
411 | "xaxis": "x",
412 | "y": [
413 | 0.3323473000000007,
414 | 0.8418071999999995,
415 | 1.4339120000000007,
416 | 2.587041499999998
417 | ],
418 | "yaxis": "y"
419 | }
420 | ],
421 | "layout": {
422 | "autosize": true,
423 | "barmode": "group",
424 | "font": {
425 | "size": 24
426 | },
427 | "legend": {
428 | "title": {
429 | "text": "Implementation"
430 | },
431 | "tracegroupgap": 0
432 | },
433 | "margin": {
434 | "t": 70
435 | },
436 | "template": {
437 | "data": {
438 | "bar": [
439 | {
440 | "error_x": {
441 | "color": "#2a3f5f"
442 | },
443 | "error_y": {
444 | "color": "#2a3f5f"
445 | },
446 | "marker": {
447 | "line": {
448 | "color": "#E5ECF6",
449 | "width": 0.5
450 | },
451 | "pattern": {
452 | "fillmode": "overlay",
453 | "size": 10,
454 | "solidity": 0.2
455 | }
456 | },
457 | "type": "bar"
458 | }
459 | ],
460 | "barpolar": [
461 | {
462 | "marker": {
463 | "line": {
464 | "color": "#E5ECF6",
465 | "width": 0.5
466 | },
467 | "pattern": {
468 | "fillmode": "overlay",
469 | "size": 10,
470 | "solidity": 0.2
471 | }
472 | },
473 | "type": "barpolar"
474 | }
475 | ],
476 | "carpet": [
477 | {
478 | "aaxis": {
479 | "endlinecolor": "#2a3f5f",
480 | "gridcolor": "white",
481 | "linecolor": "white",
482 | "minorgridcolor": "white",
483 | "startlinecolor": "#2a3f5f"
484 | },
485 | "baxis": {
486 | "endlinecolor": "#2a3f5f",
487 | "gridcolor": "white",
488 | "linecolor": "white",
489 | "minorgridcolor": "white",
490 | "startlinecolor": "#2a3f5f"
491 | },
492 | "type": "carpet"
493 | }
494 | ],
495 | "choropleth": [
496 | {
497 | "colorbar": {
498 | "outlinewidth": 0,
499 | "ticks": ""
500 | },
501 | "type": "choropleth"
502 | }
503 | ],
504 | "contour": [
505 | {
506 | "colorbar": {
507 | "outlinewidth": 0,
508 | "ticks": ""
509 | },
510 | "colorscale": [
511 | [
512 | 0,
513 | "#0d0887"
514 | ],
515 | [
516 | 0.1111111111111111,
517 | "#46039f"
518 | ],
519 | [
520 | 0.2222222222222222,
521 | "#7201a8"
522 | ],
523 | [
524 | 0.3333333333333333,
525 | "#9c179e"
526 | ],
527 | [
528 | 0.4444444444444444,
529 | "#bd3786"
530 | ],
531 | [
532 | 0.5555555555555556,
533 | "#d8576b"
534 | ],
535 | [
536 | 0.6666666666666666,
537 | "#ed7953"
538 | ],
539 | [
540 | 0.7777777777777778,
541 | "#fb9f3a"
542 | ],
543 | [
544 | 0.8888888888888888,
545 | "#fdca26"
546 | ],
547 | [
548 | 1,
549 | "#f0f921"
550 | ]
551 | ],
552 | "type": "contour"
553 | }
554 | ],
555 | "contourcarpet": [
556 | {
557 | "colorbar": {
558 | "outlinewidth": 0,
559 | "ticks": ""
560 | },
561 | "type": "contourcarpet"
562 | }
563 | ],
564 | "heatmap": [
565 | {
566 | "colorbar": {
567 | "outlinewidth": 0,
568 | "ticks": ""
569 | },
570 | "colorscale": [
571 | [
572 | 0,
573 | "#0d0887"
574 | ],
575 | [
576 | 0.1111111111111111,
577 | "#46039f"
578 | ],
579 | [
580 | 0.2222222222222222,
581 | "#7201a8"
582 | ],
583 | [
584 | 0.3333333333333333,
585 | "#9c179e"
586 | ],
587 | [
588 | 0.4444444444444444,
589 | "#bd3786"
590 | ],
591 | [
592 | 0.5555555555555556,
593 | "#d8576b"
594 | ],
595 | [
596 | 0.6666666666666666,
597 | "#ed7953"
598 | ],
599 | [
600 | 0.7777777777777778,
601 | "#fb9f3a"
602 | ],
603 | [
604 | 0.8888888888888888,
605 | "#fdca26"
606 | ],
607 | [
608 | 1,
609 | "#f0f921"
610 | ]
611 | ],
612 | "type": "heatmap"
613 | }
614 | ],
615 | "heatmapgl": [
616 | {
617 | "colorbar": {
618 | "outlinewidth": 0,
619 | "ticks": ""
620 | },
621 | "colorscale": [
622 | [
623 | 0,
624 | "#0d0887"
625 | ],
626 | [
627 | 0.1111111111111111,
628 | "#46039f"
629 | ],
630 | [
631 | 0.2222222222222222,
632 | "#7201a8"
633 | ],
634 | [
635 | 0.3333333333333333,
636 | "#9c179e"
637 | ],
638 | [
639 | 0.4444444444444444,
640 | "#bd3786"
641 | ],
642 | [
643 | 0.5555555555555556,
644 | "#d8576b"
645 | ],
646 | [
647 | 0.6666666666666666,
648 | "#ed7953"
649 | ],
650 | [
651 | 0.7777777777777778,
652 | "#fb9f3a"
653 | ],
654 | [
655 | 0.8888888888888888,
656 | "#fdca26"
657 | ],
658 | [
659 | 1,
660 | "#f0f921"
661 | ]
662 | ],
663 | "type": "heatmapgl"
664 | }
665 | ],
666 | "histogram": [
667 | {
668 | "marker": {
669 | "pattern": {
670 | "fillmode": "overlay",
671 | "size": 10,
672 | "solidity": 0.2
673 | }
674 | },
675 | "type": "histogram"
676 | }
677 | ],
678 | "histogram2d": [
679 | {
680 | "colorbar": {
681 | "outlinewidth": 0,
682 | "ticks": ""
683 | },
684 | "colorscale": [
685 | [
686 | 0,
687 | "#0d0887"
688 | ],
689 | [
690 | 0.1111111111111111,
691 | "#46039f"
692 | ],
693 | [
694 | 0.2222222222222222,
695 | "#7201a8"
696 | ],
697 | [
698 | 0.3333333333333333,
699 | "#9c179e"
700 | ],
701 | [
702 | 0.4444444444444444,
703 | "#bd3786"
704 | ],
705 | [
706 | 0.5555555555555556,
707 | "#d8576b"
708 | ],
709 | [
710 | 0.6666666666666666,
711 | "#ed7953"
712 | ],
713 | [
714 | 0.7777777777777778,
715 | "#fb9f3a"
716 | ],
717 | [
718 | 0.8888888888888888,
719 | "#fdca26"
720 | ],
721 | [
722 | 1,
723 | "#f0f921"
724 | ]
725 | ],
726 | "type": "histogram2d"
727 | }
728 | ],
729 | "histogram2dcontour": [
730 | {
731 | "colorbar": {
732 | "outlinewidth": 0,
733 | "ticks": ""
734 | },
735 | "colorscale": [
736 | [
737 | 0,
738 | "#0d0887"
739 | ],
740 | [
741 | 0.1111111111111111,
742 | "#46039f"
743 | ],
744 | [
745 | 0.2222222222222222,
746 | "#7201a8"
747 | ],
748 | [
749 | 0.3333333333333333,
750 | "#9c179e"
751 | ],
752 | [
753 | 0.4444444444444444,
754 | "#bd3786"
755 | ],
756 | [
757 | 0.5555555555555556,
758 | "#d8576b"
759 | ],
760 | [
761 | 0.6666666666666666,
762 | "#ed7953"
763 | ],
764 | [
765 | 0.7777777777777778,
766 | "#fb9f3a"
767 | ],
768 | [
769 | 0.8888888888888888,
770 | "#fdca26"
771 | ],
772 | [
773 | 1,
774 | "#f0f921"
775 | ]
776 | ],
777 | "type": "histogram2dcontour"
778 | }
779 | ],
780 | "mesh3d": [
781 | {
782 | "colorbar": {
783 | "outlinewidth": 0,
784 | "ticks": ""
785 | },
786 | "type": "mesh3d"
787 | }
788 | ],
789 | "parcoords": [
790 | {
791 | "line": {
792 | "colorbar": {
793 | "outlinewidth": 0,
794 | "ticks": ""
795 | }
796 | },
797 | "type": "parcoords"
798 | }
799 | ],
800 | "pie": [
801 | {
802 | "automargin": true,
803 | "type": "pie"
804 | }
805 | ],
806 | "scatter": [
807 | {
808 | "fillpattern": {
809 | "fillmode": "overlay",
810 | "size": 10,
811 | "solidity": 0.2
812 | },
813 | "type": "scatter"
814 | }
815 | ],
816 | "scatter3d": [
817 | {
818 | "line": {
819 | "colorbar": {
820 | "outlinewidth": 0,
821 | "ticks": ""
822 | }
823 | },
824 | "marker": {
825 | "colorbar": {
826 | "outlinewidth": 0,
827 | "ticks": ""
828 | }
829 | },
830 | "type": "scatter3d"
831 | }
832 | ],
833 | "scattercarpet": [
834 | {
835 | "marker": {
836 | "colorbar": {
837 | "outlinewidth": 0,
838 | "ticks": ""
839 | }
840 | },
841 | "type": "scattercarpet"
842 | }
843 | ],
844 | "scattergeo": [
845 | {
846 | "marker": {
847 | "colorbar": {
848 | "outlinewidth": 0,
849 | "ticks": ""
850 | }
851 | },
852 | "type": "scattergeo"
853 | }
854 | ],
855 | "scattergl": [
856 | {
857 | "marker": {
858 | "colorbar": {
859 | "outlinewidth": 0,
860 | "ticks": ""
861 | }
862 | },
863 | "type": "scattergl"
864 | }
865 | ],
866 | "scattermapbox": [
867 | {
868 | "marker": {
869 | "colorbar": {
870 | "outlinewidth": 0,
871 | "ticks": ""
872 | }
873 | },
874 | "type": "scattermapbox"
875 | }
876 | ],
877 | "scatterpolar": [
878 | {
879 | "marker": {
880 | "colorbar": {
881 | "outlinewidth": 0,
882 | "ticks": ""
883 | }
884 | },
885 | "type": "scatterpolar"
886 | }
887 | ],
888 | "scatterpolargl": [
889 | {
890 | "marker": {
891 | "colorbar": {
892 | "outlinewidth": 0,
893 | "ticks": ""
894 | }
895 | },
896 | "type": "scatterpolargl"
897 | }
898 | ],
899 | "scatterternary": [
900 | {
901 | "marker": {
902 | "colorbar": {
903 | "outlinewidth": 0,
904 | "ticks": ""
905 | }
906 | },
907 | "type": "scatterternary"
908 | }
909 | ],
910 | "surface": [
911 | {
912 | "colorbar": {
913 | "outlinewidth": 0,
914 | "ticks": ""
915 | },
916 | "colorscale": [
917 | [
918 | 0,
919 | "#0d0887"
920 | ],
921 | [
922 | 0.1111111111111111,
923 | "#46039f"
924 | ],
925 | [
926 | 0.2222222222222222,
927 | "#7201a8"
928 | ],
929 | [
930 | 0.3333333333333333,
931 | "#9c179e"
932 | ],
933 | [
934 | 0.4444444444444444,
935 | "#bd3786"
936 | ],
937 | [
938 | 0.5555555555555556,
939 | "#d8576b"
940 | ],
941 | [
942 | 0.6666666666666666,
943 | "#ed7953"
944 | ],
945 | [
946 | 0.7777777777777778,
947 | "#fb9f3a"
948 | ],
949 | [
950 | 0.8888888888888888,
951 | "#fdca26"
952 | ],
953 | [
954 | 1,
955 | "#f0f921"
956 | ]
957 | ],
958 | "type": "surface"
959 | }
960 | ],
961 | "table": [
962 | {
963 | "cells": {
964 | "fill": {
965 | "color": "#EBF0F8"
966 | },
967 | "line": {
968 | "color": "white"
969 | }
970 | },
971 | "header": {
972 | "fill": {
973 | "color": "#C8D4E3"
974 | },
975 | "line": {
976 | "color": "white"
977 | }
978 | },
979 | "type": "table"
980 | }
981 | ]
982 | },
983 | "layout": {
984 | "annotationdefaults": {
985 | "arrowcolor": "#2a3f5f",
986 | "arrowhead": 0,
987 | "arrowwidth": 1
988 | },
989 | "autotypenumbers": "strict",
990 | "coloraxis": {
991 | "colorbar": {
992 | "outlinewidth": 0,
993 | "ticks": ""
994 | }
995 | },
996 | "colorscale": {
997 | "diverging": [
998 | [
999 | 0,
1000 | "#8e0152"
1001 | ],
1002 | [
1003 | 0.1,
1004 | "#c51b7d"
1005 | ],
1006 | [
1007 | 0.2,
1008 | "#de77ae"
1009 | ],
1010 | [
1011 | 0.3,
1012 | "#f1b6da"
1013 | ],
1014 | [
1015 | 0.4,
1016 | "#fde0ef"
1017 | ],
1018 | [
1019 | 0.5,
1020 | "#f7f7f7"
1021 | ],
1022 | [
1023 | 0.6,
1024 | "#e6f5d0"
1025 | ],
1026 | [
1027 | 0.7,
1028 | "#b8e186"
1029 | ],
1030 | [
1031 | 0.8,
1032 | "#7fbc41"
1033 | ],
1034 | [
1035 | 0.9,
1036 | "#4d9221"
1037 | ],
1038 | [
1039 | 1,
1040 | "#276419"
1041 | ]
1042 | ],
1043 | "sequential": [
1044 | [
1045 | 0,
1046 | "#0d0887"
1047 | ],
1048 | [
1049 | 0.1111111111111111,
1050 | "#46039f"
1051 | ],
1052 | [
1053 | 0.2222222222222222,
1054 | "#7201a8"
1055 | ],
1056 | [
1057 | 0.3333333333333333,
1058 | "#9c179e"
1059 | ],
1060 | [
1061 | 0.4444444444444444,
1062 | "#bd3786"
1063 | ],
1064 | [
1065 | 0.5555555555555556,
1066 | "#d8576b"
1067 | ],
1068 | [
1069 | 0.6666666666666666,
1070 | "#ed7953"
1071 | ],
1072 | [
1073 | 0.7777777777777778,
1074 | "#fb9f3a"
1075 | ],
1076 | [
1077 | 0.8888888888888888,
1078 | "#fdca26"
1079 | ],
1080 | [
1081 | 1,
1082 | "#f0f921"
1083 | ]
1084 | ],
1085 | "sequentialminus": [
1086 | [
1087 | 0,
1088 | "#0d0887"
1089 | ],
1090 | [
1091 | 0.1111111111111111,
1092 | "#46039f"
1093 | ],
1094 | [
1095 | 0.2222222222222222,
1096 | "#7201a8"
1097 | ],
1098 | [
1099 | 0.3333333333333333,
1100 | "#9c179e"
1101 | ],
1102 | [
1103 | 0.4444444444444444,
1104 | "#bd3786"
1105 | ],
1106 | [
1107 | 0.5555555555555556,
1108 | "#d8576b"
1109 | ],
1110 | [
1111 | 0.6666666666666666,
1112 | "#ed7953"
1113 | ],
1114 | [
1115 | 0.7777777777777778,
1116 | "#fb9f3a"
1117 | ],
1118 | [
1119 | 0.8888888888888888,
1120 | "#fdca26"
1121 | ],
1122 | [
1123 | 1,
1124 | "#f0f921"
1125 | ]
1126 | ]
1127 | },
1128 | "colorway": [
1129 | "#636efa",
1130 | "#EF553B",
1131 | "#00cc96",
1132 | "#ab63fa",
1133 | "#FFA15A",
1134 | "#19d3f3",
1135 | "#FF6692",
1136 | "#B6E880",
1137 | "#FF97FF",
1138 | "#FECB52"
1139 | ],
1140 | "font": {
1141 | "color": "#2a3f5f"
1142 | },
1143 | "geo": {
1144 | "bgcolor": "white",
1145 | "lakecolor": "white",
1146 | "landcolor": "#E5ECF6",
1147 | "showlakes": true,
1148 | "showland": true,
1149 | "subunitcolor": "white"
1150 | },
1151 | "hoverlabel": {
1152 | "align": "left"
1153 | },
1154 | "hovermode": "closest",
1155 | "mapbox": {
1156 | "style": "light"
1157 | },
1158 | "paper_bgcolor": "white",
1159 | "plot_bgcolor": "#E5ECF6",
1160 | "polar": {
1161 | "angularaxis": {
1162 | "gridcolor": "white",
1163 | "linecolor": "white",
1164 | "ticks": ""
1165 | },
1166 | "bgcolor": "#E5ECF6",
1167 | "radialaxis": {
1168 | "gridcolor": "white",
1169 | "linecolor": "white",
1170 | "ticks": ""
1171 | }
1172 | },
1173 | "scene": {
1174 | "xaxis": {
1175 | "backgroundcolor": "#E5ECF6",
1176 | "gridcolor": "white",
1177 | "gridwidth": 2,
1178 | "linecolor": "white",
1179 | "showbackground": true,
1180 | "ticks": "",
1181 | "zerolinecolor": "white"
1182 | },
1183 | "yaxis": {
1184 | "backgroundcolor": "#E5ECF6",
1185 | "gridcolor": "white",
1186 | "gridwidth": 2,
1187 | "linecolor": "white",
1188 | "showbackground": true,
1189 | "ticks": "",
1190 | "zerolinecolor": "white"
1191 | },
1192 | "zaxis": {
1193 | "backgroundcolor": "#E5ECF6",
1194 | "gridcolor": "white",
1195 | "gridwidth": 2,
1196 | "linecolor": "white",
1197 | "showbackground": true,
1198 | "ticks": "",
1199 | "zerolinecolor": "white"
1200 | }
1201 | },
1202 | "shapedefaults": {
1203 | "line": {
1204 | "color": "#2a3f5f"
1205 | }
1206 | },
1207 | "ternary": {
1208 | "aaxis": {
1209 | "gridcolor": "white",
1210 | "linecolor": "white",
1211 | "ticks": ""
1212 | },
1213 | "baxis": {
1214 | "gridcolor": "white",
1215 | "linecolor": "white",
1216 | "ticks": ""
1217 | },
1218 | "bgcolor": "#E5ECF6",
1219 | "caxis": {
1220 | "gridcolor": "white",
1221 | "linecolor": "white",
1222 | "ticks": ""
1223 | }
1224 | },
1225 | "title": {
1226 | "x": 0.05
1227 | },
1228 | "xaxis": {
1229 | "automargin": true,
1230 | "gridcolor": "white",
1231 | "linecolor": "white",
1232 | "ticks": "",
1233 | "title": {
1234 | "standoff": 15
1235 | },
1236 | "zerolinecolor": "white",
1237 | "zerolinewidth": 2
1238 | },
1239 | "yaxis": {
1240 | "automargin": true,
1241 | "gridcolor": "white",
1242 | "linecolor": "white",
1243 | "ticks": "",
1244 | "title": {
1245 | "standoff": 15
1246 | },
1247 | "zerolinecolor": "white",
1248 | "zerolinewidth": 2
1249 | }
1250 | }
1251 | },
1252 | "title": {
1253 | "text": "XGBRegressor with `n_estimators=10`, Custom SLE Loss Benchmark"
1254 | },
1255 | "xaxis": {
1256 | "anchor": "y",
1257 | "autorange": true,
1258 | "domain": [
1259 | 0,
1260 | 1
1261 | ],
1262 | "range": [
1263 | -0.5,
1264 | 3.5
1265 | ],
1266 | "title": {
1267 | "standoff": 30,
1268 | "text": "Number of Data Points"
1269 | },
1270 | "type": "category"
1271 | },
1272 | "yaxis": {
1273 | "anchor": "x",
1274 | "domain": [
1275 | 0,
1276 | 1
1277 | ],
1278 | "range": [
1279 | 0,
1280 | 34
1281 | ],
1282 | "title": {
1283 | "standoff": 50,
1284 | "text": "Time (s)"
1285 | },
1286 | "type": "linear"
1287 | }
1288 | }
1289 | },
1290 | "image/png": "",
1291 | "text/html": [
1292 | "
"
1317 | ]
1318 | },
1319 | "metadata": {},
1320 | "output_type": "display_data"
1321 | }
1322 | ],
1323 | "source": [
1324 | "fig = px.bar(plot_df, x='n_samples', y='time', color='implementation',\n",
1325 | " barmode='group', text='time', text_auto='.2',\n",
1326 | " labels=dict(\n",
1327 | " time='Time (s)',\n",
1328 | " implementation='Implementation',\n",
1329 | " n_samples='Number of Data Points'))\n",
1330 | "fig.update_xaxes(type='category', title_standoff=30)\n",
1331 | "fig.update_yaxes(title_standoff=50)\n",
1332 | "fig.update_traces(textposition=\"outside\")\n",
1333 | "fig.update_layout(font=dict(size=24), yaxis_range=[0, 34], margin=dict(t=70))\n",
1334 | "fig.update_layout(title_text='XGBRegressor with `n_estimators=10`, Custom SLE Loss Benchmark')"
1335 | ]
1336 | }
1337 | ],
1338 | "metadata": {
1339 | "kernelspec": {
1340 | "display_name": "Python 3 (ipykernel)",
1341 | "language": "python",
1342 | "name": "python3"
1343 | },
1344 | "language_info": {
1345 | "codemirror_mode": {
1346 | "name": "ipython",
1347 | "version": 3
1348 | },
1349 | "file_extension": ".py",
1350 | "mimetype": "text/x-python",
1351 | "name": "python",
1352 | "nbconvert_exporter": "python",
1353 | "pygments_lexer": "ipython3",
1354 | "version": "3.9.12"
1355 | }
1356 | },
1357 | "nbformat": 4,
1358 | "nbformat_minor": 5
1359 | }
1360 |
--------------------------------------------------------------------------------