├── 978-1-4842-8924-2.jpg ├── errata.md ├── README.md ├── Contributing.md ├── LICENSE.txt ├── Torch_AI_7_2Ed.ipynb ├── Torch_AI_4_2Ed.ipynb └── Torch_AI_2_2Ed.ipynb /978-1-4842-8924-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Apress/pytorch-recipes-2e/HEAD/978-1-4842-8924-2.jpg -------------------------------------------------------------------------------- /errata.md: -------------------------------------------------------------------------------- 1 | # Errata for *PyTorch Recipes, 2nd Edition* 2 | 3 | On **page xx** [Summary of error]: 4 | 5 | Details of error here. Highlight key pieces in **bold**. 6 | 7 | *** 8 | 9 | On **page xx** [Summary of error]: 10 | 11 | Details of error here. Highlight key pieces in **bold**. 12 | 13 | *** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Apress Source Code 2 | 3 | This repository accompanies [*PyTorch Recipes, 2nd Edition*](https://link.springer.com/book/10.1007/978-1-4842-8925-9) by Pradeepta Mishra (Apress, 2023). 4 | 5 | [comment]: #cover 6 | ![Cover image](978-1-4842-8924-2.jpg) 7 | 8 | Download the files as a zip using the green button, or clone the repository to your machine using Git. 9 | 10 | ## Releases 11 | 12 | Release v1.0 corresponds to the code in the published book, without corrections or updates. 13 | 14 | ## Contributions 15 | 16 | See the file Contributing.md for more information on how you can contribute to this repository. -------------------------------------------------------------------------------- /Contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing to Apress Source Code 2 | 3 | Copyright for Apress source code belongs to the author(s). However, under fair use you are encouraged to fork and contribute minor corrections and updates for the benefit of the author(s) and other readers. 4 | 5 | ## How to Contribute 6 | 7 | 1. Make sure you have a GitHub account. 8 | 2. Fork the repository for the relevant book. 9 | 3. Create a new branch on which to make your change, e.g. 10 | `git checkout -b my_code_contribution` 11 | 4. Commit your change. Include a commit message describing the correction. Please note that if your commit message is not clear, the correction will not be accepted. 12 | 5. Submit a pull request. 13 | 14 | Thank you for your contribution! -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Freeware License, some rights reserved 2 | 3 | Copyright (c) 2023 Pradeepta Mishra 4 | 5 | Permission is hereby granted, free of charge, to anyone obtaining a copy 6 | of this software and associated documentation files (the "Software"), 7 | to work with the Software within the limits of freeware distribution and fair use. 8 | This includes the rights to use, copy, and modify the Software for personal use. 9 | Users are also allowed and encouraged to submit corrections and modifications 10 | to the Software for the benefit of other users. 11 | 12 | It is not allowed to reuse, modify, or redistribute the Software for 13 | commercial use in any way, or for a user’s educational materials such as books 14 | or blog articles without prior permission from the copyright holder. 15 | 16 | The above copyright notice and this permission notice need to be included 17 | in all copies or substantial portions of the software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS OR APRESS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | 27 | 28 | -------------------------------------------------------------------------------- /Torch_AI_7_2Ed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "id": "MytZ1UizABOx", 11 | "outputId": "9d21f6b5-8e50-4c10-bf97-034d33dbb932" 12 | }, 13 | "outputs": [ 14 | { 15 | "output_type": "execute_result", 16 | "data": { 17 | "text/plain": [ 18 | "" 19 | ] 20 | }, 21 | "metadata": {}, 22 | "execution_count": 2 23 | } 24 | ], 25 | "source": [ 26 | "import torch\n", 27 | "import torch.nn as nn\n", 28 | "import torch.nn.functional as F\n", 29 | "import torch.optim as optim\n", 30 | "\n", 31 | "torch.manual_seed(1234)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": { 38 | "id": "xqNfr9D2ABO2" 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "word_to_ix = {\"data\": 0, \"science\": 1}" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": { 49 | "colab": { 50 | "base_uri": "https://localhost:8080/" 51 | }, 52 | "id": "oCWoOY8jABO2", 53 | "outputId": "ae1366a2-791a-433a-b5eb-b570af9ad7fd" 54 | }, 55 | "outputs": [ 56 | { 57 | "output_type": "execute_result", 58 | "data": { 59 | "text/plain": [ 60 | "{'data': 0, 'science': 1}" 61 | ] 62 | }, 63 | "metadata": {}, 64 | "execution_count": 4 65 | } 66 | ], 67 | "source": [ 68 | "word_to_ix" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 5, 74 | "metadata": { 75 | "id": "UTgWYqqHABO3" 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "embeds = nn.Embedding(2, 5) # 2 words in vocab, 5 dimensional embeddings" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 6, 85 | "metadata": { 86 | "colab": { 87 | "base_uri": "https://localhost:8080/" 88 | }, 89 | "id": "sUGENyHJABO3", 90 | "outputId": "3a46dae5-da99-4891-88f8-603bb2d40927" 91 | }, 92 | "outputs": [ 93 | { 94 | "output_type": "execute_result", 95 | "data": { 96 | "text/plain": [ 97 | "Embedding(2, 5)" 98 | ] 99 | }, 100 | "metadata": {}, 101 | "execution_count": 6 102 | } 103 | ], 104 | "source": [ 105 | "embeds" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 7, 111 | "metadata": { 112 | "colab": { 113 | "base_uri": "https://localhost:8080/" 114 | }, 115 | "id": "V681Q9XnABO4", 116 | "outputId": "c793bd9d-daec-4efd-db46-cf111a9a75c8" 117 | }, 118 | "outputs": [ 119 | { 120 | "output_type": "execute_result", 121 | "data": { 122 | "text/plain": [ 123 | "tensor([0])" 124 | ] 125 | }, 126 | "metadata": {}, 127 | "execution_count": 7 128 | } 129 | ], 130 | "source": [ 131 | "lookup_tensor = torch.tensor([word_to_ix[\"data\"]], dtype=torch.long)\n", 132 | "lookup_tensor" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 8, 138 | "metadata": { 139 | "colab": { 140 | "base_uri": "https://localhost:8080/" 141 | }, 142 | "id": "1UXOOYbEABO4", 143 | "outputId": "5034dc9b-c6b3-493b-b885-d847e74a1415" 144 | }, 145 | "outputs": [ 146 | { 147 | "output_type": "stream", 148 | "name": "stdout", 149 | "text": [ 150 | "tensor([[ 0.0461, 0.4024, -1.0115, 0.2167, -0.6123]],\n", 151 | " grad_fn=)\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "hello_embed = embeds(lookup_tensor)\n", 157 | "print(hello_embed)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 9, 163 | "metadata": { 164 | "id": "2HfByOz4ABO5" 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "CONTEXT_SIZE = 2" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 10, 174 | "metadata": { 175 | "id": "8zO20Kr3ABO6" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "EMBEDDING_DIM = 10" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 11, 185 | "metadata": { 186 | "colab": { 187 | "base_uri": "https://localhost:8080/" 188 | }, 189 | "id": "sCpDiugWABO6", 190 | "outputId": "9e369fe0-f7e8-465c-8f4b-5297f92329b1" 191 | }, 192 | "outputs": [ 193 | { 194 | "output_type": "stream", 195 | "name": "stdout", 196 | "text": [ 197 | "[(['The', 'popularity'], 'of'), (['popularity', 'of'], 'the'), (['of', 'the'], 'term')]\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "test_sentence = \"\"\"The popularity of the term \"data science\" has exploded in \n", 203 | "business environments and academia, as indicated by a jump in job openings.[32] \n", 204 | "However, many critical academics and journalists see no distinction between data \n", 205 | "science and statistics. Writing in Forbes, Gil Press argues that data science is a \n", 206 | "buzzword without a clear definition and has simply replaced “business analytics” in \n", 207 | "contexts such as graduate degree programs.[7] In the question-and-answer section of \n", 208 | "his keynote address at the Joint Statistical Meetings of American Statistical \n", 209 | "Association, noted applied statistician Nate Silver said, “I think data-scientist \n", 210 | "is a sexed up term for a statistician....Statistics is a branch of science. \n", 211 | "Data scientist is slightly redundant in some way and people shouldn’t berate the \n", 212 | "term statistician.”[9] Similarly, in business sector, multiple researchers and \n", 213 | "analysts state that data scientists alone are far from being sufficient in granting \n", 214 | "companies a real competitive advantage[33] and consider data scientists as only \n", 215 | "one of the four greater job families companies require to leverage big \n", 216 | "data effectively, namely: data analysts, data scientists, big data developers \n", 217 | "and big data engineers.[34]\n", 218 | "\n", 219 | "On the other hand, responses to criticism are as numerous. In a 2014 Wall Street \n", 220 | "Journal article, Irving Wladawsky-Berger compares the data science enthusiasm with \n", 221 | "the dawn of computer science. He argues data science, like any other interdisciplinary \n", 222 | "field, employs methodologies and practices from across the academia and industry, but \n", 223 | "then it will morph them into a new discipline. He brings to attention the sharp criticisms \n", 224 | "computer science, now a well respected academic discipline, had to once face.[35] Likewise, \n", 225 | "NYU Stern's Vasant Dhar, as do many other academic proponents of data science,[35] argues \n", 226 | "more specifically in December 2013 that data science is different from the existing practice \n", 227 | "of data analysis across all disciplines, which focuses only on explaining data sets. \n", 228 | "Data science seeks actionable and consistent pattern for predictive uses.[1] This practical \n", 229 | "engineering goal takes data science beyond traditional analytics. Now the data in those \n", 230 | "disciplines and applied fields that lacked solid theories, like health science and social \n", 231 | "science, could be sought and utilized to generate powerful predictive models.[1]\"\"\".split()\n", 232 | "# we should tokenize the input, but we will ignore that for now\n", 233 | "# build a list of tuples. Each tuple is ([ word_i-2, word_i-1 ], target word)\n", 234 | "trigrams = [([test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2])\n", 235 | " for i in range(len(test_sentence) - 2)]\n", 236 | "# print the first 3, just so you can see what they look like\n", 237 | "print(trigrams[:3])\n", 238 | "\n", 239 | "vocab = set(test_sentence)\n", 240 | "word_to_ix = {word: i for i, word in enumerate(vocab)}\n", 241 | "\n" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 12, 247 | "metadata": { 248 | "id": "dFnZJ8U8ABO8" 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "class NGramLanguageModeler(nn.Module):\n", 253 | "\n", 254 | " def __init__(self, vocab_size, embedding_dim, context_size):\n", 255 | " super(NGramLanguageModeler, self).__init__()\n", 256 | " self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n", 257 | " self.linear1 = nn.Linear(context_size * embedding_dim, 128)\n", 258 | " self.linear2 = nn.Linear(128, vocab_size)\n", 259 | "\n", 260 | " def forward(self, inputs):\n", 261 | " embeds = self.embeddings(inputs).view((1, -1))\n", 262 | " out = F.relu(self.linear1(embeds))\n", 263 | " out = self.linear2(out)\n", 264 | " log_probs = F.log_softmax(out, dim=1)\n", 265 | " return log_probs\n", 266 | "\n", 267 | "\n", 268 | "losses = []\n", 269 | "loss_function = nn.NLLLoss()\n", 270 | "model = NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)\n", 271 | "optimizer = optim.SGD(model.parameters(), lr=0.001)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 13, 277 | "metadata": { 278 | "colab": { 279 | "base_uri": "https://localhost:8080/" 280 | }, 281 | "id": "kqDULD3XABO9", 282 | "outputId": "c203d873-8f11-4710-9a10-65ef0649291d" 283 | }, 284 | "outputs": [ 285 | { 286 | "output_type": "execute_result", 287 | "data": { 288 | "text/plain": [ 289 | "[]" 290 | ] 291 | }, 292 | "metadata": {}, 293 | "execution_count": 13 294 | } 295 | ], 296 | "source": [ 297 | "losses" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 14, 303 | "metadata": { 304 | "colab": { 305 | "base_uri": "https://localhost:8080/" 306 | }, 307 | "id": "kEt72PU4ABO9", 308 | "outputId": "21ba8b18-3271-4afa-f2f4-bb5c85e2ed15" 309 | }, 310 | "outputs": [ 311 | { 312 | "output_type": "execute_result", 313 | "data": { 314 | "text/plain": [ 315 | "NLLLoss()" 316 | ] 317 | }, 318 | "metadata": {}, 319 | "execution_count": 14 320 | } 321 | ], 322 | "source": [ 323 | "loss_function" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 15, 329 | "metadata": { 330 | "colab": { 331 | "base_uri": "https://localhost:8080/" 332 | }, 333 | "id": "S0uxHOapABO-", 334 | "outputId": "25125ea0-4039-48de-88ae-cc60f858a0e9" 335 | }, 336 | "outputs": [ 337 | { 338 | "output_type": "execute_result", 339 | "data": { 340 | "text/plain": [ 341 | "NGramLanguageModeler(\n", 342 | " (embeddings): Embedding(228, 10)\n", 343 | " (linear1): Linear(in_features=20, out_features=128, bias=True)\n", 344 | " (linear2): Linear(in_features=128, out_features=228, bias=True)\n", 345 | ")" 346 | ] 347 | }, 348 | "metadata": {}, 349 | "execution_count": 15 350 | } 351 | ], 352 | "source": [ 353 | "model" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 16, 359 | "metadata": { 360 | "colab": { 361 | "base_uri": "https://localhost:8080/" 362 | }, 363 | "id": "2iO3WkKEABO-", 364 | "outputId": "638ca290-f285-4737-a8ed-abfebdee4957" 365 | }, 366 | "outputs": [ 367 | { 368 | "output_type": "execute_result", 369 | "data": { 370 | "text/plain": [ 371 | "SGD (\n", 372 | "Parameter Group 0\n", 373 | " dampening: 0\n", 374 | " foreach: None\n", 375 | " lr: 0.001\n", 376 | " maximize: False\n", 377 | " momentum: 0\n", 378 | " nesterov: False\n", 379 | " weight_decay: 0\n", 380 | ")" 381 | ] 382 | }, 383 | "metadata": {}, 384 | "execution_count": 16 385 | } 386 | ], 387 | "source": [ 388 | "optimizer" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 17, 394 | "metadata": { 395 | "colab": { 396 | "base_uri": "https://localhost:8080/" 397 | }, 398 | "id": "txaER7vdABO-", 399 | "outputId": "58684503-1163-40f9-df29-0e0f0e0fb1f9" 400 | }, 401 | "outputs": [ 402 | { 403 | "output_type": "stream", 404 | "name": "stdout", 405 | "text": [ 406 | "[1873.4337797164917, 1859.2190294265747, 1845.3114666938782, 1831.6828165054321, 1818.3093104362488, 1805.180431842804, 1792.2873740196228, 1779.6297824382782, 1767.2129256725311, 1755.0498096942902]\n" 407 | ] 408 | } 409 | ], 410 | "source": [ 411 | "for epoch in range(10):\n", 412 | " total_loss = 0\n", 413 | " for context, target in trigrams:\n", 414 | "\n", 415 | " # Step 1. Prepare the inputs to be passed to the model (i.e, turn the words\n", 416 | " # into integer indices and wrap them in tensors)\n", 417 | " context_idxs = torch.tensor([word_to_ix[w] for w in context], dtype=torch.long)\n", 418 | "\n", 419 | " # Step 2. Recall that torch *accumulates* gradients. Before passing in a\n", 420 | " # new instance, you need to zero out the gradients from the old\n", 421 | " # instance\n", 422 | " model.zero_grad()\n", 423 | "\n", 424 | " # Step 3. Run the forward pass, getting log probabilities over next\n", 425 | " # words\n", 426 | " log_probs = model(context_idxs)\n", 427 | "\n", 428 | " # Step 4. Compute your loss function. (Again, Torch wants the target\n", 429 | " # word wrapped in a tensor)\n", 430 | " loss = loss_function(log_probs, torch.tensor([word_to_ix[target]], dtype=torch.long))\n", 431 | "\n", 432 | " # Step 5. Do the backward pass and update the gradient\n", 433 | " loss.backward()\n", 434 | " optimizer.step()\n", 435 | "\n", 436 | " # Get the Python number from a 1-element Tensor by calling tensor.item()\n", 437 | " total_loss += loss.item()\n", 438 | " losses.append(total_loss)\n", 439 | "print(losses) # The loss decreased every iteration over the training data!" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 18, 445 | "metadata": { 446 | "id": "SK-6y5JPABO_" 447 | }, 448 | "outputs": [], 449 | "source": [ 450 | "CONTEXT_SIZE = 2 # 2 words to the left, 2 to the right" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 19, 456 | "metadata": { 457 | "id": "V9IRHtxaABO_" 458 | }, 459 | "outputs": [], 460 | "source": [ 461 | "raw_text = \"\"\"For the future of data science, Donoho projects an ever-growing \n", 462 | "environment for open science where data sets used for academic publications are \n", 463 | "accessible to all researchers.[36] US National Institute of Health has already announced \n", 464 | "plans to enhance reproducibility and transparency of research data.[39] Other big journals \n", 465 | "are likewise following suit.[40][41] This way, the future of data science not only exceeds \n", 466 | "the boundary of statistical theories in scale and methodology, but data science will \n", 467 | "revolutionize current academia and research paradigms.[36] As Donoho concludes, \"the scope \n", 468 | "and impact of data science will continue to expand enormously in coming decades as scientific \n", 469 | "data and data about science itself become ubiquitously available.\"[36]\"\"\".split()" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 20, 475 | "metadata": { 476 | "colab": { 477 | "base_uri": "https://localhost:8080/" 478 | }, 479 | "id": "EFjl5-URABPA", 480 | "outputId": "4b9e6699-b4e6-44c8-a7b9-04416d635a5d" 481 | }, 482 | "outputs": [ 483 | { 484 | "output_type": "stream", 485 | "name": "stdout", 486 | "text": [ 487 | "[(['For', 'the', 'of', 'data'], 'future'), (['the', 'future', 'data', 'science,'], 'of'), (['future', 'of', 'science,', 'Donoho'], 'data'), (['of', 'data', 'Donoho', 'projects'], 'science,'), (['data', 'science,', 'projects', 'an'], 'Donoho')]\n" 488 | ] 489 | } 490 | ], 491 | "source": [ 492 | "# By deriving a set from `raw_text`, we deduplicate the array\n", 493 | "vocab = set(raw_text)\n", 494 | "vocab_size = len(vocab)\n", 495 | "\n", 496 | "word_to_ix = {word: i for i, word in enumerate(vocab)}\n", 497 | "data = []\n", 498 | "for i in range(2, len(raw_text) - 2):\n", 499 | " context = [raw_text[i - 2], raw_text[i - 1],\n", 500 | " raw_text[i + 1], raw_text[i + 2]]\n", 501 | " target = raw_text[i]\n", 502 | " data.append((context, target))\n", 503 | "print(data[:5])" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 21, 509 | "metadata": { 510 | "colab": { 511 | "base_uri": "https://localhost:8080/" 512 | }, 513 | "id": "z0l0GlqRABPA", 514 | "outputId": "35c3f42c-f6bc-464c-bd37-520b0c914853" 515 | }, 516 | "outputs": [ 517 | { 518 | "output_type": "execute_result", 519 | "data": { 520 | "text/plain": [ 521 | "tensor([26, 54, 63, 18])" 522 | ] 523 | }, 524 | "metadata": {}, 525 | "execution_count": 21 526 | } 527 | ], 528 | "source": [ 529 | "class CBOW(nn.Module):\n", 530 | "\n", 531 | " def __init__(self):\n", 532 | " pass\n", 533 | "\n", 534 | " def forward(self, inputs):\n", 535 | " pass\n", 536 | "# create your model and train. here are some functions to help you make\n", 537 | "# the data ready for use by your module\n", 538 | "def make_context_vector(context, word_to_ix):\n", 539 | " idxs = [word_to_ix[w] for w in context]\n", 540 | " return torch.tensor(idxs, dtype=torch.long)\n", 541 | "\n", 542 | "make_context_vector(data[0][0], word_to_ix) # example" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 22, 548 | "metadata": { 549 | "colab": { 550 | "base_uri": "https://localhost:8080/" 551 | }, 552 | "id": "ameEBRbmABPA", 553 | "outputId": "1dd89e6e-48bf-4e34-8ac5-30f54ba9c095" 554 | }, 555 | "outputs": [ 556 | { 557 | "output_type": "stream", 558 | "name": "stdout", 559 | "text": [ 560 | "tensor([[-0.7850, 0.8883, 1.1011],\n", 561 | " [ 0.3344, -0.3598, 0.5535]], grad_fn=)\n" 562 | ] 563 | } 564 | ], 565 | "source": [ 566 | "lin = nn.Linear(5, 3) # maps from R^5 to R^3, parameters A, b\n", 567 | "# data is 2x5. A maps from 5 to 3... can we map \"data\" under A?\n", 568 | "data = torch.randn(2, 5)\n", 569 | "print(lin(data)) # yes" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": 23, 575 | "metadata": { 576 | "colab": { 577 | "base_uri": "https://localhost:8080/" 578 | }, 579 | "id": "yivUasI0ABPA", 580 | "outputId": "c2005fb8-4ef4-4759-ed21-17af5e2a083b" 581 | }, 582 | "outputs": [ 583 | { 584 | "output_type": "stream", 585 | "name": "stdout", 586 | "text": [ 587 | "tensor([[ 1.6053, -0.1710],\n", 588 | " [ 1.4815, -1.1123]])\n", 589 | "tensor([[1.6053, 0.0000],\n", 590 | " [1.4815, 0.0000]])\n" 591 | ] 592 | } 593 | ], 594 | "source": [ 595 | "data = torch.randn(2, 2)\n", 596 | "print(data)\n", 597 | "print(F.relu(data))" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 24, 603 | "metadata": { 604 | "colab": { 605 | "base_uri": "https://localhost:8080/" 606 | }, 607 | "id": "7EqjzxRDABPB", 608 | "outputId": "eb202843-e49e-4f33-97d5-b58ac3984285" 609 | }, 610 | "outputs": [ 611 | { 612 | "output_type": "stream", 613 | "name": "stdout", 614 | "text": [ 615 | "tensor([-0.4417, -2.5164, -0.2034, -2.1575, -1.2533])\n", 616 | "tensor([0.3313, 0.0416, 0.4204, 0.0596, 0.1471])\n", 617 | "tensor(1.0000)\n", 618 | "tensor([-1.1048, -3.1795, -0.8665, -2.8206, -1.9164])\n" 619 | ] 620 | } 621 | ], 622 | "source": [ 623 | "# Softmax is also in torch.nn.functional\n", 624 | "data = torch.randn(5)\n", 625 | "print(data)\n", 626 | "print(F.softmax(data, dim=0))\n", 627 | "print(F.softmax(data, dim=0).sum()) # Sums to 1 because it is a distribution!\n", 628 | "print(F.log_softmax(data, dim=0)) # theres also log_softmax" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 25, 634 | "metadata": { 635 | "colab": { 636 | "base_uri": "https://localhost:8080/" 637 | }, 638 | "id": "Fp1Sca5kABPB", 639 | "outputId": "d2e5af42-6466-467c-e0b6-53d513cb7007" 640 | }, 641 | "outputs": [ 642 | { 643 | "output_type": "stream", 644 | "name": "stdout", 645 | "text": [ 646 | "tensor([[[-0.1500, 0.0547, 0.3930]],\n", 647 | "\n", 648 | " [[-0.1313, -0.0478, 0.0857]],\n", 649 | "\n", 650 | " [[-0.1131, 0.0047, -0.1003]],\n", 651 | "\n", 652 | " [[ 0.0176, -0.2464, -0.1589]],\n", 653 | "\n", 654 | " [[-0.0523, 0.1781, -0.1713]]], grad_fn=)\n", 655 | "(tensor([[[-0.0523, 0.1781, -0.1713]]], grad_fn=), tensor([[[-0.1997, 0.5137, -0.6064]]], grad_fn=))\n" 656 | ] 657 | } 658 | ], 659 | "source": [ 660 | "lstm = nn.LSTM(3, 3) # Input dim is 3, output dim is 3\n", 661 | "inputs = [torch.randn(1, 3) for _ in range(5)] # make a sequence of length 5\n", 662 | "\n", 663 | "# initialize the hidden state.\n", 664 | "hidden = (torch.randn(1, 1, 3),\n", 665 | " torch.randn(1, 1, 3))\n", 666 | "for i in inputs:\n", 667 | " # Step through the sequence one element at a time.\n", 668 | " # after each step, hidden contains the hidden state.\n", 669 | " out, hidden = lstm(i.view(1, 1, -1), hidden)\n", 670 | "inputs = torch.cat(inputs).view(len(inputs), 1, -1)\n", 671 | "hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3)) # clean out hidden state\n", 672 | "out, hidden = lstm(inputs, hidden)\n", 673 | "print(out)\n", 674 | "print(hidden)" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": 26, 680 | "metadata": { 681 | "id": "Avuv880qABPB" 682 | }, 683 | "outputs": [], 684 | "source": [ 685 | "def prepare_sequence(seq, to_ix):\n", 686 | " idxs = [to_ix[w] for w in seq]\n", 687 | " return torch.tensor(idxs, dtype=torch.long)\n", 688 | "\n", 689 | "\n", 690 | "training_data = [\n", 691 | " (\"Probability and random variable are integral part of computation \".split(), \n", 692 | " [\"DET\", \"NN\", \"V\", \"DET\", \"NN\"]),\n", 693 | " (\"Understanding of the probability and associated concepts are essential\".split(), \n", 694 | " [\"NN\", \"V\", \"DET\", \"NN\"])\n", 695 | "]" 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": 27, 701 | "metadata": { 702 | "colab": { 703 | "base_uri": "https://localhost:8080/" 704 | }, 705 | "id": "8l52IQxvABPC", 706 | "outputId": "2cb9e55a-526b-433f-d97f-093391679871" 707 | }, 708 | "outputs": [ 709 | { 710 | "output_type": "execute_result", 711 | "data": { 712 | "text/plain": [ 713 | "[(['Probability',\n", 714 | " 'and',\n", 715 | " 'random',\n", 716 | " 'variable',\n", 717 | " 'are',\n", 718 | " 'integral',\n", 719 | " 'part',\n", 720 | " 'of',\n", 721 | " 'computation'],\n", 722 | " ['DET', 'NN', 'V', 'DET', 'NN']),\n", 723 | " (['Understanding',\n", 724 | " 'of',\n", 725 | " 'the',\n", 726 | " 'probability',\n", 727 | " 'and',\n", 728 | " 'associated',\n", 729 | " 'concepts',\n", 730 | " 'are',\n", 731 | " 'essential'],\n", 732 | " ['NN', 'V', 'DET', 'NN'])]" 733 | ] 734 | }, 735 | "metadata": {}, 736 | "execution_count": 27 737 | } 738 | ], 739 | "source": [ 740 | "training_data" 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": 28, 746 | "metadata": { 747 | "colab": { 748 | "base_uri": "https://localhost:8080/" 749 | }, 750 | "id": "H7_X_0SMABPC", 751 | "outputId": "64e9ced5-c629-41b5-aad8-f8886179fb5a" 752 | }, 753 | "outputs": [ 754 | { 755 | "output_type": "stream", 756 | "name": "stdout", 757 | "text": [ 758 | "{'Probability': 0, 'and': 1, 'random': 2, 'variable': 3, 'are': 4, 'integral': 5, 'part': 6, 'of': 7, 'computation': 8, 'Understanding': 9, 'the': 10, 'probability': 11, 'associated': 12, 'concepts': 13, 'essential': 14}\n" 759 | ] 760 | } 761 | ], 762 | "source": [ 763 | "word_to_ix = {}\n", 764 | "for sent, tags in training_data:\n", 765 | " for word in sent:\n", 766 | " if word not in word_to_ix:\n", 767 | " word_to_ix[word] = len(word_to_ix)\n", 768 | "print(word_to_ix)\n", 769 | "tag_to_ix = {\"DET\": 0, \"NN\": 1, \"V\": 2}\n", 770 | "\n", 771 | "EMBEDDING_DIM = 6\n", 772 | "HIDDEN_DIM = 6" 773 | ] 774 | }, 775 | { 776 | "cell_type": "code", 777 | "execution_count": 29, 778 | "metadata": { 779 | "id": "4-HA50ozABPC" 780 | }, 781 | "outputs": [], 782 | "source": [ 783 | "class LSTMTagger(nn.Module):\n", 784 | "\n", 785 | " def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):\n", 786 | " super(LSTMTagger, self).__init__()\n", 787 | " self.hidden_dim = hidden_dim\n", 788 | "\n", 789 | " self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n", 790 | "\n", 791 | " # The LSTM takes word embeddings as inputs, and outputs hidden states\n", 792 | " # with dimensionality hidden_dim.\n", 793 | " self.lstm = nn.LSTM(embedding_dim, hidden_dim)\n", 794 | "\n", 795 | " # The linear layer that maps from hidden state space to tag space\n", 796 | " self.hidden2tag = nn.Linear(hidden_dim, tagset_size)\n", 797 | " self.hidden = self.init_hidden()\n", 798 | "\n", 799 | " def init_hidden(self):\n", 800 | " # Before we've done anything, we dont have any hidden state.\n", 801 | " # Refer to the Pytorch documentation to see exactly\n", 802 | " # why they have this dimensionality.\n", 803 | " # The axes semantics are (num_layers, minibatch_size, hidden_dim)\n", 804 | " return (torch.zeros(1, 1, self.hidden_dim),\n", 805 | " torch.zeros(1, 1, self.hidden_dim))\n", 806 | "\n", 807 | " def forward(self, sentence):\n", 808 | " embeds = self.word_embeddings(sentence)\n", 809 | " lstm_out, self.hidden = self.lstm(\n", 810 | " embeds.view(len(sentence), 1, -1), self.hidden)\n", 811 | " tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))\n", 812 | " tag_scores = F.log_softmax(tag_space, dim=1)\n", 813 | " return tag_scores" 814 | ] 815 | }, 816 | { 817 | "cell_type": "code", 818 | "execution_count": 30, 819 | "metadata": { 820 | "colab": { 821 | "base_uri": "https://localhost:8080/" 822 | }, 823 | "id": "vX4KX0VlABPC", 824 | "outputId": "597cbd66-a121-4478-9b57-c2d1350ecaa5" 825 | }, 826 | "outputs": [ 827 | { 828 | "output_type": "execute_result", 829 | "data": { 830 | "text/plain": [ 831 | "SGD (\n", 832 | "Parameter Group 0\n", 833 | " dampening: 0\n", 834 | " foreach: None\n", 835 | " lr: 0.1\n", 836 | " maximize: False\n", 837 | " momentum: 0\n", 838 | " nesterov: False\n", 839 | " weight_decay: 0\n", 840 | ")" 841 | ] 842 | }, 843 | "metadata": {}, 844 | "execution_count": 30 845 | } 846 | ], 847 | "source": [ 848 | "model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))\n", 849 | "loss_function = nn.NLLLoss()\n", 850 | "optimizer = optim.SGD(model.parameters(), lr=0.1)\n", 851 | "model\n", 852 | "loss_function\n", 853 | "optimizer" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": 31, 859 | "metadata": { 860 | "colab": { 861 | "base_uri": "https://localhost:8080/" 862 | }, 863 | "id": "WWO3x1n2ABPD", 864 | "outputId": "3913b377-e2f4-40fb-bd31-ce6e959511a4" 865 | }, 866 | "outputs": [ 867 | { 868 | "output_type": "stream", 869 | "name": "stdout", 870 | "text": [ 871 | "tensor([[-1.0414, -1.1928, -1.0680],\n", 872 | " [-1.0747, -1.2163, -1.0154],\n", 873 | " [-1.0706, -1.2298, -1.0083],\n", 874 | " [-1.0661, -1.2428, -1.0022],\n", 875 | " [-1.0013, -1.2948, -1.0254],\n", 876 | " [-1.0539, -1.2640, -0.9973],\n", 877 | " [-1.0718, -1.2705, -0.9757],\n", 878 | " [-0.9919, -1.2527, -1.0689],\n", 879 | " [-0.9726, -1.2880, -1.0611]])\n" 880 | ] 881 | } 882 | ], 883 | "source": [ 884 | "with torch.no_grad():\n", 885 | " inputs = prepare_sequence(training_data[0][0], word_to_ix)\n", 886 | " tag_scores = model(inputs)\n", 887 | " print(tag_scores)" 888 | ] 889 | }, 890 | { 891 | "cell_type": "code", 892 | "execution_count": 33, 893 | "metadata": { 894 | "colab": { 895 | "base_uri": "https://localhost:8080/" 896 | }, 897 | "id": "ZCblg5maABPD", 898 | "outputId": "608568d2-02ca-49f1-a8e7-ea849707396c" 899 | }, 900 | "outputs": [ 901 | { 902 | "output_type": "stream", 903 | "name": "stdout", 904 | "text": [ 905 | "tensor([[-0.9758, -1.2959, -1.0513],\n", 906 | " [-1.0012, -1.2932, -1.0267],\n", 907 | " [-1.0178, -1.2790, -1.0208],\n", 908 | " [-1.0394, -1.2699, -1.0066],\n", 909 | " [-0.9647, -1.3276, -1.0391],\n", 910 | " [-1.0317, -1.2859, -1.0019],\n", 911 | " [-1.0524, -1.2877, -0.9808],\n", 912 | " [-0.9821, -1.2619, -1.0719],\n", 913 | " [-0.9645, -1.2950, -1.0644]])\n" 914 | ] 915 | } 916 | ], 917 | "source": [ 918 | "# See what the scores are after training\n", 919 | "with torch.no_grad():\n", 920 | " inputs = prepare_sequence(training_data[0][0], word_to_ix)\n", 921 | " tag_scores = model(inputs)\n", 922 | " print(tag_scores)" 923 | ] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "execution_count": null, 928 | "metadata": { 929 | "id": "Qy4Zz42xABPD" 930 | }, 931 | "outputs": [], 932 | "source": [] 933 | } 934 | ], 935 | "metadata": { 936 | "kernelspec": { 937 | "display_name": "Python 3", 938 | "language": "python", 939 | "name": "python3" 940 | }, 941 | "language_info": { 942 | "codemirror_mode": { 943 | "name": "ipython", 944 | "version": 3 945 | }, 946 | "file_extension": ".py", 947 | "mimetype": "text/x-python", 948 | "name": "python", 949 | "nbconvert_exporter": "python", 950 | "pygments_lexer": "ipython3", 951 | "version": "3.6.4" 952 | }, 953 | "colab": { 954 | "name": "Torch_AI_7_2Ed.ipynb", 955 | "provenance": [], 956 | "collapsed_sections": [] 957 | } 958 | }, 959 | "nbformat": 4, 960 | "nbformat_minor": 0 961 | } -------------------------------------------------------------------------------- /Torch_AI_4_2Ed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "collapsed": true, 8 | "id": "JCfKPwoa61Fx" 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "from __future__ import print_function\n", 13 | "import torch\n", 14 | "import numpy as np\n", 15 | "import torch.optim\n", 16 | "import torch.nn as nn\n", 17 | "import torch.optim as optim\n", 18 | "import torch.nn.init as init\n", 19 | "import torch.nn.functional as F\n", 20 | "from torch.autograd import Variable" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "source": [ 26 | "import warnings\n", 27 | "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n" 28 | ], 29 | "metadata": { 30 | "id": "dP1q5hkL7Ls4" 31 | }, 32 | "execution_count": 3, 33 | "outputs": [] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 4, 38 | "metadata": { 39 | "collapsed": true, 40 | "id": "I-55AkKp61F2" 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "#torch.nn: - Neural networks can be constructed using the torch.nn package." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 5, 50 | "metadata": { 51 | "colab": { 52 | "base_uri": "https://localhost:8080/" 53 | }, 54 | "id": "48z4bsw_61F3", 55 | "outputId": "fd9e291b-749b-452c-8648-512ffe54b4fc" 56 | }, 57 | "outputs": [ 58 | { 59 | "output_type": "stream", 60 | "name": "stdout", 61 | "text": [ 62 | "Output size : torch.Size([100, 5])\n", 63 | "Output size : torch.Size([100, 5])\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "x = Variable(torch.randn(100, 10))\n", 69 | "y = Variable(torch.randn(100, 30))\n", 70 | "\n", 71 | "linear = nn.Linear(in_features=10, out_features=5, bias=True)\n", 72 | "output_linear = linear(x)\n", 73 | "print('Output size : ', output_linear.size())\n", 74 | "\n", 75 | "bilinear = nn.Bilinear(in1_features=10, in2_features=30, out_features=5, bias=True)\n", 76 | "output_bilinear = bilinear(x, y)\n", 77 | "print('Output size : ', output_bilinear.size())" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 6, 83 | "metadata": { 84 | "colab": { 85 | "base_uri": "https://localhost:8080/" 86 | }, 87 | "id": "hZAXdW7b61F4", 88 | "outputId": "8b3d56ab-31a0-4254-e287-389f2b531146" 89 | }, 90 | "outputs": [ 91 | { 92 | "output_type": "stream", 93 | "name": "stdout", 94 | "text": [ 95 | "Output size : torch.Size([100, 10])\n", 96 | "Output size : torch.Size([100, 30])\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "x = Variable(torch.randn(100, 10))\n", 102 | "y = Variable(torch.randn(100, 30))\n", 103 | "\n", 104 | "sig = nn.Sigmoid()\n", 105 | "output_sig = sig(x)\n", 106 | "output_sigy = sig(y)\n", 107 | "print('Output size : ', output_sig.size())\n", 108 | "print('Output size : ', output_sigy.size())" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 7, 114 | "metadata": { 115 | "colab": { 116 | "base_uri": "https://localhost:8080/" 117 | }, 118 | "id": "Em-RXt1u61F5", 119 | "outputId": "06eb30e4-874d-40ef-9fa1-16e170c217a2" 120 | }, 121 | "outputs": [ 122 | { 123 | "output_type": "stream", 124 | "name": "stdout", 125 | "text": [ 126 | "tensor([-1.5454, 0.3599, 2.2720, 0.7115, 0.5296, 0.6176, 1.8854, 0.4854,\n", 127 | " -0.3893, 0.8369])\n", 128 | "tensor([0.1758, 0.5890, 0.9065, 0.6707, 0.6294, 0.6497, 0.8682, 0.6190, 0.4039,\n", 129 | " 0.6978])\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "print(x[0])\n", 135 | "print(output_sig[0])" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 8, 141 | "metadata": { 142 | "colab": { 143 | "base_uri": "https://localhost:8080/" 144 | }, 145 | "id": "zBMtVCYW61F6", 146 | "outputId": "18a7644c-d91b-4d28-b3db-bc0984bb533a" 147 | }, 148 | "outputs": [ 149 | { 150 | "output_type": "stream", 151 | "name": "stdout", 152 | "text": [ 153 | "Output size : torch.Size([100, 10])\n", 154 | "Output size : torch.Size([100, 30])\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "x = Variable(torch.randn(100, 10))\n", 160 | "y = Variable(torch.randn(100, 30))\n", 161 | "\n", 162 | "func = nn.Tanh()\n", 163 | "output_x = func(x)\n", 164 | "output_y = func(y)\n", 165 | "print('Output size : ', output_x.size())\n", 166 | "print('Output size : ', output_y.size())" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 9, 172 | "metadata": { 173 | "colab": { 174 | "base_uri": "https://localhost:8080/" 175 | }, 176 | "id": "RHll2pGj61F7", 177 | "outputId": "a30d72d5-f97c-45b7-c9f6-cc0feeb98cdf" 178 | }, 179 | "outputs": [ 180 | { 181 | "output_type": "stream", 182 | "name": "stdout", 183 | "text": [ 184 | "tensor([ 1.6056, 0.1092, 0.2044, 1.0537, -0.8658, -0.9111, -1.1586, -1.7745,\n", 185 | " -0.8922, -2.3219])\n", 186 | "tensor([ 0.9225, 0.1087, 0.2016, 0.7832, -0.6992, -0.7217, -0.8206, -0.9441,\n", 187 | " -0.7125, -0.9809])\n", 188 | "tensor([ 0.2153, 1.3900, 0.4259, -0.3347, -1.2087, -0.1930, 0.1645, -1.5867,\n", 189 | " -0.1752, 0.3863, 0.6141, 1.6769, -0.8080, 0.3790, -0.7446, 0.1795,\n", 190 | " -1.5132, 0.8282, 1.6872, 0.7207, -0.6874, 0.0136, 0.3600, 1.9525,\n", 191 | " -0.1363, -0.2002, 0.4026, -0.1413, 2.2343, 1.0469])\n", 192 | "tensor([ 0.2121, 0.8832, 0.4019, -0.3228, -0.8363, -0.1907, 0.1631, -0.9196,\n", 193 | " -0.1735, 0.3682, 0.5470, 0.9325, -0.6685, 0.3619, -0.6319, 0.1776,\n", 194 | " -0.9075, 0.6795, 0.9338, 0.6173, -0.5963, 0.0136, 0.3452, 0.9605,\n", 195 | " -0.1355, -0.1976, 0.3821, -0.1404, 0.9773, 0.7806])\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "print(x[0])\n", 201 | "print(output_x[0])\n", 202 | "print(y[0])\n", 203 | "print(output_y[0])" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 10, 209 | "metadata": { 210 | "colab": { 211 | "base_uri": "https://localhost:8080/" 212 | }, 213 | "id": "4mvr4lQ561F8", 214 | "outputId": "cec54c4c-2877-4219-9c0b-b646d717dc53" 215 | }, 216 | "outputs": [ 217 | { 218 | "output_type": "stream", 219 | "name": "stdout", 220 | "text": [ 221 | "Output size : torch.Size([100, 10])\n", 222 | "Output size : torch.Size([100, 30])\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "x = Variable(torch.randn(100, 10))\n", 228 | "y = Variable(torch.randn(100, 30))\n", 229 | "\n", 230 | "func = nn.LogSigmoid()\n", 231 | "output_x = func(x)\n", 232 | "output_y = func(y)\n", 233 | "print('Output size : ', output_x.size())\n", 234 | "print('Output size : ', output_y.size())" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 11, 240 | "metadata": { 241 | "colab": { 242 | "base_uri": "https://localhost:8080/" 243 | }, 244 | "id": "0wI2XtXC61F9", 245 | "outputId": "312202f9-fe71-4b56-9c3e-30e70eda32f8" 246 | }, 247 | "outputs": [ 248 | { 249 | "output_type": "stream", 250 | "name": "stdout", 251 | "text": [ 252 | "tensor([-0.9983, -0.2337, 0.7794, 1.0399, -1.4705, -1.4177, -0.2531, -1.0391,\n", 253 | " -1.1570, -0.5105])\n", 254 | "tensor([-1.3120, -0.8168, -0.3775, -0.3027, -1.6773, -1.6346, -0.8277, -1.3420,\n", 255 | " -1.4304, -0.9806])\n", 256 | "tensor([-0.3758, -1.1889, 0.7846, 0.8277, 0.1351, 0.2677, -0.2810, -1.1610,\n", 257 | " -0.6973, -0.1106, 0.6361, 1.4497, -0.6007, -0.1102, 0.8876, -0.1440,\n", 258 | " -0.2914, -0.0144, 1.4152, 2.1429, 0.8828, 0.9561, -0.1876, 1.1487,\n", 259 | " 0.6150, -0.1044, 1.3075, -0.1601, -0.4018, -1.2599])\n", 260 | "tensor([-0.8986, -1.4547, -0.3759, -0.3626, -0.6279, -0.5683, -0.8435, -1.4335,\n", 261 | " -1.1014, -0.7500, -0.4249, -0.2108, -1.0379, -0.7498, -0.3448, -0.7677,\n", 262 | " -0.8494, -0.7004, -0.2174, -0.1109, -0.3462, -0.3253, -0.7913, -0.2754,\n", 263 | " -0.4322, -0.7467, -0.2394, -0.7764, -0.9141, -1.5096])\n" 264 | ] 265 | } 266 | ], 267 | "source": [ 268 | "print(x[0])\n", 269 | "print(output_x[0])\n", 270 | "print(y[0])\n", 271 | "print(output_y[0])" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 12, 277 | "metadata": { 278 | "colab": { 279 | "base_uri": "https://localhost:8080/" 280 | }, 281 | "id": "0zapVEqM61F9", 282 | "outputId": "4eb2b683-12dc-4287-dc86-8c44a0bb3e4e" 283 | }, 284 | "outputs": [ 285 | { 286 | "output_type": "stream", 287 | "name": "stdout", 288 | "text": [ 289 | "Output size : torch.Size([100, 10])\n", 290 | "Output size : torch.Size([100, 30])\n" 291 | ] 292 | } 293 | ], 294 | "source": [ 295 | "x = Variable(torch.randn(100, 10))\n", 296 | "y = Variable(torch.randn(100, 30))\n", 297 | "\n", 298 | "func = nn.ReLU()\n", 299 | "output_x = func(x)\n", 300 | "output_y = func(y)\n", 301 | "print('Output size : ', output_x.size())\n", 302 | "print('Output size : ', output_y.size())" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 13, 308 | "metadata": { 309 | "colab": { 310 | "base_uri": "https://localhost:8080/" 311 | }, 312 | "id": "KrZfvpH861F-", 313 | "outputId": "fcf0d7b5-57d6-4d41-fae3-27fc68ae3329" 314 | }, 315 | "outputs": [ 316 | { 317 | "output_type": "stream", 318 | "name": "stdout", 319 | "text": [ 320 | "tensor([-0.6479, -0.8856, 0.5144, -0.5064, 0.3280, -1.8378, 0.5670, 0.9095,\n", 321 | " -2.6267, -1.0119])\n", 322 | "tensor([0.0000, 0.0000, 0.5144, 0.0000, 0.3280, 0.0000, 0.5670, 0.9095, 0.0000,\n", 323 | " 0.0000])\n", 324 | "tensor([-1.4458, 0.8328, 0.6534, 2.0404, 0.9053, -0.2829, -0.5712, 0.0323,\n", 325 | " 0.9757, -1.5787, 1.9665, 1.0276, -1.0536, 0.0588, 0.5085, 0.1956,\n", 326 | " -0.4490, -0.8927, 0.0128, -0.5971, -0.0677, 0.0101, 0.9477, 1.1218,\n", 327 | " -1.0648, -0.8439, 0.3422, 0.6930, -0.4311, -1.2920])\n", 328 | "tensor([0.0000, 0.8328, 0.6534, 2.0404, 0.9053, 0.0000, 0.0000, 0.0323, 0.9757,\n", 329 | " 0.0000, 1.9665, 1.0276, 0.0000, 0.0588, 0.5085, 0.1956, 0.0000, 0.0000,\n", 330 | " 0.0128, 0.0000, 0.0000, 0.0101, 0.9477, 1.1218, 0.0000, 0.0000, 0.3422,\n", 331 | " 0.6930, 0.0000, 0.0000])\n" 332 | ] 333 | } 334 | ], 335 | "source": [ 336 | "print(x[0])\n", 337 | "print(output_x[0])\n", 338 | "print(y[0])\n", 339 | "print(output_y[0])" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 14, 345 | "metadata": { 346 | "colab": { 347 | "base_uri": "https://localhost:8080/" 348 | }, 349 | "id": "O_IloLJA61F-", 350 | "outputId": "3cd1b78b-2e71-4009-9d08-2d21124eafcd" 351 | }, 352 | "outputs": [ 353 | { 354 | "output_type": "stream", 355 | "name": "stdout", 356 | "text": [ 357 | "Output size : torch.Size([100, 10])\n", 358 | "Output size : torch.Size([100, 30])\n" 359 | ] 360 | } 361 | ], 362 | "source": [ 363 | "x = Variable(torch.randn(100, 10))\n", 364 | "y = Variable(torch.randn(100, 30))\n", 365 | "\n", 366 | "func = nn.LeakyReLU()\n", 367 | "output_x = func(x)\n", 368 | "output_y = func(y)\n", 369 | "print('Output size : ', output_x.size())\n", 370 | "print('Output size : ', output_y.size())" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 15, 376 | "metadata": { 377 | "colab": { 378 | "base_uri": "https://localhost:8080/" 379 | }, 380 | "id": "CIqN-nD961F_", 381 | "outputId": "dc05be10-0bac-4628-a16c-0adff13311c2" 382 | }, 383 | "outputs": [ 384 | { 385 | "output_type": "stream", 386 | "name": "stdout", 387 | "text": [ 388 | "tensor([ 0.3611, -0.3622, 0.5740, -0.3404, -0.1284, 1.4639, 1.3272, 0.0636,\n", 389 | " -1.1366, 1.1084])\n", 390 | "tensor([ 3.6107e-01, -3.6216e-03, 5.7399e-01, -3.4043e-03, -1.2843e-03,\n", 391 | " 1.4639e+00, 1.3272e+00, 6.3646e-02, -1.1366e-02, 1.1084e+00])\n", 392 | "tensor([-0.4000, -0.2603, 0.5494, -1.1904, 1.0810, 0.0770, 0.5700, -1.0860,\n", 393 | " 0.6954, -0.3596, -0.7211, -0.5289, 1.8362, -1.4268, -1.1033, 0.0696,\n", 394 | " 0.5678, 0.5952, 0.2172, 0.5269, 1.4032, -0.3520, -0.7009, 0.0710,\n", 395 | " -0.2730, -1.4919, -1.3549, 0.1566, -1.0187, 0.0810])\n", 396 | "tensor([-0.0040, -0.0026, 0.5494, -0.0119, 1.0810, 0.0770, 0.5700, -0.0109,\n", 397 | " 0.6954, -0.0036, -0.0072, -0.0053, 1.8362, -0.0143, -0.0110, 0.0696,\n", 398 | " 0.5678, 0.5952, 0.2172, 0.5269, 1.4032, -0.0035, -0.0070, 0.0710,\n", 399 | " -0.0027, -0.0149, -0.0135, 0.1566, -0.0102, 0.0810])\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "print(x[0])\n", 405 | "print(output_x[0])\n", 406 | "print(y[0])\n", 407 | "print(output_y[0])" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 16, 413 | "metadata": { 414 | "collapsed": true, 415 | "id": "LKuRr3kH61GA" 416 | }, 417 | "outputs": [], 418 | "source": [ 419 | "import torch.nn.functional as F\n", 420 | "from torch.autograd import Variable\n", 421 | "import matplotlib.pyplot as plt\n", 422 | "%matplotlib inline" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 17, 428 | "metadata": { 429 | "id": "A460kQ_u61GA" 430 | }, 431 | "outputs": [], 432 | "source": [ 433 | "x = torch.linspace(-10, 10, 1500) \n", 434 | "x = Variable(x)\n", 435 | "x_1 = x.data.numpy() # tranforming into numpy" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 18, 441 | "metadata": { 442 | "collapsed": true, 443 | "id": "i8gF1z8O61GB" 444 | }, 445 | "outputs": [], 446 | "source": [ 447 | "y_relu = F.relu(x).data.numpy()\n", 448 | "y_sigmoid = torch.sigmoid(x).data.numpy()\n", 449 | "y_tanh = torch.tanh(x).data.numpy()\n", 450 | "y_softplus = F.softplus(x).data.numpy()" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 19, 456 | "metadata": { 457 | "colab": { 458 | "base_uri": "https://localhost:8080/", 459 | "height": 283 460 | }, 461 | "id": "A5BENnoi61GB", 462 | "outputId": "5e1092ab-0fb3-46ee-ba1f-1afd725eeaa2" 463 | }, 464 | "outputs": [ 465 | { 466 | "output_type": "execute_result", 467 | "data": { 468 | "text/plain": [ 469 | "" 470 | ] 471 | }, 472 | "metadata": {}, 473 | "execution_count": 19 474 | }, 475 | { 476 | "output_type": "display_data", 477 | "data": { 478 | "text/plain": [ 479 | "
" 480 | ], 481 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAD4CAYAAAC5S3KDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAdAUlEQVR4nO3deXjU1b3H8fdXQBalIiF6UbRBS5VCFSGuXK7sILIpWLEugLZYuVq1i9XrvZa6PW4UpeKCiLggUDYJiMoS1EetKCCgshRUqIkgkUVUZD/3jzPYGBMImeX8Zubzep48TGZ+mfnkl1/y4bfMOeacQ0REJGoOCR1ARESkPCooERGJJBWUiIhEkgpKREQiSQUlIiKRVD2VL9agQQOXl5eXypcUEZEIW7hw4RfOudzyHktpQeXl5bFgwYJUvqSIiESYma2t6DEd4hMRkUhSQYmISCSpoEREJJJSeg6qPLt27aKoqIjt27eHjpIStWrVolGjRtSoUSN0FBGRSAteUEVFRdStW5e8vDzMLHScpHLOsXHjRoqKimjcuHHoOCIikRb8EN/27dvJycnJ+HICMDNycnKyZm9RRCQewQsKyIpy2iebvlcRkXhEoqBERETKUkEB1apVo0WLFjRv3pwePXqwZcuW/S4/ZMgQHnjgge/dN2DAACZNmvS9+w4//PCEZxURyRYqKKB27dosXryYDz74gPr16zNixIjQkUREsp4Kqoyzzz6b4uJiAD766CO6du1Kq1ataNOmDStWrAicTkQkewS/zLy0G26AxYsT+5wtWsCDD1Zu2T179jB37lyuuuoqAAYNGsRjjz1GkyZNmD9/PoMHD6awsDCxAUVEpFyRKqhQvv32W1q0aEFxcTFNmzalU6dOfP3117z11ltcdNFF3y23Y8eOCp+jvKvzdMWeiEjVHbCgzGw00B3Y4JxrHruvPjAByAPWAL9wzm2ON0xl93QSbd85qG3bttGlSxdGjBjBgAEDqFevHosruUuXk5PD5s3/XgWbNm2iQYMGyYosIpLxKnMOagzQtcx9NwNznXNNgLmxz9NenTp1GD58OEOHDqVOnTo0btyYiRMnAn4UiCVLllT4tW3btmXChAns3LkTgDFjxtCuXbuU5BYRyUQHLCjn3OvApjJ39wKejt1+Guid4FzBnHbaaZxyyimMGzeOsWPH8uSTT3LqqafSrFkzpk2b9t1yd955J40aNfruo3v37rRp04ZWrVrRokUL3nzzTe69996A34mISPJMmwbJHhTHnHMHXsgsD5hR6hDfFudcvdhtAzbv+7ycrx0EDAI4/vjjW61d+/25qZYvX07Tpk3j+BbSTzZ+zyKSOR57DK65Bu68E269Nb7nMrOFzrn88h6L+zJz5xuuwpZzzo10zuU75/Jzc8ud1VdERNLEqFG+nHr0gD/+MbmvVdWC+tzMGgLE/t2QuEgiIhJFTz8NgwZB164wcSIcemhyX6+qBVUA9I/d7g9M28+yB1SZw4yZIpu+VxHJHGPHwsCB0KEDTJkCNWsm/zUPWFBmNg74B3CSmRWZ2VXAPUAnM1sFdIx9XiW1atVi48aNWfGHe998ULVq1QodRUSk0iZMgCuugLZt/cURtWun5nUP+D4o59wlFTzUIREBGjVqRFFRESUlJYl4usjbN6OuiEg6mDwZLr0UWreG6dOhTp3UvXbwkSRq1Kih2WVFRCKooAD69YMzz4QXX4TDDkvt62uwWBER+YGZM6FvX2jZEl56CerWTX0GFZSIiHzPrFlw4YXw85/DK6/Aj34UJocKSkREvjN3LvTqBSefDLNnQ71yh2BIDRWUiIgA8Npr/g24P/kJzJkD9euHzaOCEhER3ngDzj8f8vL8XlQUJmNQQYmIZLm334Zu3eDYY305HXVU6ESeCkpEJIstWABduvhSKiyEhg1DJ/o3FZSISJZ67z3o1AlycmDePL8HFSUqKBGRLLR0KXTs6C8hLyyE444LneiHVFAiIlnmww/9oK+1a/tyyssLnah8KigRkSyyYoUvpxo1/GG9E08MnahiKigRkSyxahW0b+9vFxZCkyZh8xxI8MFiRUQk+T7+2JfT7t1+z+nkk0MnOjAVlIhIhluzBtq1g23bfDk1axY6UeWooEREMtinn/o9p61b/ZtwTzkldKLKU0GJiGSo4mK/57Rxox9br2XL0IkOjgpKRCQDrVvn95w2bPDTZ5x+euhEB08FJSKSYTZs8JeSFxf7+ZzOOit0oqpRQYmIZJAvvvAjRKxZ42fCbd06dKKqU0GJiGSITZv82HqrVsGMGXDuuaETxUcFJSKSAbZsgc6dYdkyKCjwh/jSnQpKRCTNffmlnzJj6VJ44QV/OxOooERE0thXX8F558GiRTB5sp94MFOooERE0tQ33/hp2t95ByZMgJ49QydKLBWUiEga2rYNevSAN9+E55+HPn1CJ0o8FZSISJrZvh1694ZXX4Vnn4WLLw6dKDk03YaISBrZsQMuuMAPXTR6NFx6aehEyRNXQZnZjWb2oZl9YGbjzKxWooKJiMj37dwJffvCyy/DyJEwYEDoRMlV5YIys2OB3wL5zrnmQDWgX6KCiYjIv+3aBf36+TfgPvII/OpXoRMlX7yH+KoDtc2sOlAH+Cz+SCIiUtru3f5Q3tSpMHw4XHNN6ESpUeWCcs4VAw8A/wLWAV8652aVXc7MBpnZAjNbUFJSUvWkIiJZaM8euOIKmDgRhg6F664LnSh14jnEdyTQC2gMHAMcZmaXlV3OOTfSOZfvnMvPzc2telIRkSyzZw9ceSWMGwf33AO/+13oRKkVzyG+jsAnzrkS59wuYApwTmJiiYhkt717YdAgeOYZuP12+NOfQidKvXgK6l/AWWZWx8wM6AAsT0wsEZHs5RwMHuwvI/+///Mf2Siec1DzgUnAIuD92HONTFAuEZGs5Jw/z/T443DzzfCXv4ROFE5cI0k45/4M/DlBWUREsppz/jzTiBHw+9/D3XeDWehU4WgkCRGRCHDOn2d68EG4/nq4//7sLidQQYmIBOecP890//3+PU7DhqmcQAUlIhLc7bfDXXf50SEefljltI8KSkQkoLvvhiFD/Lh6jz8Oh+iv8ne0KkREArn/frj1VrjsMhg1SuVUllaHiEgAw4bBTTf5AWCfegqqVQudKHpUUCIiKTZihL+cvE8fP+FgdU0dWy4VlIhICo0cCddeC716+TH2VE4VU0GJiKTI6NFw9dXQrRtMmAA1aoROFG0qKBGRFHjmGX8ZeefOMHky1KwZOlH0qaBERJLs+edh4EBo3x5eeAFq1QqdKD2ooEREkmjiRLj8cmjTBgoKoHbt0InShwpKRCRJpk6FSy6Bc86BGTOgTp3QidKLCkpEJAlmzICLL4bTT4eZM+Hww0MnSj8qKBGRBHv5Zf8ep1NP9bfr1g2dKD2poEREEmj2bOjdG5o1g1mz4IgjQidKXyooEZEEmTcPevaEk07yRXXkkaETpTcVlIhIArz+OnTvDieeCHPmQE5O6ETpTwUlIhKnt97yo0McfzzMnQu5uaETZQYVlIhIHN55B7p2hWOOgcJCOPro0IkyhwpKRKSKFi70Qxfl5vpyatgwdKLMooISEamCxYuhUyeoV8+XU6NGoRNlHhWUiMhBev996NjRv/l23jz48Y9DJ8pMKigRkYOwbBl06OBHIy8shMaNQyfKXCooEZFKWrnSj0herZrfc/rJT0Inymyay1FEpBJWr/bltHcvvPYa/PSnoRNlPhWUiMgBfPKJL6cdO+DVV6Fp09CJskNch/jMrJ6ZTTKzFWa23MzOTlQwEZEoWLsW2rWDr7/2I0Q0bx46UfaIdw/qIeBl51xfMzsU0GwnIpIxior8ntOWLX6EiBYtQifKLlUuKDM7AvgvYACAc24nsDMxsUREwvrsM7/n9MUXfuDXVq1CJ8o+8RziawyUAE+Z2XtmNsrMDiu7kJkNMrMFZragpKQkjpcTEUmN9ev9ntP69X4+pzPOCJ0oO8VTUNWBlsCjzrnTgG+Am8su5Jwb6ZzLd87l52oERRGJuJIS/z6nTz/1M+GerTPrwcRTUEVAkXNufuzzSfjCEhFJSxs3+hEiPvnET9nepk3oRNmtygXlnFsPfGpmJ8Xu6gAsS0gqEZEU27zZj623ciUUFPjzTxJWvFfxXQeMjV3B9zEwMP5IIiKptWWLH5X8ww9h2jS/FyXhxVVQzrnFQH6CsoiIpNzWrX4+pyVLYMoUf1uiQSNJiEjW+vprPxPuwoUwcaKfsl2iQwUlIlnpm2/g/PPh7bdh/Hjo3Tt0IilLBSUiWefbb6FnT3jjDXjuOejbN3QiKY8KSkSyyvbtfm9p3jx4+mm45JLQiaQiKigRyRo7dkCfPjBrFjz5JFx+eehEsj+asFBEssLOnfCLX/jRIR5/HK68MnQiORAVlIhkvF27/KG8ggJ4+GEYNCh0IqkMFZSIZLTdu/2hvClTYNgw+O//Dp1IKksFJSIZa88eGDAAJkyA++6DG24InUgOhgpKRDLS3r3wq1/B2LFw113wxz+GTiQHSwUlIhln7164+moYMwaGDIH/+Z/QiaQqVFAiklGcg2uvhVGj4NZb4bbbQieSqlJBiUjGcA6uvx4efRRuugnuuAPMQqeSqlJBiUhGcA7+8Af429/gxhvhnntUTulOBSUiac85f57pr3/1h/eGDlU5ZQIVlIikvT//2e8xXX01DB+ucsoUKigRSWt33OE/rroKHnlE5ZRJVFAikrbuucdfpXfFFTByJByiv2gZRT9OEUlLQ4fCLbfAL38Jo0ernDKRfqQiknaGD/dX7F10kZ/TqVq10IkkGVRQIpJWHn3Uv9fpggv8MEbVNatdxlJBiUjaeOIJGDwYevSA8eOhRo3QiSSZVFAikhbGjPGXkZ93HkycCIceGjqRJJsKSkQi77nn/Ay4HTv6eZ1q1gydSFJBBSUikTZ+PPTvD23bwgsvQK1aoRNJqqigRCSyJk2Cyy6D1q1h+nSoUyd0IkklFZSIRNK0aXDJJXDmmfDii3DYYaETSarFXVBmVs3M3jOzGYkIJCLy4ov+PU4tW8JLL0HduqETSQiJ2IO6HliegOcREeGVV+DCC+GUU/ztH/0odCIJJa6CMrNGwPnAqMTEEZFsNncu9O4NP/sZzJoF9eqFTiQhxbsH9SBwE7C3ogXMbJCZLTCzBSUlJXG+nIhkqldf9W/AbdIEZs+G+vVDJ5LQqlxQZtYd2OCcW7i/5ZxzI51z+c65/Nzc3Kq+nIhksDfegO7dIS8P5syBBg1CJ5IoiGcPqjXQ08zWAOOB9mb2XEJSiUjW+Mc//OgQxx7rD/EddVToRBIVVS4o59wtzrlGzrk8oB9Q6Jy7LGHJRCTjvfsudO0KRx8NhYXQsGHoRBIleh+UiASxaBF07gw5OTBvnt+DEiktIQPVO+deBV5NxHOJSOZbsgQ6dfKXkBcWwnHHhU4kUaQ9KBFJqQ8+8IO+1qnj95zy8kInkqhSQYlIyixfDh06+HmcCgvhhBNCJ5IoU0GJSEr885/Qvj2Y+XJq0iR0Iok6TZYsIkn30Ue+nPbs8Yf1Tj45dCJJByooEUmqNWt8OX37rS+nZs1CJ5J0oYISkaT517+gXTvYutUf1jvllNCJJJ2ooEQkKYqL/Z7T5s1++KLTTgudSNKNCkpEEm7dOr/ntGGDH/g1Pz90IklHKigRSajPP/d7Tp995udzOvPM0IkkXamgRCRhvvjCvwl37Vo/E27r1qETSTpTQYlIQmza5Mtp9Wo/Zfu554ZOJOlOBSUicdu82Y+tt2IFFBT4Q3wi8VJBiUhcvvwSunTxY+xNnepHKBdJBBWUiFTZV1/5yQbfew8mT4Zu3UInkkyighKRKvnmG19I77wDEyZAz56hE0mmUUGJyEHbtg26d4e33oLnn4c+fUInkkykghKRg/Ltt9CrF7z2Gjz7LFx8cehEkqk03YaIVNr27XDhhTB3Ljz1FFx6aehEksm0ByUilbJzJ/TtCy+/DKNGQf/+oRNJptMelIgc0K5d/lDeiy/Co4/CVVeFTiTZQAUlIvu1ezf88pfwwgswfDj85jehE0m2UEGJSIX27IHLL4dJk2DoULjuutCJJJuooESkXHv2wMCBMH483HMP/O53oRNJtlFBicgP7N0Lv/61v4z8jjvgT38KnUiykQpKRL5n71645hp/Gfltt8H//m/oRJKtVFAi8h3n/HmmkSPhlltgyJDQiSSbqaBEBPDldOON8Mgj8Ic/wF13gVnoVJLNqlxQZnacmc0zs2Vm9qGZXZ/IYCKSOs7BTTfBQw/B9dfDffepnCS8eEaS2A383jm3yMzqAgvNbLZzblmCsolICjjnzzM98AAMHgzDhqmcJBqqvAflnFvnnFsUu/0VsBw4NlHBRCQ1/vIXuPtuf9Xe3/6mcpLoSMg5KDPLA04D5pfz2CAzW2BmC0pKShLxciKSIHfd5Qtq4EB47DE4RGelJULi3hzN7HBgMnCDc25r2cedcyOdc/nOufzc3Nx4X05EEuS++/yhvcsvhyeeUDlJ9MS1SZpZDXw5jXXOTUlMJBFJtmHD/Jtv+/Xz73eqVi10IpEfiucqPgOeBJY75/6auEgikkwPP+yHLerTx48UoXKSqIpnD6o1cDnQ3swWxz66JSiXiCTB44/7N+L26gXjxkF1zQgnEVblzdM59wag631E0sSTT/qpMs4/HyZMgBo1QicS2T+dFhXJAk8/7S8j79LFT51Rs2boRCIHpoISyXDPP+8vI+/QAaZOhVq1QicSqRwVlEgG+/vf/WXk554L06ZB7dqhE4lUngpKJENNneqnaj/nHJg+HerUCZ1I5OCooEQy0PTpcPHFcPrpMHMmHH546EQiB08FJZJhXnoJ+vaFFi3g5Zehbt3QiUSqRgUlkkFmzYILLoDmzeGVV+CII0InEqk6FZRIhigs9G/APflkX1RHHhk6kUh8VFAiGeD116FHDzjxRJg9G3JyQicSiZ8KSiTNvfkmdOsGxx8Pc+eCJg2QTKGCEklj8+fDeefBMcf4Q3xHHx06kUjiqKBE0tSCBX7ootxcX04NG4ZOJJJYKiiRNPTee9C5s78QYt48aNQodCKRxFNBiaSZ99+HTp38m2/nzfPnnkQykQpKJI0sW+YHfa1Z05dTXl7oRCLJo4ISSRMrV0L79n4G3Hnz/CXlIplM82mKpIHVq305OefL6ac/DZ1IJPlUUCIR9/HH0K4d7Nzpy6lp09CJRFJDBSUSYWvX+nLats1fSt68eehEIqmjghKJqE8/9eW0dasfIeLUU0MnEkktFZRIBBUX+3NOGzfCnDnQsmXoRCKpp4ISiZj16305rV/vRyU//fTQiUTCUEGJRMiGDf59TkVFfrLBs88OnUgkHBWUSER88QV07AiffOKnaW/TJnQikbBUUCIRsGmTH75o1SqYPh3atg2dSCQ8FZRIYFu2+IFfly2DggK/FyUiKiiRoLZuha5dYelSmDLFT58hIp4KSiSQr77ykw0uXAgTJ0L37qETiURLXIPFmllXM1tpZqvN7OZEhRLJdN984wtp/nwYNw569w6dSCR6qrwHZWbVgBFAJ6AIeNfMCpxzyxIVrqzdu/0hEZF0tm0b9O8Pb7wBY8dC376hE4lEUzyH+M4AVjvnPgYws/FALyBpBbViBfz858l6dpHUOeQQGDMG+vULnUQkuuIpqGOBT0t9XgScWXYhMxsEDAI4Ps6pPxs2hIceiuspRCLhzDP9h4hULOkXSTjnRgIjAfLz8108z5WTA7/9bUJiiYhIxMVzkUQxcFypzxvF7hMREYlbPAX1LtDEzBqb2aFAP6AgMbFERCTbVfkQn3Nut5ldC7wCVANGO+c+TFgyERHJanGdg3LOzQRmJiiLiIjId+J6o66IiEiyqKBERCSSVFAiIhJJKigREYkkFZSIiESSCkpERCJJBSUiIpGkghIRkUhSQYmISCSpoEREJJJUUCIiEkkqKBERiSRzLq45BA/uxcxKgLUJeKoGwBcJeJ5UUd7kS7fM6ZYX0i+z8iZfIjL/2DmXW94DKS2oRDGzBc65/NA5Kkt5ky/dMqdbXki/zMqbfMnOrEN8IiISSSooERGJpHQtqJGhAxwk5U2+dMucbnkh/TIrb/IlNXNanoMSEZHMl657UCIikuFUUCIiEkmRLCgzu8jMPjSzvWaWX+axW8xstZmtNLMuFXx9YzObH1tugpkdmprk373+BDNbHPtYY2aLK1hujZm9H1tuQSozlskxxMyKS2XuVsFyXWPrfbWZ3ZzqnGWy3G9mK8xsqZlNNbN6FSwXdB0faJ2ZWc3Y9rI6ts3mpTpjqSzHmdk8M1sW+/27vpxl2prZl6W2ldtCZC2Tab8/Y/OGx9bxUjNrGSJnLMtJpdbdYjPbamY3lFkm+Do2s9FmtsHMPih1X30zm21mq2L/HlnB1/aPLbPKzPrHFcQ5F7kPoClwEvAqkF/q/p8BS4CaQGPgI6BaOV//d6Bf7PZjwDUBv5ehwG0VPLYGaBCB9T0E+MMBlqkWW98nAIfGfg4/C5i5M1A9dvte4N6orePKrDNgMPBY7HY/YELAddoQaBm7XRf4Zzl52wIzQmWsys8Y6Aa8BBhwFjA/dOZS28d6/BtVI7WOgf8CWgIflLrvPuDm2O2by/udA+oDH8f+PTJ2+8iq5ojkHpRzbrlzbmU5D/UCxjvndjjnPgFWA2eUXsDMDGgPTIrd9TTQO5l5KxLL8gtgXIjXT7AzgNXOuY+dczuB8fifRxDOuVnOud2xT98GGoXKsh+VWWe98Nso+G22Q2y7STnn3Drn3KLY7a+A5cCxIbIkWC/gGee9DdQzs4ahQwEdgI+cc4kYXSehnHOvA5vK3F16W63o72oXYLZzbpNzbjMwG+ha1RyRLKj9OBb4tNTnRfzwFygH2FLqj1d5y6RKG+Bz59yqCh53wCwzW2hmg1KYqzzXxg5/jK5g170y6z6UK/H/Qy5PyHVcmXX23TKxbfZL/DYcVOxQ42nA/HIePtvMlpjZS2bWLKXBynegn3FUt91+VPyf16itY4CjnXPrYrfXA0eXs0xC13X1qn5hvMxsDvAf5Tx0q3NuWqrzHKxK5r+E/e89/adzrtjMjgJmm9mK2P9cEm5/eYFHgTvwv+h34A9LXpmMHAejMuvYzG4FdgNjK3ialK3jTGFmhwOTgRucc1vLPLwIf0jq69i5yheAJqnOWEba/Yxj58V7AreU83AU1/H3OOecmSX9PUrBCso517EKX1YMHFfq80ax+0rbiN+Frx77H2l5y8TtQPnNrDpwIdBqP89RHPt3g5lNxR8SSsovVmXXt5k9Acwo56HKrPuEqsQ6HgB0Bzq42AHwcp4jZeu4HJVZZ/uWKYptM0fgt+EgzKwGvpzGOuemlH28dGE552aa2SNm1sA5F2yQ00r8jFO+7VbCecAi59znZR+I4jqO+dzMGjrn1sUOkW4oZ5li/Dm0fRrhryWoknQ7xFcA9Itd+dQY/7+Kd0ovEPtDNQ/oG7urPxBij6wjsMI5V1Teg2Z2mJnV3Xcbf9L/g/KWTbYyx+MvqCDHu0AT81dIHoo/PFGQinzlMbOuwE1AT+fctgqWCb2OK7POCvDbKPhttrCisk222LmvJ4Hlzrm/VrDMf+w7R2ZmZ+D/hoQs1Mr8jAuAK2JX850FfFnqUFUoFR5dido6LqX0tlrR39VXgM5mdmTsVEHn2H1VE/JKkYo+8H8ki4AdwOfAK6UeuxV/ZdRK4LxS988EjondPgFfXKuBiUDNAN/DGOA3Ze47BphZKuOS2MeH+MNWodb3s8D7wNLYRtiwbN7Y593wV3Z9FDJvLMtq/LHuxbGPfVfCRWodl7fOgNvxxQpQK7aNro5tsycEXKf/iT/Mu7TUeu0G/GbftgxcG1uXS/AXp5wTeDso92dcJrMBI2I/g/cpdWVwoMyH4QvniFL3RWod48tzHbAr9rf4Kvy50bnAKmAOUD+2bD4wqtTXXhnbnlcDA+PJoaGOREQkktLtEJ+IiGQJFZSIiESSCkpERCJJBSUiIpGkghIRkUhSQYmISCSpoEREJJL+H//UzeaLWVhHAAAAAElFTkSuQmCC\n" 482 | }, 483 | "metadata": { 484 | "needs_background": "light" 485 | } 486 | } 487 | ], 488 | "source": [ 489 | "plt.figure(figsize=(7, 4))\n", 490 | "plt.plot(x_1, y_relu, c='blue', label='ReLU')\n", 491 | "plt.ylim((-1, 11))\n", 492 | "plt.legend(loc='best')" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 20, 498 | "metadata": { 499 | "colab": { 500 | "base_uri": "https://localhost:8080/", 501 | "height": 287 502 | }, 503 | "id": "Y6m9Rbnc61GB", 504 | "outputId": "5f16dcf3-b8ee-4be2-ce79-3228bf3c6f37" 505 | }, 506 | "outputs": [ 507 | { 508 | "output_type": "execute_result", 509 | "data": { 510 | "text/plain": [ 511 | "" 512 | ] 513 | }, 514 | "metadata": {}, 515 | "execution_count": 20 516 | }, 517 | { 518 | "output_type": "display_data", 519 | "data": { 520 | "text/plain": [ 521 | "
" 522 | ], 523 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbQAAAD8CAYAAAAfSFHzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXhV1dn38e8NBFDm0QFEUEFlEIeItvIIigNiC9I6VquIihOtOCDY2tbhcaBq69uKKOJQhzJUntKoOFcEZwKiCIpGRQmiAiLFKkLgfv9YJ3IIJyGQnbPP8Ptc176yh5Wz7uyc5D5r7bXXNndHREQk29WJOwAREZEoKKGJiEhOUEITEZGcoIQmIiI5QQlNRERyghKaiIjkhEgSmpndZ2Zfmtk7lRw/3czeNrP5ZvaKmfWMol4REZFyUbXQHgD6V3H8Y6CPu/cArgfGR1SviIgIAPWieBF3n2lmHas4/krS5mtA+yjqFRERKRdJQttG5wBPpjpgZsOAYQCNGjU6aJ999klnXCIikuHmzJmzwt3bpDqW1oRmZkcQElrvVMfdfTyJ7sjCwkIvLi5OY3QiIpLpzOyTyo6lLaGZ2X7ABOA4d1+ZrnpFRCQ/pGXYvpl1AP4P+KW7v5+OOkVEJL9E0kIzs4lAX6C1mZUCfwAKANz9LuD3QCvgTjMDKHP3wijqFhERgehGOZ62lePnAudGUZeISDZbv349paWlrF27Nu5QMlrDhg1p3749BQUF1f6eOEY5iojkrdLSUpo0aULHjh1J9FhJBe7OypUrKS0tpVOnTtX+Pk19JSKSRmvXrqVVq1ZKZlUwM1q1arXNrVglNBGRNFMy27rtOUdKaCIikhOU0ERE8ty5557LwoULa7WOAQMG8PXXX2+x/5prruHWW2+NpA4NChERyXMTJkyo9TqmT59e63WohSYikkf++9//cvzxx9OzZ0+6d+/O5MmT6du3L+VTDd5777106dKFXr16cd555zF8+HAAhgwZwoUXXsihhx7KHnvswYwZMxg6dCj77rsvQ4YM+eH1J06cSI8ePejevTujRo36YX/Hjh1ZsWIFADfccANdunShd+/eLFq0KLKfTS00EZGYjBgB8+ZF+5r77w+331758aeeeopdd92VJ554AoDVq1czbtw4AD777DOuv/565s6dS5MmTTjyyCPp2XPT4ytXrVrFq6++SlFREQMHDuTll19mwoQJHHzwwcybN4+2bdsyatQo5syZQ4sWLTjmmGOYNm0aJ5xwwg+vMWfOHCZNmsS8efMoKyvjwAMP5KCDDorkZ1cLTUQkj/To0YNnn32WUaNGMWvWLJo1a/bDsTfeeIM+ffrQsmVLCgoKOOmkkzb73p/+9KeYGT169GCnnXaiR48e1KlTh27durF48WJmz55N3759adOmDfXq1eP0009n5syZm73GrFmzGDx4MDvuuCNNmzZl4MCBkf1saqGJiMSkqpZUbenSpQtz585l+vTpXH311fTr16/a39ugQQMA6tSp88N6+XZZWdk2zepRG9RCExHJI5999hk77rgjZ5xxBiNHjmTu3Lk/HDv44IN58cUXWbVqFWVlZUydOnWbXrtXr168+OKLrFixgg0bNjBx4kT69OmzWZnDDz+cadOm8d1337FmzRoee+yxSH4uUAtNRCSvzJ8/n5EjR1KnTh0KCgoYN24cV1xxBQDt2rXjN7/5Db169aJly5bss88+m3VJbs0uu+zCzTffzBFHHIG7c/zxxzNo0KDNyhx44IGccsop9OzZk7Zt23LwwQdH9rOZu0f2YlHSAz5FJBe9++677LvvvnGHUalvvvmGxo0bU1ZWxuDBgxk6dCiDBw+OJZZU58rM5lT2tBZ1OYqIyA+uueYa9t9/f7p3706nTp02G6GY6dTlKCIiP4hq1o44qIUmIpJmmXqpJ5NszzlSQhMRSaOGDRuycuVKJbUqlD8PrWHDhtv0fepyFBFJo/bt21NaWsry5cvjDiWjlT+xelsooYmIpFFBQcE2PYVZqk9djiIikhOU0EREJCcooYmISE6IJKGZ2X1m9qWZvVPJcTOzv5hZiZm9bWYHRlGviIhIuahaaA8A/as4fhzQObEMA8ZFVK+IiAgQ0ShHd59pZh2rKDIIeNDDjRevmVlzM9vF3ZdFUb+ISKZxhw0bNl/Kyqq/b+PGTa8T5bK110z1c1S1vS37Bg6EerU4tj5dw/bbAUuStksT+zZLaGY2jNCCo0OHDmkKTUTyzYYNsGoVrFwJX30Fq1fDmjXwzTdbLuX7166F779Pvaxbt+V2eUKSTb75JjcSWrW4+3hgPITZ9mMOR0SyzOrVsGQJLF265fLll5sS2Ndfp25BJKtbF5o0gcaNoVEj2GEHaNAgLE2aQOvWYb1+/U37y5f69cM/7rp1N1+qu698MQuxmEW7bO01K6q4rzplUu3bxok/tlm6EtpSYLek7faJfSIi22TDBnj/fXj77fD1gw+gpCR8XbFiy/Jt2kC7dtC2Ley5J7RqtWlp2TJ8bdZsU/Jq3Dis16+f+p+0ZK50JbQiYLiZTQIOAVbr+pmIbI07vPcevPQSzJ0Lb74ZEtl3320q0749dO4MP/sZ7LUX7L57SGDt2sEuu4QWk+SHSBKamU0E+gKtzawU+ANQAODudwHTgQFACfAtcHYU9YpIbnGHRYvgqadg5kyYNWtTq6t5c9h/fzj/fDjgANhvP+jSBXbcMd6YJXNENcrxtK0cd+DiKOoSkdyyYQO8+CIUFcHjj8OHH4b9e+wBxx8P//M/YencWV2AUrWMGhQiIvnBHd56Cx5+GP7+d1i2LHQN9usHl18OAwaErkORbaGEJiJps3YtTJwIf/1ruB5WUBCS1+mnh6+NGsUdoWQzJTQRqXUrV8Ltt8Ndd4VrYt26wdixcOqpYaShSBSU0ESk1qxcCX/6E/zlL/Df/8KgQfDrX0PfvroeJtFTQhORyK1bB3fcAddeG2baOPlk+N3vQstMpLYooYlIpJ5+Gi65JAy/798fbrkFunePOyrJB3oemohEYtUqOOuskMQ2bIDHHoPp05XMJH3UQhORGnvySTj3XPjiC7j66rBohg5JN7XQRGS7lZXB6NFhyH2LFvD663D99UpmEg+10ERkuyxbFobdz5wZpqO6/fban01dpCpKaCKyzebNC9NSff01PPgg/PKXcUckooQmItvoqafgpJPCZMGvvAI9e8YdkUiga2giUm333w8/+Ul4TMtrrymZSWZRQhORarn7bhg6NEwgPHNmeN6YSCZRQhORrbrjDrjggtA6KyoKT3QWyTRKaCJSpXHj4Fe/ghNOgKlTNSRfMpcGhYhIpaZMgYsvhp/+NKwXFMQdkUjl1EITkZSefx7OOAMOOwwmT1Yyk8ynhCYiW3jrrdDFuPfe4ZrZDjvEHZHI1imhichmli8Pzy1r1izcc9aiRdwRiVSPrqGJyA/Wrw83TX/+OcyapaH5kl2U0ETkByNGwIsvwkMPwcEHxx2NyLaJpMvRzPqb2SIzKzGz0SmOdzCzF8zsTTN728wGRFGviETnkUfgzjvhiivCYBCRbFPjhGZmdYGxwHFAV+A0M+taodjVwBR3PwA4FbizpvWKSHQ++CDcON27N9x0U9zRiGyfKFpovYASd//I3dcBk4BBFco40DSx3gz4LIJ6RSQC338fHgNTUAB//zvU04UIyVJRvHXbAUuStkuBQyqUuQZ4xsx+BTQCjkr1QmY2DBgG0KFDhwhCE5GtGT0a5s6FadNgt93ijkZk+6Vr2P5pwAPu3h4YADxkZlvU7e7j3b3Q3QvbtGmTptBE8te//x0ezDl8eBiqL5LNokhoS4Hkz3XtE/uSnQNMAXD3V4GGQOsI6haR7bRmTZg9v3NnGDMm7mhEai6KhDYb6GxmncysPmHQR1GFMp8C/QDMbF9CQlseQd0isp2uvBI+/RQeeAB23DHuaERqrsYJzd3LgOHA08C7hNGMC8zsOjMbmCh2OXCemb0FTASGuLvXtG4R2T7PPQd33QWXXQY//nHc0YhEwzI1rxQWFnpxcXHcYYjknG+/hW7dwmNg3nxT8zRKdjGzOe5emOqYBuiK5JkbboDFi8OMIEpmkks0ObFIHnn3XbjlFjjzTDj88LijEYmWEppInnCHiy6Cxo1DUhPJNepyFMkTjzwCM2aEwSBt28YdjUj01EITyQNr1oRJh3v1gvPOizsakdqhFppIHhgzBr74Ijx9uo4+xkqO0ltbJMd9+incdhucfnpooYnkKiU0kRx31VXh6403xhuHSG1TQhPJYa+/Hh4Jc/nloAdYSK5TQhPJUe5haquddoJRo+KORqT2aVCISI6aNg1eeQXGj4cmTeKORqT2qYUmkoM2bICrr4a994azz447GpH0UAtNJAdNnAgLF8LkyVBPf+WSJ9RCE8kx69fDH/4APXvCiSfGHY1I+uizm0iOuf9++OgjePxx3UQt+UVvd5EcsnYtXHcd/OhHMGBA3NGIpJdaaCI5ZNw4WLoUHnoIzOKORiS91EITyRHffgs33QRHHQVHHBF3NCLpp4QmkiPuuw+WLw8DQkTykRKaSA5Yvx5uvRUOOwx69447GpF46BqaSA6YPBk++QTuuCPuSETioxaaSJbbuBFuvhm6d9fIRslvkSQ0M+tvZovMrMTMRldS5mQzW2hmC8zs71HUKyIwfTosWBAmINZ9Z5LPatzlaGZ1gbHA0UApMNvMitx9YVKZzsBVwGHuvsrM2ta0XhEJbr4Zdt8dTjkl7khE4hXF57leQIm7f+Tu64BJwKAKZc4Dxrr7KgB3/zKCekXy3ksvwcsvwxVXQEFB3NGIxCuKhNYOWJK0XZrYl6wL0MXMXjaz18ysfwT1iuS9MWOgdWsYOjTuSETil65RjvWAzkBfoD0w08x6uPvXyYXMbBgwDKCDHq8rUqX588N8jdddBzvuGHc0IvGLooW2FNgtabt9Yl+yUqDI3de7+8fA+4QEtxl3H+/uhe5e2KZNmwhCE8ldf/wjNGoEF18cdyQimSGKhDYb6GxmncysPnAqUFShzDRC6wwza03ogvwogrpF8tLixeGZZ+efDy1bxh2NSGaocUJz9zJgOPA08C4wxd0XmNl1ZjYwUexpYKWZLQReAEa6+8qa1i2Sr267LQzRv/TSuCMRyRzm7nHHkFJhYaEXFxfHHYZIxlm+PAzTP+00uPfeuKMRSS8zm+PuhamO6TZMkSzz17+G556NHBl3JCKZRQlNJIusWRPmaxw8GPbZJ+5oRDKLEppIFrnnHli1KkxzJSKbU0ITyRLffx8GgxxxBPTqFXc0IplHj48RyRKPPAKffQb33x93JCKZSS00kSywYUO4kfqAA+Doo+OORiQzqYUmkgX+9S9YtCg8yNMs7mhEMpNaaCIZzj08ImbPPeHnP487GpHMpRaaSIabMQNmz4a774a6deOORiRzqYUmkuFuvhl22gnOPDPuSEQymxKaSAabOxeeeSbM2diwYdzRiGQ2JTSRDDZmDDRtChdcEHckIplPCU0kQ33wATz6KFx0ETRrFnc0IplPCU0kQ916KxQUwCWXxB2JSHZQQhPJQMuWwQMPwNlnw847xx2NSHZQQhPJQLffDmVlcMUVcUcikj2U0EQyzOrVcNddcPLJ4WZqEakeJTSRDDNuHPznP3DllXFHIpJdlNBEMsh334XuxmOPDRMRi0j1KaGJZJC//Q2++AJGj447EpHso4QmkiHKysIjYg45BPr0iTsakeyjyYlFMsTkyfDxx6HLUY+IEdl2kbTQzKy/mS0ysxIzq7SzxMx+bmZuZoVR1CuSKzZuDJMQd+sGP/lJ3NGIZKcat9DMrC4wFjgaKAVmm1mRuy+sUK4JcAnwek3rFMk1jz8O77wDDz8MdXQhQGS7RPGn0wsocfeP3H0dMAkYlKLc9cAYYG0EdYrkDHe48Ubo1AlOOSXuaESyVxQJrR2wJGm7NLHvB2Z2ILCbuz8RQX0iOWXGDHj99XDfWT1d1RbZbrXeuWFmdYA/AZdXo+wwMys2s+Lly5fXdmgiGeHGG8N8jUOGxB2JSHaLIqEtBXZL2m6f2FeuCdAdmGFmi4FDgaJUA0Pcfby7F7p7YZs2bSIITSSzzZ4Nzz0Hl12mB3iK1FQUCW020NnMOplZfeBUoKj8oLuvdvfW7t7R3TsCrwED3b04grpFstpNN0Hz5nqAp0gUapzQ3L0MGA48DbwLTHH3BWZ2nZkNrOnri+Sqd96Bf/4TfvUraNIk7mhEsp+5e9wxpFRYWOjFxWrESe46+WR46ilYvBhatow7GpHsYGZz3D3lvcy640UkBvPnwz/+EZ5GrWQmEg0lNJEYXHstNG0Kl14adyQiuUMJTSTN3noLpk5V60wkakpoIml27bXQrJlaZyJRU0ITSaN588LIxhEjoEWLuKMRyS1KaCJpVN46GzEi7khEco8SmkiavPYaTJsGl18ebqYWkWgpoYmkgTuMGgVt2+ramUht0dzeImnw5JMwcyaMHQuNG8cdjUhuUgtNpJZt2ACjR8Nee8F558UdjUjuUgtNpJY98kiYGWTyZCgoiDsakdylFppILVq7Fn73OygshBNPjDsakdymFppILfrzn+HTT+G++6COPj6K1Cr9iYnUkqVL4YYbYNAg6Ncv7mhEcp8SmkgtGT0a1q+H226LOxKR/KCEJlILXn0VHn443ES9555xRyOSH5TQRCK2cSP8+tew667wm9/EHY1I/tCgEJGITZgAxcXw0EO6iVokndRCE4nQsmVw5ZXQty+cfnrc0YjkFyU0kQiNGBHuPbv7bjCLOxqR/KKEJhKRJ56AKVPg6quhS5e4oxHJP0poIhH45hu46CLo2jV0OYpI+mlQiEgELrsMliyBWbOgfv24oxHJT5G00Mysv5ktMrMSMxud4vhlZrbQzN42s+fNbPco6hXJBE88AffcAyNHwmGHxR2NSP6qcUIzs7rAWOA4oCtwmpl1rVDsTaDQ3fcDHgX+WNN6RTLBihVwzjmw335w3XVxRyOS36JoofUCStz9I3dfB0wCBiUXcPcX3P3bxOZrQPsI6hWJlTtccAF89VW456xBg7gjEslvUSS0dsCSpO3SxL7KnAM8meqAmQ0zs2IzK16+fHkEoYnUnnHjYOpU+N//DS00EYlXWkc5mtkZQCFwS6rj7j7e3QvdvbBNmzbpDE1kmxQXw6WXwoABcMUVcUcjIhDNKMelwG5J2+0T+zZjZkcBvwX6uPv3EdQrEotVq+Ckk2DnneHBB/WcM5FMEUVCmw10NrNOhER2KvCL5AJmdgBwN9Df3b+MoE6RWGzYAGeeGZ51NmsWtGoVd0QiUq7Gny3dvQwYDjwNvAtMcfcFZnadmQ1MFLsFaAz8w8zmmVlRTesVicNVV8Hjj8Ptt8Mhh8QdjYgki+TGanefDkyvsO/3SetHRVGPSJwmTIBbboHhw8OsICKSWdT7L1INL7wAF14Ixx4Lf/5z3NGISCpKaCJbUVwMgwaFCYcnT4Z6mjBOJCMpoYlU4Z13QqusVSt45hlo1izuiESkMkpoIpUoKYFjjgkzgDz3HLSraroAEYmdEppICgsWwOGHw7p18OyzsOeecUckIlujhCZSwezZIZkBzJgB3brFGo6IVJMSmkiS556Dfv2gadNw43T37nFHJCLVpYQmkjB2LPTvD7vvDi+9pG5GkWyjhCZ5b926cKP08OFw3HHwyisaACKSjZTQJK99+CH07h0eBTNyJEybBk2axB2ViGwP3SIqeckdJk2C88+HunXh0Ufh5z+POyoRqQm10CTvLF0KP/sZ/OIX0KMHzJunZCaSC5TQJG+UlYWuxa5d4amnYMwYePHFMAhERLKfuhwl57nDk0+Ga2QLF8KRR8Ldd8Nee8UdmYhESS00yVnuYZb8o46C448PoxmnTg33mimZieQeJTTJORs3QlER/PjHoTW2YEF4IOeCBeHamVncEYpIbVCXo+SMZcvg/vvhnntg8WLo2BHuvBOGDIEddog5OBGpdUpoktX+8x947DGYMgWmTw8DP444Am66KYxcLCiIO0IRSRclNMk6n3wSZsAvKoKnnw7Xxtq1gxEj4LzzwoM4RST/KKFJRnOHJUvg9dfDEPtnn4X33w/HdtsNLr4YTjwRDj0U6uiKsEheU0KTjLFxI3z8cRhaP38+vPFGSGSffx6ON2oEffrAhRfC0UeH+8k0wENEyimhSVqtWxdaXIsXh67DxYvho49CEnvvPfjuu01lO3cOQ+4PPRQOOQT22w/q148rchHJdJEkNDPrD/w/oC4wwd1vrnC8AfAgcBCwEjjF3RdHUbfEZ+NGWLMmDMwoX77+Gr78MvXy+edhJKL7pteoUydc/+raNQzm6No1PFBz332hWbP4fjYRyT41TmhmVhcYCxwNlAKzzazI3RcmFTsHWOXue5nZqcAY4JSa1r01ZWXhn677pn+iVa1v7XgmlIXwM5WVwYYN2/+1fP3772Ht2i2X775LvS85ga1ZU/X5b9gQdtoJ2raFXXaBnj2hQ4cwnH733cPX9u01ElFEohFFC60XUOLuHwGY2SRgEJCc0AYB1yTWHwXuMDNzT/6sHr3Ro+G222qzhtxSv35IQhWXHXYIX5s3DwmpadOwNGu2ab18ad48JLC2bcM1L13jEpF0iSKhtQOWJG2XAodUVsbdy8xsNdAKWJFcyMyGAcMAOnToUOPAjjsOWrYsf+1N/1yrWo+r7La8lhnUqxeWunW3XN+Wrw0ahGTVoIFGCYpIdsuoQSHuPh4YD1BYWFjj1lu/fmEREZHcF8Vn8qXAbknb7RP7UpYxs3pAM8LgEBERkUhEkdBmA53NrJOZ1QdOBYoqlCkCzkqsnwj8u7avn4mISH6pcZdj4prYcOBpwrD9+9x9gZldBxS7exFwL/CQmZUAXxGSnoiISGQiuYbm7tOB6RX2/T5pfS1wUhR1iYiIpKJxbSIikhOU0EREJCcooYmISE5QQhMRkZyghCYiIjlBCU1ERHKCEpqIiOQEJTQREckJSmgiIpITlNBERCQnKKGJiEhOUEITEZGcoIQmIiI5QQlNRERyghKaiIjkBCU0ERHJCUpoIiKSE5TQREQkJyihiYhITlBCExGRnKCEJiIiOUEJTUREckKNEpqZtTSzZ83sg8TXFinK7G9mr5rZAjN728xOqUmdIiIiqdS0hTYaeN7dOwPPJ7Yr+hY40927Af2B282seQ3rFRER2UxNE9og4G+J9b8BJ1Qs4O7vu/sHifXPgC+BNjWsV0REZDP1avj9O7n7ssT658BOVRU2s15AfeDDSo4PA4YlNr8xs0U1jA+gNbAigtdJp2yLWfHWvmyLOdviheyLOV/j3b2yA+buVX6nmT0H7Jzi0G+Bv7l786Syq9x9i+toiWO7ADOAs9z9tWoEHQkzK3b3wnTVF4Vsi1nx1r5siznb4oXsi1nxbmmrLTR3P6qyY2b2hZnt4u7LEgnry0rKNQWeAH6bzmQmIiL5o6bX0IqAsxLrZwH/qljAzOoD/wQedPdHa1ifiIhISjVNaDcDR5vZB8BRiW3MrNDMJiTKnAwcDgwxs3mJZf8a1rstxqexrqhkW8yKt/ZlW8zZFi9kX8yKt4KtXkMTERHJBpopREREcoISmoiI5IScSGhmdlJiaq2NZlZY4dhVZlZiZovM7NhKvr+Tmb2eKDc5MZAlbRJ1ll9fXGxm8yopt9jM5ifKFaczxgpxXGNmS5NiHlBJuf6J815iZqlmkUkLM7vFzN5LTL32z8pmqon7/G7tfJlZg8R7pSTxfu2Y7hgrxLObmb1gZgsTf3+XpCjT18xWJ71Xfh9HrEnxVPk7tuAviXP8tpkdGEecSfHsnXTu5pnZf8xsRIUysZ5jM7vPzL40s3eS9m11WsREubMSZT4ws7NSldkm7p71C7AvsDfhPrfCpP1dgbeABkAnwg3ddVN8/xTg1MT6XcCFMf4stwG/r+TYYqB1Bpzva4ArtlKmbuJ870G4mf4toGtM8R4D1EusjwHGZNr5rc75Ai4C7kqsnwpMjvl9sAtwYGK9CfB+ipj7Ao/HGee2/I6BAcCTgAGHAq/HHXOF98jnwO6ZdI4Jg/4OBN5J2vdHYHRifXSqvzmgJfBR4muLxHqLmsSSEy00d3/X3VPNKjIImOTu37v7x0AJ0Cu5gJkZcCRQfktByim80iERy8nAxDjqj1gvoMTdP3L3dcAkwu8j7dz9GXcvS2y+BrSPI46tqM75Sp5q7lGgX+I9Ewt3X+bucxPra4B3gXZxxRORQYRbjNzDPbPNE/fYZoJ+wIfu/kncgSRz95nAVxV2b3VaROBY4Fl3/8rdVwHPEub73W45kdCq0A5YkrRdypZ/cK2Ar5P+4aUqky7/A3zhibkvU3DgGTObk5gmLE7DE10y91XSnVCdcx+HoYRP4KnEeX6rc75+KJN4v64mvH9jl+j+PAB4PcXhH5nZW2b2pJl1S2tgW9ra7zhT37cQWuWVfdjNpHMM1ZsWMfJzXdO5HNPGqpiCy923uKE701Qz/tOounXW292Xmllb4Fkzey/x6ShyVcULjAOuJ/xzuJ7QTTq0NuKoruqcXzP7LVAGPFLJy6Tt/OYSM2sMTAVGuPt/KhyeS+gi+yZxrXUa0DndMSbJyt9x4rr+QOCqFIcz7Rxvxt3dzNJyf1jWJDSvYgquKiwFdkvabp/Yl2wloVuhXuJTb6oyNba1+M2sHvAz4KAqXmNp4uuXZvZPQjdVrfwxVvd8m9k9wOMpDlXn3EemGud3CPAToJ8nOvBTvEbazm8K1Tlf5WVKE++XZoT3b2zMrICQzB5x9/+reDw5wbn7dDO708xau3ssk+pW43ec1vftNjgOmOvuX1Q8kGnnOKE60yIuJVz/K9eeMA5iu+V6l2MRcGpidFgnwqeWN5ILJP65vQCcmNiVcgqvNDgKeM/dS1MdNLNGZtakfJ0w0OGdVGVrW4VrCoMriWM20NnCCNL6hO6SonTEV5GZ9QeuBAa6+7eVlIn7/FbnfCVPNXci8O/KknM6JK7f3Qu86+5/qqTMzuXX+Sw8baMOMSXhav6Oi4AzE6MdDwVWJ3WdxanS3ptMOsdJtjotIvA0cIyZtUhctlxcpXgAAAESSURBVDgmsW/7xTUyJsqF8E+1FPge+AJ4OunYbwmjxxYBxyXtnw7smljfg5DoSoB/AA1i+BkeAC6osG9XYHpSjG8llgWErrS4zvdDwHzg7cQbd5eK8Sa2BxBGvn0Yc7wlhL76eYmlfKRgRp3fVOcLuI6QiAEaJt6fJYn36x5xndNEPL0J3c5vJ53bAcAF5e9lYHjifL5FGJDz4xjjTfk7rhCvAWMTv4P5JI2ajjHuRoQE1SxpX8acY0KiXQasT/wfPodwbfd54APgOaBlomwhMCHpe4cm3s8lwNk1jUVTX4mISE7I9S5HERHJE0poIiKSE5TQREQkJyihiYhITlBCExGRnKCEJiIiOUEJTUREcsL/B+XtYFanXrNhAAAAAElFTkSuQmCC\n" 524 | }, 525 | "metadata": { 526 | "needs_background": "light" 527 | } 528 | } 529 | ], 530 | "source": [ 531 | "plt.figure(figsize=(7, 4))\n", 532 | "plt.plot(x_1, y_sigmoid, c='blue', label='sigmoid')\n", 533 | "plt.ylim((-0.2, 1.2))\n", 534 | "plt.legend(loc='best')" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 21, 540 | "metadata": { 541 | "colab": { 542 | "base_uri": "https://localhost:8080/", 543 | "height": 283 544 | }, 545 | "id": "dm1BlKru61GC", 546 | "outputId": "df69a1fa-5387-41da-f657-b898925b84f2" 547 | }, 548 | "outputs": [ 549 | { 550 | "output_type": "execute_result", 551 | "data": { 552 | "text/plain": [ 553 | "" 554 | ] 555 | }, 556 | "metadata": {}, 557 | "execution_count": 21 558 | }, 559 | { 560 | "output_type": "display_data", 561 | "data": { 562 | "text/plain": [ 563 | "
" 564 | ], 565 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbQAAAD4CAYAAACE2RPlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAbTklEQVR4nO3de5wV9X3/8deHZQGNqFwVQbJoUPGSIK7EC6YoKpdGVFINaU20tvKI0faXWxMaHw9jkj4e1fhI0jTVKF6qtd4ao3UlFAWCVSvKHhBkEZSFYFxErhFRQFj28/tjZvG4nmV3OZfvmTnv5+NxHmcu3z3z2dnZfe/MfGfG3B0REZGk6xa6ABERkUJQoImISCoo0EREJBUUaCIikgoKNBERSYXuoQtoT//+/b2mpiZ0GSIiUkYWLVq02d0H5JpXtoFWU1NDJpMJXYaIiJQRM3uzvXk65CgiIqmgQBMRkVRQoImISCqU7Tm0XPbs2UNTUxO7du0KXUpR9erViyFDhlBdXR26FBGRxEhUoDU1NdG7d29qamows9DlFIW7s2XLFpqamhg2bFjockREEiNRhxx37dpFv379UhtmAGZGv379Ur8XKiJSaIkKNCDVYdaqEr5HEZFCS1ygiYiI5KJA66J3332X22+//YC/fuzYsbpgXESkCBRoXZRvoImISHEo0Lpo+vTprF69mpEjR/Ktb32LcePGMWrUKE455RSefPJJANauXcuIESO45pprOOmkk7jwwgvZuXPnvs/4zW9+w+jRoznuuON4/vnnQ30rIiKpkqhu+9m++U1YsqSwnzlyJPzLv+y/zc0330xDQwNLliyhubmZHTt2cOihh7J582bOOOMMJk+eDMCqVat4+OGHueuuu7j88sv57W9/yxVXXAFAc3MzCxcuZNasWfzoRz9i7ty5hf1GREQqUGIDrRy4Oz/4wQ947rnn6NatG+vWrWPDhg0ADBs2jJEjRwJw2mmnsXbt2n1fN2XKlJzTRUTkwBUk0MzsXuCLwEZ3PznHfAN+CUwCdgBXufvifJbZ0Z5UKTz44INs2rSJRYsWUV1dTU1Nzb7rx3r27LmvXVVV1ccOObbOq6qqorm5ubRFi4ikVKHOod0HTNjP/InA8Pg1Dfh1gZZbcr1792b79u0AbNu2jYEDB1JdXc38+fN58812n2ogIiJFVpA9NHd/zsxq9tPkYuA/3N2Bl8zscDMb5O7rC7H8UurXrx9nn302J598MqeffjorV67klFNOoba2lhNOOCF0eSKJtmcP7NwZvTc3w9690at1uO373r3Q0gLu0dfnet/fvK60kfydey50L+KJrlKdQxsMvJU13hRP+1igmdk0oj04hg4dWqLSuu6hhx7qsE1DQ8O+4e9+97v7hp999tl9w/3799c5NEmdlhbYuBH++Mfo1dQEW7bA1q0fvW/dCh98EIXXzp2wY0f0vndv6OqlmLZvh0MOKd7nl1WnEHefAcwAqK2t1f9EImXuww/h5Zdh4UJYtgxefRVWroS2tyLt1g369IG+faPXgAEwbBgcfDAcdFD0yh7u0QOqqqJX9+4ff287XFUVLaP1jnG53vc3ryttJD8HHVTczy9VoK0Djs4aHxJPE5GEefNNeOwxmDkTFiyIQg3gqKPglFPgvPPgmGNg6NDodfTRcPjhUaiJFFOpAq0OuN7MHgE+D2w70PNn7p76m/e6DthLmdm9Owqx226DF1+Mpn3uc/CNb8DYsXDWWdC/f9ASRQrWbf9hYCzQ38yagB8C1QDufgcwi6jLfiNRt/2/PpDl9OrViy1btqT6ETKtz0Pr1atX6FJE2LMH7rwT/vmf4e234bjjouHLLoNjjw1dncjHFaqX41c6mO/AdfkuZ8iQITQ1NbFp06Z8P6qstT6xWiSkOXPg+uvhjTfgC1+Au+6CCRN06FDKV1l1CulIdXW1nuIsUmQ7dsA//APcfnu0R1ZXB1/8ojpFSPlLVKCJSHE1NcHkydF9Ur/9bfinfyp+zzSRQlGgiQgQdbkfPz66Puypp+DP/zx0RSJdo0ATEZYti7rbH3RQ1Ivx5E/ckVWk/CnQRCrcqlUwbhz07Anz58NnPhO6IpEDo/5KIhXs3Xfhooui21UpzCTptIcmUqFaWuCv/gpWr4Z586IejSJJpkATqVC/+hXMmhXd/eMLXwhdjUj+dMhRpAI1NMD3vx8dbrz22tDViBSGAk2kwrS0wDXXwKGHwt1364JpSQ8dchSpMPfdBy+9BPffDwMHhq5GpHC0hyZSQf70p+hQ45gx8NWvhq5GpLAUaCIV5NZbo6dG/+pXOtQo6aNAE6kQGzbAL38JU6fCyJGhqxEpPAWaSIW4+ebo6dI33RS6EpHiUKCJVIDNm+GOO6LzZrqAWtJKgSZSAX79a9i1K3rOmUhaKdBEUm7XLvi3f4OJE+HEE0NXI1I8CjSRlHvoIdi4Eb7zndCViBSXAk0k5e64A046KXremUiaKdBEUmzpUqivj251pevOJO0UaCIpds890YM7dVcQqQQKNJGU2rkTHngApkyBvn1DVyNSfAo0kZT63e+iJ1JffXXoSkRKQ4EmklKPPAJHHAHnnhu6EpHSUKCJpND27dEe2mWXQVVV6GpESkOBJpJCdXXRBdVf/nLoSkRKR4EmkkKPPgpDhsBZZ4WuRKR0FGgiKbNtG8yeDZdfDt30Gy4VRJu7SMrMng179kTd9UUqiQJNJGVmzoR+/eCMM0JXIlJaCjSRFGluhlmzYNIk9W6UyqNAE0mRBQtg61a46KLQlYiUngJNJEVmzoTu3WH8+NCViJSeAk0kRZ56Cv7sz+DQQ0NXIlJ6CjSRlPjjH2HFiuj8mUglUqCJpMS8edH7BReErUMklIIEmplNMLPXzazRzKbnmH+VmW0ysyXx628LsVwR+ci8eTBwIJx8cuhKRMLonu8HmFkVcBtwAdAE1JtZnbu/1qbpo+5+fb7LE5FPcoe5c2HcOD2ZWipXIfbQRgON7r7G3XcDjwAXF+BzRaSTli+HDRvg/PNDVyISTiECbTDwVtZ4UzytrS+Z2atm9piZHZ3rg8xsmpllzCyzadOmApQmUhlaz58p0KSSlapTyFNAjbt/FpgD3J+rkbvPcPdad68dMGBAiUoTSb65c2H4cBg6NHQlIuEUItDWAdl7XEPiafu4+xZ3/zAevRs4rQDLFRGiGxH/7/9G589EKlkhAq0eGG5mw8ysBzAVqMtuYGaDskYnAysKsFwRAV55JXpC9Xnnha5EJKy8ezm6e7OZXQ88DVQB97r7cjP7MZBx9zrg781sMtAMbAWuyne5IhJ5/vno/ZxzwtYhEpq5e+gacqqtrfVMJhO6DJGyd+ml0NAAq1aFrkSk+MxskbvX5pqnO4WIJJg7vPACjBkTuhKR8BRoIgn2xhuwebMCTQQUaCKJ1nr+TIEmokATSbQXXoABA+C440JXIhKeAk0kwVrPn+n+jSIKNJHEWr8eVq/W4UaRVgo0kYR64YXoXYEmElGgiSTUggXQqxecemroSkTKgwJNJKEWLYrCrLo6dCUi5UGBJpJAe/fC4sVQm/N+CSKVSYEmkkBvvAHvvw+n6bkVIvso0EQSqPU2p9pDE/mIAk0kgTIZOPhgOOGE0JWIlA8FmkgCZTIwahRUVYWuRKR8KNBEEqa5GZYs0eFGkbYUaCIJs3Il7NihQBNpS4EmkjCtHULUw1Hk4xRoIgmTycAhh+gO+yJtKdBEEiaTifbOuum3V+Rj9CshkiB79qhDiEh7FGgiCfLaa/Dhhwo0kVwUaCIJojuEiLRPgSaSIJkMHHYYHHts6EpEyo8CTSRBWjuEmIWuRKT8KNBEEuLDD2HpUh1uFGmPAk0kIRoaol6OCjSR3BRoIgmxaFH0rkATyU2BJpIQmQz06QM1NaErESlPCjSRhMhkor0zdQgRyU2BJpIAu3bBsmU63CiyPwo0kQR49dXoOWgKNJH2KdBEEkB3CBHpmAJNJAEWLYIBA+Doo0NXIlK+FGgiCaA7hIh0TIEmUuZ27IDly3W4UaQjCjSRMrd0Kezdq0AT6UhBAs3MJpjZ62bWaGbTc8zvaWaPxvNfNrOaQixXpBKoQ4hI5+QdaGZWBdwGTAROBL5iZie2afY3wJ/c/TPAL4Bb8l2uSKXIZODII+Goo0JXIlLeCrGHNhpodPc17r4beAS4uE2bi4H74+HHgHFmOr0t0hmLFukOISKdUYhAGwy8lTXeFE/L2cbdm4FtQL+2H2Rm08wsY2aZTZs2FaA0kWR7/31YsSLq4Sgi+1dWnULcfYa717p77YABA0KXIxLckiXQ0qLzZyKdUYhAWwdkX+45JJ6Ws42ZdQcOA7YUYNkiqdbaIUR7aCIdK0Sg1QPDzWyYmfUApgJ1bdrUAVfGw38B/N7dvQDLFkm1TAYGD4ZBg0JXIlL+uuf7Ae7ebGbXA08DVcC97r7czH4MZNy9DrgHeMDMGoGtRKEnIh1ofWSMiHQs70ADcPdZwKw2027MGt4FXFaIZYlUivfegzfegCuuCF2JSDKUVacQEfnIK6+Au86fiXSWAk2kTKlDiEjXKNBEylQmA0OHwsCBoSsRSQYFmkiZUocQka5RoImUoXffhcZGBZpIVyjQRMrQ4sXRuwJNpPMUaCJlqLVDyKhRYesQSRIFmkgZymRg2DDo94lbeItIexRoImWovl6HG0W6SoEmUmY2bYK1a+H000NXIpIsCjSRMtN6/kyBJtI1CjSRMrNwYfR0at0hRKRrFGgiZaa+Hk44AXr3Dl2JSLIo0ETKiHsUaDrcKNJ1CjSRMvLWW7BxI4weHboSkeRRoImUkfr66F17aCJdp0ATKSP19VBdDZ/7XOhKRJJHgSZSRurr4bOfhZ49Q1cikjwKNJEy0dISXYOmw40iB0aBJlImVq2C995ToIkcKAWaSJlQhxCR/CjQRMrEwoVw8MEwYkToSkSSSYEmUiYWLIiuP+vePXQlIsmkQBMpAx98AK+8AmedFboSkeRSoImUgUwG9u5VoInkQ4EmUgZefDF6P+OMsHWIJJkCTaQMLFgAxx8P/fqFrkQkuRRoIoG5R3toOtwokh8Fmkhgq1bBli0KNJF8KdBEAms9f6ZAE8mPAk0ksBdfhMMPj55SLSIHToEmEtiCBXDmmdBNv40iedGvkEhAW7ZAQwOcfXboSkSST4EmEtBzz0Xv554btg6RNFCgiQQ0f350Q+La2tCViCSfAk0koGefjQ439ugRuhKR5Msr0Mysr5nNMbNV8XufdtrtNbMl8asun2WKpMXmzbBsGYwdG7oSkXTIdw9tOjDP3YcD8+LxXHa6+8j4NTnPZYqkQuv5MwWaSGHkG2gXA/fHw/cDl+T5eSIVo/X8mZ5QLVIY+QbaEe6+Ph5+BziinXa9zCxjZi+ZWbuhZ2bT4naZTZs25VmaSHl79lkYMwaqq0NXIpIOHT4b18zmAkfmmHVD9oi7u5l5Ox/zaXdfZ2bHAL83s2XuvrptI3efAcwAqK2tbe+zRBJv48bo+rO//MvQlYikR4eB5u7ntzfPzDaY2SB3X29mg4CN7XzGuvh9jZk9C5wKfCLQRCrFM89E7xdeGLYOkTTJ95BjHXBlPHwl8GTbBmbWx8x6xsP9gbOB1/JcrkiizZ4NAwbAqaeGrkQkPfINtJuBC8xsFXB+PI6Z1ZrZ3XGbEUDGzJYC84Gb3V2BJhWrpQWefhrGj9f9G0UKqcNDjvvj7luAcTmmZ4C/jYdfBE7JZzkiabJ4cXQN2sSJoSsRSRf9fyhSYrNngxlccEHoSkTSRYEmUmKzZ0f3bhwwIHQlIumiQBMpoc2bo+efTZgQuhKR9FGgiZRQXV3UKeTSS0NXIpI+CjSREnr8caipgZEjQ1cikj4KNJES2b4d5syBKVOiTiEiUlgKNJESmTULdu/W4UaRYlGgiZTI44/DEUfAmWeGrkQknRRoIiXw/vswcyZccglUVYWuRiSdFGgiJfD447BjB3z1q6ErEUkvBZpICTzwABxzDJx1VuhKRNJLgSZSZOvWwbx5cMUV6t0oUkwKNJEie/BBcI8CTUSKR4EmUkQtLXDnnXDOOTB8eOhqRNJNgSZSRM88A2vWwHXXha5EJP0UaCJFdPvt0bVnuphapPgUaCJFsmZNdO3ZNddAjx6hqxFJPwWaSJHceitUV8PXvx66EpHKoEATKYL16+Hee+Gqq2Dw4NDViFQGBZpIEfz859DcDN/7XuhKRCqHAk2kwNatg9tug6lT4dhjQ1cjUjkUaCIFduON0d7ZT34SuhKRyqJAEymgZcvg3/8d/u7vons3ikjpKNBECqSlBa69Fvr0gRtuCF2NSOXpHroAkbS44w74v/+D++6Dvn1DVyNSebSHJlIAf/gDTJ8O558PX/ta6GpEKpMCTSRPu3fDl78M3brBjBl6RIxIKDrkKJKnb38b6uujp1IPGxa6GpHKpT00kTz84hfRNWff+Y5uQCwSmgJN5AD9539Ge2df+hLcckvoakREgSZyAO66K+r8cd55UbBVVYWuSEQUaCJdsHcvfP/7MG0aTJwYPR6mV6/QVYkIKNBEOu2tt2D8ePjpT6MLqJ94Ag46KHRVItJKgSbSgT17oidPn3QSLFgA99wTjeuhnSLlRYEm0o49e+Dhh6Mgu+46OP10aGiAq68OXZmI5KLr0ETaaGyMOnrMmBE9qPPkk+HJJ+Gii3TRtEg5U6BJxXvvPXjpJXjuuSi4Ghqi4JowAe68EyZNUi9GkSTIK9DM7DLgJmAEMNrdM+20mwD8EqgC7nb3m/NZrkhXtbTA5s3w9tuwahWsWBG9li+PXi0t0a2rxoyJLpa+9FL49KdDVy0iXZHvHloDMAW4s70GZlYF3AZcADQB9WZW5+6v5bnsDn3wAeza9dG4+8fna7y8xyHqJr97d8evnTth27Zob+u99z4a3ro1Omz4zjvRQzdbmUWBNWIETJkSBdnnPw+9e3+yBhFJhrwCzd1XANj+TyyMBhrdfU3c9hHgYqDogfbDH8LPflbspUg5+dSn4NBD4bDDovc+feDEE+Goo6LXoEHRgzePPx4OPjh0tSJSSKU4hzYYeCtrvAn4fK6GZjYNmAYwdOjQvBd8ySWfPGzUNns1Xl7jbad16wY9e0Zd5Pf36tUr2rvqrrPCIhWrw19/M5sLHJlj1g3u/mQhi3H3GcAMgNra2hwHoLpmzJjoJSIi6ddhoLn7+XkuYx1wdNb4kHiaiIhIwZTiwup6YLiZDTOzHsBUoK4EyxURkQqSV6CZ2aVm1gScCfzOzJ6Opx9lZrMA3L0ZuB54GlgB/Je7L8+vbBERkY/Lt5fjE8ATOaa/DUzKGp8FzMpnWSIiIvujezmKiEgqKNBERCQVFGgiIpIKCjQREUkFBZqIiKSCAk1ERFJBgSYiIqmgQBMRkVRQoImISCoo0EREJBUUaCIikgoKNBERSQVzz/s5mkVhZpuANwvwUf2BzQX4nFJKWs2qt/iSVnPS6oXk1Vyp9X7a3QfkmlG2gVYoZpZx99rQdXRF0mpWvcWXtJqTVi8kr2bV+0k65CgiIqmgQBMRkVSohECbEbqAA5C0mlVv8SWt5qTVC8mrWfW2kfpzaCIiUhkqYQ9NREQqgAJNRERSIRWBZmaXmdlyM2sxs9o28/7RzBrN7HUzG9/O1w8zs5fjdo+aWY/SVL5v+Y+a2ZL4tdbMlrTTbq2ZLYvbZUpZY5s6bjKzdVk1T2qn3YR4vTea2fRS15lVx61mttLMXjWzJ8zs8HbaBV2/Ha0vM+sZbyuN8fZaU+oa29RztJnNN7PX4t+//5ejzVgz25a1rdwYotasevb7M7bIv8br+FUzGxWizqx6js9ad0vM7D0z+2abNkHXsZnda2Ybzawha1pfM5tjZqvi9z7tfO2VcZtVZnZl3sW4e+JfwAjgeOBZoDZr+onAUqAnMAxYDVTl+Pr/AqbGw3cA1wb8Xn4G3NjOvLVA/zJY3zcB3+2gTVW8vo8BesQ/hxMD1Xsh0D0evgW4pdzWb2fWF/AN4I54eCrwaODtYBAwKh7uDbyRo+axwMyQdXblZwxMAv4HMOAM4OXQNbfZRt4hurC4bNYx8AVgFNCQNe2nwPR4eHqu3zmgL7Amfu8TD/fJp5ZU7KG5+wp3fz3HrIuBR9z9Q3f/A9AIjM5uYGYGnAc8Fk+6H7ikmPW2J67lcuDhEMsvsNFAo7uvcffdwCNEP4+Sc/dn3L05Hn0JGBKijg50Zn1dTLR9QrS9jou3mSDcfb27L46HtwMrgMGh6imQi4H/8MhLwOFmNih0UbFxwGp3L8QdlArG3Z8DtraZnL2ttvc3dTwwx923uvufgDnAhHxqSUWg7cdg4K2s8SY++QvXD3g36w9erjalcg6wwd1XtTPfgWfMbJGZTSthXblcHx+SubedwwmdWfchXE30H3guIddvZ9bXvjbx9rqNaPsNLj78eSrwco7ZZ5rZUjP7HzM7qaSFfVJHP+Ny3W4h2itv75/dclrHAEe4+/p4+B3giBxtCr6uu+fzxaVkZnOBI3PMusHdnyx1PV3Vyfq/wv73zsa4+zozGwjMMbOV8X9HBbe/eoFfAz8h+uPwE6LDpFcXo47O6sz6NbMbgGbgwXY+pmTrN03M7BDgt8A33f29NrMXEx0iez8+1/rfwPBS15glkT/j+Lz+ZOAfc8wut3X8Me7uZlaS68MSE2jufv4BfNk64Ois8SHxtGxbiA4rdI//683VJm8d1W9m3YEpwGn7+Yx18ftGM3uC6DBVUX4ZO7u+zewuYGaOWZ1Z9wXTifV7FfBFYJzHB/BzfEbJ1m8OnVlfrW2a4u3lMKLtNxgzqyYKswfd/fG287MDzt1nmdntZtbf3YPcVLcTP+OSbrddMBFY7O4b2s4ot3Uc22Bmg9x9fXzIdmOONuuIzv+1GkLUD+KApf2QYx0wNe4dNozov5aF2Q3iP27zgb+IJ10JhNjjOx9Y6e5NuWaa2afMrHfrMFFHh4ZcbYutzTmFS9upox4YblEP0h5Eh0vqSlFfW2Y2AfgeMNndd7TTJvT67cz6qiPaPiHaXn/fXjiXQnz+7h5ghbv/vJ02R7ae5zOz0UR/c4KEcCd/xnXA1+LejmcA27IOnYXU7tGbclrHWbK31fb+pj4NXGhmfeLTFhfG0w5cqJ4xhXwR/VFtAj4ENgBPZ827gaj32OvAxKzps4Cj4uFjiIKuEfgN0DPA93Af8PU2044CZmXVuDR+LSc6lBZqfT8ALANejTfcQW3rjccnEfV8Wx243kaiY/VL4ldrT8GyWr+51hfwY6IgBugVb5+N8fZ6TKh1Gtczhuiw86tZ63YS8PXWbRm4Pl6fS4k65JwVsN6cP+M29RpwW/wzWEZWr+mAdX+KKKAOy5pWNuuYKGjXA3viv8N/Q3Rudx6wCpgL9I3b1gJ3Z33t1fH23Aj8db616NZXIiKSCmk/5CgiIhVCgSYiIqmgQBMRkVRQoImISCoo0EREJBUUaCIikgoKNBERSYX/D4yQGwcp2LC4AAAAAElFTkSuQmCC\n" 566 | }, 567 | "metadata": { 568 | "needs_background": "light" 569 | } 570 | } 571 | ], 572 | "source": [ 573 | "plt.figure(figsize=(7, 4))\n", 574 | "plt.plot(x_1, y_tanh, c='blue', label='tanh')\n", 575 | "plt.ylim((-1.2, 1.2))\n", 576 | "plt.legend(loc='best')" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 22, 582 | "metadata": { 583 | "colab": { 584 | "base_uri": "https://localhost:8080/", 585 | "height": 283 586 | }, 587 | "id": "eTRgAPER61GD", 588 | "outputId": "65cca173-5e2e-41be-aa69-134d561b3b60" 589 | }, 590 | "outputs": [ 591 | { 592 | "output_type": "execute_result", 593 | "data": { 594 | "text/plain": [ 595 | "" 596 | ] 597 | }, 598 | "metadata": {}, 599 | "execution_count": 22 600 | }, 601 | { 602 | "output_type": "display_data", 603 | "data": { 604 | "text/plain": [ 605 | "
" 606 | ], 607 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAD4CAYAAAC5S3KDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAf70lEQVR4nO3de5yWc/7H8deng05SqZaiVH5Cqi07WGLJKR2UCisS4tcPG7W7ktjsPJyPqyw6bvippbZCOo9SyiNRSSFW6Og05Vc6bMf5/v743jGNmZru0/c+vJ+Pxzzmnvu+Zq73XHPPvOe67u/1vcw5h4iISKopEzqAiIhIcVRQIiKSklRQIiKSklRQIiKSklRQIiKSksolc2W1atVyDRo0SOYqRUQkhS1evHiDc652cY8ltaAaNGjAokWLkrlKERFJYWa2uqTHdIhPRERSkgpKRERSkgpKRERSUlJfgyrO7t27WbduHTt27AgdJaVVrFiRY489lvLly4eOIiKSFMELat26dVStWpUGDRpgZqHjpCTnHBs3bmTdunU0bNgwdBwRkaQIfohvx44d1KxZU+V0AGZGzZo1tZcpIlkleEEBKqdS0DYSkWyTEgUlIiJSlArqEMybN49TTjmFFi1asGDBAqZOnXrQz1m1ahVNmzZNQjoRkcyigjoEY8aMYcCAASxdupTPPvusVAUlIiLRyfqC2rZtG+3bt+fXv/41TZs2ZezYscyaNYuWLVvSrFkzevbsyc6dOxk5ciTjxo1j4MCBdOvWjXvvvZexY8fSokULxo4dS25uLtdeey1nnnkmJ5xwAiNGjPjFul544QV69+7908cdOnRgzpw57N27l+uvv56mTZvSrFkznnrqqWRuAhGRlBR8mHlhffvC0qXx/ZotWsCgQSU/Pn36dOrWrcuUKVMA2Lx5M02bNmXWrFk0btyYHj16MGTIEPr27cv8+fPp0KEDl19+OS+88AKLFi3imWeeASA3N5dly5bx7rvvsm3bNlq2bEn79u1LlXHp0qWsX7+ejz76CIBNmzbF9k2LiGSArN+DatasGXl5efTv35958+axatUqGjZsSOPGjQG47rrrePvtt0v1tTp16kSlSpWoVasWrVu35r333ivV5zVq1Igvv/yS2267jenTp3PEEUdE/f2IiGSKlNqDOtCeTqI0btyYJUuWMHXqVP7yl79w/vnnR/21ig4FL/pxuXLlKCgo+Onjfec11ahRgw8//JAZM2YwdOhQxo0bx6hRo6LOISKSCbJ+D+rrr7+mcuXKdO/enX79+rFgwQJWrVrFypUrAXjppZc499xzf/F5VatWZcuWLfvd9/rrr7Njxw42btzInDlzOO200/Z7vEGDBixdupSCggLWrl370x7Whg0bKCgooGvXrjzwwAMsWbIkQd+tiEj6OOgelJmNAjoA3zvnmkbuOxIYCzQAVgFXOuf+L3ExE2f58uX069ePMmXKUL58eYYMGcLmzZu54oor2LNnD6eddho333zzLz6vdevWPPLII7Ro0YIBAwYA0Lx5c1q3bs2GDRsYOHAgdevWZdWqVT99TqtWrWjYsCFNmjTh5JNP5tRTTwVg/fr13HDDDT/tXT388MOJ/8ZFRFKcOecOvIDZ74CtwP8WKqjHgB+cc4+Y2V1ADedc/4OtLCcnxxW9YOGKFSs4+eSTo82fMnJzczn88MO54447EraOTNlWIpL+8vOhdrHXwT00ZrbYOZdT3GMHPcTnnHsb+KHI3Z2AFyO3XwQuiymhiIikjY8/hhNPhCFDErueaAdJHOWc+yZy+1vgqJIWNLNeQC+A+vXrR7m61Jebmxs6gohIwq1eDRdfDBUrQtu2iV1XzIMknD9GWOJxQufccOdcjnMup3YJ+4MHO8wo2kYiEl5+vi+n7dthxgxo0CCx64u2oL4zszoAkfffRxugYsWKbNy4UX+AD2Df9aAqVqwYOoqIZKktW/we05o1MHkyNGuW+HVGe4hvEnAd8Ejk/evRBjj22GNZt24d+fn50X6JrLDviroiIsm2cyd07uxn+nntNWjVKjnrLc0w85eB84BaZrYO+Cu+mMaZ2Y3AauDKaAOUL19eV4kVEUlRe/dC9+4waxa8+CJ06JC8dR+0oJxz3Up46II4ZxERkRTiHPTuDePHwxNPQI8eyV1/1s8kISIixcvNhaFDoX9/+POfk79+FZSIiPzCM8/AffdBz54QanIbFZSIiOzn5Zfh9tuhUycYNgyKzHudNCooERH5yYwZ/rWmc87xRVUu4DUvVFAiIgLAwoXQtSuccgpMmgSVKoXNo4ISERFWrID27eGoo2D6dKhWLXQiFZSISNZbu9ZPYVSuHMycCUcfHTqRl1JX1BURkeTauNGX048/wty5cPzxoRP9TAUlIpKltm71h/W++soPjmjRInSi/amgRESy0K5dcPnl8P77MGECnHtu6ES/pIISEckyBQVw/fV+r2nkSLgsRS85q0ESIiJZxDno08ef4/TII3DjjaETlUwFJSKSRR580E9j9Kc/wZ13hk5zYCooEZEsMXQoDBwI114Ljz8ebgqj0lJBiYhkgfHj4dZb/ai9f/wDyqTBX/80iCgiIrGYNQuuuQbOPBPGjYPy5UMnKh0VlIhIBlu82I/Sa9wYJk+GypVDJyo9FZSISIb697+hbVuoWdPPr1ejRuhEh0YFJSKSgdav91MYgZ9f75hjwuaJhk7UFRHJMD/8AG3a+Hn25szxh/fSkQpKRCSDbN8Ol14Kn38O06bBb34TOlH0VFAiIhli92644gpYsMCP1jv//NCJYqOCEhHJAAUF0LMnTJ3qT8i9/PLQiWKnQRIiImnOObjjDhg9Gu6/H/7nf0Inig8VlIhImnv0UXjqKbjtNrjnntBp4kcFJSKSxkaOhAEDoFs3GDQo9efXOxQqKBGRNPXaa/5wXps28MIL6TG/3qHIsG9HRCQ7zJ0LV10Fp53mr4h72GGhE8WfCkpEJM188AF07AiNGsGUKVClSuhEiRFTQZnZH83sYzP7yMxeNrOK8QomIiK/tHIlXHIJVKvmL9les2boRIkTdUGZ2THA7UCOc64pUBa4Kl7BRERkf998419v2rvXz69Xr17oRIkV64m65YBKZrYbqAx8HXskEREpatMmPzP5d9/B7Nlw0kmhEyVe1HtQzrn1wBPAGuAbYLNzbma8gomIiPef//jXnD75BCZOhNNPD50oOWI5xFcD6AQ0BOoCVcysezHL9TKzRWa2KD8/P/qkIiJZaM8eP1pv/nx46aWfL6GRDWIZJHEh8JVzLt85txuYCJxVdCHn3HDnXI5zLqd27doxrE5EJLs4B716waRJ8PTT8Pvfh06UXLEU1Brgt2ZW2cwMuABYEZ9YIiJy113w/PNw773Qu3foNMkXy2tQC4HxwBJgeeRrDY9TLhGRrPbEE/DYY3DLLZCbGzpNGDGN4nPO/RX4a5yyiIgI8OKL0K+fv7bT3/+eWfPrHQrNJCEikkImT4Ybb4QLLvCDIsqWDZ0oHBWUiEiKmD/f7zW1bAmvvgoVKoROFJYKSkQkBSxbBh06QP36/qq4VauGThSeCkpEJLCvvvLz61Wp4qcw0hk5XqxTHYmISAy+/96ffLtjB8ybB8cdFzpR6lBBiYgE8uOPfs9p/Xp480045ZTQiVKLCkpEJIAdO+Cyy2D5cj9TxFm/mIdHVFAiIkm2dy9ccw289ZYfSt62behEqUmDJEREksg5PzvExInw1FPQ/RdTbMs+KigRkSQaOBBGjIC774a+fUOnSW0qKBGRJBk8GB58EG66CR54IHSa1KeCEhFJgjFj/B5T584wZEj2zq93KFRQIiIJNm0aXH89nHce/POfUE7D00pFBSUikkALFkDXrtCsGbz+OlSsGDpR+lBBiYgkyMcfQ/v2ULeu34s64ojQidKLCkpEJAHWrIE2bfyM5DNnwlFHhU6UfnQkVEQkzvLz/fx6W7fC229Do0ahE6UnFZSISBxt2QLt2sHq1X7PqXnz0InSlwpKRCROdu6ELl3ggw/8BQfPOSd0ovSmghIRiYO9e6FHDz8r+fPPw6WXhk6U/jRIQkQkRs7B7bfDuHHw+OP+nCeJnQpKRCRG990Hzz0H/frBHXeETpM5VFAiIjF47jnIzfV7TY8+GjpNZlFBiYhEadw46N3bv940YoTm14s3FZSISBTy8vy1nFq1grFjNb9eIqigREQO0fvv+1nJTzoJ3ngDKlUKnSgzqaBERA7Bp5/6S7TXrg0zZkD16qETZS4VlIhIKa1b56cwKlvWH+KrUyd0osymo6YiIqWwcaMvp02bYO5c+K//Cp0o86mgREQOYts26NABvvjCH9Zr2TJ0ouwQ0yE+M6tuZuPN7FMzW2FmZ8YrmIhIKti1Cy6/HN57D155xV8VV5Ij1j2owcB059zlZnYYUDkOmUREUkJBAdxwA0yf7s9z6tw5dKLsEnVBmVk14HfA9QDOuV3ArvjEEhEJyzn44x/hn/+Ehx6Cm24KnSj7xHKIryGQDzxvZh+Y2Ugzq1J0ITPrZWaLzGxRfn5+DKsTEUmehx+Gp5+Gvn3hrrtCp8lOsRRUOeBUYIhzriWwDfjFj9E5N9w5l+Ocy6ldu3YMqxMRSY7hw+Gee/xMEU8+qSmMQomloNYB65xzCyMfj8cXlohI2powAW65xZ+MO2oUlNHZosFEvemdc98Ca83sxMhdFwCfxCWViEgAs2fD1VfDGWfAv/4F5cuHTpTdYh3FdxswJjKC70vghtgjiYgk36JF0KkTnHACTJ4MVX7xirokW0wF5ZxbCuTEKYuISBD75terVcufiHvkkaETCWguPhHJcmvX+imMypSBmTPhmGNCJ5J9NNWRiGStDRt8OW3eDHPm+MN7kjpUUCKSlbZuhfbt4auvNL9eqlJBiUjW2bkTunSBxYth4kQ499zQiaQ4KigRySp798K11/rrOT3/PHTsGDqRlESDJEQkazgHf/iDP8fpiSfg+utDJ5IDUUGJSNa4914YNgz694c//zl0GjkYFZSIZIXBg+GBB+DGG/1EsJL6VFAikvFGj/azknfpAkOHavLXdKGCEpGMNmWKf62pdWsYMwbKaWhY2lBBiUjGmj/fX669RQt47TWoWDF0IjkUKigRyUjLlkGHDlC/PkybBkccETqRHCoVlIhknC+/hDZt4PDD/flOulZqetLRWBHJKN9+CxddBLt2wbx5fg9K0pMKSkQyxqZNfs/pu+9g1ixo0iR0IomFCkpEMsL27XDppbBihb/g4BlnhE4ksVJBiUja270bfv97eOcdeOUVfwkNSX8qKBFJawUF0LOn32saMgSuvDJ0IokXjeITkbTlnJ9Tb/RouP9+uPnm0IkknlRQIpK2Hn4YBg2CPn3gnntCp5F4U0GJSFoaNsyXUvfu8Le/aX69TKSCEpG0M3483HKLv2T7qFFQRn/JMpJ+rCKSVvLy4Oqr4ayzYNw4KF8+dCJJFBWUiKSN996Dzp3hpJPgjTegcuXQiSSRVFAikhZWrIB27eBXv4IZM6BGjdCJJNFUUCKS8lav9ifflivnD/HVqRM6kSSDTtQVkZT23Xdw4YWwdSvMnQvHHx86kSSLCkpEUta+yV+//trvOTVvHjqRJJMKSkRS0rZtfhj5J5/4ARFnnRU6kSSbCkpEUs6uXf5S7e++C2PH+r0oyT4xD5Iws7Jm9oGZTY5HIBHJbnv3+tkhpk+H4cN9UUl2iscovj7Aijh8HRHJcs75CV//9S944gm48cbQiSSkmArKzI4F2gMj4xNHRLLZXXfByJFw991+lnLJbrHuQQ0C7gQKSlrAzHqZ2SIzW5Sfnx/j6kQkUz3yCDz2GNx6KzzwQOg0kgqiLigz6wB875xbfKDlnHPDnXM5zrmc2rVrR7s6EclgQ4fCgAF+jr2//10zk4sXyx5UK6Cjma0CXgHON7PRcUklIlnj5Zf9XlP79vDCC5qZXH4W9VPBOTfAOXesc64BcBUw2znXPW7JRCTjTZ0KPXrAOef4gRGamVwK0/8qIhLEvHnQtaufHeKNN6BSpdCJJNXE5URd59wcYE48vpaIZL4lS6BDB2jQwJ/vdMQRoRNJKtIelIgk1WefwSWXQPXqMHMmaOyUlEQFJSJJs2YNXHSRH6WXlwf16oVOJKlMc/GJSFJ8/70vpx9/hDlzoHHj0Ikk1amgRCThfvjBl9Patf6wXosWoRNJOlBBiUhCbdkCbdvCp5/C5Mlw9tmhE0m6UEGJSMJs3+5H6y1eDBMm+L0okdJSQYlIQuzc6c9zmjcPxoyBTp1CJ5J0o4ISkbjbs8fPqzd9up+dvFu30IkkHWmYuYjEVUEB3HADTJwIgwbpmk4SPRWUiMSNc37i19Gj/SUz+vQJnUjSmQpKROLCOejXD4YN8xcevPvu0Ikk3amgRCQu7rsPnnwSeveGhx7SNZ0kdiooEYnZE09Abq5/7WnwYJWTxIcKSkRiMmSIP7R35ZUwYoQuOCjxo6eSiETtxRf9oIgOHeCll6Bs2dCJJJOooEQkKmPG+EN6F17or4Z72GGhE0mmUUGJyCEbO9Zfqv288+D116FixdCJJBOpoETkkEyYANdcA61a+Uu1V64cOpFkKhWUiJTa66/DVVfBGWfAlClQpUroRJLJVFAiUipTpsAVV8Cpp8K0aVC1auhEkulUUCJyUDNmQJcu0Ly5v33EEaETSTZQQYnIAc2aBZddBk2a+KvhVq8eOpFkCxWUiJRozhy49FI44QTIy4MjjwydSLKJCkpEijV3rj8Bt2FDePNNqFUrdCLJNiooEfmFWbOgbVuoX9/f/tWvQieSbKSCEpH9zJzp95yOP94f4jv66NCJJFupoETkJ9OmQceO0LgxzJ6tPScJSwUlIoCfFWLfaL3Zs6F27dCJJNupoESE116Drl39eU6zZkHNmqETiaigRLLe+PE/zxCRlwc1aoROJOJFXVBmVs/M3jKzT8zsYzPrE89gIpJ4L7/s59Y7/XSdhCupJ5Y9qD3An51zTYDfAn8wsybxiSUiiTZsmJ+V/OyzYfp0TV8kqSfqgnLOfeOcWxK5vQVYARwTr2AikjiPPQY33wzt2mniV0ldcXkNyswaAC2BhcU81svMFpnZovz8/HisTkSi5BzcfTf07+8P7b36KlSqFDqVSPFiLigzOxyYAPR1zv1Y9HHn3HDnXI5zLqe2xq2KBFNQAL17w8MPQ69eMHo0lC8fOpVIyWIqKDMrjy+nMc65ifGJJCLxtns3XHcdPPcc9OsHQ4dC2bKhU4kcWLloP9HMDPgHsMI597f4RRKReNq2zR/OmzwZHnwQBgwAs9CpRA4u6oICWgHXAsvNbGnkvrudc1NjjyUi8ZCf7y+X8d578OyzcOutoROJlF7UBeWcmw/o/zCRFPXll3DJJbB2LUyYAJ07h04kcmhi2YMSkRS1eLEfQr5nj5+66KyzQicSOXSa6kgkw8yYAeee64ePv/OOyknSlwpKJIMMGwbt2/tLtC9YACedFDqRSPRUUCIZYM8e6NPHzw7Rpo2/XHudOqFTicRGBSWS5jZv9iP1nn4a/vhHmDRJ8+pJZtAgCZE09sUXvpw+/xyGD4f//u/QiUTiRwUlkqby8qBbNz+F0cyZ0Lp16EQi8aVDfCJppqAAHnrIv9ZUpw4sXKhyksykPSiRNLJpE/ToAW+8AVdf7Q/rVakSOpVIYqigRNLEsmXQpQusXu0HRPTurTn1JLPpEJ9IinMOnnnGX5Z9+3aYMwduu03lJJlPBSWSwvLzoWNHX0jnnw8ffACtWoVOJZIcKiiRFJWXB82b+xF6gwfDlClw1FGhU4kkjwpKJMVs2eJfX7r4YjjySHj/fbj9dh3Sk+yjghJJIXl50KyZv/Jtnz6+nJo3D51KJAwVlEgK2LQJbrrJ7zVVqADz5sGgQVC5cuhkIuGooEQCKiiA55+Hxo39+/79YelSDYQQAZ0HJRLMokX+taaFC+HMM2H6dDj11NCpRFKH9qBEkmzNGujZ05/XtGoVvPgizJ+vchIpSntQIkmyYQM8/DA8+6w/+fZPf4KBA6FatdDJRFKTCkokwfLz/dREgwfDtm1w3XWQmwv164dOJpLaVFAiCbJmDTz5JIwYAf/5j59H7/77oUmT0MlE0oMKSiSOnIN33oEhQ2DcOH9f9+5w551w8slhs4mkGxWUSBz8+COMGeOLaflyf8n13r39Jdh1KE8kOiookSjt3AnTpvlimjwZduyAli39Ib1u3XSdJpFYqaBEDsHWrX46ojfegIkTYfNmqF3bzwJx7bVw2mmaM08kXlRQIgfgHHzyCcye7WcTf+st2LXLDw3v2BGuuQYuuADK6TdJJO70ayVSyO7d8PHHfi68uXP924YN/rHGjf3rSpde6qciKl8+bFaRTKeCkqy1ZQt8+qmf+27JEli82F9WfedO//hxx0G7dnDuuf7t+OPD5hXJNiooyVjOwQ8/wNq1/pykr77yhfTZZ/7t669/XrZaNT/V0G23wW9+4+fGO+64cNlFJMaCMrNLgMFAWWCkc+6RuKQSKYFzfkj3hg1+hob8/P1v5+fD+vW+lNauhe3b9//86tXhpJPgoov8+xNP9NdbatRIgxtEUk3UBWVmZYFngYuAdcD7ZjbJOfdJvMJJWM7B3r3Fv+3ZU/JjhZfZudO/7dhRutvbt/sC2rJl/7d9923d6nMVp0IFP6Kubl1o2tQfnqtXz5+HVK8eNGjgH1cRiaSHWPagTgdWOue+BDCzV4BOQMIK6vPPoVMnf7vwH6lE3k7WelIhS0HB/gVTUhEkQpkyvmCqVIGqVX9+q1kTGjbc/75q1XzR1Kq1//sqVVQ+IpkkloI6Blhb6ON1wBlFFzKzXkAvgPoxnlJfsaL/z/jnr52c28laT+gsZlC2bOneypUr3TIVKvifW4UK+98uep+GaYtIUQn/s+CcGw4MB8jJyYnpf/J69X6e30xERDJbLBcsXA/UK/TxsZH7REREYhZLQb0PnGBmDc3sMOAqYFJ8YomISLYzF8Mr4WbWDhiEH2Y+yjn34EGWzwdWR73Cn9UCNsTh6ySL8iZeumVOt7yQfpmVN/Hikfk451zt4h6IqaBCMbNFzrmc0DlKS3kTL90yp1teSL/Mypt4ic4cyyE+ERGRhFFBiYhISkrXghoeOsAhUt7ES7fM6ZYX0i+z8iZeQjOn5WtQIiKS+dJ1D0pERDKcCkpERFJSShaUmV1hZh+bWYGZ5RR5bICZrTSzz8ysTQmf39DMFkaWGxs5kThpIutcGnlbZWZLS1hulZktjyy3KJkZi+TINbP1hTK3K2G5SyLbfaWZ3ZXsnEWyPG5mn5rZMjN71cyql7Bc0G18sG1mZhUiz5eVkedsg2RnLJSlnpm9ZWafRH7/+hSzzHlmtrnQc+XeEFmLZDrgz9i8pyPbeJmZnRoiZyTLiYW23VIz+9HM+hZZJvg2NrNRZva9mX1U6L4jzSzPzD6PvK9RwudeF1nmczO7LqYgzrmUewNOBk4E5gA5he5vAnwIVAAaAl8AZYv5/HHAVZHbQ4FbAn4vTwL3lvDYKqBWCmzvXOCOgyxTNrK9GwGHRX4OTQJmvhgoF7n9KPBoqm3j0mwz4FZgaOT2VcDYgNu0DnBq5HZV4N/F5D0PmBwqYzQ/Y6AdMA0w4LfAwtCZCz0/vsWfqJpS2xj4HXAq8FGh+x4D7orcvqu43zngSODLyPsakds1os2RkntQzrkVzrnPinmoE/CKc26nc+4rYCX+sh8/MTMDzgfGR+56EbgskXlLEslyJfByiPXH2U+XV3HO7QL2XV4lCOfcTOfcnsiH7+Lngkw1pdlmnfDPUfDP2Qsiz5ukc85945xbErm9BViBv2pBuusE/K/z3gWqm1md0KGAC4AvnHPxmF0nrpxzbwM/FLm78HO1pL+rbYA859wPzrn/A/KAS6LNkZIFdQDFXeKj6C9QTWBToT9exS2TLOcA3znnPi/hcQfMNLPFkcuShNQ7cvhjVAm77qXZ9qH0xP+HXJyQ27g02+ynZSLP2c3453BQkUONLYGFxTx8ppl9aGbTzOyUpAYr3sF+xqn63L2Kkv95TbVtDHCUc+6byO1vgaOKWSau2zrYVXjM7E3g6GIeusc593qy8xyqUubvxoH3ns52zq03s18BeWb2aeQ/l7g7UF5gCHA//hf9fvxhyZ6JyHEoSrONzeweYA8wpoQvk7RtnCnM7HBgAtDXOfdjkYeX4A9JbY28VvkacEKyMxaRdj/jyOviHYEBxTycitt4P845Z2YJP0cpWEE55y6M4tNKc4mPjfhd+HKR/0gTchmQg+U3s3JAF+A3B/ga6yPvvzezV/GHhBLyi1Xa7W1mI4DJxTyU9MurlGIbXw90AC5wkQPgxXyNpG3jYpRmm+1bZl3kOVMN/xwOwszK48tpjHNuYtHHCxeWc26qmT1nZrWcc8EmOS3FzzgVLw3UFljinPuu6AOpuI0jvjOzOs65byKHSL8vZpn1+NfQ9jkWP5YgKul2iG8ScFVk5FND/H8V7xVeIPKH6i3g8shd1wEh9sguBD51zq0r7kEzq2JmVffdxr/o/1FxyyZakePxnUvIkVKXVzGzS4A7gY7Oue0lLBN6G5dmm03CP0fBP2dnl1S2iRZ57esfwArn3N9KWObofa+Rmdnp+L8hIQu1ND/jSUCPyGi+3wKbCx2qCqXEoyupto0LKfxcLenv6gzgYjOrEXmp4OLIfdEJOVKkpDf8H8l1wE7gO2BGocfuwY+M+gxoW+j+qUDdyO1G+OJaCfwLqBDge3gBuLnIfXWBqYUyfhh5+xh/2CrU9n4JWA4sizwJ6xTNG/m4HX5k1xch80ayrMQf614aeds3Ei6ltnFx2wy4D1+sABUjz9GVkedso4Db9Gz8Yd5lhbZrO+Dmfc9loHdkW36IH5xyVuDnQbE/4yKZDXg28jNYTqGRwYEyV8EXTrVC96XUNsaX5zfA7sjf4hvxr43OAj4H3gSOjCybA4ws9Lk9I8/nlcANseTQVEciIpKS0u0Qn4iIZAkVlIiIpCQVlIiIpCQVlIiIpCQVlIiIpCQVlIiIpCQVlIiIpKT/B6UAYpYmuXq+AAAAAElFTkSuQmCC\n" 608 | }, 609 | "metadata": { 610 | "needs_background": "light" 611 | } 612 | } 613 | ], 614 | "source": [ 615 | "plt.figure(figsize=(7, 4))\n", 616 | "plt.plot(x_1, y_softplus, c='blue', label='softplus')\n", 617 | "plt.ylim((-0.2, 11))\n", 618 | "plt.legend(loc='best')" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": 23, 624 | "metadata": { 625 | "collapsed": true, 626 | "id": "aP1odC6d61GD" 627 | }, 628 | "outputs": [], 629 | "source": [ 630 | "def prep_data():\n", 631 | " train_X = np.asarray([13.3,14.4,15.5,16.71,16.93,14.168,19.779,16.182,\n", 632 | " 17.59,12.167,17.042,10.791,15.313,17.997,15.654,\n", 633 | " 19.27,13.1])\n", 634 | " train_Y = np.asarray([11.7,12.76,12.09,13.19,11.694,11.573,13.366,12.596,\n", 635 | " 12.53,11.221,12.827,13.465,11.65,12.904,12.42,12.94,\n", 636 | " 11.3])\n", 637 | " dtype = torch.FloatTensor\n", 638 | " X = Variable(torch.from_numpy(train_X).type(dtype),\n", 639 | " requires_grad=False).view(17,1)\n", 640 | " y = Variable(torch.from_numpy(train_Y).type(dtype),requires_grad=False)\n", 641 | " return X,y" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 24, 647 | "metadata": { 648 | "collapsed": true, 649 | "id": "DVryHmSW61GD" 650 | }, 651 | "outputs": [], 652 | "source": [ 653 | "# get dynamic parameters" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": 25, 659 | "metadata": { 660 | "collapsed": true, 661 | "id": "mkoMb2bB61GD" 662 | }, 663 | "outputs": [], 664 | "source": [ 665 | "def set_weights():\n", 666 | " w = Variable(torch.randn(1),requires_grad = True)\n", 667 | " b = Variable(torch.randn(1),requires_grad=True)\n", 668 | " return w,b" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": 26, 674 | "metadata": { 675 | "collapsed": true, 676 | "id": "UGuKavsI61GE" 677 | }, 678 | "outputs": [], 679 | "source": [ 680 | "#deploy neural network model" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 27, 686 | "metadata": { 687 | "collapsed": true, 688 | "id": "wxNs7d4f61GE" 689 | }, 690 | "outputs": [], 691 | "source": [ 692 | "def build_network(x):\n", 693 | " y_pred = torch.matmul(x,w)+b\n", 694 | " return y_pred" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 28, 700 | "metadata": { 701 | "colab": { 702 | "base_uri": "https://localhost:8080/" 703 | }, 704 | "id": "PHBqbn9D61GE", 705 | "outputId": "49da70e2-d18d-4c78-becf-1e6a8d311686" 706 | }, 707 | "outputs": [ 708 | { 709 | "output_type": "execute_result", 710 | "data": { 711 | "text/plain": [ 712 | "Linear(in_features=17, out_features=1, bias=True)" 713 | ] 714 | }, 715 | "metadata": {}, 716 | "execution_count": 28 717 | } 718 | ], 719 | "source": [ 720 | "#implement in PyTorch\n", 721 | "import torch.nn as nn\n", 722 | "f = nn.Linear(17,1) # Much simpler.\n", 723 | "f" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": 29, 729 | "metadata": { 730 | "collapsed": true, 731 | "id": "wxh4e4b-61GF" 732 | }, 733 | "outputs": [], 734 | "source": [ 735 | "#calculate the loss function" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": 35, 741 | "metadata": { 742 | "collapsed": true, 743 | "id": "Nh4jA4H561GF" 744 | }, 745 | "outputs": [], 746 | "source": [ 747 | "def loss_calc(y,y_pred):\n", 748 | " loss = (y_pred-y).pow(2).sum()\n", 749 | " for param in [w,b]:\n", 750 | " if not param.grad is None: param.grad.data.zero_()\n", 751 | " loss.backward()\n", 752 | " return loss.data" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": 36, 758 | "metadata": { 759 | "collapsed": true, 760 | "id": "JrM-2sCf61GF" 761 | }, 762 | "outputs": [], 763 | "source": [ 764 | "# optimizing results" 765 | ] 766 | }, 767 | { 768 | "cell_type": "code", 769 | "execution_count": 37, 770 | "metadata": { 771 | "collapsed": true, 772 | "id": "lMRYtQ_261GF" 773 | }, 774 | "outputs": [], 775 | "source": [ 776 | "def optimize(learning_rate):\n", 777 | " w.data -= learning_rate * w.grad.data\n", 778 | " b.data -= learning_rate * b.grad.data" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 38, 784 | "metadata": { 785 | "collapsed": true, 786 | "id": "FUXYASCD61GG" 787 | }, 788 | "outputs": [], 789 | "source": [ 790 | "learning_rate = 1e-4" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": 39, 796 | "metadata": { 797 | "colab": { 798 | "base_uri": "https://localhost:8080/" 799 | }, 800 | "id": "N6o2Vuhy61GG", 801 | "outputId": "67e21a5b-9845-4d52-c70e-709460962e22" 802 | }, 803 | "outputs": [ 804 | { 805 | "output_type": "stream", 806 | "name": "stdout", 807 | "text": [ 808 | "tensor(5954.0488)\n", 809 | "tensor(44.9320)\n", 810 | "tensor(39.5382)\n", 811 | "tensor(34.9094)\n", 812 | "tensor(30.9371)\n" 813 | ] 814 | } 815 | ], 816 | "source": [ 817 | "x,y = prep_data() # x - training data,y - target variables\n", 818 | "w,b = set_weights() # w,b - parameters\n", 819 | "for i in range(5000):\n", 820 | " y_pred = build_network(x) # function which computes wx + b\n", 821 | " loss = loss_calc(y,y_pred) # error calculation\n", 822 | " if i % 1000 == 0: \n", 823 | " print(loss)\n", 824 | " optimize(learning_rate) # minimize the loss w.r.t. w, b" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "execution_count": 40, 830 | "metadata": { 831 | "collapsed": true, 832 | "id": "sTb-AV9I61GG" 833 | }, 834 | "outputs": [], 835 | "source": [ 836 | "import matplotlib.pyplot as plt\n", 837 | "%matplotlib inline" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": 41, 843 | "metadata": { 844 | "colab": { 845 | "base_uri": "https://localhost:8080/", 846 | "height": 283 847 | }, 848 | "id": "_wK7dOLB61GG", 849 | "outputId": "e5430fcf-6f9e-437f-bcf7-cb8ae7ae4024" 850 | }, 851 | "outputs": [ 852 | { 853 | "output_type": "execute_result", 854 | "data": { 855 | "text/plain": [ 856 | "[]" 857 | ] 858 | }, 859 | "metadata": {}, 860 | "execution_count": 41 861 | }, 862 | { 863 | "output_type": "display_data", 864 | "data": { 865 | "text/plain": [ 866 | "
" 867 | ], 868 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAZQ0lEQVR4nO3deXxU1d3H8c8PpBC1JSKgoMRYRSxuoLG1ahWXCo8bSK2tu5WnPFrrVkXBBeJWqbhUrbZFQcRaXlpBsNAWFxTUurGooKgoAhJEVAQ3kCXn+WOCkzuZYZLJ3Dn3znzf/8A5M8n8XpPw5Tfn3nuuOecQEZH4aeG7ABERyY0CXEQkphTgIiIxpQAXEYkpBbiISExtUcgXa9++vausrCzkS4qIxN6sWbM+cc51SJ0vaIBXVlYyc+bMQr6kiEjsmdnidPNaQhERiSkFuIhITCnARURiSgEuIhJTCnARkZhSgIuIxJQCXEQkphTgIiJh+uRdmD4CNq7P+7cu6IU8IiIlwzn4x5nw5qTEeJ9fQHlFXl9CAS4ikm/L5sDIXsnxCSPzHt6gABcRyZ/aWrivD3zwUmK8VUe4eB5s0TqUl1OAi4jkw8JnYGzf5PjU8dD1yFBfUgEuItIcG9fDHfvC6iWJ8fZ7w8BnoEXL0F9aAS4ikqs3HoV/nJUcD3gSuuxfsJdXgIuINNW6r2D4TlBbd2pg195wykNgVtAyFOAiIk3xyr0w5ZLk+DcvQcfdvZSiABcRaYyvV8JNOyfH+50Fx93urRxQgIuIZPfMcHjmxuT4onlQ3sVfPXUU4CIimdTMhnsOS44PvRwOu8JfPSkU4CIi6VS3DY4HLYSttvVTSwbazEpEpL75/wyGd4fdoXp15MIb1IGLiCQ4B9eUB+cueQe+u52fehpBHbiIyEt/DYb37scmuu4IhzeoAxeRUrZxPVzXPjg3pAZab+2nniZSBy4ipWnqlcHwPvCCRNcdk/AGdeAiUmq++RJu3CE4d/Wn0DJ+cRi/ikVEcjXuFHh7SnJ89M3ww1/7q6eZFOAiUvy++Ahu2S04N2xVwTefyjcFuIgUtzt6wsqFyfEvx8HuR/urJ4+yHsQ0s9FmtsLM5qV57BIzc2bWPt3Xioh48/HbiQty6od39eqiCW9oXAc+BvgTMLb+pJl1AY4CluS/LBGRZki9DP5/n4Idq/zUEqKsHbhzbgawMs1DtwGXAS7fRYmI5GTxf4Ph3aJVousuwvCGHNfAzawvUOOce82yHAQws4HAQICKiopcXk5EJLvUrvv82bDtLn5qKZAmX8hjZlsCVwBDG/N859xI51yVc66qQ4cOTX05EZHNmzchGN6d9kl03UUe3pBbB74LsDOwqfveEZhtZj90zi3PZ3EiIhml23wqglu+hqnJAe6cmwt03DQ2s0VAlXPukzzWJSKS2X/vhMevSo73PBFOHOWvHk+yBriZjQN6Ae3NbCkwzDlXeu+UiPi3YR1cn7IUe8WH8J0t/dTjWdYAd86dnOXxyrxVIyKSyZRL4ZV7kuNDBsHhV2V+fgnQlZgiEm1rP4fhKTcQHroSWrT0U0+EKMBFJLoeOAHem5YcH38n7HuGv3oiRgEuItGzugZu6x6cK4LNp/JNAS4i0XLLD+CLZcnxqY9A15/6qyfCFOAiEg0fvQF/PjA4V73aTy0xoQAXEf9SL4MfOB069/BTS4wowEXEn4XTYezxyXHr78GQD/zVEzMKcBHxI7XrvvB12GYnP7XElO5KLyKFNePmYHh3+VFirVvh3WTqwEWkMGo3wrXtgnOXvQ9btkv/fMlKAS4i4fvbz+DdJ5PjNuUweLG/eopE5AN84pwaRkx9m2Wr1tC5vIxBvbvRr+cOvssSkcZY9xX8vnNwbkgNtN7aTz1FJtIBPnFODUMmzGXN+o0A1Kxaw5AJcwEU4iJRN2JX+Orj5Pj7veCMSb6qKUqRDvARU9/+Nrw3WbN+IyOmvq0AF4mqL5bDLd2Cc9p8KhSRDvBlq9Y0aV5EPEs9NfDHv4XeN/ippQREOsA7l5dRkyasO5eXeahGRDJaPhf+cnBwTpfBA+Eex4v0eeCDenejrFXwY1dZq5YM6t0tw1eISMFVtw2G97F/VHjX2XQcr2bVGhzJ43gT59Tk5ftHOsD79dyBG/vvxQ7lZRiwQ3kZN/bfS+vfIlEw528Nl0yqV0PVr/zUE0GbO46XD5FeQoFEiCuwRSImNbhPmwC7HuGnlggL+zhe5ANcRCLkP1fAi3cF57RcklHYx/EivYQiIhFS3TYY3gOeUHhnEfZxPHXgIjEX+tXKo/vAkheCcwruRtn0cwjr56MAF4mxUK9W3rgermsfnLtoLpRXNO/7lpgwj+MpwEViLLSrlVMPUoK67ghSgIvEWN7PcljzGfyhMjinzaciSwEuEmN5PcshtesuaweXv59jZVIIOgtFJMbycpbDJwsahvfQzxTeMaAOXCTGmn2WQ2pwdzsGTv57nquUsGQNcDMbDRwLrHDO7Vk3dx3QF6gFVgBnOeeWhVmoiKSX01kO7z0ND/QLzukgZew0ZgllDNAnZW6Ec25v51wPYDIwNN+FiUhIqtsGw7vXFQrvmMragTvnZphZZcrc5/WGWwEuv2WJSN49dS08e0twTsEdazmvgZvZDcAZwGrgsM08byAwEKCiQhcAiHiRutZ9zK2w/wA/tcREHO7Ha85lb57rOvDJm9bAUx4bArRxzg3L9n2qqqrczJkzcyhTRHJyyw/gi5TDUx677jiEIjS8whUSZ/f42s7azGY556pS5/NxGuGDwM/y8H1EJF+cS3Td9cP7zMnewzvMmxvkU9j7eOdLTksoZtbVObegbtgXeCt/JYlIs0T0Mvg43aQ8LvfjbcxphOOAXkB7M1sKDAOONrNuJE4jXAycE2aRItII676G33cKzp33CnTY7duhzyWMuIQixOd+vI05C+XkNNOjQqhFRHLViK471J0LGyEuoQiJK1zTrYFH7X68upReImHinBoOGj6NnQdP4aDh0yK5LhpJny1uGN6Dl6RdMvG9rhunm5TH5X68upRevPPdGcZWE9e6fS9hhH1zg3yLw/14FeDiXZwObkXCwukw9vjg3NDPoMXmP1BHYQkjDqEYJ1pCEe98d4axUt02GN4tv5PourOEN8RrCUMaRx24eBeFzjDyXrgLpl4RnGviqYFxW8KQ7BTg4l1cjvh7k8ctX7WEUVwU4OJdqXaGWc/Jfuh0mP9Y8IsicEGORIcCXCKh1DrDrGfepHbdP70ODrqg0GVKxCnARTzIdObNUZP2hUlrg09W1y0ZKMBFPEg9w8ao5f02pwWfdPpE2CXjTs1exWVXwWKnABfxoP6ZN4vanNLwCRHuunXhVXToPHARDwb17ka7VusbhPfjRz4e6fAG/5fkS5I6cBEP+k3qTr/gNTVM7PtmLDpYXXgVHQpwkSZq1vrvqiXwx72Cc1cuh1Zl9Ev/FZGjC6+iQ0soIk3QrLvKVLdtGN7Vq6FVvIJPl+RHhzpwkSbIaeOtRc/DmKODc8NWgVlIVYarVC+8iiIFuEgTNHn9N/WCnJ0Ogl/9K89VFV6pXXgVVQpwkSZo9PrvzNEw+eLgXMTPLpH40Rq4SBM0av23um0wvA++WOEtoVAHLtIEm13/fex8mD02+AUKbgmRAlykidKu/6audZ94H+zZv3BFSUlSgIs0x537wafvBufUdUuBKMBFclG7Ea5tF5wbOB069/BTj5QkBbhIU43oCl+tCM6p6xYPFOAijfXNl3Bjytr3pQtg645+6pGSpwAX2YxN+548v/aEhg+q6xbPdB64SAYT59Rwx4SnG4T3pONeU3hLJCjARTLoN6k701qe9+14Vm1XKtf+nZueWOixKpEkLaGIpPrgFRh1ZGCqcu2DQGLzKe17LVGRtQM3s9FmtsLM5tWbG2Fmb5nZ62b2qJmVh1umSIFUtw2E9+gNfahc+3c2hTdo32uJjsYsoYwB+qTMPQHs6ZzbG3gHGJLnukQK6/WHG1xNObHvm4ywXwXmtO+1REnWJRTn3Awzq0yZe7ze8EXgxPyWJVJAqZfBH3cH7Hfmt3fI0b7XElX5WAM/G3go04NmNhAYCFBRUZGHlxPJkyevgeduDc6lnF2ifa8lypoV4GZ2JbABeDDTc5xzI4GRAFVVVa45ryeSN6ld91lToPJgP7WI5CjnADezs4BjgSOccwpmiYex/WDh08E5ndMtMZVTgJtZH+Ay4FDn3Nf5LUkkBBs3wHXbBucueBXa7eynHpE8yBrgZjYO6AW0N7OlwDASZ520Bp6wxI1ZX3TOnRNinSK5u6ETrE/pM9R1SxFozFkoJ6eZHhVCLSL5tXY1DE85cD54CbRpm/75IjGjKzGlOKUepNyiDVz1kZ9aREKiAJfisnIh3NEzOHf1p9BSv+pSfPRbXYI2bZFadBenpHbd3z8Mzpjop5aIKNqftQAK8JIzcU4NQybMZc36jQDUrFrDkAlzAeL7D3vRczDmmOCcDlIW589aArSdbIkZMfXtb/9Bb7Jm/UZGTH3bU0XNVN02GN4H/67J4T1xTg0HDZ/GzoOncNDwaUycU5PnIv0oup+1NKAOvMRk2go1dlukzh4Lj50fnMuh6y7mLrVoftaSkTrwEpNpK9RYbZFa3TYY3v3vyXnJpJi71KL4WctmKcBLzKDe3Shr1TIwV6gtUpu9VPHvwQ0PVFavhr1PyrmmYu5Sff6spTC0hFJiNi0LFPrMhGYtVTgH16TcM2TAk9Bl/2bX1bm8jJo0YV0MXaqvn7UUjhVyH6qqqio3c+bMgr2eRMdBw6elDcodyst4fvDhmb/w3p/C0peDc3k8wyT1PxZIdKk39t9LQSeRYWaznHNVqfPqwKUgmrxUsXE9XNc+OHfRPCjvkte61KVKnCnApSCatFSRus4NoZ7XrZs2SFzpIKYURKMOqH29smF4D6nRRTkiGagDl4LIulSRGtxbdYRBCwpcpUi8KMClYNIuVXz8DtyVcjbJ0M+ghT4cimSjABd/Urvu3Y+FX2a8vaqIpFCAS+G9+xT8rX9wTuvcIk2mAC9Skd1GNLXrPuxKOPQyP7WIxJwCvAhFcoOm+f+Eh04LzqnrFmkWBXgR2twGTV4CPLXrPmksdO9b+DpEiowCPES+ljEis0HT83fAE1cH59R1i+SNAjwkPpcxvG/QlG7zqd/Ogva7Fub1RUqETrYNic99pr1uIzrpvIbhXb1a4S0SAnXgIfG5jOFlg6YN6+D6DsG5QQthq23De02REqcAD4nvZYyCbtD010Pgw9eS4w67w3kvFea1RUqYllBCUhJ3Q1nzWeIMk/rhfdUKhbdIgagDD0nR7zOdemrgXj+Hn93rpxaREqUAD1FR7jO98n24o0dwbtgqMPNTj0gJy7qEYmajzWyFmc2rN/dzM3vDzGrNrMFtfqRIVbcNhvfhVyXOMFF4i3jRmDXwMUCflLl5QH9gRr4Lkgj64OX0d4M/ZJCfekQEaMQSinNuhplVpszNBzB1XsUvNbhPvA/27J/+uSJSUKGvgZvZQGAgQEVFRdgvJ/ky9xEYPyA4p8vgRSIl9AB3zo0ERgJUVVW5sF9P8iC16x7wJHTZP/1zRcQbnYUiSdNHwNPXB+fUdYtElgJc0m8+deFrsE2ll3JEpHGyBriZjQN6Ae3NbCkwDFgJ3Al0AKaY2avOud5hFioheeRsmDc+OKeuWyQWGnMWyskZHno0z7VIIW34Bq7vGJy7fDGUlad/vohEjpZQStGDJ8GCqclx554w8Blf1YhIjhTgpWTNKvjDTsG5qz+Blq381CMizaIALxW37gGfL02OD7kMDr/SXz0i0mwK8GL32WK4fe/gnDafEikKCvBilnpBTt+7oeepfmoRkbxTgBejFfPh7gOCczo1UKToKMCLTWrXfda/oPIgP7WISKgU4MVi0XMw5pjkeIsyuGq5v3pEJHQK8GKQ2nVfMAfafd9PLSJSMLqpcZzNfSQY3p17Jta6Fd4iJUEdeBzV1sK12wTnBi2Erbb1U4+IeKEOPG6e+2MwvPc6KdF1K7xFSo468LjYsA6u7xCcu3I5tCrzU4+IeKcOPA4mXxwM70MvT3TdCm+RkqYOPMrSbT41dCW0aOmnHhGJFAV4VN1/HLw/Izk+/k+w7+n+6hGRyFGAR83qpXDbHsE5bT4lImkowKPk5m7wZb2rJ08bD7se6a8eEYk0BXgULJ8Hf0nZr0SbT4lIFgpw31Ivg/+/GdBpHz+1iEisKMB9eW8aPHBCclzWDi5/3189IhI7CnAfUrvui+ZCeYWfWkQktnQhTyG9Oi4Y3hU/Tqx1K7xFJAfqwAsh3eZTly+Csm3SPl1EpDHUgYdt+ohgePc8PdF1K7xFpJnUgYdl4wa4+0fw6bvJuatWwBat/dUkIkVFAR6G+ZPhoXp3f+99I/z4N/7qEZGipADPp/VrYMSusO7LxHjnQ+CMx3QZvIiEIusauJmNNrMVZjav3lw7M3vCzBbU/akF3dkPwA3bJ8P7nOfgzH8qvEUkNI05iDkG6JMyNxh4yjnXFXiqblya1qxKnBr42G8T471/kThIuf1efusSkaKXNcCdczOAlSnTfYH76/5+P9Avz3XFw7O3BvfrvuBV6D/SXz0iUlJyXQPfzjn3Yd3flwPb5ameePj8Q7h19+T4wAvgqOv81SMiJanZBzGdc87MXKbHzWwgMBCgoqIIrjj8zxB48e7k+NIFsHVHf/WISMnKNcA/MrNOzrkPzawTsCLTE51zI4GRAFVVVRmDPvI+fQ/u3Dc5Pup6OPB8f/WISMnLNcAfA84Ehtf9OSlvFUWNc/DI2fDGhOTc4CXQpm3mrxERKYCsAW5m44BeQHszWwoMIxHcD5vZAGAxcFKYRXqz7FUYeWhy3O/P0OMUf/WIiNSTNcCdcydneOiIPNcSHbW1MOZoWPJCYlzWDn43H1q18VuXiEg9uhIz1fszEneE3+SUh2G33v7qERHJQAG+ycb1cOd+sGpxYtxxDzjnWWjR0m9dIiIZKMAB3pwED5+RHJ89FSoO8FePiEgjlHaAr/sa/lAJG79JjHc5Ak4br/1LRCQWSjfAZ94Hky9Kjs99Abbr7q8eEZEmKr0A/3ol3LRzctzzNOh7l796RERyVFoBPv0mePqG5PjC12GbnTI/X0QkwkojwD9fBrf+IDn+ySVwxFB/9YiI5EHxB/iUS+GVe5LjQe/BVu391SMikifFG+CfLIA/VSXHfYbDAef6q0dEJM+KL8Cdg4dOg7cmJ+eGLIXW3/VXk4hICIorwGtmwT2HJ8f974W9f+6vHhGREBVHgNfWwqgjEwEOsPX2cNHrsEVrv3WJiIQo/gH+3jR44ITk+NTx0PVIf/WIiBRIfAN8wzq4owd8XpMYd+oBv56mzadEpGTEM8DnjU/cJWeTAU9Cl/391SMi4kG8AvybL2F4F3C1ifFu/wMnj9PmUyJSkuIT4C/fA/+6NDk+72Xo0M1fPSIinsUjwGePTYb3fmfBcbd7LUdEJAriEeAdu0OXH8GJo6Htjr6rERGJhHgE+I5VMOBx31WIiERKC98FiIhIbhTgIiIxpQAXEYkpBbiISEwpwEVEYkoBLiISUwpwEZGYUoCLiMSUOecK92JmHwOLC/aChdce+MR3ERGi9yNI70eS3ougbO/HTs65DqmTBQ3wYmdmM51zVdmfWRr0fgTp/UjSexGU6/uhJRQRkZhSgIuIxJQCPL9G+i4gYvR+BOn9SNJ7EZTT+6E1cBGRmFIHLiISUwpwEZGYUoDnyMxGm9kKM5tXb26Emb1lZq+b2aNmVu6zxkJK937Ue+wSM3Nm1t5HbYWW6b0ws/Prfj/eMLObfNVXaBn+rfQwsxfN7FUzm2lmP/RZY6GYWRcze9rM3qz7Pbiwbr6dmT1hZgvq/tymMd9PAZ67MUCflLkngD2dc3sD7wBDCl2UR2No+H5gZl2Ao4AlhS7IozGkvBdmdhjQF9jHObcHcLOHunwZQ8PfjZuAa5xzPYChdeNSsAG4xDnXHTgAOM/MugODgaecc12Bp+rGWSnAc+ScmwGsTJl73Dm3oW74IlAyN/BM937UuQ24DCiZo+UZ3otzgeHOuW/qnrOi4IV5kuH9cMD36v7eFlhW0KI8cc596JybXff3L4D5wA4k/nO/v+5p9wP9GvP9FODhORv4t+8ifDKzvkCNc+4137VEwG7AT8zsJTObbmb7+y7Is4uAEWb2AYlPI6X0aRUAM6sEegIvAds55z6se2g5sF1jvocCPARmdiWJj0oP+q7FFzPbEriCxMdjSdxAvB2Jj82DgIfNzPyW5NW5wMXOuS7AxcAoz/UUlJltDYwHLnLOfV7/MZc4t7tRn1gV4HlmZmcBxwKnutI+yX4XYGfgNTNbRGI5abaZbe+1Kn+WAhNcwstALYkNjErVmcCEur//AyiJg5gAZtaKRHg/6Jzb9B58ZGad6h7vBDRqiU0Bnkdm1ofEeu/xzrmvfdfjk3NurnOuo3Ou0jlXSSLA9nXOLfdcmi8TgcMAzGw34DuU9m58y4BD6/5+OLDAYy0FU/epaxQw3zl3a72HHiPxnxp1f05q1Pcr7SYxd2Y2DuhFoov6CBhGYh2vNfBp3dNedM6d46XAAkv3fjjnRtV7fBFQ5Zwr+tDK8LvxADAa6AGsAy51zk3zVWMhZXg/3gZuJ7G0tBb4jXNulq8aC8XMDgaeBeaS+BQGiaXGl4CHgQoSW26f5JxLd1JA8PspwEVE4klLKCIiMaUAFxGJKQW4iEhMKcBFRGJKAS4iElMKcBGRmFKAi4jE1P8DJKmZuwUdJOkAAAAASUVORK5CYII=\n" 869 | }, 870 | "metadata": { 871 | "needs_background": "light" 872 | } 873 | } 874 | ], 875 | "source": [ 876 | "x_numpy = x.data.numpy()\n", 877 | "y_numpy = y.data.numpy()\n", 878 | "y_pred = y_pred.data.numpy()\n", 879 | "plt.plot(x_numpy,y_numpy,'o')\n", 880 | "plt.plot(x_numpy,y_pred,'-')" 881 | ] 882 | }, 883 | { 884 | "cell_type": "code", 885 | "execution_count": 42, 886 | "metadata": { 887 | "collapsed": true, 888 | "id": "65OROyBG61GG" 889 | }, 890 | "outputs": [], 891 | "source": [ 892 | "x = Variable(torch.ones(4, 4) * 12.5, requires_grad=True)" 893 | ] 894 | }, 895 | { 896 | "cell_type": "code", 897 | "execution_count": 43, 898 | "metadata": { 899 | "colab": { 900 | "base_uri": "https://localhost:8080/" 901 | }, 902 | "id": "I5DKAY7w61GH", 903 | "outputId": "91f4dfe3-5859-4887-94f5-5c4508f64114" 904 | }, 905 | "outputs": [ 906 | { 907 | "output_type": "execute_result", 908 | "data": { 909 | "text/plain": [ 910 | "tensor([[12.5000, 12.5000, 12.5000, 12.5000],\n", 911 | " [12.5000, 12.5000, 12.5000, 12.5000],\n", 912 | " [12.5000, 12.5000, 12.5000, 12.5000],\n", 913 | " [12.5000, 12.5000, 12.5000, 12.5000]], requires_grad=True)" 914 | ] 915 | }, 916 | "metadata": {}, 917 | "execution_count": 43 918 | } 919 | ], 920 | "source": [ 921 | "x" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": 44, 927 | "metadata": { 928 | "collapsed": true, 929 | "id": "06ZRd18A61GH" 930 | }, 931 | "outputs": [], 932 | "source": [ 933 | "fn = 2 * (x * x) + 5 * x + 6\n", 934 | "\n", 935 | "# 2x^2 + 5x + 6" 936 | ] 937 | }, 938 | { 939 | "cell_type": "code", 940 | "execution_count": 45, 941 | "metadata": { 942 | "id": "8R_j4jfs61GI" 943 | }, 944 | "outputs": [], 945 | "source": [ 946 | "fn.backward(torch.ones(4,4))" 947 | ] 948 | }, 949 | { 950 | "cell_type": "code", 951 | "execution_count": 46, 952 | "metadata": { 953 | "colab": { 954 | "base_uri": "https://localhost:8080/" 955 | }, 956 | "id": "XXyZ_ZP761GI", 957 | "outputId": "8e278dba-0c88-428c-fbf6-9174cfd721a0" 958 | }, 959 | "outputs": [ 960 | { 961 | "output_type": "stream", 962 | "name": "stdout", 963 | "text": [ 964 | "tensor([[55., 55., 55., 55.],\n", 965 | " [55., 55., 55., 55.],\n", 966 | " [55., 55., 55., 55.],\n", 967 | " [55., 55., 55., 55.]])\n" 968 | ] 969 | } 970 | ], 971 | "source": [ 972 | "print(x.grad)" 973 | ] 974 | }, 975 | { 976 | "cell_type": "code", 977 | "source": [], 978 | "metadata": { 979 | "id": "bd6fd4ly9coh" 980 | }, 981 | "execution_count": null, 982 | "outputs": [] 983 | } 984 | ], 985 | "metadata": { 986 | "kernelspec": { 987 | "display_name": "Python 3", 988 | "language": "python", 989 | "name": "python3" 990 | }, 991 | "language_info": { 992 | "codemirror_mode": { 993 | "name": "ipython", 994 | "version": 3 995 | }, 996 | "file_extension": ".py", 997 | "mimetype": "text/x-python", 998 | "name": "python", 999 | "nbconvert_exporter": "python", 1000 | "pygments_lexer": "ipython3", 1001 | "version": "3.6.4" 1002 | }, 1003 | "colab": { 1004 | "name": "Torch_AI_4_2Ed.ipynb", 1005 | "provenance": [], 1006 | "collapsed_sections": [] 1007 | } 1008 | }, 1009 | "nbformat": 4, 1010 | "nbformat_minor": 0 1011 | } -------------------------------------------------------------------------------- /Torch_AI_2_2Ed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "omrGLNyylOxD" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "source": [ 17 | "print(torch.cuda.is_available())" 18 | ], 19 | "metadata": { 20 | "colab": { 21 | "base_uri": "https://localhost:8080/" 22 | }, 23 | "id": "bSAwTCvT0bmS", 24 | "outputId": "cfa2fd4e-32c7-4c20-fdf6-7b41e2a9544c" 25 | }, 26 | "execution_count": null, 27 | "outputs": [ 28 | { 29 | "output_type": "stream", 30 | "name": "stdout", 31 | "text": [ 32 | "False\n" 33 | ] 34 | } 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "source": [ 40 | "# CUDA is an API developed by NVIDIA to interface GPU" 41 | ], 42 | "metadata": { 43 | "id": "R2AXFFJd0bas" 44 | }, 45 | "execution_count": null, 46 | "outputs": [] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "source": [ 51 | "x = torch.randn(10)\n", 52 | "print(x.device)" 53 | ], 54 | "metadata": { 55 | "colab": { 56 | "base_uri": "https://localhost:8080/" 57 | }, 58 | "id": "f9whLITJ0bKS", 59 | "outputId": "84a6c8e8-e592-466b-eda5-55a440781b10" 60 | }, 61 | "execution_count": null, 62 | "outputs": [ 63 | { 64 | "output_type": "stream", 65 | "name": "stdout", 66 | "text": [ 67 | "cpu\n" 68 | ] 69 | } 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "id": "QRU4qnYolOxE" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "# how to perform random sampling of the tensors" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "colab": { 88 | "base_uri": "https://localhost:8080/" 89 | }, 90 | "id": "aHvKASaflOxF", 91 | "outputId": "195b235a-cd2e-4a20-d8a8-d946a39a89d5" 92 | }, 93 | "outputs": [ 94 | { 95 | "output_type": "execute_result", 96 | "data": { 97 | "text/plain": [ 98 | "" 99 | ] 100 | }, 101 | "metadata": {}, 102 | "execution_count": 3 103 | } 104 | ], 105 | "source": [ 106 | "torch.manual_seed(1234)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "colab": { 114 | "base_uri": "https://localhost:8080/" 115 | }, 116 | "id": "im5Spt1UlOxH", 117 | "outputId": "503f7232-e9be-4956-eae7-682c27076b13" 118 | }, 119 | "outputs": [ 120 | { 121 | "output_type": "execute_result", 122 | "data": { 123 | "text/plain": [ 124 | "tensor([[-0.1117, -0.4966, 0.1631, -0.8817],\n", 125 | " [ 0.0539, 0.6684, -0.0597, -0.4675],\n", 126 | " [-0.2153, 0.8840, -0.7584, -0.3689],\n", 127 | " [-0.3424, -1.4020, 0.3206, -1.0219]])" 128 | ] 129 | }, 130 | "metadata": {}, 131 | "execution_count": 4 132 | } 133 | ], 134 | "source": [ 135 | "torch.manual_seed(1234)\n", 136 | "torch.randn(4,4)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": { 143 | "id": "ekFUX0oslOxI" 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "#generate random numbers from a statistical distribution" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "colab": { 155 | "base_uri": "https://localhost:8080/" 156 | }, 157 | "id": "kSo2SOnjlOxI", 158 | "outputId": "dd7b324a-4004-453c-cdf0-6266d2fbff01" 159 | }, 160 | "outputs": [ 161 | { 162 | "output_type": "execute_result", 163 | "data": { 164 | "text/plain": [ 165 | "tensor([[0.2837, 0.6567, 0.2388, 0.7313],\n", 166 | " [0.6012, 0.3043, 0.2548, 0.6294],\n", 167 | " [0.9665, 0.7399, 0.4517, 0.4757],\n", 168 | " [0.7842, 0.1525, 0.6662, 0.3343]])" 169 | ] 170 | }, 171 | "metadata": {}, 172 | "execution_count": 6 173 | } 174 | ], 175 | "source": [ 176 | "torch.Tensor(4, 4).uniform_(0, 1) #random number from uniform distribution" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": { 183 | "id": "cJ1in1selOxJ" 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "#now apply the distribution assuming the input values from the \n", 188 | "#tensor are probabilities" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": { 195 | "colab": { 196 | "base_uri": "https://localhost:8080/" 197 | }, 198 | "id": "r2lilTLrlOxK", 199 | "outputId": "df75cf14-b5aa-49e9-ee46-51672026f94d" 200 | }, 201 | "outputs": [ 202 | { 203 | "output_type": "execute_result", 204 | "data": { 205 | "text/plain": [ 206 | "tensor([[0., 0., 0., 0.],\n", 207 | " [1., 0., 1., 0.],\n", 208 | " [1., 0., 1., 1.],\n", 209 | " [0., 0., 0., 0.]])" 210 | ] 211 | }, 212 | "metadata": {}, 213 | "execution_count": 8 214 | } 215 | ], 216 | "source": [ 217 | "torch.bernoulli(torch.Tensor(4, 4).uniform_(0, 1))" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": { 224 | "id": "NixFXUm8lOxK" 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "#how to perform sampling from a multinomial distribution" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": { 235 | "colab": { 236 | "base_uri": "https://localhost:8080/" 237 | }, 238 | "id": "5QiMoFXzlOxL", 239 | "outputId": "4e8cba6f-cbdc-480f-f077-6127363f6f13" 240 | }, 241 | "outputs": [ 242 | { 243 | "output_type": "execute_result", 244 | "data": { 245 | "text/plain": [ 246 | "tensor([10., 10., 13., 10., 34., 45., 65., 67., 87., 89., 87., 34.])" 247 | ] 248 | }, 249 | "metadata": {}, 250 | "execution_count": 10 251 | } 252 | ], 253 | "source": [ 254 | "torch.Tensor([10, 10, 13, 10,34,45,65,67,87,89,87,34])" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": { 261 | "colab": { 262 | "base_uri": "https://localhost:8080/" 263 | }, 264 | "id": "ko_bwKdrlOxM", 265 | "outputId": "f5348833-d6a3-414d-e17e-f6826c32f8e2" 266 | }, 267 | "outputs": [ 268 | { 269 | "output_type": "execute_result", 270 | "data": { 271 | "text/plain": [ 272 | "tensor([4, 5, 7])" 273 | ] 274 | }, 275 | "metadata": {}, 276 | "execution_count": 11 277 | } 278 | ], 279 | "source": [ 280 | "torch.multinomial(torch.tensor([10., 10., 13., 10., \n", 281 | " 34., 45., 65., 67., \n", 282 | " 87., 89., 87., 34.]), \n", 283 | " 3)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": { 290 | "colab": { 291 | "base_uri": "https://localhost:8080/" 292 | }, 293 | "id": "4Z-2q3S7lOxM", 294 | "outputId": "903257f5-ccdb-457d-f592-6e6eb00bae05" 295 | }, 296 | "outputs": [ 297 | { 298 | "output_type": "execute_result", 299 | "data": { 300 | "text/plain": [ 301 | "tensor([10, 5, 9, 10, 5])" 302 | ] 303 | }, 304 | "metadata": {}, 305 | "execution_count": 12 306 | } 307 | ], 308 | "source": [ 309 | "torch.multinomial(torch.tensor([10., 10., 13., 10., \n", 310 | " 34., 45., 65., 67., \n", 311 | " 87., 89., 87., 34.]), \n", 312 | " 5, replacement=True)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": { 319 | "id": "LKo_-nV3lOxN" 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "#generate random numbers from the normal distribution" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": { 330 | "colab": { 331 | "base_uri": "https://localhost:8080/" 332 | }, 333 | "id": "PkmfwYNXlOxN", 334 | "outputId": "b5ab4453-1952-4eda-ece5-9471e9b6fde6" 335 | }, 336 | "outputs": [ 337 | { 338 | "output_type": "execute_result", 339 | "data": { 340 | "text/plain": [ 341 | "tensor([1.5236, 2.2441, 2.7375, 3.9521, 5.4380, 5.5158, 8.2489, 8.1645, 9.0575,\n", 342 | " 9.8627])" 343 | ] 344 | }, 345 | "metadata": {}, 346 | "execution_count": 14 347 | } 348 | ], 349 | "source": [ 350 | "torch.normal(mean=torch.arange(1., 11.), \n", 351 | " std=torch.arange(1, 0, -0.1))" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": { 358 | "colab": { 359 | "base_uri": "https://localhost:8080/" 360 | }, 361 | "id": "oXu1fQ8VlOxN", 362 | "outputId": "5a0aef29-c9cf-4bde-f78f-97719fe80d9a" 363 | }, 364 | "outputs": [ 365 | { 366 | "output_type": "execute_result", 367 | "data": { 368 | "text/plain": [ 369 | "tensor([ 1.1144, 0.0361, 1.2766, -1.3999, -0.1648])" 370 | ] 371 | }, 372 | "metadata": {}, 373 | "execution_count": 15 374 | } 375 | ], 376 | "source": [ 377 | "torch.normal(mean=0.5, \n", 378 | " std=torch.arange(1., 6.))" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "colab": { 386 | "base_uri": "https://localhost:8080/" 387 | }, 388 | "id": "UxK2xVShlOxO", 389 | "outputId": "7d21742b-65aa-40cd-a06c-a416cdc4e524" 390 | }, 391 | "outputs": [ 392 | { 393 | "output_type": "execute_result", 394 | "data": { 395 | "text/plain": [ 396 | "tensor([-0.0844])" 397 | ] 398 | }, 399 | "metadata": {}, 400 | "execution_count": 16 401 | } 402 | ], 403 | "source": [ 404 | "torch.normal(mean=0.5, \n", 405 | " std=torch.arange(0.2,0.6))" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": { 412 | "colab": { 413 | "base_uri": "https://localhost:8080/" 414 | }, 415 | "id": "mioQhg_-lOxO", 416 | "outputId": "b0bd56ea-8d52-413e-c9f0-13bdc733d6d0" 417 | }, 418 | "outputs": [ 419 | { 420 | "output_type": "execute_result", 421 | "data": { 422 | "text/plain": [ 423 | "tensor(45.9167)" 424 | ] 425 | }, 426 | "metadata": {}, 427 | "execution_count": 17 428 | } 429 | ], 430 | "source": [ 431 | "#computing the descriptive statistics: mean\n", 432 | "torch.mean(torch.tensor([10., 10., 13., 10., 34., \n", 433 | " 45., 65., 67., 87., 89., 87., 34.]))" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "metadata": { 440 | "colab": { 441 | "base_uri": "https://localhost:8080/" 442 | }, 443 | "id": "3gi7FsB1lOxP", 444 | "outputId": "1e1d3e91-4418-42d5-f6f2-ee52ed2d2ac3" 445 | }, 446 | "outputs": [ 447 | { 448 | "output_type": "execute_result", 449 | "data": { 450 | "text/plain": [ 451 | "tensor([[-1.6406, 0.9295, 1.2907, 0.2612, 0.9711],\n", 452 | " [ 0.3551, 0.8562, -0.3635, -0.1552, -1.2282],\n", 453 | " [ 1.2445, 1.1750, -0.2217, -2.0901, -1.2658],\n", 454 | " [-1.8761, -0.6066, 0.7470, 0.4811, 0.6234]])" 455 | ] 456 | }, 457 | "metadata": {}, 458 | "execution_count": 18 459 | } 460 | ], 461 | "source": [ 462 | "# mean across rows and across columns\n", 463 | "d = torch.randn(4, 5)\n", 464 | "d" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": null, 470 | "metadata": { 471 | "colab": { 472 | "base_uri": "https://localhost:8080/" 473 | }, 474 | "id": "6IvPjHWUlOxP", 475 | "outputId": "9ca2974e-781b-48f1-b231-93301a41910f" 476 | }, 477 | "outputs": [ 478 | { 479 | "output_type": "execute_result", 480 | "data": { 481 | "text/plain": [ 482 | "tensor([-0.4793, 0.5885, 0.3631, -0.3757, -0.2249])" 483 | ] 484 | }, 485 | "metadata": {}, 486 | "execution_count": 19 487 | } 488 | ], 489 | "source": [ 490 | "torch.mean(d,dim=0)" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": { 497 | "colab": { 498 | "base_uri": "https://localhost:8080/" 499 | }, 500 | "id": "CKzLUDMKlOxP", 501 | "outputId": "e1bde541-6e72-408a-c59d-cd52522c2272" 502 | }, 503 | "outputs": [ 504 | { 505 | "output_type": "execute_result", 506 | "data": { 507 | "text/plain": [ 508 | "tensor([ 0.3624, -0.1071, -0.2316, -0.1262])" 509 | ] 510 | }, 511 | "metadata": {}, 512 | "execution_count": 20 513 | } 514 | ], 515 | "source": [ 516 | "torch.mean(d,dim=1)" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": null, 522 | "metadata": { 523 | "colab": { 524 | "base_uri": "https://localhost:8080/" 525 | }, 526 | "id": "yPdAnDCHlOxQ", 527 | "outputId": "bf06580d-09c2-4556-f350-95f05a8353db" 528 | }, 529 | "outputs": [ 530 | { 531 | "output_type": "execute_result", 532 | "data": { 533 | "text/plain": [ 534 | "torch.return_types.median(\n", 535 | "values=tensor([-1.6406, 0.8562, -0.2217, -0.1552, -1.2282]),\n", 536 | "indices=tensor([0, 1, 2, 1, 1]))" 537 | ] 538 | }, 539 | "metadata": {}, 540 | "execution_count": 21 541 | } 542 | ], 543 | "source": [ 544 | "#compute median\n", 545 | "torch.median(d,dim=0)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": null, 551 | "metadata": { 552 | "colab": { 553 | "base_uri": "https://localhost:8080/" 554 | }, 555 | "id": "MIL1nnAvlOxQ", 556 | "outputId": "67ad766e-ec6e-4016-8232-8c3a8b472f48" 557 | }, 558 | "outputs": [ 559 | { 560 | "output_type": "execute_result", 561 | "data": { 562 | "text/plain": [ 563 | "torch.return_types.median(\n", 564 | "values=tensor([ 0.9295, -0.1552, -0.2217, 0.4811]),\n", 565 | "indices=tensor([1, 3, 2, 3]))" 566 | ] 567 | }, 568 | "metadata": {}, 569 | "execution_count": 22 570 | } 571 | ], 572 | "source": [ 573 | "torch.median(d,dim=1)" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": null, 579 | "metadata": { 580 | "colab": { 581 | "base_uri": "https://localhost:8080/" 582 | }, 583 | "id": "5V2sseCblOxQ", 584 | "outputId": "a2b18314-df83-427a-ebc5-06fc9eceaca5" 585 | }, 586 | "outputs": [ 587 | { 588 | "output_type": "execute_result", 589 | "data": { 590 | "text/plain": [ 591 | "torch.return_types.mode(\n", 592 | "values=tensor([-1.6406, -1.2282, -2.0901, -1.8761]),\n", 593 | "indices=tensor([0, 4, 3, 0]))" 594 | ] 595 | }, 596 | "metadata": {}, 597 | "execution_count": 23 598 | } 599 | ], 600 | "source": [ 601 | "# compute the mode\n", 602 | "torch.mode(d)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": { 609 | "colab": { 610 | "base_uri": "https://localhost:8080/" 611 | }, 612 | "id": "eAwqQrC5lOxR", 613 | "outputId": "2a3c9ae8-e20f-4842-a1d1-3deac12f734a" 614 | }, 615 | "outputs": [ 616 | { 617 | "output_type": "execute_result", 618 | "data": { 619 | "text/plain": [ 620 | "torch.return_types.mode(\n", 621 | "values=tensor([-1.8761, -0.6066, -0.3635, -2.0901, -1.2658]),\n", 622 | "indices=tensor([3, 3, 1, 2, 2]))" 623 | ] 624 | }, 625 | "metadata": {}, 626 | "execution_count": 24 627 | } 628 | ], 629 | "source": [ 630 | "torch.mode(d,dim=0)" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "metadata": { 637 | "colab": { 638 | "base_uri": "https://localhost:8080/" 639 | }, 640 | "id": "t9T8cEVNlOxR", 641 | "outputId": "d7192b35-b0e3-44b4-a08e-501113d29e4e" 642 | }, 643 | "outputs": [ 644 | { 645 | "output_type": "execute_result", 646 | "data": { 647 | "text/plain": [ 648 | "torch.return_types.mode(\n", 649 | "values=tensor([-1.6406, -1.2282, -2.0901, -1.8761]),\n", 650 | "indices=tensor([0, 4, 3, 0]))" 651 | ] 652 | }, 653 | "metadata": {}, 654 | "execution_count": 25 655 | } 656 | ], 657 | "source": [ 658 | "torch.mode(d,dim=1)" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": null, 664 | "metadata": { 665 | "colab": { 666 | "base_uri": "https://localhost:8080/" 667 | }, 668 | "id": "_Qi-Q2pLlOxR", 669 | "outputId": "a1fd67e9-35a7-45eb-f172-99ace52a658a" 670 | }, 671 | "outputs": [ 672 | { 673 | "output_type": "execute_result", 674 | "data": { 675 | "text/plain": [ 676 | "tensor(1.0944)" 677 | ] 678 | }, 679 | "metadata": {}, 680 | "execution_count": 26 681 | } 682 | ], 683 | "source": [ 684 | "#compute the standard deviation\n", 685 | "torch.std(d)" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": null, 691 | "metadata": { 692 | "colab": { 693 | "base_uri": "https://localhost:8080/" 694 | }, 695 | "id": "1MXZa3yMlOxS", 696 | "outputId": "7aebe106-b5ba-4c0d-e44d-ebcfb5c178ed" 697 | }, 698 | "outputs": [ 699 | { 700 | "output_type": "execute_result", 701 | "data": { 702 | "text/plain": [ 703 | "tensor([1.5240, 0.8083, 0.7911, 1.1730, 1.1889])" 704 | ] 705 | }, 706 | "metadata": {}, 707 | "execution_count": 27 708 | } 709 | ], 710 | "source": [ 711 | "torch.std(d,dim=0)" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": null, 717 | "metadata": { 718 | "colab": { 719 | "base_uri": "https://localhost:8080/" 720 | }, 721 | "id": "i9dZjicllOxS", 722 | "outputId": "e63ba7ce-9751-4e2c-a3ea-b893af51e13f" 723 | }, 724 | "outputs": [ 725 | { 726 | "output_type": "execute_result", 727 | "data": { 728 | "text/plain": [ 729 | "tensor([1.1807, 0.7852, 1.4732, 1.1165])" 730 | ] 731 | }, 732 | "metadata": {}, 733 | "execution_count": 28 734 | } 735 | ], 736 | "source": [ 737 | "torch.std(d,dim=1)" 738 | ] 739 | }, 740 | { 741 | "cell_type": "code", 742 | "execution_count": null, 743 | "metadata": { 744 | "colab": { 745 | "base_uri": "https://localhost:8080/" 746 | }, 747 | "id": "q4YFfwpPlOxS", 748 | "outputId": "2cd24f43-b502-4487-8801-f39a8a37a5b0" 749 | }, 750 | "outputs": [ 751 | { 752 | "output_type": "execute_result", 753 | "data": { 754 | "text/plain": [ 755 | "tensor(1.1978)" 756 | ] 757 | }, 758 | "metadata": {}, 759 | "execution_count": 29 760 | } 761 | ], 762 | "source": [ 763 | "#compute variance\n", 764 | "torch.var(d)" 765 | ] 766 | }, 767 | { 768 | "cell_type": "code", 769 | "execution_count": null, 770 | "metadata": { 771 | "colab": { 772 | "base_uri": "https://localhost:8080/" 773 | }, 774 | "id": "gZravnGDlOxS", 775 | "outputId": "70d0eed9-f26a-4cbe-a820-4f102c90adbe" 776 | }, 777 | "outputs": [ 778 | { 779 | "output_type": "execute_result", 780 | "data": { 781 | "text/plain": [ 782 | "tensor([2.3224, 0.6534, 0.6259, 1.3758, 1.4134])" 783 | ] 784 | }, 785 | "metadata": {}, 786 | "execution_count": 30 787 | } 788 | ], 789 | "source": [ 790 | "torch.var(d,dim=0)" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": null, 796 | "metadata": { 797 | "colab": { 798 | "base_uri": "https://localhost:8080/" 799 | }, 800 | "id": "Wq_RCgmVlOxS", 801 | "outputId": "dacf2b5a-2c5a-430e-8123-a70b11111dec" 802 | }, 803 | "outputs": [ 804 | { 805 | "output_type": "execute_result", 806 | "data": { 807 | "text/plain": [ 808 | "tensor([1.3940, 0.6166, 2.1703, 1.2466])" 809 | ] 810 | }, 811 | "metadata": {}, 812 | "execution_count": 31 813 | } 814 | ], 815 | "source": [ 816 | "torch.var(d,dim=1)" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "execution_count": null, 822 | "metadata": { 823 | "colab": { 824 | "base_uri": "https://localhost:8080/" 825 | }, 826 | "id": "yfeNnEmMlOxT", 827 | "outputId": "885dcab6-140b-4d1e-d8cc-a31da2bfe19e" 828 | }, 829 | "outputs": [ 830 | { 831 | "output_type": "execute_result", 832 | "data": { 833 | "text/plain": [ 834 | "tensor(-2.0901)" 835 | ] 836 | }, 837 | "metadata": {}, 838 | "execution_count": 32 839 | } 840 | ], 841 | "source": [ 842 | "# compute min and max\n", 843 | "torch.min(d)" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": null, 849 | "metadata": { 850 | "colab": { 851 | "base_uri": "https://localhost:8080/" 852 | }, 853 | "id": "Ll8qp443lOxT", 854 | "outputId": "f3977565-997a-440d-d165-5db65c57e2b4" 855 | }, 856 | "outputs": [ 857 | { 858 | "output_type": "execute_result", 859 | "data": { 860 | "text/plain": [ 861 | "torch.return_types.min(\n", 862 | "values=tensor([-1.8761, -0.6066, -0.3635, -2.0901, -1.2658]),\n", 863 | "indices=tensor([3, 3, 1, 2, 2]))" 864 | ] 865 | }, 866 | "metadata": {}, 867 | "execution_count": 33 868 | } 869 | ], 870 | "source": [ 871 | "torch.min(d,dim=0)" 872 | ] 873 | }, 874 | { 875 | "cell_type": "code", 876 | "execution_count": null, 877 | "metadata": { 878 | "colab": { 879 | "base_uri": "https://localhost:8080/" 880 | }, 881 | "id": "rqXNNp1dlOxT", 882 | "outputId": "ac287893-2707-43c9-88c0-69c31cd299c1" 883 | }, 884 | "outputs": [ 885 | { 886 | "output_type": "execute_result", 887 | "data": { 888 | "text/plain": [ 889 | "torch.return_types.min(\n", 890 | "values=tensor([-1.6406, -1.2282, -2.0901, -1.8761]),\n", 891 | "indices=tensor([0, 4, 3, 0]))" 892 | ] 893 | }, 894 | "metadata": {}, 895 | "execution_count": 34 896 | } 897 | ], 898 | "source": [ 899 | "torch.min(d,dim=1)" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": null, 905 | "metadata": { 906 | "colab": { 907 | "base_uri": "https://localhost:8080/" 908 | }, 909 | "id": "zhEQejqwlOxT", 910 | "outputId": "95e20b83-4743-44bb-87ad-801813c693f5" 911 | }, 912 | "outputs": [ 913 | { 914 | "output_type": "execute_result", 915 | "data": { 916 | "text/plain": [ 917 | "tensor(1.2907)" 918 | ] 919 | }, 920 | "metadata": {}, 921 | "execution_count": 35 922 | } 923 | ], 924 | "source": [ 925 | "torch.max(d)" 926 | ] 927 | }, 928 | { 929 | "cell_type": "code", 930 | "execution_count": null, 931 | "metadata": { 932 | "colab": { 933 | "base_uri": "https://localhost:8080/" 934 | }, 935 | "id": "JIqkH8e-lOxT", 936 | "outputId": "d0aedddc-b561-4c20-a39f-7edf8d6cbd54" 937 | }, 938 | "outputs": [ 939 | { 940 | "output_type": "execute_result", 941 | "data": { 942 | "text/plain": [ 943 | "torch.return_types.max(\n", 944 | "values=tensor([1.2445, 1.1750, 1.2907, 0.4811, 0.9711]),\n", 945 | "indices=tensor([2, 2, 0, 3, 0]))" 946 | ] 947 | }, 948 | "metadata": {}, 949 | "execution_count": 36 950 | } 951 | ], 952 | "source": [ 953 | "torch.max(d,dim=0)" 954 | ] 955 | }, 956 | { 957 | "cell_type": "code", 958 | "execution_count": null, 959 | "metadata": { 960 | "colab": { 961 | "base_uri": "https://localhost:8080/" 962 | }, 963 | "id": "OzjXObVmlOxU", 964 | "outputId": "944ac37f-e53d-4b59-c962-8e0f6f4bfad6" 965 | }, 966 | "outputs": [ 967 | { 968 | "output_type": "execute_result", 969 | "data": { 970 | "text/plain": [ 971 | "torch.return_types.max(\n", 972 | "values=tensor([1.2907, 0.8562, 1.2445, 0.7470]),\n", 973 | "indices=tensor([2, 1, 0, 2]))" 974 | ] 975 | }, 976 | "metadata": {}, 977 | "execution_count": 37 978 | } 979 | ], 980 | "source": [ 981 | "torch.max(d,dim=1)" 982 | ] 983 | }, 984 | { 985 | "cell_type": "code", 986 | "execution_count": null, 987 | "metadata": { 988 | "colab": { 989 | "base_uri": "https://localhost:8080/" 990 | }, 991 | "id": "1FkKZf85lOxU", 992 | "outputId": "6abb196d-3fc4-4f6b-e1c0-0b8035b05e35" 993 | }, 994 | "outputs": [ 995 | { 996 | "output_type": "execute_result", 997 | "data": { 998 | "text/plain": [ 999 | "torch.return_types.sort(\n", 1000 | "values=tensor([[-1.6406, 0.2612, 0.9295, 0.9711, 1.2907],\n", 1001 | " [-1.2282, -0.3635, -0.1552, 0.3551, 0.8562],\n", 1002 | " [-2.0901, -1.2658, -0.2217, 1.1750, 1.2445],\n", 1003 | " [-1.8761, -0.6066, 0.4811, 0.6234, 0.7470]]),\n", 1004 | "indices=tensor([[0, 3, 1, 4, 2],\n", 1005 | " [4, 2, 3, 0, 1],\n", 1006 | " [3, 4, 2, 1, 0],\n", 1007 | " [0, 1, 3, 4, 2]]))" 1008 | ] 1009 | }, 1010 | "metadata": {}, 1011 | "execution_count": 38 1012 | } 1013 | ], 1014 | "source": [ 1015 | "# sorting a tensor\n", 1016 | "torch.sort(d)" 1017 | ] 1018 | }, 1019 | { 1020 | "cell_type": "code", 1021 | "execution_count": null, 1022 | "metadata": { 1023 | "colab": { 1024 | "base_uri": "https://localhost:8080/" 1025 | }, 1026 | "id": "j1TzTTp4lOxU", 1027 | "outputId": "b254cbe1-6eae-4dbd-e70d-3fad3407c5ee" 1028 | }, 1029 | "outputs": [ 1030 | { 1031 | "output_type": "execute_result", 1032 | "data": { 1033 | "text/plain": [ 1034 | "torch.return_types.sort(\n", 1035 | "values=tensor([[-1.8761, -0.6066, -0.3635, -2.0901, -1.2658],\n", 1036 | " [-1.6406, 0.8562, -0.2217, -0.1552, -1.2282],\n", 1037 | " [ 0.3551, 0.9295, 0.7470, 0.2612, 0.6234],\n", 1038 | " [ 1.2445, 1.1750, 1.2907, 0.4811, 0.9711]]),\n", 1039 | "indices=tensor([[3, 3, 1, 2, 2],\n", 1040 | " [0, 1, 2, 1, 1],\n", 1041 | " [1, 0, 3, 0, 3],\n", 1042 | " [2, 2, 0, 3, 0]]))" 1043 | ] 1044 | }, 1045 | "metadata": {}, 1046 | "execution_count": 39 1047 | } 1048 | ], 1049 | "source": [ 1050 | "torch.sort(d,dim=0)" 1051 | ] 1052 | }, 1053 | { 1054 | "cell_type": "code", 1055 | "execution_count": null, 1056 | "metadata": { 1057 | "colab": { 1058 | "base_uri": "https://localhost:8080/" 1059 | }, 1060 | "id": "pB-J5XrnlOxU", 1061 | "outputId": "32be57c6-d0d0-4920-895b-4018933dd2c3" 1062 | }, 1063 | "outputs": [ 1064 | { 1065 | "output_type": "execute_result", 1066 | "data": { 1067 | "text/plain": [ 1068 | "torch.return_types.sort(\n", 1069 | "values=tensor([[ 1.2445, 1.1750, 1.2907, 0.4811, 0.9711],\n", 1070 | " [ 0.3551, 0.9295, 0.7470, 0.2612, 0.6234],\n", 1071 | " [-1.6406, 0.8562, -0.2217, -0.1552, -1.2282],\n", 1072 | " [-1.8761, -0.6066, -0.3635, -2.0901, -1.2658]]),\n", 1073 | "indices=tensor([[2, 2, 0, 3, 0],\n", 1074 | " [1, 0, 3, 0, 3],\n", 1075 | " [0, 1, 2, 1, 1],\n", 1076 | " [3, 3, 1, 2, 2]]))" 1077 | ] 1078 | }, 1079 | "metadata": {}, 1080 | "execution_count": 40 1081 | } 1082 | ], 1083 | "source": [ 1084 | "torch.sort(d,dim=0,descending=True)" 1085 | ] 1086 | }, 1087 | { 1088 | "cell_type": "code", 1089 | "execution_count": null, 1090 | "metadata": { 1091 | "colab": { 1092 | "base_uri": "https://localhost:8080/" 1093 | }, 1094 | "id": "4FiVXGT3lOxU", 1095 | "outputId": "96857d5e-d1f8-42a5-8a60-711600a459da" 1096 | }, 1097 | "outputs": [ 1098 | { 1099 | "output_type": "execute_result", 1100 | "data": { 1101 | "text/plain": [ 1102 | "torch.return_types.sort(\n", 1103 | "values=tensor([[ 1.2907, 0.9711, 0.9295, 0.2612, -1.6406],\n", 1104 | " [ 0.8562, 0.3551, -0.1552, -0.3635, -1.2282],\n", 1105 | " [ 1.2445, 1.1750, -0.2217, -1.2658, -2.0901],\n", 1106 | " [ 0.7470, 0.6234, 0.4811, -0.6066, -1.8761]]),\n", 1107 | "indices=tensor([[2, 4, 1, 3, 0],\n", 1108 | " [1, 0, 3, 2, 4],\n", 1109 | " [0, 1, 2, 4, 3],\n", 1110 | " [2, 4, 3, 1, 0]]))" 1111 | ] 1112 | }, 1113 | "metadata": {}, 1114 | "execution_count": 41 1115 | } 1116 | ], 1117 | "source": [ 1118 | "torch.sort(d,dim=1,descending=True)" 1119 | ] 1120 | }, 1121 | { 1122 | "cell_type": "code", 1123 | "execution_count": null, 1124 | "metadata": { 1125 | "id": "ylndSPY7lOxU" 1126 | }, 1127 | "outputs": [], 1128 | "source": [ 1129 | "from torch.autograd import Variable" 1130 | ] 1131 | }, 1132 | { 1133 | "cell_type": "code", 1134 | "execution_count": null, 1135 | "metadata": { 1136 | "colab": { 1137 | "base_uri": "https://localhost:8080/" 1138 | }, 1139 | "id": "0i4C-GmrlOxV", 1140 | "outputId": "18b240e6-455d-4541-ddef-d09b40beaab0" 1141 | }, 1142 | "outputs": [ 1143 | { 1144 | "output_type": "execute_result", 1145 | "data": { 1146 | "text/plain": [ 1147 | "tensor([[1., 1.],\n", 1148 | " [1., 1.]], requires_grad=True)" 1149 | ] 1150 | }, 1151 | "metadata": {}, 1152 | "execution_count": 43 1153 | } 1154 | ], 1155 | "source": [ 1156 | "Variable(torch.ones(2,2),requires_grad=True)" 1157 | ] 1158 | }, 1159 | { 1160 | "cell_type": "code", 1161 | "execution_count": null, 1162 | "metadata": { 1163 | "id": "odD1Jb9VlOxV" 1164 | }, 1165 | "outputs": [], 1166 | "source": [ 1167 | "a, b = 12,23\n", 1168 | "x1 = Variable(torch.randn(a,b),\n", 1169 | " requires_grad=True)\n", 1170 | "x2 = Variable(torch.randn(a,b),\n", 1171 | " requires_grad=True)\n", 1172 | "x3 =Variable(torch.randn(a,b),\n", 1173 | " requires_grad=True)" 1174 | ] 1175 | }, 1176 | { 1177 | "cell_type": "code", 1178 | "execution_count": null, 1179 | "metadata": { 1180 | "colab": { 1181 | "base_uri": "https://localhost:8080/" 1182 | }, 1183 | "id": "_PfiDQPalOxV", 1184 | "outputId": "a262f932-411a-403d-fae6-3b9e7397f2e5" 1185 | }, 1186 | "outputs": [ 1187 | { 1188 | "output_type": "stream", 1189 | "name": "stdout", 1190 | "text": [ 1191 | "tensor(3278.1235, grad_fn=)\n" 1192 | ] 1193 | } 1194 | ], 1195 | "source": [ 1196 | "c = x1 * x2\n", 1197 | "d = a + x3\n", 1198 | "e = torch.sum(d)\n", 1199 | "\n", 1200 | "e.backward()\n", 1201 | "\n", 1202 | "print(e)" 1203 | ] 1204 | }, 1205 | { 1206 | "cell_type": "code", 1207 | "execution_count": null, 1208 | "metadata": { 1209 | "colab": { 1210 | "base_uri": "https://localhost:8080/" 1211 | }, 1212 | "id": "MY85sfEWlOxV", 1213 | "outputId": "3a021eb4-dd07-4420-eceb-79be4b49c71c" 1214 | }, 1215 | "outputs": [ 1216 | { 1217 | "output_type": "execute_result", 1218 | "data": { 1219 | "text/plain": [ 1220 | "tensor([[-4.9545e-02, 6.2245e-01, 1.6573e-01, 3.1583e-01, 2.4915e-01,\n", 1221 | " -4.9784e-01, 2.9079e+00, 1.6201e+00, -6.4459e-01, -1.9885e-02,\n", 1222 | " 1.6222e+00, 1.4239e+00, 9.0691e-01, 7.6310e-02, 1.1225e+00,\n", 1223 | " -1.2433e+00, -6.7258e-01, 8.8433e-01, -6.6589e-01, -7.3347e-01,\n", 1224 | " -2.7599e-01, 5.5485e-01, -1.9303e+00],\n", 1225 | " [-7.6389e-01, -9.9300e-01, 1.4080e+00, 2.5969e-01, 8.0760e-01,\n", 1226 | " -1.2618e+00, -7.7109e-01, -1.8497e+00, 2.2400e-01, 2.2088e-01,\n", 1227 | " -8.2452e-01, -2.5581e-02, -1.9850e+00, -3.7880e-01, -1.5030e+00,\n", 1228 | " -2.9808e+00, 5.1149e-01, -9.3890e-01, -1.9421e+00, 9.8052e-01,\n", 1229 | " -7.4463e-02, -7.1181e-01, 1.6136e+00],\n", 1230 | " [-1.8821e+00, -1.8542e+00, -2.5013e-01, 1.6023e-01, 2.0778e-01,\n", 1231 | " 1.1879e+00, -7.3204e-01, 1.3668e+00, 9.4616e-01, 6.1018e-01,\n", 1232 | " -1.1084e-01, -2.5730e-01, 6.3652e-01, -4.8517e-01, -7.1437e-01,\n", 1233 | " -6.0011e-02, 2.5377e-01, 8.2522e-02, 1.0270e+00, 7.0402e-02,\n", 1234 | " 4.2977e-01, 8.2052e-01, -1.4562e+00],\n", 1235 | " [ 3.0134e-01, -1.9370e+00, -2.9235e-03, -1.0772e+00, 2.5160e-01,\n", 1236 | " 6.6866e-02, 1.4607e+00, 7.7117e-01, 1.2076e-03, 4.2106e-01,\n", 1237 | " 1.0153e+00, -4.9146e-01, -1.1049e+00, 1.9696e-01, -1.4355e-01,\n", 1238 | " 1.4660e+00, -1.0422e-01, -1.3760e-01, 3.7740e-01, 8.5221e-01,\n", 1239 | " -4.0656e-01, -1.2096e+00, 3.6486e-02],\n", 1240 | " [ 8.2569e-01, -6.4759e-01, -1.8021e+00, 2.4911e+00, -2.1786e-01,\n", 1241 | " 7.6996e-01, -9.1563e-01, 7.1281e-01, 1.1037e+00, -7.1190e-01,\n", 1242 | " -2.2254e-01, 3.5406e-01, 5.5607e-01, -6.8683e-01, -9.9216e-01,\n", 1243 | " -2.4422e-01, -3.4747e-01, -4.0146e-01, 8.2387e-01, -1.0022e+00,\n", 1244 | " -1.0535e+00, -9.4035e-01, -2.9630e-01],\n", 1245 | " [-2.3561e+00, 1.1116e-02, -5.8723e-01, -6.1191e-01, 7.4930e-01,\n", 1246 | " -4.5617e-01, -1.7884e+00, -9.5469e-01, 3.8564e-01, 1.2921e+00,\n", 1247 | " 8.0012e-01, 8.6142e-02, -1.8761e+00, 1.0685e+00, -1.8311e-01,\n", 1248 | " -1.0058e+00, 3.7973e-01, -2.9853e-01, -1.3616e+00, -3.3329e-01,\n", 1249 | " -1.2527e+00, -3.2175e-01, -9.4914e-01],\n", 1250 | " [ 3.4618e-01, -1.2652e+00, -2.1208e+00, 1.5305e+00, -1.0777e+00,\n", 1251 | " -3.1854e-01, -8.6873e-01, -3.5102e-01, -1.5684e+00, 1.3443e+00,\n", 1252 | " 3.7181e-02, -3.5979e-01, 1.0322e-01, -1.6268e+00, -1.1469e-02,\n", 1253 | " -1.0545e+00, -2.1433e+00, -5.4195e-01, -4.2554e-01, 1.8125e+00,\n", 1254 | " -1.3636e+00, 3.0668e-01, -1.8680e+00],\n", 1255 | " [ 8.1188e-01, 9.2804e-01, 2.5353e+00, 7.4317e-01, -9.2664e-01,\n", 1256 | " 1.0827e+00, -1.0121e+00, -2.4838e-01, 3.5444e-01, -1.0163e-01,\n", 1257 | " -3.6335e-01, 1.1053e+00, -9.1091e-01, -6.5603e-01, -7.0487e-01,\n", 1258 | " 9.6358e-01, 8.3500e-01, -1.2857e+00, -2.1408e-01, 1.0373e-01,\n", 1259 | " -4.0387e-01, -4.9104e-01, -9.8583e-01],\n", 1260 | " [ 1.2580e+00, -1.8494e+00, -9.1556e-01, 1.0041e+00, 4.0680e-01,\n", 1261 | " -6.7118e-01, 4.5932e-01, -6.7476e-01, -3.6495e-01, 1.2697e+00,\n", 1262 | " -6.7312e-01, -5.7493e-01, 9.2411e-01, 7.2763e-01, 6.2405e-01,\n", 1263 | " 9.3639e-01, 1.2043e+00, 1.9976e-01, -8.2096e-01, -8.3537e-01,\n", 1264 | " -3.3319e-01, -1.3532e+00, -4.1707e-02],\n", 1265 | " [ 7.4626e-01, -1.3165e-01, -6.2212e-01, 1.3070e+00, 1.4547e+00,\n", 1266 | " -2.1038e-01, -3.4147e-01, -3.6501e-01, 4.9274e-01, 3.4149e-01,\n", 1267 | " 1.3207e+00, -1.0124e+00, -6.5716e-01, -5.7463e-01, -2.0788e+00,\n", 1268 | " -4.5005e-02, 1.3279e+00, 1.3064e+00, -2.1880e+00, -1.0477e+00,\n", 1269 | " 3.1200e-01, -8.2720e-01, -8.5936e-01],\n", 1270 | " [-3.9723e-01, 4.2088e-01, 5.3277e-01, -7.2912e-01, -1.1725e+00,\n", 1271 | " 9.1836e-02, -5.6562e-01, -1.1900e+00, -9.4340e-01, -8.1699e-01,\n", 1272 | " -6.7582e-01, -8.0012e-02, 3.1367e-01, -7.1102e-01, 6.3412e-02,\n", 1273 | " -3.0735e+00, -8.9330e-01, -2.5040e-01, -3.1167e-02, 1.4725e+00,\n", 1274 | " 1.2173e+00, 6.0895e-01, -1.1243e+00],\n", 1275 | " [ 7.6267e-01, 8.0859e-01, -9.1959e-01, 1.6713e+00, 7.7652e-01,\n", 1276 | " -6.4198e-01, 1.2667e+00, -2.4394e-01, -3.8565e-01, -9.5948e-02,\n", 1277 | " 9.4557e-01, -1.2708e+00, 2.9587e-01, 1.5027e-01, 1.1833e+00,\n", 1278 | " 5.7281e-01, -2.5252e-01, -7.3179e-02, -1.4961e-02, 1.4511e-01,\n", 1279 | " 2.0197e-01, 2.7220e-01, 2.9865e-01]])" 1280 | ] 1281 | }, 1282 | "metadata": {}, 1283 | "execution_count": 46 1284 | } 1285 | ], 1286 | "source": [ 1287 | "x1.data" 1288 | ] 1289 | }, 1290 | { 1291 | "cell_type": "code", 1292 | "execution_count": null, 1293 | "metadata": { 1294 | "colab": { 1295 | "base_uri": "https://localhost:8080/" 1296 | }, 1297 | "id": "Vnaj-a9clOxV", 1298 | "outputId": "69813e5b-13f6-4456-e847-3d206dee7b40" 1299 | }, 1300 | "outputs": [ 1301 | { 1302 | "output_type": "execute_result", 1303 | "data": { 1304 | "text/plain": [ 1305 | "tensor([[-7.5597e-01, -1.1689e+00, -9.3890e-01, 8.8566e-01, 1.3764e+00,\n", 1306 | " -7.8276e-01, 2.2200e-01, 7.3758e-02, -6.9147e-01, -5.1308e-01,\n", 1307 | " 1.1427e+00, -1.0126e+00, 1.1602e-01, -1.0350e+00, 1.0803e+00,\n", 1308 | " -7.9977e-01, -9.1219e-02, 5.0242e-01, -4.5173e-01, -4.8067e-01,\n", 1309 | " 5.9066e-01, 1.6343e-01, -3.1368e-02],\n", 1310 | " [ 4.4646e-01, -2.1036e+00, -1.8971e+00, 1.4661e+00, 6.1199e-01,\n", 1311 | " 7.9935e-01, 9.3007e-01, 1.2819e+00, -6.2896e-01, 3.0550e-01,\n", 1312 | " 5.8429e-01, -2.1502e+00, 4.1505e-01, 3.1579e-03, -3.8580e-01,\n", 1313 | " 5.7681e-01, -1.1972e+00, -2.4006e-01, -1.3253e+00, 1.1415e+00,\n", 1314 | " -4.4164e-01, 3.4923e-01, -1.6730e+00],\n", 1315 | " [ 1.7173e+00, -1.0184e+00, 1.1152e+00, -5.0580e-01, 1.7331e-01,\n", 1316 | " -6.8128e-01, 2.5642e-02, 1.2838e+00, 1.3107e+00, 1.4736e+00,\n", 1317 | " 1.3023e+00, -1.4379e+00, -1.7115e+00, -8.0051e-01, 1.4921e+00,\n", 1318 | " -1.0768e-01, 1.8001e+00, 9.7892e-01, -5.9166e-01, 8.5103e-02,\n", 1319 | " 7.5238e-02, 8.3984e-01, 1.3866e+00],\n", 1320 | " [-9.5211e-01, -4.9072e-01, -1.4583e+00, 1.3491e+00, 1.5738e+00,\n", 1321 | " 1.2403e+00, 8.7063e-01, -7.4153e-01, 1.7150e+00, 5.9718e-01,\n", 1322 | " 1.1104e-01, 5.8702e-01, 1.7935e+00, -4.3749e-01, -4.4637e-01,\n", 1323 | " -7.3716e-01, 6.1647e-02, -2.5027e-01, -1.4518e-01, -1.9606e-01,\n", 1324 | " -5.0812e-01, -8.0779e-01, 1.6439e+00],\n", 1325 | " [-1.2366e+00, 5.9530e-02, -1.3995e-01, 6.6461e-02, -1.6116e+00,\n", 1326 | " 1.0756e+00, -5.2225e-02, 1.0433e+00, -1.7803e-01, 5.4113e-01,\n", 1327 | " 1.5247e+00, -2.2933e-02, -1.0489e+00, 5.9934e-01, -1.1722e+00,\n", 1328 | " -3.5785e-01, 2.7906e+00, -1.8163e-01, 5.1971e-01, -3.1983e-01,\n", 1329 | " 1.2022e+00, -4.8330e-01, -2.8758e-01],\n", 1330 | " [ 1.0301e+00, 7.2947e-01, 1.0306e-01, 2.7964e-01, 1.2181e+00,\n", 1331 | " 5.6054e-01, 8.6046e-01, -1.0222e+00, -1.1600e+00, 7.2069e-01,\n", 1332 | " 1.4963e-01, -8.3145e-01, -1.9167e+00, -1.8408e+00, 9.6285e-01,\n", 1333 | " 2.2254e-02, -1.4754e+00, 4.3352e-01, -2.3785e-01, 3.4293e-01,\n", 1334 | " -2.3405e+00, 3.1645e-01, 1.4717e+00],\n", 1335 | " [-3.9061e-01, 9.7402e-02, 1.0617e+00, 1.2896e+00, 7.8817e-01,\n", 1336 | " -1.6625e+00, -1.1102e+00, 7.6015e-03, -1.0987e+00, -5.3510e-01,\n", 1337 | " -2.0040e-03, 1.0430e+00, -2.6981e+00, -5.0815e-01, -7.9794e-01,\n", 1338 | " 9.4735e-01, 5.9186e-01, -4.8238e-01, 8.9618e-02, 5.0087e-01,\n", 1339 | " 1.3167e-01, 1.2114e-01, -1.1848e+00],\n", 1340 | " [-1.4867e+00, -1.1119e+00, -5.5689e-01, -1.9381e-01, 6.6268e-01,\n", 1341 | " -3.6233e-01, 9.7089e-01, 1.5987e+00, 4.4115e-01, 1.3639e+00,\n", 1342 | " 3.1804e-01, -1.6259e-01, -4.1701e-02, 2.7475e-01, -6.1249e-01,\n", 1343 | " 2.1635e+00, -6.9394e-01, -6.7843e-01, -1.4126e+00, -6.1121e-01,\n", 1344 | " -2.7585e-01, -4.9035e-01, 2.5709e-01],\n", 1345 | " [ 1.1294e+00, 1.4856e+00, -6.7626e-01, 9.3721e-02, -6.6421e-02,\n", 1346 | " -1.0708e+00, -6.1524e-01, -4.2784e-01, 7.8178e-01, -6.3566e-01,\n", 1347 | " -3.9345e-01, 8.0382e-01, 5.0099e-01, -6.9284e-01, -1.0052e-01,\n", 1348 | " -8.0754e-01, 5.8784e-01, 1.2261e+00, 1.1008e+00, -3.8728e-01,\n", 1349 | " 3.2881e-01, 6.4573e-01, -2.1029e-01],\n", 1350 | " [-9.9325e-01, -8.6854e-01, 7.7369e-02, 1.1841e+00, 5.0544e-01,\n", 1351 | " -5.6922e-01, 1.1186e+00, 1.0655e+00, 5.1852e-01, -2.0170e+00,\n", 1352 | " -1.3915e+00, 1.3255e+00, 4.3642e-01, -1.7325e+00, 4.2292e-01,\n", 1353 | " 2.0095e-01, -6.1319e-01, 1.4466e+00, -1.2209e+00, 4.0064e-01,\n", 1354 | " 3.0691e-01, 5.4840e-01, 2.1183e+00],\n", 1355 | " [-7.8591e-01, -6.0411e-01, 2.4474e-01, -2.3646e-01, 1.1019e+00,\n", 1356 | " -1.5382e+00, 8.6998e-02, -6.7848e-01, 7.1260e-01, 1.1081e+00,\n", 1357 | " 7.9302e-01, 3.6504e-01, -7.8480e-01, -3.5138e-01, 5.5498e-01,\n", 1358 | " -2.7474e-01, 4.1603e-01, -1.3930e+00, -8.6290e-01, -2.4129e-01,\n", 1359 | " -1.2446e+00, -2.1552e+00, 6.1390e-01],\n", 1360 | " [ 8.6510e-01, -3.1663e-02, -6.4619e-02, -3.0712e-01, 2.0007e-01,\n", 1361 | " 4.0111e-01, -1.2684e+00, 3.3103e-01, 6.2498e-01, -7.7015e-01,\n", 1362 | " 1.0212e+00, 7.1803e-01, -1.0951e-01, -4.4673e-02, 3.4933e-01,\n", 1363 | " -1.4548e-01, 9.7760e-01, 3.8946e-01, -6.9499e-01, -8.3517e-02,\n", 1364 | " -5.6924e-02, 1.1568e+00, -1.2049e+00]])" 1365 | ] 1366 | }, 1367 | "metadata": {}, 1368 | "execution_count": 47 1369 | } 1370 | ], 1371 | "source": [ 1372 | "x2.data" 1373 | ] 1374 | }, 1375 | { 1376 | "cell_type": "code", 1377 | "execution_count": null, 1378 | "metadata": { 1379 | "colab": { 1380 | "base_uri": "https://localhost:8080/" 1381 | }, 1382 | "id": "Dq1rnJpYlOxW", 1383 | "outputId": "83899120-cc80-4b9a-b145-6b66c4134d1b" 1384 | }, 1385 | "outputs": [ 1386 | { 1387 | "output_type": "execute_result", 1388 | "data": { 1389 | "text/plain": [ 1390 | "tensor([[ 0.2499, 0.2458, 0.1029, -0.6494, -0.3258, 0.8149, 0.4049, 0.2481,\n", 1391 | " 0.4841, 0.3293, -1.2471, 0.2117, 1.4315, 0.0502, -0.3668, 0.8378,\n", 1392 | " -0.7901, 0.0267, -0.3120, 2.4534, 0.7926, 0.2382, -0.5245],\n", 1393 | " [-0.2131, -2.0323, -0.3952, 0.7286, 1.3579, 1.3583, -0.5818, -0.3204,\n", 1394 | " -0.7242, -1.4629, 0.6109, 0.2401, -0.5455, 2.9154, -0.1985, 0.7507,\n", 1395 | " -0.0390, -0.8322, 0.1364, 0.4531, -0.1112, 1.0307, -1.1862],\n", 1396 | " [-0.9696, 0.1508, 0.9814, -1.9638, -0.2218, 0.1477, -0.1875, 0.9963,\n", 1397 | " 0.3611, 0.0238, 1.7371, -0.7392, -0.1709, 1.2120, 0.9385, 0.3245,\n", 1398 | " -1.7054, 0.7467, -1.4787, -0.2400, 0.4558, -1.0911, 0.1566],\n", 1399 | " [-1.0629, 0.7770, 0.4003, -0.7194, 0.3499, -0.3935, -0.3622, -3.4058,\n", 1400 | " -0.1108, -0.7741, 0.9103, -0.0961, -0.6697, -0.1985, -1.1560, -0.7709,\n", 1401 | " -0.4655, -0.3847, -0.4364, -0.7429, -0.5210, -0.8078, -0.4448],\n", 1402 | " [ 1.4265, 0.2566, 0.4157, 1.7485, -0.5296, -0.7132, -0.1676, -0.6071,\n", 1403 | " 0.0347, -0.9721, -1.6722, 1.7052, 0.3938, -0.7728, -0.6024, -0.3242,\n", 1404 | " -0.6425, -2.3893, -2.6920, -0.2646, 0.4027, -1.0844, 0.1360],\n", 1405 | " [ 2.1896, -0.2159, -0.4342, -0.2216, -0.7294, -0.1235, 0.6418, -2.1429,\n", 1406 | " -0.0135, 0.9174, -1.8431, 0.6344, 0.8557, -1.1231, -0.6695, 0.0306,\n", 1407 | " -2.3209, 1.9433, -1.7311, -1.2552, 0.4064, 1.6822, 1.3038],\n", 1408 | " [-0.3663, -1.2401, -0.2566, -0.0977, 0.2273, -1.5666, -0.7700, -0.1483,\n", 1409 | " -0.8830, -1.5841, -0.3711, -1.2338, -0.6142, 0.2801, 0.0178, 0.3047,\n", 1410 | " 0.8214, -2.6705, 0.3366, -0.4533, 0.0461, -0.2598, 0.7919],\n", 1411 | " [-0.8666, -0.3528, -1.5244, -0.4039, -1.4359, 0.7409, 1.1551, -1.2673,\n", 1412 | " -0.1306, -0.3743, -1.1143, -1.0104, 0.5946, 0.3441, 1.5870, -0.4498,\n", 1413 | " 0.0428, 0.2114, 0.3053, 0.1590, 0.6634, -0.2352, 0.4823],\n", 1414 | " [ 1.2331, 0.0465, -0.0451, 0.6444, 1.6787, 0.8955, 0.5074, -0.3619,\n", 1415 | " 1.4870, -1.2376, -0.0362, -0.2950, -2.5790, -1.0069, 1.8991, 0.0619,\n", 1416 | " 0.2949, 0.0973, -0.1098, 1.5281, -0.4748, -0.5837, -0.2411],\n", 1417 | " [ 0.9073, -0.2173, 0.4118, -1.8131, -0.7694, -0.7774, -0.4924, -0.7196,\n", 1418 | " -0.1061, -0.2974, -0.6724, 0.8060, 0.1783, -0.8092, 0.2455, -0.1944,\n", 1419 | " -0.3686, -0.2207, 1.3099, -1.3060, -1.4669, -0.3317, -0.4190],\n", 1420 | " [-1.4119, -2.5146, 2.0673, -1.0484, 0.4922, 1.7666, 0.0404, 0.4461,\n", 1421 | " -1.6570, -0.6111, -0.7872, 0.7080, -0.4652, 0.5982, 0.2778, 1.7961,\n", 1422 | " 0.7223, 0.5924, -0.8698, 0.4279, 1.7435, -0.2529, -1.1322],\n", 1423 | " [-0.5567, 0.1003, 0.1121, 1.2011, -0.7812, -0.6869, -0.1129, -0.2705,\n", 1424 | " -3.0171, 0.0981, -0.1275, 0.1092, -0.2081, -1.7573, 0.3798, -0.5517,\n", 1425 | " -0.1958, -0.7510, 1.3651, 0.3416, 1.2432, 0.8636, -0.2051]])" 1426 | ] 1427 | }, 1428 | "metadata": {}, 1429 | "execution_count": 48 1430 | } 1431 | ], 1432 | "source": [ 1433 | "x3.data" 1434 | ] 1435 | }, 1436 | { 1437 | "cell_type": "code", 1438 | "execution_count": null, 1439 | "metadata": { 1440 | "colab": { 1441 | "base_uri": "https://localhost:8080/" 1442 | }, 1443 | "id": "WU_q483FlOxW", 1444 | "outputId": "3dd2f92b-ba59-4d00-fea6-2c2c18fd6297" 1445 | }, 1446 | "outputs": [ 1447 | { 1448 | "output_type": "stream", 1449 | "name": "stdout", 1450 | "text": [ 1451 | "Gradient of w1 w.r.t to Loss: -455.0\n", 1452 | "Gradient of w2 w.r.t to Loss: -365.0\n", 1453 | "Gradient of w3 w.r.t to Loss: -60.0\n", 1454 | "Gradient of w4 w.r.t to Loss: -265.0\n" 1455 | ] 1456 | } 1457 | ], 1458 | "source": [ 1459 | "from torch import FloatTensor\n", 1460 | "from torch.autograd import Variable\n", 1461 | "\n", 1462 | "a = Variable(FloatTensor([5]))\n", 1463 | "\n", 1464 | "weights = [Variable(FloatTensor([i]), requires_grad=True) for i in (12, 53, 91, 73)]\n", 1465 | "\n", 1466 | "w1, w2, w3, w4 = weights\n", 1467 | "\n", 1468 | "b = w1 * a\n", 1469 | "c = w2 * a\n", 1470 | "d = w3 * b + w4 * c\n", 1471 | "Loss = (10 - d)\n", 1472 | "\n", 1473 | "Loss.backward()\n", 1474 | "\n", 1475 | "for index, weight in enumerate(weights, start=1):\n", 1476 | " gradient, *_ = weight.grad.data\n", 1477 | " print(f\"Gradient of w{index} w.r.t to Loss: {gradient}\")" 1478 | ] 1479 | }, 1480 | { 1481 | "cell_type": "code", 1482 | "execution_count": null, 1483 | "metadata": { 1484 | "id": "Kv6j2By7lOxW" 1485 | }, 1486 | "outputs": [], 1487 | "source": [ 1488 | "" 1489 | ] 1490 | }, 1491 | { 1492 | "cell_type": "code", 1493 | "execution_count": null, 1494 | "metadata": { 1495 | "id": "YgkJRNsAlOxW" 1496 | }, 1497 | "outputs": [], 1498 | "source": [ 1499 | "# Using forward pass\n", 1500 | "def forward(x):\n", 1501 | " return x * w" 1502 | ] 1503 | }, 1504 | { 1505 | "cell_type": "code", 1506 | "execution_count": null, 1507 | "metadata": { 1508 | "colab": { 1509 | "base_uri": "https://localhost:8080/" 1510 | }, 1511 | "id": "_vKnWJfRlOxW", 1512 | "outputId": "567b4f44-07ba-4ad4-930f-ceebbb2e3389" 1513 | }, 1514 | "outputs": [ 1515 | { 1516 | "output_type": "stream", 1517 | "name": "stdout", 1518 | "text": [ 1519 | "predict (before training) 4 tensor(4.)\n" 1520 | ] 1521 | } 1522 | ], 1523 | "source": [ 1524 | "import torch\n", 1525 | "from torch.autograd import Variable\n", 1526 | "\n", 1527 | "x_data = [11.0, 22.0, 33.0]\n", 1528 | "y_data = [21.0, 14.0, 64.0]\n", 1529 | "\n", 1530 | "w = Variable(torch.Tensor([1.0]), requires_grad=True) # Any random value\n", 1531 | "\n", 1532 | "# Before training\n", 1533 | "print(\"predict (before training)\", 4, forward(4).data[0])" 1534 | ] 1535 | }, 1536 | { 1537 | "cell_type": "code", 1538 | "execution_count": null, 1539 | "metadata": { 1540 | "id": "gR-GPl9alOxX" 1541 | }, 1542 | "outputs": [], 1543 | "source": [ 1544 | "# define the Loss function\n", 1545 | "def loss(x, y):\n", 1546 | " y_pred = forward(x)\n", 1547 | " return (y_pred - y) * (y_pred - y)" 1548 | ] 1549 | }, 1550 | { 1551 | "cell_type": "code", 1552 | "execution_count": null, 1553 | "metadata": { 1554 | "colab": { 1555 | "base_uri": "https://localhost:8080/" 1556 | }, 1557 | "id": "5C0AjcmXlOxX", 1558 | "outputId": "6dea7b8b-8912-47f3-88e4-9ed5b3cd10ec" 1559 | }, 1560 | "outputs": [ 1561 | { 1562 | "output_type": "stream", 1563 | "name": "stdout", 1564 | "text": [ 1565 | "\tgrad: 11.0 21.0 tensor(-220.)\n", 1566 | "\tgrad: 22.0 14.0 tensor(2481.6001)\n", 1567 | "\tgrad: 33.0 64.0 tensor(-51303.6484)\n", 1568 | "progress: 0 tensor(604238.8125)\n", 1569 | "\tgrad: 11.0 21.0 tensor(118461.7578)\n", 1570 | "\tgrad: 22.0 14.0 tensor(-671630.6875)\n", 1571 | "\tgrad: 33.0 64.0 tensor(13114108.)\n", 1572 | "progress: 1 tensor(3.9481e+10)\n", 1573 | "\tgrad: 11.0 21.0 tensor(-30279010.)\n", 1574 | "\tgrad: 22.0 14.0 tensor(1.7199e+08)\n", 1575 | "\tgrad: 33.0 64.0 tensor(-3.3589e+09)\n", 1576 | "progress: 2 tensor(2.5900e+15)\n", 1577 | "\tgrad: 11.0 21.0 tensor(7.7553e+09)\n", 1578 | "\tgrad: 22.0 14.0 tensor(-4.4050e+10)\n", 1579 | "\tgrad: 33.0 64.0 tensor(8.6030e+11)\n", 1580 | "progress: 3 tensor(1.6991e+20)\n", 1581 | "\tgrad: 11.0 21.0 tensor(-1.9863e+12)\n", 1582 | "\tgrad: 22.0 14.0 tensor(1.1282e+13)\n", 1583 | "\tgrad: 33.0 64.0 tensor(-2.2034e+14)\n", 1584 | "progress: 4 tensor(1.1146e+25)\n", 1585 | "\tgrad: 11.0 21.0 tensor(5.0875e+14)\n", 1586 | "\tgrad: 22.0 14.0 tensor(-2.8897e+15)\n", 1587 | "\tgrad: 33.0 64.0 tensor(5.6436e+16)\n", 1588 | "progress: 5 tensor(7.3118e+29)\n", 1589 | "\tgrad: 11.0 21.0 tensor(-1.3030e+17)\n", 1590 | "\tgrad: 22.0 14.0 tensor(7.4013e+17)\n", 1591 | "\tgrad: 33.0 64.0 tensor(-1.4455e+19)\n", 1592 | "progress: 6 tensor(4.7966e+34)\n", 1593 | "\tgrad: 11.0 21.0 tensor(3.3374e+19)\n", 1594 | "\tgrad: 22.0 14.0 tensor(-1.8957e+20)\n", 1595 | "\tgrad: 33.0 64.0 tensor(3.7022e+21)\n", 1596 | "progress: 7 tensor(inf)\n", 1597 | "\tgrad: 11.0 21.0 tensor(-8.5480e+21)\n", 1598 | "\tgrad: 22.0 14.0 tensor(4.8553e+22)\n", 1599 | "\tgrad: 33.0 64.0 tensor(-9.4824e+23)\n", 1600 | "progress: 8 tensor(inf)\n", 1601 | "\tgrad: 11.0 21.0 tensor(2.1894e+24)\n", 1602 | "\tgrad: 22.0 14.0 tensor(-1.2436e+25)\n", 1603 | "\tgrad: 33.0 64.0 tensor(2.4287e+26)\n", 1604 | "progress: 9 tensor(inf)\n" 1605 | ] 1606 | } 1607 | ], 1608 | "source": [ 1609 | "# Run the Training loop\n", 1610 | "for epoch in range(10):\n", 1611 | " for x_val, y_val in zip(x_data, y_data):\n", 1612 | " l = loss(x_val, y_val)\n", 1613 | " l.backward()\n", 1614 | " print(\"\\tgrad: \", x_val, y_val, w.grad.data[0])\n", 1615 | " w.data = w.data - 0.01 * w.grad.data\n", 1616 | "\n", 1617 | " # Manually set the gradients to zero after updating weights\n", 1618 | " w.grad.data.zero_()\n", 1619 | "\n", 1620 | " print(\"progress:\", epoch, l.data[0])" 1621 | ] 1622 | }, 1623 | { 1624 | "cell_type": "code", 1625 | "execution_count": null, 1626 | "metadata": { 1627 | "colab": { 1628 | "base_uri": "https://localhost:8080/" 1629 | }, 1630 | "id": "Z6Q0Lq1TlOxX", 1631 | "outputId": "f559d811-e055-4798-daa6-4c44fd8a241e" 1632 | }, 1633 | "outputs": [ 1634 | { 1635 | "output_type": "stream", 1636 | "name": "stdout", 1637 | "text": [ 1638 | "predict (after training) 4 tensor(-9.2687e+24)\n" 1639 | ] 1640 | } 1641 | ], 1642 | "source": [ 1643 | "# After training\n", 1644 | "print(\"predict (after training)\", 4, forward(4).data[0])" 1645 | ] 1646 | }, 1647 | { 1648 | "cell_type": "code", 1649 | "execution_count": null, 1650 | "metadata": { 1651 | "colab": { 1652 | "base_uri": "https://localhost:8080/" 1653 | }, 1654 | "id": "M2dEQtSzlOxX", 1655 | "outputId": "42ecd0bf-e857-47d6-be31-faf4ae7a07f7" 1656 | }, 1657 | "outputs": [ 1658 | { 1659 | "output_type": "stream", 1660 | "name": "stdout", 1661 | "text": [ 1662 | "tensor([[-0.3071, -3.6691, -2.8417, -1.1818],\n", 1663 | " [-1.4654, -0.4344, -2.0130, -2.3842],\n", 1664 | " [ 1.3962, 1.4962, -2.0996, 1.8881],\n", 1665 | " [-1.9797, 0.2337, -1.0308, 0.1266]])\n" 1666 | ] 1667 | } 1668 | ], 1669 | "source": [ 1670 | "z = Variable(torch.Tensor(4, 4).uniform_(-5, 5))\n", 1671 | "print(z)" 1672 | ] 1673 | }, 1674 | { 1675 | "cell_type": "code", 1676 | "execution_count": null, 1677 | "metadata": { 1678 | "colab": { 1679 | "base_uri": "https://localhost:8080/" 1680 | }, 1681 | "id": "E1pLWZk8lOxY", 1682 | "outputId": "bd37d1c8-1c98-4c94-ec86-7fb84c235fcc" 1683 | }, 1684 | "outputs": [ 1685 | { 1686 | "output_type": "stream", 1687 | "name": "stdout", 1688 | "text": [ 1689 | "Requires Gradient : False \n", 1690 | "Volatile : False \n", 1691 | "Gradient : None \n", 1692 | "tensor([[-0.3071, -3.6691, -2.8417, -1.1818],\n", 1693 | " [-1.4654, -0.4344, -2.0130, -2.3842],\n", 1694 | " [ 1.3962, 1.4962, -2.0996, 1.8881],\n", 1695 | " [-1.9797, 0.2337, -1.0308, 0.1266]])\n" 1696 | ] 1697 | }, 1698 | { 1699 | "output_type": "stream", 1700 | "name": "stderr", 1701 | "text": [ 1702 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:2: UserWarning: volatile was removed (Variable.volatile is always False)\n", 1703 | " \n" 1704 | ] 1705 | } 1706 | ], 1707 | "source": [ 1708 | "print('Requires Gradient : %s ' % (z.requires_grad))\n", 1709 | "print('Volatile : %s ' % (z.volatile))\n", 1710 | "print('Gradient : %s ' % (z.grad))\n", 1711 | "print(z.data)" 1712 | ] 1713 | }, 1714 | { 1715 | "cell_type": "code", 1716 | "execution_count": null, 1717 | "metadata": { 1718 | "colab": { 1719 | "base_uri": "https://localhost:8080/" 1720 | }, 1721 | "id": "JSrWTkLSlOxY", 1722 | "outputId": "e66d5f18-f574-4655-f4e2-e5cb6f4812c2" 1723 | }, 1724 | "outputs": [ 1725 | { 1726 | "output_type": "stream", 1727 | "name": "stdout", 1728 | "text": [ 1729 | "torch.Size([4, 4])\n" 1730 | ] 1731 | } 1732 | ], 1733 | "source": [ 1734 | "x = Variable(torch.Tensor(4, 4).uniform_(-4, 5))\n", 1735 | "y = Variable(torch.Tensor(4, 4).uniform_(-3, 2))\n", 1736 | "# matrix multiplication\n", 1737 | "z = torch.mm(x, y)\n", 1738 | "print(z.size())" 1739 | ] 1740 | }, 1741 | { 1742 | "cell_type": "code", 1743 | "execution_count": null, 1744 | "metadata": { 1745 | "colab": { 1746 | "base_uri": "https://localhost:8080/" 1747 | }, 1748 | "id": "THb7ZbbklOxY", 1749 | "outputId": "9cd93d2f-ac65-4e85-abe2-514ebb9455e8" 1750 | }, 1751 | "outputs": [ 1752 | { 1753 | "output_type": "execute_result", 1754 | "data": { 1755 | "text/plain": [ 1756 | "tensor([[ 4.1714, -1.3202, 4.5641, 0.2945],\n", 1757 | " [-2.5585, 2.0775, 1.0064, 4.9148],\n", 1758 | " [ 3.9464, -3.6807, 4.3685, 4.6255],\n", 1759 | " [ 2.8453, 0.2704, 4.8619, 0.7903]])" 1760 | ] 1761 | }, 1762 | "metadata": {}, 1763 | "execution_count": 58 1764 | } 1765 | ], 1766 | "source": [ 1767 | "x.data" 1768 | ] 1769 | }, 1770 | { 1771 | "cell_type": "code", 1772 | "execution_count": null, 1773 | "metadata": { 1774 | "id": "esesJDtllOxY" 1775 | }, 1776 | "outputs": [], 1777 | "source": [ 1778 | "#tensor operations" 1779 | ] 1780 | }, 1781 | { 1782 | "cell_type": "code", 1783 | "execution_count": null, 1784 | "metadata": { 1785 | "colab": { 1786 | "base_uri": "https://localhost:8080/" 1787 | }, 1788 | "id": "f66PpVEvlOxZ", 1789 | "outputId": "bab93004-72a5-4200-abbe-9956e513a631" 1790 | }, 1791 | "outputs": [ 1792 | { 1793 | "output_type": "execute_result", 1794 | "data": { 1795 | "text/plain": [ 1796 | "tensor([[0.9002, 0.9188, 0.1386, 0.3701],\n", 1797 | " [0.1947, 0.2268, 0.9587, 0.2615],\n", 1798 | " [0.7256, 0.7673, 0.5667, 0.1863],\n", 1799 | " [0.4642, 0.4016, 0.9981, 0.8452]])" 1800 | ] 1801 | }, 1802 | "metadata": {}, 1803 | "execution_count": 60 1804 | } 1805 | ], 1806 | "source": [ 1807 | "mat1 = torch.FloatTensor(4,4).uniform_(0,1)\n", 1808 | "mat1" 1809 | ] 1810 | }, 1811 | { 1812 | "cell_type": "code", 1813 | "execution_count": null, 1814 | "metadata": { 1815 | "colab": { 1816 | "base_uri": "https://localhost:8080/" 1817 | }, 1818 | "id": "omnV7gP6lOxZ", 1819 | "outputId": "45ce18f9-ab70-42e7-964b-91acfdc090b5" 1820 | }, 1821 | "outputs": [ 1822 | { 1823 | "output_type": "execute_result", 1824 | "data": { 1825 | "text/plain": [ 1826 | "tensor([[0.4962, 0.4947, 0.8344, 0.6721],\n", 1827 | " [0.1182, 0.5997, 0.8990, 0.8252],\n", 1828 | " [0.1466, 0.1093, 0.8135, 0.9047],\n", 1829 | " [0.2486, 0.1873, 0.6159, 0.2471]])" 1830 | ] 1831 | }, 1832 | "metadata": {}, 1833 | "execution_count": 72 1834 | } 1835 | ], 1836 | "source": [ 1837 | "mat2 = torch.FloatTensor(4,4).uniform_(0,1)\n", 1838 | "mat2" 1839 | ] 1840 | }, 1841 | { 1842 | "cell_type": "code", 1843 | "execution_count": null, 1844 | "metadata": { 1845 | "colab": { 1846 | "base_uri": "https://localhost:8080/" 1847 | }, 1848 | "id": "kvQexmLLlOxZ", 1849 | "outputId": "040e431f-678e-4e61-b0ab-54610f6b0df2" 1850 | }, 1851 | "outputs": [ 1852 | { 1853 | "output_type": "execute_result", 1854 | "data": { 1855 | "text/plain": [ 1856 | "tensor([0.7582, 0.6879, 0.8949, 0.3995])" 1857 | ] 1858 | }, 1859 | "metadata": {}, 1860 | "execution_count": 73 1861 | } 1862 | ], 1863 | "source": [ 1864 | "vec1 = torch.FloatTensor(4).uniform_(0,1)\n", 1865 | "vec1" 1866 | ] 1867 | }, 1868 | { 1869 | "cell_type": "code", 1870 | "execution_count": null, 1871 | "metadata": { 1872 | "id": "eIdac4aelOxZ" 1873 | }, 1874 | "outputs": [], 1875 | "source": [ 1876 | "# scalar addition" 1877 | ] 1878 | }, 1879 | { 1880 | "cell_type": "code", 1881 | "execution_count": null, 1882 | "metadata": { 1883 | "colab": { 1884 | "base_uri": "https://localhost:8080/" 1885 | }, 1886 | "id": "YwAXVw-0lOxZ", 1887 | "outputId": "46032e73-88d2-4e04-f856-229bc59c7c7b" 1888 | }, 1889 | "outputs": [ 1890 | { 1891 | "output_type": "execute_result", 1892 | "data": { 1893 | "text/plain": [ 1894 | "tensor([[11.4002, 11.4188, 10.6386, 10.8701],\n", 1895 | " [10.6947, 10.7268, 11.4587, 10.7615],\n", 1896 | " [11.2256, 11.2673, 11.0667, 10.6863],\n", 1897 | " [10.9642, 10.9016, 11.4981, 11.3452]])" 1898 | ] 1899 | }, 1900 | "metadata": {}, 1901 | "execution_count": 75 1902 | } 1903 | ], 1904 | "source": [ 1905 | "mat1 + 10.5" 1906 | ] 1907 | }, 1908 | { 1909 | "cell_type": "code", 1910 | "execution_count": null, 1911 | "metadata": { 1912 | "id": "ksx-522ylOxZ" 1913 | }, 1914 | "outputs": [], 1915 | "source": [ 1916 | "# scalar subtraction" 1917 | ] 1918 | }, 1919 | { 1920 | "cell_type": "code", 1921 | "execution_count": null, 1922 | "metadata": { 1923 | "colab": { 1924 | "base_uri": "https://localhost:8080/" 1925 | }, 1926 | "id": "lPQfHCJAlOxa", 1927 | "outputId": "e7e30d31-1028-4fd2-e89c-8451b75e21c2" 1928 | }, 1929 | "outputs": [ 1930 | { 1931 | "output_type": "execute_result", 1932 | "data": { 1933 | "text/plain": [ 1934 | "tensor([[ 0.2962, 0.2947, 0.6344, 0.4721],\n", 1935 | " [-0.0818, 0.3997, 0.6990, 0.6252],\n", 1936 | " [-0.0534, -0.0907, 0.6135, 0.7047],\n", 1937 | " [ 0.0486, -0.0127, 0.4159, 0.0471]])" 1938 | ] 1939 | }, 1940 | "metadata": {}, 1941 | "execution_count": 77 1942 | } 1943 | ], 1944 | "source": [ 1945 | "mat2 - 0.20" 1946 | ] 1947 | }, 1948 | { 1949 | "cell_type": "code", 1950 | "execution_count": null, 1951 | "metadata": { 1952 | "id": "uBMORUwalOxa" 1953 | }, 1954 | "outputs": [], 1955 | "source": [ 1956 | "# vector and matrix addition" 1957 | ] 1958 | }, 1959 | { 1960 | "cell_type": "code", 1961 | "execution_count": null, 1962 | "metadata": { 1963 | "colab": { 1964 | "base_uri": "https://localhost:8080/" 1965 | }, 1966 | "id": "e28tx8ailOxa", 1967 | "outputId": "3cd9c421-30c5-47a5-cbd1-b4bbf4a9b845" 1968 | }, 1969 | "outputs": [ 1970 | { 1971 | "output_type": "execute_result", 1972 | "data": { 1973 | "text/plain": [ 1974 | "tensor([[1.6584, 1.6067, 1.0335, 0.7695],\n", 1975 | " [0.9530, 0.9147, 1.8537, 0.6610],\n", 1976 | " [1.4839, 1.4553, 1.4616, 0.5858],\n", 1977 | " [1.2224, 1.0895, 1.8931, 1.2446]])" 1978 | ] 1979 | }, 1980 | "metadata": {}, 1981 | "execution_count": 79 1982 | } 1983 | ], 1984 | "source": [ 1985 | "mat1 + vec1" 1986 | ] 1987 | }, 1988 | { 1989 | "cell_type": "code", 1990 | "execution_count": null, 1991 | "metadata": { 1992 | "colab": { 1993 | "base_uri": "https://localhost:8080/" 1994 | }, 1995 | "id": "LG9e5d2NlOxa", 1996 | "outputId": "98e97de3-0db2-40ed-c7ca-4d272be97fed" 1997 | }, 1998 | "outputs": [ 1999 | { 2000 | "output_type": "execute_result", 2001 | "data": { 2002 | "text/plain": [ 2003 | "tensor([[1.2544, 1.1826, 1.7293, 1.0716],\n", 2004 | " [0.8764, 1.2876, 1.7939, 1.2247],\n", 2005 | " [0.9049, 0.7972, 1.7084, 1.3042],\n", 2006 | " [1.0068, 0.8752, 1.5108, 0.6466]])" 2007 | ] 2008 | }, 2009 | "metadata": {}, 2010 | "execution_count": 80 2011 | } 2012 | ], 2013 | "source": [ 2014 | "mat2 + vec1" 2015 | ] 2016 | }, 2017 | { 2018 | "cell_type": "code", 2019 | "execution_count": null, 2020 | "metadata": { 2021 | "id": "WKVWHJpklOxa" 2022 | }, 2023 | "outputs": [], 2024 | "source": [ 2025 | "# matrix-matrix addition" 2026 | ] 2027 | }, 2028 | { 2029 | "cell_type": "code", 2030 | "execution_count": null, 2031 | "metadata": { 2032 | "colab": { 2033 | "base_uri": "https://localhost:8080/" 2034 | }, 2035 | "id": "23pMreZclOxa", 2036 | "outputId": "7d1f2523-f173-4120-df1c-aa0c02837449" 2037 | }, 2038 | "outputs": [ 2039 | { 2040 | "output_type": "execute_result", 2041 | "data": { 2042 | "text/plain": [ 2043 | "tensor([[1.3963, 1.4135, 0.9730, 1.0422],\n", 2044 | " [0.3129, 0.8265, 1.8577, 1.0867],\n", 2045 | " [0.8722, 0.8766, 1.3802, 1.0910],\n", 2046 | " [0.7127, 0.5888, 1.6141, 1.0923]])" 2047 | ] 2048 | }, 2049 | "metadata": {}, 2050 | "execution_count": 82 2051 | } 2052 | ], 2053 | "source": [ 2054 | "mat1 + mat2" 2055 | ] 2056 | }, 2057 | { 2058 | "cell_type": "code", 2059 | "execution_count": null, 2060 | "metadata": { 2061 | "colab": { 2062 | "base_uri": "https://localhost:8080/" 2063 | }, 2064 | "id": "UaFGpnDRlOxb", 2065 | "outputId": "97068dad-8a0a-4e3a-8f2a-ea96b485f413" 2066 | }, 2067 | "outputs": [ 2068 | { 2069 | "output_type": "execute_result", 2070 | "data": { 2071 | "text/plain": [ 2072 | "tensor([[0.8103, 0.8442, 0.0192, 0.1370],\n", 2073 | " [0.0379, 0.0514, 0.9192, 0.0684],\n", 2074 | " [0.5265, 0.5888, 0.3211, 0.0347],\n", 2075 | " [0.2155, 0.1613, 0.9963, 0.7143]])" 2076 | ] 2077 | }, 2078 | "metadata": {}, 2079 | "execution_count": 83 2080 | } 2081 | ], 2082 | "source": [ 2083 | "mat1 * mat1" 2084 | ] 2085 | }, 2086 | { 2087 | "cell_type": "code", 2088 | "execution_count": null, 2089 | "metadata": { 2090 | "id": "iB9qGUzjlOxb" 2091 | }, 2092 | "outputs": [], 2093 | "source": [ 2094 | "# about Bernoulli distribution" 2095 | ] 2096 | }, 2097 | { 2098 | "cell_type": "code", 2099 | "execution_count": null, 2100 | "metadata": { 2101 | "id": "Le4cbcF-lOxb" 2102 | }, 2103 | "outputs": [], 2104 | "source": [ 2105 | "from torch.distributions.bernoulli import Bernoulli" 2106 | ] 2107 | }, 2108 | { 2109 | "cell_type": "code", 2110 | "execution_count": null, 2111 | "metadata": { 2112 | "id": "Z4Z_VOeilOxb" 2113 | }, 2114 | "outputs": [], 2115 | "source": [ 2116 | "dist = Bernoulli(torch.tensor([0.3,0.6,0.9]))" 2117 | ] 2118 | }, 2119 | { 2120 | "cell_type": "code", 2121 | "execution_count": null, 2122 | "metadata": { 2123 | "colab": { 2124 | "base_uri": "https://localhost:8080/" 2125 | }, 2126 | "id": "rC-Czw1dlOxb", 2127 | "outputId": "6a03e406-86aa-48fa-a697-e6fbcc4b672a" 2128 | }, 2129 | "outputs": [ 2130 | { 2131 | "output_type": "execute_result", 2132 | "data": { 2133 | "text/plain": [ 2134 | "tensor([0., 1., 0.])" 2135 | ] 2136 | }, 2137 | "metadata": {}, 2138 | "execution_count": 87 2139 | } 2140 | ], 2141 | "source": [ 2142 | "dist.sample() #sample is binary, it takes 1 with p and 0 with 1-p" 2143 | ] 2144 | }, 2145 | { 2146 | "cell_type": "code", 2147 | "execution_count": null, 2148 | "metadata": { 2149 | "id": "l89yVuwrlOxb" 2150 | }, 2151 | "outputs": [], 2152 | "source": [ 2153 | "#Creates a Bernoulli distribution parameterized by probs \n", 2154 | "#Samples are binary (0 or 1). They take the value 1 with probability p \n", 2155 | "#and 0 with probability 1 - p." 2156 | ] 2157 | }, 2158 | { 2159 | "cell_type": "code", 2160 | "execution_count": null, 2161 | "metadata": { 2162 | "id": "JI3ktUbDlOxb" 2163 | }, 2164 | "outputs": [], 2165 | "source": [ 2166 | "from torch.distributions.beta import Beta" 2167 | ] 2168 | }, 2169 | { 2170 | "cell_type": "code", 2171 | "execution_count": null, 2172 | "metadata": { 2173 | "colab": { 2174 | "base_uri": "https://localhost:8080/" 2175 | }, 2176 | "id": "b94WrCrWlOxc", 2177 | "outputId": "03f86a9b-e467-46ff-abec-e2b49ae05409" 2178 | }, 2179 | "outputs": [ 2180 | { 2181 | "output_type": "execute_result", 2182 | "data": { 2183 | "text/plain": [ 2184 | "Beta()" 2185 | ] 2186 | }, 2187 | "metadata": {}, 2188 | "execution_count": 90 2189 | } 2190 | ], 2191 | "source": [ 2192 | "dist = Beta(torch.tensor([0.5]), torch.tensor([0.5]))\n", 2193 | "dist" 2194 | ] 2195 | }, 2196 | { 2197 | "cell_type": "code", 2198 | "execution_count": null, 2199 | "metadata": { 2200 | "colab": { 2201 | "base_uri": "https://localhost:8080/" 2202 | }, 2203 | "id": "-dGCXUAMlOxc", 2204 | "outputId": "077f0007-2aee-4ad8-fb9e-7a724b0f8184" 2205 | }, 2206 | "outputs": [ 2207 | { 2208 | "output_type": "execute_result", 2209 | "data": { 2210 | "text/plain": [ 2211 | "tensor([0.6771])" 2212 | ] 2213 | }, 2214 | "metadata": {}, 2215 | "execution_count": 91 2216 | } 2217 | ], 2218 | "source": [ 2219 | "dist.sample()" 2220 | ] 2221 | }, 2222 | { 2223 | "cell_type": "code", 2224 | "execution_count": null, 2225 | "metadata": { 2226 | "id": "5GGiU-wwlOxc" 2227 | }, 2228 | "outputs": [], 2229 | "source": [ 2230 | "from torch.distributions.binomial import Binomial" 2231 | ] 2232 | }, 2233 | { 2234 | "cell_type": "code", 2235 | "execution_count": null, 2236 | "metadata": { 2237 | "id": "sf0ndW6tlOxc" 2238 | }, 2239 | "outputs": [], 2240 | "source": [ 2241 | "dist = Binomial(100, torch.tensor([0 , .2, .8, 1]))" 2242 | ] 2243 | }, 2244 | { 2245 | "cell_type": "code", 2246 | "execution_count": null, 2247 | "metadata": { 2248 | "colab": { 2249 | "base_uri": "https://localhost:8080/" 2250 | }, 2251 | "id": "8EB23HpXlOxc", 2252 | "outputId": "623ebece-db7b-433e-ee0c-7cb0e0230c85" 2253 | }, 2254 | "outputs": [ 2255 | { 2256 | "output_type": "execute_result", 2257 | "data": { 2258 | "text/plain": [ 2259 | "tensor([ 0., 21., 83., 100.])" 2260 | ] 2261 | }, 2262 | "metadata": {}, 2263 | "execution_count": 94 2264 | } 2265 | ], 2266 | "source": [ 2267 | "dist.sample()" 2268 | ] 2269 | }, 2270 | { 2271 | "cell_type": "code", 2272 | "execution_count": null, 2273 | "metadata": { 2274 | "id": "W2jx1FfvlOxc" 2275 | }, 2276 | "outputs": [], 2277 | "source": [ 2278 | "# 100- count of trials\n", 2279 | "# 0, 0.2, 0.8 and 1 are event probabilities" 2280 | ] 2281 | }, 2282 | { 2283 | "cell_type": "code", 2284 | "execution_count": null, 2285 | "metadata": { 2286 | "id": "reuCxuaulOxd" 2287 | }, 2288 | "outputs": [], 2289 | "source": [ 2290 | "from torch.distributions.categorical import Categorical" 2291 | ] 2292 | }, 2293 | { 2294 | "cell_type": "code", 2295 | "execution_count": null, 2296 | "metadata": { 2297 | "colab": { 2298 | "base_uri": "https://localhost:8080/" 2299 | }, 2300 | "id": "fzFDxbdylOxd", 2301 | "outputId": "fa6d937d-023d-4cdd-a5e0-d82b416b94fa" 2302 | }, 2303 | "outputs": [ 2304 | { 2305 | "output_type": "execute_result", 2306 | "data": { 2307 | "text/plain": [ 2308 | "Categorical(probs: torch.Size([5]))" 2309 | ] 2310 | }, 2311 | "metadata": {}, 2312 | "execution_count": 97 2313 | } 2314 | ], 2315 | "source": [ 2316 | "dist = Categorical(torch.tensor([ 0.20, 0.20, 0.20, 0.20, 0.20 ]))\n", 2317 | "dist" 2318 | ] 2319 | }, 2320 | { 2321 | "cell_type": "code", 2322 | "execution_count": null, 2323 | "metadata": { 2324 | "colab": { 2325 | "base_uri": "https://localhost:8080/" 2326 | }, 2327 | "id": "MJhxIJwElOxd", 2328 | "outputId": "0ad16ed9-d260-4034-a38e-a6a2b0b36ba7" 2329 | }, 2330 | "outputs": [ 2331 | { 2332 | "output_type": "execute_result", 2333 | "data": { 2334 | "text/plain": [ 2335 | "tensor(2)" 2336 | ] 2337 | }, 2338 | "metadata": {}, 2339 | "execution_count": 98 2340 | } 2341 | ], 2342 | "source": [ 2343 | "dist.sample()" 2344 | ] 2345 | }, 2346 | { 2347 | "cell_type": "code", 2348 | "execution_count": null, 2349 | "metadata": { 2350 | "id": "eSE7XhculOxd" 2351 | }, 2352 | "outputs": [], 2353 | "source": [ 2354 | "# 0.20, 0.20, 0.20, 0.20,0.20 event probabilities" 2355 | ] 2356 | }, 2357 | { 2358 | "cell_type": "code", 2359 | "execution_count": null, 2360 | "metadata": { 2361 | "id": "KVOXshFBlOxd" 2362 | }, 2363 | "outputs": [], 2364 | "source": [ 2365 | "# Laplace distribution parameterized by loc and ‘scale’." 2366 | ] 2367 | }, 2368 | { 2369 | "cell_type": "code", 2370 | "execution_count": null, 2371 | "metadata": { 2372 | "id": "yMv-qmfJlOxl" 2373 | }, 2374 | "outputs": [], 2375 | "source": [ 2376 | "from torch.distributions.laplace import Laplace" 2377 | ] 2378 | }, 2379 | { 2380 | "cell_type": "code", 2381 | "execution_count": null, 2382 | "metadata": { 2383 | "colab": { 2384 | "base_uri": "https://localhost:8080/" 2385 | }, 2386 | "id": "IvxaNdI-lOxl", 2387 | "outputId": "e14d7961-5b09-4d06-f6a1-67af1dc69465" 2388 | }, 2389 | "outputs": [ 2390 | { 2391 | "output_type": "execute_result", 2392 | "data": { 2393 | "text/plain": [ 2394 | "Laplace(loc: tensor([10.]), scale: tensor([0.9900]))" 2395 | ] 2396 | }, 2397 | "metadata": {}, 2398 | "execution_count": 102 2399 | } 2400 | ], 2401 | "source": [ 2402 | "dist = Laplace(torch.tensor([10.0]), torch.tensor([0.990]))\n", 2403 | "dist" 2404 | ] 2405 | }, 2406 | { 2407 | "cell_type": "code", 2408 | "execution_count": null, 2409 | "metadata": { 2410 | "colab": { 2411 | "base_uri": "https://localhost:8080/" 2412 | }, 2413 | "id": "y7aFsTJglOxl", 2414 | "outputId": "df9a0c7b-91dc-4ef0-a98d-4f17f1109f41" 2415 | }, 2416 | "outputs": [ 2417 | { 2418 | "output_type": "execute_result", 2419 | "data": { 2420 | "text/plain": [ 2421 | "tensor([9.6554])" 2422 | ] 2423 | }, 2424 | "metadata": {}, 2425 | "execution_count": 103 2426 | } 2427 | ], 2428 | "source": [ 2429 | "dist.sample()" 2430 | ] 2431 | }, 2432 | { 2433 | "cell_type": "code", 2434 | "execution_count": null, 2435 | "metadata": { 2436 | "id": "4iIJEjrMlOxm" 2437 | }, 2438 | "outputs": [], 2439 | "source": [ 2440 | "#Normal (Gaussian) distribution parameterized by loc and ‘scale’." 2441 | ] 2442 | }, 2443 | { 2444 | "cell_type": "code", 2445 | "execution_count": null, 2446 | "metadata": { 2447 | "id": "t9nWn9PQlOxm" 2448 | }, 2449 | "outputs": [], 2450 | "source": [ 2451 | "from torch.distributions.normal import Normal" 2452 | ] 2453 | }, 2454 | { 2455 | "cell_type": "code", 2456 | "execution_count": null, 2457 | "metadata": { 2458 | "colab": { 2459 | "base_uri": "https://localhost:8080/" 2460 | }, 2461 | "id": "wXcIsBNilOxm", 2462 | "outputId": "cbda0cf5-2973-4689-86bf-0e5a34d0e080" 2463 | }, 2464 | "outputs": [ 2465 | { 2466 | "output_type": "execute_result", 2467 | "data": { 2468 | "text/plain": [ 2469 | "Normal(loc: tensor([100.]), scale: tensor([10.]))" 2470 | ] 2471 | }, 2472 | "metadata": {}, 2473 | "execution_count": 106 2474 | } 2475 | ], 2476 | "source": [ 2477 | "dist = Normal(torch.tensor([100.0]), torch.tensor([10.0]))\n", 2478 | "dist" 2479 | ] 2480 | }, 2481 | { 2482 | "cell_type": "code", 2483 | "execution_count": null, 2484 | "metadata": { 2485 | "colab": { 2486 | "base_uri": "https://localhost:8080/" 2487 | }, 2488 | "id": "SMn_25hslOxm", 2489 | "outputId": "a356ca71-2da2-4d80-9839-73c29b97969f" 2490 | }, 2491 | "outputs": [ 2492 | { 2493 | "output_type": "execute_result", 2494 | "data": { 2495 | "text/plain": [ 2496 | "tensor([84.3435])" 2497 | ] 2498 | }, 2499 | "metadata": {}, 2500 | "execution_count": 107 2501 | } 2502 | ], 2503 | "source": [ 2504 | "dist.sample()" 2505 | ] 2506 | }, 2507 | { 2508 | "cell_type": "code", 2509 | "execution_count": null, 2510 | "metadata": { 2511 | "id": "QTmYev-XlOxm" 2512 | }, 2513 | "outputs": [], 2514 | "source": [ 2515 | "" 2516 | ] 2517 | } 2518 | ], 2519 | "metadata": { 2520 | "kernelspec": { 2521 | "display_name": "Python 3", 2522 | "language": "python", 2523 | "name": "python3" 2524 | }, 2525 | "language_info": { 2526 | "codemirror_mode": { 2527 | "name": "ipython", 2528 | "version": 3 2529 | }, 2530 | "file_extension": ".py", 2531 | "mimetype": "text/x-python", 2532 | "name": "python", 2533 | "nbconvert_exporter": "python", 2534 | "pygments_lexer": "ipython3", 2535 | "version": "3.6.4" 2536 | }, 2537 | "colab": { 2538 | "name": "Torch_AI_2_2Ed.ipynb", 2539 | "provenance": [], 2540 | "collapsed_sections": [] 2541 | } 2542 | }, 2543 | "nbformat": 4, 2544 | "nbformat_minor": 0 2545 | } --------------------------------------------------------------------------------