├── 01-tensor_tutorial.ipynb ├── 02-space_stretching.ipynb ├── 03-autograd_tutorial.ipynb ├── 04-spiral_classification.ipynb ├── 05-regression.ipynb ├── 06-convnet.ipynb ├── 07-listening_to_kernels.ipynb ├── 08-seq_classification.ipynb ├── 09-echo_data.ipynb ├── 10-autoencoder.ipynb ├── 11-VAE.ipynb ├── 12-regularization.ipynb ├── 13-bayesian_nn..ipynb ├── 14-truck_backer-upper.ipynb ├── 15-transformer.ipynb ├── 16-gated_GCN.ipynb ├── LICENSE.md ├── README.md └── res ├── plot_lib.py ├── sequential_tasks.py ├── test └── win_xp_shutdown.wav /03-autograd_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Traduction en français du notebook *03* du cours ***Deep Learning*** d'Alfredo Canziani, professeur assistant à la *New York University* : \n", 8 | "https://github.com/Atcold/pytorch-Deep-Learning/blob/master/03-autograd_tutorial.ipynb" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "# Autograd : différenciation automatique\n", 23 | "\n", 24 | "Le package ``autograd`` fournit une différenciation automatique pour toutes les opérations sur les tenseurs. Il s'agit d'un cadre défini par l'utilisateur, ce qui veut dire que votre toile de fond est\n", 25 | "défini par la façon dont votre code est exécuté, et que chaque itération peut être différent.\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import torch" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "Créer un tenseur :" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "tensor([[1., 2.],\n", 54 | " [3., 4.]], requires_grad=True)\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "# Créer un tenseur 2x2 avec des capacités d'accumulation de gradients\n", 60 | "x = torch.tensor([[1, 2], [3, 4]], requires_grad=True, dtype=torch.float32)\n", 61 | "print(x)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "Effectuer une opération sur le tenseur :" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "tensor([[-1., 0.],\n", 81 | " [ 1., 2.]], grad_fn=)\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "# Déduire 2 de tous les éléments\n", 87 | "y = x - 2\n", 88 | "print(y)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "``y`` a été créé à la suite d'une opération, il a donc un ``grad_fn``" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "print(y.grad_fn)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "# Qu'est-ce qui se passe ici ?\n", 122 | "print(x.grad_fn)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 6, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "data": { 132 | "text/plain": [ 133 | "" 134 | ] 135 | }, 136 | "execution_count": 6, 137 | "metadata": {}, 138 | "output_type": "execute_result" 139 | } 140 | ], 141 | "source": [ 142 | "# Creusons un peu plus loin...\n", 143 | "y.grad_fn" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 7, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "data": { 153 | "text/plain": [ 154 | "" 155 | ] 156 | }, 157 | "execution_count": 7, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "y.grad_fn.next_functions[0][0]" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 8, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "tensor([[1., 2.],\n", 175 | " [3., 4.]], requires_grad=True)" 176 | ] 177 | }, 178 | "execution_count": 8, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "y.grad_fn.next_functions[0][0].variable" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "tensor([[ 3., 0.],\n", 197 | " [ 3., 12.]], grad_fn=)\n", 198 | "tensor(4.5000, grad_fn=)\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "# Faire plus d'opérations sur y\n", 204 | "z = y * y * 3\n", 205 | "a = z.mean() # moyenne\n", 206 | "\n", 207 | "print(z)\n", 208 | "print(a)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 10, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "# # Visualisons le graphique de calcul ! (Alfredo tient à remercier ici @szagoruyko)\n", 218 | "from torchviz import make_dot" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 21, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "make_dot(a)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "## Gradients\n", 235 | "\n", 236 | "Rétropropagons maintenant `out.backward()`. Cela équivaut à faire `out.backward(torch.tensor([1.0]))`." 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 12, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "# Rétropropagation\n", 246 | "a.backward()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "Affichage des gradients $\\frac{\\text{d}a}{\\text{d}x}$.\n", 254 | "\n", 255 | "\n" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 13, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stdout", 265 | "output_type": "stream", 266 | "text": [ 267 | "tensor([[-1.5000, 0.0000],\n", 268 | " [ 1.5000, 3.0000]])\n" 269 | ] 270 | } 271 | ], 272 | "source": [ 273 | "# Calculez-le à la main AVANT de l'exécuter\n", 274 | "print(x.grad)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": {}, 280 | "source": [ 281 | "Vous pouvez faire beaucoup de choses avec autograd !\n", 282 | "> Avec une grande *flexibilité* vient une grande responsabilité" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 14, 288 | "metadata": {}, 289 | "outputs": [ 290 | { 291 | "name": "stdout", 292 | "output_type": "stream", 293 | "text": [ 294 | "tensor([ 12.5663, -995.8140, 188.6319], grad_fn=)\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "# Graphes dynamiques!\n", 300 | "x = torch.randn(3, requires_grad=True)\n", 301 | "\n", 302 | "y = x * 2\n", 303 | "i = 0\n", 304 | "while y.data.norm() < 1000:\n", 305 | " y = y * 2\n", 306 | " i += 1\n", 307 | "print(y)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 15, 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stdout", 317 | "output_type": "stream", 318 | "text": [ 319 | "tensor([5.1200e+01, 5.1200e+02, 5.1200e-02])\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "# # Si nous ne faisons pas le *backward* sur un scalaire, nous devons spécifier le *grad_output*\n", 325 | "gradients = torch.FloatTensor([0.1, 1.0, 0.0001])\n", 326 | "y.backward(gradients)\n", 327 | "\n", 328 | "print(x.grad)" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 16, 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "name": "stdout", 338 | "output_type": "stream", 339 | "text": [ 340 | "8\n" 341 | ] 342 | } 343 | ], 344 | "source": [ 345 | "# AVANT d'éxécuter la cellule, pouvez-vous dire ce qui sera affiché ?\n", 346 | "print(i)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "## Inférence" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 17, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "# Cette variable détermine l'étendue du tenseur en dessous\n", 363 | "n = 3" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 18, 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "name": "stdout", 373 | "output_type": "stream", 374 | "text": [ 375 | "tensor([1., 1., 1.])\n", 376 | "tensor([1., 2., 3.])\n" 377 | ] 378 | } 379 | ], 380 | "source": [ 381 | "# x et w permettent l'accumulation du gradient\n", 382 | "x = torch.arange(1., n + 1, requires_grad=True)\n", 383 | "w = torch.ones(n, requires_grad=True)\n", 384 | "z = w @ x\n", 385 | "z.backward()\n", 386 | "print(x.grad, w.grad, sep='\\n')" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 19, 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "None\n", 399 | "tensor([1., 2., 3.])\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "# Seulement w permet l'accumulation du gradient\n", 405 | "x = torch.arange(1., n + 1)\n", 406 | "w = torch.ones(n, requires_grad=True)\n", 407 | "z = w @ x\n", 408 | "z.backward()\n", 409 | "print(x.grad, w.grad, sep='\\n')" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 20, 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "name": "stdout", 419 | "output_type": "stream", 420 | "text": [ 421 | "RuntimeError!!! >:[\n", 422 | "element 0 of tensors does not require grad and does not have a grad_fn\n" 423 | ] 424 | } 425 | ], 426 | "source": [ 427 | "x = torch.arange(1., n + 1)\n", 428 | "w = torch.ones(n, requires_grad=True)\n", 429 | "\n", 430 | "# Indépendamment de ce que vous faites dans ce contexte, tous les tenseurs de torch n'auront pas d'accumulation de gradient\n", 431 | "with torch.no_grad():\n", 432 | " z = w @ x\n", 433 | "\n", 434 | "try:\n", 435 | " z.backward() # PyTorch va renvoyer une erreur ici, puisque z n'a pas de grad accum.\n", 436 | "except RuntimeError as e:\n", 437 | " print('RuntimeError!!! >:[')\n", 438 | " print(e)" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "## Plus de choses\n", 446 | "\n", 447 | "La documentation relative au paquet de différenciation automatique se trouve à l'adresse suivante\n", 448 | "http://pytorch.org/docs/autograd." 449 | ] 450 | } 451 | ], 452 | "metadata": { 453 | "kernelspec": { 454 | "display_name": "Python 3", 455 | "language": "python", 456 | "name": "python3" 457 | }, 458 | "language_info": { 459 | "codemirror_mode": { 460 | "name": "ipython", 461 | "version": 3 462 | }, 463 | "file_extension": ".py", 464 | "mimetype": "text/x-python", 465 | "name": "python", 466 | "nbconvert_exporter": "python", 467 | "pygments_lexer": "ipython3", 468 | "version": "3.6.5" 469 | } 470 | }, 471 | "nbformat": 4, 472 | "nbformat_minor": 4 473 | } 474 | -------------------------------------------------------------------------------- /09-echo_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Traduction en français du notebook *09* du cours ***Deep Learning*** d'Alfredo Canziani, professeur assistant à la *New York University* : \n", 8 | "https://github.com/Atcold/pytorch-Deep-Learning/blob/master/09-echo_data.ipynb" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "# Echo du signal\n", 23 | "\n", 24 | "L'écho des signaux `n` étapes est un exemple de tâche synchronisée many-to-many (plusieurs à plusieurs)." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "from res.sequential_tasks import EchoData\n", 34 | "import torch\n", 35 | "import torch.nn as nn\n", 36 | "# import torch.nn.functional as F\n", 37 | "import torch.optim as optim\n", 38 | "\n", 39 | "torch.manual_seed(1);" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "batch_size = 5\n", 49 | "echo_step = 3\n", 50 | "series_length = 20_000\n", 51 | "BPTT_T = 20\n", 52 | "\n", 53 | "train_data = EchoData(\n", 54 | " echo_step=echo_step,\n", 55 | " batch_size=batch_size,\n", 56 | " series_length=series_length,\n", 57 | " truncated_length=BPTT_T,\n", 58 | ")\n", 59 | "train_size = len(train_data)\n", 60 | "\n", 61 | "test_data = EchoData(\n", 62 | " echo_step=echo_step,\n", 63 | " batch_size=batch_size,\n", 64 | " series_length=series_length,\n", 65 | " truncated_length=BPTT_T,\n", 66 | ")\n", 67 | "test_size = len(test_data)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "(Premiere sequence d entrée) x: 0 0 1 1 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 1 ... \n", 80 | "(Premiere sequence cible) y: 0 0 0 0 0 1 1 0 0 0 1 0 0 0 1 1 0 0 0 0 ... \n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "# Affichons les 20 premiers pas des premières séquences pour voir les données d'écho :\n", 86 | "print('(Premiere sequence d entrée) x:', *train_data.x_batch[0, :20], '... ')\n", 87 | "print('(Premiere sequence cible) y:', *train_data.y_batch[0, :20], '... ')" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "x_batch:\n", 100 | "0 0 1 1 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 1 ...\n", 101 | "1 0 1 0 1 1 1 1 1 0 0 0 1 0 1 1 1 1 0 0 ...\n", 102 | "0 1 1 0 1 0 0 0 1 1 1 1 0 0 1 0 1 0 0 0 ...\n", 103 | "1 0 0 0 1 1 1 1 0 0 1 0 0 1 1 0 1 0 1 0 ...\n", 104 | "0 0 1 0 0 0 0 0 0 1 0 1 1 0 0 0 0 1 0 1 ...\n", 105 | "x_batch de taille: (5, 20000)\n", 106 | "\n", 107 | "y_batch:\n", 108 | "0 0 0 0 0 1 1 0 0 0 1 0 0 0 1 1 0 0 0 0 ...\n", 109 | "0 0 0 1 0 1 0 1 1 1 1 1 0 0 0 1 0 1 1 1 ...\n", 110 | "0 0 0 0 1 1 0 1 0 0 0 1 1 1 1 0 0 1 0 1 ...\n", 111 | "0 0 0 1 0 0 0 1 1 1 1 0 0 1 0 0 1 1 0 1 ...\n", 112 | "0 0 0 0 0 1 0 0 0 0 0 0 1 0 1 1 0 0 0 0 ...\n", 113 | "y_batch de taille: (5, 20000)\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "# des séquences différentes de batch_size sont créées :\n", 119 | "print('x_batch:', *(str(d)[1:-1] + ' ...' for d in train_data.x_batch[:, :20]), sep='\\n')\n", 120 | "print('x_batch de taille:', train_data.x_batch.shape)\n", 121 | "print()\n", 122 | "print('y_batch:', *(str(d)[1:-1] + ' ...' for d in train_data.y_batch[:, :20]), sep='\\n')\n", 123 | "print('y_batch de taille:', train_data.y_batch.shape)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "x_chunk:\n", 136 | "[0 0 1 1 0 0 0 1 0 0 0 1 1 0 0 0 0 1 1 1]\n", 137 | "[1 0 1 0 1 1 1 1 1 0 0 0 1 0 1 1 1 1 0 0]\n", 138 | "[0 1 1 0 1 0 0 0 1 1 1 1 0 0 1 0 1 0 0 0]\n", 139 | "[1 0 0 0 1 1 1 1 0 0 1 0 0 1 1 0 1 0 1 0]\n", 140 | "[0 0 1 0 0 0 0 0 0 1 0 1 1 0 0 0 0 1 0 1]\n", 141 | "Premier x_chunk de taille: (5, 20, 1)\n", 142 | "\n", 143 | "y_chunk:\n", 144 | "[0 0 0 0 0 1 1 0 0 0 1 0 0 0 1 1 0 0 0 0]\n", 145 | "[0 0 0 1 0 1 0 1 1 1 1 1 0 0 0 1 0 1 1 1]\n", 146 | "[0 0 0 0 1 1 0 1 0 0 0 1 1 1 1 0 0 1 0 1]\n", 147 | "[0 0 0 1 0 0 0 1 1 1 1 0 0 1 0 0 1 1 0 1]\n", 148 | "[0 0 0 0 0 1 0 0 0 0 0 0 1 0 1 1 0 0 0 0]\n", 149 | "Premier y_chunk de taille: (5, 20, 1)\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "# Pour utiliser les RNN, les données sont organisées en\n", 155 | "# morceaux de taille [batch_size, T, feature_dim]\n", 156 | "print('x_chunk:', *train_data.x_chunks[0].squeeze(), sep='\\n')\n", 157 | "print('Premier x_chunk de taille:', train_data.x_chunks[0].shape)\n", 158 | "print()\n", 159 | "print('y_chunk:', *train_data.y_chunks[0].squeeze(), sep='\\n')\n", 160 | "print('Premier y_chunk de taille:', train_data.y_chunks[0].shape)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 7, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "class SimpleRNN(nn.Module):\n", 170 | " def __init__(self, input_size, rnn_hidden_size, output_size):\n", 171 | " super().__init__()\n", 172 | " self.rnn_hidden_size = rnn_hidden_size\n", 173 | " self.rnn = torch.nn.RNN(\n", 174 | " input_size=input_size,\n", 175 | " hidden_size=rnn_hidden_size,\n", 176 | " num_layers=1,\n", 177 | " nonlinearity='relu',\n", 178 | " batch_first=True\n", 179 | " )\n", 180 | " self.linear = torch.nn.Linear(\n", 181 | " in_features=rnn_hidden_size,\n", 182 | " out_features=1\n", 183 | " )\n", 184 | "\n", 185 | " def forward(self, x, hidden):\n", 186 | " x, hidden = self.rnn(x, hidden) \n", 187 | " x = self.linear(x)\n", 188 | " return x, hidden" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 8, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 9, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "def train(hidden):\n", 207 | " model.train()\n", 208 | " \n", 209 | " correct = 0\n", 210 | " for batch_idx in range(train_size):\n", 211 | " data, target = train_data[batch_idx]\n", 212 | " data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).float().to(device)\n", 213 | " optimizer.zero_grad()\n", 214 | " if hidden is not None: hidden.detach_()\n", 215 | " logits, hidden = model(data, hidden)\n", 216 | " loss = criterion(logits, target)\n", 217 | " loss.backward()\n", 218 | " optimizer.step()\n", 219 | " \n", 220 | " pred = (torch.sigmoid(logits) > 0.5)\n", 221 | " correct += (pred == target.byte()).int().sum().item()\n", 222 | " \n", 223 | " return correct, loss.item(), hidden" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 10, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "def test(hidden):\n", 233 | " model.eval() \n", 234 | " correct = 0\n", 235 | " with torch.no_grad():\n", 236 | " for batch_idx in range(test_size):\n", 237 | " data, target = test_data[batch_idx]\n", 238 | " data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).float().to(device)\n", 239 | " logits, hidden = model(data, hidden)\n", 240 | " \n", 241 | " pred = (torch.sigmoid(logits) > 0.5)\n", 242 | " correct += (pred == target.byte()).int().sum().item()\n", 243 | "\n", 244 | " return correct" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 11, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "feature_dim = 1 #puisque nous avons une série scalaire\n", 254 | "h_units = 4\n", 255 | "\n", 256 | "model = SimpleRNN(\n", 257 | " input_size=1,\n", 258 | " rnn_hidden_size=h_units,\n", 259 | " output_size=feature_dim\n", 260 | ").to(device)\n", 261 | "hidden = None\n", 262 | " \n", 263 | "criterion = torch.nn.BCEWithLogitsLoss()\n", 264 | "optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 12, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "Train Epoch: 1/5, loss: 0.538, accuracy 59.5%\n", 277 | "Train Epoch: 2/5, loss: 0.082, accuracy 87.1%\n", 278 | "Train Epoch: 3/5, loss: 0.001, accuracy 100.0%\n", 279 | "Train Epoch: 4/5, loss: 0.000, accuracy 100.0%\n", 280 | "Train Epoch: 5/5, loss: 0.000, accuracy 100.0%\n", 281 | "Test accuracy: 100.0%\n" 282 | ] 283 | } 284 | ], 285 | "source": [ 286 | "n_epochs = 5\n", 287 | "epoch = 0\n", 288 | "\n", 289 | "while epoch < n_epochs:\n", 290 | " correct, loss, hidden = train(hidden)\n", 291 | " epoch += 1\n", 292 | " train_accuracy = float(correct) / train_size\n", 293 | " print(f'Train Epoch: {epoch}/{n_epochs}, loss: {loss:.3f}, accuracy {train_accuracy:.1f}%')\n", 294 | "\n", 295 | "#test \n", 296 | "correct = test(hidden)\n", 297 | "test_accuracy = float(correct) / test_size\n", 298 | "print(f'Test accuracy: {test_accuracy:.1f}%')" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 13, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | "tensor([[1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1,\n", 311 | " 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1,\n", 312 | " 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0,\n", 313 | " 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0,\n", 314 | " 1, 1, 0, 0]], dtype=torch.uint8)\n", 315 | "tensor([[ True, True, False, False, True, True, True, True, False, False,\n", 316 | " True, True, True, True, True, True, False, True, True, False,\n", 317 | " False, True, False, False, True, True, True, False, True, False,\n", 318 | " False, True, True, False, True, True, True, True, False, False,\n", 319 | " True, True, False, False, False, False, True, True, True, False,\n", 320 | " True, False, False, True, True, False, True, True, False, True,\n", 321 | " False, False, True, True, True, False, True, True, False, True,\n", 322 | " True, True, True, False, False, False, False, False, True, True,\n", 323 | " True, True, True, True, True, False, False, False, False, False,\n", 324 | " False, True, True, True, True, True, True, False, False, True]])\n" 325 | ] 326 | } 327 | ], 328 | "source": [ 329 | "# Essayons un peu d'écho\n", 330 | "my_input = torch.empty(1, 100, 1).random_(2).to(device)\n", 331 | "hidden = None\n", 332 | "my_out, _ = model(my_input, hidden)\n", 333 | "print(my_input.view(1, -1).byte(), (my_out > 0).view(1, -1), sep='\\n')" 334 | ] 335 | } 336 | ], 337 | "metadata": { 338 | "kernelspec": { 339 | "display_name": "Python 3", 340 | "language": "python", 341 | "name": "python3" 342 | }, 343 | "language_info": { 344 | "codemirror_mode": { 345 | "name": "ipython", 346 | "version": 3 347 | }, 348 | "file_extension": ".py", 349 | "mimetype": "text/x-python", 350 | "name": "python", 351 | "nbconvert_exporter": "python", 352 | "pygments_lexer": "ipython3", 353 | "version": "3.6.5" 354 | } 355 | }, 356 | "nbformat": 4, 357 | "nbformat_minor": 4 358 | } 359 | -------------------------------------------------------------------------------- /14-truck_backer-upper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Traduction en français du notebook *14* du cours ***Deep Learning*** d'Alfredo Canziani, professeur assistant à la *New York University* :\n", 8 | "https://github.com/Atcold/pytorch-Deep-Learning/blob/master/14-truck_backer-upper.ipynb" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from matplotlib.pylab import *\n", 25 | "from matplotlib.patches import Rectangle\n", 26 | "from matplotlib.collections import PatchCollection\n", 27 | "from matplotlib.lines import Line2D\n", 28 | "π = pi" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "style.use(['dark_background', 'bmh'])\n", 38 | "%matplotlib notebook" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "Diagramme voiture-remorque (image inversée `res/car-trainer-k.png` disponible également) :\n", 46 | "![voiture-remorque](car-trailer-w.png)\n", 47 | "\n", 48 | "Équation de la voiture-remorque :\n", 49 | "\\begin{align}\n", 50 | "\\dot x &= s \\cos \\theta_0 \\\\\n", 51 | "\\dot y &= s \\sin \\theta_0 \\\\\n", 52 | "\\dot \\theta_0 &= \\frac{s}{L} \\tan \\phi \\\\\n", 53 | "\\dot \\theta_1 &= \\frac{s}{d_1} \\sin(\\theta_1 - \\theta_0)\n", 54 | "\\end{align}\n", 55 | "où $s$ : vitesse signée, $\\phi$ : angle de braquage négatif," 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "class Truck:\n", 65 | " def __init__(self, display=False):\n", 66 | "\n", 67 | " self.W = 1 # largeur de la voiture et de la remorque, pour le dessin uniquement\n", 68 | " self.L = 1 * self.W # longueur de la voiture\n", 69 | " self.d = 4 * self.L # d_1\n", 70 | " self.s = -0.1 # vitesse\n", 71 | " self.display = display\n", 72 | " \n", 73 | " self.box = [0, 40, -10, 10]\n", 74 | " if self.display:\n", 75 | " self.f = figure(figsize=(10, 5), num='The truck backer-upper', facecolor='none')\n", 76 | " self.ax = self.f.add_axes([0.01, 0.01, 0.98, 0.98], facecolor='black')\n", 77 | " self.patches = list()\n", 78 | " \n", 79 | " self.ax.axis('equal')\n", 80 | " b = self.box\n", 81 | " self.ax.axis([b[0] - 1, b[1], b[2], b[3]])\n", 82 | " self.ax.set_xticks([], []); self.ax.set_yticks([], [])\n", 83 | " self.ax.axhline(); self.ax.axvline()\n", 84 | "\n", 85 | " self.reset()\n", 86 | " \n", 87 | " def reset(self, ϕ=0):\n", 88 | " self.ϕ = ϕ # angle de braquage initial de la voiture\n", 89 | " \n", 90 | " # self.θ0 = deg2rad(30) # direction initiale de la voiture\n", 91 | " # self.θ1 = deg2rad(-30) # direction initiale de la remorque\n", 92 | " # self.x, self.y = 20, -5 # les coordonnées initiales de la voiture\n", 93 | " \n", 94 | " self.θ0 = random() * 2 * π # 0 <= ϑ₀ < 2π\n", 95 | " self.θ1 = (random() - 0.5) * π / 2 + self.θ0 # -π/4 <= ϑ₁ - ϑ₀ < π/4\n", 96 | " self.x = (random() * .75 + 0.25) * self.box[1]\n", 97 | " self.y = (random() - 0.5) * (self.box[3] - self.box[2])\n", 98 | " \n", 99 | " # En cas de mauvaise initialisation, réinitialiser\n", 100 | " if not self.valid():\n", 101 | " self.reset(ϕ)\n", 102 | " \n", 103 | " # Dessine, si display vaut True\n", 104 | " if self.display: self.draw()\n", 105 | " \n", 106 | " def step(self, ϕ=0, dt=1):\n", 107 | " \n", 108 | " # Contrôle des conditions illégales\n", 109 | " if self.is_jackknifed():\n", 110 | " print('Le camion est en portefeuille !')\n", 111 | " return\n", 112 | " \n", 113 | " if self.is_offscreen():\n", 114 | " print('La voiture ou la remorque est hors champ')\n", 115 | " return\n", 116 | " \n", 117 | " self.ϕ = ϕ\n", 118 | " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", 119 | " \n", 120 | " # Effectuer la mise à jour de l'état\n", 121 | " self.x += s * cos(θ0) * dt\n", 122 | " self.y += s * sin(θ0) * dt\n", 123 | " self.θ0 += s / L * tan(ϕ) * dt\n", 124 | " self.θ1 += s / d * sin(θ0 - θ1) * dt\n", 125 | " \n", 126 | " return (self.x, self.y, self.θ0, *self._traler_xy(), self.θ1)\n", 127 | " \n", 128 | " def state(self):\n", 129 | " return (self.x, self.y, self.θ0, *self._traler_xy(), self.θ1)\n", 130 | " \n", 131 | " def _get_atributes(self):\n", 132 | " return (\n", 133 | " self.x, self.y, self.W, self.L, self.d, self.s,\n", 134 | " self.θ0, self.θ1, self.ϕ\n", 135 | " )\n", 136 | " \n", 137 | " def _traler_xy(self):\n", 138 | " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", 139 | " return x - d * cos(θ1), y - d * sin(θ1)\n", 140 | " \n", 141 | " def is_jackknifed(self):\n", 142 | " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", 143 | " return abs(θ0 - θ1) * 180 / π > 90\n", 144 | " \n", 145 | " def is_offscreen(self):\n", 146 | " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", 147 | " \n", 148 | " x1, y1 = x + 1.5 * L * cos(θ0), y + 1.5 * L * sin(θ0)\n", 149 | " x2, y2 = self._traler_xy()\n", 150 | " \n", 151 | " b = self.box\n", 152 | " return not (\n", 153 | " b[0] <= x1 <= b[1] and b[2] <= y1 <= b[3] and\n", 154 | " b[0] <= x2 <= b[1] and b[2] <= y2 <= b[3]\n", 155 | " )\n", 156 | " \n", 157 | " def valid(self):\n", 158 | " return not self.is_jackknifed() and not self.is_offscreen()\n", 159 | " \n", 160 | " def draw(self):\n", 161 | " if not self.display: return\n", 162 | " if self.patches: self.clear()\n", 163 | " self._draw_car()\n", 164 | " self._draw_trailer()\n", 165 | " self.f.canvas.draw()\n", 166 | " \n", 167 | " def clear(self):\n", 168 | " for p in self.patches:\n", 169 | " p.remove()\n", 170 | " self.patches = list()\n", 171 | " \n", 172 | " def _draw_car(self):\n", 173 | " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", 174 | " ax = self.ax\n", 175 | " \n", 176 | " x1, y1 = x + L / 2 * cos(θ0), y + L / 2 * sin(θ0)\n", 177 | " bar = Line2D((x, x1), (y, y1), lw=5, color='C2', alpha=0.8)\n", 178 | " ax.add_line(bar)\n", 179 | "\n", 180 | " car = Rectangle(\n", 181 | " (x1, y1 - W / 2), L, W, 0, color='C2', alpha=0.8, transform=\n", 182 | " matplotlib.transforms.Affine2D().rotate_deg_around(x1, y1, θ0 * 180 / π) +\n", 183 | " ax.transData\n", 184 | " )\n", 185 | " ax.add_patch(car)\n", 186 | "\n", 187 | " x2, y2 = x1 + L / 2 ** 0.5 * cos(θ0 + π / 4), y1 + L / 2 ** 0.5 * sin(θ0 + π / 4)\n", 188 | " left_wheel = Line2D(\n", 189 | " (x2 - L / 4 * cos(θ0 + ϕ), x2 + L / 4 * cos(θ0 + ϕ)),\n", 190 | " (y2 - L / 4 * sin(θ0 + ϕ), y2 + L / 4 * sin(θ0 + ϕ)),\n", 191 | " lw=3, color='C5', alpha=1)\n", 192 | " ax.add_line(left_wheel)\n", 193 | "\n", 194 | " x3, y3 = x1 + L / 2 ** 0.5 * cos(π / 4 - θ0), y1 - L / 2 ** 0.5 * sin(π / 4 - θ0)\n", 195 | " right_wheel = Line2D(\n", 196 | " (x3 - L / 4 * cos(θ0 + ϕ), x3 + L / 4 * cos(θ0 + ϕ)),\n", 197 | " (y3 - L / 4 * sin(θ0 + ϕ), y3 + L / 4 * sin(θ0 + ϕ)),\n", 198 | " lw=3, color='C5', alpha=1)\n", 199 | " ax.add_line(right_wheel)\n", 200 | " \n", 201 | " self.patches += [car, bar, left_wheel, right_wheel]\n", 202 | " \n", 203 | " def _draw_trailer(self):\n", 204 | " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", 205 | " ax = self.ax\n", 206 | " \n", 207 | " x, y = x - d * cos(θ1), y - d * sin(θ1) - W / 2\n", 208 | " trailer = Rectangle(\n", 209 | " (x, y), d, W, 0, color='C0', alpha=0.8, transform=\n", 210 | " matplotlib.transforms.Affine2D().rotate_deg_around(x, y + W/2, θ1 * 180 / π) +\n", 211 | " ax.transData\n", 212 | " )\n", 213 | " ax.add_patch(trailer)\n", 214 | " \n", 215 | " self.patches += [trailer]" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 4, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "application/javascript": [ 226 | "/* Put everything inside the global mpl namespace */\n", 227 | "window.mpl = {};\n", 228 | "\n", 229 | "\n", 230 | "mpl.get_websocket_type = function() {\n", 231 | " if (typeof(WebSocket) !== 'undefined') {\n", 232 | " return WebSocket;\n", 233 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n", 234 | " return MozWebSocket;\n", 235 | " } else {\n", 236 | " alert('Your browser does not have WebSocket support. ' +\n", 237 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", 238 | " 'Firefox 4 and 5 are also supported but you ' +\n", 239 | " 'have to enable WebSockets in about:config.');\n", 240 | " };\n", 241 | "}\n", 242 | "\n", 243 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", 244 | " this.id = figure_id;\n", 245 | "\n", 246 | " this.ws = websocket;\n", 247 | "\n", 248 | " this.supports_binary = (this.ws.binaryType != undefined);\n", 249 | "\n", 250 | " if (!this.supports_binary) {\n", 251 | " var warnings = document.getElementById(\"mpl-warnings\");\n", 252 | " if (warnings) {\n", 253 | " warnings.style.display = 'block';\n", 254 | " warnings.textContent = (\n", 255 | " \"This browser does not support binary websocket messages. \" +\n", 256 | " \"Performance may be slow.\");\n", 257 | " }\n", 258 | " }\n", 259 | "\n", 260 | " this.imageObj = new Image();\n", 261 | "\n", 262 | " this.context = undefined;\n", 263 | " this.message = undefined;\n", 264 | " this.canvas = undefined;\n", 265 | " this.rubberband_canvas = undefined;\n", 266 | " this.rubberband_context = undefined;\n", 267 | " this.format_dropdown = undefined;\n", 268 | "\n", 269 | " this.image_mode = 'full';\n", 270 | "\n", 271 | " this.root = $('
');\n", 272 | " this._root_extra_style(this.root)\n", 273 | " this.root.attr('style', 'display: inline-block');\n", 274 | "\n", 275 | " $(parent_element).append(this.root);\n", 276 | "\n", 277 | " this._init_header(this);\n", 278 | " this._init_canvas(this);\n", 279 | " this._init_toolbar(this);\n", 280 | "\n", 281 | " var fig = this;\n", 282 | "\n", 283 | " this.waiting = false;\n", 284 | "\n", 285 | " this.ws.onopen = function () {\n", 286 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", 287 | " fig.send_message(\"send_image_mode\", {});\n", 288 | " if (mpl.ratio != 1) {\n", 289 | " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", 290 | " }\n", 291 | " fig.send_message(\"refresh\", {});\n", 292 | " }\n", 293 | "\n", 294 | " this.imageObj.onload = function() {\n", 295 | " if (fig.image_mode == 'full') {\n", 296 | " // Full images could contain transparency (where diff images\n", 297 | " // almost always do), so we need to clear the canvas so that\n", 298 | " // there is no ghosting.\n", 299 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", 300 | " }\n", 301 | " fig.context.drawImage(fig.imageObj, 0, 0);\n", 302 | " };\n", 303 | "\n", 304 | " this.imageObj.onunload = function() {\n", 305 | " fig.ws.close();\n", 306 | " }\n", 307 | "\n", 308 | " this.ws.onmessage = this._make_on_message_function(this);\n", 309 | "\n", 310 | " this.ondownload = ondownload;\n", 311 | "}\n", 312 | "\n", 313 | "mpl.figure.prototype._init_header = function() {\n", 314 | " var titlebar = $(\n", 315 | " '
');\n", 317 | " var titletext = $(\n", 318 | " '
');\n", 320 | " titlebar.append(titletext)\n", 321 | " this.root.append(titlebar);\n", 322 | " this.header = titletext[0];\n", 323 | "}\n", 324 | "\n", 325 | "\n", 326 | "\n", 327 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", 328 | "\n", 329 | "}\n", 330 | "\n", 331 | "\n", 332 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", 333 | "\n", 334 | "}\n", 335 | "\n", 336 | "mpl.figure.prototype._init_canvas = function() {\n", 337 | " var fig = this;\n", 338 | "\n", 339 | " var canvas_div = $('
');\n", 340 | "\n", 341 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", 342 | "\n", 343 | " function canvas_keyboard_event(event) {\n", 344 | " return fig.key_event(event, event['data']);\n", 345 | " }\n", 346 | "\n", 347 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n", 348 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n", 349 | " this.canvas_div = canvas_div\n", 350 | " this._canvas_extra_style(canvas_div)\n", 351 | " this.root.append(canvas_div);\n", 352 | "\n", 353 | " var canvas = $('');\n", 354 | " canvas.addClass('mpl-canvas');\n", 355 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", 356 | "\n", 357 | " this.canvas = canvas[0];\n", 358 | " this.context = canvas[0].getContext(\"2d\");\n", 359 | "\n", 360 | " var backingStore = this.context.backingStorePixelRatio ||\n", 361 | "\tthis.context.webkitBackingStorePixelRatio ||\n", 362 | "\tthis.context.mozBackingStorePixelRatio ||\n", 363 | "\tthis.context.msBackingStorePixelRatio ||\n", 364 | "\tthis.context.oBackingStorePixelRatio ||\n", 365 | "\tthis.context.backingStorePixelRatio || 1;\n", 366 | "\n", 367 | " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", 368 | "\n", 369 | " var rubberband = $('');\n", 370 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", 371 | "\n", 372 | " var pass_mouse_events = true;\n", 373 | "\n", 374 | " canvas_div.resizable({\n", 375 | " start: function(event, ui) {\n", 376 | " pass_mouse_events = false;\n", 377 | " },\n", 378 | " resize: function(event, ui) {\n", 379 | " fig.request_resize(ui.size.width, ui.size.height);\n", 380 | " },\n", 381 | " stop: function(event, ui) {\n", 382 | " pass_mouse_events = true;\n", 383 | " fig.request_resize(ui.size.width, ui.size.height);\n", 384 | " },\n", 385 | " });\n", 386 | "\n", 387 | " function mouse_event_fn(event) {\n", 388 | " if (pass_mouse_events)\n", 389 | " return fig.mouse_event(event, event['data']);\n", 390 | " }\n", 391 | "\n", 392 | " rubberband.mousedown('button_press', mouse_event_fn);\n", 393 | " rubberband.mouseup('button_release', mouse_event_fn);\n", 394 | " // Throttle sequential mouse events to 1 every 20ms.\n", 395 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n", 396 | "\n", 397 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", 398 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", 399 | "\n", 400 | " canvas_div.on(\"wheel\", function (event) {\n", 401 | " event = event.originalEvent;\n", 402 | " event['data'] = 'scroll'\n", 403 | " if (event.deltaY < 0) {\n", 404 | " event.step = 1;\n", 405 | " } else {\n", 406 | " event.step = -1;\n", 407 | " }\n", 408 | " mouse_event_fn(event);\n", 409 | " });\n", 410 | "\n", 411 | " canvas_div.append(canvas);\n", 412 | " canvas_div.append(rubberband);\n", 413 | "\n", 414 | " this.rubberband = rubberband;\n", 415 | " this.rubberband_canvas = rubberband[0];\n", 416 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", 417 | " this.rubberband_context.strokeStyle = \"#000000\";\n", 418 | "\n", 419 | " this._resize_canvas = function(width, height) {\n", 420 | " // Keep the size of the canvas, canvas container, and rubber band\n", 421 | " // canvas in synch.\n", 422 | " canvas_div.css('width', width)\n", 423 | " canvas_div.css('height', height)\n", 424 | "\n", 425 | " canvas.attr('width', width * mpl.ratio);\n", 426 | " canvas.attr('height', height * mpl.ratio);\n", 427 | " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", 428 | "\n", 429 | " rubberband.attr('width', width);\n", 430 | " rubberband.attr('height', height);\n", 431 | " }\n", 432 | "\n", 433 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n", 434 | " // upon first draw.\n", 435 | " this._resize_canvas(600, 600);\n", 436 | "\n", 437 | " // Disable right mouse context menu.\n", 438 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", 439 | " return false;\n", 440 | " });\n", 441 | "\n", 442 | " function set_focus () {\n", 443 | " canvas.focus();\n", 444 | " canvas_div.focus();\n", 445 | " }\n", 446 | "\n", 447 | " window.setTimeout(set_focus, 100);\n", 448 | "}\n", 449 | "\n", 450 | "mpl.figure.prototype._init_toolbar = function() {\n", 451 | " var fig = this;\n", 452 | "\n", 453 | " var nav_element = $('
');\n", 454 | " nav_element.attr('style', 'width: 100%');\n", 455 | " this.root.append(nav_element);\n", 456 | "\n", 457 | " // Define a callback function for later on.\n", 458 | " function toolbar_event(event) {\n", 459 | " return fig.toolbar_button_onclick(event['data']);\n", 460 | " }\n", 461 | " function toolbar_mouse_event(event) {\n", 462 | " return fig.toolbar_button_onmouseover(event['data']);\n", 463 | " }\n", 464 | "\n", 465 | " for(var toolbar_ind in mpl.toolbar_items) {\n", 466 | " var name = mpl.toolbar_items[toolbar_ind][0];\n", 467 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", 468 | " var image = mpl.toolbar_items[toolbar_ind][2];\n", 469 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", 470 | "\n", 471 | " if (!name) {\n", 472 | " // put a spacer in here.\n", 473 | " continue;\n", 474 | " }\n", 475 | " var button = $('