├── README.md ├── Pytorch-3-Neural Nets.ipynb ├── Pytorch-1-Introduction-to-tensors.ipynb ├── Pytorch-5-Training-Loop.ipynb └── Pytorch-2-Datasets.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Tutorials 2 | 3 | ## Videos 4 | 5 | All the videos are recorded in Persian and available at [Youtube](https://www.youtube.com/watch?v=MCxIN9Ujlx0&list=PLNlYDjW0sqoWrRWNH0aMVrqyOqlVTzBTj). 6 | 7 | ## Notebooks 8 | 9 | 1. [Introduction to Tensors](https://github.com/AINT-TV/pytorch/blob/main/Pytorch-1-Introduction-to-tensors.ipynb) 10 | 2. [Datasets](https://github.com/AINT-TV/pytorch/blob/main/Pytorch-2-Datasets.ipynb) 11 | 3. [Neural Net Intro.](https://github.com/AINT-TV/pytorch/blob/main/Pytorch-3-Neural%20Nets.ipynb) 12 | 4. [Automatic Differentiation](https://github.com/AINT-TV/pytorch/blob/main/Pytorch_4_AutoGrad.ipynb) 13 | 5. [The Training Loop](https://github.com/AINT-TV/pytorch/blob/main/Pytorch-5-Training-Loop.ipynb) 14 | 15 | More notebooks will be added soon! 16 | -------------------------------------------------------------------------------- /Pytorch-3-Neural Nets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2e57d0c7", 6 | "metadata": {}, 7 | "source": [ 8 | "# Importing the Necessary Libraries" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 52, 14 | "id": "eda2431b", 15 | "metadata": { 16 | "ExecuteTime": { 17 | "end_time": "2022-05-27T14:26:46.850430Z", 18 | "start_time": "2022-05-27T14:26:46.737013Z" 19 | } 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import torch" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "bd8eb6c1", 29 | "metadata": {}, 30 | "source": [ 31 | "# Getting a device to work with" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 53, 37 | "id": "b214b511", 38 | "metadata": { 39 | "ExecuteTime": { 40 | "end_time": "2022-05-27T14:27:31.756293Z", 41 | "start_time": "2022-05-27T14:27:31.722395Z" 42 | } 43 | }, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "'cpu'" 49 | ] 50 | }, 51 | "execution_count": 53, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 58 | "device" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "a1afe545", 64 | "metadata": {}, 65 | "source": [ 66 | "# A Simple Neural Net" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 54, 72 | "id": "ba4cadf6", 73 | "metadata": { 74 | "ExecuteTime": { 75 | "end_time": "2022-05-27T14:32:32.893333Z", 76 | "start_time": "2022-05-27T14:32:32.865121Z" 77 | } 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "class my_neural_net(torch.nn.Module):\n", 82 | " def __init__(self):\n", 83 | " # first of all, initialize the original class's object\n", 84 | " super(my_neural_net, self).__init__() \n", 85 | " # second, define layers of the NN\n", 86 | " self.first_layer = torch.nn.Sequential( \n", 87 | " torch.nn.Linear(10000, 100), # this layer first maps a 10000-long input to a 100-d vector\n", 88 | " torch.nn.Softmax(dim=1) # and then applies the softmax on the 100-d vector\n", 89 | " )\n", 90 | " self.flatten = torch.nn.Flatten()\n", 91 | " def forward(self, x):\n", 92 | " # here, we define how we want our neural network to operate on a given input\n", 93 | " x = self.flatten(x)\n", 94 | " output = self.first_layer(x)\n", 95 | " return output" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "c54242a9", 101 | "metadata": {}, 102 | "source": [ 103 | "Make an object of that class" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 55, 109 | "id": "7477d194", 110 | "metadata": { 111 | "ExecuteTime": { 112 | "end_time": "2022-05-27T14:32:36.150878Z", 113 | "start_time": "2022-05-27T14:32:36.045299Z" 114 | } 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "simple_nn = my_neural_net()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "41288f10", 124 | "metadata": {}, 125 | "source": [ 126 | "Move the object to the device" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 56, 132 | "id": "3f87c9d4", 133 | "metadata": { 134 | "ExecuteTime": { 135 | "end_time": "2022-05-27T14:33:05.197634Z", 136 | "start_time": "2022-05-27T14:33:05.161537Z" 137 | } 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "simple_nn = simple_nn.to(device)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 57, 147 | "id": "4439fcd3", 148 | "metadata": { 149 | "ExecuteTime": { 150 | "end_time": "2022-05-27T14:33:12.759337Z", 151 | "start_time": "2022-05-27T14:33:12.735981Z" 152 | } 153 | }, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "text/plain": [ 158 | "my_neural_net(\n", 159 | " (first_layer): Sequential(\n", 160 | " (0): Linear(in_features=10000, out_features=100, bias=True)\n", 161 | " (1): Softmax(dim=1)\n", 162 | " )\n", 163 | " (flatten): Flatten(start_dim=1, end_dim=-1)\n", 164 | ")" 165 | ] 166 | }, 167 | "execution_count": 57, 168 | "metadata": {}, 169 | "output_type": "execute_result" 170 | } 171 | ], 172 | "source": [ 173 | "simple_nn" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "091e6f10", 179 | "metadata": {}, 180 | "source": [ 181 | "## Exploring the Flatten Layer" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 58, 187 | "id": "2fdba93c", 188 | "metadata": { 189 | "ExecuteTime": { 190 | "end_time": "2022-05-27T14:33:36.253989Z", 191 | "start_time": "2022-05-27T14:33:36.232322Z" 192 | } 193 | }, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "" 199 | ] 200 | }, 201 | "execution_count": 58, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | } 205 | ], 206 | "source": [ 207 | "simple_nn.flatten.parameters" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "id": "0f577dfc", 213 | "metadata": {}, 214 | "source": [ 215 | "## Exploring the First Layer" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 59, 221 | "id": "cc6ce9d3", 222 | "metadata": { 223 | "ExecuteTime": { 224 | "end_time": "2022-05-27T14:33:48.915474Z", 225 | "start_time": "2022-05-27T14:33:48.893975Z" 226 | } 227 | }, 228 | "outputs": [ 229 | { 230 | "data": { 231 | "text/plain": [ 232 | "" 236 | ] 237 | }, 238 | "execution_count": 59, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "simple_nn.first_layer.parameters" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 60, 250 | "id": "2e164dad", 251 | "metadata": { 252 | "ExecuteTime": { 253 | "end_time": "2022-05-27T14:34:19.311411Z", 254 | "start_time": "2022-05-27T14:34:19.167189Z" 255 | } 256 | }, 257 | "outputs": [ 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "Parameter containing:\n", 262 | "tensor([[-0.0011, 0.0041, 0.0034, ..., 0.0078, -0.0096, -0.0040],\n", 263 | " [-0.0090, 0.0009, -0.0025, ..., -0.0088, 0.0071, -0.0042],\n", 264 | " [ 0.0058, -0.0024, 0.0058, ..., 0.0085, -0.0051, 0.0041],\n", 265 | " ...,\n", 266 | " [ 0.0022, -0.0050, -0.0042, ..., -0.0089, -0.0062, 0.0081],\n", 267 | " [ 0.0003, -0.0025, -0.0011, ..., 0.0094, -0.0006, 0.0034],\n", 268 | " [-0.0009, 0.0035, 0.0073, ..., 0.0039, 0.0008, 0.0028]],\n", 269 | " requires_grad=True)" 270 | ] 271 | }, 272 | "execution_count": 60, 273 | "metadata": {}, 274 | "output_type": "execute_result" 275 | } 276 | ], 277 | "source": [ 278 | "simple_nn.first_layer[0].weight # weights of the linear layer" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 61, 284 | "id": "55304de4", 285 | "metadata": { 286 | "ExecuteTime": { 287 | "end_time": "2022-05-27T14:34:22.251731Z", 288 | "start_time": "2022-05-27T14:34:22.232984Z" 289 | } 290 | }, 291 | "outputs": [ 292 | { 293 | "data": { 294 | "text/plain": [ 295 | "torch.Size([100, 10000])" 296 | ] 297 | }, 298 | "execution_count": 61, 299 | "metadata": {}, 300 | "output_type": "execute_result" 301 | } 302 | ], 303 | "source": [ 304 | "simple_nn.first_layer[0].weight.shape" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 62, 310 | "id": "d64d221e", 311 | "metadata": { 312 | "ExecuteTime": { 313 | "end_time": "2022-05-27T14:34:49.590060Z", 314 | "start_time": "2022-05-27T14:34:49.563129Z" 315 | } 316 | }, 317 | "outputs": [ 318 | { 319 | "data": { 320 | "text/plain": [ 321 | "Parameter containing:\n", 322 | "tensor([-1.3796e-03, 3.2914e-03, 3.9477e-03, -5.9671e-03, -4.8569e-03,\n", 323 | " -6.1478e-03, -1.2100e-04, -4.0389e-03, -2.3757e-03, 3.1359e-04,\n", 324 | " -5.5810e-03, -2.0566e-03, -8.5154e-03, 9.4019e-03, -7.9394e-03,\n", 325 | " -5.3935e-03, 2.7543e-04, 9.4502e-03, 8.5511e-03, -9.9402e-03,\n", 326 | " 4.0279e-03, 9.2692e-03, 4.6177e-03, -1.7801e-03, -8.2470e-03,\n", 327 | " 1.8742e-05, 4.5681e-03, 4.6123e-04, -2.1608e-03, -7.9991e-03,\n", 328 | " -6.1995e-03, -4.5583e-03, -2.0524e-03, -6.1808e-03, 5.2230e-03,\n", 329 | " -8.2621e-03, -4.7253e-03, 8.3447e-03, -5.7234e-03, 9.3964e-03,\n", 330 | " -9.2948e-03, 1.0814e-03, -3.5466e-03, -4.9322e-03, 9.5777e-03,\n", 331 | " -3.4608e-03, 9.2246e-03, -4.7672e-03, 3.4156e-03, 1.3527e-03,\n", 332 | " 2.6197e-03, 3.7030e-03, -9.8795e-03, 8.8686e-03, -8.4393e-03,\n", 333 | " 2.8033e-03, 1.3842e-04, -2.9049e-03, -2.5941e-03, -9.3400e-03,\n", 334 | " 3.0445e-03, 9.1432e-03, 7.4289e-04, 7.9431e-03, 9.7155e-03,\n", 335 | " -7.7544e-03, -5.7786e-03, -7.6822e-03, 6.0548e-03, -7.3356e-03,\n", 336 | " 7.6280e-03, 9.1774e-03, -7.5331e-03, -9.0172e-03, -5.5731e-03,\n", 337 | " 4.7288e-03, -4.0011e-03, 7.2871e-03, -8.9371e-03, -9.7504e-03,\n", 338 | " -5.4372e-03, -7.0768e-03, -6.3307e-03, -6.2958e-03, -5.8411e-03,\n", 339 | " 3.1092e-03, 6.5856e-03, -7.3472e-03, 1.1687e-03, 6.9618e-03,\n", 340 | " 7.4277e-04, 5.5266e-03, -7.3987e-03, 8.7863e-03, 3.8270e-05,\n", 341 | " -4.7482e-03, -4.5880e-03, -9.2353e-03, 1.5453e-03, -6.9686e-03],\n", 342 | " requires_grad=True)" 343 | ] 344 | }, 345 | "execution_count": 62, 346 | "metadata": {}, 347 | "output_type": "execute_result" 348 | } 349 | ], 350 | "source": [ 351 | "simple_nn.first_layer[0].bias # biases of the linear layer" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 63, 357 | "id": "6cf4dcd6", 358 | "metadata": { 359 | "ExecuteTime": { 360 | "end_time": "2022-05-27T14:35:00.486075Z", 361 | "start_time": "2022-05-27T14:35:00.463769Z" 362 | } 363 | }, 364 | "outputs": [ 365 | { 366 | "data": { 367 | "text/plain": [ 368 | "torch.Size([100])" 369 | ] 370 | }, 371 | "execution_count": 63, 372 | "metadata": {}, 373 | "output_type": "execute_result" 374 | } 375 | ], 376 | "source": [ 377 | "simple_nn.first_layer[0].bias.shape" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 64, 383 | "id": "dda3255f", 384 | "metadata": { 385 | "ExecuteTime": { 386 | "end_time": "2022-05-27T14:35:14.256129Z", 387 | "start_time": "2022-05-27T14:35:14.222219Z" 388 | } 389 | }, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "text/plain": [ 394 | "Softmax(dim=1)" 395 | ] 396 | }, 397 | "execution_count": 64, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "simple_nn.first_layer[1]" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "id": "17715978", 409 | "metadata": {}, 410 | "source": [ 411 | "# Testing the Model" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 65, 417 | "id": "1f96b6df", 418 | "metadata": { 419 | "ExecuteTime": { 420 | "end_time": "2022-05-27T14:35:33.232183Z", 421 | "start_time": "2022-05-27T14:35:33.192717Z" 422 | } 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "sample_input = torch.rand(1,100,100)" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 66, 432 | "id": "975ec027", 433 | "metadata": { 434 | "ExecuteTime": { 435 | "end_time": "2022-05-27T14:35:38.460595Z", 436 | "start_time": "2022-05-27T14:35:38.452435Z" 437 | } 438 | }, 439 | "outputs": [ 440 | { 441 | "data": { 442 | "text/plain": [ 443 | "torch.Size([1, 100, 100])" 444 | ] 445 | }, 446 | "execution_count": 66, 447 | "metadata": {}, 448 | "output_type": "execute_result" 449 | } 450 | ], 451 | "source": [ 452 | "sample_input.shape" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 68, 458 | "id": "ecff0780", 459 | "metadata": { 460 | "ExecuteTime": { 461 | "end_time": "2022-05-27T14:35:47.453455Z", 462 | "start_time": "2022-05-27T14:35:47.437416Z" 463 | } 464 | }, 465 | "outputs": [ 466 | { 467 | "data": { 468 | "text/plain": [ 469 | "tensor([[0.0046, 0.0100, 0.0096, 0.0064, 0.0126, 0.0100, 0.0185, 0.0066, 0.0085,\n", 470 | " 0.0089, 0.0053, 0.0076, 0.0111, 0.0177, 0.0065, 0.0250, 0.0116, 0.0110,\n", 471 | " 0.0061, 0.0096, 0.0083, 0.0120, 0.0083, 0.0117, 0.0116, 0.0131, 0.0118,\n", 472 | " 0.0027, 0.0072, 0.0191, 0.0071, 0.0127, 0.0101, 0.0100, 0.0118, 0.0108,\n", 473 | " 0.0087, 0.0112, 0.0131, 0.0107, 0.0079, 0.0091, 0.0090, 0.0085, 0.0091,\n", 474 | " 0.0106, 0.0112, 0.0103, 0.0106, 0.0118, 0.0085, 0.0073, 0.0072, 0.0056,\n", 475 | " 0.0072, 0.0097, 0.0160, 0.0075, 0.0131, 0.0082, 0.0071, 0.0091, 0.0057,\n", 476 | " 0.0046, 0.0103, 0.0103, 0.0066, 0.0067, 0.0063, 0.0094, 0.0039, 0.0104,\n", 477 | " 0.0070, 0.0029, 0.0083, 0.0074, 0.0131, 0.0137, 0.0138, 0.0087, 0.0061,\n", 478 | " 0.0096, 0.0083, 0.0164, 0.0186, 0.0059, 0.0119, 0.0150, 0.0090, 0.0077,\n", 479 | " 0.0093, 0.0239, 0.0153, 0.0109, 0.0107, 0.0100, 0.0106, 0.0073, 0.0125,\n", 480 | " 0.0080]], grad_fn=)" 481 | ] 482 | }, 483 | "execution_count": 68, 484 | "metadata": {}, 485 | "output_type": "execute_result" 486 | } 487 | ], 488 | "source": [ 489 | "simple_nn(sample_input)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 69, 495 | "id": "68e46e8a", 496 | "metadata": { 497 | "ExecuteTime": { 498 | "end_time": "2022-05-27T14:35:49.243090Z", 499 | "start_time": "2022-05-27T14:35:49.227907Z" 500 | } 501 | }, 502 | "outputs": [ 503 | { 504 | "data": { 505 | "text/plain": [ 506 | "torch.Size([1, 100])" 507 | ] 508 | }, 509 | "execution_count": 69, 510 | "metadata": {}, 511 | "output_type": "execute_result" 512 | } 513 | ], 514 | "source": [ 515 | "simple_nn(sample_input).shape" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 70, 521 | "id": "fad62c5d", 522 | "metadata": { 523 | "ExecuteTime": { 524 | "end_time": "2022-05-27T14:36:21.832366Z", 525 | "start_time": "2022-05-27T14:36:21.801473Z" 526 | } 527 | }, 528 | "outputs": [ 529 | { 530 | "data": { 531 | "text/plain": [ 532 | "tensor([[0.5143, 0.3422, 0.5147, ..., 0.1284, 0.7533, 0.9834]])" 533 | ] 534 | }, 535 | "execution_count": 70, 536 | "metadata": {}, 537 | "output_type": "execute_result" 538 | } 539 | ], 540 | "source": [ 541 | "simple_nn.flatten(sample_input)" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 71, 547 | "id": "608b759c", 548 | "metadata": { 549 | "ExecuteTime": { 550 | "end_time": "2022-05-27T14:36:25.055947Z", 551 | "start_time": "2022-05-27T14:36:25.049781Z" 552 | } 553 | }, 554 | "outputs": [ 555 | { 556 | "data": { 557 | "text/plain": [ 558 | "torch.Size([1, 10000])" 559 | ] 560 | }, 561 | "execution_count": 71, 562 | "metadata": {}, 563 | "output_type": "execute_result" 564 | } 565 | ], 566 | "source": [ 567 | "simple_nn.flatten(sample_input).shape" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": 72, 573 | "id": "97ae755e", 574 | "metadata": { 575 | "ExecuteTime": { 576 | "end_time": "2022-05-27T14:36:38.168236Z", 577 | "start_time": "2022-05-27T14:36:38.139423Z" 578 | } 579 | }, 580 | "outputs": [ 581 | { 582 | "data": { 583 | "text/plain": [ 584 | "tensor([[-0.7693, 0.0030, -0.0334, -0.4434, 0.2405, 0.0029, 0.6232, -0.4082,\n", 585 | " -0.1580, -0.1069, -0.6261, -0.2655, 0.1094, 0.5806, -0.4161, 0.9235,\n", 586 | " 0.1572, 0.1070, -0.4924, -0.0356, -0.1770, 0.1885, -0.1733, 0.1669,\n", 587 | " 0.1561, 0.2752, 0.1711, -1.3161, -0.3260, 0.6552, -0.3396, 0.2472,\n", 588 | " 0.0213, 0.0052, 0.1735, 0.0885, -0.1368, 0.1171, 0.2802, 0.0724,\n", 589 | " -0.2337, -0.0848, -0.1031, -0.1537, -0.0895, 0.0614, 0.1227, 0.0357,\n", 590 | " 0.0637, 0.1727, -0.1588, -0.3097, -0.3173, -0.5774, -0.3230, -0.0222,\n", 591 | " 0.4743, -0.2770, 0.2789, -0.1854, -0.3404, -0.0821, -0.5541, -0.7767,\n", 592 | " 0.0378, 0.0400, -0.4057, -0.3930, -0.4567, -0.0502, -0.9299, 0.0438,\n", 593 | " -0.3472, -1.2186, -0.1734, -0.2879, 0.2751, 0.3204, 0.3268, -0.1279,\n", 594 | " -0.4940, -0.0317, -0.1765, 0.5049, 0.6267, -0.5123, 0.1804, 0.4135,\n", 595 | " -0.0935, -0.2549, -0.0659, 0.8785, 0.4327, 0.0927, 0.0779, 0.0092,\n", 596 | " 0.0678, -0.3014, 0.2304, -0.2191]], grad_fn=)" 597 | ] 598 | }, 599 | "execution_count": 72, 600 | "metadata": {}, 601 | "output_type": "execute_result" 602 | } 603 | ], 604 | "source": [ 605 | "simple_nn.first_layer[0](simple_nn.flatten(sample_input))" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": 73, 611 | "id": "569a4c27", 612 | "metadata": { 613 | "ExecuteTime": { 614 | "end_time": "2022-05-27T14:36:50.398726Z", 615 | "start_time": "2022-05-27T14:36:50.386953Z" 616 | } 617 | }, 618 | "outputs": [ 619 | { 620 | "data": { 621 | "text/plain": [ 622 | "torch.Size([1, 100])" 623 | ] 624 | }, 625 | "execution_count": 73, 626 | "metadata": {}, 627 | "output_type": "execute_result" 628 | } 629 | ], 630 | "source": [ 631 | "simple_nn.first_layer[0](simple_nn.flatten(sample_input)).shape" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 74, 637 | "id": "3a7a69dd", 638 | "metadata": { 639 | "ExecuteTime": { 640 | "end_time": "2022-05-27T14:37:01.546207Z", 641 | "start_time": "2022-05-27T14:37:01.493345Z" 642 | } 643 | }, 644 | "outputs": [ 645 | { 646 | "data": { 647 | "text/plain": [ 648 | "-6.2187476" 649 | ] 650 | }, 651 | "execution_count": 74, 652 | "metadata": {}, 653 | "output_type": "execute_result" 654 | } 655 | ], 656 | "source": [ 657 | "simple_nn.first_layer[0](simple_nn.flatten(sample_input)).detach().numpy().sum()" 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": 75, 663 | "id": "0f840cef", 664 | "metadata": { 665 | "ExecuteTime": { 666 | "end_time": "2022-05-27T14:37:48.963899Z", 667 | "start_time": "2022-05-27T14:37:48.932936Z" 668 | } 669 | }, 670 | "outputs": [ 671 | { 672 | "data": { 673 | "text/plain": [ 674 | "tensor([[0.0046, 0.0100, 0.0096, 0.0064, 0.0126, 0.0100, 0.0185, 0.0066, 0.0085,\n", 675 | " 0.0089, 0.0053, 0.0076, 0.0111, 0.0177, 0.0065, 0.0250, 0.0116, 0.0110,\n", 676 | " 0.0061, 0.0096, 0.0083, 0.0120, 0.0083, 0.0117, 0.0116, 0.0131, 0.0118,\n", 677 | " 0.0027, 0.0072, 0.0191, 0.0071, 0.0127, 0.0101, 0.0100, 0.0118, 0.0108,\n", 678 | " 0.0087, 0.0112, 0.0131, 0.0107, 0.0079, 0.0091, 0.0090, 0.0085, 0.0091,\n", 679 | " 0.0106, 0.0112, 0.0103, 0.0106, 0.0118, 0.0085, 0.0073, 0.0072, 0.0056,\n", 680 | " 0.0072, 0.0097, 0.0160, 0.0075, 0.0131, 0.0082, 0.0071, 0.0091, 0.0057,\n", 681 | " 0.0046, 0.0103, 0.0103, 0.0066, 0.0067, 0.0063, 0.0094, 0.0039, 0.0104,\n", 682 | " 0.0070, 0.0029, 0.0083, 0.0074, 0.0131, 0.0137, 0.0138, 0.0087, 0.0061,\n", 683 | " 0.0096, 0.0083, 0.0164, 0.0186, 0.0059, 0.0119, 0.0150, 0.0090, 0.0077,\n", 684 | " 0.0093, 0.0239, 0.0153, 0.0109, 0.0107, 0.0100, 0.0106, 0.0073, 0.0125,\n", 685 | " 0.0080]], grad_fn=)" 686 | ] 687 | }, 688 | "execution_count": 75, 689 | "metadata": {}, 690 | "output_type": "execute_result" 691 | } 692 | ], 693 | "source": [ 694 | "simple_nn.first_layer[1](simple_nn.first_layer[0](simple_nn.flatten(sample_input)))" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 76, 700 | "id": "7ad15bc2", 701 | "metadata": { 702 | "ExecuteTime": { 703 | "end_time": "2022-05-27T14:38:01.306679Z", 704 | "start_time": "2022-05-27T14:38:01.288618Z" 705 | } 706 | }, 707 | "outputs": [ 708 | { 709 | "data": { 710 | "text/plain": [ 711 | "torch.Size([1, 100])" 712 | ] 713 | }, 714 | "execution_count": 76, 715 | "metadata": {}, 716 | "output_type": "execute_result" 717 | } 718 | ], 719 | "source": [ 720 | "simple_nn.first_layer[1](simple_nn.first_layer[0](simple_nn.flatten(sample_input))).shape" 721 | ] 722 | }, 723 | { 724 | "cell_type": "code", 725 | "execution_count": 77, 726 | "id": "0a29e708", 727 | "metadata": { 728 | "ExecuteTime": { 729 | "end_time": "2022-05-27T14:38:05.947777Z", 730 | "start_time": "2022-05-27T14:38:05.936800Z" 731 | } 732 | }, 733 | "outputs": [ 734 | { 735 | "data": { 736 | "text/plain": [ 737 | "0.9999999" 738 | ] 739 | }, 740 | "execution_count": 77, 741 | "metadata": {}, 742 | "output_type": "execute_result" 743 | } 744 | ], 745 | "source": [ 746 | "simple_nn.first_layer[1](simple_nn.first_layer[0](simple_nn.flatten(sample_input))).detach().numpy().sum()" 747 | ] 748 | } 749 | ], 750 | "metadata": { 751 | "kernelspec": { 752 | "display_name": "Python [conda env:torch] *", 753 | "language": "python", 754 | "name": "conda-env-torch-py" 755 | }, 756 | "language_info": { 757 | "codemirror_mode": { 758 | "name": "ipython", 759 | "version": 3 760 | }, 761 | "file_extension": ".py", 762 | "mimetype": "text/x-python", 763 | "name": "python", 764 | "nbconvert_exporter": "python", 765 | "pygments_lexer": "ipython3", 766 | "version": "3.9.12" 767 | } 768 | }, 769 | "nbformat": 4, 770 | "nbformat_minor": 5 771 | } 772 | -------------------------------------------------------------------------------- /Pytorch-1-Introduction-to-tensors.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Pytorch-1-Introduction-to-tensors.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyOUp4YX98auyrqVK2K7K8WS"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"id":"9nl3YOC2nylG","executionInfo":{"status":"ok","timestamp":1650021338953,"user_tz":-270,"elapsed":2488,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"outputs":[],"source":["import torch"]},{"cell_type":"markdown","source":["# Intro"],"metadata":{"id":"LGvhpZnBwfZh"}},{"cell_type":"markdown","source":["A tensor is like a multi-dimensional array (a generalization of a matrix). You can make a tensor from a list. This will be 1-dimensional tensor (similar to a vector)."],"metadata":{"id":"bIQonPXpp2d3"}},{"cell_type":"code","source":["my_list = [1,2,3,4]"],"metadata":{"id":"YtWPBUVFn_Os","executionInfo":{"status":"ok","timestamp":1650021510308,"user_tz":-270,"elapsed":399,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":2,"outputs":[]},{"cell_type":"code","source":["my_tensor = torch.Tensor(my_list)"],"metadata":{"id":"AqylsQvYtxD7","executionInfo":{"status":"ok","timestamp":1650021519385,"user_tz":-270,"elapsed":448,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":3,"outputs":[]},{"cell_type":"code","source":["my_tensor"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9N3DBY1mt32q","executionInfo":{"status":"ok","timestamp":1650021522852,"user_tz":-270,"elapsed":4,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"185cdf12-4bf6-4c7b-b907-73e432696841"},"execution_count":4,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([1., 2., 3., 4.])"]},"metadata":{},"execution_count":4}]},{"cell_type":"code","source":["my_tensor[2]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tqD6SYMrt5ap","executionInfo":{"status":"ok","timestamp":1650021536501,"user_tz":-270,"elapsed":446,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"9aa22337-9918-495d-c892-aa5c63d4752e"},"execution_count":5,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor(3.)"]},"metadata":{},"execution_count":5}]},{"cell_type":"markdown","source":["Wanna convert it to a number? No problem at all!"],"metadata":{"id":"XzfxgigBt8pn"}},{"cell_type":"code","source":["my_tensor.numpy()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9orKGc5Xt7YM","executionInfo":{"status":"ok","timestamp":1650021580034,"user_tz":-270,"elapsed":479,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"38fc494c-e414-422d-ac1d-4123bc39d7a8"},"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([1., 2., 3., 4.], dtype=float32)"]},"metadata":{},"execution_count":6}]},{"cell_type":"markdown","source":["Wanna convert it back to a tensor? That's so easy!"],"metadata":{"id":"9ebLmktSuDph"}},{"cell_type":"code","source":["torch.Tensor(my_tensor.numpy())"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"wCTX6HQGuCkS","executionInfo":{"status":"ok","timestamp":1650021598736,"user_tz":-270,"elapsed":477,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"6573013f-15ae-404e-c8fd-dbaf5c5f8fb0"},"execution_count":7,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([1., 2., 3., 4.])"]},"metadata":{},"execution_count":7}]},{"cell_type":"markdown","source":["Get it's shape"],"metadata":{"id":"NjDQpriYud_-"}},{"cell_type":"code","source":["torch.Tensor(my_tensor.numpy()).shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gEeqH4KPuSND","executionInfo":{"status":"ok","timestamp":1650021613555,"user_tz":-270,"elapsed":470,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"5ff66139-26d7-43b1-92a3-7b888b9edd11"},"execution_count":8,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.Size([4])"]},"metadata":{},"execution_count":8}]},{"cell_type":"markdown","source":["# Making a Tensor from a Numpy Array"],"metadata":{"id":"oc8kSkQLwiiv"}},{"cell_type":"markdown","source":["How about making a Tensor from a numpy array? Let's go!"],"metadata":{"id":"9s2DGQJ1vEMn"}},{"cell_type":"code","source":["import numpy as np"],"metadata":{"id":"3Ft974jbuygL","executionInfo":{"status":"ok","timestamp":1650021645979,"user_tz":-270,"elapsed":413,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":9,"outputs":[]},{"cell_type":"markdown","source":["First, we make a Numpy array"],"metadata":{"id":"aoImHjS4wYgD"}},{"cell_type":"code","source":["my_np_array = np.array([ [1,2,3,4] , [5,6,7,8] ])"],"metadata":{"id":"gyMMvemIvICn","executionInfo":{"status":"ok","timestamp":1650021671483,"user_tz":-270,"elapsed":386,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":10,"outputs":[]},{"cell_type":"markdown","source":["Now we can get it's shape"],"metadata":{"id":"R4GK8WPowbQE"}},{"cell_type":"code","source":["my_np_array.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_ysIYsQlwUPT","executionInfo":{"status":"ok","timestamp":1650021675660,"user_tz":-270,"elapsed":403,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"b3ffa3a3-53a6-40fe-839c-f31056062484"},"execution_count":11,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(2, 4)"]},"metadata":{},"execution_count":11}]},{"cell_type":"code","source":["my_np_array.dtype"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZGGvmCQv9YuJ","executionInfo":{"status":"ok","timestamp":1650021734655,"user_tz":-270,"elapsed":606,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"9a38b345-73bc-42a2-f831-546789e4d4d2"},"execution_count":16,"outputs":[{"output_type":"execute_result","data":{"text/plain":["dtype('int64')"]},"metadata":{},"execution_count":16}]},{"cell_type":"code","source":["my_tensor = torch.Tensor(my_np_array)"],"metadata":{"id":"Ndi2jFXfwVg_","executionInfo":{"status":"ok","timestamp":1650021687242,"user_tz":-270,"elapsed":460,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["my_tensor"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MPfsHgcTx1lv","executionInfo":{"status":"ok","timestamp":1650021688633,"user_tz":-270,"elapsed":5,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"b59087dc-da06-41a2-d522-69cd09ffd01c"},"execution_count":13,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[1., 2., 3., 4.],\n"," [5., 6., 7., 8.]])"]},"metadata":{},"execution_count":13}]},{"cell_type":"code","source":["my_tensor.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"aVJJCfi99RE3","executionInfo":{"status":"ok","timestamp":1650021705074,"user_tz":-270,"elapsed":490,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"eb784d67-7b33-4651-bb86-b7e63e969bf2"},"execution_count":14,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.Size([2, 4])"]},"metadata":{},"execution_count":14}]},{"cell_type":"code","source":["my_tensor.dtype"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NgBZIKTtD5xJ","executionInfo":{"status":"ok","timestamp":1650021716775,"user_tz":-270,"elapsed":471,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"c27a45e0-41db-4092-f67f-e27a5532864b"},"execution_count":15,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.float32"]},"metadata":{},"execution_count":15}]},{"cell_type":"markdown","source":["Another way of making a Tensor from a Numpy array..."],"metadata":{"id":"TyJXMwqBDsAd"}},{"cell_type":"code","source":["my_tensor = torch.from_numpy(my_np_array)"],"metadata":{"id":"yEEGE7klDP4i","executionInfo":{"status":"ok","timestamp":1650021749099,"user_tz":-270,"elapsed":461,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":17,"outputs":[]},{"cell_type":"code","source":["my_tensor"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"LJISIhtfDxaA","executionInfo":{"status":"ok","timestamp":1650021752157,"user_tz":-270,"elapsed":475,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"9947bd4e-3795-424b-a8af-289ea4f024e7"},"execution_count":18,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[1, 2, 3, 4],\n"," [5, 6, 7, 8]])"]},"metadata":{},"execution_count":18}]},{"cell_type":"code","source":["my_tensor.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WjiNn8SDx23m","executionInfo":{"status":"ok","timestamp":1650021759483,"user_tz":-270,"elapsed":441,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"1d82fdf1-80c1-4269-9229-a9246b69bfe4"},"execution_count":19,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.Size([2, 4])"]},"metadata":{},"execution_count":19}]},{"cell_type":"code","source":["my_tensor.dtype"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"0m6uvhrxD8XC","executionInfo":{"status":"ok","timestamp":1650021761657,"user_tz":-270,"elapsed":384,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"d139396c-a10a-4b6c-8f27-f114453cd50c"},"execution_count":20,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.int64"]},"metadata":{},"execution_count":20}]},{"cell_type":"markdown","source":["Even the shape can be converted to a number"],"metadata":{"id":"6wrSSzH40NMW"}},{"cell_type":"code","source":["my_tensor.shape[0]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"a4WYYLZXx303","executionInfo":{"status":"ok","timestamp":1650021787523,"user_tz":-270,"elapsed":400,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"1bd83994-72a3-448b-d90e-5e058a3afdb4"},"execution_count":21,"outputs":[{"output_type":"execute_result","data":{"text/plain":["2"]},"metadata":{},"execution_count":21}]},{"cell_type":"markdown","source":["# Making a Random Tensor"],"metadata":{"id":"HK1G3QDqEKKH"}},{"cell_type":"markdown","source":["Uniform Distribution $[0,1)$"],"metadata":{"id":"FE5TzIZqEZXz"}},{"cell_type":"code","source":["torch.rand((3,5))"],"metadata":{"id":"rnOt5aXTr1Sr","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650021965498,"user_tz":-270,"elapsed":464,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"f65e0166-d68d-4e78-e0b6-e2408dfa940f"},"execution_count":22,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.3368, 0.3964, 0.8355, 0.4981, 0.6181],\n"," [0.9403, 0.1545, 0.1922, 0.2865, 0.0816],\n"," [0.2312, 0.5274, 0.9766, 0.9085, 0.2159]])"]},"metadata":{},"execution_count":22}]},{"cell_type":"markdown","source":["Integers between 0 and 100"],"metadata":{"id":"Bao3LXjmEfP0"}},{"cell_type":"code","source":["torch.randint(0,100,(3,5))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"EvWJf9uzEUHR","executionInfo":{"status":"ok","timestamp":1650021996734,"user_tz":-270,"elapsed":581,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"5503639b-70a6-408d-cd1d-49328e555cbe"},"execution_count":23,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[30, 71, 10, 7, 87],\n"," [98, 5, 93, 64, 20],\n"," [21, 63, 69, 97, 14]])"]},"metadata":{},"execution_count":23}]},{"cell_type":"markdown","source":["Numbers drawn from $\\mathcal{N}(0,1)$"],"metadata":{"id":"30lV_yZlFdyU"}},{"cell_type":"code","source":["torch.rand((3,5))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"JSJu_LEqFKnN","executionInfo":{"status":"ok","timestamp":1650022017301,"user_tz":-270,"elapsed":564,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"050cd8a3-c470-41c1-ca72-0df77e6ac342"},"execution_count":24,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.0522, 0.2935, 0.2521, 0.6683, 0.9789],\n"," [0.8658, 0.6981, 0.9624, 0.3947, 0.9429],\n"," [0.6837, 0.5939, 0.1641, 0.0439, 0.8904]])"]},"metadata":{},"execution_count":24}]},{"cell_type":"markdown","source":["# Special Tensors"],"metadata":{"id":"J2gLzneqF6NC"}},{"cell_type":"markdown","source":["Filled with zeros"],"metadata":{"id":"B3CM5QGTF9hB"}},{"cell_type":"code","source":["torch.zeros((3,5))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"HU1_F5k1Fcqk","executionInfo":{"status":"ok","timestamp":1650022052836,"user_tz":-270,"elapsed":519,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"b8097f68-386c-4a80-d5e6-02fc3d82007c"},"execution_count":25,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0., 0., 0., 0., 0.],\n"," [0., 0., 0., 0., 0.],\n"," [0., 0., 0., 0., 0.]])"]},"metadata":{},"execution_count":25}]},{"cell_type":"markdown","source":["Filled with ones"],"metadata":{"id":"atEOXTBbGD06"}},{"cell_type":"code","source":["torch.ones((3,5))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fR3e1yDCGCPJ","executionInfo":{"status":"ok","timestamp":1650022061060,"user_tz":-270,"elapsed":499,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"5b222f73-6eb6-4535-a862-d22aa76db41e"},"execution_count":26,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[1., 1., 1., 1., 1.],\n"," [1., 1., 1., 1., 1.],\n"," [1., 1., 1., 1., 1.]])"]},"metadata":{},"execution_count":26}]},{"cell_type":"markdown","source":["# Tensors like other tensors!"],"metadata":{"id":"FHiKJtdyGbHi"}},{"cell_type":"code","source":["t = torch.rand((2,3))"],"metadata":{"id":"KfIM3mMfGFs7","executionInfo":{"status":"ok","timestamp":1650022119869,"user_tz":-270,"elapsed":412,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":27,"outputs":[]},{"cell_type":"code","source":["u = torch.rand_like(t)"],"metadata":{"id":"Jv1p392YGkAX","executionInfo":{"status":"ok","timestamp":1650022133519,"user_tz":-270,"elapsed":443,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":28,"outputs":[]},{"cell_type":"code","source":["t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Cp5hc23RGoQs","executionInfo":{"status":"ok","timestamp":1650022134866,"user_tz":-270,"elapsed":4,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"2db6f992-2acf-4241-82de-6e1be8fe1349"},"execution_count":29,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.0785, 0.3488, 0.4077],\n"," [0.2882, 0.7558, 0.6523]])"]},"metadata":{},"execution_count":29}]},{"cell_type":"code","source":["u"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"JVDrPfs3GpB0","executionInfo":{"status":"ok","timestamp":1650022143253,"user_tz":-270,"elapsed":517,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"f8c8c462-58d7-4003-91aa-72d7de048a5c"},"execution_count":30,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031]])"]},"metadata":{},"execution_count":30}]},{"cell_type":"code","source":["v = torch.ones_like(t)"],"metadata":{"id":"tb0nXKHXGpu8","executionInfo":{"status":"ok","timestamp":1650022165155,"user_tz":-270,"elapsed":401,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":31,"outputs":[]},{"cell_type":"code","source":["v"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"oF7oO5EtGt-c","executionInfo":{"status":"ok","timestamp":1650022167371,"user_tz":-270,"elapsed":435,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"01175dfb-e462-4b20-dfca-a78619728e71"},"execution_count":32,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[1., 1., 1.],\n"," [1., 1., 1.]])"]},"metadata":{},"execution_count":32}]},{"cell_type":"code","source":["w = torch.zeros_like(t)"],"metadata":{"id":"J83IjL8vGuX3","executionInfo":{"status":"ok","timestamp":1650022173774,"user_tz":-270,"elapsed":475,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":33,"outputs":[]},{"cell_type":"code","source":["w"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"rDzZpegsGxar","executionInfo":{"status":"ok","timestamp":1650022176856,"user_tz":-270,"elapsed":395,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"b5441613-44df-43c1-f6b8-ede3cf081ba8"},"execution_count":34,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0., 0., 0.],\n"," [0., 0., 0.]])"]},"metadata":{},"execution_count":34}]},{"cell_type":"markdown","source":["# Indexing"],"metadata":{"id":"lqUOFMStHJWZ"}},{"cell_type":"code","source":["u"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"hHdl4_OlGyei","executionInfo":{"status":"ok","timestamp":1650022198663,"user_tz":-270,"elapsed":424,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"7e26be5b-5484-4b3b-8654-3c10d68e2b82"},"execution_count":35,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031]])"]},"metadata":{},"execution_count":35}]},{"cell_type":"markdown","source":["First row, Second Column"],"metadata":{"id":"OmtkjxpfHgzY"}},{"cell_type":"code","source":["u[0][1]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mJ8848pHHN5g","executionInfo":{"status":"ok","timestamp":1650022213998,"user_tz":-270,"elapsed":480,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"36aadc71-60a2-4dc9-d404-a8c82779e49c"},"execution_count":36,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor(0.7695)"]},"metadata":{},"execution_count":36}]},{"cell_type":"markdown","source":["Second Column"],"metadata":{"id":"HOM0R0bUHeBT"}},{"cell_type":"code","source":["u[:,1]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"hvRsDqm2HPCu","executionInfo":{"status":"ok","timestamp":1650022306762,"user_tz":-270,"elapsed":660,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"781127c6-50ca-48ee-8849-b895d7937ee2"},"execution_count":39,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([0.7695, 0.6689])"]},"metadata":{},"execution_count":39}]},{"cell_type":"markdown","source":["First row"],"metadata":{"id":"Vv8wkP0BHkSB"}},{"cell_type":"code","source":["u[0,:]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"sUWfLSc-HX6e","executionInfo":{"status":"ok","timestamp":1650022297550,"user_tz":-270,"elapsed":497,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"8ddde5c7-60f7-4bb8-a9e3-811d3a653027"},"execution_count":38,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([0.8722, 0.7695, 0.4542])"]},"metadata":{},"execution_count":38}]},{"cell_type":"markdown","source":["# Concatenation"],"metadata":{"id":"rcSDMBPDHsYy"}},{"cell_type":"code","source":["u"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"e9LZEG9SH3aR","executionInfo":{"status":"ok","timestamp":1650022335819,"user_tz":-270,"elapsed":417,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"b93af822-6c46-4a84-d974-5795387157fe"},"execution_count":40,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031]])"]},"metadata":{},"execution_count":40}]},{"cell_type":"code","source":["t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"kX-i1VHVH3sQ","executionInfo":{"status":"ok","timestamp":1650022338016,"user_tz":-270,"elapsed":449,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"7ecb2761-018f-4fef-c54d-4dd5021a90ae"},"execution_count":41,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.0785, 0.3488, 0.4077],\n"," [0.2882, 0.7558, 0.6523]])"]},"metadata":{},"execution_count":41}]},{"cell_type":"code","source":["torch.cat([u,t])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"7eRzuhEiHpI-","executionInfo":{"status":"ok","timestamp":1650022352963,"user_tz":-270,"elapsed":489,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"b5ef7448-2c90-4bec-c6dd-3a36a7a27c6b"},"execution_count":42,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031],\n"," [0.0785, 0.3488, 0.4077],\n"," [0.2882, 0.7558, 0.6523]])"]},"metadata":{},"execution_count":42}]},{"cell_type":"markdown","source":["Concatenate by joining the rows"],"metadata":{"id":"fs-kcdQ_IFbI"}},{"cell_type":"code","source":["torch.cat([u,t], dim=0)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"iGz-NB66IAD5","executionInfo":{"status":"ok","timestamp":1650022365703,"user_tz":-270,"elapsed":6,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"a6f5cd5c-d344-4d4a-ea50-979e89cbb044"},"execution_count":43,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031],\n"," [0.0785, 0.3488, 0.4077],\n"," [0.2882, 0.7558, 0.6523]])"]},"metadata":{},"execution_count":43}]},{"cell_type":"markdown","source":["Concatenate by joining the columns"],"metadata":{"id":"SIMWXHoJIHX9"}},{"cell_type":"code","source":["torch.cat([u,t], dim=1)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"qQyHw27jHy4s","executionInfo":{"status":"ok","timestamp":1650022381554,"user_tz":-270,"elapsed":420,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"3bb4a390-c170-4daf-e2dc-ecc84587d089"},"execution_count":44,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542, 0.0785, 0.3488, 0.4077],\n"," [0.8422, 0.6689, 0.1031, 0.2882, 0.7558, 0.6523]])"]},"metadata":{},"execution_count":44}]},{"cell_type":"markdown","source":["# Basic Arithmetic Operations"],"metadata":{"id":"Ao3sxTsMIc3d"}},{"cell_type":"markdown","source":["Element-wise Multiplication"],"metadata":{"id":"Codm03eFIisH"}},{"cell_type":"code","source":["u"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"QSKP5WERImkC","executionInfo":{"status":"ok","timestamp":1650022423515,"user_tz":-270,"elapsed":442,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"e1957614-1975-43ee-cc58-b05dcfcc0a92"},"execution_count":45,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031]])"]},"metadata":{},"execution_count":45}]},{"cell_type":"code","source":["v"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XlmBz2rcIm3B","executionInfo":{"status":"ok","timestamp":1650022425182,"user_tz":-270,"elapsed":6,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"fef17485-344f-4205-8688-d8c51228e5dc"},"execution_count":46,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[1., 1., 1.],\n"," [1., 1., 1.]])"]},"metadata":{},"execution_count":46}]},{"cell_type":"code","source":["t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"oNH3G1A1Ipip","executionInfo":{"status":"ok","timestamp":1650022427462,"user_tz":-270,"elapsed":410,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"2bd7812d-0561-4e4c-f4df-5391d68e3012"},"execution_count":47,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.0785, 0.3488, 0.4077],\n"," [0.2882, 0.7558, 0.6523]])"]},"metadata":{},"execution_count":47}]},{"cell_type":"code","source":["u*v"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"C-f1vPlFH9Ho","executionInfo":{"status":"ok","timestamp":1650022444154,"user_tz":-270,"elapsed":448,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"36f6ca75-4da2-4c78-871d-2e118959d377"},"execution_count":48,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031]])"]},"metadata":{},"execution_count":48}]},{"cell_type":"code","source":["u*t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nfiHaiIFIg0x","executionInfo":{"status":"ok","timestamp":1650022465694,"user_tz":-270,"elapsed":407,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"e419083a-623e-406c-b389-70f30469ab80"},"execution_count":49,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.0685, 0.2684, 0.1851],\n"," [0.2427, 0.5056, 0.0672]])"]},"metadata":{},"execution_count":49}]},{"cell_type":"code","source":["0.7695*0.3488"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"wc0QtIm0IsZH","executionInfo":{"status":"ok","timestamp":1650022505099,"user_tz":-270,"elapsed":439,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"cc911667-92dd-4516-90d3-8ee55098e02b"},"execution_count":50,"outputs":[{"output_type":"execute_result","data":{"text/plain":["0.26840159999999996"]},"metadata":{},"execution_count":50}]},{"cell_type":"code","source":["torch.mul(u,t)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"wiA7036BKjRx","executionInfo":{"status":"ok","timestamp":1650022525502,"user_tz":-270,"elapsed":501,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"f342d455-5114-41c4-dc22-230d04b6b811"},"execution_count":51,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.0685, 0.2684, 0.1851],\n"," [0.2427, 0.5056, 0.0672]])"]},"metadata":{},"execution_count":51}]},{"cell_type":"markdown","source":["Matrix Multiplication"],"metadata":{"id":"wAcysJ-CI_n8"}},{"cell_type":"code","source":["u"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"SCyK3S2KI6Hi","executionInfo":{"status":"ok","timestamp":1650022548690,"user_tz":-270,"elapsed":399,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"17dff320-79be-47d0-b9cd-9d48113fa57d"},"execution_count":52,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031]])"]},"metadata":{},"execution_count":52}]},{"cell_type":"code","source":["t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2foNpR2zJEkL","executionInfo":{"status":"ok","timestamp":1650022549991,"user_tz":-270,"elapsed":5,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"50207f19-7fb5-4649-a6a2-374667b4f1fb"},"execution_count":53,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.0785, 0.3488, 0.4077],\n"," [0.2882, 0.7558, 0.6523]])"]},"metadata":{},"execution_count":53}]},{"cell_type":"code","source":["torch.matmul(u,t.T) # Note that u is 2*3 and t.T is 3*2. So the result is 2*2."],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"lgI6PLgbJFDO","executionInfo":{"status":"ok","timestamp":1650022576991,"user_tz":-270,"elapsed":437,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"a1d110a5-ae50-4c5b-b444-7e22a0666a02"},"execution_count":54,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.5220, 1.1292],\n"," [0.3415, 0.8156]])"]},"metadata":{},"execution_count":54}]},{"cell_type":"markdown","source":["Getting the Abstract Value"],"metadata":{"id":"AbMMa-cLJsuG"}},{"cell_type":"code","source":["s = torch.Tensor([-10,-20,30])"],"metadata":{"id":"0NHzpWdpJxrx","executionInfo":{"status":"ok","timestamp":1650022587458,"user_tz":-270,"elapsed":409,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":55,"outputs":[]},{"cell_type":"code","source":["torch.abs(s)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"pus8uezUJ165","executionInfo":{"status":"ok","timestamp":1650022593694,"user_tz":-270,"elapsed":475,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"a428c30e-54b7-42c3-db1e-1e084efe385d"},"execution_count":56,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([10., 20., 30.])"]},"metadata":{},"execution_count":56}]},{"cell_type":"markdown","source":["Addition (Element-wise)"],"metadata":{"id":"rgJhZ5hIKFcI"}},{"cell_type":"code","source":["u"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"C-RLAZUqKCbH","executionInfo":{"status":"ok","timestamp":1650022602808,"user_tz":-270,"elapsed":5,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"3aa58b98-52e4-4f5a-f405-858b007b4850"},"execution_count":57,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031]])"]},"metadata":{},"execution_count":57}]},{"cell_type":"code","source":["u+1"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"lLMyNoWiLqrB","executionInfo":{"status":"ok","timestamp":1650022605783,"user_tz":-270,"elapsed":426,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"a1eaee4e-3e03-48d9-d6fa-d36dd627bd3d"},"execution_count":58,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[1.8722, 1.7695, 1.4542],\n"," [1.8422, 1.6689, 1.1031]])"]},"metadata":{},"execution_count":58}]},{"cell_type":"code","source":["t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"g_PHbcAmKGhc","executionInfo":{"status":"ok","timestamp":1650022614105,"user_tz":-270,"elapsed":515,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"57a0efe2-c835-4626-fe22-381dc44a9ed0"},"execution_count":59,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.0785, 0.3488, 0.4077],\n"," [0.2882, 0.7558, 0.6523]])"]},"metadata":{},"execution_count":59}]},{"cell_type":"code","source":["u+t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"HMrE_5-4KG20","executionInfo":{"status":"ok","timestamp":1650022616760,"user_tz":-270,"elapsed":398,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"34b64762-1744-4d19-a10d-1a3600f281f6"},"execution_count":60,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.9508, 1.1183, 0.8618],\n"," [1.1304, 1.4248, 0.7553]])"]},"metadata":{},"execution_count":60}]},{"cell_type":"code","source":["torch.add(u,t)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZB-8ZuuWKH-b","executionInfo":{"status":"ok","timestamp":1650022624447,"user_tz":-270,"elapsed":432,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"ea8f6b45-2616-4b3a-f7ca-6a5298746d4d"},"execution_count":61,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.9508, 1.1183, 0.8618],\n"," [1.1304, 1.4248, 0.7553]])"]},"metadata":{},"execution_count":61}]},{"cell_type":"markdown","source":["Subtraction (Element-wise)"],"metadata":{"id":"al176z4eKO-C"}},{"cell_type":"code","source":["u-t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"urGZtVSNKM2u","executionInfo":{"status":"ok","timestamp":1650022630250,"user_tz":-270,"elapsed":450,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"bf1046a0-e443-4a5f-ba95-ff140b0a436b"},"execution_count":62,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[ 0.7937, 0.4207, 0.0465],\n"," [ 0.5540, -0.0869, -0.5492]])"]},"metadata":{},"execution_count":62}]},{"cell_type":"code","source":["torch.sub(u,t)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"JGDGll8vKSBe","executionInfo":{"status":"ok","timestamp":1650022632779,"user_tz":-270,"elapsed":432,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"20c19024-ef11-496a-ef1d-9e25546a8b52"},"execution_count":63,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[ 0.7937, 0.4207, 0.0465],\n"," [ 0.5540, -0.0869, -0.5492]])"]},"metadata":{},"execution_count":63}]},{"cell_type":"markdown","source":["Division (Element-wise)"],"metadata":{"id":"uU9v8eB_KWrx"}},{"cell_type":"code","source":["u/t"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XKwRn1p6KaaC","executionInfo":{"status":"ok","timestamp":1650022637519,"user_tz":-270,"elapsed":487,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"e9dbeed1-f180-42ac-81b1-c249485c0ddc"},"execution_count":64,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[11.1055, 2.2061, 1.1141],\n"," [ 2.9224, 0.8850, 0.1580]])"]},"metadata":{},"execution_count":64}]},{"cell_type":"code","source":["torch.div(u,t)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Q5DyGWOBKTVJ","executionInfo":{"status":"ok","timestamp":1650022641426,"user_tz":-270,"elapsed":428,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"0a442424-c7dc-4e2d-f0c0-93f25c7f5d4a"},"execution_count":65,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[11.1055, 2.2061, 1.1141],\n"," [ 2.9224, 0.8850, 0.1580]])"]},"metadata":{},"execution_count":65}]},{"cell_type":"markdown","source":["Negation"],"metadata":{"id":"rbBizr9vKobN"}},{"cell_type":"code","source":["u"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"oPyHb5IFKpTZ","executionInfo":{"status":"ok","timestamp":1650022649399,"user_tz":-270,"elapsed":445,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"57cfb2e6-f3c1-4f35-dfc0-12b4b843f3d5"},"execution_count":66,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[0.8722, 0.7695, 0.4542],\n"," [0.8422, 0.6689, 0.1031]])"]},"metadata":{},"execution_count":66}]},{"cell_type":"code","source":["torch.neg(u)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"bXy65Vs8Kp-R","executionInfo":{"status":"ok","timestamp":1650022653147,"user_tz":-270,"elapsed":374,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"6e5133d5-357c-466d-a54e-80e686492df5"},"execution_count":67,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[-0.8722, -0.7695, -0.4542],\n"," [-0.8422, -0.6689, -0.1031]])"]},"metadata":{},"execution_count":67}]},{"cell_type":"code","source":["-u"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9TTubhrFKrEx","executionInfo":{"status":"ok","timestamp":1650022662240,"user_tz":-270,"elapsed":420,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"b7490c99-37a9-4bb6-a95a-9ca230569bdf"},"execution_count":68,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[-0.8722, -0.7695, -0.4542],\n"," [-0.8422, -0.6689, -0.1031]])"]},"metadata":{},"execution_count":68}]},{"cell_type":"markdown","source":["Raising to a Power (element-wise)"],"metadata":{"id":"v6rTFk3OKuFs"}},{"cell_type":"code","source":["r = torch.ones((2,3))*2"],"metadata":{"id":"rH23w8NSKrgu","executionInfo":{"status":"ok","timestamp":1650022680525,"user_tz":-270,"elapsed":475,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":69,"outputs":[]},{"cell_type":"code","source":["r"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"lR0H55tHLOUz","executionInfo":{"status":"ok","timestamp":1650022682469,"user_tz":-270,"elapsed":457,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"0c1764f8-29de-4e37-c1ad-039c9c0cf7c2"},"execution_count":70,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[2., 2., 2.],\n"," [2., 2., 2.]])"]},"metadata":{},"execution_count":70}]},{"cell_type":"code","source":["torch.pow(r,5)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vjCFBXnrLOlW","executionInfo":{"status":"ok","timestamp":1650022692078,"user_tz":-270,"elapsed":422,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"eb410e65-adc8-435b-a933-8526fc04527f"},"execution_count":71,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([[32., 32., 32.],\n"," [32., 32., 32.]])"]},"metadata":{},"execution_count":71}]}]} -------------------------------------------------------------------------------- /Pytorch-5-Training-Loop.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e0f8d36d", 6 | "metadata": {}, 7 | "source": [ 8 | "# Preliminaries" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 72, 14 | "id": "f12a45a8", 15 | "metadata": { 16 | "ExecuteTime": { 17 | "end_time": "2022-06-07T12:53:12.043999Z", 18 | "start_time": "2022-06-07T12:53:11.982975Z" 19 | } 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import torch\n", 24 | "from torch import nn\n", 25 | "from torch.utils.data import Dataset\n", 26 | "from torch.utils.data import DataLoader\n", 27 | "import numpy as np\n", 28 | "import pandas as pd\n", 29 | "import os\n", 30 | "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n", 31 | "import spacy # install command: pip install spacy\n", 32 | "import string\n", 33 | "import nltk # install command: pip install nltk\n", 34 | "from nltk.stem import WordNetLemmatizer\n", 35 | "import re\n", 36 | "from sklearn.model_selection import train_test_split\n", 37 | "import onnxruntime # install command: pip install onnxruntime\n", 38 | "import torch.onnx as onnx" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 73, 44 | "id": "6d5226a8", 45 | "metadata": { 46 | "ExecuteTime": { 47 | "end_time": "2022-06-07T12:53:12.790788Z", 48 | "start_time": "2022-06-07T12:53:12.786937Z" 49 | } 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "# !python3 -m spacy download en_core_web_sm" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 74, 59 | "id": "3dbbc15b", 60 | "metadata": { 61 | "ExecuteTime": { 62 | "end_time": "2022-06-07T12:53:13.332616Z", 63 | "start_time": "2022-06-07T12:53:13.325569Z" 64 | } 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "# nltk.download('wordnet')" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 75, 74 | "id": "9a01d18d", 75 | "metadata": { 76 | "ExecuteTime": { 77 | "end_time": "2022-06-07T12:53:14.178352Z", 78 | "start_time": "2022-06-07T12:53:14.174293Z" 79 | } 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "# nltk.download('omw-1.4')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 76, 89 | "id": "2ae09d3d", 90 | "metadata": { 91 | "ExecuteTime": { 92 | "end_time": "2022-06-07T12:53:15.087167Z", 93 | "start_time": "2022-06-07T12:53:15.071639Z" 94 | } 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "config = {\n", 99 | " 'max_features': 1000,\n", 100 | " 'num_epochs':200,\n", 101 | " 'learning_rate':1e-1,\n", 102 | " 'batch_size':32,\n", 103 | " 'train_percentage':90\n", 104 | "}" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "f8317505", 110 | "metadata": {}, 111 | "source": [ 112 | "# Dataset" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "1a6fbc2a", 118 | "metadata": {}, 119 | "source": [ 120 | "Note that the dataset is taken from https://www.kaggle.com/datasets/yasserh/twitter-tweets-sentiment-dataset" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "0b5ded9e", 126 | "metadata": {}, 127 | "source": [ 128 | "## Exploration" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 78, 134 | "id": "c94cc999", 135 | "metadata": { 136 | "ExecuteTime": { 137 | "end_time": "2022-06-07T12:54:13.585581Z", 138 | "start_time": "2022-06-07T12:54:13.466482Z" 139 | } 140 | }, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "text/html": [ 145 | "
\n", 146 | "\n", 159 | "\n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | "
textIDtextselected_textsentiment
0cb774db0d1I`d have responded, if I were goingI`d have responded, if I were goingneutral
1549e992a42Sooo SAD I will miss you here in San Diego!!!Sooo SADnegative
2088c60f138my boss is bullying me...bullying menegative
39642c003efwhat interview! leave me aloneleave me alonenegative
4358bd9e861Sons of ****, why couldn`t they put them on t...Sons of ****,negative
\n", 207 | "
" 208 | ], 209 | "text/plain": [ 210 | " textID text \\\n", 211 | "0 cb774db0d1 I`d have responded, if I were going \n", 212 | "1 549e992a42 Sooo SAD I will miss you here in San Diego!!! \n", 213 | "2 088c60f138 my boss is bullying me... \n", 214 | "3 9642c003ef what interview! leave me alone \n", 215 | "4 358bd9e861 Sons of ****, why couldn`t they put them on t... \n", 216 | "\n", 217 | " selected_text sentiment \n", 218 | "0 I`d have responded, if I were going neutral \n", 219 | "1 Sooo SAD negative \n", 220 | "2 bullying me negative \n", 221 | "3 leave me alone negative \n", 222 | "4 Sons of ****, negative " 223 | ] 224 | }, 225 | "execution_count": 78, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "df = pd.read_csv('tweets.csv')\n", 232 | "df.head()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 79, 238 | "id": "d2360b30", 239 | "metadata": { 240 | "ExecuteTime": { 241 | "end_time": "2022-06-07T12:55:16.030122Z", 242 | "start_time": "2022-06-07T12:55:15.999686Z" 243 | } 244 | }, 245 | "outputs": [ 246 | { 247 | "data": { 248 | "text/plain": [ 249 | "array(['neutral', 'negative', 'positive'], dtype=object)" 250 | ] 251 | }, 252 | "execution_count": 79, 253 | "metadata": {}, 254 | "output_type": "execute_result" 255 | } 256 | ], 257 | "source": [ 258 | "df['sentiment'].unique()" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 80, 264 | "id": "87fb728c", 265 | "metadata": { 266 | "ExecuteTime": { 267 | "end_time": "2022-06-07T12:55:37.625606Z", 268 | "start_time": "2022-06-07T12:55:37.584586Z" 269 | } 270 | }, 271 | "outputs": [], 272 | "source": [ 273 | "df = df.loc[df['sentiment']!='neutral']" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 81, 279 | "id": "7c1824c4", 280 | "metadata": { 281 | "ExecuteTime": { 282 | "end_time": "2022-06-07T12:55:52.530116Z", 283 | "start_time": "2022-06-07T12:55:52.501840Z" 284 | } 285 | }, 286 | "outputs": [ 287 | { 288 | "data": { 289 | "text/html": [ 290 | "
\n", 291 | "\n", 304 | "\n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | "
textsentiment
1Sooo SAD I will miss you here in San Diego!!!negative
2my boss is bullying me...negative
3what interview! leave me alonenegative
4Sons of ****, why couldn`t they put them on t...negative
62am feedings for the baby are fun when he is a...positive
\n", 340 | "
" 341 | ], 342 | "text/plain": [ 343 | " text sentiment\n", 344 | "1 Sooo SAD I will miss you here in San Diego!!! negative\n", 345 | "2 my boss is bullying me... negative\n", 346 | "3 what interview! leave me alone negative\n", 347 | "4 Sons of ****, why couldn`t they put them on t... negative\n", 348 | "6 2am feedings for the baby are fun when he is a... positive" 349 | ] 350 | }, 351 | "execution_count": 81, 352 | "metadata": {}, 353 | "output_type": "execute_result" 354 | } 355 | ], 356 | "source": [ 357 | "df = df[['text','sentiment']]\n", 358 | "df.head()" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 82, 364 | "id": "f34a98e2", 365 | "metadata": { 366 | "ExecuteTime": { 367 | "end_time": "2022-06-07T12:56:06.896776Z", 368 | "start_time": "2022-06-07T12:56:06.877288Z" 369 | } 370 | }, 371 | "outputs": [ 372 | { 373 | "data": { 374 | "text/plain": [ 375 | "16363" 376 | ] 377 | }, 378 | "execution_count": 82, 379 | "metadata": {}, 380 | "output_type": "execute_result" 381 | } 382 | ], 383 | "source": [ 384 | "df['text'].count()" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "id": "22995c40", 390 | "metadata": {}, 391 | "source": [ 392 | "## Preprocessing" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "id": "89835b9d", 398 | "metadata": {}, 399 | "source": [ 400 | "### Lemmatization" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 83, 406 | "id": "1d8ff9a7", 407 | "metadata": { 408 | "ExecuteTime": { 409 | "end_time": "2022-06-07T12:58:23.801740Z", 410 | "start_time": "2022-06-07T12:58:23.792486Z" 411 | } 412 | }, 413 | "outputs": [], 414 | "source": [ 415 | "lemmatizer = WordNetLemmatizer()" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 84, 421 | "id": "8222a654", 422 | "metadata": { 423 | "ExecuteTime": { 424 | "end_time": "2022-06-07T12:58:33.702709Z", 425 | "start_time": "2022-06-07T12:58:33.688971Z" 426 | } 427 | }, 428 | "outputs": [ 429 | { 430 | "data": { 431 | "text/plain": [ 432 | "'setting'" 433 | ] 434 | }, 435 | "execution_count": 84, 436 | "metadata": {}, 437 | "output_type": "execute_result" 438 | } 439 | ], 440 | "source": [ 441 | "lemmatizer.lemmatize(\"settings\")" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "id": "b631ef75", 447 | "metadata": {}, 448 | "source": [ 449 | "### Stopword Removal" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 85, 455 | "id": "32254da3", 456 | "metadata": { 457 | "ExecuteTime": { 458 | "end_time": "2022-06-07T12:59:45.630515Z", 459 | "start_time": "2022-06-07T12:59:44.892831Z" 460 | } 461 | }, 462 | "outputs": [], 463 | "source": [ 464 | "en = spacy.load('en_core_web_sm')\n", 465 | "stopwords = en.Defaults.stop_words" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 86, 471 | "id": "1bbfe590", 472 | "metadata": { 473 | "ExecuteTime": { 474 | "end_time": "2022-06-07T13:00:04.741846Z", 475 | "start_time": "2022-06-07T13:00:04.733390Z" 476 | } 477 | }, 478 | "outputs": [ 479 | { 480 | "data": { 481 | "text/plain": [ 482 | "set" 483 | ] 484 | }, 485 | "execution_count": 86, 486 | "metadata": {}, 487 | "output_type": "execute_result" 488 | } 489 | ], 490 | "source": [ 491 | "type(stopwords)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 87, 497 | "id": "72beacae", 498 | "metadata": { 499 | "ExecuteTime": { 500 | "end_time": "2022-06-07T13:00:06.931762Z", 501 | "start_time": "2022-06-07T13:00:06.926317Z" 502 | } 503 | }, 504 | "outputs": [ 505 | { 506 | "data": { 507 | "text/plain": [ 508 | "326" 509 | ] 510 | }, 511 | "execution_count": 87, 512 | "metadata": {}, 513 | "output_type": "execute_result" 514 | } 515 | ], 516 | "source": [ 517 | "len(stopwords)" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 88, 523 | "id": "cc2b505b", 524 | "metadata": { 525 | "ExecuteTime": { 526 | "end_time": "2022-06-07T13:00:32.026822Z", 527 | "start_time": "2022-06-07T13:00:32.019442Z" 528 | } 529 | }, 530 | "outputs": [ 531 | { 532 | "data": { 533 | "text/plain": [ 534 | "['elsewhere',\n", 535 | " 'sixty',\n", 536 | " 'by',\n", 537 | " 'toward',\n", 538 | " 'cannot',\n", 539 | " 'perhaps',\n", 540 | " 'over',\n", 541 | " 'down',\n", 542 | " 'any',\n", 543 | " 'thence']" 544 | ] 545 | }, 546 | "execution_count": 88, 547 | "metadata": {}, 548 | "output_type": "execute_result" 549 | } 550 | ], 551 | "source": [ 552 | "list(stopwords)[:10]" 553 | ] 554 | }, 555 | { 556 | "cell_type": "markdown", 557 | "id": "2819073b", 558 | "metadata": {}, 559 | "source": [ 560 | "### Preprocessing Pipeline" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": 89, 566 | "id": "aefb6702", 567 | "metadata": { 568 | "ExecuteTime": { 569 | "end_time": "2022-06-07T13:03:28.469718Z", 570 | "start_time": "2022-06-07T13:03:28.448570Z" 571 | } 572 | }, 573 | "outputs": [], 574 | "source": [ 575 | "def preprocess(txt):\n", 576 | " txt = txt.lower()\n", 577 | " tokens = txt.split()\n", 578 | " tokens = [lemmatizer.lemmatize(token) for token in tokens]\n", 579 | " txt = ' '.join(tokens)\n", 580 | " txt = txt.translate(str.maketrans('', '', string.punctuation))\n", 581 | " tokens = txt.split()\n", 582 | " tokens = [token for token in tokens if token not in stopwords]\n", 583 | " txt = ' '.join(tokens)\n", 584 | " txt = re.sub(r'[0-9]+', '', txt)\n", 585 | " return txt" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 91, 591 | "id": "4be1bc5b", 592 | "metadata": { 593 | "ExecuteTime": { 594 | "end_time": "2022-06-07T13:04:33.152749Z", 595 | "start_time": "2022-06-07T13:04:33.134307Z" 596 | } 597 | }, 598 | "outputs": [], 599 | "source": [ 600 | "df = df.reset_index() # restore the indices of the dataframe so that it starts from 0 and skips nothing" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 92, 606 | "id": "8cd240ae", 607 | "metadata": { 608 | "ExecuteTime": { 609 | "end_time": "2022-06-07T13:04:34.320488Z", 610 | "start_time": "2022-06-07T13:04:34.306446Z" 611 | } 612 | }, 613 | "outputs": [ 614 | { 615 | "name": "stdout", 616 | "output_type": "stream", 617 | "text": [ 618 | "The original text was:\n", 619 | " Well what im working on isn`t QUITE ready to post about publicly (still beta testing) but its a cool new script I coded\n", 620 | " The preprocessed text is: \n", 621 | "im working isnt ready post publicly beta testing cool new script coded\n" 622 | ] 623 | } 624 | ], 625 | "source": [ 626 | "original_txt = df['text'][50]\n", 627 | "processed_txt = preprocess(df['text'][50])\n", 628 | "print(f'The original text was:\\n{original_txt}\\n The preprocessed text is: \\n{processed_txt}')" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 93, 634 | "id": "c4560071", 635 | "metadata": { 636 | "ExecuteTime": { 637 | "end_time": "2022-06-07T13:06:10.429918Z", 638 | "start_time": "2022-06-07T13:06:09.580485Z" 639 | } 640 | }, 641 | "outputs": [ 642 | { 643 | "data": { 644 | "text/html": [ 645 | "
\n", 646 | "\n", 659 | "\n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | "
indextextsentimentpreprocessed_text
01Sooo SAD I will miss you here in San Diego!!!negativesooo sad miss san diego
12my boss is bullying me...negativebos bullying
23what interview! leave me alonenegativeinterview leave
34Sons of ****, why couldn`t they put them on t...negativeson couldnt release bought
462am feedings for the baby are fun when he is a...positiveam feeding baby fun smile coo
\n", 707 | "
" 708 | ], 709 | "text/plain": [ 710 | " index text sentiment \\\n", 711 | "0 1 Sooo SAD I will miss you here in San Diego!!! negative \n", 712 | "1 2 my boss is bullying me... negative \n", 713 | "2 3 what interview! leave me alone negative \n", 714 | "3 4 Sons of ****, why couldn`t they put them on t... negative \n", 715 | "4 6 2am feedings for the baby are fun when he is a... positive \n", 716 | "\n", 717 | " preprocessed_text \n", 718 | "0 sooo sad miss san diego \n", 719 | "1 bos bullying \n", 720 | "2 interview leave \n", 721 | "3 son couldnt release bought \n", 722 | "4 am feeding baby fun smile coo " 723 | ] 724 | }, 725 | "execution_count": 93, 726 | "metadata": {}, 727 | "output_type": "execute_result" 728 | } 729 | ], 730 | "source": [ 731 | "df['preprocessed_text'] = df['text'].apply(lambda x: preprocess(str(x)))\n", 732 | "df.head()" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": 94, 738 | "id": "4f6d45d2", 739 | "metadata": { 740 | "ExecuteTime": { 741 | "end_time": "2022-06-07T13:06:34.307461Z", 742 | "start_time": "2022-06-07T13:06:34.284354Z" 743 | } 744 | }, 745 | "outputs": [], 746 | "source": [ 747 | "texts = df['preprocessed_text'].to_list()" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": 95, 753 | "id": "f8ff8ab4", 754 | "metadata": { 755 | "ExecuteTime": { 756 | "end_time": "2022-06-07T13:06:43.595751Z", 757 | "start_time": "2022-06-07T13:06:43.384307Z" 758 | } 759 | }, 760 | "outputs": [], 761 | "source": [ 762 | "# vectorizer = CountVectorizer(max_features=config['max_features'])\n", 763 | "# features = vectorizer.fit_transform(texts)\n", 764 | "vectorizer = TfidfVectorizer(max_features=config['max_features'])\n", 765 | "features = vectorizer.fit_transform(texts)" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": 96, 771 | "id": "e03cd4b0", 772 | "metadata": { 773 | "ExecuteTime": { 774 | "end_time": "2022-06-07T13:07:54.825763Z", 775 | "start_time": "2022-06-07T13:07:54.802178Z" 776 | } 777 | }, 778 | "outputs": [ 779 | { 780 | "data": { 781 | "text/plain": [ 782 | "array(['able', 'absolutely', 'account', 'ache', 'actually', 'add',\n", 783 | " 'afford', 'afraid', 'afternoon', 'age', 'ago', 'agree', 'ah',\n", 784 | " 'ahh', 'ahhh', 'aint', 'air', 'album', 'alot', 'alright', 'am',\n", 785 | " 'amazing', 'america', 'american', 'annoying', 'answer', 'anymore',\n", 786 | " 'anyways', 'apart', 'app', 'apparently', 'apple', 'appreciate',\n", 787 | " 'arent', 'arm', 'art', 'ask', 'asking', 'asleep', 'ate', 'aw',\n", 788 | " 'awake', 'away', 'awesome', 'awful', 'aww', 'awww', 'awwww',\n", 789 | " 'babe', 'baby', 'bad', 'bag', 'ball', 'band', 'bank', 'bar',\n", 790 | " 'barely', 'bbq', 'bc', 'bday', 'beach', 'beat', 'beautiful', 'bed',\n", 791 | " 'beer', 'believe', 'ben', 'best', 'bet', 'better', 'bgt', 'big',\n", 792 | " 'bike', 'bird', 'birthday', 'bit', 'black', 'blackberry', 'blah',\n", 793 | " 'blast', 'bless', 'blessed', 'block', 'blog', 'blood', 'bloody',\n", 794 | " 'blue', 'body', 'boo', 'book', 'bored', 'boring', 'bought', 'bout',\n", 795 | " 'boy', 'boyfriend', 'break', 'breakfast', 'breaking', 'bring',\n", 796 | " 'britain', 'bro', 'broke', 'broken', 'brother', 'brought', 'btw',\n", 797 | " 'bug', 'bummed', 'bummer', 'burn', 'burned', 'burnt', 'bus',\n", 798 | " 'business', 'busy', 'buy', 'bye', 'cake', 'called', 'came',\n", 799 | " 'camera', 'cancelled', 'cant', 'car', 'card', 'care', 'case',\n", 800 | " 'cat', 'catch', 'cause', 'cd', 'chance', 'change', 'chat', 'check',\n", 801 | " 'checked', 'cheer', 'cheese', 'chicken', 'child', 'chill',\n", 802 | " 'chillin', 'chocolate', 'choice', 'church', 'city', 'class',\n", 803 | " 'clean', 'cleaning', 'close', 'clothes', 'club', 'co', 'coffee',\n", 804 | " 'cold', 'college', 'come', 'coming', 'comment', 'company',\n", 805 | " 'completely', 'computer', 'concert', 'congrats', 'congratulation',\n", 806 | " 'congratulations', 'cook', 'cool', 'copy', 'cost', 'couldnt',\n", 807 | " 'couple', 'course', 'cousin', 'cover', 'coz', 'crazy', 'cream',\n", 808 | " 'cried', 'cry', 'cup', 'currently', 'cut', 'cute', 'cuz', 'da',\n", 809 | " 'dad', 'daddy', 'dammit', 'dance', 'dancing', 'dang', 'dark',\n", 810 | " 'darn', 'date', 'daughter', 'david', 'day', 'days', 'dead', 'deal',\n", 811 | " 'dear', 'death', 'decided', 'def', 'definitely', 'degree',\n", 812 | " 'delicious', 'depressed', 'depressing', 'deserve', 'design',\n", 813 | " 'didnt', 'die', 'died', 'different', 'dinner', 'disappointed',\n", 814 | " 'dnt', 'doe', 'doesnt', 'dog', 'dont', 'dream', 'dress', 'drink',\n", 815 | " 'drinking', 'drive', 'driving', 'drunk', 'dude', 'dvd', 'earlier',\n", 816 | " 'early', 'earth', 'easy', 'eat', 'eating', 'eh', 'em', 'email',\n", 817 | " 'end', 'english', 'enjoy', 'enjoyed', 'enjoying', 'entire', 'epic',\n", 818 | " 'episode', 'especially', 'etc', 'evening', 'event', 'everybody',\n", 819 | " 'everyday', 'exactly', 'exam', 'excellent', 'excited', 'exciting',\n", 820 | " 'exhausted', 'expensive', 'experience', 'extra', 'eye', 'fab',\n", 821 | " 'face', 'facebook', 'fact', 'fail', 'failed', 'fair', 'fall',\n", 822 | " 'falling', 'fam', 'family', 'fan', 'fantastic', 'far', 'fast',\n", 823 | " 'fav', 'fave', 'favorite', 'favourite', 'fb', 'feel', 'feelin',\n", 824 | " 'feeling', 'fell', 'felt', 'fever', 'ff', 'fight', 'figure',\n", 825 | " 'film', 'final', 'finally', 'find', 'fine', 'finger', 'finish',\n", 826 | " 'finished', 'fit', 'fix', 'flight', 'flower', 'flu', 'fly', 'fml',\n", 827 | " 'follow', 'followed', 'follower', 'followers', 'followfriday',\n", 828 | " 'following', 'food', 'foot', 'forever', 'forget', 'forgot',\n", 829 | " 'forward', 'found', 'freakin', 'freaking', 'free', 'french',\n", 830 | " 'fresh', 'friday', 'friend', 'friends', 'fun', 'funny', 'future',\n", 831 | " 'game', 'garden', 'gave', 'germany', 'gettin', 'getting', 'gift',\n", 832 | " 'gig', 'girl', 'girls', 'giving', 'glad', 'god', 'goin', 'going',\n", 833 | " 'gone', 'gonna', 'good', 'goodbye', 'goodnight', 'google',\n", 834 | " 'gorgeous', 'gosh', 'got', 'gotta', 'graduation', 'great', 'green',\n", 835 | " 'group', 'guess', 'gutted', 'guy', 'guys', 'gym', 'ha', 'haha',\n", 836 | " 'hahah', 'hahaha', 'hahahaha', 'hair', 'half', 'hand', 'hang',\n", 837 | " 'hanging', 'hannah', 'happen', 'happened', 'happens', 'happiness',\n", 838 | " 'happy', 'hard', 'hate', 'havent', 'having', 'head', 'headache',\n", 839 | " 'heading', 'healthy', 'hear', 'heard', 'heart', 'heat', 'hehe',\n", 840 | " 'hell', 'hello', 'help', 'helping', 'hes', 'hey', 'hi', 'high',\n", 841 | " 'hilarious', 'hill', 'history', 'hit', 'hmm', 'holiday', 'home',\n", 842 | " 'homework', 'hope', 'hopefully', 'hoping', 'horrible', 'hospital',\n", 843 | " 'hot', 'hour', 'hours', 'house', 'hr', 'hubby', 'hug', 'huge',\n", 844 | " 'hugs', 'huh', 'hun', 'hungry', 'hurt', 'hurting', 'hurts', 'ice',\n", 845 | " 'id', 'idea', 'idk', 'ill', 'im', 'ima', 'important', 'impressed',\n", 846 | " 'info', 'inside', 'instead', 'interesting', 'internet',\n", 847 | " 'interview', 'iphone', 'ipod', 'isnt', 'itll', 'itunes', 'ive',\n", 848 | " 'iï', 'jealous', 'job', 'john', 'join', 'joke', 'jon', 'jonas',\n", 849 | " 'joy', 'july', 'june', 'jus', 'justin', 'key', 'kick', 'kid',\n", 850 | " 'kids', 'kill', 'killed', 'killing', 'kind', 'kinda', 'kiss',\n", 851 | " 'knee', 'knew', 'know', 'la', 'lady', 'lake', 'lame', 'laptop',\n", 852 | " 'late', 'lately', 'later', 'laugh', 'lay', 'laying', 'lazy', 'le',\n", 853 | " 'learn', 'leave', 'leaving', 'left', 'leg', 'let', 'lets', 'life',\n", 854 | " 'light', 'like', 'liked', 'lil', 'line', 'link', 'list', 'listen',\n", 855 | " 'listening', 'little', 'live', 'living', 'lmao', 'load', 'lol',\n", 856 | " 'london', 'lonely', 'long', 'longer', 'look', 'looked', 'looking',\n", 857 | " 'lose', 'losing', 'lost', 'lot', 'love', 'loved', 'lovely',\n", 858 | " 'loving', 'low', 'luck', 'lucky', 'lunch', 'luv', 'ma', 'mac',\n", 859 | " 'macbook', 'mad', 'magic', 'making', 'mama', 'man', 'mate', 'math',\n", 860 | " 'matter', 'maybe', 'mean', 'meant', 'meet', 'meeting', 'men',\n", 861 | " 'mention', 'mess', 'message', 'messed', 'met', 'middle', 'mile',\n", 862 | " 'miley', 'min', 'mind', 'minute', 'miss', 'missed', 'missing',\n", 863 | " 'mobile', 'mom', 'moment', 'momma', 'mommy', 'moms', 'monday',\n", 864 | " 'money', 'month', 'months', 'mood', 'moon', 'morning', 'mother',\n", 865 | " 'mothers', 'mouth', 'moved', 'movie', 'moving', 'mr', 'mum',\n", 866 | " 'music', 'nap', 'nd', 'near', 'neck', 'need', 'needed', 'new',\n", 867 | " 'news', 'nice', 'night', 'nite', 'nope', 'nose', 'note', 'number',\n", 868 | " 'ny', 'nyc', 'office', 'officially', 'oh', 'ohh', 'ohhh', 'ok',\n", 869 | " 'okay', 'old', 'omg', 'ones', 'online', 'ooh', 'oops', 'open',\n", 870 | " 'order', 'ouch', 'outside', 'outta', 'packing', 'page', 'pain',\n", 871 | " 'paper', 'parent', 'park', 'party', 'pas', 'past', 'pay', 'pc',\n", 872 | " 'peace', 'people', 'perfect', 'person', 'phone', 'photo', 'pic',\n", 873 | " 'pick', 'picture', 'pink', 'pissed', 'pizza', 'place', 'plan',\n", 874 | " 'play', 'played', 'playing', 'plus', 'pm', 'point', 'pool', 'poor',\n", 875 | " 'post', 'posted', 'power', 'ppl', 'pray', 'present', 'pretty',\n", 876 | " 'prob', 'probably', 'problem', 'productive', 'profile', 'project',\n", 877 | " 'prom', 'proud', 'ps', 'public', 'puppy', 'putting', 'question',\n", 878 | " 'quick', 'quiet', 'quote', 'radio', 'rain', 'raining', 'rainy',\n", 879 | " 'ran', 'read', 'reading', 'ready', 'real', 'realized', 'reason',\n", 880 | " 'red', 'relaxing', 'remember', 'reply', 'rest', 'return', 'review',\n", 881 | " 'revision', 'ride', 'right', 'ring', 'rip', 'rock', 'room',\n", 882 | " 'round', 'run', 'running', 'sad', 'sadly', 'safe', 'said', 'sale',\n", 883 | " 'sam', 'sat', 'saturday', 'save', 'saw', 'saying', 'scared',\n", 884 | " 'scary', 'school', 'screen', 'season', 'second', 'seeing', 'seen',\n", 885 | " 'self', 'send', 'sending', 'sense', 'sent', 'series', 'seriously',\n", 886 | " 'service', 'session', 'set', 'sexy', 'shall', 'shame', 'share',\n", 887 | " 'sharing', 'shes', 'shift', 'shirt', 'shoe', 'shop', 'shopping',\n", 888 | " 'short', 'shot', 'shouldnt', 'shower', 'showing', 'shut', 'si',\n", 889 | " 'sick', 'sigh', 'sign', 'simple', 'sing', 'singing', 'single',\n", 890 | " 'sister', 'sit', 'site', 'sitting', 'sleep', 'sleeping', 'sleepy',\n", 891 | " 'slept', 'slow', 'small', 'smile', 'snl', 'sold', 'son', 'song',\n", 892 | " 'soo', 'soon', 'sooo', 'soooo', 'sooooo', 'sore', 'sorry', 'sound',\n", 893 | " 'special', 'spend', 'spending', 'spent', 'st', 'stand', 'star',\n", 894 | " 'starbucks', 'start', 'started', 'starting', 'state', 'stay',\n", 895 | " 'stick', 'stomach', 'stop', 'stopped', 'store', 'story', 'strange',\n", 896 | " 'stressed', 'stuck', 'student', 'study', 'studying', 'stuff',\n", 897 | " 'stupid', 'suck', 'sucks', 'summer', 'sun', 'sunday', 'sunny',\n", 898 | " 'sunshine', 'super', 'support', 'supposed', 'sure', 'surgery',\n", 899 | " 'surprise', 'sux', 'swear', 'sweet', 'sweetie', 'swine', 'system',\n", 900 | " 'taken', 'taking', 'talent', 'talk', 'talking', 'taste', 'taylor',\n", 901 | " 'tea', 'team', 'tear', 'tell', 'terrible', 'test', 'text', 'th',\n", 902 | " 'thank', 'thanks', 'thanx', 'thats', 'theres', 'theyre', 'thing',\n", 903 | " 'things', 'think', 'thinking', 'tho', 'thought', 'throat',\n", 904 | " 'thursday', 'thx', 'ticket', 'til', 'till', 'time', 'times',\n", 905 | " 'tired', 'today', 'told', 'tom', 'tomorrow', 'tonight', 'took',\n", 906 | " 'totally', 'touch', 'tour', 'town', 'track', 'traffic', 'train',\n", 907 | " 'travel', 'trek', 'tried', 'trip', 'trouble', 'true', 'truly',\n", 908 | " 'try', 'trying', 'tuesday', 'tummy', 'turn', 'turned', 'tv',\n", 909 | " 'tweet', 'tweetdeck', 'tweeting', 'tweets', 'twilight', 'twitter',\n", 910 | " 'ugh', 'ugly', 'uk', 'understand', 'unfortunately', 'update',\n", 911 | " 'upload', 'upset', 'ur', 'use', 'usually', 'version', 'video',\n", 912 | " 'visit', 'voice', 'vote', 'wa', 'wait', 'waiting', 'wake',\n", 913 | " 'waking', 'walk', 'walking', 'wanna', 'want', 'wanted', 'war',\n", 914 | " 'warm', 'wasnt', 'watch', 'watched', 'watching', 'water', 'way',\n", 915 | " 'wear', 'wearing', 'weather', 'website', 'wedding', 'week',\n", 916 | " 'weekend', 'weeks', 'weird', 'welcome', 'went', 'whats', 'white',\n", 917 | " 'wife', 'win', 'window', 'wine', 'wish', 'wishing', 'woke',\n", 918 | " 'woman', 'won', 'wonder', 'wonderful', 'wondering', 'wont', 'woo',\n", 919 | " 'word', 'work', 'worked', 'working', 'world', 'worried', 'worry',\n", 920 | " 'worse', 'worst', 'worth', 'wouldnt', 'wow', 'write', 'writing',\n", 921 | " 'wrong', 'wtf', 'xd', 'xoxo', 'ya', 'yall', 'yay', 'yea', 'yeah',\n", 922 | " 'year', 'years', 'yep', 'yes', 'yesterday', 'yo', 'youll', 'young',\n", 923 | " 'youre', 'youtube', 'youve', 'yr', 'yum', 'yummy', '½m', '½s'],\n", 924 | " dtype=object)" 925 | ] 926 | }, 927 | "execution_count": 96, 928 | "metadata": {}, 929 | "output_type": "execute_result" 930 | } 931 | ], 932 | "source": [ 933 | "vectorizer.get_feature_names_out()" 934 | ] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "execution_count": 97, 939 | "id": "84232a1f", 940 | "metadata": { 941 | "ExecuteTime": { 942 | "end_time": "2022-06-07T13:08:36.848531Z", 943 | "start_time": "2022-06-07T13:08:36.766568Z" 944 | } 945 | }, 946 | "outputs": [], 947 | "source": [ 948 | "np_features = features.toarray()" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": 98, 954 | "id": "50d5d571", 955 | "metadata": { 956 | "ExecuteTime": { 957 | "end_time": "2022-06-07T13:08:38.328208Z", 958 | "start_time": "2022-06-07T13:08:38.320072Z" 959 | } 960 | }, 961 | "outputs": [ 962 | { 963 | "data": { 964 | "text/plain": [ 965 | "(16363, 1000)" 966 | ] 967 | }, 968 | "execution_count": 98, 969 | "metadata": {}, 970 | "output_type": "execute_result" 971 | } 972 | ], 973 | "source": [ 974 | "np_features.shape" 975 | ] 976 | }, 977 | { 978 | "cell_type": "markdown", 979 | "id": "6e2e3479", 980 | "metadata": {}, 981 | "source": [ 982 | "We should also work with the labels..." 983 | ] 984 | }, 985 | { 986 | "cell_type": "code", 987 | "execution_count": 99, 988 | "id": "dfa5b4f1", 989 | "metadata": { 990 | "ExecuteTime": { 991 | "end_time": "2022-06-07T13:09:21.017350Z", 992 | "start_time": "2022-06-07T13:09:20.981632Z" 993 | } 994 | }, 995 | "outputs": [ 996 | { 997 | "data": { 998 | "text/html": [ 999 | "
\n", 1000 | "\n", 1013 | "\n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | "
indextextsentimentpreprocessed_textnum_sentiment
01Sooo SAD I will miss you here in San Diego!!!negativesooo sad miss san diego0
12my boss is bullying me...negativebos bullying0
23what interview! leave me alonenegativeinterview leave0
34Sons of ****, why couldn`t they put them on t...negativeson couldnt release bought0
462am feedings for the baby are fun when he is a...positiveam feeding baby fun smile coo1
\n", 1067 | "
" 1068 | ], 1069 | "text/plain": [ 1070 | " index text sentiment \\\n", 1071 | "0 1 Sooo SAD I will miss you here in San Diego!!! negative \n", 1072 | "1 2 my boss is bullying me... negative \n", 1073 | "2 3 what interview! leave me alone negative \n", 1074 | "3 4 Sons of ****, why couldn`t they put them on t... negative \n", 1075 | "4 6 2am feedings for the baby are fun when he is a... positive \n", 1076 | "\n", 1077 | " preprocessed_text num_sentiment \n", 1078 | "0 sooo sad miss san diego 0 \n", 1079 | "1 bos bullying 0 \n", 1080 | "2 interview leave 0 \n", 1081 | "3 son couldnt release bought 0 \n", 1082 | "4 am feeding baby fun smile coo 1 " 1083 | ] 1084 | }, 1085 | "execution_count": 99, 1086 | "metadata": {}, 1087 | "output_type": "execute_result" 1088 | } 1089 | ], 1090 | "source": [ 1091 | "df['num_sentiment'] = df['sentiment'].apply(lambda x: 0 if x == 'negative' else 1)\n", 1092 | "df.head()" 1093 | ] 1094 | }, 1095 | { 1096 | "cell_type": "code", 1097 | "execution_count": 100, 1098 | "id": "d7c9d5ca", 1099 | "metadata": { 1100 | "ExecuteTime": { 1101 | "end_time": "2022-06-07T13:09:30.734110Z", 1102 | "start_time": "2022-06-07T13:09:30.729289Z" 1103 | } 1104 | }, 1105 | "outputs": [], 1106 | "source": [ 1107 | "labels = df['num_sentiment'].to_list()" 1108 | ] 1109 | }, 1110 | { 1111 | "cell_type": "code", 1112 | "execution_count": 101, 1113 | "id": "b257eefa", 1114 | "metadata": { 1115 | "ExecuteTime": { 1116 | "end_time": "2022-06-07T13:09:33.567310Z", 1117 | "start_time": "2022-06-07T13:09:33.559989Z" 1118 | } 1119 | }, 1120 | "outputs": [ 1121 | { 1122 | "data": { 1123 | "text/plain": [ 1124 | "[0, 0, 0, 0, 1]" 1125 | ] 1126 | }, 1127 | "execution_count": 101, 1128 | "metadata": {}, 1129 | "output_type": "execute_result" 1130 | } 1131 | ], 1132 | "source": [ 1133 | "labels[:5]" 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "markdown", 1138 | "id": "071a7cf9", 1139 | "metadata": {}, 1140 | "source": [ 1141 | "## Train / Test / Dev Split" 1142 | ] 1143 | }, 1144 | { 1145 | "cell_type": "code", 1146 | "execution_count": 102, 1147 | "id": "1ba0006a", 1148 | "metadata": { 1149 | "ExecuteTime": { 1150 | "end_time": "2022-06-07T13:12:31.888186Z", 1151 | "start_time": "2022-06-07T13:12:31.589910Z" 1152 | } 1153 | }, 1154 | "outputs": [], 1155 | "source": [ 1156 | "f_train, f_rem, l_train, l_rem = train_test_split(np_features, labels, test_size=1-config['train_percentage']/100, random_state=50)\n", 1157 | "f_test, f_dev, l_test, l_dev = train_test_split(f_rem, l_rem, test_size=0.5, random_state=50)" 1158 | ] 1159 | }, 1160 | { 1161 | "cell_type": "code", 1162 | "execution_count": 103, 1163 | "id": "55f1275e", 1164 | "metadata": { 1165 | "ExecuteTime": { 1166 | "end_time": "2022-06-07T13:12:33.276538Z", 1167 | "start_time": "2022-06-07T13:12:33.271927Z" 1168 | } 1169 | }, 1170 | "outputs": [ 1171 | { 1172 | "name": "stdout", 1173 | "output_type": "stream", 1174 | "text": [ 1175 | "train features: (14726, 1000), dev features: (819, 1000), test features: (818, 1000)\n" 1176 | ] 1177 | } 1178 | ], 1179 | "source": [ 1180 | "print(f'train features: {f_train.shape}, dev features: {f_dev.shape}, test features: {f_test.shape}')" 1181 | ] 1182 | }, 1183 | { 1184 | "cell_type": "code", 1185 | "execution_count": 104, 1186 | "id": "8d226c3b", 1187 | "metadata": { 1188 | "ExecuteTime": { 1189 | "end_time": "2022-06-07T13:12:51.450352Z", 1190 | "start_time": "2022-06-07T13:12:51.437967Z" 1191 | } 1192 | }, 1193 | "outputs": [ 1194 | { 1195 | "name": "stdout", 1196 | "output_type": "stream", 1197 | "text": [ 1198 | "train labels: 14726, dev labels: 819, test labels: 818\n" 1199 | ] 1200 | } 1201 | ], 1202 | "source": [ 1203 | "print(f'train labels: {len(l_train)}, dev labels: {len(l_dev)}, test labels: {len(l_test)}')" 1204 | ] 1205 | }, 1206 | { 1207 | "cell_type": "markdown", 1208 | "id": "4bd2ef51", 1209 | "metadata": {}, 1210 | "source": [ 1211 | "## Converting Everything to Tensors" 1212 | ] 1213 | }, 1214 | { 1215 | "cell_type": "markdown", 1216 | "id": "b45678ad", 1217 | "metadata": {}, 1218 | "source": [ 1219 | "The numpy array we defined above should be converted to a tensor. This tensor will be used in a \"Dataset\" object." 1220 | ] 1221 | }, 1222 | { 1223 | "cell_type": "code", 1224 | "execution_count": 105, 1225 | "id": "6b08c1fe", 1226 | "metadata": { 1227 | "ExecuteTime": { 1228 | "end_time": "2022-06-07T13:14:40.849936Z", 1229 | "start_time": "2022-06-07T13:14:40.837745Z" 1230 | } 1231 | }, 1232 | "outputs": [], 1233 | "source": [ 1234 | "class MyVectorDataset(Dataset):\n", 1235 | " def __init__(self, features, labels):\n", 1236 | " self.features = features\n", 1237 | " self.labels = np.array(labels).reshape(-1, 1)\n", 1238 | " def __len__(self):\n", 1239 | " return self.features.shape[0]\n", 1240 | " def __getitem__(self, idx):\n", 1241 | " return torch.Tensor(self.features[idx]), torch.Tensor(self.labels[idx])" 1242 | ] 1243 | }, 1244 | { 1245 | "cell_type": "code", 1246 | "execution_count": 106, 1247 | "id": "80e62889", 1248 | "metadata": { 1249 | "ExecuteTime": { 1250 | "end_time": "2022-06-07T13:14:54.781392Z", 1251 | "start_time": "2022-06-07T13:14:54.775362Z" 1252 | } 1253 | }, 1254 | "outputs": [], 1255 | "source": [ 1256 | "train_dataset = MyVectorDataset(f_train, l_train)\n", 1257 | "test_dataset = MyVectorDataset(f_test, l_test)\n", 1258 | "dev_dataset = MyVectorDataset(f_dev, l_dev)" 1259 | ] 1260 | }, 1261 | { 1262 | "cell_type": "code", 1263 | "execution_count": 107, 1264 | "id": "7e443816", 1265 | "metadata": { 1266 | "ExecuteTime": { 1267 | "end_time": "2022-06-07T13:15:09.393343Z", 1268 | "start_time": "2022-06-07T13:15:09.378632Z" 1269 | } 1270 | }, 1271 | "outputs": [], 1272 | "source": [ 1273 | "train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)\n", 1274 | "test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)\n", 1275 | "dev_dataloader = DataLoader(dev_dataset, batch_size=config['batch_size'], shuffle=True)" 1276 | ] 1277 | }, 1278 | { 1279 | "cell_type": "markdown", 1280 | "id": "c5006d18", 1281 | "metadata": {}, 1282 | "source": [ 1283 | "# Neural Net Architecture" 1284 | ] 1285 | }, 1286 | { 1287 | "cell_type": "code", 1288 | "execution_count": 108, 1289 | "id": "070b61b4", 1290 | "metadata": { 1291 | "ExecuteTime": { 1292 | "end_time": "2022-06-07T13:15:46.833065Z", 1293 | "start_time": "2022-06-07T13:15:46.817936Z" 1294 | } 1295 | }, 1296 | "outputs": [ 1297 | { 1298 | "data": { 1299 | "text/plain": [ 1300 | "'cpu'" 1301 | ] 1302 | }, 1303 | "execution_count": 108, 1304 | "metadata": {}, 1305 | "output_type": "execute_result" 1306 | } 1307 | ], 1308 | "source": [ 1309 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 1310 | "device" 1311 | ] 1312 | }, 1313 | { 1314 | "cell_type": "code", 1315 | "execution_count": 109, 1316 | "id": "e9bdad21", 1317 | "metadata": { 1318 | "ExecuteTime": { 1319 | "end_time": "2022-06-07T13:17:43.987486Z", 1320 | "start_time": "2022-06-07T13:17:43.974767Z" 1321 | } 1322 | }, 1323 | "outputs": [], 1324 | "source": [ 1325 | "class my_neural_net(torch.nn.Module):\n", 1326 | " def __init__(self):\n", 1327 | " super(my_neural_net, self).__init__() \n", 1328 | " self.first_layer = torch.nn.Sequential( \n", 1329 | " nn.Linear(config['max_features'], 1),\n", 1330 | " nn.Sigmoid()\n", 1331 | " )\n", 1332 | " def forward(self, x):\n", 1333 | " output = self.first_layer(x)\n", 1334 | " return output" 1335 | ] 1336 | }, 1337 | { 1338 | "cell_type": "code", 1339 | "execution_count": 110, 1340 | "id": "01e8f9a4", 1341 | "metadata": { 1342 | "ExecuteTime": { 1343 | "end_time": "2022-06-07T13:17:46.903195Z", 1344 | "start_time": "2022-06-07T13:17:46.874879Z" 1345 | } 1346 | }, 1347 | "outputs": [], 1348 | "source": [ 1349 | "simple_nn = my_neural_net()" 1350 | ] 1351 | }, 1352 | { 1353 | "cell_type": "code", 1354 | "execution_count": 111, 1355 | "id": "3d469445", 1356 | "metadata": { 1357 | "ExecuteTime": { 1358 | "end_time": "2022-06-07T13:18:01.254913Z", 1359 | "start_time": "2022-06-07T13:18:01.235635Z" 1360 | } 1361 | }, 1362 | "outputs": [], 1363 | "source": [ 1364 | "simple_nn = simple_nn.to(device)" 1365 | ] 1366 | }, 1367 | { 1368 | "cell_type": "code", 1369 | "execution_count": 112, 1370 | "id": "eafc60c2", 1371 | "metadata": { 1372 | "ExecuteTime": { 1373 | "end_time": "2022-06-07T13:18:27.119810Z", 1374 | "start_time": "2022-06-07T13:18:27.032854Z" 1375 | } 1376 | }, 1377 | "outputs": [ 1378 | { 1379 | "data": { 1380 | "text/plain": [ 1381 | "tensor([[0.4972],\n", 1382 | " [0.4993]], grad_fn=)" 1383 | ] 1384 | }, 1385 | "execution_count": 112, 1386 | "metadata": {}, 1387 | "output_type": "execute_result" 1388 | } 1389 | ], 1390 | "source": [ 1391 | "simple_nn(train_dataset[:2][0])" 1392 | ] 1393 | }, 1394 | { 1395 | "cell_type": "code", 1396 | "execution_count": 113, 1397 | "id": "987951b1", 1398 | "metadata": { 1399 | "ExecuteTime": { 1400 | "end_time": "2022-06-07T13:18:41.231088Z", 1401 | "start_time": "2022-06-07T13:18:41.215579Z" 1402 | } 1403 | }, 1404 | "outputs": [ 1405 | { 1406 | "data": { 1407 | "text/plain": [ 1408 | "torch.Size([2, 1])" 1409 | ] 1410 | }, 1411 | "execution_count": 113, 1412 | "metadata": {}, 1413 | "output_type": "execute_result" 1414 | } 1415 | ], 1416 | "source": [ 1417 | "simple_nn(train_dataset[:2][0]).shape" 1418 | ] 1419 | }, 1420 | { 1421 | "cell_type": "markdown", 1422 | "id": "2904153e", 1423 | "metadata": {}, 1424 | "source": [ 1425 | "# Training" 1426 | ] 1427 | }, 1428 | { 1429 | "cell_type": "markdown", 1430 | "id": "afa42b77", 1431 | "metadata": {}, 1432 | "source": [ 1433 | "## Binary Cross-Entropy" 1434 | ] 1435 | }, 1436 | { 1437 | "cell_type": "markdown", 1438 | "id": "68284027", 1439 | "metadata": {}, 1440 | "source": [ 1441 | "For each point, the loss is calculated like this: $l_n = -w_n[y_n.\\log(\\mathrm{pred}_n)+(1-y_n).\\log(1-\\mathrm{pred}_n)]$ where $w_n$ is a rescaling factor
" 1442 | ] 1443 | }, 1444 | { 1445 | "cell_type": "markdown", 1446 | "id": "66b86a65", 1447 | "metadata": {}, 1448 | "source": [ 1449 | "Assume that $w_n=1$" 1450 | ] 1451 | }, 1452 | { 1453 | "cell_type": "markdown", 1454 | "id": "92294e91", 1455 | "metadata": {}, 1456 | "source": [ 1457 | "If $y_n=0$ and $\\mathrm{pred}_n=1$, then we'll have $l_n=-w_n(0.log(1)+1.log(0))=-w_n(0.log(1)+1.-\\infty)=+\\infty$" 1458 | ] 1459 | }, 1460 | { 1461 | "cell_type": "markdown", 1462 | "id": "05728e3f", 1463 | "metadata": {}, 1464 | "source": [ 1465 | "If $y_n=0$ and $\\mathrm{pred}_n=0.1$, then we'll have $l_n=-w_n(0.log(0.1)+1.log(0.9))=-(-0.04)=0.04$" 1466 | ] 1467 | }, 1468 | { 1469 | "cell_type": "markdown", 1470 | "id": "32c16527", 1471 | "metadata": {}, 1472 | "source": [ 1473 | "If $y_n=0$ and $\\mathrm{pred}_n=0.9$, then we'll have $l_n=-w_n(0.log(0.9)+1.log(0.1))=-(-1)=1$" 1474 | ] 1475 | }, 1476 | { 1477 | "cell_type": "code", 1478 | "execution_count": 114, 1479 | "id": "c86608dd", 1480 | "metadata": { 1481 | "ExecuteTime": { 1482 | "end_time": "2022-06-07T13:21:25.439043Z", 1483 | "start_time": "2022-06-07T13:21:25.428685Z" 1484 | } 1485 | }, 1486 | "outputs": [], 1487 | "source": [ 1488 | "loss_fn = nn.BCELoss()" 1489 | ] 1490 | }, 1491 | { 1492 | "cell_type": "markdown", 1493 | "id": "3e9b5ddc", 1494 | "metadata": {}, 1495 | "source": [ 1496 | "## Optimizer" 1497 | ] 1498 | }, 1499 | { 1500 | "cell_type": "markdown", 1501 | "id": "518ac0cf", 1502 | "metadata": {}, 1503 | "source": [ 1504 | "Note that stochastic gradient descent performs a parameter update for each training example $x_i$ and label $y_i$" 1505 | ] 1506 | }, 1507 | { 1508 | "cell_type": "markdown", 1509 | "id": "484ba024", 1510 | "metadata": {}, 1511 | "source": [ 1512 | "$\\theta = \\theta - \\eta.\\nabla_\\theta J(\\theta;x_i;y_i)$" 1513 | ] 1514 | }, 1515 | { 1516 | "cell_type": "code", 1517 | "execution_count": 115, 1518 | "id": "5f6fcd5a", 1519 | "metadata": { 1520 | "ExecuteTime": { 1521 | "end_time": "2022-06-07T13:23:13.579253Z", 1522 | "start_time": "2022-06-07T13:23:13.565666Z" 1523 | } 1524 | }, 1525 | "outputs": [], 1526 | "source": [ 1527 | "optimizer = torch.optim.SGD(simple_nn.parameters(), lr=config['learning_rate'])" 1528 | ] 1529 | }, 1530 | { 1531 | "cell_type": "code", 1532 | "execution_count": 116, 1533 | "id": "95de2f05", 1534 | "metadata": { 1535 | "ExecuteTime": { 1536 | "end_time": "2022-06-07T13:24:13.074906Z", 1537 | "start_time": "2022-06-07T13:24:13.063585Z" 1538 | } 1539 | }, 1540 | "outputs": [], 1541 | "source": [ 1542 | "def output_to_label(out):\n", 1543 | " dist_to_0 = abs(out)\n", 1544 | " dist_to_1 = abs(out-1)\n", 1545 | " if dist_to_0 <= dist_to_1:\n", 1546 | " return 0\n", 1547 | " else:\n", 1548 | " return 1" 1549 | ] 1550 | }, 1551 | { 1552 | "cell_type": "code", 1553 | "execution_count": 117, 1554 | "id": "bbd5fc91", 1555 | "metadata": { 1556 | "ExecuteTime": { 1557 | "end_time": "2022-06-07T13:31:46.646317Z", 1558 | "start_time": "2022-06-07T13:31:46.616928Z" 1559 | } 1560 | }, 1561 | "outputs": [], 1562 | "source": [ 1563 | "def train_loop(dataloader, model, loss_fn, optimizer, epoch_num):\n", 1564 | " num_points = len(dataloader.dataset)\n", 1565 | " for batch, (features, labels) in enumerate(dataloader): \n", 1566 | " # Compute prediction and loss\n", 1567 | " pred = model(features)\n", 1568 | " loss = loss_fn(pred, labels)\n", 1569 | " \n", 1570 | " # Backpropagation\n", 1571 | " optimizer.zero_grad() # sets gradients of all model parameters to zero\n", 1572 | " loss.backward() # calculate the gradients again\n", 1573 | " optimizer.step() # w = w - learning_rate * grad(loss)_with_respect_to_w\n", 1574 | "\n", 1575 | " if batch % 100 == 0:\n", 1576 | " loss, current = loss.item(), batch * len(features)\n", 1577 | " print(f\"\\r Epoch {epoch_num} - loss: {loss:>7f} [{current:>5d}/{num_points:>5d}]\", end=\" \")\n", 1578 | "\n", 1579 | "\n", 1580 | "def test_loop(dataloader, model, loss_fn, epoch_num, name):\n", 1581 | " num_points = len(dataloader.dataset)\n", 1582 | " sum_test_loss, correct = 0, 0\n", 1583 | "\n", 1584 | " with torch.no_grad():\n", 1585 | " for batch, (features, labels) in enumerate(dataloader):\n", 1586 | " pred = model(features)\n", 1587 | " sum_test_loss += loss_fn(pred, labels).item() # add the current loss to the sum of the losses\n", 1588 | " # convert the outputs of the model on the current batch to a numpy array\n", 1589 | " pred_lst = list(pred.numpy().squeeze())\n", 1590 | " pred_lst = [output_to_label(item) for item in pred_lst]\n", 1591 | " # convert the original labels corresponding to the current batch to a numpy array\n", 1592 | " output_lst = list(labels.numpy().squeeze()) \n", 1593 | " # determine the points for which the model is correctly predicting the label (add a 1 for each)\n", 1594 | " match_lst = [1 if p==o else 0 for (p, o) in zip(pred_lst, output_lst)] \n", 1595 | " # count how many points are labeled correctly in this batch and add the number to the overall count of the correct labeled points\n", 1596 | " correct += sum(match_lst) \n", 1597 | " \n", 1598 | " sum_test_loss /= num_points\n", 1599 | " correct /= num_points\n", 1600 | " print(f\"\\r Epoch {epoch_num} - {name} Error: Accuracy: {(100*correct):>0.1f}%, Avg loss: {sum_test_loss:>8f}\", end=\" \")" 1601 | ] 1602 | }, 1603 | { 1604 | "cell_type": "code", 1605 | "execution_count": 118, 1606 | "id": "2956a42e", 1607 | "metadata": { 1608 | "ExecuteTime": { 1609 | "end_time": "2022-06-07T13:33:38.883705Z", 1610 | "start_time": "2022-06-07T13:31:56.079832Z" 1611 | } 1612 | }, 1613 | "outputs": [ 1614 | { 1615 | "name": "stdout", 1616 | "output_type": "stream", 1617 | "text": [ 1618 | " Epoch 200 - Development/Validation Error: Accuracy: 86.1%, Avg loss: 0.010535 " 1619 | ] 1620 | } 1621 | ], 1622 | "source": [ 1623 | "for epoch_num in range(1, config['num_epochs']+1):\n", 1624 | " train_loop(train_dataloader, simple_nn, loss_fn, optimizer, epoch_num)\n", 1625 | " test_loop(dev_dataloader, simple_nn, loss_fn, epoch_num, 'Development/Validation')" 1626 | ] 1627 | }, 1628 | { 1629 | "cell_type": "code", 1630 | "execution_count": 119, 1631 | "id": "7cc97521", 1632 | "metadata": { 1633 | "ExecuteTime": { 1634 | "end_time": "2022-06-07T13:33:38.958639Z", 1635 | "start_time": "2022-06-07T13:33:38.886807Z" 1636 | } 1637 | }, 1638 | "outputs": [ 1639 | { 1640 | "name": "stdout", 1641 | "output_type": "stream", 1642 | "text": [ 1643 | "\r", 1644 | " Epoch 200 - Test Error: Accuracy: 83.7%, Avg loss: 0.011854 " 1645 | ] 1646 | } 1647 | ], 1648 | "source": [ 1649 | "test_loop(test_dataloader, simple_nn, loss_fn, epoch_num, 'Test')" 1650 | ] 1651 | }, 1652 | { 1653 | "cell_type": "markdown", 1654 | "id": "ee1517b0", 1655 | "metadata": {}, 1656 | "source": [ 1657 | "# Saving the Model" 1658 | ] 1659 | }, 1660 | { 1661 | "cell_type": "code", 1662 | "execution_count": 120, 1663 | "id": "3f0c4405", 1664 | "metadata": { 1665 | "ExecuteTime": { 1666 | "end_time": "2022-06-07T13:34:27.534587Z", 1667 | "start_time": "2022-06-07T13:34:27.507954Z" 1668 | } 1669 | }, 1670 | "outputs": [], 1671 | "source": [ 1672 | "torch.save(simple_nn.state_dict(), \"neural_net.pth\")" 1673 | ] 1674 | }, 1675 | { 1676 | "cell_type": "markdown", 1677 | "id": "1e4def5d", 1678 | "metadata": {}, 1679 | "source": [ 1680 | "# Load the Model" 1681 | ] 1682 | }, 1683 | { 1684 | "cell_type": "code", 1685 | "execution_count": 121, 1686 | "id": "b9feed95", 1687 | "metadata": { 1688 | "ExecuteTime": { 1689 | "end_time": "2022-06-07T13:35:28.453060Z", 1690 | "start_time": "2022-06-07T13:35:28.411372Z" 1691 | } 1692 | }, 1693 | "outputs": [ 1694 | { 1695 | "data": { 1696 | "text/plain": [ 1697 | "my_neural_net(\n", 1698 | " (first_layer): Sequential(\n", 1699 | " (0): Linear(in_features=1000, out_features=1, bias=True)\n", 1700 | " (1): Sigmoid()\n", 1701 | " )\n", 1702 | ")" 1703 | ] 1704 | }, 1705 | "execution_count": 121, 1706 | "metadata": {}, 1707 | "output_type": "execute_result" 1708 | } 1709 | ], 1710 | "source": [ 1711 | "model = my_neural_net()\n", 1712 | "model.load_state_dict(torch.load(\"neural_net.pth\"))\n", 1713 | "model.eval() # use this line if you have Dropout and BatchNormalization layers in your model" 1714 | ] 1715 | }, 1716 | { 1717 | "cell_type": "code", 1718 | "execution_count": 122, 1719 | "id": "900fd5ef", 1720 | "metadata": { 1721 | "ExecuteTime": { 1722 | "end_time": "2022-06-07T13:35:33.428980Z", 1723 | "start_time": "2022-06-07T13:35:33.412855Z" 1724 | } 1725 | }, 1726 | "outputs": [ 1727 | { 1728 | "data": { 1729 | "text/plain": [ 1730 | "tensor([[0.1337],\n", 1731 | " [0.9956]], grad_fn=)" 1732 | ] 1733 | }, 1734 | "execution_count": 122, 1735 | "metadata": {}, 1736 | "output_type": "execute_result" 1737 | } 1738 | ], 1739 | "source": [ 1740 | "model(test_dataset[:2][0])" 1741 | ] 1742 | }, 1743 | { 1744 | "cell_type": "code", 1745 | "execution_count": 123, 1746 | "id": "c2659ab6", 1747 | "metadata": { 1748 | "ExecuteTime": { 1749 | "end_time": "2022-06-07T13:35:41.701427Z", 1750 | "start_time": "2022-06-07T13:35:41.690160Z" 1751 | } 1752 | }, 1753 | "outputs": [ 1754 | { 1755 | "data": { 1756 | "text/plain": [ 1757 | "[0, 1]" 1758 | ] 1759 | }, 1760 | "execution_count": 123, 1761 | "metadata": {}, 1762 | "output_type": "execute_result" 1763 | } 1764 | ], 1765 | "source": [ 1766 | "l_test[:2]" 1767 | ] 1768 | }, 1769 | { 1770 | "cell_type": "markdown", 1771 | "id": "fe0df351", 1772 | "metadata": {}, 1773 | "source": [ 1774 | "# The ONNX Format" 1775 | ] 1776 | }, 1777 | { 1778 | "cell_type": "markdown", 1779 | "id": "c2e11893", 1780 | "metadata": {}, 1781 | "source": [ 1782 | "This format is useful when you want to use your model while coding in Java, Javascript, and C#!" 1783 | ] 1784 | }, 1785 | { 1786 | "cell_type": "markdown", 1787 | "id": "3c98b040", 1788 | "metadata": {}, 1789 | "source": [ 1790 | "## Save the Model" 1791 | ] 1792 | }, 1793 | { 1794 | "cell_type": "code", 1795 | "execution_count": 124, 1796 | "id": "0e75ced9", 1797 | "metadata": { 1798 | "ExecuteTime": { 1799 | "end_time": "2022-06-07T13:36:49.073992Z", 1800 | "start_time": "2022-06-07T13:36:49.047633Z" 1801 | } 1802 | }, 1803 | "outputs": [], 1804 | "source": [ 1805 | "dummy_input = torch.zeros((1,config['max_features']))" 1806 | ] 1807 | }, 1808 | { 1809 | "cell_type": "code", 1810 | "execution_count": 125, 1811 | "id": "3e259531", 1812 | "metadata": { 1813 | "ExecuteTime": { 1814 | "end_time": "2022-06-07T13:37:00.073765Z", 1815 | "start_time": "2022-06-07T13:36:59.889922Z" 1816 | } 1817 | }, 1818 | "outputs": [], 1819 | "source": [ 1820 | "onnx.export(model, dummy_input, 'neural_net.onnx')" 1821 | ] 1822 | }, 1823 | { 1824 | "cell_type": "markdown", 1825 | "id": "ab6d5742", 1826 | "metadata": {}, 1827 | "source": [ 1828 | "## Inference" 1829 | ] 1830 | }, 1831 | { 1832 | "cell_type": "code", 1833 | "execution_count": 126, 1834 | "id": "d9639473", 1835 | "metadata": { 1836 | "ExecuteTime": { 1837 | "end_time": "2022-06-07T13:37:26.839286Z", 1838 | "start_time": "2022-06-07T13:37:26.756831Z" 1839 | } 1840 | }, 1841 | "outputs": [], 1842 | "source": [ 1843 | "session = onnxruntime.InferenceSession('neural_net.onnx', None) # None: we want all of the outputs" 1844 | ] 1845 | }, 1846 | { 1847 | "cell_type": "code", 1848 | "execution_count": 127, 1849 | "id": "b9ae9c59", 1850 | "metadata": { 1851 | "ExecuteTime": { 1852 | "end_time": "2022-06-07T13:37:32.213835Z", 1853 | "start_time": "2022-06-07T13:37:32.198022Z" 1854 | } 1855 | }, 1856 | "outputs": [], 1857 | "source": [ 1858 | "input_name = session.get_inputs()[0].name\n", 1859 | "output_name = session.get_outputs()[0].name" 1860 | ] 1861 | }, 1862 | { 1863 | "cell_type": "code", 1864 | "execution_count": 128, 1865 | "id": "97f068c7", 1866 | "metadata": { 1867 | "ExecuteTime": { 1868 | "end_time": "2022-06-07T13:37:34.609760Z", 1869 | "start_time": "2022-06-07T13:37:34.599874Z" 1870 | } 1871 | }, 1872 | "outputs": [ 1873 | { 1874 | "data": { 1875 | "text/plain": [ 1876 | "'onnx::Gemm_0'" 1877 | ] 1878 | }, 1879 | "execution_count": 128, 1880 | "metadata": {}, 1881 | "output_type": "execute_result" 1882 | } 1883 | ], 1884 | "source": [ 1885 | "input_name" 1886 | ] 1887 | }, 1888 | { 1889 | "cell_type": "code", 1890 | "execution_count": 129, 1891 | "id": "5104f885", 1892 | "metadata": { 1893 | "ExecuteTime": { 1894 | "end_time": "2022-06-07T13:37:35.333370Z", 1895 | "start_time": "2022-06-07T13:37:35.326109Z" 1896 | } 1897 | }, 1898 | "outputs": [ 1899 | { 1900 | "data": { 1901 | "text/plain": [ 1902 | "'4'" 1903 | ] 1904 | }, 1905 | "execution_count": 129, 1906 | "metadata": {}, 1907 | "output_type": "execute_result" 1908 | } 1909 | ], 1910 | "source": [ 1911 | "output_name" 1912 | ] 1913 | }, 1914 | { 1915 | "cell_type": "code", 1916 | "execution_count": 130, 1917 | "id": "60e0d09d", 1918 | "metadata": { 1919 | "ExecuteTime": { 1920 | "end_time": "2022-06-07T13:38:21.799264Z", 1921 | "start_time": "2022-06-07T13:38:21.764289Z" 1922 | } 1923 | }, 1924 | "outputs": [], 1925 | "source": [ 1926 | "result = session.run([output_name], {input_name: test_dataset[0][0].numpy().reshape(1,-1)})" 1927 | ] 1928 | }, 1929 | { 1930 | "cell_type": "code", 1931 | "execution_count": 131, 1932 | "id": "881770f9", 1933 | "metadata": { 1934 | "ExecuteTime": { 1935 | "end_time": "2022-06-07T13:38:23.073119Z", 1936 | "start_time": "2022-06-07T13:38:23.051015Z" 1937 | } 1938 | }, 1939 | "outputs": [ 1940 | { 1941 | "data": { 1942 | "text/plain": [ 1943 | "[array([[0.1336531]], dtype=float32)]" 1944 | ] 1945 | }, 1946 | "execution_count": 131, 1947 | "metadata": {}, 1948 | "output_type": "execute_result" 1949 | } 1950 | ], 1951 | "source": [ 1952 | "result" 1953 | ] 1954 | } 1955 | ], 1956 | "metadata": { 1957 | "kernelspec": { 1958 | "display_name": "Python [conda env:torch] *", 1959 | "language": "python", 1960 | "name": "conda-env-torch-py" 1961 | }, 1962 | "language_info": { 1963 | "codemirror_mode": { 1964 | "name": "ipython", 1965 | "version": 3 1966 | }, 1967 | "file_extension": ".py", 1968 | "mimetype": "text/x-python", 1969 | "name": "python", 1970 | "nbconvert_exporter": "python", 1971 | "pygments_lexer": "ipython3", 1972 | "version": "3.9.12" 1973 | } 1974 | }, 1975 | "nbformat": 4, 1976 | "nbformat_minor": 5 1977 | } 1978 | -------------------------------------------------------------------------------- /Pytorch-2-Datasets.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Pytorch-2-Datasets.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyP4amUSuV/jfm6UQ0anPkIt"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"f8169018ae794abc94ab196d8125bf61":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_ce3bc39cc1b24ceba190382443f15672","IPY_MODEL_addc57f1484147b1b1330354dffab93c","IPY_MODEL_13181970177f43e9a5189e9e06d78490"],"layout":"IPY_MODEL_faf6d3b8c95b41468f8537b8c086bd93"}},"ce3bc39cc1b24ceba190382443f15672":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_745fe054690a4c8f8893e7569ee16575","placeholder":"​","style":"IPY_MODEL_3482fd34d144442cb44c123ce743018f","value":""}},"addc57f1484147b1b1330354dffab93c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_244e1fb794ea4e0a8355b545b132421f","max":9912422,"min":0,"orientation":"horizontal","style":"IPY_MODEL_a32baf002d8e41b1ac499cf851f8b2d1","value":9912422}},"13181970177f43e9a5189e9e06d78490":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_79dc2dbf4bf445b6b1671e23f9e4eb95","placeholder":"​","style":"IPY_MODEL_923c5621509d4353aa223c7be8a5894a","value":" 9913344/? [00:00<00:00, 17356071.71it/s]"}},"faf6d3b8c95b41468f8537b8c086bd93":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"745fe054690a4c8f8893e7569ee16575":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"3482fd34d144442cb44c123ce743018f":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"244e1fb794ea4e0a8355b545b132421f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"a32baf002d8e41b1ac499cf851f8b2d1":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"79dc2dbf4bf445b6b1671e23f9e4eb95":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"923c5621509d4353aa223c7be8a5894a":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"20a781dc8b4c4a358f3c17a20bf9f06e":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_0b36465265be4e909a66748ed48a2d78","IPY_MODEL_bea0c17dc6bf476987601b0400c7652a","IPY_MODEL_659ce58f0c8e42d487689f987b28f385"],"layout":"IPY_MODEL_e626b1001e03434e8ea36a4eabf4c06a"}},"0b36465265be4e909a66748ed48a2d78":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_8704e19310124cc4839f3c2fc6c0e0f7","placeholder":"​","style":"IPY_MODEL_6fd723aea7004f6088333c739657d3a4","value":""}},"bea0c17dc6bf476987601b0400c7652a":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_77c13c0bd0df4d54879d5ef6a8828917","max":28881,"min":0,"orientation":"horizontal","style":"IPY_MODEL_96858592543446beb18fc6f09275e548","value":28881}},"659ce58f0c8e42d487689f987b28f385":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_748c4ee513684d709eeba80a84fe10e3","placeholder":"​","style":"IPY_MODEL_1e232368d1b2493597dacb66f1c644e3","value":" 29696/? [00:00<00:00, 760830.57it/s]"}},"e626b1001e03434e8ea36a4eabf4c06a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"8704e19310124cc4839f3c2fc6c0e0f7":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"6fd723aea7004f6088333c739657d3a4":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"77c13c0bd0df4d54879d5ef6a8828917":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"96858592543446beb18fc6f09275e548":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"748c4ee513684d709eeba80a84fe10e3":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"1e232368d1b2493597dacb66f1c644e3":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"47f3e099c41c4ff4b1425888c84b6c93":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_1a493af8310b44faba94228cec91da44","IPY_MODEL_f119714fb8b7434bb5e670d34a0ce86a","IPY_MODEL_f2087dbf65644d73ab87ff7545b961c8"],"layout":"IPY_MODEL_a899f6f9e426466bbe7281b5fe10ced0"}},"1a493af8310b44faba94228cec91da44":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_8fc5f28192cb4090ad0103a206fa2062","placeholder":"​","style":"IPY_MODEL_7ba84d4abfdd4e58b85758a7440d1213","value":""}},"f119714fb8b7434bb5e670d34a0ce86a":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_9da2e9408960485bb6d170223c966834","max":1648877,"min":0,"orientation":"horizontal","style":"IPY_MODEL_bbbe72dbf6e14ab0ab5f40e21cf71291","value":1648877}},"f2087dbf65644d73ab87ff7545b961c8":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_9b3b85bd9bbf4b8b8769650d31919f3e","placeholder":"​","style":"IPY_MODEL_33b16335b7654d6f9d5a95260c41f12b","value":" 1649664/? [00:00<00:00, 16444553.92it/s]"}},"a899f6f9e426466bbe7281b5fe10ced0":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"8fc5f28192cb4090ad0103a206fa2062":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"7ba84d4abfdd4e58b85758a7440d1213":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"9da2e9408960485bb6d170223c966834":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"bbbe72dbf6e14ab0ab5f40e21cf71291":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"9b3b85bd9bbf4b8b8769650d31919f3e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"33b16335b7654d6f9d5a95260c41f12b":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"72444c5fc0d54669ad6a00cb103d0f90":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_55dd5d2253e84607950df2d9c929cde2","IPY_MODEL_37ccaa7f295a498586d2b582eb5eac35","IPY_MODEL_9f746863b41547aaa2ebf55a3f0eb906"],"layout":"IPY_MODEL_10cdc08bf73e4adabbef6c263ec531ed"}},"55dd5d2253e84607950df2d9c929cde2":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_2fbe8e189bb547fca6ff8cc23843c9f7","placeholder":"​","style":"IPY_MODEL_31e986a031594f76b492c4c10c33b25f","value":""}},"37ccaa7f295a498586d2b582eb5eac35":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_25f1514f92ae499ba42e2e7a3dd12602","max":4542,"min":0,"orientation":"horizontal","style":"IPY_MODEL_41af5ad49bf64f849b5430b2c514c2bb","value":4542}},"9f746863b41547aaa2ebf55a3f0eb906":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_0fccb71e51d24b88ae39cc00c357d413","placeholder":"​","style":"IPY_MODEL_6c3d89cf87fc4471811ed841cd179831","value":" 5120/? [00:00<00:00, 136832.95it/s]"}},"10cdc08bf73e4adabbef6c263ec531ed":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"2fbe8e189bb547fca6ff8cc23843c9f7":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"31e986a031594f76b492c4c10c33b25f":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"25f1514f92ae499ba42e2e7a3dd12602":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"41af5ad49bf64f849b5430b2c514c2bb":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"0fccb71e51d24b88ae39cc00c357d413":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"6c3d89cf87fc4471811ed841cd179831":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["Necessary Packages and Classes"],"metadata":{"id":"5TdGjA-p3jdo"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"BdvVhpM94mJq"},"outputs":[],"source":["import torch\n","from torch.utils.data import Dataset\n","from torch.utils.data import DataLoader\n","import torchvision\n","import matplotlib.pyplot as plt\n","import os"]},{"cell_type":"markdown","source":["Making the dataset files"],"metadata":{"id":"JnXhagD93nfa"}},{"cell_type":"code","source":["!mkdir /content/mydataset\n","!mkdir /content/mydataset/train/\n","!mkdir /content/mydataset/valid/\n","!mkdir /content/mydataset/test/"],"metadata":{"id":"cdsICr_V49Cn"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(80):\n"," with open(f'/content/mydataset/train/{str(i)}.txt','w', encoding='utf-8') as file:\n"," file.write(f'training file #{str(i)}')\n","for i in range(10):\n"," with open(f'/content/mydataset/valid/{str(i)}.txt', 'w', encoding='utf-8') as file:\n"," file.write(f'validation file #{str(i)}')\n","for i in range(10):\n"," with open(f'/content/mydataset/test/{str(i)}.txt', 'w', encoding='utf-8') as file:\n"," file.write(f'testing file #{str(i)}')"],"metadata":{"id":"lGAFUK2Y5Gua"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Loading it into a Dataset object recognizable by Pytorch"],"metadata":{"id":"qJxzqe0x3qFT"}},{"cell_type":"code","source":["class MyTextDataset(Dataset):\n"," def __init__(self, path):\n"," def sort_key(item):\n"," return int(item.replace('.txt',''))\n"," \n"," files_lst = os.listdir(path)\n"," files_lst = sorted([item for item in files_lst if item.endswith('.txt')], key=sort_key)\n"," self.texts = []\n"," for file_name in files_lst:\n"," with open(path+file_name, 'r') as file:\n"," file_content = file.read()\n"," self.texts.append(file_content)\n"," def __len__(self):\n"," return len(self.texts)\n"," def __getitem__(self, idx):\n"," original_item = self.texts[idx]\n"," tokens = original_item.split()\n"," length_lst = [len(token) for token in tokens]\n"," length_lst = length_lst[:20] if len(length_lst) >= 20 else length_lst + [0 for i in range(20-len(length_lst))]\n"," return torch.Tensor(length_lst)"],"metadata":{"id":"c-P2ZMJK6Qfm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_dataset = MyTextDataset('/content/mydataset/train/')"],"metadata":{"id":"b8QAAcbR6KRf"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Working with the dataset"],"metadata":{"id":"-vo9wuig_ckS"}},{"cell_type":"code","source":["len(train_dataset)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"YgfvHmqK5b99","executionInfo":{"status":"ok","timestamp":1652710784674,"user_tz":-270,"elapsed":641,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"cfb8ff51-4e32-487d-dddb-31c792e8720a"},"execution_count":8,"outputs":[{"output_type":"execute_result","data":{"text/plain":["80"]},"metadata":{},"execution_count":8}]},{"cell_type":"code","source":["train_dataset[54]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"c6XFv3b85fHA","executionInfo":{"status":"ok","timestamp":1652710789959,"user_tz":-270,"elapsed":363,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"4b1fd2f6-0155-43ec-d6ed-28fb81e66a61"},"execution_count":9,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([8., 4., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n"," 0., 0.])"]},"metadata":{},"execution_count":9}]},{"cell_type":"code","source":["train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)"],"metadata":{"id":"YowgVTT45dLs","executionInfo":{"status":"ok","timestamp":1652710934933,"user_tz":-270,"elapsed":601,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["for idx, batch in enumerate(iter(train_dataloader)):\n"," print(f'Batch {str(idx)} : {str(batch.size())}')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"KX1v6_DW-Be6","executionInfo":{"status":"ok","timestamp":1652710936289,"user_tz":-270,"elapsed":5,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"ac6bc19d-4ab1-4d83-8cdb-fc28325971ff"},"execution_count":12,"outputs":[{"output_type":"stream","name":"stdout","text":["Batch 0 : torch.Size([32, 20])\n","Batch 1 : torch.Size([32, 20])\n","Batch 2 : torch.Size([16, 20])\n"]}]},{"cell_type":"markdown","source":["MNIST Dataset"],"metadata":{"id":"eLLzRaAiB1Zj"}},{"cell_type":"code","source":["mnist_dataset = torchvision.datasets.MNIST('/', train=True, download=True,\n"," transform=torchvision.transforms.Compose([\n"," torchvision.transforms.ToTensor()\n"," ]))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":431,"referenced_widgets":["f8169018ae794abc94ab196d8125bf61","ce3bc39cc1b24ceba190382443f15672","addc57f1484147b1b1330354dffab93c","13181970177f43e9a5189e9e06d78490","faf6d3b8c95b41468f8537b8c086bd93","745fe054690a4c8f8893e7569ee16575","3482fd34d144442cb44c123ce743018f","244e1fb794ea4e0a8355b545b132421f","a32baf002d8e41b1ac499cf851f8b2d1","79dc2dbf4bf445b6b1671e23f9e4eb95","923c5621509d4353aa223c7be8a5894a","20a781dc8b4c4a358f3c17a20bf9f06e","0b36465265be4e909a66748ed48a2d78","bea0c17dc6bf476987601b0400c7652a","659ce58f0c8e42d487689f987b28f385","e626b1001e03434e8ea36a4eabf4c06a","8704e19310124cc4839f3c2fc6c0e0f7","6fd723aea7004f6088333c739657d3a4","77c13c0bd0df4d54879d5ef6a8828917","96858592543446beb18fc6f09275e548","748c4ee513684d709eeba80a84fe10e3","1e232368d1b2493597dacb66f1c644e3","47f3e099c41c4ff4b1425888c84b6c93","1a493af8310b44faba94228cec91da44","f119714fb8b7434bb5e670d34a0ce86a","f2087dbf65644d73ab87ff7545b961c8","a899f6f9e426466bbe7281b5fe10ced0","8fc5f28192cb4090ad0103a206fa2062","7ba84d4abfdd4e58b85758a7440d1213","9da2e9408960485bb6d170223c966834","bbbe72dbf6e14ab0ab5f40e21cf71291","9b3b85bd9bbf4b8b8769650d31919f3e","33b16335b7654d6f9d5a95260c41f12b","72444c5fc0d54669ad6a00cb103d0f90","55dd5d2253e84607950df2d9c929cde2","37ccaa7f295a498586d2b582eb5eac35","9f746863b41547aaa2ebf55a3f0eb906","10cdc08bf73e4adabbef6c263ec531ed","2fbe8e189bb547fca6ff8cc23843c9f7","31e986a031594f76b492c4c10c33b25f","25f1514f92ae499ba42e2e7a3dd12602","41af5ad49bf64f849b5430b2c514c2bb","0fccb71e51d24b88ae39cc00c357d413","6c3d89cf87fc4471811ed841cd179831"]},"id":"ddci5CmCCQin","executionInfo":{"status":"ok","timestamp":1652711088035,"user_tz":-270,"elapsed":2007,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"f1a314d3-bb08-4705-d589-0805008c348c"},"execution_count":13,"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n","Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /MNIST/raw/train-images-idx3-ubyte.gz\n"]},{"output_type":"display_data","data":{"text/plain":[" 0%| | 0/9912422 [00:00"]},"metadata":{},"execution_count":20},{"output_type":"display_data","data":{"text/plain":["
"],"image/png":"iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANW0lEQVR4nO3df6xX9X3H8dereMUJWqVagkIKFsxKlwzXG9TUGRs6p/yDbTYjWwztHNemmtjNbHMuiy7LVtbWmmZrXLASsbE0nVbETbsy2sZ2U8rFUH5IV9FhgPHDljnRdciP9/64h+6C9/u5l+/3fH/A+/lIbr7f73mfc8/bE1+c8z0/7scRIQCnv3d1uwEAnUHYgSQIO5AEYQeSIOxAEmd0cmVnenycpQmdXCWQyv/qLb0dBz1SraWw275O0pckjZP0lYhYUpr/LE3Q5Z7XyioBFKyNNQ1rTR/G2x4n6cuSrpc0W9JC27Ob/X0A2quV7+xzJW2LiFci4m1JX5e0oJ62ANStlbBfLGnHsM87q2nHsT1ge9D24CEdbGF1AFrR9rPxEbE0Ivojor9P49u9OgANtBL2XZKmDfs8tZoGoAe1EvZ1kmbZnmH7TEk3SVpVT1sA6tb0pbeIOGz7dkn/rKFLb8siYkttnQGoVUvX2SPiaUlP19QLgDbidlkgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEi0N2Wx7u6QDko5IOhwR/XU0BaB+LYW98pGI+GkNvwdAG3EYDyTRathD0rdtr7c9MNIMtgdsD9oePKSDLa4OQLNaPYy/KiJ22X6vpNW2fxwRzw6fISKWSloqSed6UrS4PgBNamnPHhG7qtd9kp6QNLeOpgDUr+mw255g+5xj7yVdK2lzXY0BqFcrh/GTJT1h+9jv+VpEfKuWrgDUrumwR8Qrkn61xl4AtBGX3oAkCDuQBGEHkiDsQBKEHUiijgdh0MPGnffuYn3bXbOL9c9+/NFi/aNn7y3WHzswo2HtNydsKy577YN/XKxP+8t/K9ZxPPbsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE19lPA6Vr6Wc/1VdcduslXy7WL/3OLcX69IfL+4sz1qxvWLvvzz9eXPZbiz9XrC9e/eliXc9vLNeTYc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0lwnf00cP2/bm9Yu6jvv4rLfuS2TxXrM1f+sJmWxmT6324p1n+86Pxi/Se3lu8huPT5k27ptMaeHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7KeBni68s1hec8/mGtd/5wzuLy05Yubapnupw5PX/LtY/+x/zO9RJDqPu2W0vs73P9uZh0ybZXm37peq1fPcDgK4by2H8w5KuO2HaXZLWRMQsSWuqzwB62Khhj4hnJe0/YfICScur98sl3VBzXwBq1ux39skRsbt6v0fS5EYz2h6QNCBJZ+nsJlcHoFUtn42PiJAUhfrSiOiPiP4+jW91dQCa1GzY99qeIknV6776WgLQDs2GfZWkRdX7RZKerKcdAO0y6nd22yskXSPpAts7Jd0jaYmkb9i+RdKrkm5sZ5PZXTHwQrH+yOv9DWsTHuvedfRWbd95QbdbOK2MGvaIWNigNK/mXgC0EbfLAkkQdiAJwg4kQdiBJAg7kASPuJ4CLpv4arE+re9nDWsr/nRxcdn/ufRgUz0dc8+VTxXrf73hxGeo/t+Mm8pDKn9oZvm/e/229xXrOB57diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IwkN/aKYzzvWkuNw8LHeyfn7D3GL995asbFg7EuV/z8f5aLG+fEf5z1i/dmBisf6jy7/asDbrsU8Xl/389V8r1pfNu7pYP7xjZ7F+Oloba/RG7PdINfbsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEz7OfAn5p5Q+L9RUrL2rbus9U+ZnyqZd9sFg/+o+N7+P47aufLy77V/f/brF+4Y7ninUcjz07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBdXa05MDM8vPs79KIj1ZLkp77i/Jz+heu5Dp6nUbds9teZnuf7c3Dpt1re5ftDdXP/Pa2CaBVYzmMf1jSSMN63B8Rc6qfp+ttC0DdRg17RDwraX8HegHQRq2coLvd9sbqMP/8RjPZHrA9aHvwkFobVwxA85oN+wOS3i9pjqTdku5rNGNELI2I/ojo79P4JlcHoFVNhT0i9kbEkYg4KulBSeXTqgC6rqmw254y7OPHJG1uNC+A3jDqdXbbKyRdI+kC2zsl3SPpGttzJIWk7ZJubWOP6GF7FrxdrB9V4+fZR3tOH/UaNewRsXCEyQ+1oRcAbcTtskAShB1IgrADSRB2IAnCDiTBI65oyQem7inWV73V8E5qdBh7diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IguvsaMnfX/IPxfqvP/MHDWuXal3d7aCAPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF1dhS9/IUrivWpZ2wo1n/5gbca1o421RGaxZ4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5LgOjuKPnHt94r1x988t1iPrS/X1wxaMuqe3fY029+1/aLtLbbvqKZPsr3a9kvVK6MBAD1sLIfxhyXdGRGzJV0h6TbbsyXdJWlNRMyStKb6DKBHjRr2iNgdES9U7w9I2irpYkkLJC2vZlsu6YZ2NQmgdSf1nd32dEmXSVoraXJE7K5KeyRNbrDMgKQBSTpLZzfbJ4AWjflsvO2Jkh6X9JmIeGN4LSJCUoy0XEQsjYj+iOjv0/iWmgXQvDGF3XafhoL+aER8s5q81/aUqj5F0r72tAigDqMextu2pIckbY2ILw4rrZK0SNKS6vXJtnSItho3+b3F+kcnPlOsL1r3yWJ9+sGNJ90T2mMs39k/LOlmSZtsH3t4+W4Nhfwbtm+R9KqkG9vTIoA6jBr2iPiBJDcoz6u3HQDtwu2yQBKEHUiCsANJEHYgCcIOJMEjrsntuHlmsf6hUW56nPQkt0CfKtizA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASXGc/zY07793Fev9vbSrWdx/5ebE+6fs7i/XDxSo6iT07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBdfbT3K5FHyzWn5r2d8X6Bx75o2J9xo7nTrondAd7diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IYizjs0+T9IikyZJC0tKI+JLteyUtlvRaNevdEfF0uxpFY6Ux1n9/4J+Ky97xn1cW6zP/5sVi/Uixil4ylptqDku6MyJesH2OpPW2V1e1+yPiC+1rD0BdxjI++25Ju6v3B2xvlXRxuxsDUK+T+s5ue7qkyyStrSbdbnuj7WW2z2+wzIDtQduDh3SwpWYBNG/MYbc9UdLjkj4TEW9IekDS+yXN0dCe/76RlouIpRHRHxH9fRpl4DAAbTOmsNvu01DQH42Ib0pSROyNiCMRcVTSg5Lmtq9NAK0aNey2LekhSVsj4ovDpk8ZNtvHJG2uvz0AdRnL2fgPS7pZ0ibbG6ppd0taaHuOhi7HbZd0a1s6xKiOXnRhw9qnznumuOyV988v1t/zOo+wni7Gcjb+B5I8Qolr6sAphDvogCQIO5AEYQeSIOxAEoQdSIKwA0k4Ijq2snM9KS73vI6tD8hmbazRG7F/pEvl7NmBLAg7kARhB5Ig7EAShB1IgrADSRB2IImOXme3/ZqkV4dNukDSTzvWwMnp1d56tS+J3ppVZ2/vi4gR/8BBR8P+jpXbgxHR37UGCnq1t17tS6K3ZnWqNw7jgSQIO5BEt8O+tMvrL+nV3nq1L4nemtWR3rr6nR1A53R7zw6gQwg7kERXwm77Otv/bnub7bu60UMjtrfb3mR7g+3BLveyzPY+25uHTZtke7Xtl6rXEcfY61Jv99reVW27DbbLf5S+fb1Ns/1d2y/a3mL7jmp6V7ddoa+ObLeOf2e3PU7STyT9hqSdktZJWhgR5YHAO8T2dkn9EdH1GzBsXy3pTUmPRMSvVNM+J2l/RCyp/qE8PyL+pEd6u1fSm90exrsarWjK8GHGJd0g6RPq4rYr9HWjOrDdurFnnytpW0S8EhFvS/q6pAVd6KPnRcSzkvafMHmBpOXV++Ua+p+l4xr01hMiYndEvFC9PyDp2DDjXd12hb46ohthv1jSjmGfd6q3xnsPSd+2vd72QLebGcHkiNhdvd8jaXI3mxnBqMN4d9IJw4z3zLZrZvjzVnGC7p2uiohfk3S9pNuqw9WeFEPfwXrp2umYhvHulBGGGf+Fbm67Zoc/b1U3wr5L0rRhn6dW03pCROyqXvdJekK9NxT13mMj6Fav+7rczy/00jDeIw0zrh7Ydt0c/rwbYV8naZbtGbbPlHSTpFVd6OMdbE+oTpzI9gRJ16r3hqJeJWlR9X6RpCe72MtxemUY70bDjKvL267rw59HRMd/JM3X0Bn5lyX9WTd6aNDXJZJ+VP1s6XZvklZo6LDukIbObdwi6T2S1kh6SdK/SJrUQ719VdImSRs1FKwpXertKg0dom+UtKH6md/tbVfoqyPbjdtlgSQ4QQckQdiBJAg7kARhB5Ig7EAShB1IgrADSfwf5qvsnhUvXpMAAAAASUVORK5CYII=\n"},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","source":["The shape of this data point"],"metadata":{"id":"YElS1NX9FWrK"}},{"cell_type":"code","source":["mnist_dataset[50][0].shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"kkAaY3QUD-gL","executionInfo":{"status":"ok","timestamp":1652711160379,"user_tz":-270,"elapsed":497,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"c7138a90-192d-4f11-c9d0-7c8bf6387a27"},"execution_count":17,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.Size([1, 28, 28])"]},"metadata":{},"execution_count":17}]},{"cell_type":"code","source":["mnist_dataset[50][1]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XijGI055EQwo","executionInfo":{"status":"ok","timestamp":1652711182883,"user_tz":-270,"elapsed":590,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"d097f38e-c215-4f57-9ba7-e8433cd5f969"},"execution_count":18,"outputs":[{"output_type":"execute_result","data":{"text/plain":["3"]},"metadata":{},"execution_count":18}]},{"cell_type":"markdown","source":["What about the batches?"],"metadata":{"id":"nMleRr7BF84x"}},{"cell_type":"code","source":["for idx, batch in enumerate(iter(mnist_dataloader)):\n"," print(f'Batch {str(idx)} : {batch[0].size()}')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_CIR453KCBAf","executionInfo":{"status":"ok","timestamp":1652711238698,"user_tz":-270,"elapsed":7109,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"6b18c2b4-50eb-42c1-cfec-b459b863493e"},"execution_count":21,"outputs":[{"output_type":"stream","name":"stdout","text":["Batch 0 : torch.Size([64, 1, 28, 28])\n","Batch 1 : torch.Size([64, 1, 28, 28])\n","Batch 2 : torch.Size([64, 1, 28, 28])\n","Batch 3 : torch.Size([64, 1, 28, 28])\n","Batch 4 : torch.Size([64, 1, 28, 28])\n","Batch 5 : torch.Size([64, 1, 28, 28])\n","Batch 6 : torch.Size([64, 1, 28, 28])\n","Batch 7 : torch.Size([64, 1, 28, 28])\n","Batch 8 : torch.Size([64, 1, 28, 28])\n","Batch 9 : torch.Size([64, 1, 28, 28])\n","Batch 10 : torch.Size([64, 1, 28, 28])\n","Batch 11 : torch.Size([64, 1, 28, 28])\n","Batch 12 : torch.Size([64, 1, 28, 28])\n","Batch 13 : torch.Size([64, 1, 28, 28])\n","Batch 14 : torch.Size([64, 1, 28, 28])\n","Batch 15 : torch.Size([64, 1, 28, 28])\n","Batch 16 : torch.Size([64, 1, 28, 28])\n","Batch 17 : torch.Size([64, 1, 28, 28])\n","Batch 18 : torch.Size([64, 1, 28, 28])\n","Batch 19 : torch.Size([64, 1, 28, 28])\n","Batch 20 : torch.Size([64, 1, 28, 28])\n","Batch 21 : torch.Size([64, 1, 28, 28])\n","Batch 22 : torch.Size([64, 1, 28, 28])\n","Batch 23 : torch.Size([64, 1, 28, 28])\n","Batch 24 : torch.Size([64, 1, 28, 28])\n","Batch 25 : torch.Size([64, 1, 28, 28])\n","Batch 26 : torch.Size([64, 1, 28, 28])\n","Batch 27 : torch.Size([64, 1, 28, 28])\n","Batch 28 : torch.Size([64, 1, 28, 28])\n","Batch 29 : torch.Size([64, 1, 28, 28])\n","Batch 30 : torch.Size([64, 1, 28, 28])\n","Batch 31 : torch.Size([64, 1, 28, 28])\n","Batch 32 : torch.Size([64, 1, 28, 28])\n","Batch 33 : torch.Size([64, 1, 28, 28])\n","Batch 34 : torch.Size([64, 1, 28, 28])\n","Batch 35 : torch.Size([64, 1, 28, 28])\n","Batch 36 : torch.Size([64, 1, 28, 28])\n","Batch 37 : torch.Size([64, 1, 28, 28])\n","Batch 38 : torch.Size([64, 1, 28, 28])\n","Batch 39 : torch.Size([64, 1, 28, 28])\n","Batch 40 : torch.Size([64, 1, 28, 28])\n","Batch 41 : torch.Size([64, 1, 28, 28])\n","Batch 42 : torch.Size([64, 1, 28, 28])\n","Batch 43 : torch.Size([64, 1, 28, 28])\n","Batch 44 : torch.Size([64, 1, 28, 28])\n","Batch 45 : torch.Size([64, 1, 28, 28])\n","Batch 46 : torch.Size([64, 1, 28, 28])\n","Batch 47 : torch.Size([64, 1, 28, 28])\n","Batch 48 : torch.Size([64, 1, 28, 28])\n","Batch 49 : torch.Size([64, 1, 28, 28])\n","Batch 50 : torch.Size([64, 1, 28, 28])\n","Batch 51 : torch.Size([64, 1, 28, 28])\n","Batch 52 : torch.Size([64, 1, 28, 28])\n","Batch 53 : torch.Size([64, 1, 28, 28])\n","Batch 54 : torch.Size([64, 1, 28, 28])\n","Batch 55 : torch.Size([64, 1, 28, 28])\n","Batch 56 : torch.Size([64, 1, 28, 28])\n","Batch 57 : torch.Size([64, 1, 28, 28])\n","Batch 58 : torch.Size([64, 1, 28, 28])\n","Batch 59 : torch.Size([64, 1, 28, 28])\n","Batch 60 : torch.Size([64, 1, 28, 28])\n","Batch 61 : torch.Size([64, 1, 28, 28])\n","Batch 62 : torch.Size([64, 1, 28, 28])\n","Batch 63 : torch.Size([64, 1, 28, 28])\n","Batch 64 : torch.Size([64, 1, 28, 28])\n","Batch 65 : torch.Size([64, 1, 28, 28])\n","Batch 66 : torch.Size([64, 1, 28, 28])\n","Batch 67 : torch.Size([64, 1, 28, 28])\n","Batch 68 : torch.Size([64, 1, 28, 28])\n","Batch 69 : torch.Size([64, 1, 28, 28])\n","Batch 70 : torch.Size([64, 1, 28, 28])\n","Batch 71 : torch.Size([64, 1, 28, 28])\n","Batch 72 : torch.Size([64, 1, 28, 28])\n","Batch 73 : torch.Size([64, 1, 28, 28])\n","Batch 74 : torch.Size([64, 1, 28, 28])\n","Batch 75 : torch.Size([64, 1, 28, 28])\n","Batch 76 : torch.Size([64, 1, 28, 28])\n","Batch 77 : torch.Size([64, 1, 28, 28])\n","Batch 78 : torch.Size([64, 1, 28, 28])\n","Batch 79 : torch.Size([64, 1, 28, 28])\n","Batch 80 : torch.Size([64, 1, 28, 28])\n","Batch 81 : torch.Size([64, 1, 28, 28])\n","Batch 82 : torch.Size([64, 1, 28, 28])\n","Batch 83 : torch.Size([64, 1, 28, 28])\n","Batch 84 : torch.Size([64, 1, 28, 28])\n","Batch 85 : torch.Size([64, 1, 28, 28])\n","Batch 86 : torch.Size([64, 1, 28, 28])\n","Batch 87 : torch.Size([64, 1, 28, 28])\n","Batch 88 : torch.Size([64, 1, 28, 28])\n","Batch 89 : torch.Size([64, 1, 28, 28])\n","Batch 90 : torch.Size([64, 1, 28, 28])\n","Batch 91 : torch.Size([64, 1, 28, 28])\n","Batch 92 : torch.Size([64, 1, 28, 28])\n","Batch 93 : torch.Size([64, 1, 28, 28])\n","Batch 94 : torch.Size([64, 1, 28, 28])\n","Batch 95 : torch.Size([64, 1, 28, 28])\n","Batch 96 : torch.Size([64, 1, 28, 28])\n","Batch 97 : torch.Size([64, 1, 28, 28])\n","Batch 98 : torch.Size([64, 1, 28, 28])\n","Batch 99 : torch.Size([64, 1, 28, 28])\n","Batch 100 : torch.Size([64, 1, 28, 28])\n","Batch 101 : torch.Size([64, 1, 28, 28])\n","Batch 102 : torch.Size([64, 1, 28, 28])\n","Batch 103 : torch.Size([64, 1, 28, 28])\n","Batch 104 : torch.Size([64, 1, 28, 28])\n","Batch 105 : torch.Size([64, 1, 28, 28])\n","Batch 106 : torch.Size([64, 1, 28, 28])\n","Batch 107 : torch.Size([64, 1, 28, 28])\n","Batch 108 : torch.Size([64, 1, 28, 28])\n","Batch 109 : torch.Size([64, 1, 28, 28])\n","Batch 110 : torch.Size([64, 1, 28, 28])\n","Batch 111 : torch.Size([64, 1, 28, 28])\n","Batch 112 : torch.Size([64, 1, 28, 28])\n","Batch 113 : torch.Size([64, 1, 28, 28])\n","Batch 114 : torch.Size([64, 1, 28, 28])\n","Batch 115 : torch.Size([64, 1, 28, 28])\n","Batch 116 : torch.Size([64, 1, 28, 28])\n","Batch 117 : torch.Size([64, 1, 28, 28])\n","Batch 118 : torch.Size([64, 1, 28, 28])\n","Batch 119 : torch.Size([64, 1, 28, 28])\n","Batch 120 : torch.Size([64, 1, 28, 28])\n","Batch 121 : torch.Size([64, 1, 28, 28])\n","Batch 122 : torch.Size([64, 1, 28, 28])\n","Batch 123 : torch.Size([64, 1, 28, 28])\n","Batch 124 : torch.Size([64, 1, 28, 28])\n","Batch 125 : torch.Size([64, 1, 28, 28])\n","Batch 126 : torch.Size([64, 1, 28, 28])\n","Batch 127 : torch.Size([64, 1, 28, 28])\n","Batch 128 : torch.Size([64, 1, 28, 28])\n","Batch 129 : torch.Size([64, 1, 28, 28])\n","Batch 130 : torch.Size([64, 1, 28, 28])\n","Batch 131 : torch.Size([64, 1, 28, 28])\n","Batch 132 : torch.Size([64, 1, 28, 28])\n","Batch 133 : torch.Size([64, 1, 28, 28])\n","Batch 134 : torch.Size([64, 1, 28, 28])\n","Batch 135 : torch.Size([64, 1, 28, 28])\n","Batch 136 : torch.Size([64, 1, 28, 28])\n","Batch 137 : torch.Size([64, 1, 28, 28])\n","Batch 138 : torch.Size([64, 1, 28, 28])\n","Batch 139 : torch.Size([64, 1, 28, 28])\n","Batch 140 : torch.Size([64, 1, 28, 28])\n","Batch 141 : torch.Size([64, 1, 28, 28])\n","Batch 142 : torch.Size([64, 1, 28, 28])\n","Batch 143 : torch.Size([64, 1, 28, 28])\n","Batch 144 : torch.Size([64, 1, 28, 28])\n","Batch 145 : torch.Size([64, 1, 28, 28])\n","Batch 146 : torch.Size([64, 1, 28, 28])\n","Batch 147 : torch.Size([64, 1, 28, 28])\n","Batch 148 : torch.Size([64, 1, 28, 28])\n","Batch 149 : torch.Size([64, 1, 28, 28])\n","Batch 150 : torch.Size([64, 1, 28, 28])\n","Batch 151 : torch.Size([64, 1, 28, 28])\n","Batch 152 : torch.Size([64, 1, 28, 28])\n","Batch 153 : torch.Size([64, 1, 28, 28])\n","Batch 154 : torch.Size([64, 1, 28, 28])\n","Batch 155 : torch.Size([64, 1, 28, 28])\n","Batch 156 : torch.Size([64, 1, 28, 28])\n","Batch 157 : torch.Size([64, 1, 28, 28])\n","Batch 158 : torch.Size([64, 1, 28, 28])\n","Batch 159 : torch.Size([64, 1, 28, 28])\n","Batch 160 : torch.Size([64, 1, 28, 28])\n","Batch 161 : torch.Size([64, 1, 28, 28])\n","Batch 162 : torch.Size([64, 1, 28, 28])\n","Batch 163 : torch.Size([64, 1, 28, 28])\n","Batch 164 : torch.Size([64, 1, 28, 28])\n","Batch 165 : torch.Size([64, 1, 28, 28])\n","Batch 166 : torch.Size([64, 1, 28, 28])\n","Batch 167 : torch.Size([64, 1, 28, 28])\n","Batch 168 : torch.Size([64, 1, 28, 28])\n","Batch 169 : torch.Size([64, 1, 28, 28])\n","Batch 170 : torch.Size([64, 1, 28, 28])\n","Batch 171 : torch.Size([64, 1, 28, 28])\n","Batch 172 : torch.Size([64, 1, 28, 28])\n","Batch 173 : torch.Size([64, 1, 28, 28])\n","Batch 174 : torch.Size([64, 1, 28, 28])\n","Batch 175 : torch.Size([64, 1, 28, 28])\n","Batch 176 : torch.Size([64, 1, 28, 28])\n","Batch 177 : torch.Size([64, 1, 28, 28])\n","Batch 178 : torch.Size([64, 1, 28, 28])\n","Batch 179 : torch.Size([64, 1, 28, 28])\n","Batch 180 : torch.Size([64, 1, 28, 28])\n","Batch 181 : torch.Size([64, 1, 28, 28])\n","Batch 182 : torch.Size([64, 1, 28, 28])\n","Batch 183 : torch.Size([64, 1, 28, 28])\n","Batch 184 : torch.Size([64, 1, 28, 28])\n","Batch 185 : torch.Size([64, 1, 28, 28])\n","Batch 186 : torch.Size([64, 1, 28, 28])\n","Batch 187 : torch.Size([64, 1, 28, 28])\n","Batch 188 : torch.Size([64, 1, 28, 28])\n","Batch 189 : torch.Size([64, 1, 28, 28])\n","Batch 190 : torch.Size([64, 1, 28, 28])\n","Batch 191 : torch.Size([64, 1, 28, 28])\n","Batch 192 : torch.Size([64, 1, 28, 28])\n","Batch 193 : torch.Size([64, 1, 28, 28])\n","Batch 194 : torch.Size([64, 1, 28, 28])\n","Batch 195 : torch.Size([64, 1, 28, 28])\n","Batch 196 : torch.Size([64, 1, 28, 28])\n","Batch 197 : torch.Size([64, 1, 28, 28])\n","Batch 198 : torch.Size([64, 1, 28, 28])\n","Batch 199 : torch.Size([64, 1, 28, 28])\n","Batch 200 : torch.Size([64, 1, 28, 28])\n","Batch 201 : torch.Size([64, 1, 28, 28])\n","Batch 202 : torch.Size([64, 1, 28, 28])\n","Batch 203 : torch.Size([64, 1, 28, 28])\n","Batch 204 : torch.Size([64, 1, 28, 28])\n","Batch 205 : torch.Size([64, 1, 28, 28])\n","Batch 206 : torch.Size([64, 1, 28, 28])\n","Batch 207 : torch.Size([64, 1, 28, 28])\n","Batch 208 : torch.Size([64, 1, 28, 28])\n","Batch 209 : torch.Size([64, 1, 28, 28])\n","Batch 210 : torch.Size([64, 1, 28, 28])\n","Batch 211 : torch.Size([64, 1, 28, 28])\n","Batch 212 : torch.Size([64, 1, 28, 28])\n","Batch 213 : torch.Size([64, 1, 28, 28])\n","Batch 214 : torch.Size([64, 1, 28, 28])\n","Batch 215 : torch.Size([64, 1, 28, 28])\n","Batch 216 : torch.Size([64, 1, 28, 28])\n","Batch 217 : torch.Size([64, 1, 28, 28])\n","Batch 218 : torch.Size([64, 1, 28, 28])\n","Batch 219 : torch.Size([64, 1, 28, 28])\n","Batch 220 : torch.Size([64, 1, 28, 28])\n","Batch 221 : torch.Size([64, 1, 28, 28])\n","Batch 222 : torch.Size([64, 1, 28, 28])\n","Batch 223 : torch.Size([64, 1, 28, 28])\n","Batch 224 : torch.Size([64, 1, 28, 28])\n","Batch 225 : torch.Size([64, 1, 28, 28])\n","Batch 226 : torch.Size([64, 1, 28, 28])\n","Batch 227 : torch.Size([64, 1, 28, 28])\n","Batch 228 : torch.Size([64, 1, 28, 28])\n","Batch 229 : torch.Size([64, 1, 28, 28])\n","Batch 230 : torch.Size([64, 1, 28, 28])\n","Batch 231 : torch.Size([64, 1, 28, 28])\n","Batch 232 : torch.Size([64, 1, 28, 28])\n","Batch 233 : torch.Size([64, 1, 28, 28])\n","Batch 234 : torch.Size([64, 1, 28, 28])\n","Batch 235 : torch.Size([64, 1, 28, 28])\n","Batch 236 : torch.Size([64, 1, 28, 28])\n","Batch 237 : torch.Size([64, 1, 28, 28])\n","Batch 238 : torch.Size([64, 1, 28, 28])\n","Batch 239 : torch.Size([64, 1, 28, 28])\n","Batch 240 : torch.Size([64, 1, 28, 28])\n","Batch 241 : torch.Size([64, 1, 28, 28])\n","Batch 242 : torch.Size([64, 1, 28, 28])\n","Batch 243 : torch.Size([64, 1, 28, 28])\n","Batch 244 : torch.Size([64, 1, 28, 28])\n","Batch 245 : torch.Size([64, 1, 28, 28])\n","Batch 246 : torch.Size([64, 1, 28, 28])\n","Batch 247 : torch.Size([64, 1, 28, 28])\n","Batch 248 : torch.Size([64, 1, 28, 28])\n","Batch 249 : torch.Size([64, 1, 28, 28])\n","Batch 250 : torch.Size([64, 1, 28, 28])\n","Batch 251 : torch.Size([64, 1, 28, 28])\n","Batch 252 : torch.Size([64, 1, 28, 28])\n","Batch 253 : torch.Size([64, 1, 28, 28])\n","Batch 254 : torch.Size([64, 1, 28, 28])\n","Batch 255 : torch.Size([64, 1, 28, 28])\n","Batch 256 : torch.Size([64, 1, 28, 28])\n","Batch 257 : torch.Size([64, 1, 28, 28])\n","Batch 258 : torch.Size([64, 1, 28, 28])\n","Batch 259 : torch.Size([64, 1, 28, 28])\n","Batch 260 : torch.Size([64, 1, 28, 28])\n","Batch 261 : torch.Size([64, 1, 28, 28])\n","Batch 262 : torch.Size([64, 1, 28, 28])\n","Batch 263 : torch.Size([64, 1, 28, 28])\n","Batch 264 : torch.Size([64, 1, 28, 28])\n","Batch 265 : torch.Size([64, 1, 28, 28])\n","Batch 266 : torch.Size([64, 1, 28, 28])\n","Batch 267 : torch.Size([64, 1, 28, 28])\n","Batch 268 : torch.Size([64, 1, 28, 28])\n","Batch 269 : torch.Size([64, 1, 28, 28])\n","Batch 270 : torch.Size([64, 1, 28, 28])\n","Batch 271 : torch.Size([64, 1, 28, 28])\n","Batch 272 : torch.Size([64, 1, 28, 28])\n","Batch 273 : torch.Size([64, 1, 28, 28])\n","Batch 274 : torch.Size([64, 1, 28, 28])\n","Batch 275 : torch.Size([64, 1, 28, 28])\n","Batch 276 : torch.Size([64, 1, 28, 28])\n","Batch 277 : torch.Size([64, 1, 28, 28])\n","Batch 278 : torch.Size([64, 1, 28, 28])\n","Batch 279 : torch.Size([64, 1, 28, 28])\n","Batch 280 : torch.Size([64, 1, 28, 28])\n","Batch 281 : torch.Size([64, 1, 28, 28])\n","Batch 282 : torch.Size([64, 1, 28, 28])\n","Batch 283 : torch.Size([64, 1, 28, 28])\n","Batch 284 : torch.Size([64, 1, 28, 28])\n","Batch 285 : torch.Size([64, 1, 28, 28])\n","Batch 286 : torch.Size([64, 1, 28, 28])\n","Batch 287 : torch.Size([64, 1, 28, 28])\n","Batch 288 : torch.Size([64, 1, 28, 28])\n","Batch 289 : torch.Size([64, 1, 28, 28])\n","Batch 290 : torch.Size([64, 1, 28, 28])\n","Batch 291 : torch.Size([64, 1, 28, 28])\n","Batch 292 : torch.Size([64, 1, 28, 28])\n","Batch 293 : torch.Size([64, 1, 28, 28])\n","Batch 294 : torch.Size([64, 1, 28, 28])\n","Batch 295 : torch.Size([64, 1, 28, 28])\n","Batch 296 : torch.Size([64, 1, 28, 28])\n","Batch 297 : torch.Size([64, 1, 28, 28])\n","Batch 298 : torch.Size([64, 1, 28, 28])\n","Batch 299 : torch.Size([64, 1, 28, 28])\n","Batch 300 : torch.Size([64, 1, 28, 28])\n","Batch 301 : torch.Size([64, 1, 28, 28])\n","Batch 302 : torch.Size([64, 1, 28, 28])\n","Batch 303 : torch.Size([64, 1, 28, 28])\n","Batch 304 : torch.Size([64, 1, 28, 28])\n","Batch 305 : torch.Size([64, 1, 28, 28])\n","Batch 306 : torch.Size([64, 1, 28, 28])\n","Batch 307 : torch.Size([64, 1, 28, 28])\n","Batch 308 : torch.Size([64, 1, 28, 28])\n","Batch 309 : torch.Size([64, 1, 28, 28])\n","Batch 310 : torch.Size([64, 1, 28, 28])\n","Batch 311 : torch.Size([64, 1, 28, 28])\n","Batch 312 : torch.Size([64, 1, 28, 28])\n","Batch 313 : torch.Size([64, 1, 28, 28])\n","Batch 314 : torch.Size([64, 1, 28, 28])\n","Batch 315 : torch.Size([64, 1, 28, 28])\n","Batch 316 : torch.Size([64, 1, 28, 28])\n","Batch 317 : torch.Size([64, 1, 28, 28])\n","Batch 318 : torch.Size([64, 1, 28, 28])\n","Batch 319 : torch.Size([64, 1, 28, 28])\n","Batch 320 : torch.Size([64, 1, 28, 28])\n","Batch 321 : torch.Size([64, 1, 28, 28])\n","Batch 322 : torch.Size([64, 1, 28, 28])\n","Batch 323 : torch.Size([64, 1, 28, 28])\n","Batch 324 : torch.Size([64, 1, 28, 28])\n","Batch 325 : torch.Size([64, 1, 28, 28])\n","Batch 326 : torch.Size([64, 1, 28, 28])\n","Batch 327 : torch.Size([64, 1, 28, 28])\n","Batch 328 : torch.Size([64, 1, 28, 28])\n","Batch 329 : torch.Size([64, 1, 28, 28])\n","Batch 330 : torch.Size([64, 1, 28, 28])\n","Batch 331 : torch.Size([64, 1, 28, 28])\n","Batch 332 : torch.Size([64, 1, 28, 28])\n","Batch 333 : torch.Size([64, 1, 28, 28])\n","Batch 334 : torch.Size([64, 1, 28, 28])\n","Batch 335 : torch.Size([64, 1, 28, 28])\n","Batch 336 : torch.Size([64, 1, 28, 28])\n","Batch 337 : torch.Size([64, 1, 28, 28])\n","Batch 338 : torch.Size([64, 1, 28, 28])\n","Batch 339 : torch.Size([64, 1, 28, 28])\n","Batch 340 : torch.Size([64, 1, 28, 28])\n","Batch 341 : torch.Size([64, 1, 28, 28])\n","Batch 342 : torch.Size([64, 1, 28, 28])\n","Batch 343 : torch.Size([64, 1, 28, 28])\n","Batch 344 : torch.Size([64, 1, 28, 28])\n","Batch 345 : torch.Size([64, 1, 28, 28])\n","Batch 346 : torch.Size([64, 1, 28, 28])\n","Batch 347 : torch.Size([64, 1, 28, 28])\n","Batch 348 : torch.Size([64, 1, 28, 28])\n","Batch 349 : torch.Size([64, 1, 28, 28])\n","Batch 350 : torch.Size([64, 1, 28, 28])\n","Batch 351 : torch.Size([64, 1, 28, 28])\n","Batch 352 : torch.Size([64, 1, 28, 28])\n","Batch 353 : torch.Size([64, 1, 28, 28])\n","Batch 354 : torch.Size([64, 1, 28, 28])\n","Batch 355 : torch.Size([64, 1, 28, 28])\n","Batch 356 : torch.Size([64, 1, 28, 28])\n","Batch 357 : torch.Size([64, 1, 28, 28])\n","Batch 358 : torch.Size([64, 1, 28, 28])\n","Batch 359 : torch.Size([64, 1, 28, 28])\n","Batch 360 : torch.Size([64, 1, 28, 28])\n","Batch 361 : torch.Size([64, 1, 28, 28])\n","Batch 362 : torch.Size([64, 1, 28, 28])\n","Batch 363 : torch.Size([64, 1, 28, 28])\n","Batch 364 : torch.Size([64, 1, 28, 28])\n","Batch 365 : torch.Size([64, 1, 28, 28])\n","Batch 366 : torch.Size([64, 1, 28, 28])\n","Batch 367 : torch.Size([64, 1, 28, 28])\n","Batch 368 : torch.Size([64, 1, 28, 28])\n","Batch 369 : torch.Size([64, 1, 28, 28])\n","Batch 370 : torch.Size([64, 1, 28, 28])\n","Batch 371 : torch.Size([64, 1, 28, 28])\n","Batch 372 : torch.Size([64, 1, 28, 28])\n","Batch 373 : torch.Size([64, 1, 28, 28])\n","Batch 374 : torch.Size([64, 1, 28, 28])\n","Batch 375 : torch.Size([64, 1, 28, 28])\n","Batch 376 : torch.Size([64, 1, 28, 28])\n","Batch 377 : torch.Size([64, 1, 28, 28])\n","Batch 378 : torch.Size([64, 1, 28, 28])\n","Batch 379 : torch.Size([64, 1, 28, 28])\n","Batch 380 : torch.Size([64, 1, 28, 28])\n","Batch 381 : torch.Size([64, 1, 28, 28])\n","Batch 382 : torch.Size([64, 1, 28, 28])\n","Batch 383 : torch.Size([64, 1, 28, 28])\n","Batch 384 : torch.Size([64, 1, 28, 28])\n","Batch 385 : torch.Size([64, 1, 28, 28])\n","Batch 386 : torch.Size([64, 1, 28, 28])\n","Batch 387 : torch.Size([64, 1, 28, 28])\n","Batch 388 : torch.Size([64, 1, 28, 28])\n","Batch 389 : torch.Size([64, 1, 28, 28])\n","Batch 390 : torch.Size([64, 1, 28, 28])\n","Batch 391 : torch.Size([64, 1, 28, 28])\n","Batch 392 : torch.Size([64, 1, 28, 28])\n","Batch 393 : torch.Size([64, 1, 28, 28])\n","Batch 394 : torch.Size([64, 1, 28, 28])\n","Batch 395 : torch.Size([64, 1, 28, 28])\n","Batch 396 : torch.Size([64, 1, 28, 28])\n","Batch 397 : torch.Size([64, 1, 28, 28])\n","Batch 398 : torch.Size([64, 1, 28, 28])\n","Batch 399 : torch.Size([64, 1, 28, 28])\n","Batch 400 : torch.Size([64, 1, 28, 28])\n","Batch 401 : torch.Size([64, 1, 28, 28])\n","Batch 402 : torch.Size([64, 1, 28, 28])\n","Batch 403 : torch.Size([64, 1, 28, 28])\n","Batch 404 : torch.Size([64, 1, 28, 28])\n","Batch 405 : torch.Size([64, 1, 28, 28])\n","Batch 406 : torch.Size([64, 1, 28, 28])\n","Batch 407 : torch.Size([64, 1, 28, 28])\n","Batch 408 : torch.Size([64, 1, 28, 28])\n","Batch 409 : torch.Size([64, 1, 28, 28])\n","Batch 410 : torch.Size([64, 1, 28, 28])\n","Batch 411 : torch.Size([64, 1, 28, 28])\n","Batch 412 : torch.Size([64, 1, 28, 28])\n","Batch 413 : torch.Size([64, 1, 28, 28])\n","Batch 414 : torch.Size([64, 1, 28, 28])\n","Batch 415 : torch.Size([64, 1, 28, 28])\n","Batch 416 : torch.Size([64, 1, 28, 28])\n","Batch 417 : torch.Size([64, 1, 28, 28])\n","Batch 418 : torch.Size([64, 1, 28, 28])\n","Batch 419 : torch.Size([64, 1, 28, 28])\n","Batch 420 : torch.Size([64, 1, 28, 28])\n","Batch 421 : torch.Size([64, 1, 28, 28])\n","Batch 422 : torch.Size([64, 1, 28, 28])\n","Batch 423 : torch.Size([64, 1, 28, 28])\n","Batch 424 : torch.Size([64, 1, 28, 28])\n","Batch 425 : torch.Size([64, 1, 28, 28])\n","Batch 426 : torch.Size([64, 1, 28, 28])\n","Batch 427 : torch.Size([64, 1, 28, 28])\n","Batch 428 : torch.Size([64, 1, 28, 28])\n","Batch 429 : torch.Size([64, 1, 28, 28])\n","Batch 430 : torch.Size([64, 1, 28, 28])\n","Batch 431 : torch.Size([64, 1, 28, 28])\n","Batch 432 : torch.Size([64, 1, 28, 28])\n","Batch 433 : torch.Size([64, 1, 28, 28])\n","Batch 434 : torch.Size([64, 1, 28, 28])\n","Batch 435 : torch.Size([64, 1, 28, 28])\n","Batch 436 : torch.Size([64, 1, 28, 28])\n","Batch 437 : torch.Size([64, 1, 28, 28])\n","Batch 438 : torch.Size([64, 1, 28, 28])\n","Batch 439 : torch.Size([64, 1, 28, 28])\n","Batch 440 : torch.Size([64, 1, 28, 28])\n","Batch 441 : torch.Size([64, 1, 28, 28])\n","Batch 442 : torch.Size([64, 1, 28, 28])\n","Batch 443 : torch.Size([64, 1, 28, 28])\n","Batch 444 : torch.Size([64, 1, 28, 28])\n","Batch 445 : torch.Size([64, 1, 28, 28])\n","Batch 446 : torch.Size([64, 1, 28, 28])\n","Batch 447 : torch.Size([64, 1, 28, 28])\n","Batch 448 : torch.Size([64, 1, 28, 28])\n","Batch 449 : torch.Size([64, 1, 28, 28])\n","Batch 450 : torch.Size([64, 1, 28, 28])\n","Batch 451 : torch.Size([64, 1, 28, 28])\n","Batch 452 : torch.Size([64, 1, 28, 28])\n","Batch 453 : torch.Size([64, 1, 28, 28])\n","Batch 454 : torch.Size([64, 1, 28, 28])\n","Batch 455 : torch.Size([64, 1, 28, 28])\n","Batch 456 : torch.Size([64, 1, 28, 28])\n","Batch 457 : torch.Size([64, 1, 28, 28])\n","Batch 458 : torch.Size([64, 1, 28, 28])\n","Batch 459 : torch.Size([64, 1, 28, 28])\n","Batch 460 : torch.Size([64, 1, 28, 28])\n","Batch 461 : torch.Size([64, 1, 28, 28])\n","Batch 462 : torch.Size([64, 1, 28, 28])\n","Batch 463 : torch.Size([64, 1, 28, 28])\n","Batch 464 : torch.Size([64, 1, 28, 28])\n","Batch 465 : torch.Size([64, 1, 28, 28])\n","Batch 466 : torch.Size([64, 1, 28, 28])\n","Batch 467 : torch.Size([64, 1, 28, 28])\n","Batch 468 : torch.Size([64, 1, 28, 28])\n","Batch 469 : torch.Size([64, 1, 28, 28])\n","Batch 470 : torch.Size([64, 1, 28, 28])\n","Batch 471 : torch.Size([64, 1, 28, 28])\n","Batch 472 : torch.Size([64, 1, 28, 28])\n","Batch 473 : torch.Size([64, 1, 28, 28])\n","Batch 474 : torch.Size([64, 1, 28, 28])\n","Batch 475 : torch.Size([64, 1, 28, 28])\n","Batch 476 : torch.Size([64, 1, 28, 28])\n","Batch 477 : torch.Size([64, 1, 28, 28])\n","Batch 478 : torch.Size([64, 1, 28, 28])\n","Batch 479 : torch.Size([64, 1, 28, 28])\n","Batch 480 : torch.Size([64, 1, 28, 28])\n","Batch 481 : torch.Size([64, 1, 28, 28])\n","Batch 482 : torch.Size([64, 1, 28, 28])\n","Batch 483 : torch.Size([64, 1, 28, 28])\n","Batch 484 : torch.Size([64, 1, 28, 28])\n","Batch 485 : torch.Size([64, 1, 28, 28])\n","Batch 486 : torch.Size([64, 1, 28, 28])\n","Batch 487 : torch.Size([64, 1, 28, 28])\n","Batch 488 : torch.Size([64, 1, 28, 28])\n","Batch 489 : torch.Size([64, 1, 28, 28])\n","Batch 490 : torch.Size([64, 1, 28, 28])\n","Batch 491 : torch.Size([64, 1, 28, 28])\n","Batch 492 : torch.Size([64, 1, 28, 28])\n","Batch 493 : torch.Size([64, 1, 28, 28])\n","Batch 494 : torch.Size([64, 1, 28, 28])\n","Batch 495 : torch.Size([64, 1, 28, 28])\n","Batch 496 : torch.Size([64, 1, 28, 28])\n","Batch 497 : torch.Size([64, 1, 28, 28])\n","Batch 498 : torch.Size([64, 1, 28, 28])\n","Batch 499 : torch.Size([64, 1, 28, 28])\n","Batch 500 : torch.Size([64, 1, 28, 28])\n","Batch 501 : torch.Size([64, 1, 28, 28])\n","Batch 502 : torch.Size([64, 1, 28, 28])\n","Batch 503 : torch.Size([64, 1, 28, 28])\n","Batch 504 : torch.Size([64, 1, 28, 28])\n","Batch 505 : torch.Size([64, 1, 28, 28])\n","Batch 506 : torch.Size([64, 1, 28, 28])\n","Batch 507 : torch.Size([64, 1, 28, 28])\n","Batch 508 : torch.Size([64, 1, 28, 28])\n","Batch 509 : torch.Size([64, 1, 28, 28])\n","Batch 510 : torch.Size([64, 1, 28, 28])\n","Batch 511 : torch.Size([64, 1, 28, 28])\n","Batch 512 : torch.Size([64, 1, 28, 28])\n","Batch 513 : torch.Size([64, 1, 28, 28])\n","Batch 514 : torch.Size([64, 1, 28, 28])\n","Batch 515 : torch.Size([64, 1, 28, 28])\n","Batch 516 : torch.Size([64, 1, 28, 28])\n","Batch 517 : torch.Size([64, 1, 28, 28])\n","Batch 518 : torch.Size([64, 1, 28, 28])\n","Batch 519 : torch.Size([64, 1, 28, 28])\n","Batch 520 : torch.Size([64, 1, 28, 28])\n","Batch 521 : torch.Size([64, 1, 28, 28])\n","Batch 522 : torch.Size([64, 1, 28, 28])\n","Batch 523 : torch.Size([64, 1, 28, 28])\n","Batch 524 : torch.Size([64, 1, 28, 28])\n","Batch 525 : torch.Size([64, 1, 28, 28])\n","Batch 526 : torch.Size([64, 1, 28, 28])\n","Batch 527 : torch.Size([64, 1, 28, 28])\n","Batch 528 : torch.Size([64, 1, 28, 28])\n","Batch 529 : torch.Size([64, 1, 28, 28])\n","Batch 530 : torch.Size([64, 1, 28, 28])\n","Batch 531 : torch.Size([64, 1, 28, 28])\n","Batch 532 : torch.Size([64, 1, 28, 28])\n","Batch 533 : torch.Size([64, 1, 28, 28])\n","Batch 534 : torch.Size([64, 1, 28, 28])\n","Batch 535 : torch.Size([64, 1, 28, 28])\n","Batch 536 : torch.Size([64, 1, 28, 28])\n","Batch 537 : torch.Size([64, 1, 28, 28])\n","Batch 538 : torch.Size([64, 1, 28, 28])\n","Batch 539 : torch.Size([64, 1, 28, 28])\n","Batch 540 : torch.Size([64, 1, 28, 28])\n","Batch 541 : torch.Size([64, 1, 28, 28])\n","Batch 542 : torch.Size([64, 1, 28, 28])\n","Batch 543 : torch.Size([64, 1, 28, 28])\n","Batch 544 : torch.Size([64, 1, 28, 28])\n","Batch 545 : torch.Size([64, 1, 28, 28])\n","Batch 546 : torch.Size([64, 1, 28, 28])\n","Batch 547 : torch.Size([64, 1, 28, 28])\n","Batch 548 : torch.Size([64, 1, 28, 28])\n","Batch 549 : torch.Size([64, 1, 28, 28])\n","Batch 550 : torch.Size([64, 1, 28, 28])\n","Batch 551 : torch.Size([64, 1, 28, 28])\n","Batch 552 : torch.Size([64, 1, 28, 28])\n","Batch 553 : torch.Size([64, 1, 28, 28])\n","Batch 554 : torch.Size([64, 1, 28, 28])\n","Batch 555 : torch.Size([64, 1, 28, 28])\n","Batch 556 : torch.Size([64, 1, 28, 28])\n","Batch 557 : torch.Size([64, 1, 28, 28])\n","Batch 558 : torch.Size([64, 1, 28, 28])\n","Batch 559 : torch.Size([64, 1, 28, 28])\n","Batch 560 : torch.Size([64, 1, 28, 28])\n","Batch 561 : torch.Size([64, 1, 28, 28])\n","Batch 562 : torch.Size([64, 1, 28, 28])\n","Batch 563 : torch.Size([64, 1, 28, 28])\n","Batch 564 : torch.Size([64, 1, 28, 28])\n","Batch 565 : torch.Size([64, 1, 28, 28])\n","Batch 566 : torch.Size([64, 1, 28, 28])\n","Batch 567 : torch.Size([64, 1, 28, 28])\n","Batch 568 : torch.Size([64, 1, 28, 28])\n","Batch 569 : torch.Size([64, 1, 28, 28])\n","Batch 570 : torch.Size([64, 1, 28, 28])\n","Batch 571 : torch.Size([64, 1, 28, 28])\n","Batch 572 : torch.Size([64, 1, 28, 28])\n","Batch 573 : torch.Size([64, 1, 28, 28])\n","Batch 574 : torch.Size([64, 1, 28, 28])\n","Batch 575 : torch.Size([64, 1, 28, 28])\n","Batch 576 : torch.Size([64, 1, 28, 28])\n","Batch 577 : torch.Size([64, 1, 28, 28])\n","Batch 578 : torch.Size([64, 1, 28, 28])\n","Batch 579 : torch.Size([64, 1, 28, 28])\n","Batch 580 : torch.Size([64, 1, 28, 28])\n","Batch 581 : torch.Size([64, 1, 28, 28])\n","Batch 582 : torch.Size([64, 1, 28, 28])\n","Batch 583 : torch.Size([64, 1, 28, 28])\n","Batch 584 : torch.Size([64, 1, 28, 28])\n","Batch 585 : torch.Size([64, 1, 28, 28])\n","Batch 586 : torch.Size([64, 1, 28, 28])\n","Batch 587 : torch.Size([64, 1, 28, 28])\n","Batch 588 : torch.Size([64, 1, 28, 28])\n","Batch 589 : torch.Size([64, 1, 28, 28])\n","Batch 590 : torch.Size([64, 1, 28, 28])\n","Batch 591 : torch.Size([64, 1, 28, 28])\n","Batch 592 : torch.Size([64, 1, 28, 28])\n","Batch 593 : torch.Size([64, 1, 28, 28])\n","Batch 594 : torch.Size([64, 1, 28, 28])\n","Batch 595 : torch.Size([64, 1, 28, 28])\n","Batch 596 : torch.Size([64, 1, 28, 28])\n","Batch 597 : torch.Size([64, 1, 28, 28])\n","Batch 598 : torch.Size([64, 1, 28, 28])\n","Batch 599 : torch.Size([64, 1, 28, 28])\n","Batch 600 : torch.Size([64, 1, 28, 28])\n","Batch 601 : torch.Size([64, 1, 28, 28])\n","Batch 602 : torch.Size([64, 1, 28, 28])\n","Batch 603 : torch.Size([64, 1, 28, 28])\n","Batch 604 : torch.Size([64, 1, 28, 28])\n","Batch 605 : torch.Size([64, 1, 28, 28])\n","Batch 606 : torch.Size([64, 1, 28, 28])\n","Batch 607 : torch.Size([64, 1, 28, 28])\n","Batch 608 : torch.Size([64, 1, 28, 28])\n","Batch 609 : torch.Size([64, 1, 28, 28])\n","Batch 610 : torch.Size([64, 1, 28, 28])\n","Batch 611 : torch.Size([64, 1, 28, 28])\n","Batch 612 : torch.Size([64, 1, 28, 28])\n","Batch 613 : torch.Size([64, 1, 28, 28])\n","Batch 614 : torch.Size([64, 1, 28, 28])\n","Batch 615 : torch.Size([64, 1, 28, 28])\n","Batch 616 : torch.Size([64, 1, 28, 28])\n","Batch 617 : torch.Size([64, 1, 28, 28])\n","Batch 618 : torch.Size([64, 1, 28, 28])\n","Batch 619 : torch.Size([64, 1, 28, 28])\n","Batch 620 : torch.Size([64, 1, 28, 28])\n","Batch 621 : torch.Size([64, 1, 28, 28])\n","Batch 622 : torch.Size([64, 1, 28, 28])\n","Batch 623 : torch.Size([64, 1, 28, 28])\n","Batch 624 : torch.Size([64, 1, 28, 28])\n","Batch 625 : torch.Size([64, 1, 28, 28])\n","Batch 626 : torch.Size([64, 1, 28, 28])\n","Batch 627 : torch.Size([64, 1, 28, 28])\n","Batch 628 : torch.Size([64, 1, 28, 28])\n","Batch 629 : torch.Size([64, 1, 28, 28])\n","Batch 630 : torch.Size([64, 1, 28, 28])\n","Batch 631 : torch.Size([64, 1, 28, 28])\n","Batch 632 : torch.Size([64, 1, 28, 28])\n","Batch 633 : torch.Size([64, 1, 28, 28])\n","Batch 634 : torch.Size([64, 1, 28, 28])\n","Batch 635 : torch.Size([64, 1, 28, 28])\n","Batch 636 : torch.Size([64, 1, 28, 28])\n","Batch 637 : torch.Size([64, 1, 28, 28])\n","Batch 638 : torch.Size([64, 1, 28, 28])\n","Batch 639 : torch.Size([64, 1, 28, 28])\n","Batch 640 : torch.Size([64, 1, 28, 28])\n","Batch 641 : torch.Size([64, 1, 28, 28])\n","Batch 642 : torch.Size([64, 1, 28, 28])\n","Batch 643 : torch.Size([64, 1, 28, 28])\n","Batch 644 : torch.Size([64, 1, 28, 28])\n","Batch 645 : torch.Size([64, 1, 28, 28])\n","Batch 646 : torch.Size([64, 1, 28, 28])\n","Batch 647 : torch.Size([64, 1, 28, 28])\n","Batch 648 : torch.Size([64, 1, 28, 28])\n","Batch 649 : torch.Size([64, 1, 28, 28])\n","Batch 650 : torch.Size([64, 1, 28, 28])\n","Batch 651 : torch.Size([64, 1, 28, 28])\n","Batch 652 : torch.Size([64, 1, 28, 28])\n","Batch 653 : torch.Size([64, 1, 28, 28])\n","Batch 654 : torch.Size([64, 1, 28, 28])\n","Batch 655 : torch.Size([64, 1, 28, 28])\n","Batch 656 : torch.Size([64, 1, 28, 28])\n","Batch 657 : torch.Size([64, 1, 28, 28])\n","Batch 658 : torch.Size([64, 1, 28, 28])\n","Batch 659 : torch.Size([64, 1, 28, 28])\n","Batch 660 : torch.Size([64, 1, 28, 28])\n","Batch 661 : torch.Size([64, 1, 28, 28])\n","Batch 662 : torch.Size([64, 1, 28, 28])\n","Batch 663 : torch.Size([64, 1, 28, 28])\n","Batch 664 : torch.Size([64, 1, 28, 28])\n","Batch 665 : torch.Size([64, 1, 28, 28])\n","Batch 666 : torch.Size([64, 1, 28, 28])\n","Batch 667 : torch.Size([64, 1, 28, 28])\n","Batch 668 : torch.Size([64, 1, 28, 28])\n","Batch 669 : torch.Size([64, 1, 28, 28])\n","Batch 670 : torch.Size([64, 1, 28, 28])\n","Batch 671 : torch.Size([64, 1, 28, 28])\n","Batch 672 : torch.Size([64, 1, 28, 28])\n","Batch 673 : torch.Size([64, 1, 28, 28])\n","Batch 674 : torch.Size([64, 1, 28, 28])\n","Batch 675 : torch.Size([64, 1, 28, 28])\n","Batch 676 : torch.Size([64, 1, 28, 28])\n","Batch 677 : torch.Size([64, 1, 28, 28])\n","Batch 678 : torch.Size([64, 1, 28, 28])\n","Batch 679 : torch.Size([64, 1, 28, 28])\n","Batch 680 : torch.Size([64, 1, 28, 28])\n","Batch 681 : torch.Size([64, 1, 28, 28])\n","Batch 682 : torch.Size([64, 1, 28, 28])\n","Batch 683 : torch.Size([64, 1, 28, 28])\n","Batch 684 : torch.Size([64, 1, 28, 28])\n","Batch 685 : torch.Size([64, 1, 28, 28])\n","Batch 686 : torch.Size([64, 1, 28, 28])\n","Batch 687 : torch.Size([64, 1, 28, 28])\n","Batch 688 : torch.Size([64, 1, 28, 28])\n","Batch 689 : torch.Size([64, 1, 28, 28])\n","Batch 690 : torch.Size([64, 1, 28, 28])\n","Batch 691 : torch.Size([64, 1, 28, 28])\n","Batch 692 : torch.Size([64, 1, 28, 28])\n","Batch 693 : torch.Size([64, 1, 28, 28])\n","Batch 694 : torch.Size([64, 1, 28, 28])\n","Batch 695 : torch.Size([64, 1, 28, 28])\n","Batch 696 : torch.Size([64, 1, 28, 28])\n","Batch 697 : torch.Size([64, 1, 28, 28])\n","Batch 698 : torch.Size([64, 1, 28, 28])\n","Batch 699 : torch.Size([64, 1, 28, 28])\n","Batch 700 : torch.Size([64, 1, 28, 28])\n","Batch 701 : torch.Size([64, 1, 28, 28])\n","Batch 702 : torch.Size([64, 1, 28, 28])\n","Batch 703 : torch.Size([64, 1, 28, 28])\n","Batch 704 : torch.Size([64, 1, 28, 28])\n","Batch 705 : torch.Size([64, 1, 28, 28])\n","Batch 706 : torch.Size([64, 1, 28, 28])\n","Batch 707 : torch.Size([64, 1, 28, 28])\n","Batch 708 : torch.Size([64, 1, 28, 28])\n","Batch 709 : torch.Size([64, 1, 28, 28])\n","Batch 710 : torch.Size([64, 1, 28, 28])\n","Batch 711 : torch.Size([64, 1, 28, 28])\n","Batch 712 : torch.Size([64, 1, 28, 28])\n","Batch 713 : torch.Size([64, 1, 28, 28])\n","Batch 714 : torch.Size([64, 1, 28, 28])\n","Batch 715 : torch.Size([64, 1, 28, 28])\n","Batch 716 : torch.Size([64, 1, 28, 28])\n","Batch 717 : torch.Size([64, 1, 28, 28])\n","Batch 718 : torch.Size([64, 1, 28, 28])\n","Batch 719 : torch.Size([64, 1, 28, 28])\n","Batch 720 : torch.Size([64, 1, 28, 28])\n","Batch 721 : torch.Size([64, 1, 28, 28])\n","Batch 722 : torch.Size([64, 1, 28, 28])\n","Batch 723 : torch.Size([64, 1, 28, 28])\n","Batch 724 : torch.Size([64, 1, 28, 28])\n","Batch 725 : torch.Size([64, 1, 28, 28])\n","Batch 726 : torch.Size([64, 1, 28, 28])\n","Batch 727 : torch.Size([64, 1, 28, 28])\n","Batch 728 : torch.Size([64, 1, 28, 28])\n","Batch 729 : torch.Size([64, 1, 28, 28])\n","Batch 730 : torch.Size([64, 1, 28, 28])\n","Batch 731 : torch.Size([64, 1, 28, 28])\n","Batch 732 : torch.Size([64, 1, 28, 28])\n","Batch 733 : torch.Size([64, 1, 28, 28])\n","Batch 734 : torch.Size([64, 1, 28, 28])\n","Batch 735 : torch.Size([64, 1, 28, 28])\n","Batch 736 : torch.Size([64, 1, 28, 28])\n","Batch 737 : torch.Size([64, 1, 28, 28])\n","Batch 738 : torch.Size([64, 1, 28, 28])\n","Batch 739 : torch.Size([64, 1, 28, 28])\n","Batch 740 : torch.Size([64, 1, 28, 28])\n","Batch 741 : torch.Size([64, 1, 28, 28])\n","Batch 742 : torch.Size([64, 1, 28, 28])\n","Batch 743 : torch.Size([64, 1, 28, 28])\n","Batch 744 : torch.Size([64, 1, 28, 28])\n","Batch 745 : torch.Size([64, 1, 28, 28])\n","Batch 746 : torch.Size([64, 1, 28, 28])\n","Batch 747 : torch.Size([64, 1, 28, 28])\n","Batch 748 : torch.Size([64, 1, 28, 28])\n","Batch 749 : torch.Size([64, 1, 28, 28])\n","Batch 750 : torch.Size([64, 1, 28, 28])\n","Batch 751 : torch.Size([64, 1, 28, 28])\n","Batch 752 : torch.Size([64, 1, 28, 28])\n","Batch 753 : torch.Size([64, 1, 28, 28])\n","Batch 754 : torch.Size([64, 1, 28, 28])\n","Batch 755 : torch.Size([64, 1, 28, 28])\n","Batch 756 : torch.Size([64, 1, 28, 28])\n","Batch 757 : torch.Size([64, 1, 28, 28])\n","Batch 758 : torch.Size([64, 1, 28, 28])\n","Batch 759 : torch.Size([64, 1, 28, 28])\n","Batch 760 : torch.Size([64, 1, 28, 28])\n","Batch 761 : torch.Size([64, 1, 28, 28])\n","Batch 762 : torch.Size([64, 1, 28, 28])\n","Batch 763 : torch.Size([64, 1, 28, 28])\n","Batch 764 : torch.Size([64, 1, 28, 28])\n","Batch 765 : torch.Size([64, 1, 28, 28])\n","Batch 766 : torch.Size([64, 1, 28, 28])\n","Batch 767 : torch.Size([64, 1, 28, 28])\n","Batch 768 : torch.Size([64, 1, 28, 28])\n","Batch 769 : torch.Size([64, 1, 28, 28])\n","Batch 770 : torch.Size([64, 1, 28, 28])\n","Batch 771 : torch.Size([64, 1, 28, 28])\n","Batch 772 : torch.Size([64, 1, 28, 28])\n","Batch 773 : torch.Size([64, 1, 28, 28])\n","Batch 774 : torch.Size([64, 1, 28, 28])\n","Batch 775 : torch.Size([64, 1, 28, 28])\n","Batch 776 : torch.Size([64, 1, 28, 28])\n","Batch 777 : torch.Size([64, 1, 28, 28])\n","Batch 778 : torch.Size([64, 1, 28, 28])\n","Batch 779 : torch.Size([64, 1, 28, 28])\n","Batch 780 : torch.Size([64, 1, 28, 28])\n","Batch 781 : torch.Size([64, 1, 28, 28])\n","Batch 782 : torch.Size([64, 1, 28, 28])\n","Batch 783 : torch.Size([64, 1, 28, 28])\n","Batch 784 : torch.Size([64, 1, 28, 28])\n","Batch 785 : torch.Size([64, 1, 28, 28])\n","Batch 786 : torch.Size([64, 1, 28, 28])\n","Batch 787 : torch.Size([64, 1, 28, 28])\n","Batch 788 : torch.Size([64, 1, 28, 28])\n","Batch 789 : torch.Size([64, 1, 28, 28])\n","Batch 790 : torch.Size([64, 1, 28, 28])\n","Batch 791 : torch.Size([64, 1, 28, 28])\n","Batch 792 : torch.Size([64, 1, 28, 28])\n","Batch 793 : torch.Size([64, 1, 28, 28])\n","Batch 794 : torch.Size([64, 1, 28, 28])\n","Batch 795 : torch.Size([64, 1, 28, 28])\n","Batch 796 : torch.Size([64, 1, 28, 28])\n","Batch 797 : torch.Size([64, 1, 28, 28])\n","Batch 798 : torch.Size([64, 1, 28, 28])\n","Batch 799 : torch.Size([64, 1, 28, 28])\n","Batch 800 : torch.Size([64, 1, 28, 28])\n","Batch 801 : torch.Size([64, 1, 28, 28])\n","Batch 802 : torch.Size([64, 1, 28, 28])\n","Batch 803 : torch.Size([64, 1, 28, 28])\n","Batch 804 : torch.Size([64, 1, 28, 28])\n","Batch 805 : torch.Size([64, 1, 28, 28])\n","Batch 806 : torch.Size([64, 1, 28, 28])\n","Batch 807 : torch.Size([64, 1, 28, 28])\n","Batch 808 : torch.Size([64, 1, 28, 28])\n","Batch 809 : torch.Size([64, 1, 28, 28])\n","Batch 810 : torch.Size([64, 1, 28, 28])\n","Batch 811 : torch.Size([64, 1, 28, 28])\n","Batch 812 : torch.Size([64, 1, 28, 28])\n","Batch 813 : torch.Size([64, 1, 28, 28])\n","Batch 814 : torch.Size([64, 1, 28, 28])\n","Batch 815 : torch.Size([64, 1, 28, 28])\n","Batch 816 : torch.Size([64, 1, 28, 28])\n","Batch 817 : torch.Size([64, 1, 28, 28])\n","Batch 818 : torch.Size([64, 1, 28, 28])\n","Batch 819 : torch.Size([64, 1, 28, 28])\n","Batch 820 : torch.Size([64, 1, 28, 28])\n","Batch 821 : torch.Size([64, 1, 28, 28])\n","Batch 822 : torch.Size([64, 1, 28, 28])\n","Batch 823 : torch.Size([64, 1, 28, 28])\n","Batch 824 : torch.Size([64, 1, 28, 28])\n","Batch 825 : torch.Size([64, 1, 28, 28])\n","Batch 826 : torch.Size([64, 1, 28, 28])\n","Batch 827 : torch.Size([64, 1, 28, 28])\n","Batch 828 : torch.Size([64, 1, 28, 28])\n","Batch 829 : torch.Size([64, 1, 28, 28])\n","Batch 830 : torch.Size([64, 1, 28, 28])\n","Batch 831 : torch.Size([64, 1, 28, 28])\n","Batch 832 : torch.Size([64, 1, 28, 28])\n","Batch 833 : torch.Size([64, 1, 28, 28])\n","Batch 834 : torch.Size([64, 1, 28, 28])\n","Batch 835 : torch.Size([64, 1, 28, 28])\n","Batch 836 : torch.Size([64, 1, 28, 28])\n","Batch 837 : torch.Size([64, 1, 28, 28])\n","Batch 838 : torch.Size([64, 1, 28, 28])\n","Batch 839 : torch.Size([64, 1, 28, 28])\n","Batch 840 : torch.Size([64, 1, 28, 28])\n","Batch 841 : torch.Size([64, 1, 28, 28])\n","Batch 842 : torch.Size([64, 1, 28, 28])\n","Batch 843 : torch.Size([64, 1, 28, 28])\n","Batch 844 : torch.Size([64, 1, 28, 28])\n","Batch 845 : torch.Size([64, 1, 28, 28])\n","Batch 846 : torch.Size([64, 1, 28, 28])\n","Batch 847 : torch.Size([64, 1, 28, 28])\n","Batch 848 : torch.Size([64, 1, 28, 28])\n","Batch 849 : torch.Size([64, 1, 28, 28])\n","Batch 850 : torch.Size([64, 1, 28, 28])\n","Batch 851 : torch.Size([64, 1, 28, 28])\n","Batch 852 : torch.Size([64, 1, 28, 28])\n","Batch 853 : torch.Size([64, 1, 28, 28])\n","Batch 854 : torch.Size([64, 1, 28, 28])\n","Batch 855 : torch.Size([64, 1, 28, 28])\n","Batch 856 : torch.Size([64, 1, 28, 28])\n","Batch 857 : torch.Size([64, 1, 28, 28])\n","Batch 858 : torch.Size([64, 1, 28, 28])\n","Batch 859 : torch.Size([64, 1, 28, 28])\n","Batch 860 : torch.Size([64, 1, 28, 28])\n","Batch 861 : torch.Size([64, 1, 28, 28])\n","Batch 862 : torch.Size([64, 1, 28, 28])\n","Batch 863 : torch.Size([64, 1, 28, 28])\n","Batch 864 : torch.Size([64, 1, 28, 28])\n","Batch 865 : torch.Size([64, 1, 28, 28])\n","Batch 866 : torch.Size([64, 1, 28, 28])\n","Batch 867 : torch.Size([64, 1, 28, 28])\n","Batch 868 : torch.Size([64, 1, 28, 28])\n","Batch 869 : torch.Size([64, 1, 28, 28])\n","Batch 870 : torch.Size([64, 1, 28, 28])\n","Batch 871 : torch.Size([64, 1, 28, 28])\n","Batch 872 : torch.Size([64, 1, 28, 28])\n","Batch 873 : torch.Size([64, 1, 28, 28])\n","Batch 874 : torch.Size([64, 1, 28, 28])\n","Batch 875 : torch.Size([64, 1, 28, 28])\n","Batch 876 : torch.Size([64, 1, 28, 28])\n","Batch 877 : torch.Size([64, 1, 28, 28])\n","Batch 878 : torch.Size([64, 1, 28, 28])\n","Batch 879 : torch.Size([64, 1, 28, 28])\n","Batch 880 : torch.Size([64, 1, 28, 28])\n","Batch 881 : torch.Size([64, 1, 28, 28])\n","Batch 882 : torch.Size([64, 1, 28, 28])\n","Batch 883 : torch.Size([64, 1, 28, 28])\n","Batch 884 : torch.Size([64, 1, 28, 28])\n","Batch 885 : torch.Size([64, 1, 28, 28])\n","Batch 886 : torch.Size([64, 1, 28, 28])\n","Batch 887 : torch.Size([64, 1, 28, 28])\n","Batch 888 : torch.Size([64, 1, 28, 28])\n","Batch 889 : torch.Size([64, 1, 28, 28])\n","Batch 890 : torch.Size([64, 1, 28, 28])\n","Batch 891 : torch.Size([64, 1, 28, 28])\n","Batch 892 : torch.Size([64, 1, 28, 28])\n","Batch 893 : torch.Size([64, 1, 28, 28])\n","Batch 894 : torch.Size([64, 1, 28, 28])\n","Batch 895 : torch.Size([64, 1, 28, 28])\n","Batch 896 : torch.Size([64, 1, 28, 28])\n","Batch 897 : torch.Size([64, 1, 28, 28])\n","Batch 898 : torch.Size([64, 1, 28, 28])\n","Batch 899 : torch.Size([64, 1, 28, 28])\n","Batch 900 : torch.Size([64, 1, 28, 28])\n","Batch 901 : torch.Size([64, 1, 28, 28])\n","Batch 902 : torch.Size([64, 1, 28, 28])\n","Batch 903 : torch.Size([64, 1, 28, 28])\n","Batch 904 : torch.Size([64, 1, 28, 28])\n","Batch 905 : torch.Size([64, 1, 28, 28])\n","Batch 906 : torch.Size([64, 1, 28, 28])\n","Batch 907 : torch.Size([64, 1, 28, 28])\n","Batch 908 : torch.Size([64, 1, 28, 28])\n","Batch 909 : torch.Size([64, 1, 28, 28])\n","Batch 910 : torch.Size([64, 1, 28, 28])\n","Batch 911 : torch.Size([64, 1, 28, 28])\n","Batch 912 : torch.Size([64, 1, 28, 28])\n","Batch 913 : torch.Size([64, 1, 28, 28])\n","Batch 914 : torch.Size([64, 1, 28, 28])\n","Batch 915 : torch.Size([64, 1, 28, 28])\n","Batch 916 : torch.Size([64, 1, 28, 28])\n","Batch 917 : torch.Size([64, 1, 28, 28])\n","Batch 918 : torch.Size([64, 1, 28, 28])\n","Batch 919 : torch.Size([64, 1, 28, 28])\n","Batch 920 : torch.Size([64, 1, 28, 28])\n","Batch 921 : torch.Size([64, 1, 28, 28])\n","Batch 922 : torch.Size([64, 1, 28, 28])\n","Batch 923 : torch.Size([64, 1, 28, 28])\n","Batch 924 : torch.Size([64, 1, 28, 28])\n","Batch 925 : torch.Size([64, 1, 28, 28])\n","Batch 926 : torch.Size([64, 1, 28, 28])\n","Batch 927 : torch.Size([64, 1, 28, 28])\n","Batch 928 : torch.Size([64, 1, 28, 28])\n","Batch 929 : torch.Size([64, 1, 28, 28])\n","Batch 930 : torch.Size([64, 1, 28, 28])\n","Batch 931 : torch.Size([64, 1, 28, 28])\n","Batch 932 : torch.Size([64, 1, 28, 28])\n","Batch 933 : torch.Size([64, 1, 28, 28])\n","Batch 934 : torch.Size([64, 1, 28, 28])\n","Batch 935 : torch.Size([64, 1, 28, 28])\n","Batch 936 : torch.Size([64, 1, 28, 28])\n","Batch 937 : torch.Size([32, 1, 28, 28])\n"]}]},{"cell_type":"code","source":["937*64+32"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MvRWB_vlREHP","executionInfo":{"status":"ok","timestamp":1652711261032,"user_tz":-270,"elapsed":411,"user":{"displayName":"Arman Malekzadeh","userId":"13206128678526609465"}},"outputId":"9d4ec8a4-96cc-4d9f-d9bc-e195e867805f"},"execution_count":23,"outputs":[{"output_type":"execute_result","data":{"text/plain":["60000"]},"metadata":{},"execution_count":23}]}]} --------------------------------------------------------------------------------