├── .gitignore ├── Pytorch Introduction.ipynb └── images ├── intro_pytorch_gradient_descent.PNG ├── intro_pytorch_mnist.PNG ├── intro_pytorch_mnist_operations.PNG ├── intro_pytorch_negativeloglikelihood.PNG └── pytorhc_vs_tensorflow.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /Pytorch Introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Introduction to Pytorch\n", 8 | "\n", 9 | "" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 29, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import torch\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from mpl_toolkits.mplot3d import Axes3D" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Recall Gradient Descent\n", 29 | "\n", 30 | "- [Formation au Deep Learning](https://www.youtube.com/playlist?list=PLpEPgC7cUJ4b1ARx8PyIQa_sdZRL2GXw5)\n", 31 | "- [Formation Tensorflow 2.0](https://www.youtube.com/playlist?list=PLpEPgC7cUJ4byTM5kGA0Te1jUeNwbSgfd)\n", 32 | "- [DESCENTE DE GRADIENT (GRADIENT DESCENT) - ML#4](https://www.youtube.com/watch?v=rcl_YRyoLIY&t=3s)\n", 33 | "- [La descente de gradient (stochastique) | Intelligence artificielle 42](https://www.youtube.com/watch?v=Q9-vDFvDdfg)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### A simple function to minimize" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 32, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stderr", 50 | "output_type": "stream", 51 | "text": [ 52 | "C:\\Users\\Thibault\\Anaconda3\\envs\\ai-3.7\\lib\\site-packages\\ipykernel_launcher.py:13: UserWarning: The following kwargs were not used by contour: 'color'\n", 53 | " del sys.path[0]\n" 54 | ] 55 | }, 56 | { 57 | "data": { 58 | "image/png": "\n", 59 | "text/plain": [ 60 | "
" 61 | ] 62 | }, 63 | "metadata": { 64 | "needs_background": "light" 65 | }, 66 | "output_type": "display_data" 67 | } 68 | ], 69 | "source": [ 70 | "def f(x, y):\n", 71 | " return x ** 2 + y ** 2\n", 72 | "\n", 73 | "x = np.linspace(-6, 6, 30)\n", 74 | "y = np.linspace(-6, 6, 30)\n", 75 | "\n", 76 | "X, Y = np.meshgrid(x, y)\n", 77 | "\n", 78 | "Z = f(X, Y)\n", 79 | "\n", 80 | "fig = plt.figure()\n", 81 | "ax = plt.axes(projection='3d')\n", 82 | "ax.contour(X, Y, Z, 50, cmap='binary', color=\"r\")\n", 83 | "ax.set_xlabel('parma1')\n", 84 | "ax.set_ylabel('param2')\n", 85 | "ax.set_zlabel('error')\n", 86 | "\n", 87 | "ax.view_init(30, 30)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "### Create the two variables" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 37, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "x = torch.tensor(3, dtype=torch.float, requires_grad=True)\n", 104 | "y = torch.tensor(3, dtype=torch.float, requires_grad=True)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "### Minimize the function by applying gradient descent\n", 112 | "\n", 113 | "" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 50, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Result = >> tensor(11.5200, grad_fn=)\n", 126 | "Result = >> tensor(7.3728, grad_fn=)\n", 127 | "Result = >> tensor(4.7186, grad_fn=)\n", 128 | "Result = >> tensor(3.0199, grad_fn=)\n", 129 | "Result = >> tensor(1.9327, grad_fn=)\n", 130 | "Result = >> tensor(1.2370, grad_fn=)\n", 131 | "Result = >> tensor(0.7916, grad_fn=)\n", 132 | "Result = >> tensor(0.5067, grad_fn=)\n", 133 | "Result = >> tensor(0.3243, grad_fn=)\n", 134 | "Result = >> tensor(0.2075, grad_fn=)\n", 135 | "Result = >> tensor(0.1328, grad_fn=)\n", 136 | "Result = >> tensor(0.0850, grad_fn=)\n", 137 | "Result = >> tensor(0.0544, grad_fn=)\n", 138 | "Result = >> tensor(0.0348, grad_fn=)\n", 139 | "Result = >> tensor(0.0223, grad_fn=)\n", 140 | "Result = >> tensor(0.0143, grad_fn=)\n", 141 | "Result = >> tensor(0.0091, grad_fn=)\n", 142 | "Result = >> tensor(0.0058, grad_fn=)\n", 143 | "Result = >> tensor(0.0037, grad_fn=)\n", 144 | "Result = >> tensor(0.0024, grad_fn=)\n", 145 | "Result = >> tensor(0.0015, grad_fn=)\n", 146 | "Result = >> tensor(0.0010, grad_fn=)\n", 147 | "Result = >> tensor(0.0006, grad_fn=)\n", 148 | "Result = >> tensor(0.0004, grad_fn=)\n", 149 | "Result = >> tensor(0.0003, grad_fn=)\n", 150 | "Result = >> tensor(0.0002, grad_fn=)\n", 151 | "Result = >> tensor(0.0001, grad_fn=)\n", 152 | "Result = >> tensor(6.7346e-05, grad_fn=)\n", 153 | "Result = >> tensor(4.3101e-05, grad_fn=)\n", 154 | "Result = >> tensor(2.7585e-05, grad_fn=)\n", 155 | "Result = >> tensor(1.7654e-05, grad_fn=)\n", 156 | "Result = >> tensor(1.1299e-05, grad_fn=)\n", 157 | "Result = >> tensor(7.2312e-06, grad_fn=)\n", 158 | "Result = >> tensor(4.6280e-06, grad_fn=)\n", 159 | "Result = >> tensor(2.9619e-06, grad_fn=)\n", 160 | "Result = >> tensor(1.8956e-06, grad_fn=)\n", 161 | "Result = >> tensor(1.2132e-06, grad_fn=)\n", 162 | "Result = >> tensor(7.7645e-07, grad_fn=)\n", 163 | "Result = >> tensor(4.9693e-07, grad_fn=)\n", 164 | "Result = >> tensor(3.1803e-07, grad_fn=)\n", 165 | "Result = >> tensor(2.0354e-07, grad_fn=)\n", 166 | "Result = >> tensor(1.3027e-07, grad_fn=)\n", 167 | "Result = >> tensor(8.3370e-08, grad_fn=)\n", 168 | "Result = >> tensor(5.3357e-08, grad_fn=)\n", 169 | "Result = >> tensor(3.4148e-08, grad_fn=)\n", 170 | "Result = >> tensor(2.1855e-08, grad_fn=)\n", 171 | "Result = >> tensor(1.3987e-08, grad_fn=)\n", 172 | "Result = >> tensor(8.9518e-09, grad_fn=)\n", 173 | "Result = >> tensor(5.7292e-09, grad_fn=)\n", 174 | "Result = >> tensor(3.6667e-09, grad_fn=)\n", 175 | "Result = >> tensor(2.3467e-09, grad_fn=)\n", 176 | "Result = >> tensor(1.5019e-09, grad_fn=)\n", 177 | "Result = >> tensor(9.6119e-10, grad_fn=)\n", 178 | "Result = >> tensor(6.1516e-10, grad_fn=)\n", 179 | "Result = >> tensor(3.9371e-10, grad_fn=)\n", 180 | "Result = >> tensor(2.5197e-10, grad_fn=)\n", 181 | "Result = >> tensor(1.6126e-10, grad_fn=)\n", 182 | "Result = >> tensor(1.0321e-10, grad_fn=)\n", 183 | "Result = >> tensor(6.6053e-11, grad_fn=)\n", 184 | "Result = >> tensor(4.2274e-11, grad_fn=)\n", 185 | "Result = >> tensor(2.7055e-11, grad_fn=)\n", 186 | "Result = >> tensor(1.7315e-11, grad_fn=)\n", 187 | "Result = >> tensor(1.1082e-11, grad_fn=)\n", 188 | "Result = >> tensor(7.0924e-12, grad_fn=)\n", 189 | "Result = >> tensor(4.5391e-12, grad_fn=)\n", 190 | "Result = >> tensor(2.9050e-12, grad_fn=)\n", 191 | "Result = >> tensor(1.8592e-12, grad_fn=)\n", 192 | "Result = >> tensor(1.1899e-12, grad_fn=)\n", 193 | "Result = >> tensor(7.6154e-13, grad_fn=)\n", 194 | "Result = >> tensor(4.8738e-13, grad_fn=)\n", 195 | "Result = >> tensor(3.1193e-13, grad_fn=)\n", 196 | "Result = >> tensor(1.9963e-13, grad_fn=)\n", 197 | "Result = >> tensor(1.2776e-13, grad_fn=)\n", 198 | "Result = >> tensor(8.1769e-14, grad_fn=)\n", 199 | "Result = >> tensor(5.2332e-14, grad_fn=)\n", 200 | "Result = >> tensor(3.3493e-14, grad_fn=)\n", 201 | "Result = >> tensor(2.1435e-14, grad_fn=)\n", 202 | "Result = >> tensor(1.3719e-14, grad_fn=)\n", 203 | "Result = >> tensor(8.7799e-15, grad_fn=)\n", 204 | "Result = >> tensor(5.6191e-15, grad_fn=)\n", 205 | "Result = >> tensor(3.5963e-15, grad_fn=)\n", 206 | "Result = >> tensor(2.3016e-15, grad_fn=)\n", 207 | "Result = >> tensor(1.4730e-15, grad_fn=)\n", 208 | "Result = >> tensor(9.4274e-16, grad_fn=)\n", 209 | "Result = >> tensor(6.0335e-16, grad_fn=)\n", 210 | "Result = >> tensor(3.8614e-16, grad_fn=)\n", 211 | "Result = >> tensor(2.4713e-16, grad_fn=)\n", 212 | "Result = >> tensor(1.5816e-16, grad_fn=)\n", 213 | "Result = >> tensor(1.0123e-16, grad_fn=)\n", 214 | "Result = >> tensor(6.4784e-17, grad_fn=)\n", 215 | "Result = >> tensor(4.1462e-17, grad_fn=)\n", 216 | "Result = >> tensor(2.6536e-17, grad_fn=)\n", 217 | "Result = >> tensor(1.6983e-17, grad_fn=)\n", 218 | "Result = >> tensor(1.0869e-17, grad_fn=)\n", 219 | "Result = >> tensor(6.9562e-18, grad_fn=)\n", 220 | "Result = >> tensor(4.4519e-18, grad_fn=)\n", 221 | "Result = >> tensor(2.8492e-18, grad_fn=)\n", 222 | "Result = >> tensor(1.8235e-18, grad_fn=)\n", 223 | "Result = >> tensor(1.1671e-18, grad_fn=)\n", 224 | "Result = >> tensor(7.4691e-19, grad_fn=)\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "for i in range(100):\n", 230 | " # Compute the operation. We want to minimise the result. To do so we can compute\n", 231 | " # the gradient of each variable and apply the gradient descent formula.\n", 232 | " result = x**2 + y**2\n", 233 | " print(\"Result = >>\", result)\n", 234 | " \n", 235 | " # Compute the gradient for each operations made before\n", 236 | " result.backward()\n", 237 | " \n", 238 | " # Apply gradient descent without tracking the gradient.\n", 239 | " with torch.no_grad():\n", 240 | " x -= 0.1*x.grad\n", 241 | " y -= 0.1*y.grad\n", 242 | " \n", 243 | " x.grad.zero_()\n", 244 | " y.grad.zero_()" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 51, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "name": "stdout", 254 | "output_type": "stream", 255 | "text": [ 256 | "tensor(4.8889e-10, requires_grad=True) tensor(4.8889e-10, requires_grad=True)\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "print(x, y)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "## Simple neural network on MNIST" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 52, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "from torchvision import datasets, transforms\n", 278 | "import torch.nn.functional as F" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "### Import the dataset\n", 286 | "\n", 287 | "" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 53, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "# Transform each image into tensor and normalized with mean and std\n", 297 | "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", 298 | "# Define the batch size used each time we go through the dataset\n", 299 | "batch_size = 32\n", 300 | "\n", 301 | "# Set the training loader\n", 302 | "train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)\n", 303 | "# Set the testing loader\n", 304 | "test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, download=True, transform=transform), batch_size=batch_size, shuffle=True)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "### Init the weights" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 56, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "## Init weights\n", 321 | "# 784 because there is 784 pixels in each image\n", 322 | "# 10 because there is 10 possible outputs : 0,1,2,3,4,5,6,7,8,9\n", 323 | "# Each pixel is linked to 10 outputs where each link is a weight to optimize\\\n", 324 | "# <=> Each class is linked to 784 pixel where each link is a weight to optimize\n", 325 | "weights = torch.randn(784, 10, requires_grad=True)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "### Compute the accuracy on the test set\n", 333 | "\n", 334 | "" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 83, 340 | "metadata": {}, 341 | "outputs": [ 342 | { 343 | "name": "stdout", 344 | "output_type": "stream", 345 | "text": [ 346 | " Accuracy on test set 0.0816\n" 347 | ] 348 | } 349 | ], 350 | "source": [ 351 | "def test(weights, test_loader):\n", 352 | " test_size = len(test_loader.dataset)\n", 353 | " correct = 0\n", 354 | "\n", 355 | " for batch_idx, (data, target) in enumerate(test_loader):\n", 356 | " #print(batch_idx, data.shape, target.shape)\n", 357 | " data = data.view((-1, 28*28))\n", 358 | " #print(batch_idx, data.shape, target.shape)\n", 359 | "\n", 360 | " outputs = torch.matmul(data, weights)\n", 361 | " softmax = F.softmax(outputs, dim=1)\n", 362 | " pred = softmax.argmax(dim=1, keepdim=True)\n", 363 | " n_correct = pred.eq(target.view_as(pred)).sum().item()\n", 364 | " correct += n_correct\n", 365 | "\n", 366 | " acc = correct / test_size\n", 367 | " print(\" Accuracy on test set\", acc)\n", 368 | " return\n", 369 | "\n", 370 | "test(weights, test_loader)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "metadata": {}, 376 | "source": [ 377 | "### Train the model\n", 378 | "\n", 379 | "" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 98, 385 | "metadata": { 386 | "scrolled": true 387 | }, 388 | "outputs": [ 389 | { 390 | "name": "stdout", 391 | "output_type": "stream", 392 | "text": [ 393 | "Loss shape: 5.5470104217529375 Accuracy on test set 0.6676\n", 394 | "Loss shape: 3.7371554374694824 Accuracy on test set 0.7678\n", 395 | "Loss shape: 3.7003250122070312 Accuracy on test set 0.7969\n", 396 | "Loss shape: 3.63990139961242685 Accuracy on test set 0.825\n", 397 | "Loss shape: 2.156468629837036787 Accuracy on test set 0.8363\n", 398 | "Loss shape: 1.54896616935729985 Accuracy on test set 0.8345\n", 399 | "Loss shape: 2.52677869796752935 Accuracy on test set 0.842\n", 400 | "Loss shape: 2.875357151031494713 Accuracy on test set 0.8573\n", 401 | "Loss shape: 0.84983670711517333 Accuracy on test set 0.8513\n", 402 | "Loss shape: 0.92615586519241333 Accuracy on test set 0.8582\n", 403 | "Loss shape: 1.82365214824676517 Accuracy on test set 0.8635\n", 404 | "Loss shape: 0.69127744436264046 Accuracy on test set 0.8644\n", 405 | "Loss shape: 1.83464241027832034 Accuracy on test set 0.8613\n", 406 | "Loss shape: 0.3489245176315307666 Accuracy on test set 0.8659\n", 407 | "Loss shape: 1.96130001544952487 Accuracy on test set 0.8691\n", 408 | "Loss shape: 0.52960449457168586 Accuracy on test set 0.8714\n", 409 | "Loss shape: 2.4448368549346924447 Accuracy on test set 0.8688\n", 410 | "Loss shape: 0.0516575723886489944 Accuracy on test set 0.8669\n", 411 | "Loss shape: 0.0055376142263412476" 412 | ] 413 | } 414 | ], 415 | "source": [ 416 | "it = 0\n", 417 | "for batch_idx, (data, targets) in enumerate(train_loader):\n", 418 | " # Be sure to start the loop with zeros grad\n", 419 | " if weights.grad is not None:\n", 420 | " weights.grad.zero_()\n", 421 | " \n", 422 | " data = data.view((-1, 28*28))\n", 423 | " #print(\"batch_idx: {}, data.shape: {}, target.shape: {}\".format(batch_idx, data.shape, targets.shape))\n", 424 | " outputs = torch.matmul(data, weights)\n", 425 | " #print(\"outputs.shape: {}\".format(outputs.shape))\n", 426 | "\n", 427 | " log_softmax = F.log_softmax(outputs, dim=1)\n", 428 | " #print(\"Log softmax: {}\".format(log_softmax.shape))\n", 429 | "\n", 430 | " #print((-log_softmax[0][targets[0]] + -log_softmax[1][targets[1]] ) / 2 )\n", 431 | " #print(-log_softmax[0][targets[0]], targets[0])\n", 432 | " \n", 433 | " loss = F.nll_loss(log_softmax, targets)\n", 434 | " print(\"\\rLoss shape: {}\".format(loss), end=\"\")\n", 435 | " \n", 436 | " # Compute the gradients for each variables\n", 437 | " loss.backward()\n", 438 | " \n", 439 | " with torch.no_grad():\n", 440 | " weights -= 0.1*weights.grad\n", 441 | " \n", 442 | " it += 1\n", 443 | " if it % 100 == 0:\n", 444 | " test(weights, test_loader)\n", 445 | " \n", 446 | " if it > 5000:\n", 447 | " break" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 111, 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "data": { 457 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAUkUlEQVR4nO3de5BcdZnG8e9DQoBKghAgIQmQaCAiWEuEENiC0iACAZdLrMLiHlwwKLJAlXIpdBUXUUIpaK0ua9ywJtxBQEKBCoVi1gtCZDEGAhKpQALZDDclkYiGvPvH+Q3VDN2nJ32Z7szv+VRNTfd5z+Xtnnn6nD6nTx9FBGY2+G3R6QbMbGA47GaZcNjNMuGwm2XCYTfLhMNulgmHvUmSJkoKSUPT/R9JmjUAy71U0vUNTPegpDPb0dPmRtLXJJ2fbp8u6U1J6yS9r5/T/1TSXyX9It0fI2mZpK3a2Xejsgi7pBWS1qc/5BpJ/y1pRDuWFRFHRsT8fvb0kXb00C0kTZe0qtN9VCNpJ+A04LsVg38dESMiYlka5wRJT0n6s6QeSfMlbds7ckR8GPhUxf01wM+A2QPzKDZNFmFPjo6IEcC+wP7AF/qOoEJOz0l2erfAgNOBeyNifcnovwQOioh3Ae8BhgJfqbOIG4Czmu2zHbL7x46I54EfAe+HtzZrL5f0S+B14D2S3iVpnqTVkp6X9BVJQ9L4QyR9XdJLkp4BPlo5/76byZI+mTbt1kp6QtK+kq4DdgPuTlsbF6ZxD5T0K0l/kvQ7SdMr5vNuST9P87kf2LHscUo6VtJjkl6T9EdJM6qMMyltir6cHs8NkrarqF+UHv/atIY7NA2fJmlxmvcaSVdVmffw9DyPS49xnaRxkraQdHHq6WVJt0oalabpfUs0S9JzqafPV8yz5nIlHSPp8fTcPVi5KZ62oi6StAT4Swr8kcDPy57DiFgZES9VDHoT2L1sGuA3FP9DE+qMN/AiYtD/ACuAj6TbuwKPA5el+w8CzwF7U7xybwn8kGLzbjgwGngYOCuN/yngyTSfURSbbQEMrZjfmen28cDzFFsSovhHmdC3p3R/PPAycBTFi/Bh6f5Oqf5r4CpgK+CDwFrg+hqPdxrw5zSPLdK896zS3+5pnK2AnYBFwDdT7b3ASmBcuj8RmFTRy6np9gjgwBp9TAdW9Rl2PvAQsEta7neBmyqWEcD3gG2AfYA3gPeVLReYDPwlPZYtgQuB5cCwiuf6sfQ32yYNexHYv6Kv04FfVHkMB6fnMtIyDu9Tf8d0wBLgmE7/37/jsXS6gQF5kMUfex3wJ+BZ4D8q/ugPAv9WMe6Y9A+2TcWwE4Gfpds/BT5VUTuc2mH/CXBeSU+VYb8IuK7POD8BZlFsBWwAhlfUbqR22L8LXF2j9lZ/VWrHAf+bbu8O9AAfAbbsM94i4MvAjnWe92phXwYcWnF/LPB3ihfa3rDvUlF/GDihbLnAvwK3VtzfguJFdnrFc/3Pfab5O+kFMN2vGvaK+njgUmByn+HVwv5L4LRO/9/3/clpM/64iNguIiZExNnx9vdqKytuT6BYO6xOm4R/ogjP6FQf12f8Z0uWuSvwx372NwE4vneZabkHU4RhHPBqRPyllcuVNFrSzWlT/TXgetLbg4hYTrEWvhToSeONS5OeQbE2fVLSI5L+qZ+Psfdx3lnxGJdRbB6PqRjn/ypuv06xFi9b7jgqno+I2EjxNxpfMZ/KvxnAq8DI/jYdxdu/HwM392P0kRQrlq6SU9jLVJ76t5Jizb5jenHYLiK2jYi9U301RZh67VYy35XApH4ss3fc6yqWuV1EDI+IK9Iyt0/vg5tdbqWvpT7+ISK2BU6heLtRNBhxY0QcTBHQAOak4U9HxIkUL4BzgB/06a3WY+zt7cg+j3PrFKZSJct9IfUIFDtaKf5GlfPs28sSiheOTTGUOs9r2h+wO/C7TZx32znsfUTEauA+4BuStk07lCZJ+lAa5VbgXEm7SNoeuLhkdv8FfE7SfmlP/+4VO27WUOzh7XU9cLSkI9JOwK1VHLraJSKeBRYDX5Y0TNLBwNEly50HfELSoan/8ZL2rDLeSNLbG0njgQt6C5LeK+nDKo4Z/xVYT7EGRtIpknZKa9DeNdibVea/BthB0rsqhv0ncHnv8yBpJ0nHljyWt5Qs91bgo+nxbgl8luIF+1cls7sX+FBJHUknS9ot/e0mAJcDD9RpcxqwIv3NuorDXt1pwDDgCYrNvR9QbE5DsfPoJxSv3I8Cd9SaSUTcRvEPciPFDrUfUuzUg2Kt+oW0Ofu5iFgJHAtcQrHzaCVF+Hr/RicBBwCvAF8CFpQs92HgE8DVFDuXfk7Fmq/ClykORf4ZuKfPY9kKuAJ4iWKzenTqDWAG8LikdcC3KN5T/7VKH08CNwHPpMc5Lo2/ELhP0lqKnXUH1HosfVRdbkQ8RbFV8u+p36MpDrX+rWReC4CjJG1TMs5eFC8Y6yjehz8FfLJOjydTvKB1HaUdCmbZkfRVoCcivinpVIp9M38D/jHSB2vqTH8/cCDwcEQcKmk0xQvrB6q9+HWaw26WCW/Gm2XCYTfLhMNulomh9UdpHUneQWDWZhGhasObWrNLmpFOkFguqex4s5l1WMN741WcBfYHipMPVgGPACdGxBMl03jNbtZm7VizTwOWR8Qz6cMLN1N8KMTMulAzYR/P208uWMXbTzwAQNLsdA7y4iaWZWZNamYHXbVNhXdspkfEXGAueDPerJOaWbOv4u1nf+1CcfaRmXWhZsL+CLCHiq9LGgacQHGCg5l1oYY34yNig6RzKM4AGwJcGxGPt6wzM2upAT0Rxu/ZzdqvLR+qMbPNh8NulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZaPj67ACSVgBrgTeBDRExtRVNmVnrNRX25JCIeKkF8zGzNvJmvFkmmg17APdJ+q2k2dVGkDRb0mJJi5tclpk1QRHR+MTSuIh4QdJo4H7gXyJiUcn4jS/MzPolIlRteFNr9oh4If3uAe4EpjUzPzNrn4bDLmm4pJG9t4HDgaWtaszMWquZvfFjgDsl9c7nxoj4cUu66oBRo0aV1idPnlyzdtJJJ7W6nbfZY489SutHHHFEzVr6+zTsoYceKq0vXLiwtH777bfXrC1fvrx02o0bN5bWbdM0HPaIeAbYp4W9mFkb+dCbWSYcdrNMOOxmmXDYzTLhsJtloqlP0G3ywtr4CboRI0aU1qdPn15a/+IXv1ha32+//Ta1Jauj3mG7s88+u7S+evXqVrYzaLTlE3Rmtvlw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmBs1x9nvuuae0PmPGjHYt2trk7rvvLq3PnDmztD6Q/9vdxMfZzTLnsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMDJrj7PW+driTx1zrLfuJJ55oav7z5s2rWVu/fn1T865n3LhxpfULLrigZm3rrbduatnnnXdeaf3b3/52U/PfXPk4u1nmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WiWYu2ZyVtWvX1qydc845pdM+99xzpfVFixY11NPmYMKECTVrp512WlPz3muvvZqaPjd11+ySrpXUI2lpxbBRku6X9HT6vX172zSzZvVnM/77QN+vebkYeCAi9gAeSPfNrIvVDXtELAJe6TP4WGB+uj0fOK7FfZlZizX6nn1MRKwGiIjVkkbXGlHSbGB2g8sxsxZp+w66iJgLzIX2nghjZuUaPfS2RtJYgPS7p3UtmVk7NBr2hcCsdHsWcFdr2jGzdqm7GS/pJmA6sKOkVcCXgCuAWyWdATwHHN/OJvtjw4YNpfUhQ4aU1l9//fXS+uTJk2vWenry3bA56KCDSusnnXTSAHVi9dQNe0ScWKN0aIt7MbM28sdlzTLhsJtlwmE3y4TDbpYJh90sE4PmFNdjjjmmtH7mmWeW1q+88srSeq6H1yZNmlRaX7BgQWl96ND2/Yu9+OKLbZv3YOQ1u1kmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WiUFzyWZrzGGHHVZa/853vlNar3ccvhm33HJLab3eZyfqnbY8WPmSzWaZc9jNMuGwm2XCYTfLhMNulgmH3SwTDrtZJnycfTMwbNiw0vq5555bs1bvPP8DDjigtN7O89EXLlxYWj/hhBNK62+88UYr2xk0fJzdLHMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8vEoPne+MFs5syZpfU5c+YMUCettdtuu5XWP/axj5XW77rrrtJ6ruez11J3zS7pWkk9kpZWDLtU0vOSHks/R7W3TTNrVn82478PzKgy/OqImJJ+7m1tW2bWanXDHhGLgFcGoBcza6NmdtCdI2lJ2szfvtZIkmZLWixpcRPLMrMmNRr2a4BJwBRgNfCNWiNGxNyImBoRUxtclpm1QENhj4g1EfFmRGwEvgdMa21bZtZqDYVd0tiKuzOBpbXGNbPuUPc4u6SbgOnAjpJWAV8CpkuaAgSwAjirjT3aIDVlypTS+vXXX19av/nmm0vrl112Wc3ak08+WTrtYFQ37BFxYpXB89rQi5m1kT8ua5YJh90sEw67WSYcdrNMOOxmmfBXSW8G9t5779L6V7/61Zq1el8VXc/cuXNL6zvssENp/eSTT65ZGzlyZEM99VdPT0/N2vTp00unfeqpp1rczcDxV0mbZc5hN8uEw26WCYfdLBMOu1kmHHazTDjsZpnwcXZrq4kTJ9asHXLIIaXTXnnllaX1UaNGNdISADfeeGNpffbs2aX19evXN7zsdvNxdrPMOexmmXDYzTLhsJtlwmE3y4TDbpYJh90sEz7Obl2r7KugAS655JK2LbvZy0V3ko+zm2XOYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZqHucXdKuwAJgZ2AjMDciviVpFHALMJHiss0fj4hX68wry+PsW2xR/pp66qmnltb33HPP0vrll19es7Zu3brSabvZkCFDSutLliwprdd73so8+uijpfX999+/4Xm3WzPH2TcAn42I9wEHAp+RtBdwMfBAROwBPJDum1mXqhv2iFgdEY+m22uBZcB44FhgfhptPnBcu5o0s+Zt0nt2SROBDwC/AcZExGooXhCA0a1uzsxaZ2h/R5Q0ArgdOD8iXpOqvi2oNt1soPwLvcys7fq1Zpe0JUXQb4iIO9LgNZLGpvpYoOpV9CJibkRMjYiprWjYzBpTN+wqVuHzgGURcVVFaSEwK92eBXTvaUBm1q/N+IOAU4HfS3osDbsEuAK4VdIZwHPA8e1pcfO31VZbldbLLrkMsPPOO5fWyw6vlR2W63aHH354aX38+PFtW/a+++7btnl3St2wR8QvgFpv0A9tbTtm1i7+BJ1ZJhx2s0w47GaZcNjNMuGwm2XCYTfLhL9KugtMmTKltF7v8sGnnHJKzdqzzz5bOu1tt91WWp8zZ05p/Y033iitlxkxYkRpfdGiRaX1ffbZp+Fl17N06dKOLbtZ/ipps8w57GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTPs4+CMybN69m7fjjy79mYPjw4aX1DRs2NNRTf9T7arN6XyVdT1nv11xzTem09S4X/fLLLzfU00DwcXazzDnsZplw2M0y4bCbZcJhN8uEw26WCYfdLBM+zj7IXXjhhaX1Qw45pKn51zuve8yYMU3Nvxll58M3+7i7mY+zm2XOYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZqHucXdKuwAJgZ2AjMDciviXpUuCTwItp1Esi4t468/Jx9kFm7NixpfWy8+U//elPl067Zs2a0vodd9xRWi875/zVV18tnXZzVus4e93rswMbgM9GxKOSRgK/lXR/ql0dEV9vVZNm1j51wx4Rq4HV6fZaScuA8e1uzMxaa5Pes0uaCHwA+E0adI6kJZKulbR9jWlmS1osaXFTnZpZU/oddkkjgNuB8yPiNeAaYBIwhWLN/41q00XE3IiYGhFTW9CvmTWoX2GXtCVF0G+IiDsAImJNRLwZERuB7wHT2temmTWrbthVfAXoPGBZRFxVMbxyN+xMoPyyl2bWUf059HYw8D/A7ykOvQFcApxIsQkfwArgrLQzr2xePvRm1ma1Dr35fHazQcbns5tlzmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNM9OfbZVvpJeDZivs7pmHdqFt769a+wL01qpW9TahVGNDz2d+xcGlxt343Xbf21q19gXtr1ED15s14s0w47GaZ6HTY53Z4+WW6tbdu7QvcW6MGpLeOvmc3s4HT6TW7mQ0Qh90sEx0Ju6QZkp6StFzSxZ3ooRZJKyT9XtJjnb4+XbqGXo+kpRXDRkm6X9LT6XfVa+x1qLdLJT2fnrvHJB3Vod52lfQzScskPS7pvDS8o89dSV8D8rwN+Ht2SUOAPwCHAauAR4ATI+KJAW2kBkkrgKkR0fEPYEj6ILAOWBAR70/DrgReiYgr0gvl9hFxUZf0dimwrtOX8U5XKxpbeZlx4DjgdDr43JX09XEG4HnrxJp9GrA8Ip6JiL8BNwPHdqCPrhcRi4BX+gw+Fpifbs+n+GcZcDV66woRsToiHk231wK9lxnv6HNX0teA6ETYxwMrK+6voruu9x7AfZJ+K2l2p5upYkzvZbbS79Ed7qevupfxHkh9LjPeNc9dI5c/b1Ynwl7t0jTddPzvoIjYFzgS+EzaXLX+6ddlvAdKlcuMd4VGL3/erE6EfRWwa8X9XYAXOtBHVRHxQvrdA9xJ912Kek3vFXTT754O9/OWbrqMd7XLjNMFz10nL3/eibA/Auwh6d2ShgEnAAs70Mc7SBqedpwgaThwON13KeqFwKx0exZwVwd7eZtuuYx3rcuM0+HnruOXP4+IAf8BjqLYI/9H4POd6KFGX+8Bfpd+Hu90b8BNFJt1f6fYIjoD2AF4AHg6/R7VRb1dR3Fp7yUUwRrbod4OpnhruAR4LP0c1ennrqSvAXne/HFZs0z4E3RmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSb+H/5NMsfTFob+AAAAAElFTkSuQmCC\n", 458 | "text/plain": [ 459 | "
" 460 | ] 461 | }, 462 | "metadata": { 463 | "needs_background": "light" 464 | }, 465 | "output_type": "display_data" 466 | } 467 | ], 468 | "source": [ 469 | "import matplotlib.pyplot as plt\n", 470 | "\n", 471 | "batch_idx, (data, target) = next(enumerate(test_loader))\n", 472 | "data = data.view((-1, 28*28))\n", 473 | "\n", 474 | "outputs = torch.matmul(data, weights)\n", 475 | "softmax = F.softmax(outputs, dim=1)\n", 476 | "pred = softmax.argmax(dim=1, keepdim=True)\n", 477 | "\n", 478 | "plt.imshow(data[0].view(28, 28), cmap=\"gray\")\n", 479 | "plt.title(\"Predicted class {}\".format(pred[0]))\n", 480 | "plt.show()" 481 | ] 482 | } 483 | ], 484 | "metadata": { 485 | "kernelspec": { 486 | "display_name": "Python 3", 487 | "language": "python", 488 | "name": "python3" 489 | }, 490 | "language_info": { 491 | "codemirror_mode": { 492 | "name": "ipython", 493 | "version": 3 494 | }, 495 | "file_extension": ".py", 496 | "mimetype": "text/x-python", 497 | "name": "python", 498 | "nbconvert_exporter": "python", 499 | "pygments_lexer": "ipython3", 500 | "version": "3.7.3" 501 | } 502 | }, 503 | "nbformat": 4, 504 | "nbformat_minor": 2 505 | } 506 | -------------------------------------------------------------------------------- /images/intro_pytorch_gradient_descent.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thibo73800/pytorch_nlp/0ce992d674de7e4a133fb033d8ba69ac0b386762/images/intro_pytorch_gradient_descent.PNG -------------------------------------------------------------------------------- /images/intro_pytorch_mnist.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thibo73800/pytorch_nlp/0ce992d674de7e4a133fb033d8ba69ac0b386762/images/intro_pytorch_mnist.PNG -------------------------------------------------------------------------------- /images/intro_pytorch_mnist_operations.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thibo73800/pytorch_nlp/0ce992d674de7e4a133fb033d8ba69ac0b386762/images/intro_pytorch_mnist_operations.PNG -------------------------------------------------------------------------------- /images/intro_pytorch_negativeloglikelihood.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thibo73800/pytorch_nlp/0ce992d674de7e4a133fb033d8ba69ac0b386762/images/intro_pytorch_negativeloglikelihood.PNG -------------------------------------------------------------------------------- /images/pytorhc_vs_tensorflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thibo73800/pytorch_nlp/0ce992d674de7e4a133fb033d8ba69ac0b386762/images/pytorhc_vs_tensorflow.png --------------------------------------------------------------------------------