├── 0_Deep_Learning's_Hello_World.ipynb ├── 10_RNN_Network.ipynb ├── 11_Text_Classification.ipynb ├── 12_Machine_Translation_From_Scratch_Part_1.ipynb ├── 13_Machine_Translation_From_Scratch_Part_2.ipynb ├── 14_Vanilla_GAN.ipynb ├── 15_DCGAN.ipynb ├── 16_Conditional_DCGAN.ipynb ├── 17_Pix2Pix_GAN.ipynb ├── 18_Cycle_GAN.ipynb ├── 19_Arbitrary_Style_Transfer_(AdaIN).ipynb ├── 1_Fashion_MNIST_and_CIFAR10.ipynb ├── 20_VAE.ipynb ├── 21_Diffusion_Model.ipynb ├── 22_Open_Source_NLU_models.ipynb ├── 23_Value_functions_and_policy_iteration.ipynb ├── 24_Double_Deep_Q_Learning_1_gym_intro.ipynb ├── 25_Double_Deep_Q_Learning_2.ipynb ├── 2_horse_or_human_workshop.ipynb ├── 3_VGG16_keras_applicarions.ipynb ├── 4_Residual_Networks.ipynb ├── 5_Pytorch_Introdution.ipynb ├── 6_CIFAR10_pytorch.ipynb ├── 7_Neural_Style_Transfer_.ipynb ├── 8_Siamese_Networks.ipynb ├── 9_Unet_and_Segmentation.ipynb └── README.md /13_Machine_Translation_From_Scratch_Part_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyMMzEZjTDTJxfI3u0MIefDO", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 35, 32 | "metadata": { 33 | "id": "cezeGfesoylg" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "import torch\n", 38 | "import torch.nn as nn\n", 39 | "import torch.nn.functional as F\n", 40 | "import torch.optim as optim" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "source": [ 46 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" 47 | ], 48 | "metadata": { 49 | "id": "RCpeVTNlvgxg" 50 | }, 51 | "execution_count": 36, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "source": [ 57 | "class Encoder(nn.Module):\n", 58 | " def __init__(self, num_tokens, embedding_dim, latent_dim):\n", 59 | " super().__init__()\n", 60 | "\n", 61 | " self.embedding = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_dim)\n", 62 | " self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=latent_dim, num_layers=1, batch_first=True, bidirectional=True)\n", 63 | "\n", 64 | " self.latent_dim = latent_dim\n", 65 | "\n", 66 | " def forward(self, x):\n", 67 | " x = self.embedding(x)\n", 68 | " batch_size, _, _ = x.size()\n", 69 | " h_0 = torch.zeros(2, batch_size, self.latent_dim).to(DEVICE)\n", 70 | " outputs, context_vector = self.rnn(x, h_0)\n", 71 | "\n", 72 | " return context_vector, outputs" 73 | ], 74 | "metadata": { 75 | "id": "l6drzN3TvYEP" 76 | }, 77 | "execution_count": 37, 78 | "outputs": [] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "source": [ 83 | "encoder = Encoder(num_tokens=100, embedding_dim=16, latent_dim=64).to(DEVICE)" 84 | ], 85 | "metadata": { 86 | "id": "uEfD71WEvcs0" 87 | }, 88 | "execution_count": 38, 89 | "outputs": [] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "source": [ 94 | "batch_size = 10\n", 95 | "seq_length = 30\n", 96 | "\n", 97 | "test_input = torch.zeros(batch_size, seq_length, dtype=torch.int64)\n", 98 | "test_output, _ = encoder(test_input)\n", 99 | "\n", 100 | "test_output.shape" 101 | ], 102 | "metadata": { 103 | "colab": { 104 | "base_uri": "https://localhost:8080/" 105 | }, 106 | "id": "nKKSzpd_voel", 107 | "outputId": "ed6adfc9-9c52-4b02-a511-8fff8f3988de" 108 | }, 109 | "execution_count": 39, 110 | "outputs": [ 111 | { 112 | "output_type": "execute_result", 113 | "data": { 114 | "text/plain": [ 115 | "torch.Size([2, 10, 64])" 116 | ] 117 | }, 118 | "metadata": {}, 119 | "execution_count": 39 120 | } 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "source": [ 126 | "class AttentionBlock(nn.Module):\n", 127 | " def __init__(self, hidden_dim):\n", 128 | " super().__init__()\n", 129 | " self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)\n", 130 | " self.U = nn.Linear(hidden_dim, hidden_dim, bias=False)\n", 131 | " self.V = nn.Linear(hidden_dim, 1, bias=False)\n", 132 | "\n", 133 | " def forward(self, query, keys):\n", 134 | " QK = self.W(query).unsqueeze(1) + self.U(keys)\n", 135 | " QK = torch.tanh(QK)\n", 136 | " scores = self.V(QK)\n", 137 | "\n", 138 | " scores = scores.squeeze(2).unsqueeze(1)\n", 139 | " weigths = F.softmax(scores, dim=-1)\n", 140 | "\n", 141 | " context = torch.bmm(weigths, keys)\n", 142 | "\n", 143 | " return context, weigths" 144 | ], 145 | "metadata": { 146 | "id": "QoAL81lzwBtT" 147 | }, 148 | "execution_count": 40, 149 | "outputs": [] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "source": [ 154 | "class AttentionGRU(nn.Module):\n", 155 | " def __init__(self, input_size, latent_dim):\n", 156 | " super().__init__()\n", 157 | " self.attention = AttentionBlock(2 * latent_dim)\n", 158 | " self.rnn = nn.GRU(input_size=input_size, hidden_size=2 * latent_dim)\n", 159 | " self.latent_dim = latent_dim\n", 160 | "\n", 161 | " def forward(self, predicted_label, encoder_outputs):\n", 162 | " batch_size, _, _ = predicted_label.size()\n", 163 | " h = torch.zeros(batch_size, 2 * self.latent_dim)\n", 164 | " predicted_label = predicted_label.permute(1, 0, 2)\n", 165 | " for token in predicted_label:\n", 166 | " context, weights = self.attention(h, encoder_outputs)\n", 167 | " context = context.permute(1, 0, 2)\n", 168 | " token = token.unsqueeze(1).permute(1, 0, 2)\n", 169 | " output, h = self.rnn(token, context)\n", 170 | " h = h.squeeze()\n", 171 | "\n", 172 | " return output, h" 173 | ], 174 | "metadata": { 175 | "id": "FHt-pZmz40N1" 176 | }, 177 | "execution_count": 41, 178 | "outputs": [] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "source": [ 183 | "class Decoder(nn.Module):\n", 184 | " def __init__(self, num_tokens, embedding_dim, latent_dim):\n", 185 | " super().__init__()\n", 186 | "\n", 187 | " self.embedding = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_dim)\n", 188 | " self.rnn = AttentionGRU(embedding_dim, latent_dim)\n", 189 | " self.fc = nn.Linear(in_features=2 * latent_dim, out_features=num_tokens)\n", 190 | " self.softmax = nn.LogSoftmax(dim=1)\n", 191 | "\n", 192 | " def forward(self, encoder_outputs, predicted_label):\n", 193 | " x = self.embedding(predicted_label)\n", 194 | " x, _ = self.rnn(x, encoder_outputs)\n", 195 | " x = self.fc(x)\n", 196 | " x = self.softmax(x)\n", 197 | "\n", 198 | " return x" 199 | ], 200 | "metadata": { 201 | "id": "nfrRgkJF62bE" 202 | }, 203 | "execution_count": 42, 204 | "outputs": [] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "source": [ 209 | "encoder = Encoder(num_tokens=100, embedding_dim=8, latent_dim=16)\n", 210 | "decoder = Decoder(num_tokens=100, embedding_dim=8, latent_dim=16)" 211 | ], 212 | "metadata": { 213 | "id": "lrVicVYv7LBr" 214 | }, 215 | "execution_count": 43, 216 | "outputs": [] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "source": [ 221 | "batch_size = 50\n", 222 | "seq_length = 20\n", 223 | "predicted_labels_count = 10\n", 224 | "test_input = torch.zeros(batch_size, seq_length, dtype=torch.int64)\n", 225 | "predicted_labels = torch.zeros(batch_size, predicted_labels_count, dtype=torch.int64)\n", 226 | "\n", 227 | "_, encoder_output = encoder(test_input)\n", 228 | "new_token = decoder(encoder_output, predicted_labels)\n", 229 | "new_token.size()" 230 | ], 231 | "metadata": { 232 | "colab": { 233 | "base_uri": "https://localhost:8080/" 234 | }, 235 | "id": "L-glyCsm7T29", 236 | "outputId": "4e784ab3-7cef-4924-ecb7-3d0ee3b7b8e5" 237 | }, 238 | "execution_count": 44, 239 | "outputs": [ 240 | { 241 | "output_type": "execute_result", 242 | "data": { 243 | "text/plain": [ 244 | "torch.Size([1, 50, 100])" 245 | ] 246 | }, 247 | "metadata": {}, 248 | "execution_count": 44 249 | } 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "source": [], 255 | "metadata": { 256 | "id": "ncvEzVsD79Zh" 257 | }, 258 | "execution_count": 44, 259 | "outputs": [] 260 | } 261 | ] 262 | } -------------------------------------------------------------------------------- /21_Diffusion_Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4", 8 | "authorship_tag": "ABX9TyPp+28MPcewUc/8K+ah+M4u", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU", 19 | "widgets": { 20 | "application/vnd.jupyter.widget-state+json": { 21 | "b0fa66fc56e34465bb2a59f9ce916b71": { 22 | "model_module": "@jupyter-widgets/controls", 23 | "model_name": "HBoxModel", 24 | "model_module_version": "1.5.0", 25 | "state": { 26 | "_dom_classes": [], 27 | "_model_module": "@jupyter-widgets/controls", 28 | "_model_module_version": "1.5.0", 29 | "_model_name": "HBoxModel", 30 | "_view_count": null, 31 | "_view_module": "@jupyter-widgets/controls", 32 | "_view_module_version": "1.5.0", 33 | "_view_name": "HBoxView", 34 | "box_style": "", 35 | "children": [ 36 | "IPY_MODEL_14019560a279458cb775634b162b2407", 37 | "IPY_MODEL_4db4b93e47bd4c859b1d6685ed8cf76e", 38 | "IPY_MODEL_2ef34af8de2c4c7eab85fd3c83a5e1c5" 39 | ], 40 | "layout": "IPY_MODEL_1638fe599f2c409c8d75396bbad174b7" 41 | } 42 | }, 43 | "14019560a279458cb775634b162b2407": { 44 | "model_module": "@jupyter-widgets/controls", 45 | "model_name": "HTMLModel", 46 | "model_module_version": "1.5.0", 47 | "state": { 48 | "_dom_classes": [], 49 | "_model_module": "@jupyter-widgets/controls", 50 | "_model_module_version": "1.5.0", 51 | "_model_name": "HTMLModel", 52 | "_view_count": null, 53 | "_view_module": "@jupyter-widgets/controls", 54 | "_view_module_version": "1.5.0", 55 | "_view_name": "HTMLView", 56 | "description": "", 57 | "description_tooltip": null, 58 | "layout": "IPY_MODEL_0a2010f11ba345a385997b328168127a", 59 | "placeholder": "​", 60 | "style": "IPY_MODEL_5a39152a73cb4a87bd5bdf6f5a6888cc", 61 | "value": "100%" 62 | } 63 | }, 64 | "4db4b93e47bd4c859b1d6685ed8cf76e": { 65 | "model_module": "@jupyter-widgets/controls", 66 | "model_name": "FloatProgressModel", 67 | "model_module_version": "1.5.0", 68 | "state": { 69 | "_dom_classes": [], 70 | "_model_module": "@jupyter-widgets/controls", 71 | "_model_module_version": "1.5.0", 72 | "_model_name": "FloatProgressModel", 73 | "_view_count": null, 74 | "_view_module": "@jupyter-widgets/controls", 75 | "_view_module_version": "1.5.0", 76 | "_view_name": "ProgressView", 77 | "bar_style": "success", 78 | "description": "", 79 | "description_tooltip": null, 80 | "layout": "IPY_MODEL_b72fab00156243478d0619f3f8a56a1a", 81 | "max": 938, 82 | "min": 0, 83 | "orientation": "horizontal", 84 | "style": "IPY_MODEL_38548a20e7744f098a9118b82ff13f95", 85 | "value": 938 86 | } 87 | }, 88 | "2ef34af8de2c4c7eab85fd3c83a5e1c5": { 89 | "model_module": "@jupyter-widgets/controls", 90 | "model_name": "HTMLModel", 91 | "model_module_version": "1.5.0", 92 | "state": { 93 | "_dom_classes": [], 94 | "_model_module": "@jupyter-widgets/controls", 95 | "_model_module_version": "1.5.0", 96 | "_model_name": "HTMLModel", 97 | "_view_count": null, 98 | "_view_module": "@jupyter-widgets/controls", 99 | "_view_module_version": "1.5.0", 100 | "_view_name": "HTMLView", 101 | "description": "", 102 | "description_tooltip": null, 103 | "layout": "IPY_MODEL_47d16d89d79543ab8b893ccae490d2f0", 104 | "placeholder": "​", 105 | "style": "IPY_MODEL_05f22484b66641818d3288792ead2efa", 106 | "value": " 938/938 [03:40<00:00,  4.89it/s]" 107 | } 108 | }, 109 | "1638fe599f2c409c8d75396bbad174b7": { 110 | "model_module": "@jupyter-widgets/base", 111 | "model_name": "LayoutModel", 112 | "model_module_version": "1.2.0", 113 | "state": { 114 | "_model_module": "@jupyter-widgets/base", 115 | "_model_module_version": "1.2.0", 116 | "_model_name": "LayoutModel", 117 | "_view_count": null, 118 | "_view_module": "@jupyter-widgets/base", 119 | "_view_module_version": "1.2.0", 120 | "_view_name": "LayoutView", 121 | "align_content": null, 122 | "align_items": null, 123 | "align_self": null, 124 | "border": null, 125 | "bottom": null, 126 | "display": null, 127 | "flex": null, 128 | "flex_flow": null, 129 | "grid_area": null, 130 | "grid_auto_columns": null, 131 | "grid_auto_flow": null, 132 | "grid_auto_rows": null, 133 | "grid_column": null, 134 | "grid_gap": null, 135 | "grid_row": null, 136 | "grid_template_areas": null, 137 | "grid_template_columns": null, 138 | "grid_template_rows": null, 139 | "height": null, 140 | "justify_content": null, 141 | "justify_items": null, 142 | "left": null, 143 | "margin": null, 144 | "max_height": null, 145 | "max_width": null, 146 | "min_height": null, 147 | "min_width": null, 148 | "object_fit": null, 149 | "object_position": null, 150 | "order": null, 151 | "overflow": null, 152 | "overflow_x": null, 153 | "overflow_y": null, 154 | "padding": null, 155 | "right": null, 156 | "top": null, 157 | "visibility": null, 158 | "width": null 159 | } 160 | }, 161 | "0a2010f11ba345a385997b328168127a": { 162 | "model_module": "@jupyter-widgets/base", 163 | "model_name": "LayoutModel", 164 | "model_module_version": "1.2.0", 165 | "state": { 166 | "_model_module": "@jupyter-widgets/base", 167 | "_model_module_version": "1.2.0", 168 | "_model_name": "LayoutModel", 169 | "_view_count": null, 170 | "_view_module": "@jupyter-widgets/base", 171 | "_view_module_version": "1.2.0", 172 | "_view_name": "LayoutView", 173 | "align_content": null, 174 | "align_items": null, 175 | "align_self": null, 176 | "border": null, 177 | "bottom": null, 178 | "display": null, 179 | "flex": null, 180 | "flex_flow": null, 181 | "grid_area": null, 182 | "grid_auto_columns": null, 183 | "grid_auto_flow": null, 184 | "grid_auto_rows": null, 185 | "grid_column": null, 186 | "grid_gap": null, 187 | "grid_row": null, 188 | "grid_template_areas": null, 189 | "grid_template_columns": null, 190 | "grid_template_rows": null, 191 | "height": null, 192 | "justify_content": null, 193 | "justify_items": null, 194 | "left": null, 195 | "margin": null, 196 | "max_height": null, 197 | "max_width": null, 198 | "min_height": null, 199 | "min_width": null, 200 | "object_fit": null, 201 | "object_position": null, 202 | "order": null, 203 | "overflow": null, 204 | "overflow_x": null, 205 | "overflow_y": null, 206 | "padding": null, 207 | "right": null, 208 | "top": null, 209 | "visibility": null, 210 | "width": null 211 | } 212 | }, 213 | "5a39152a73cb4a87bd5bdf6f5a6888cc": { 214 | "model_module": "@jupyter-widgets/controls", 215 | "model_name": "DescriptionStyleModel", 216 | "model_module_version": "1.5.0", 217 | "state": { 218 | "_model_module": "@jupyter-widgets/controls", 219 | "_model_module_version": "1.5.0", 220 | "_model_name": "DescriptionStyleModel", 221 | "_view_count": null, 222 | "_view_module": "@jupyter-widgets/base", 223 | "_view_module_version": "1.2.0", 224 | "_view_name": "StyleView", 225 | "description_width": "" 226 | } 227 | }, 228 | "b72fab00156243478d0619f3f8a56a1a": { 229 | "model_module": "@jupyter-widgets/base", 230 | "model_name": "LayoutModel", 231 | "model_module_version": "1.2.0", 232 | "state": { 233 | "_model_module": "@jupyter-widgets/base", 234 | "_model_module_version": "1.2.0", 235 | "_model_name": "LayoutModel", 236 | "_view_count": null, 237 | "_view_module": "@jupyter-widgets/base", 238 | "_view_module_version": "1.2.0", 239 | "_view_name": "LayoutView", 240 | "align_content": null, 241 | "align_items": null, 242 | "align_self": null, 243 | "border": null, 244 | "bottom": null, 245 | "display": null, 246 | "flex": null, 247 | "flex_flow": null, 248 | "grid_area": null, 249 | "grid_auto_columns": null, 250 | "grid_auto_flow": null, 251 | "grid_auto_rows": null, 252 | "grid_column": null, 253 | "grid_gap": null, 254 | "grid_row": null, 255 | "grid_template_areas": null, 256 | "grid_template_columns": null, 257 | "grid_template_rows": null, 258 | "height": null, 259 | "justify_content": null, 260 | "justify_items": null, 261 | "left": null, 262 | "margin": null, 263 | "max_height": null, 264 | "max_width": null, 265 | "min_height": null, 266 | "min_width": null, 267 | "object_fit": null, 268 | "object_position": null, 269 | "order": null, 270 | "overflow": null, 271 | "overflow_x": null, 272 | "overflow_y": null, 273 | "padding": null, 274 | "right": null, 275 | "top": null, 276 | "visibility": null, 277 | "width": null 278 | } 279 | }, 280 | "38548a20e7744f098a9118b82ff13f95": { 281 | "model_module": "@jupyter-widgets/controls", 282 | "model_name": "ProgressStyleModel", 283 | "model_module_version": "1.5.0", 284 | "state": { 285 | "_model_module": "@jupyter-widgets/controls", 286 | "_model_module_version": "1.5.0", 287 | "_model_name": "ProgressStyleModel", 288 | "_view_count": null, 289 | "_view_module": "@jupyter-widgets/base", 290 | "_view_module_version": "1.2.0", 291 | "_view_name": "StyleView", 292 | "bar_color": null, 293 | "description_width": "" 294 | } 295 | }, 296 | "47d16d89d79543ab8b893ccae490d2f0": { 297 | "model_module": "@jupyter-widgets/base", 298 | "model_name": "LayoutModel", 299 | "model_module_version": "1.2.0", 300 | "state": { 301 | "_model_module": "@jupyter-widgets/base", 302 | "_model_module_version": "1.2.0", 303 | "_model_name": "LayoutModel", 304 | "_view_count": null, 305 | "_view_module": "@jupyter-widgets/base", 306 | "_view_module_version": "1.2.0", 307 | "_view_name": "LayoutView", 308 | "align_content": null, 309 | "align_items": null, 310 | "align_self": null, 311 | "border": null, 312 | "bottom": null, 313 | "display": null, 314 | "flex": null, 315 | "flex_flow": null, 316 | "grid_area": null, 317 | "grid_auto_columns": null, 318 | "grid_auto_flow": null, 319 | "grid_auto_rows": null, 320 | "grid_column": null, 321 | "grid_gap": null, 322 | "grid_row": null, 323 | "grid_template_areas": null, 324 | "grid_template_columns": null, 325 | "grid_template_rows": null, 326 | "height": null, 327 | "justify_content": null, 328 | "justify_items": null, 329 | "left": null, 330 | "margin": null, 331 | "max_height": null, 332 | "max_width": null, 333 | "min_height": null, 334 | "min_width": null, 335 | "object_fit": null, 336 | "object_position": null, 337 | "order": null, 338 | "overflow": null, 339 | "overflow_x": null, 340 | "overflow_y": null, 341 | "padding": null, 342 | "right": null, 343 | "top": null, 344 | "visibility": null, 345 | "width": null 346 | } 347 | }, 348 | "05f22484b66641818d3288792ead2efa": { 349 | "model_module": "@jupyter-widgets/controls", 350 | "model_name": "DescriptionStyleModel", 351 | "model_module_version": "1.5.0", 352 | "state": { 353 | "_model_module": "@jupyter-widgets/controls", 354 | "_model_module_version": "1.5.0", 355 | "_model_name": "DescriptionStyleModel", 356 | "_view_count": null, 357 | "_view_module": "@jupyter-widgets/base", 358 | "_view_module_version": "1.2.0", 359 | "_view_name": "StyleView", 360 | "description_width": "" 361 | } 362 | }, 363 | "7c76076cccf04326a82ec69c5b33c9af": { 364 | "model_module": "@jupyter-widgets/controls", 365 | "model_name": "HBoxModel", 366 | "model_module_version": "1.5.0", 367 | "state": { 368 | "_dom_classes": [], 369 | "_model_module": "@jupyter-widgets/controls", 370 | "_model_module_version": "1.5.0", 371 | "_model_name": "HBoxModel", 372 | "_view_count": null, 373 | "_view_module": "@jupyter-widgets/controls", 374 | "_view_module_version": "1.5.0", 375 | "_view_name": "HBoxView", 376 | "box_style": "", 377 | "children": [ 378 | "IPY_MODEL_326f80a6a9364bc4a37fd727649c5599", 379 | "IPY_MODEL_cd209532893744078f13583bc2765ef5", 380 | "IPY_MODEL_9933cef2fd0c44c2a308f24e5de0b1ea" 381 | ], 382 | "layout": "IPY_MODEL_9d07802c82da4382966ee2bc58bff4e2" 383 | } 384 | }, 385 | "326f80a6a9364bc4a37fd727649c5599": { 386 | "model_module": "@jupyter-widgets/controls", 387 | "model_name": "HTMLModel", 388 | "model_module_version": "1.5.0", 389 | "state": { 390 | "_dom_classes": [], 391 | "_model_module": "@jupyter-widgets/controls", 392 | "_model_module_version": "1.5.0", 393 | "_model_name": "HTMLModel", 394 | "_view_count": null, 395 | "_view_module": "@jupyter-widgets/controls", 396 | "_view_module_version": "1.5.0", 397 | "_view_name": "HTMLView", 398 | "description": "", 399 | "description_tooltip": null, 400 | "layout": "IPY_MODEL_7c62a32ae5944b7584b0026533d3bea9", 401 | "placeholder": "​", 402 | "style": "IPY_MODEL_7d6defe1998a4083bbf9f4fdb4e5d599", 403 | "value": "100%" 404 | } 405 | }, 406 | "cd209532893744078f13583bc2765ef5": { 407 | "model_module": "@jupyter-widgets/controls", 408 | "model_name": "FloatProgressModel", 409 | "model_module_version": "1.5.0", 410 | "state": { 411 | "_dom_classes": [], 412 | "_model_module": "@jupyter-widgets/controls", 413 | "_model_module_version": "1.5.0", 414 | "_model_name": "FloatProgressModel", 415 | "_view_count": null, 416 | "_view_module": "@jupyter-widgets/controls", 417 | "_view_module_version": "1.5.0", 418 | "_view_name": "ProgressView", 419 | "bar_style": "success", 420 | "description": "", 421 | "description_tooltip": null, 422 | "layout": "IPY_MODEL_5a2ad31325dc45d0b795eb88daf147df", 423 | "max": 938, 424 | "min": 0, 425 | "orientation": "horizontal", 426 | "style": "IPY_MODEL_c981017d7c39426dbaf566505c817bf4", 427 | "value": 938 428 | } 429 | }, 430 | "9933cef2fd0c44c2a308f24e5de0b1ea": { 431 | "model_module": "@jupyter-widgets/controls", 432 | "model_name": "HTMLModel", 433 | "model_module_version": "1.5.0", 434 | "state": { 435 | "_dom_classes": [], 436 | "_model_module": "@jupyter-widgets/controls", 437 | "_model_module_version": "1.5.0", 438 | "_model_name": "HTMLModel", 439 | "_view_count": null, 440 | "_view_module": "@jupyter-widgets/controls", 441 | "_view_module_version": "1.5.0", 442 | "_view_name": "HTMLView", 443 | "description": "", 444 | "description_tooltip": null, 445 | "layout": "IPY_MODEL_2ca3d70b283a4eaba7c683d03a53da07", 446 | "placeholder": "​", 447 | "style": "IPY_MODEL_51922432c661481da1ebe8192367c4dd", 448 | "value": " 938/938 [03:39<00:00,  4.95it/s]" 449 | } 450 | }, 451 | "9d07802c82da4382966ee2bc58bff4e2": { 452 | "model_module": "@jupyter-widgets/base", 453 | "model_name": "LayoutModel", 454 | "model_module_version": "1.2.0", 455 | "state": { 456 | "_model_module": "@jupyter-widgets/base", 457 | "_model_module_version": "1.2.0", 458 | "_model_name": "LayoutModel", 459 | "_view_count": null, 460 | "_view_module": "@jupyter-widgets/base", 461 | "_view_module_version": "1.2.0", 462 | "_view_name": "LayoutView", 463 | "align_content": null, 464 | "align_items": null, 465 | "align_self": null, 466 | "border": null, 467 | "bottom": null, 468 | "display": null, 469 | "flex": null, 470 | "flex_flow": null, 471 | "grid_area": null, 472 | "grid_auto_columns": null, 473 | "grid_auto_flow": null, 474 | "grid_auto_rows": null, 475 | "grid_column": null, 476 | "grid_gap": null, 477 | "grid_row": null, 478 | "grid_template_areas": null, 479 | "grid_template_columns": null, 480 | "grid_template_rows": null, 481 | "height": null, 482 | "justify_content": null, 483 | "justify_items": null, 484 | "left": null, 485 | "margin": null, 486 | "max_height": null, 487 | "max_width": null, 488 | "min_height": null, 489 | "min_width": null, 490 | "object_fit": null, 491 | "object_position": null, 492 | "order": null, 493 | "overflow": null, 494 | "overflow_x": null, 495 | "overflow_y": null, 496 | "padding": null, 497 | "right": null, 498 | "top": null, 499 | "visibility": null, 500 | "width": null 501 | } 502 | }, 503 | "7c62a32ae5944b7584b0026533d3bea9": { 504 | "model_module": "@jupyter-widgets/base", 505 | "model_name": "LayoutModel", 506 | "model_module_version": "1.2.0", 507 | "state": { 508 | "_model_module": "@jupyter-widgets/base", 509 | "_model_module_version": "1.2.0", 510 | "_model_name": "LayoutModel", 511 | "_view_count": null, 512 | "_view_module": "@jupyter-widgets/base", 513 | "_view_module_version": "1.2.0", 514 | "_view_name": "LayoutView", 515 | "align_content": null, 516 | "align_items": null, 517 | "align_self": null, 518 | "border": null, 519 | "bottom": null, 520 | "display": null, 521 | "flex": null, 522 | "flex_flow": null, 523 | "grid_area": null, 524 | "grid_auto_columns": null, 525 | "grid_auto_flow": null, 526 | "grid_auto_rows": null, 527 | "grid_column": null, 528 | "grid_gap": null, 529 | "grid_row": null, 530 | "grid_template_areas": null, 531 | "grid_template_columns": null, 532 | "grid_template_rows": null, 533 | "height": null, 534 | "justify_content": null, 535 | "justify_items": null, 536 | "left": null, 537 | "margin": null, 538 | "max_height": null, 539 | "max_width": null, 540 | "min_height": null, 541 | "min_width": null, 542 | "object_fit": null, 543 | "object_position": null, 544 | "order": null, 545 | "overflow": null, 546 | "overflow_x": null, 547 | "overflow_y": null, 548 | "padding": null, 549 | "right": null, 550 | "top": null, 551 | "visibility": null, 552 | "width": null 553 | } 554 | }, 555 | "7d6defe1998a4083bbf9f4fdb4e5d599": { 556 | "model_module": "@jupyter-widgets/controls", 557 | "model_name": "DescriptionStyleModel", 558 | "model_module_version": "1.5.0", 559 | "state": { 560 | "_model_module": "@jupyter-widgets/controls", 561 | "_model_module_version": "1.5.0", 562 | "_model_name": "DescriptionStyleModel", 563 | "_view_count": null, 564 | "_view_module": "@jupyter-widgets/base", 565 | "_view_module_version": "1.2.0", 566 | "_view_name": "StyleView", 567 | "description_width": "" 568 | } 569 | }, 570 | "5a2ad31325dc45d0b795eb88daf147df": { 571 | "model_module": "@jupyter-widgets/base", 572 | "model_name": "LayoutModel", 573 | "model_module_version": "1.2.0", 574 | "state": { 575 | "_model_module": "@jupyter-widgets/base", 576 | "_model_module_version": "1.2.0", 577 | "_model_name": "LayoutModel", 578 | "_view_count": null, 579 | "_view_module": "@jupyter-widgets/base", 580 | "_view_module_version": "1.2.0", 581 | "_view_name": "LayoutView", 582 | "align_content": null, 583 | "align_items": null, 584 | "align_self": null, 585 | "border": null, 586 | "bottom": null, 587 | "display": null, 588 | "flex": null, 589 | "flex_flow": null, 590 | "grid_area": null, 591 | "grid_auto_columns": null, 592 | "grid_auto_flow": null, 593 | "grid_auto_rows": null, 594 | "grid_column": null, 595 | "grid_gap": null, 596 | "grid_row": null, 597 | "grid_template_areas": null, 598 | "grid_template_columns": null, 599 | "grid_template_rows": null, 600 | "height": null, 601 | "justify_content": null, 602 | "justify_items": null, 603 | "left": null, 604 | "margin": null, 605 | "max_height": null, 606 | "max_width": null, 607 | "min_height": null, 608 | "min_width": null, 609 | "object_fit": null, 610 | "object_position": null, 611 | "order": null, 612 | "overflow": null, 613 | "overflow_x": null, 614 | "overflow_y": null, 615 | "padding": null, 616 | "right": null, 617 | "top": null, 618 | "visibility": null, 619 | "width": null 620 | } 621 | }, 622 | "c981017d7c39426dbaf566505c817bf4": { 623 | "model_module": "@jupyter-widgets/controls", 624 | "model_name": "ProgressStyleModel", 625 | "model_module_version": "1.5.0", 626 | "state": { 627 | "_model_module": "@jupyter-widgets/controls", 628 | "_model_module_version": "1.5.0", 629 | "_model_name": "ProgressStyleModel", 630 | "_view_count": null, 631 | "_view_module": "@jupyter-widgets/base", 632 | "_view_module_version": "1.2.0", 633 | "_view_name": "StyleView", 634 | "bar_color": null, 635 | "description_width": "" 636 | } 637 | }, 638 | "2ca3d70b283a4eaba7c683d03a53da07": { 639 | "model_module": "@jupyter-widgets/base", 640 | "model_name": "LayoutModel", 641 | "model_module_version": "1.2.0", 642 | "state": { 643 | "_model_module": "@jupyter-widgets/base", 644 | "_model_module_version": "1.2.0", 645 | "_model_name": "LayoutModel", 646 | "_view_count": null, 647 | "_view_module": "@jupyter-widgets/base", 648 | "_view_module_version": "1.2.0", 649 | "_view_name": "LayoutView", 650 | "align_content": null, 651 | "align_items": null, 652 | "align_self": null, 653 | "border": null, 654 | "bottom": null, 655 | "display": null, 656 | "flex": null, 657 | "flex_flow": null, 658 | "grid_area": null, 659 | "grid_auto_columns": null, 660 | "grid_auto_flow": null, 661 | "grid_auto_rows": null, 662 | "grid_column": null, 663 | "grid_gap": null, 664 | "grid_row": null, 665 | "grid_template_areas": null, 666 | "grid_template_columns": null, 667 | "grid_template_rows": null, 668 | "height": null, 669 | "justify_content": null, 670 | "justify_items": null, 671 | "left": null, 672 | "margin": null, 673 | "max_height": null, 674 | "max_width": null, 675 | "min_height": null, 676 | "min_width": null, 677 | "object_fit": null, 678 | "object_position": null, 679 | "order": null, 680 | "overflow": null, 681 | "overflow_x": null, 682 | "overflow_y": null, 683 | "padding": null, 684 | "right": null, 685 | "top": null, 686 | "visibility": null, 687 | "width": null 688 | } 689 | }, 690 | "51922432c661481da1ebe8192367c4dd": { 691 | "model_module": "@jupyter-widgets/controls", 692 | "model_name": "DescriptionStyleModel", 693 | "model_module_version": "1.5.0", 694 | "state": { 695 | "_model_module": "@jupyter-widgets/controls", 696 | "_model_module_version": "1.5.0", 697 | "_model_name": "DescriptionStyleModel", 698 | "_view_count": null, 699 | "_view_module": "@jupyter-widgets/base", 700 | "_view_module_version": "1.2.0", 701 | "_view_name": "StyleView", 702 | "description_width": "" 703 | } 704 | }, 705 | "61d30891744b4dcda25e8eeb71b54aad": { 706 | "model_module": "@jupyter-widgets/controls", 707 | "model_name": "HBoxModel", 708 | "model_module_version": "1.5.0", 709 | "state": { 710 | "_dom_classes": [], 711 | "_model_module": "@jupyter-widgets/controls", 712 | "_model_module_version": "1.5.0", 713 | "_model_name": "HBoxModel", 714 | "_view_count": null, 715 | "_view_module": "@jupyter-widgets/controls", 716 | "_view_module_version": "1.5.0", 717 | "_view_name": "HBoxView", 718 | "box_style": "", 719 | "children": [ 720 | "IPY_MODEL_718e38085b3b4f5eab8c481913bad4fb", 721 | "IPY_MODEL_3346c18e5aae424a9705689c62854830", 722 | "IPY_MODEL_1a782ba0602d4f7ebb0659295f7497f0" 723 | ], 724 | "layout": "IPY_MODEL_80ad93a8096a41f6ba7035bb483ee461" 725 | } 726 | }, 727 | "718e38085b3b4f5eab8c481913bad4fb": { 728 | "model_module": "@jupyter-widgets/controls", 729 | "model_name": "HTMLModel", 730 | "model_module_version": "1.5.0", 731 | "state": { 732 | "_dom_classes": [], 733 | "_model_module": "@jupyter-widgets/controls", 734 | "_model_module_version": "1.5.0", 735 | "_model_name": "HTMLModel", 736 | "_view_count": null, 737 | "_view_module": "@jupyter-widgets/controls", 738 | "_view_module_version": "1.5.0", 739 | "_view_name": "HTMLView", 740 | "description": "", 741 | "description_tooltip": null, 742 | "layout": "IPY_MODEL_bf1b70fc39db4494aee42122b14a0c8a", 743 | "placeholder": "​", 744 | "style": "IPY_MODEL_9cfbaf78195a4409aa3c751fef6377e1", 745 | "value": "100%" 746 | } 747 | }, 748 | "3346c18e5aae424a9705689c62854830": { 749 | "model_module": "@jupyter-widgets/controls", 750 | "model_name": "FloatProgressModel", 751 | "model_module_version": "1.5.0", 752 | "state": { 753 | "_dom_classes": [], 754 | "_model_module": "@jupyter-widgets/controls", 755 | "_model_module_version": "1.5.0", 756 | "_model_name": "FloatProgressModel", 757 | "_view_count": null, 758 | "_view_module": "@jupyter-widgets/controls", 759 | "_view_module_version": "1.5.0", 760 | "_view_name": "ProgressView", 761 | "bar_style": "success", 762 | "description": "", 763 | "description_tooltip": null, 764 | "layout": "IPY_MODEL_c8379e3f85c74bb9889bf49229de82b3", 765 | "max": 938, 766 | "min": 0, 767 | "orientation": "horizontal", 768 | "style": "IPY_MODEL_75125bd2901045ffa9c9b7d3583ee94c", 769 | "value": 938 770 | } 771 | }, 772 | "1a782ba0602d4f7ebb0659295f7497f0": { 773 | "model_module": "@jupyter-widgets/controls", 774 | "model_name": "HTMLModel", 775 | "model_module_version": "1.5.0", 776 | "state": { 777 | "_dom_classes": [], 778 | "_model_module": "@jupyter-widgets/controls", 779 | "_model_module_version": "1.5.0", 780 | "_model_name": "HTMLModel", 781 | "_view_count": null, 782 | "_view_module": "@jupyter-widgets/controls", 783 | "_view_module_version": "1.5.0", 784 | "_view_name": "HTMLView", 785 | "description": "", 786 | "description_tooltip": null, 787 | "layout": "IPY_MODEL_78a5fa8234ac4ff182f882908461559a", 788 | "placeholder": "​", 789 | "style": "IPY_MODEL_28abc7c5f1914aa0a97fd7d280d75ee0", 790 | "value": " 938/938 [03:40<00:00,  4.95it/s]" 791 | } 792 | }, 793 | "80ad93a8096a41f6ba7035bb483ee461": { 794 | "model_module": "@jupyter-widgets/base", 795 | "model_name": "LayoutModel", 796 | "model_module_version": "1.2.0", 797 | "state": { 798 | "_model_module": "@jupyter-widgets/base", 799 | "_model_module_version": "1.2.0", 800 | "_model_name": "LayoutModel", 801 | "_view_count": null, 802 | "_view_module": "@jupyter-widgets/base", 803 | "_view_module_version": "1.2.0", 804 | "_view_name": "LayoutView", 805 | "align_content": null, 806 | "align_items": null, 807 | "align_self": null, 808 | "border": null, 809 | "bottom": null, 810 | "display": null, 811 | "flex": null, 812 | "flex_flow": null, 813 | "grid_area": null, 814 | "grid_auto_columns": null, 815 | "grid_auto_flow": null, 816 | "grid_auto_rows": null, 817 | "grid_column": null, 818 | "grid_gap": null, 819 | "grid_row": null, 820 | "grid_template_areas": null, 821 | "grid_template_columns": null, 822 | "grid_template_rows": null, 823 | "height": null, 824 | "justify_content": null, 825 | "justify_items": null, 826 | "left": null, 827 | "margin": null, 828 | "max_height": null, 829 | "max_width": null, 830 | "min_height": null, 831 | "min_width": null, 832 | "object_fit": null, 833 | "object_position": null, 834 | "order": null, 835 | "overflow": null, 836 | "overflow_x": null, 837 | "overflow_y": null, 838 | "padding": null, 839 | "right": null, 840 | "top": null, 841 | "visibility": null, 842 | "width": null 843 | } 844 | }, 845 | "bf1b70fc39db4494aee42122b14a0c8a": { 846 | "model_module": "@jupyter-widgets/base", 847 | "model_name": "LayoutModel", 848 | "model_module_version": "1.2.0", 849 | "state": { 850 | "_model_module": "@jupyter-widgets/base", 851 | "_model_module_version": "1.2.0", 852 | "_model_name": "LayoutModel", 853 | "_view_count": null, 854 | "_view_module": "@jupyter-widgets/base", 855 | "_view_module_version": "1.2.0", 856 | "_view_name": "LayoutView", 857 | "align_content": null, 858 | "align_items": null, 859 | "align_self": null, 860 | "border": null, 861 | "bottom": null, 862 | "display": null, 863 | "flex": null, 864 | "flex_flow": null, 865 | "grid_area": null, 866 | "grid_auto_columns": null, 867 | "grid_auto_flow": null, 868 | "grid_auto_rows": null, 869 | "grid_column": null, 870 | "grid_gap": null, 871 | "grid_row": null, 872 | "grid_template_areas": null, 873 | "grid_template_columns": null, 874 | "grid_template_rows": null, 875 | "height": null, 876 | "justify_content": null, 877 | "justify_items": null, 878 | "left": null, 879 | "margin": null, 880 | "max_height": null, 881 | "max_width": null, 882 | "min_height": null, 883 | "min_width": null, 884 | "object_fit": null, 885 | "object_position": null, 886 | "order": null, 887 | "overflow": null, 888 | "overflow_x": null, 889 | "overflow_y": null, 890 | "padding": null, 891 | "right": null, 892 | "top": null, 893 | "visibility": null, 894 | "width": null 895 | } 896 | }, 897 | "9cfbaf78195a4409aa3c751fef6377e1": { 898 | "model_module": "@jupyter-widgets/controls", 899 | "model_name": "DescriptionStyleModel", 900 | "model_module_version": "1.5.0", 901 | "state": { 902 | "_model_module": "@jupyter-widgets/controls", 903 | "_model_module_version": "1.5.0", 904 | "_model_name": "DescriptionStyleModel", 905 | "_view_count": null, 906 | "_view_module": "@jupyter-widgets/base", 907 | "_view_module_version": "1.2.0", 908 | "_view_name": "StyleView", 909 | "description_width": "" 910 | } 911 | }, 912 | "c8379e3f85c74bb9889bf49229de82b3": { 913 | "model_module": "@jupyter-widgets/base", 914 | "model_name": "LayoutModel", 915 | "model_module_version": "1.2.0", 916 | "state": { 917 | "_model_module": "@jupyter-widgets/base", 918 | "_model_module_version": "1.2.0", 919 | "_model_name": "LayoutModel", 920 | "_view_count": null, 921 | "_view_module": "@jupyter-widgets/base", 922 | "_view_module_version": "1.2.0", 923 | "_view_name": "LayoutView", 924 | "align_content": null, 925 | "align_items": null, 926 | "align_self": null, 927 | "border": null, 928 | "bottom": null, 929 | "display": null, 930 | "flex": null, 931 | "flex_flow": null, 932 | "grid_area": null, 933 | "grid_auto_columns": null, 934 | "grid_auto_flow": null, 935 | "grid_auto_rows": null, 936 | "grid_column": null, 937 | "grid_gap": null, 938 | "grid_row": null, 939 | "grid_template_areas": null, 940 | "grid_template_columns": null, 941 | "grid_template_rows": null, 942 | "height": null, 943 | "justify_content": null, 944 | "justify_items": null, 945 | "left": null, 946 | "margin": null, 947 | "max_height": null, 948 | "max_width": null, 949 | "min_height": null, 950 | "min_width": null, 951 | "object_fit": null, 952 | "object_position": null, 953 | "order": null, 954 | "overflow": null, 955 | "overflow_x": null, 956 | "overflow_y": null, 957 | "padding": null, 958 | "right": null, 959 | "top": null, 960 | "visibility": null, 961 | "width": null 962 | } 963 | }, 964 | "75125bd2901045ffa9c9b7d3583ee94c": { 965 | "model_module": "@jupyter-widgets/controls", 966 | "model_name": "ProgressStyleModel", 967 | "model_module_version": "1.5.0", 968 | "state": { 969 | "_model_module": "@jupyter-widgets/controls", 970 | "_model_module_version": "1.5.0", 971 | "_model_name": "ProgressStyleModel", 972 | "_view_count": null, 973 | "_view_module": "@jupyter-widgets/base", 974 | "_view_module_version": "1.2.0", 975 | "_view_name": "StyleView", 976 | "bar_color": null, 977 | "description_width": "" 978 | } 979 | }, 980 | "78a5fa8234ac4ff182f882908461559a": { 981 | "model_module": "@jupyter-widgets/base", 982 | "model_name": "LayoutModel", 983 | "model_module_version": "1.2.0", 984 | "state": { 985 | "_model_module": "@jupyter-widgets/base", 986 | "_model_module_version": "1.2.0", 987 | "_model_name": "LayoutModel", 988 | "_view_count": null, 989 | "_view_module": "@jupyter-widgets/base", 990 | "_view_module_version": "1.2.0", 991 | "_view_name": "LayoutView", 992 | "align_content": null, 993 | "align_items": null, 994 | "align_self": null, 995 | "border": null, 996 | "bottom": null, 997 | "display": null, 998 | "flex": null, 999 | "flex_flow": null, 1000 | "grid_area": null, 1001 | "grid_auto_columns": null, 1002 | "grid_auto_flow": null, 1003 | "grid_auto_rows": null, 1004 | "grid_column": null, 1005 | "grid_gap": null, 1006 | "grid_row": null, 1007 | "grid_template_areas": null, 1008 | "grid_template_columns": null, 1009 | "grid_template_rows": null, 1010 | "height": null, 1011 | "justify_content": null, 1012 | "justify_items": null, 1013 | "left": null, 1014 | "margin": null, 1015 | "max_height": null, 1016 | "max_width": null, 1017 | "min_height": null, 1018 | "min_width": null, 1019 | "object_fit": null, 1020 | "object_position": null, 1021 | "order": null, 1022 | "overflow": null, 1023 | "overflow_x": null, 1024 | "overflow_y": null, 1025 | "padding": null, 1026 | "right": null, 1027 | "top": null, 1028 | "visibility": null, 1029 | "width": null 1030 | } 1031 | }, 1032 | "28abc7c5f1914aa0a97fd7d280d75ee0": { 1033 | "model_module": "@jupyter-widgets/controls", 1034 | "model_name": "DescriptionStyleModel", 1035 | "model_module_version": "1.5.0", 1036 | "state": { 1037 | "_model_module": "@jupyter-widgets/controls", 1038 | "_model_module_version": "1.5.0", 1039 | "_model_name": "DescriptionStyleModel", 1040 | "_view_count": null, 1041 | "_view_module": "@jupyter-widgets/base", 1042 | "_view_module_version": "1.2.0", 1043 | "_view_name": "StyleView", 1044 | "description_width": "" 1045 | } 1046 | }, 1047 | "c9305060d9e144a095a0d93b63355e0b": { 1048 | "model_module": "@jupyter-widgets/controls", 1049 | "model_name": "HBoxModel", 1050 | "model_module_version": "1.5.0", 1051 | "state": { 1052 | "_dom_classes": [], 1053 | "_model_module": "@jupyter-widgets/controls", 1054 | "_model_module_version": "1.5.0", 1055 | "_model_name": "HBoxModel", 1056 | "_view_count": null, 1057 | "_view_module": "@jupyter-widgets/controls", 1058 | "_view_module_version": "1.5.0", 1059 | "_view_name": "HBoxView", 1060 | "box_style": "", 1061 | "children": [ 1062 | "IPY_MODEL_2afc450d38f948ad89af08983065e9e2", 1063 | "IPY_MODEL_6dfbbde3bbe84ab9803ae729c6697aae", 1064 | "IPY_MODEL_58a1dace4c7b4ad98419fb961b920cc0" 1065 | ], 1066 | "layout": "IPY_MODEL_68db21ef3c204407ad9fa181fea152ae" 1067 | } 1068 | }, 1069 | "2afc450d38f948ad89af08983065e9e2": { 1070 | "model_module": "@jupyter-widgets/controls", 1071 | "model_name": "HTMLModel", 1072 | "model_module_version": "1.5.0", 1073 | "state": { 1074 | "_dom_classes": [], 1075 | "_model_module": "@jupyter-widgets/controls", 1076 | "_model_module_version": "1.5.0", 1077 | "_model_name": "HTMLModel", 1078 | "_view_count": null, 1079 | "_view_module": "@jupyter-widgets/controls", 1080 | "_view_module_version": "1.5.0", 1081 | "_view_name": "HTMLView", 1082 | "description": "", 1083 | "description_tooltip": null, 1084 | "layout": "IPY_MODEL_b01677a02efb420cb32c0957d77a2345", 1085 | "placeholder": "​", 1086 | "style": "IPY_MODEL_f8cfcf541c684ad889fda00b140e3ab1", 1087 | "value": "  2%" 1088 | } 1089 | }, 1090 | "6dfbbde3bbe84ab9803ae729c6697aae": { 1091 | "model_module": "@jupyter-widgets/controls", 1092 | "model_name": "FloatProgressModel", 1093 | "model_module_version": "1.5.0", 1094 | "state": { 1095 | "_dom_classes": [], 1096 | "_model_module": "@jupyter-widgets/controls", 1097 | "_model_module_version": "1.5.0", 1098 | "_model_name": "FloatProgressModel", 1099 | "_view_count": null, 1100 | "_view_module": "@jupyter-widgets/controls", 1101 | "_view_module_version": "1.5.0", 1102 | "_view_name": "ProgressView", 1103 | "bar_style": "danger", 1104 | "description": "", 1105 | "description_tooltip": null, 1106 | "layout": "IPY_MODEL_de9dcf17fb4a4858a5b68922376272e2", 1107 | "max": 938, 1108 | "min": 0, 1109 | "orientation": "horizontal", 1110 | "style": "IPY_MODEL_a32ba491756d40f49fe5b7e795bd3f21", 1111 | "value": 22 1112 | } 1113 | }, 1114 | "58a1dace4c7b4ad98419fb961b920cc0": { 1115 | "model_module": "@jupyter-widgets/controls", 1116 | "model_name": "HTMLModel", 1117 | "model_module_version": "1.5.0", 1118 | "state": { 1119 | "_dom_classes": [], 1120 | "_model_module": "@jupyter-widgets/controls", 1121 | "_model_module_version": "1.5.0", 1122 | "_model_name": "HTMLModel", 1123 | "_view_count": null, 1124 | "_view_module": "@jupyter-widgets/controls", 1125 | "_view_module_version": "1.5.0", 1126 | "_view_name": "HTMLView", 1127 | "description": "", 1128 | "description_tooltip": null, 1129 | "layout": "IPY_MODEL_1354164d0f164f7dac7a26e7fa86ca61", 1130 | "placeholder": "​", 1131 | "style": "IPY_MODEL_02c67b9bb3904e71a68168c23418fa1c", 1132 | "value": " 22/938 [00:05<03:39,  4.17it/s]" 1133 | } 1134 | }, 1135 | "68db21ef3c204407ad9fa181fea152ae": { 1136 | "model_module": "@jupyter-widgets/base", 1137 | "model_name": "LayoutModel", 1138 | "model_module_version": "1.2.0", 1139 | "state": { 1140 | "_model_module": "@jupyter-widgets/base", 1141 | "_model_module_version": "1.2.0", 1142 | "_model_name": "LayoutModel", 1143 | "_view_count": null, 1144 | "_view_module": "@jupyter-widgets/base", 1145 | "_view_module_version": "1.2.0", 1146 | "_view_name": "LayoutView", 1147 | "align_content": null, 1148 | "align_items": null, 1149 | "align_self": null, 1150 | "border": null, 1151 | "bottom": null, 1152 | "display": null, 1153 | "flex": null, 1154 | "flex_flow": null, 1155 | "grid_area": null, 1156 | "grid_auto_columns": null, 1157 | "grid_auto_flow": null, 1158 | "grid_auto_rows": null, 1159 | "grid_column": null, 1160 | "grid_gap": null, 1161 | "grid_row": null, 1162 | "grid_template_areas": null, 1163 | "grid_template_columns": null, 1164 | "grid_template_rows": null, 1165 | "height": null, 1166 | "justify_content": null, 1167 | "justify_items": null, 1168 | "left": null, 1169 | "margin": null, 1170 | "max_height": null, 1171 | "max_width": null, 1172 | "min_height": null, 1173 | "min_width": null, 1174 | "object_fit": null, 1175 | "object_position": null, 1176 | "order": null, 1177 | "overflow": null, 1178 | "overflow_x": null, 1179 | "overflow_y": null, 1180 | "padding": null, 1181 | "right": null, 1182 | "top": null, 1183 | "visibility": null, 1184 | "width": null 1185 | } 1186 | }, 1187 | "b01677a02efb420cb32c0957d77a2345": { 1188 | "model_module": "@jupyter-widgets/base", 1189 | "model_name": "LayoutModel", 1190 | "model_module_version": "1.2.0", 1191 | "state": { 1192 | "_model_module": "@jupyter-widgets/base", 1193 | "_model_module_version": "1.2.0", 1194 | "_model_name": "LayoutModel", 1195 | "_view_count": null, 1196 | "_view_module": "@jupyter-widgets/base", 1197 | "_view_module_version": "1.2.0", 1198 | "_view_name": "LayoutView", 1199 | "align_content": null, 1200 | "align_items": null, 1201 | "align_self": null, 1202 | "border": null, 1203 | "bottom": null, 1204 | "display": null, 1205 | "flex": null, 1206 | "flex_flow": null, 1207 | "grid_area": null, 1208 | "grid_auto_columns": null, 1209 | "grid_auto_flow": null, 1210 | "grid_auto_rows": null, 1211 | "grid_column": null, 1212 | "grid_gap": null, 1213 | "grid_row": null, 1214 | "grid_template_areas": null, 1215 | "grid_template_columns": null, 1216 | "grid_template_rows": null, 1217 | "height": null, 1218 | "justify_content": null, 1219 | "justify_items": null, 1220 | "left": null, 1221 | "margin": null, 1222 | "max_height": null, 1223 | "max_width": null, 1224 | "min_height": null, 1225 | "min_width": null, 1226 | "object_fit": null, 1227 | "object_position": null, 1228 | "order": null, 1229 | "overflow": null, 1230 | "overflow_x": null, 1231 | "overflow_y": null, 1232 | "padding": null, 1233 | "right": null, 1234 | "top": null, 1235 | "visibility": null, 1236 | "width": null 1237 | } 1238 | }, 1239 | "f8cfcf541c684ad889fda00b140e3ab1": { 1240 | "model_module": "@jupyter-widgets/controls", 1241 | "model_name": "DescriptionStyleModel", 1242 | "model_module_version": "1.5.0", 1243 | "state": { 1244 | "_model_module": "@jupyter-widgets/controls", 1245 | "_model_module_version": "1.5.0", 1246 | "_model_name": "DescriptionStyleModel", 1247 | "_view_count": null, 1248 | "_view_module": "@jupyter-widgets/base", 1249 | "_view_module_version": "1.2.0", 1250 | "_view_name": "StyleView", 1251 | "description_width": "" 1252 | } 1253 | }, 1254 | "de9dcf17fb4a4858a5b68922376272e2": { 1255 | "model_module": "@jupyter-widgets/base", 1256 | "model_name": "LayoutModel", 1257 | "model_module_version": "1.2.0", 1258 | "state": { 1259 | "_model_module": "@jupyter-widgets/base", 1260 | "_model_module_version": "1.2.0", 1261 | "_model_name": "LayoutModel", 1262 | "_view_count": null, 1263 | "_view_module": "@jupyter-widgets/base", 1264 | "_view_module_version": "1.2.0", 1265 | "_view_name": "LayoutView", 1266 | "align_content": null, 1267 | "align_items": null, 1268 | "align_self": null, 1269 | "border": null, 1270 | "bottom": null, 1271 | "display": null, 1272 | "flex": null, 1273 | "flex_flow": null, 1274 | "grid_area": null, 1275 | "grid_auto_columns": null, 1276 | "grid_auto_flow": null, 1277 | "grid_auto_rows": null, 1278 | "grid_column": null, 1279 | "grid_gap": null, 1280 | "grid_row": null, 1281 | "grid_template_areas": null, 1282 | "grid_template_columns": null, 1283 | "grid_template_rows": null, 1284 | "height": null, 1285 | "justify_content": null, 1286 | "justify_items": null, 1287 | "left": null, 1288 | "margin": null, 1289 | "max_height": null, 1290 | "max_width": null, 1291 | "min_height": null, 1292 | "min_width": null, 1293 | "object_fit": null, 1294 | "object_position": null, 1295 | "order": null, 1296 | "overflow": null, 1297 | "overflow_x": null, 1298 | "overflow_y": null, 1299 | "padding": null, 1300 | "right": null, 1301 | "top": null, 1302 | "visibility": null, 1303 | "width": null 1304 | } 1305 | }, 1306 | "a32ba491756d40f49fe5b7e795bd3f21": { 1307 | "model_module": "@jupyter-widgets/controls", 1308 | "model_name": "ProgressStyleModel", 1309 | "model_module_version": "1.5.0", 1310 | "state": { 1311 | "_model_module": "@jupyter-widgets/controls", 1312 | "_model_module_version": "1.5.0", 1313 | "_model_name": "ProgressStyleModel", 1314 | "_view_count": null, 1315 | "_view_module": "@jupyter-widgets/base", 1316 | "_view_module_version": "1.2.0", 1317 | "_view_name": "StyleView", 1318 | "bar_color": null, 1319 | "description_width": "" 1320 | } 1321 | }, 1322 | "1354164d0f164f7dac7a26e7fa86ca61": { 1323 | "model_module": "@jupyter-widgets/base", 1324 | "model_name": "LayoutModel", 1325 | "model_module_version": "1.2.0", 1326 | "state": { 1327 | "_model_module": "@jupyter-widgets/base", 1328 | "_model_module_version": "1.2.0", 1329 | "_model_name": "LayoutModel", 1330 | "_view_count": null, 1331 | "_view_module": "@jupyter-widgets/base", 1332 | "_view_module_version": "1.2.0", 1333 | "_view_name": "LayoutView", 1334 | "align_content": null, 1335 | "align_items": null, 1336 | "align_self": null, 1337 | "border": null, 1338 | "bottom": null, 1339 | "display": null, 1340 | "flex": null, 1341 | "flex_flow": null, 1342 | "grid_area": null, 1343 | "grid_auto_columns": null, 1344 | "grid_auto_flow": null, 1345 | "grid_auto_rows": null, 1346 | "grid_column": null, 1347 | "grid_gap": null, 1348 | "grid_row": null, 1349 | "grid_template_areas": null, 1350 | "grid_template_columns": null, 1351 | "grid_template_rows": null, 1352 | "height": null, 1353 | "justify_content": null, 1354 | "justify_items": null, 1355 | "left": null, 1356 | "margin": null, 1357 | "max_height": null, 1358 | "max_width": null, 1359 | "min_height": null, 1360 | "min_width": null, 1361 | "object_fit": null, 1362 | "object_position": null, 1363 | "order": null, 1364 | "overflow": null, 1365 | "overflow_x": null, 1366 | "overflow_y": null, 1367 | "padding": null, 1368 | "right": null, 1369 | "top": null, 1370 | "visibility": null, 1371 | "width": null 1372 | } 1373 | }, 1374 | "02c67b9bb3904e71a68168c23418fa1c": { 1375 | "model_module": "@jupyter-widgets/controls", 1376 | "model_name": "DescriptionStyleModel", 1377 | "model_module_version": "1.5.0", 1378 | "state": { 1379 | "_model_module": "@jupyter-widgets/controls", 1380 | "_model_module_version": "1.5.0", 1381 | "_model_name": "DescriptionStyleModel", 1382 | "_view_count": null, 1383 | "_view_module": "@jupyter-widgets/base", 1384 | "_view_module_version": "1.2.0", 1385 | "_view_name": "StyleView", 1386 | "description_width": "" 1387 | } 1388 | } 1389 | } 1390 | } 1391 | }, 1392 | "cells": [ 1393 | { 1394 | "cell_type": "markdown", 1395 | "metadata": { 1396 | "id": "view-in-github", 1397 | "colab_type": "text" 1398 | }, 1399 | "source": [ 1400 | "\"Open" 1401 | ] 1402 | }, 1403 | { 1404 | "cell_type": "code", 1405 | "execution_count": 1, 1406 | "metadata": { 1407 | "id": "Ee1NBwatIJlR" 1408 | }, 1409 | "outputs": [], 1410 | "source": [ 1411 | "import torch\n", 1412 | "import torch.nn as nn\n", 1413 | "import torch.nn.functional as F\n", 1414 | "import torch.optim as optim\n", 1415 | "\n", 1416 | "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n", 1417 | "\n", 1418 | "import torchvision as tv\n", 1419 | "import torchvision.transforms as T\n", 1420 | "\n", 1421 | "from PIL import Image\n", 1422 | "\n", 1423 | "import numpy as np\n", 1424 | "import matplotlib.pyplot as plt\n", 1425 | "\n", 1426 | "from tqdm.notebook import tqdm" 1427 | ] 1428 | }, 1429 | { 1430 | "cell_type": "code", 1431 | "source": [ 1432 | "START = 1e-4\n", 1433 | "END = .02\n", 1434 | "TIMESTEPS = 300\n", 1435 | "\n", 1436 | "IMAGE_SIZE = 64\n", 1437 | "BATCH_SIZE = 64\n", 1438 | "EPOCHS = 5\n", 1439 | "\n", 1440 | "LR = 1e-3\n", 1441 | "\n", 1442 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" 1443 | ], 1444 | "metadata": { 1445 | "id": "nzHOnslOMzvU" 1446 | }, 1447 | "execution_count": 2, 1448 | "outputs": [] 1449 | }, 1450 | { 1451 | "cell_type": "code", 1452 | "source": [ 1453 | "betas = torch.linspace(start=START, end=END, steps=TIMESTEPS)\n", 1454 | "alphas = 1 - betas\n", 1455 | "\n", 1456 | "alpha_bars = torch.cumprod(alphas, dim=0)\n", 1457 | "alpha_bars_prev = F.pad(alpha_bars[:-1], (1, 0), value=1.0)\n", 1458 | "sqrt_one_over_alpha_bars = torch.sqrt(1. / alpha_bars)\n", 1459 | "sqrt_alpha_bars = torch.sqrt(alpha_bars)\n", 1460 | "sqrt_one_minus_alpha_bars = torch.sqrt(1 - alpha_bars)\n", 1461 | "\n", 1462 | "posterior_variance = betas * (1. - alpha_bars_prev) / (1. - alpha_bars)" 1463 | ], 1464 | "metadata": { 1465 | "id": "_zBev601NSRS" 1466 | }, 1467 | "execution_count": 3, 1468 | "outputs": [] 1469 | }, 1470 | { 1471 | "cell_type": "code", 1472 | "source": [ 1473 | "class ConvBlock(nn.Module):\n", 1474 | " def __init__(self, in_channels, out_channels):\n", 1475 | " super().__init__()\n", 1476 | "\n", 1477 | " self.conv_blocks = nn.Sequential(\n", 1478 | " nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),\n", 1479 | " nn.ReLU(inplace=True),\n", 1480 | " nn.BatchNorm2d(out_channels),\n", 1481 | " nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),\n", 1482 | " nn.ReLU(inplace=True),\n", 1483 | " nn.BatchNorm2d(out_channels),\n", 1484 | " )\n", 1485 | "\n", 1486 | " def forward(self, x):\n", 1487 | " return self.conv_blocks(x)" 1488 | ], 1489 | "metadata": { 1490 | "id": "GMK460xiUdg7" 1491 | }, 1492 | "execution_count": 4, 1493 | "outputs": [] 1494 | }, 1495 | { 1496 | "cell_type": "code", 1497 | "source": [ 1498 | "class DownBlock(nn.Module):\n", 1499 | " def __init__(self, in_channels, out_channels):\n", 1500 | " super().__init__()\n", 1501 | "\n", 1502 | " self.conv_blocks = nn.Sequential(\n", 1503 | " nn.MaxPool2d(2),\n", 1504 | " ConvBlock(in_channels, out_channels)\n", 1505 | " )\n", 1506 | "\n", 1507 | " def forward(self, x):\n", 1508 | " return self.conv_blocks(x)" 1509 | ], 1510 | "metadata": { 1511 | "id": "DYvOCCVRUtNS" 1512 | }, 1513 | "execution_count": 5, 1514 | "outputs": [] 1515 | }, 1516 | { 1517 | "cell_type": "code", 1518 | "source": [ 1519 | "class UpBlock(nn.Module):\n", 1520 | " def __init__(self, in_channels, out_channels):\n", 1521 | " super().__init__()\n", 1522 | "\n", 1523 | " self.up = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=2, stride=2)\n", 1524 | " self.conv_blocks = ConvBlock(in_channels, out_channels)\n", 1525 | "\n", 1526 | " def forward(self, x, residual_inputs):\n", 1527 | " x = self.up(x)\n", 1528 | "\n", 1529 | " diff_y = residual_inputs.size()[2] - x.size()[2]\n", 1530 | " diff_x = residual_inputs.size()[3] - x.size()[3]\n", 1531 | "\n", 1532 | " x = F.pad(x, [diff_x // 2, diff_x - diff_x // 2,\n", 1533 | " diff_y // 2, diff_y - diff_y // 2])\n", 1534 | "\n", 1535 | " x = torch.cat([residual_inputs, x], dim=1)\n", 1536 | " x = self.conv_blocks(x)\n", 1537 | "\n", 1538 | " return x" 1539 | ], 1540 | "metadata": { 1541 | "id": "3QFPrpawUtP6" 1542 | }, 1543 | "execution_count": 6, 1544 | "outputs": [] 1545 | }, 1546 | { 1547 | "cell_type": "code", 1548 | "source": [ 1549 | "class OutBlock(nn.Module):\n", 1550 | " def __init__(self, in_channels, num_classes):\n", 1551 | " super().__init__()\n", 1552 | "\n", 1553 | " self.conv = nn.Conv2d(in_channels=in_channels, out_channels=num_classes, kernel_size=1)\n", 1554 | "\n", 1555 | " def forward(self, x):\n", 1556 | " return self.conv(x)" 1557 | ], 1558 | "metadata": { 1559 | "id": "s-V4VnnUUtTM" 1560 | }, 1561 | "execution_count": 7, 1562 | "outputs": [] 1563 | }, 1564 | { 1565 | "cell_type": "code", 1566 | "source": [ 1567 | "class UNet(nn.Module):\n", 1568 | " def __init__(self, in_channels, num_classes):\n", 1569 | " super().__init__()\n", 1570 | "\n", 1571 | " self.input_block = ConvBlock(in_channels, 64)\n", 1572 | "\n", 1573 | " self.down_1 = DownBlock(64, 128)\n", 1574 | " self.down_2 = DownBlock(128, 256)\n", 1575 | " self.down_3 = DownBlock(256, 512)\n", 1576 | " self.down_4 = DownBlock(512, 1024)\n", 1577 | "\n", 1578 | " self.up_4 = UpBlock(1024, 512)\n", 1579 | " self.up_3 = UpBlock(512, 256)\n", 1580 | " self.up_2 = UpBlock(256, 128)\n", 1581 | " self.up_1 = UpBlock(128, 64)\n", 1582 | "\n", 1583 | " self.output_block = OutBlock(64, num_classes)\n", 1584 | "\n", 1585 | " self.embedding_up_1 = nn.Linear(1, 128)\n", 1586 | " self.embedding_up_2 = nn.Linear(1, 256)\n", 1587 | " self.embedding_up_3 = nn.Linear(1, 512)\n", 1588 | " self.embedding_up_4 = nn.Linear(1, 1024)\n", 1589 | "\n", 1590 | " def forward(self, x, t):\n", 1591 | " batch_size = x.size(0)\n", 1592 | " down_cache_1 = self.input_block(x)\n", 1593 | "\n", 1594 | " down_cache_2 = self.down_1(down_cache_1)\n", 1595 | " down_cache_3 = self.down_2(down_cache_2)\n", 1596 | " down_cache_4 = self.down_3(down_cache_3)\n", 1597 | " down_cache_5 = self.down_4(down_cache_4)\n", 1598 | "\n", 1599 | " t_embed = self.embedding_up_4(t).view(batch_size, -1, 1, 1)\n", 1600 | " x = self.up_4(down_cache_5 + t_embed, down_cache_4)\n", 1601 | "\n", 1602 | " t_embed = self.embedding_up_3(t).view(batch_size, -1, 1, 1)\n", 1603 | " x = self.up_3(x + t_embed, down_cache_3)\n", 1604 | "\n", 1605 | " t_embed = self.embedding_up_2(t).view(batch_size, -1, 1, 1)\n", 1606 | " x = self.up_2(x + t_embed, down_cache_2)\n", 1607 | "\n", 1608 | " t_embed = self.embedding_up_1(t).view(batch_size, -1, 1, 1)\n", 1609 | " x = self.up_1(x + t_embed, down_cache_1)\n", 1610 | "\n", 1611 | " x = self.output_block(x)\n", 1612 | "\n", 1613 | " return x" 1614 | ], 1615 | "metadata": { 1616 | "id": "nQd2F863UtWj" 1617 | }, 1618 | "execution_count": 8, 1619 | "outputs": [] 1620 | }, 1621 | { 1622 | "cell_type": "code", 1623 | "source": [ 1624 | "def get_index_for_batch(values, t, x_shape):\n", 1625 | " batch_size = t.size(0)\n", 1626 | " out = values.gather(-1, t.cpu())\n", 1627 | "\n", 1628 | " return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(DEVICE)" 1629 | ], 1630 | "metadata": { 1631 | "id": "0d_mFHm6N0mp" 1632 | }, 1633 | "execution_count": 9, 1634 | "outputs": [] 1635 | }, 1636 | { 1637 | "cell_type": "code", 1638 | "source": [ 1639 | "def forward_process(x_0, t):\n", 1640 | " noise = torch.randn_like(x_0)\n", 1641 | "\n", 1642 | " sqrt_alpha_bars_for_batch = get_index_for_batch(sqrt_alpha_bars, t, x_0.shape).to(DEVICE)\n", 1643 | " sqrt_one_minus_alpha_bars_for_batch = get_index_for_batch(sqrt_one_minus_alpha_bars, t, x_0.shape).to(DEVICE)\n", 1644 | "\n", 1645 | " z = sqrt_one_minus_alpha_bars_for_batch * noise + sqrt_alpha_bars_for_batch * x_0\n", 1646 | "\n", 1647 | " return z, noise" 1648 | ], 1649 | "metadata": { 1650 | "id": "uhOmFQeZN5XY" 1651 | }, 1652 | "execution_count": 10, 1653 | "outputs": [] 1654 | }, 1655 | { 1656 | "cell_type": "code", 1657 | "source": [ 1658 | "reverse_transforms = T.Compose([\n", 1659 | " T.Lambda(lambda x: (x + 1) / 2),\n", 1660 | " T.Lambda(lambda x: x.permute(1, 2, 0)),\n", 1661 | " T.Lambda(lambda x: x * 255),\n", 1662 | " T.Lambda(lambda x: x.cpu().numpy().astype(np.uint8)),\n", 1663 | " T.ToPILImage()\n", 1664 | "])\n", 1665 | "\n", 1666 | "def convert_tensor_image(image):\n", 1667 | " if len(image.shape) == 4:\n", 1668 | " image = image[0, :, :, :]\n", 1669 | "\n", 1670 | " image = reverse_transforms(image)\n", 1671 | " return image" 1672 | ], 1673 | "metadata": { 1674 | "id": "-oWTY3w9SxDK" 1675 | }, 1676 | "execution_count": 11, 1677 | "outputs": [] 1678 | }, 1679 | { 1680 | "cell_type": "code", 1681 | "source": [ 1682 | "transforms = T.Compose([\n", 1683 | " T.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n", 1684 | " T.RandomHorizontalFlip(),\n", 1685 | " T.ToTensor(),\n", 1686 | " T.Lambda(lambda x: (x * 2) - 1)\n", 1687 | "])\n", 1688 | "\n", 1689 | "train = tv.datasets.CIFAR10(root='./dataset', download=True, transform=transforms, train=True)\n", 1690 | "val = tv.datasets.CIFAR10(root='./dataset', download=True, transform=transforms, train=False)\n", 1691 | "\n", 1692 | "dataset = ConcatDataset([train, val])\n", 1693 | "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)" 1694 | ], 1695 | "metadata": { 1696 | "colab": { 1697 | "base_uri": "https://localhost:8080/" 1698 | }, 1699 | "id": "aRMbMaGMRrgS", 1700 | "outputId": "763ce410-8c10-42dd-9468-f922c4055924" 1701 | }, 1702 | "execution_count": 12, 1703 | "outputs": [ 1704 | { 1705 | "output_type": "stream", 1706 | "name": "stdout", 1707 | "text": [ 1708 | "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz\n" 1709 | ] 1710 | }, 1711 | { 1712 | "output_type": "stream", 1713 | "name": "stderr", 1714 | "text": [ 1715 | "100%|██████████| 170498071/170498071 [00:13<00:00, 12645277.96it/s]\n" 1716 | ] 1717 | }, 1718 | { 1719 | "output_type": "stream", 1720 | "name": "stdout", 1721 | "text": [ 1722 | "Extracting ./dataset/cifar-10-python.tar.gz to ./dataset\n", 1723 | "Files already downloaded and verified\n" 1724 | ] 1725 | } 1726 | ] 1727 | }, 1728 | { 1729 | "cell_type": "code", 1730 | "source": [ 1731 | "network = UNet(3, 3).to(DEVICE)" 1732 | ], 1733 | "metadata": { 1734 | "id": "5WW4S17DSgsI" 1735 | }, 1736 | "execution_count": 13, 1737 | "outputs": [] 1738 | }, 1739 | { 1740 | "cell_type": "code", 1741 | "source": [ 1742 | "optimizer = optim.Adam(network.parameters(), lr=LR)" 1743 | ], 1744 | "metadata": { 1745 | "id": "DNCJALJwToDx" 1746 | }, 1747 | "execution_count": 14, 1748 | "outputs": [] 1749 | }, 1750 | { 1751 | "cell_type": "code", 1752 | "source": [ 1753 | "criterion = nn.MSELoss()" 1754 | ], 1755 | "metadata": { 1756 | "id": "nnYKe7ZwXPth" 1757 | }, 1758 | "execution_count": 15, 1759 | "outputs": [] 1760 | }, 1761 | { 1762 | "cell_type": "code", 1763 | "source": [ 1764 | "for epoch in range(1, EPOCHS + 1):\n", 1765 | " print(f'Epoch {epoch} / {EPOCHS}')\n", 1766 | " total_loss = .0\n", 1767 | " for images, _ in tqdm(dataloader):\n", 1768 | " optimizer.zero_grad()\n", 1769 | "\n", 1770 | " images = images.to(DEVICE)\n", 1771 | "\n", 1772 | " batch_size = images.size(0)\n", 1773 | " t = torch.randint(0, TIMESTEPS, (batch_size,)).to(DEVICE).long()\n", 1774 | " noisy_image, noise = forward_process(images, t)\n", 1775 | " noise_preds = network(noisy_image, t.unsqueeze(-1).float())\n", 1776 | "\n", 1777 | " loss = criterion(noise_preds, noise)\n", 1778 | " loss.backward()\n", 1779 | "\n", 1780 | " optimizer.step()\n", 1781 | "\n", 1782 | " total_loss += loss.detach().cpu().item()\n", 1783 | "\n", 1784 | " print(f'Loss: {total_loss:.2f}')" 1785 | ], 1786 | "metadata": { 1787 | "colab": { 1788 | "base_uri": "https://localhost:8080/", 1789 | "height": 474, 1790 | "referenced_widgets": [ 1791 | "b0fa66fc56e34465bb2a59f9ce916b71", 1792 | "14019560a279458cb775634b162b2407", 1793 | "4db4b93e47bd4c859b1d6685ed8cf76e", 1794 | "2ef34af8de2c4c7eab85fd3c83a5e1c5", 1795 | "1638fe599f2c409c8d75396bbad174b7", 1796 | "0a2010f11ba345a385997b328168127a", 1797 | "5a39152a73cb4a87bd5bdf6f5a6888cc", 1798 | "b72fab00156243478d0619f3f8a56a1a", 1799 | "38548a20e7744f098a9118b82ff13f95", 1800 | "47d16d89d79543ab8b893ccae490d2f0", 1801 | "05f22484b66641818d3288792ead2efa", 1802 | "7c76076cccf04326a82ec69c5b33c9af", 1803 | "326f80a6a9364bc4a37fd727649c5599", 1804 | "cd209532893744078f13583bc2765ef5", 1805 | "9933cef2fd0c44c2a308f24e5de0b1ea", 1806 | "9d07802c82da4382966ee2bc58bff4e2", 1807 | "7c62a32ae5944b7584b0026533d3bea9", 1808 | "7d6defe1998a4083bbf9f4fdb4e5d599", 1809 | "5a2ad31325dc45d0b795eb88daf147df", 1810 | "c981017d7c39426dbaf566505c817bf4", 1811 | "2ca3d70b283a4eaba7c683d03a53da07", 1812 | "51922432c661481da1ebe8192367c4dd", 1813 | "61d30891744b4dcda25e8eeb71b54aad", 1814 | "718e38085b3b4f5eab8c481913bad4fb", 1815 | "3346c18e5aae424a9705689c62854830", 1816 | "1a782ba0602d4f7ebb0659295f7497f0", 1817 | "80ad93a8096a41f6ba7035bb483ee461", 1818 | "bf1b70fc39db4494aee42122b14a0c8a", 1819 | "9cfbaf78195a4409aa3c751fef6377e1", 1820 | "c8379e3f85c74bb9889bf49229de82b3", 1821 | "75125bd2901045ffa9c9b7d3583ee94c", 1822 | "78a5fa8234ac4ff182f882908461559a", 1823 | "28abc7c5f1914aa0a97fd7d280d75ee0", 1824 | "c9305060d9e144a095a0d93b63355e0b", 1825 | "2afc450d38f948ad89af08983065e9e2", 1826 | "6dfbbde3bbe84ab9803ae729c6697aae", 1827 | "58a1dace4c7b4ad98419fb961b920cc0", 1828 | "68db21ef3c204407ad9fa181fea152ae", 1829 | "b01677a02efb420cb32c0957d77a2345", 1830 | "f8cfcf541c684ad889fda00b140e3ab1", 1831 | "de9dcf17fb4a4858a5b68922376272e2", 1832 | "a32ba491756d40f49fe5b7e795bd3f21", 1833 | "1354164d0f164f7dac7a26e7fa86ca61", 1834 | "02c67b9bb3904e71a68168c23418fa1c" 1835 | ] 1836 | }, 1837 | "id": "re16Q5X9Xjij", 1838 | "outputId": "177d142f-c50b-4ced-9ef8-7b0119d95643" 1839 | }, 1840 | "execution_count": 20, 1841 | "outputs": [ 1842 | { 1843 | "output_type": "stream", 1844 | "name": "stdout", 1845 | "text": [ 1846 | "Epoch 1 / 5\n" 1847 | ] 1848 | }, 1849 | { 1850 | "output_type": "display_data", 1851 | "data": { 1852 | "text/plain": [ 1853 | " 0%| | 0/938 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mtotal_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Loss: {total_loss:.2f}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1937 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 1938 | ] 1939 | } 1940 | ] 1941 | }, 1942 | { 1943 | "cell_type": "code", 1944 | "source": [ 1945 | "@torch.no_grad()\n", 1946 | "def sample_timestep(x, t):\n", 1947 | " betas_for_batch = get_index_for_batch(betas, t, x.shape)\n", 1948 | " sqrt_one_minus_alpha_for_batch = get_index_for_batch(\n", 1949 | " sqrt_one_minus_alpha_bars, t, x.shape\n", 1950 | " )\n", 1951 | " sqrt_one_over_alpha_bars_for_batch = get_index_for_batch(\n", 1952 | " sqrt_one_over_alpha_bars, t, x.shape\n", 1953 | " )\n", 1954 | "\n", 1955 | " model_mean = sqrt_one_over_alpha_bars_for_batch * (\n", 1956 | " x - betas_for_batch * network(x, t.float()) / sqrt_one_minus_alpha_for_batch\n", 1957 | " )\n", 1958 | "\n", 1959 | " posterior_variance_for_batch = get_index_for_batch(\n", 1960 | " posterior_variance, t, x.shape\n", 1961 | " )\n", 1962 | "\n", 1963 | " if t == 0:\n", 1964 | " return model_mean\n", 1965 | " else:\n", 1966 | " noise = torch.randn_like(x)\n", 1967 | " return model_mean + torch.sqrt(posterior_variance_for_batch) * noise" 1968 | ], 1969 | "metadata": { 1970 | "id": "WO7VBqT3YYei" 1971 | }, 1972 | "execution_count": 64, 1973 | "outputs": [] 1974 | }, 1975 | { 1976 | "cell_type": "code", 1977 | "source": [ 1978 | "@torch.no_grad()\n", 1979 | "def sample_plot_image():\n", 1980 | " img_size = IMAGE_SIZE\n", 1981 | " img = torch.randn((1, 3, img_size, img_size), device=DEVICE)\n", 1982 | " num_images = 10\n", 1983 | " stepsize = int(TIMESTEPS / num_images)\n", 1984 | "\n", 1985 | " for i in range(0, TIMESTEPS)[::-1]:\n", 1986 | " t = torch.full((1,), i, device=DEVICE, dtype=torch.long)\n", 1987 | " img = sample_timestep(img, t)\n", 1988 | "\n", 1989 | " return img" 1990 | ], 1991 | "metadata": { 1992 | "id": "WhA9Wk8Mcerx" 1993 | }, 1994 | "execution_count": 73, 1995 | "outputs": [] 1996 | }, 1997 | { 1998 | "cell_type": "code", 1999 | "source": [ 2000 | "img = sample_plot_image()" 2001 | ], 2002 | "metadata": { 2003 | "id": "8DGzIGuXeJf-" 2004 | }, 2005 | "execution_count": 79, 2006 | "outputs": [] 2007 | }, 2008 | { 2009 | "cell_type": "code", 2010 | "source": [ 2011 | "img = convert_tensor_image(img)\n", 2012 | "plt.imshow(img)\n", 2013 | "plt.show()" 2014 | ], 2015 | "metadata": { 2016 | "id": "Kbc0j49gnQ0I" 2017 | }, 2018 | "execution_count": null, 2019 | "outputs": [] 2020 | }, 2021 | { 2022 | "cell_type": "code", 2023 | "source": [ 2024 | "i = 0\n", 2025 | "img = torch.randn((1, 3, IMAGE_SIZE, IMAGE_SIZE), device=DEVICE)\n", 2026 | "t = torch.full((1,), i, device=DEVICE, dtype=torch.long)\n", 2027 | "img = sample_timestep(img, t)" 2028 | ], 2029 | "metadata": { 2030 | "id": "jK--uaBboI5a" 2031 | }, 2032 | "execution_count": 81, 2033 | "outputs": [] 2034 | }, 2035 | { 2036 | "cell_type": "code", 2037 | "source": [ 2038 | "img" 2039 | ], 2040 | "metadata": { 2041 | "id": "iFnbbmNXocA2" 2042 | }, 2043 | "execution_count": null, 2044 | "outputs": [] 2045 | }, 2046 | { 2047 | "cell_type": "code", 2048 | "source": [], 2049 | "metadata": { 2050 | "id": "3NgzKu8mov3g" 2051 | }, 2052 | "execution_count": 45, 2053 | "outputs": [] 2054 | }, 2055 | { 2056 | "cell_type": "code", 2057 | "source": [], 2058 | "metadata": { 2059 | "id": "qJxom48CpbRj" 2060 | }, 2061 | "execution_count": null, 2062 | "outputs": [] 2063 | } 2064 | ] 2065 | } -------------------------------------------------------------------------------- /23_Value_functions_and_policy_iteration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyN0jZMs8LETkT8dHP5fS1rg", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": { 33 | "id": "Fz5lLqGQzR5M" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "import numpy as np" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "source": [ 43 | "class Environment:\n", 44 | " def __init__(self):\n", 45 | " self.x = 2\n", 46 | " self.y = 2\n", 47 | " self.A = (0, 1)\n", 48 | " self.B = (0, 3)\n", 49 | " self.A_next = (1, 4)\n", 50 | " self.B_next = (3, 2)\n", 51 | " self.edge_size = 5\n", 52 | "\n", 53 | " def calculate_next(self, y, x, move):\n", 54 | " if isinstance(move, str):\n", 55 | " move = move.lower()\n", 56 | " new_x, new_y = x, y\n", 57 | "\n", 58 | " if move in ('u', 'up', 0):\n", 59 | " new_y -= 1\n", 60 | " elif move in ('d', 'down', 1):\n", 61 | " new_y += 1\n", 62 | " elif move in ('r', 'right', 2):\n", 63 | " new_x += 1\n", 64 | " elif move in ('l', 'left', 3):\n", 65 | " new_x -= 1\n", 66 | "\n", 67 | " if (y, x) == self.A:\n", 68 | " new_x, new_y = self.A_next\n", 69 | " reward = 10\n", 70 | " elif (y, x) == self.B:\n", 71 | " new_x, new_y = self.B_next\n", 72 | " reward = 5\n", 73 | " elif new_x < 0 or new_x >= self.edge_size:\n", 74 | " new_x, new_y = x, y\n", 75 | " reward = -1\n", 76 | " elif new_y < 0 or new_y >= self.edge_size:\n", 77 | " new_x, new_y = x, y\n", 78 | " reward = -1\n", 79 | " else:\n", 80 | " reward = 0\n", 81 | "\n", 82 | " return new_y, new_x, reward\n", 83 | "\n", 84 | " def step(self, move):\n", 85 | " new_y, new_x, reward = self.calculate_next(self.y, self.x, move)\n", 86 | "\n", 87 | " self.y = new_y\n", 88 | " self.x = new_x\n", 89 | "\n", 90 | " return reward\n", 91 | "\n", 92 | " def predict_reward(self, y, x, move):\n", 93 | " new_y, new_x, reward = self.calculate_next(y, x, move)\n", 94 | " return new_y, new_x, reward\n", 95 | "\n", 96 | " def reset(self):\n", 97 | " self.x = 2\n", 98 | " self.y = 2\n", 99 | "\n", 100 | " @property\n", 101 | " def moves(self):\n", 102 | " return range(4)\n", 103 | "\n", 104 | " def __repr__(self):\n", 105 | " val = ''\n", 106 | " for i in range(self.edge_size):\n", 107 | " for j in range(self.edge_size):\n", 108 | " if i == self.y and j == self.x:\n", 109 | " val += '*'\n", 110 | " else:\n", 111 | " val += '_'\n", 112 | " val += '\\n'\n", 113 | " return val" 114 | ], 115 | "metadata": { 116 | "id": "7fxGRyVVzf1U" 117 | }, 118 | "execution_count": 126, 119 | "outputs": [] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "source": [ 124 | "def argmax(values):\n", 125 | " maximum = float('-inf')\n", 126 | " moves = []\n", 127 | "\n", 128 | " for i, value in enumerate(values):\n", 129 | " if value > maximum:\n", 130 | " moves = [i]\n", 131 | " maximum = value\n", 132 | " elif value == maximum:\n", 133 | " moves.append(i)\n", 134 | "\n", 135 | " return moves" 136 | ], 137 | "metadata": { 138 | "id": "SBIJaIUV-ELw" 139 | }, 140 | "execution_count": 161, 141 | "outputs": [] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "source": [ 146 | "def calculate_value_function(env, policy, gamma=.9):\n", 147 | " value_function = np.zeros((5, 5))\n", 148 | "\n", 149 | " for _ in range(50):\n", 150 | " for i in range(5):\n", 151 | " for j in range(5):\n", 152 | " temp = 0\n", 153 | " for a in env.moves:\n", 154 | " next_y, next_x, reward = env.predict_reward(i, j, a)\n", 155 | " temp += policy[i, j, a] * (reward + gamma * value_function[next_y, next_x])\n", 156 | " value_function[i, j] = temp\n", 157 | "\n", 158 | " return value_function" 159 | ], 160 | "metadata": { 161 | "id": "v4175Xl_1h84" 162 | }, 163 | "execution_count": 134, 164 | "outputs": [] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "source": [ 169 | "def update_policy(value_function):\n", 170 | " new_policy = np.zeros((5, 5, 4))\n", 171 | "\n", 172 | " for i in range(5):\n", 173 | " for j in range(5):\n", 174 | " l = []\n", 175 | " for a in env.moves:\n", 176 | " new_y, new_x, _ = env.predict_reward(i, j, a)\n", 177 | " l.append(value_function[new_y, new_x])\n", 178 | " maximums = argmax(l)\n", 179 | "\n", 180 | " new_policy[i, j, maximums] = 1 / len(maximums)\n", 181 | "\n", 182 | " return new_policy" 183 | ], 184 | "metadata": { 185 | "id": "Vd9Wp36NAAfM" 186 | }, 187 | "execution_count": 171, 188 | "outputs": [] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "source": [ 193 | "env = Environment()" 194 | ], 195 | "metadata": { 196 | "id": "AdXjN_rG1fVy" 197 | }, 198 | "execution_count": 172, 199 | "outputs": [] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "source": [ 204 | "policy = np.ones((5, 5, 4)) * .25" 205 | ], 206 | "metadata": { 207 | "id": "XQdAZ2KQ8nn5" 208 | }, 209 | "execution_count": 173, 210 | "outputs": [] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "source": [ 215 | "for _ in range(10):\n", 216 | " value_function = calculate_value_function(env, policy)\n", 217 | " policy = update_policy(value_function)" 218 | ], 219 | "metadata": { 220 | "id": "ZEUZexM15bAb" 221 | }, 222 | "execution_count": 174, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "source": [ 228 | "value_function # Q table" 229 | ], 230 | "metadata": { 231 | "colab": { 232 | "base_uri": "https://localhost:8080/" 233 | }, 234 | "id": "Q1ey6vjm9Jgw", 235 | "outputId": "4072c658-1386-4fce-82e4-cb718f90e9a2" 236 | }, 237 | "execution_count": 175, 238 | "outputs": [ 239 | { 240 | "output_type": "execute_result", 241 | "data": { 242 | "text/plain": [ 243 | "array([[21.97748529, 24.4194281 , 21.97748529, 19.4194281 , 17.47748529],\n", 244 | " [19.77973676, 21.97748529, 19.77973676, 17.80176308, 16.02158677],\n", 245 | " [17.80176308, 19.77973676, 17.80176308, 16.02158677, 14.4194281 ],\n", 246 | " [16.02158677, 17.80176308, 16.02158677, 14.4194281 , 12.97748529],\n", 247 | " [14.4194281 , 16.02158677, 14.4194281 , 12.97748529, 11.67973676]])" 248 | ] 249 | }, 250 | "metadata": {}, 251 | "execution_count": 175 252 | } 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "source": [ 258 | "moves = ['U', 'D', 'R', 'L']\n", 259 | "\n", 260 | "for i in range(5):\n", 261 | " print('|', end='')\n", 262 | " for j in range(5):\n", 263 | " for a in env.moves:\n", 264 | " if policy[i, j, a] != 0:\n", 265 | " print(moves[a], end='')\n", 266 | " else:\n", 267 | " print(' ', end='')\n", 268 | " print('|', end='')\n", 269 | " print('')" 270 | ], 271 | "metadata": { 272 | "colab": { 273 | "base_uri": "https://localhost:8080/" 274 | }, 275 | "id": "N9X8m-LH9UfU", 276 | "outputId": "b619cb70-3ca9-4aef-d4c5-e60917e2ba90" 277 | }, 278 | "execution_count": 181, 279 | "outputs": [ 280 | { 281 | "output_type": "stream", 282 | "name": "stdout", 283 | "text": [ 284 | "| R |UDRL| L|UDRL| L|\n", 285 | "| R |U |U L| L| L|\n", 286 | "| R |U |U L|U L|U L|\n", 287 | "| R |U |U L|U L|U L|\n", 288 | "| R |U |U L|U L|U L|\n" 289 | ] 290 | } 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "source": [], 296 | "metadata": { 297 | "id": "OXFIubz3_IyX" 298 | }, 299 | "execution_count": null, 300 | "outputs": [] 301 | } 302 | ] 303 | } -------------------------------------------------------------------------------- /24_Double_Deep_Q_Learning_1_gym_intro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyM4+5u2E899RPT5eXzO/42D", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "source": [ 32 | "!pip install -q pyvirtualdisplay\n", 33 | "!pip install -q swig\n", 34 | "!pip install -q gymnasium[all]" 35 | ], 36 | "metadata": { 37 | "colab": { 38 | "base_uri": "https://localhost:8080/" 39 | }, 40 | "id": "OqttDVvk0LTA", 41 | "outputId": "5247d6e3-b297-47fd-efde-1fbf7f30090a" 42 | }, 43 | "execution_count": 12, 44 | "outputs": [ 45 | { 46 | "output_type": "stream", 47 | "name": "stdout", 48 | "text": [ 49 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 50 | " Building wheel for box2d-py (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 51 | ] 52 | } 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 117, 58 | "metadata": { 59 | "id": "ufYQ8d50z_6t" 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "import gymnasium as gym\n", 64 | "import numpy as np\n", 65 | "import matplotlib.pyplot as plt\n", 66 | "\n", 67 | "import random\n", 68 | "from collections import namedtuple, deque\n", 69 | "\n", 70 | "from IPython.display import clear_output" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "source": [ 76 | "env = gym.make(\"LunarLander-v2\", render_mode=\"rgb_array\")" 77 | ], 78 | "metadata": { 79 | "id": "jNSDYwlI0HL4" 80 | }, 81 | "execution_count": 28, 82 | "outputs": [] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "source": [ 87 | "observation, info = env.reset()" 88 | ], 89 | "metadata": { 90 | "id": "G8V5e4Lb0Zgj" 91 | }, 92 | "execution_count": 29, 93 | "outputs": [] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "source": [ 98 | "observation" 99 | ], 100 | "metadata": { 101 | "colab": { 102 | "base_uri": "https://localhost:8080/" 103 | }, 104 | "id": "Bg3ih6gk27ll", 105 | "outputId": "cf92c132-f634-4f83-9c72-c5b3df1e52cd" 106 | }, 107 | "execution_count": 32, 108 | "outputs": [ 109 | { 110 | "output_type": "execute_result", 111 | "data": { 112 | "text/plain": [ 113 | "array([-0.00484533, 1.408985 , -0.4908019 , -0.08600599, 0.00562137,\n", 114 | " 0.11117391, 0. , 0. ], dtype=float32)" 115 | ] 116 | }, 117 | "metadata": {}, 118 | "execution_count": 32 119 | } 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "source": [ 125 | "env.action_space.sample()" 126 | ], 127 | "metadata": { 128 | "colab": { 129 | "base_uri": "https://localhost:8080/" 130 | }, 131 | "id": "z5qq_0PW5YXC", 132 | "outputId": "ee96bf80-3dce-4141-8f79-4ac61bd0ce92" 133 | }, 134 | "execution_count": 57, 135 | "outputs": [ 136 | { 137 | "output_type": "execute_result", 138 | "data": { 139 | "text/plain": [ 140 | "1" 141 | ] 142 | }, 143 | "metadata": {}, 144 | "execution_count": 57 145 | } 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "source": [ 151 | "fig, axs = plt.subplots(1, 1, figsize=(5, 5))\n", 152 | "\n", 153 | "for _ in range(100):\n", 154 | " action = env.action_space.sample() # agent policy that uses the observation and info\n", 155 | " observation, reward, terminated, truncated, info = env.step(action)\n", 156 | "\n", 157 | " if terminated or truncated:\n", 158 | " observation, info = env.reset()\n", 159 | "\n", 160 | " axs.imshow(env.render())\n", 161 | " axs.axis('off')\n", 162 | " plt.pause(.01)\n", 163 | "\n", 164 | "env.close()" 165 | ], 166 | "metadata": { 167 | "colab": { 168 | "base_uri": "https://localhost:8080/", 169 | "height": 295 170 | }, 171 | "id": "xKmBOW9_29UQ", 172 | "outputId": "e6851b6b-8cac-4260-ccef-88b47bc5a763" 173 | }, 174 | "execution_count": 52, 175 | "outputs": [ 176 | { 177 | "output_type": "display_data", 178 | "data": { 179 | "text/plain": [ 180 | "
" 181 | ], 182 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEWCAYAAACqitpwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWYUlEQVR4nO3de3CU5aHH8d+7l2w2uxty3QByC6AogiXgUREvnVPl1EqtY0vtOFo7xzmdTosd/UsdZ+w4c+aIh3qmc04RRS5tgZSbB4TjJYiUWkC5KBCu4WJISMIl93uyt/f8sSYEDBjgSTaB72fmnXd3k+w+2ZnsN8/7vvuuZdu2LQAADHIkegAAgGsPcQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYJyrp99oWVZvjgMAMED05MQuzFwAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wA4AIOy9LoG27Qd266SUlud6KHMyARFwC4QJLbrWE5ORrk9ys7PT3RwxmQLNu27R59o2X19lgAoN8YnJkpf0qKviorU6xnL5PXjZ5kg7gAAC5LT7LBZjEAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAYABwuVwyO0YGC/bA2OUAHCdC3q9WvPQQ1o/Y4aGpKQkejjfypXoAQAAvt2EzExNHzFCkpSXna1TJSUJHtGlWbZt2z36Rsvq7bEAAC7CYVn6t/Hj5XY69ea+fYr17KW7V/QkG8QFAHBZepIN9rkAAIwjLgBwDXC5kuRw9J/d6MQFAAYwy7KUnTVGk/Me14gRUxI9nE79J3MAcB1zu1OU7AnI6w0oGouqurr4kt/v8fiVkT5Kt9w4XSODUxVMHa9jpz9RdXWxGhvP9tGoL464AEBCWEpKSlFO9jilpuYoK32MhufcoRRPunYdXqLa2lLFYtFv/JQnya9hwybpppH/rNwh9yrNm6skZ/x9LzcOma6Km/do+xd/USwW6etf6DzEBQD6gNPpVoo3XUken4JZ4zRiyBSl+nOU5stVmneEkpx+OSyXbMU0JGuCDiZ9qLa2RkmSw+FUauoQjRg6ReNGP6ihaXkKJA+V00o670hejzOg2296WqfO7NeJ0h2J+lUlERcA6BNDB09U3oSZGpIxSSnuDCW70pTk9EmyZCumSKxN9e0n1dR+Wk6XU05nkiRLmZmjNHrk3bp5xL8oOGi8fEk5cljObh/DsixlpIzVlFufVHVNsRqbKvv0d+yKuABAH/A4UxVICWqwf6KidlihaLPq20+qLVyv6sajqq4vVmtzvcpP79Opyn1yOJyaOP6HmjTuJ7oh7XYluQIXjcqFbhz8gConFmnL9rcTtnmMuABAH6iuO6GGljM6Wf+5wtEWtbU1qrRip85WH1Fzc7XqGsoUCjdLspSamqO8CTM1eeyTCniG9jgqUnz2kuxO022jf6rik5/pZNmXvfdLXWocvEMfAPpG3q0zFY2EVXZqj+qbKxSLRWXb53bau93JGjn8Dt096Vcann6HXA7vFb/2xuyoik69r4/+8arq60+Z+hUkcfoXADAm4HbrN+PHa2N5uXZVVRm/f6czSXff8YxuG/UTZfpvuqzZysW0Req09cB/a+vOBUY3j3H6FwAw5KmxYzU1GNTzEycqqRc+U8VhORUcdItcLo8sQy/NHucg5d34pEaOuF1S304Q2OeCfud3v5NGjpTa26V166TCwvjtTU1SfX1ix9ZfjRolvfKKZNtSTY00b178+YvFpKoqKRxO9AgHvvdKSjQuLU0FZWUKx2LG7z8cadPu/auUPCVV3q+PJrtalmUpzTtS/3Tr06qpLTG+eeySj81mMfQ3S5dK48adu27b8RfJffukbdvi18+elT74IHFj7G9uuUX6y1/Ov822pba2+PN05owUiUg7dkhFRYkZI3pm2qRf69ZbHlKOb6KcjqSrvj/bthWKNmnb4f/RPz6f1+0bM6/kPr8NMxf0S13/l7EsyeGQ8vKkSZPit7W1SU8/HX8BbWiQ5s+Pz2piMen0aam5OSHDTqgL//+zLCklRfrxj+PXbVuqrj73PO3ZI61ZE/9ae7tUWtqnw8VF7ClaqaysXCW70pSePPqq/7G3LEseV0C35f5UpRVf6ETJ5z2Kw9UiLhhQOv7OvF5pzJhzt8+bF1+HQtLf/ia9+iqbgjp0PGeWJWVnxxdJuukmaebMeHRqa6U5c6SNGxM3TsQ1t1Zpz/7VCviz5XXFN4+Z2HKUnjJKd932jKpritXQcNrASC+NuGDA6vjnKxyWjhyJh6WtTdqwIb4JCN/U9R/W6ur4bMW245vNvkzM2yHQjdIzu1R0bLO8EzI02H+bnNbVbx5zWC7lZt2nyRMe19ad8xUOtxoY6cURF/RrXV8MYzEpGo3fVlEh/fWv5/YrbN4stfbu38qA0vV5i0Ti18Ph+PO0d2/8uSwtJSj9lW1Htf/oexocvFkp7gylJecamb24nT6NueF+HSv9u8rK9lz9QC+BuKBfsu14SI4di++8j0bjO6M3bIh/PRqNHz2GczqC0toq7d4df46am+MRLi8/97VQKHFjRM+1ttfqi33LlRrIUZIrIJ87+6ruz7ZjOtt4SNv2vqWKiv2GRnlxxAX90G/08stzFYlIBw/Gd9Dj0jyeXG3derM++OBDtbRIn30WjwsGtorKPSo6tkm+iVlKdqbJ6XBf0f3YdkyVTYe1addsHf3q731yvjHign5oqjZsmJvoQQwoTmeGSkpu1oYNHyZ6KDDswLH/U1bmaLmH+5Ttu/myf962Y6qo261Nu/5TX53Y0idHikm8Qx8A+rXW9lrt3r9K9c1laglXXVYcYnZUpxsKtXHHf6i4ZFufhUUiLgDQ752pOaT9h9erqqlItnq2vdO2YzpZs10fbfudiks+N/LmyctBXADAALfDoRt8Pjl65Wwmtg5/tUHlpwtV21os27706WdidlRltTv18Wf/rtKyLyT13YylA3EBAAN+deut2vDDH2pm13f3GtQWqtcX+/6q6objaglXX/T7bNvWiapP9eHWV1RWsfdbQ9RbiAsAGDApM1OOr9e9paruuAoPvaealuOKxs4/pty2bbWG63T49Hpt3P6aTp3er0TMWDpwtBgAGPDC55/r0dxc5R892ouPYutoySYNDU6Qxx1Qdsp4WZalmB1VbWuxjpVv0vbdi1RTl/gTxREXADCgqq1NCw4d6vXHaQ81avvexcrOulE+d7acDo/ONOxTYdH/6vDxjZJly7IcCdsc1oG4AMAA09B8WoUH18qeEFOovVl7DqzWsZK/KzV1sL4z8REdL96miooDYrMYAKDHbDumA8fflyfZrxNl21RdW6y0wDA9eM+Lyg3ep1uGPaKdhxdr34H3FYm0JWSMxAUABqBYLKxd+5ZKkvwpQT1470saGZwqrztTreE6tbc3KxptT9j4OFoMAAa4ITnjFUy/RT53UG2ROm07MFeHj2zs03fkX6jHcXnnnXc0bdo0+Xw+PvIYAPqJ0cPv1T1Tfq1072hFYyEdO71Rh4983Ccnp7yUHsflmWee0ccff6wVK1boqaeeUjAY7M1xAQC+hcvl0YjheRqcepuclltltTv1j11/VHNLjdHHCbhcujMrS67LmFj0eJ+LZVnyer16+OGH9d3vfldFRUVasWKFVq1apYqKCrW3J27bHgBcfyzdlfevyhv7hJKcfjW0l2nXoSWqrDpu/JEeHzVK41JTFUxO1vqysh79zBXtc/H5fMrLy9Ps2bO1ZcsWvf7667r//vvldDqv5O4AAJcpI2O4Rg+7V6me4QpFW7Tnq+UqOvZJr7y/paq9XfbX65664h36lmXJsiwNHTpUs2bN0sqVK7V27Vo98sgjCgaDcjg4VgAAeoPb7dV9d85SzqAJitkRHT+zUV/szVc43Duf9f1+ebn+69AhHWxr05tvvtmjnzFyKLLT6VQwGNTDDz+shx56SFu3btXatWu1cuVKlXd8vioAwBBLDTWVqvIXybIsfXlwuRoaz/Tao4VjMVWGQnpn7lz94he/6NHPGH2fi2VZcjqduu+++3TnnXdq1qxZWrNmjZYtW6ajR4+qiQ89B4CrFg636NNd/61Rp6YqNT1HJWU7evXxAoGA3njjDf385z/v+dHCdh84e/asvWTJEnvGjBl2cnKyrfg5CVhYul2WLl2a8DEMtGXKlCn2888/n/BxsFx7yw9+8AN71apVl/263yfv0M/OztYTTzyhRx99VIWFhVq0aJE2bdqk0tJSRaN9++loAICeefzxx7VgwQL5fL7L/tk+O/2Lw+GQ3+/X3Xffrbvuukt79uzRRx99pD//+c8qLi5WOBzuq6EAAC7Bsiw9+uijmjdvnvx+/xXdR0IO6XI4HJo8ebJeeOEFbdmyRQsXLtSDDz6oQYMGJWI4AICvDR8+XL/85S+Vn5+v9PT0K76fhJ640ul0Kjs7W08++aRmzJihHTt2aMmSJSooKFBtbS2bzACgDw0dOlQrV67U7bffLpfr6vLQL86KbFmW0tPTNX36dH3ve99TUVGR8vPztX79eh04cECxWGI/9AYArnVjxozRypUrlZeXZ+T8kZZtJ/C0mZcQjUZ14sQJbd68We+8844OHjyoxsbGRA8LV8nhcMjpdF50cblcGj9+vIqKihI91AHF7/fL4/Fo7969/DOGy5KUlKRhw4Zp9erVysvLM3a//TYuHWzbVnNzszZv3qxly5Z1bjJD/+JyuRQIBOTz+RQIBOT3+7tdAoGAUlNTNWjQIAUCgW7Xfr+fM29fgaNHj2r+/PlavXq1Tp48mejhYABwOBx65ZVX9Nxzzyk1NdXo312/j0sH27bV1NSk48eP609/+pPef/99lZaWKhQKJXpoA1bHTMHtdp+37nrZ4/EoLS1N6enpl1x8Pp+SkpLkcrk61263+xsLpwXqXaFQSMXFxVq0aJHeffddlZSUKBJJ7KnX0T95vV69+uqrevbZZ5WcnGz8/gdMXDp0DLe4uFhr167VihUrtGNH7747daDyer3KyMjoNg5paWlKTU1VIBDonGl0XO56vafHtzPT6F9s29bx48e1ZMkS5efn69ixY4keEvoRy7L0+9//Xs8//3yv/e0OuLh0FY1GVVNTo8rKStXW1qqurk41NTXnrWtrazuXlpYWhcNhhUKhzqW76/2FZVlKSkqSx+ORx+NRcnJy52WPxyO/36/s7OzzlmAwqKysLAWDQfl8vs4ZQ3czFGYR175wOKzy8nLl5+dr+fLlKioqYrZ/nZs6dapmz56tqVOnyu1299rjDOi4XA7bttXe3q7m5ubOpampqdvrDQ0Nqq+vP29dV1d33vXm5ubOWdSF6wsvS+f+s+84m3TH5UGDBikjI0OZmZnKyMg473J6evo39k+kpqZ2Ll6vlxkDeqy4uFhr1qzR4sWLdfDgQXb8X4ceeOABvf3228rNze31147rJi49Zdu2YrGYYrGYotHoRdehUKhzZtR1hnThbW63u3NWkZWV1bkOBoPKzMyUx+Pp9ogpp9Mph8NBPGBUNBrV2bNntW7dOi1atEj79u1Ta2vvnKYd/cv3v/99LV++3PiO+4shLsB1yLZtlZeXq6CgQPPmzVNhYWG/2iQMc3w+n5599lk999xzysnJ6bPHJS7AdSwWi6m+vl4FBQV6++23tX37dmYy15ChQ4dqzpw5mjlzZq/uX+kOcQEg27ZVWVmpTz/9VH/4wx/05ZdfEpkBLiMjQ++++67uvffehHwEPXEB0Mm2bYVCIRUUFGjhwoX65JNP1NzcnOhh4TJNmzZNr732mu65556E7bclLgC6VVdXp88++0xz587Vli1b1NDQ8I2jING/uN1uPfbYY5o9e7ZGjRqV0LEQFwAX1XH05MaNG7V06VKtW7dODQ0NiR4WLuLFF1/USy+9pEAgkPAjTYkLgB5pbGzU7t27tWDBAhUUFKiqqor3yvQT6enpmjVrll566SV5vd5ED0cScQFwGTpmMtu2bdOyZcu0atUq1dTUJHpY17Xc3FzNmTNHjz32WMJnK10RFwBXpKWlRYcPH9bChQu1fv16lZeXM5PpY3l5eXrrrbc0efLkq/5wL9OIC4CrEolEVFhYqGXLlik/P19nzpxhx38fmDZtmhYvXqyxY8f2qxlLB+IC4Kp1HMJcVlam+fPn67333tORI0eITC9ISUnRE088odmzZyszMzPRw7ko4gLAqGg0qqKiIq1YsUKLFi1SRUUFm8sMSU5O1ssvv6xZs2YpLS0t0cO5JOICwDjbthWNRlVdXa358+dr586dam5uVktLy3lnIu+4jfh8u+zsbM2ZM0c/+9nP5PF4Ej2cb0VcAPS6UCh0Xly6RqalpUU1NTWqqqpSZWWlqqurVVlZqaqqKlVVVammpkaRSKTzbOVdz1Aei8Wu+U1vlmVp9OjR+uMf/6jp06cPmM9hIi4AEqrr5yF1t0QiEdXW1naGpuu6urpatbW1amxsVFNTkxobGzuXpqYmNTQ0DPizPT/wwAN64403NHHixH654/5iiAuAAS0ajaqtrU2tra1qbW3tvNyxbmhoUGVlZbdLVVWVQqHQNwJ3sct9+X1Op1M/+tGP9Prrr2vs2LG9/jyaRlwAXNO6+4TYrrdFo1FFo1FFIpHOddfLPfna1dzW3dcikYjS0tL029/+Vn6/f0DNWDoQFwCAcQNjzxAAYEAhLgAA44gLAMA44gIAMI64AACMIy4AAOOICwDAOOICADCOuAAAjCMuAADjiAsAwDjiAgAwjrgAAIwjLgAA44gLAMA44gIAMI64AACMIy4AAOOICwDAOOICADCOuAAAjCMuAADjiAsAwDjiAgAwjrgAAIwjLgAA44gLAMA44gIAMI64AACMIy4AAOOICwDAOOICADCOuAAAjCMuAADjiAsAwDjiAgAwjrgAAIwjLgAA44gLAMA44gIAMI64AACMIy4AAOOICwDAOOICADCOuAAAjCMuAADjiAsAwDjiAgAwjrgAAIwjLgAA4/4fpN7J9XcZ78wAAAAASUVORK5CYII=\n" 183 | }, 184 | "metadata": {} 185 | } 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "source": [ 191 | "MemoryBlock = namedtuple('MemoryBlock', ('current', 'action', 'reward', 'next'))" 192 | ], 193 | "metadata": { 194 | "id": "CatEHV-C5Ard" 195 | }, 196 | "execution_count": 62, 197 | "outputs": [] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "source": [ 202 | "block = MemoryBlock(1, 1, 2, 1)" 203 | ], 204 | "metadata": { 205 | "id": "awpIYsZg6Phm" 206 | }, 207 | "execution_count": 64, 208 | "outputs": [] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "source": [ 213 | "block.action" 214 | ], 215 | "metadata": { 216 | "colab": { 217 | "base_uri": "https://localhost:8080/" 218 | }, 219 | "id": "j0Xnvz_C6d9B", 220 | "outputId": "f2a5c582-9655-4fbb-95ee-39a561cecc76" 221 | }, 222 | "execution_count": 68, 223 | "outputs": [ 224 | { 225 | "output_type": "execute_result", 226 | "data": { 227 | "text/plain": [ 228 | "1" 229 | ] 230 | }, 231 | "metadata": {}, 232 | "execution_count": 68 233 | } 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "source": [ 239 | "memory = deque(maxlen=10)" 240 | ], 241 | "metadata": { 242 | "id": "LlcxJRes6gH4" 243 | }, 244 | "execution_count": 80, 245 | "outputs": [] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "source": [ 250 | "memory" 251 | ], 252 | "metadata": { 253 | "colab": { 254 | "base_uri": "https://localhost:8080/" 255 | }, 256 | "id": "SaUTuWou6r81", 257 | "outputId": "703aa9af-8471-4647-a8eb-4270d08266ee" 258 | }, 259 | "execution_count": 81, 260 | "outputs": [ 261 | { 262 | "output_type": "execute_result", 263 | "data": { 264 | "text/plain": [ 265 | "deque([])" 266 | ] 267 | }, 268 | "metadata": {}, 269 | "execution_count": 81 270 | } 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "source": [ 276 | "memory.append(100)" 277 | ], 278 | "metadata": { 279 | "id": "fSD6S3pK61Dh" 280 | }, 281 | "execution_count": 82, 282 | "outputs": [] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "source": [ 287 | "memory" 288 | ], 289 | "metadata": { 290 | "colab": { 291 | "base_uri": "https://localhost:8080/" 292 | }, 293 | "id": "AX6uGERC62dL", 294 | "outputId": "f26a33f1-4838-45a3-a1b7-06cf638f1ab2" 295 | }, 296 | "execution_count": 83, 297 | "outputs": [ 298 | { 299 | "output_type": "execute_result", 300 | "data": { 301 | "text/plain": [ 302 | "deque([100])" 303 | ] 304 | }, 305 | "metadata": {}, 306 | "execution_count": 83 307 | } 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "source": [ 313 | "for i in range(10):\n", 314 | " memory.append(i)\n", 315 | " print(memory)" 316 | ], 317 | "metadata": { 318 | "colab": { 319 | "base_uri": "https://localhost:8080/" 320 | }, 321 | "id": "W87WmNCE65fe", 322 | "outputId": "7f177f5b-a448-4fb4-b04d-aae531d306b2" 323 | }, 324 | "execution_count": 84, 325 | "outputs": [ 326 | { 327 | "output_type": "stream", 328 | "name": "stdout", 329 | "text": [ 330 | "deque([100, 0], maxlen=10)\n", 331 | "deque([100, 0, 1], maxlen=10)\n", 332 | "deque([100, 0, 1, 2], maxlen=10)\n", 333 | "deque([100, 0, 1, 2, 3], maxlen=10)\n", 334 | "deque([100, 0, 1, 2, 3, 4], maxlen=10)\n", 335 | "deque([100, 0, 1, 2, 3, 4, 5], maxlen=10)\n", 336 | "deque([100, 0, 1, 2, 3, 4, 5, 6], maxlen=10)\n", 337 | "deque([100, 0, 1, 2, 3, 4, 5, 6, 7], maxlen=10)\n", 338 | "deque([100, 0, 1, 2, 3, 4, 5, 6, 7, 8], maxlen=10)\n", 339 | "deque([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], maxlen=10)\n" 340 | ] 341 | } 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "source": [ 347 | "env = gym.make(\"LunarLander-v2\", render_mode=\"rgb_array\")\n", 348 | "observation, info = env.reset()\n", 349 | "\n", 350 | "memory = deque(maxlen=150)\n", 351 | "\n", 352 | "for _ in range(100):\n", 353 | " action = env.action_space.sample() # agent policy that uses the observation and info\n", 354 | "\n", 355 | " current = observation.copy()\n", 356 | " observation, reward, terminated, truncated, info = env.step(action)\n", 357 | " block = MemoryBlock(current, action, reward, observation)\n", 358 | " memory.append(block)\n", 359 | "\n", 360 | " if terminated or truncated:\n", 361 | " observation, info = env.reset()\n", 362 | "\n", 363 | "\n", 364 | "env.close()" 365 | ], 366 | "metadata": { 367 | "id": "V0EB3J8F6-GP" 368 | }, 369 | "execution_count": 87, 370 | "outputs": [] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "source": [ 375 | "b1 = memory[0]\n", 376 | "b2 = memory[1]" 377 | ], 378 | "metadata": { 379 | "id": "Hd4Jcpkn7gS1" 380 | }, 381 | "execution_count": 96, 382 | "outputs": [] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "source": [ 387 | "b1" 388 | ], 389 | "metadata": { 390 | "colab": { 391 | "base_uri": "https://localhost:8080/" 392 | }, 393 | "id": "MKW7UpO87hYY", 394 | "outputId": "5338c33f-7b14-4bf1-94f7-3e09ae6dac7f" 395 | }, 396 | "execution_count": 97, 397 | "outputs": [ 398 | { 399 | "output_type": "execute_result", 400 | "data": { 401 | "text/plain": [ 402 | "MemoryBlock(current=array([ 0.0032485 , 1.4054555 , 0.32901463, -0.24287517, -0.00375733,\n", 403 | " -0.07452664, 0. , 0. ], dtype=float32), action=0, reward=-1.3044605623392442, next=array([ 0.006497 , 1.3994141 , 0.32857022, -0.26852667, -0.00744009,\n", 404 | " -0.07366162, 0. , 0. ], dtype=float32))" 405 | ] 406 | }, 407 | "metadata": {}, 408 | "execution_count": 97 409 | } 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "source": [ 415 | "b2" 416 | ], 417 | "metadata": { 418 | "colab": { 419 | "base_uri": "https://localhost:8080/" 420 | }, 421 | "id": "FuteA5op788U", 422 | "outputId": "fdca541a-a4a7-47dd-e211-93400c3ce372" 423 | }, 424 | "execution_count": 98, 425 | "outputs": [ 426 | { 427 | "output_type": "execute_result", 428 | "data": { 429 | "text/plain": [ 430 | "MemoryBlock(current=array([ 0.006497 , 1.3994141 , 0.32857022, -0.26852667, -0.00744009,\n", 431 | " -0.07366162, 0. , 0. ], dtype=float32), action=2, reward=-1.3291067603536988, next=array([ 0.00991163, 1.3933325 , 0.34438816, -0.2703012 , -0.01033648,\n", 432 | " -0.05793293, 0. , 0. ], dtype=float32))" 433 | ] 434 | }, 435 | "metadata": {}, 436 | "execution_count": 98 437 | } 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "source": [ 443 | "batch = MemoryBlock(*zip(*memory))" 444 | ], 445 | "metadata": { 446 | "id": "b0Z0mrcu79K8" 447 | }, 448 | "execution_count": 115, 449 | "outputs": [] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "source": [ 454 | "batch.current" 455 | ], 456 | "metadata": { 457 | "colab": { 458 | "base_uri": "https://localhost:8080/" 459 | }, 460 | "id": "t4YaNqJK8C4S", 461 | "outputId": "706880f1-a7d6-453e-f28e-3aa8917b9f99" 462 | }, 463 | "execution_count": 116, 464 | "outputs": [ 465 | { 466 | "output_type": "execute_result", 467 | "data": { 468 | "text/plain": [ 469 | "(array([ 0.0032485 , 1.4054555 , 0.32901463, -0.24287517, -0.00375733,\n", 470 | " -0.07452664, 0. , 0. ], dtype=float32),\n", 471 | " array([ 0.006497 , 1.3994141 , 0.32857022, -0.26852667, -0.00744009,\n", 472 | " -0.07366162, 0. , 0. ], dtype=float32),\n", 473 | " array([ 0.00991163, 1.3933325 , 0.34438816, -0.2703012 , -0.01033648,\n", 474 | " -0.05793293, 0. , 0. ], dtype=float32),\n", 475 | " array([ 0.01323576, 1.3866549 , 0.33303937, -0.29678914, -0.01095505,\n", 476 | " -0.0123724 , 0. , 0. ], dtype=float32),\n", 477 | " array([ 0.01642313, 1.3809104 , 0.32008106, -0.25532517, -0.01228218,\n", 478 | " -0.02654489, 0. , 0. ], dtype=float32),\n", 479 | " array([ 0.0196105 , 1.3745655 , 0.32008517, -0.28200608, -0.01360798,\n", 480 | " -0.02651851, 0. , 0. ], dtype=float32),\n", 481 | " array([ 0.02286119, 1.3676276 , 0.32800844, -0.30837443, -0.01652008,\n", 482 | " -0.05824757, 0. , 0. ], dtype=float32),\n", 483 | " array([ 0.0260045 , 1.3614005 , 0.31786153, -0.27680793, -0.0200112 ,\n", 484 | " -0.06982894, 0. , 0. ], dtype=float32),\n", 485 | " array([ 0.02913609, 1.3556799 , 0.31680828, -0.25429958, -0.02362819,\n", 486 | " -0.07234631, 0. , 0. ], dtype=float32),\n", 487 | " array([ 0.03228216, 1.350129 , 0.3182506 , -0.24676459, -0.02723356,\n", 488 | " -0.07211356, 0. , 0. ], dtype=float32),\n", 489 | " array([ 0.03551693, 1.3439851 , 0.32936293, -0.2731835 , -0.03306038,\n", 490 | " -0.11654727, 0. , 0. ], dtype=float32),\n", 491 | " array([ 0.03875179, 1.3372415 , 0.32938066, -0.29985175, -0.03888537,\n", 492 | " -0.11651033, 0. , 0. ], dtype=float32),\n", 493 | " array([ 0.04192553, 1.3298929 , 0.3217017 , -0.32672888, -0.04317166,\n", 494 | " -0.08573384, 0. , 0. ], dtype=float32),\n", 495 | " array([ 0.04518175, 1.3219506 , 0.33203125, -0.353187 , -0.04952039,\n", 496 | " -0.12698598, 0. , 0. ], dtype=float32),\n", 497 | " array([ 0.04841232, 1.314907 , 0.32982442, -0.31328657, -0.05622081,\n", 498 | " -0.13402088, 0. , 0. ], dtype=float32),\n", 499 | " array([ 0.05172186, 1.3072615 , 0.33972174, -0.34014696, -0.06490061,\n", 500 | " -0.17361203, 0. , 0. ], dtype=float32),\n", 501 | " array([ 0.05503168, 1.2990172 , 0.33974722, -0.36681646, -0.07357845,\n", 502 | " -0.17357238, 0. , 0. ], dtype=float32),\n", 503 | " array([ 0.05834246, 1.2912889 , 0.3400916 , -0.34394732, -0.08250807,\n", 504 | " -0.17860876, 0. , 0. ], dtype=float32),\n", 505 | " array([ 0.06172209, 1.2829638 , 0.34867343, -0.37062198, -0.09314454,\n", 506 | " -0.21274868, 0. , 0. ], dtype=float32),\n", 507 | " array([ 0.06502628, 1.2740395 , 0.33922154, -0.39720652, -0.10188421,\n", 508 | " -0.17480874, 0. , 0. ], dtype=float32),\n", 509 | " array([ 0.06852102, 1.2659017 , 0.35781997, -0.36226854, -0.11017855,\n", 510 | " -0.1659019 , 0. , 0. ], dtype=float32),\n", 511 | " array([ 0.07213764, 1.2581213 , 0.3697724 , -0.34640688, -0.11823454,\n", 512 | " -0.16113424, 0. , 0. ], dtype=float32),\n", 513 | " array([ 0.07599831, 1.250621 , 0.39334607, -0.33393574, -0.1254857 ,\n", 514 | " -0.14503631, 0. , 0. ], dtype=float32),\n", 515 | " array([ 0.07978897, 1.2425411 , 0.3845337 , -0.35957322, -0.13093206,\n", 516 | " -0.10893674, 0. , 0. ], dtype=float32),\n", 517 | " array([ 0.08357992, 1.2338617 , 0.38454834, -0.38624236, -0.136378 ,\n", 518 | " -0.10892855, 0. , 0. ], dtype=float32),\n", 519 | " array([ 0.08736897, 1.2245612 , 0.3843667 , -0.41385975, -0.14182337,\n", 520 | " -0.10890688, 0. , 0. ], dtype=float32),\n", 521 | " array([ 0.09117527, 1.2154578 , 0.3863613 , -0.4051519 , -0.14753653,\n", 522 | " -0.11426322, 0. , 0. ], dtype=float32),\n", 523 | " array([ 0.09506798, 1.2057257 , 0.39723632, -0.43334863, -0.15548898,\n", 524 | " -0.15904924, 0. , 0. ], dtype=float32),\n", 525 | " array([ 0.09896078, 1.1953943 , 0.39723513, -0.4600205 , -0.16344142,\n", 526 | " -0.15904859, 0. , 0. ], dtype=float32),\n", 527 | " array([ 0.10305023, 1.1857693 , 0.41663918, -0.42865118, -0.1711515 ,\n", 528 | " -0.15420182, 0. , 0. ], dtype=float32),\n", 529 | " array([ 0.10713968, 1.1755449 , 0.41663796, -0.45532274, -0.17886156,\n", 530 | " -0.15420106, 0. , 0. ], dtype=float32),\n", 531 | " array([ 0.11116524, 1.1647334 , 0.40861574, -0.4812524 , -0.18494107,\n", 532 | " -0.12159048, 0. , 0. ], dtype=float32),\n", 533 | " array([ 0.1151638 , 1.1540335 , 0.40642828, -0.47639313, -0.19154665,\n", 534 | " -0.13211167, 0. , 0. ], dtype=float32),\n", 535 | " array([ 0.11916237, 1.1427339 , 0.4064273 , -0.50306344, -0.19815221,\n", 536 | " -0.13211125, 0. , 0. ], dtype=float32),\n", 537 | " array([ 0.12323084, 1.1308116 , 0.41521496, -0.53102475, -0.20657948,\n", 538 | " -0.16854541, 0. , 0. ], dtype=float32),\n", 539 | " array([ 0.12759057, 1.1193523 , 0.44367886, -0.5104016 , -0.21436031,\n", 540 | " -0.15561649, 0. , 0. ], dtype=float32),\n", 541 | " array([ 0.1321722 , 1.1080383 , 0.46535143, -0.50390726, -0.22162652,\n", 542 | " -0.1453242 , 0. , 0. ], dtype=float32),\n", 543 | " array([ 0.13710518, 1.0973196 , 0.49971023, -0.4773649 , -0.2281169 ,\n", 544 | " -0.1298075 , 0. , 0. ], dtype=float32),\n", 545 | " array([ 0.14234333, 1.086952 , 0.52950966, -0.46167386, -0.23389536,\n", 546 | " -0.11556929, 0. , 0. ], dtype=float32),\n", 547 | " array([ 0.14751807, 1.0759965 , 0.52158654, -0.4875675 , -0.23806275,\n", 548 | " -0.08334794, 0. , 0. ], dtype=float32),\n", 549 | " array([ 0.15269288, 1.0644413 , 0.52158606, -0.51423556, -0.24223015,\n", 550 | " -0.08334783, 0. , 0. ], dtype=float32),\n", 551 | " array([ 0.15786782, 1.0522863 , 0.5215855 , -0.5409037 , -0.24639754,\n", 552 | " -0.08334772, 0. , 0. ], dtype=float32),\n", 553 | " array([ 0.16310024, 1.0408137 , 0.52803177, -0.5107082 , -0.25127062,\n", 554 | " -0.09746158, 0. , 0. ], dtype=float32),\n", 555 | " array([ 0.16864958, 1.029617 , 0.5589772 , -0.49832523, -0.25539538,\n", 556 | " -0.08249549, 0. , 0. ], dtype=float32),\n", 557 | " array([ 0.17437668, 1.019299 , 0.57709444, -0.45934814, -0.259876 ,\n", 558 | " -0.08961239, 0. , 0. ], dtype=float32),\n", 559 | " array([ 0.18010378, 1.0083812 , 0.5770937 , -0.48601642, -0.2643566 ,\n", 560 | " -0.08961224, 0. , 0. ], dtype=float32),\n", 561 | " array([ 0.1859231 , 0.99682903, 0.588689 , -0.5146646 , -0.2712742 ,\n", 562 | " -0.1383522 , 0. , 0. ], dtype=float32),\n", 563 | " array([ 0.19174251, 0.98467755, 0.5886875 , -0.5413351 , -0.2781918 ,\n", 564 | " -0.13835177, 0. , 0. ], dtype=float32),\n", 565 | " array([ 0.19747925, 0.9719698 , 0.5782013 , -0.5656596 , -0.2828468 ,\n", 566 | " -0.09309985, 0. , 0. ], dtype=float32),\n", 567 | " array([ 0.20321599, 0.9586623 , 0.5782005 , -0.592328 , -0.28750178,\n", 568 | " -0.09309975, 0. , 0. ], dtype=float32),\n", 569 | " array([ 0.20902376, 0.9447341 , 0.5870816 , -0.62028587, -0.29400042,\n", 570 | " -0.12997194, 0. , 0. ], dtype=float32),\n", 571 | " array([ 0.21475688, 0.9302393 , 0.57767045, -0.6450995 , -0.2984922 ,\n", 572 | " -0.08983554, 0. , 0. ], dtype=float32),\n", 573 | " array([ 0.22040614, 0.9151806 , 0.56708866, -0.6697237 , -0.3007316 ,\n", 574 | " -0.04478817, 0. , 0. ], dtype=float32),\n", 575 | " array([ 0.22598 , 0.89956385, 0.5575346 , -0.69411033, -0.30088818,\n", 576 | " -0.00313152, 0. , 0. ], dtype=float32),\n", 577 | " array([ 0.23146506, 0.8833899 , 0.54632086, -0.7183911 , -0.2986327 ,\n", 578 | " 0.0451091 , 0. , 0. ], dtype=float32),\n", 579 | " array([ 0.23701553, 0.866591 , 0.5545404 , -0.7465191 , -0.2981128 ,\n", 580 | " 0.01039831, 0. , 0. ], dtype=float32),\n", 581 | " array([ 0.24256602, 0.84919196, 0.5545404 , -0.77318585, -0.29759288,\n", 582 | " 0.01039837, 0. , 0. ], dtype=float32),\n", 583 | " array([ 0.24834327, 0.8317434 , 0.5767339 , -0.77528656, -0.29656732,\n", 584 | " 0.02051135, 0. , 0. ], dtype=float32),\n", 585 | " array([ 0.25418806, 0.8136553 , 0.5852955 , -0.80408293, -0.29741687,\n", 586 | " -0.01699063, 0. , 0. ], dtype=float32),\n", 587 | " array([ 0.26003274, 0.7949673 , 0.5852954 , -0.83074963, -0.2982664 ,\n", 588 | " -0.01699061, 0. , 0. ], dtype=float32),\n", 589 | " array([ 0.26631337, 0.77691966, 0.62809825, -0.80212414, -0.2983126 ,\n", 590 | " -0.00092363, 0. , 0. ], dtype=float32),\n", 591 | " array([ 0.2725939 , 0.7582721 , 0.62809837, -0.82879084, -0.2983588 ,\n", 592 | " -0.00092367, 0. , 0. ], dtype=float32),\n", 593 | " array([ 2.7907389e-01, 7.3992133e-01, 6.4797944e-01, -8.1558305e-01,\n", 594 | " -2.9833561e-01, 4.6390909e-04, 0.0000000e+00, 0.0000000e+00],\n", 595 | " dtype=float32),\n", 596 | " array([ 0.28546923, 0.7210037 , 0.63733035, -0.8403326 , -0.29606122,\n", 597 | " 0.04548761, 0. , 0. ], dtype=float32),\n", 598 | " array([ 0.29180604, 0.70152026, 0.6298959 , -0.8651683 , -0.29216075,\n", 599 | " 0.07800949, 0. , 0. ], dtype=float32),\n", 600 | " array([ 0.29806557, 0.6814782 , 0.6201067 , -0.8896006 , -0.2861396 ,\n", 601 | " 0.12042297, 0. , 0. ], dtype=float32),\n", 602 | " array([ 0.304393 , 0.66080976, 0.62864447, -0.917799 , -0.28192195,\n", 603 | " 0.08435254, 0. , 0. ], dtype=float32),\n", 604 | " array([ 0.31065854, 0.63955736, 0.62088794, -0.9434684 , -0.27610767,\n", 605 | " 0.11628503, 0. , 0. ], dtype=float32),\n", 606 | " array([ 0.31683215, 0.6177396 , 0.60932267, -0.96817976, -0.26786372,\n", 607 | " 0.16487893, 0. , 0. ], dtype=float32),\n", 608 | " array([ 0.32300606, 0.5953227 , 0.60932034, -0.9948518 , -0.2596198 ,\n", 609 | " 0.16487816, 0. , 0. ], dtype=float32),\n", 610 | " array([ 0.32926998, 0.5733222 , 0.61870646, -0.9764615 , -0.25177723,\n", 611 | " 0.1568514 , 0. , 0. ], dtype=float32),\n", 612 | " array([ 0.335534 , 0.5507224 , 0.6187045 , -1.0031332 , -0.24393468,\n", 613 | " 0.15685079, 0. , 0. ], dtype=float32),\n", 614 | " array([ 0.34193307, 0.5280501 , 0.63197047, -1.0063579 , -0.23587935,\n", 615 | " 0.16110703, 0. , 0. ], dtype=float32),\n", 616 | " array([ 0.3483321 , 0.5047788 , 0.6319685 , -1.0330298 , -0.22782403,\n", 617 | " 0.16110618, 0. , 0. ], dtype=float32),\n", 618 | " array([ 0.35482207, 0.4808718 , 0.6434083 , -1.061682 , -0.22217 ,\n", 619 | " 0.11308068, 0. , 0. ], dtype=float32),\n", 620 | " array([ 0.36123332, 0.45639145, 0.63350135, -1.0868869 , -0.21445887,\n", 621 | " 0.1542222 , 0. , 0. ], dtype=float32),\n", 622 | " array([ 0.36755657, 0.4313467 , 0.6223889 , -1.1116874 , -0.2044231 ,\n", 623 | " 0.20071515, 0. , 0. ], dtype=float32),\n", 624 | " array([ 0.3738801 , 0.4057033 , 0.62238616, -1.1383624 , -0.1943874 ,\n", 625 | " 0.20071383, 0. , 0. ], dtype=float32),\n", 626 | " array([ 0.3802759 , 0.37943146, 0.63148636, -1.1665958 , -0.18625432,\n", 627 | " 0.16266184, 0. , 0. ], dtype=float32),\n", 628 | " array([ 0.3866108 , 0.35258853, 0.6237731 , -1.1918262 , -0.17650253,\n", 629 | " 0.19503552, 0. , 0. ], dtype=float32),\n", 630 | " array([ 0.3928669 , 0.32515943, 0.6138825 , -1.2177227 , -0.1647498 ,\n", 631 | " 0.23505464, 0. , 0. ], dtype=float32),\n", 632 | " array([ 0.39912328, 0.29713205, 0.6138795 , -1.2444009 , -0.15299718,\n", 633 | " 0.23505235, 0. , 0. ], dtype=float32),\n", 634 | " array([ 0.40555716, 0.26928914, 0.63117206, -1.23625 , -0.14080116,\n", 635 | " 0.24392083, 0. , 0. ], dtype=float32),\n", 636 | " array([ 0.41192016, 0.24086352, 0.62222606, -1.262094 , -0.12678766,\n", 637 | " 0.28026995, 0. , 0. ], dtype=float32),\n", 638 | " array([ 0.4183545 , 0.2118345 , 0.63115126, -1.2891777 , -0.11456715,\n", 639 | " 0.24441049, 0. , 0. ], dtype=float32),\n", 640 | " array([ 0.42478913, 0.18220748, 0.63114893, -1.3158567 , -0.10234676,\n", 641 | " 0.24440798, 0. , 0. ], dtype=float32),\n", 642 | " array([ 0.43113318, 0.15199342, 0.6197608 , -1.3419081 , -0.08783399,\n", 643 | " 0.29025573, 0. , 0. ], dtype=float32),\n", 644 | " array([ 0.43754464, 0.12116834, 0.6282196 , -1.369294 , -0.07503272,\n", 645 | " 0.25602564, 0. , 0. ], dtype=float32),\n", 646 | " array([ 0.44387072, 0.08976682, 0.6174868 , -1.3949317 , -0.06005894,\n", 647 | " 0.2994754 , 0. , 0. ], dtype=float32),\n", 648 | " array([ 0.45019692, 0.05776836, 0.6174848 , -1.4216172 , -0.04508539,\n", 649 | " 0.29947075, 0. , 0. ], dtype=float32),\n", 650 | " array([ 0.45643815, 0.0251823 , 0.60678774, -1.4478359 , -0.02796434,\n", 651 | " 0.34242067, 0. , 0. ], dtype=float32),\n", 652 | " array([ 0.46254796, -0.00653446, 0.5943854 , -1.4094068 , -0.01157715,\n", 653 | " 0.3277441 , 0. , 0. ], dtype=float32),\n", 654 | " array([ 0.46856374, -0.03884045, 0.5825814 , -1.4357822 , 0.00717474,\n", 655 | " 0.37503785, 1. , 0. ], dtype=float32),\n", 656 | " array([ 0.47394055, -0.07029414, 0.5007308 , -1.398636 , 0.03890656,\n", 657 | " 0.62863874, 1. , 0. ], dtype=float32),\n", 658 | " array([ 6.3772203e-04, 1.4035777e+00, 6.4582005e-02, -3.2633755e-01,\n", 659 | " -7.3219155e-04, -1.4628743e-02, 0.0000000e+00, 0.0000000e+00],\n", 660 | " dtype=float32),\n", 661 | " array([ 1.2557984e-03, 1.3968295e+00, 6.2634937e-02, -2.9991859e-01,\n", 662 | " -1.5520940e-03, -1.6398780e-02, 0.0000000e+00, 0.0000000e+00],\n", 663 | " dtype=float32),\n", 664 | " array([ 2.0483017e-03, 1.3907650e+00, 7.9237178e-02, -2.6953679e-01,\n", 665 | " -1.5398865e-03, 2.4391804e-04, 0.0000000e+00, 0.0000000e+00],\n", 666 | " dtype=float32),\n", 667 | " array([ 2.8409003e-03, 1.3841002e+00, 7.9236843e-02, -2.9621407e-01,\n", 668 | " -1.5283929e-03, 2.3040108e-04, 0.0000000e+00, 0.0000000e+00],\n", 669 | " dtype=float32),\n", 670 | " array([ 0.00369711, 1.3768382 , 0.08724354, -0.3227533 , -0.00312197,\n", 671 | " -0.03187469, 0. , 0. ], dtype=float32),\n", 672 | " array([ 0.00447617, 1.3689772 , 0.0775428 , -0.3493837 , -0.00276848,\n", 673 | " 0.00707043, 0. , 0. ], dtype=float32))" 674 | ] 675 | }, 676 | "metadata": {}, 677 | "execution_count": 116 678 | } 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "source": [ 684 | "s = random.sample(memory, 10)\n", 685 | "batch = MemoryBlock(*zip(*s))" 686 | ], 687 | "metadata": { 688 | "id": "dkcpe-Qw9E2X" 689 | }, 690 | "execution_count": 119, 691 | "outputs": [] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "source": [ 696 | "batch.reward" 697 | ], 698 | "metadata": { 699 | "colab": { 700 | "base_uri": "https://localhost:8080/" 701 | }, 702 | "id": "Z2dSnIvO9P62", 703 | "outputId": "93da11b9-a8a7-4090-8862-dbf46a493ace" 704 | }, 705 | "execution_count": 120, 706 | "outputs": [ 707 | { 708 | "output_type": "execute_result", 709 | "data": { 710 | "text/plain": [ 711 | "(0.7480714012465353,\n", 712 | " -0.6803533446248753,\n", 713 | " 3.16454145519175,\n", 714 | " -0.4614322955537659,\n", 715 | " -1.2174461058552595,\n", 716 | " 0.24074994760946994,\n", 717 | " -0.9185988699621135,\n", 718 | " -0.9376932553099369,\n", 719 | " -0.575740884578579,\n", 720 | " -0.37997100287804986)" 721 | ] 722 | }, 723 | "metadata": {}, 724 | "execution_count": 120 725 | } 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "source": [], 731 | "metadata": { 732 | "id": "Ej2l8P2R9U3E" 733 | }, 734 | "execution_count": null, 735 | "outputs": [] 736 | } 737 | ] 738 | } -------------------------------------------------------------------------------- /25_Double_Deep_Q_Learning_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4", 8 | "authorship_tag": "ABX9TyO5oExkA5B4ZMegA5eJNX6v", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "Main source: https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html" 35 | ], 36 | "metadata": { 37 | "id": "e2c9wneh2jNf" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "kMrwaCQ804lM", 48 | "outputId": "51b22a64-3639-4925-b883-7affe72d375f" 49 | }, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "name": "stdout", 54 | "text": [ 55 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m953.9/953.9 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 56 | "\u001b[?25h" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "!pip install -q gymnasium[classic_control]" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "source": [ 67 | "import math\n", 68 | "import random\n", 69 | "from collections import namedtuple, deque\n", 70 | "from itertools import count\n", 71 | "\n", 72 | "import gymnasium as gym\n", 73 | "\n", 74 | "import torch\n", 75 | "import torch.nn as nn\n", 76 | "import torch.nn.functional as F\n", 77 | "import torch.optim as optim\n", 78 | "\n", 79 | "from IPython import display\n", 80 | "\n", 81 | "import numpy as np\n", 82 | "import matplotlib.pyplot as plt" 83 | ], 84 | "metadata": { 85 | "id": "S-VcJSlJ2hM3" 86 | }, 87 | "execution_count": 2, 88 | "outputs": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "source": [ 93 | "env = gym.make('CartPole-v1')" 94 | ], 95 | "metadata": { 96 | "id": "d4JAWIO23I6V" 97 | }, 98 | "execution_count": 3, 99 | "outputs": [] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "source": [ 104 | "BATCH_SIZE = 128\n", 105 | "NUM_EPISODES = 600\n", 106 | "\n", 107 | "GAMMA = .99 # Discount factor\n", 108 | "\n", 109 | "# epsilon-greedy parameters:\n", 110 | "EPS_START = .9\n", 111 | "EPS_END = .05\n", 112 | "EPS_DECAY = 1000\n", 113 | "\n", 114 | "TAU = 5e-3\n", 115 | "\n", 116 | "LR = 1e-4\n", 117 | "CLIP_VALUE = 100\n", 118 | "\n", 119 | "MEMORY_SIZE = 10000\n", 120 | "\n", 121 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" 122 | ], 123 | "metadata": { 124 | "id": "-tf9lVbc3CsC" 125 | }, 126 | "execution_count": 4, 127 | "outputs": [] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "source": [ 132 | "state, info = env.reset()\n", 133 | "N_OBSERVATIONS = len(state)\n", 134 | "N_ACTIONS = env.action_space.n" 135 | ], 136 | "metadata": { 137 | "id": "7LuXu5sU5qGW" 138 | }, 139 | "execution_count": 5, 140 | "outputs": [] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "source": [ 145 | "plt.ion()" 146 | ], 147 | "metadata": { 148 | "colab": { 149 | "base_uri": "https://localhost:8080/" 150 | }, 151 | "id": "9Py_Q-6j3M9n", 152 | "outputId": "f3bfeeed-f12b-44b3-fa5b-f7dc89c5af2c" 153 | }, 154 | "execution_count": 6, 155 | "outputs": [ 156 | { 157 | "output_type": "execute_result", 158 | "data": { 159 | "text/plain": [ 160 | "" 161 | ] 162 | }, 163 | "metadata": {}, 164 | "execution_count": 6 165 | } 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "source": [ 171 | "MemoryBlock = namedtuple('MemoryBlock', ('state', 'action', 'reward', 'next_state'))\n", 172 | "\n", 173 | "class ReplayMemory:\n", 174 | " def __init__(self, capacity):\n", 175 | " self.memory = deque(maxlen=capacity)\n", 176 | "\n", 177 | " def __len__(self):\n", 178 | " return len(self.memory)\n", 179 | "\n", 180 | " def push(self, state, action, reward, next_state):\n", 181 | " block = MemoryBlock(state, action, reward, next_state)\n", 182 | " self.memory.append(block)\n", 183 | "\n", 184 | " def sample(self, batch_size):\n", 185 | " return random.sample(self.memory, batch_size)" 186 | ], 187 | "metadata": { 188 | "id": "A-qtTzAV3ZY8" 189 | }, 190 | "execution_count": 7, 191 | "outputs": [] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "source": [ 196 | "class PolicyNetwork(nn.Module):\n", 197 | " def __init__(self, n_observations, n_actions, latent_dim=128):\n", 198 | " super().__init__()\n", 199 | "\n", 200 | " self.layers = nn.Sequential(\n", 201 | " nn.Linear(n_observations, latent_dim),\n", 202 | " nn.ReLU(),\n", 203 | " nn.Linear(latent_dim, latent_dim),\n", 204 | " nn.ReLU(),\n", 205 | " nn.Linear(latent_dim, n_actions)\n", 206 | " )\n", 207 | "\n", 208 | " def forward(self, x):\n", 209 | " return self.layers(x)" 210 | ], 211 | "metadata": { 212 | "id": "3EQxRKlX4gPh" 213 | }, 214 | "execution_count": 8, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "source": [ 220 | "policy_net = PolicyNetwork(N_OBSERVATIONS, N_ACTIONS).to(DEVICE)\n", 221 | "target_net = PolicyNetwork(N_OBSERVATIONS, N_ACTIONS).to(DEVICE)\n", 222 | "\n", 223 | "target_net.load_state_dict(policy_net.state_dict())" 224 | ], 225 | "metadata": { 226 | "colab": { 227 | "base_uri": "https://localhost:8080/" 228 | }, 229 | "id": "y3MogvRV5MmR", 230 | "outputId": "e3742ebb-08ad-4d5d-cb64-c9a0333e7f68" 231 | }, 232 | "execution_count": 9, 233 | "outputs": [ 234 | { 235 | "output_type": "execute_result", 236 | "data": { 237 | "text/plain": [ 238 | "" 239 | ] 240 | }, 241 | "metadata": {}, 242 | "execution_count": 9 243 | } 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "source": [ 249 | "optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)\n", 250 | "criterion = nn.SmoothL1Loss()" 251 | ], 252 | "metadata": { 253 | "id": "vZTxKNpG6IAT" 254 | }, 255 | "execution_count": 10, 256 | "outputs": [] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "source": [ 261 | "memory = ReplayMemory(MEMORY_SIZE)" 262 | ], 263 | "metadata": { 264 | "id": "qK5Uinxk6Uxp" 265 | }, 266 | "execution_count": 11, 267 | "outputs": [] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "source": [ 272 | "x = np.array(range(6000))\n", 273 | "\n", 274 | "eps = EPS_END + (EPS_START - EPS_END) * np.exp(-1 * x / EPS_DECAY)\n", 275 | "\n", 276 | "plt.plot(x, eps)\n", 277 | "plt.show()" 278 | ], 279 | "metadata": { 280 | "colab": { 281 | "base_uri": "https://localhost:8080/", 282 | "height": 430 283 | }, 284 | "id": "ln4qZVoT7Dhi", 285 | "outputId": "66003bd4-5de0-453c-dff1-5f7174cb12f3" 286 | }, 287 | "execution_count": 12, 288 | "outputs": [ 289 | { 290 | "output_type": "display_data", 291 | "data": { 292 | "text/plain": [ 293 | "
" 294 | ], 295 | "image/png": "\n" 296 | }, 297 | "metadata": {} 298 | } 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "source": [ 304 | "steps_done = 0\n", 305 | "\n", 306 | "def select_action(state):\n", 307 | " global steps_done\n", 308 | "\n", 309 | " eps = EPS_END + (EPS_START - EPS_END) * math.exp(-1 * steps_done / EPS_DECAY)\n", 310 | " steps_done += 1\n", 311 | "\n", 312 | " if random.random() > eps:\n", 313 | " with torch.no_grad():\n", 314 | " action = policy_net(state).max(1).indices.view(1, 1)\n", 315 | " else:\n", 316 | " action = torch.tensor([[env.action_space.sample()]], device=DEVICE, dtype=torch.long)\n", 317 | "\n", 318 | " return action" 319 | ], 320 | "metadata": { 321 | "id": "YijGFqvN6b6E" 322 | }, 323 | "execution_count": 13, 324 | "outputs": [] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "source": [ 329 | "def optimize_model():\n", 330 | " if len(memory) < BATCH_SIZE:\n", 331 | " return\n", 332 | "\n", 333 | " optimizer.zero_grad()\n", 334 | "\n", 335 | " history = memory.sample(BATCH_SIZE)\n", 336 | " batch = MemoryBlock(*zip(*history))\n", 337 | "\n", 338 | " state_batch = torch.cat(batch.state)\n", 339 | " action_batch = torch.cat(batch.action)\n", 340 | " reward_batch = torch.cat(batch.reward)\n", 341 | "\n", 342 | " state_action_values = policy_net(state_batch).gather(1, action_batch)\n", 343 | "\n", 344 | " non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=DEVICE, dtype=torch.bool)\n", 345 | " non_final_next_state = torch.cat([s for s in batch.next_state if s is not None])\n", 346 | " next_state_values = torch.zeros(BATCH_SIZE, device=DEVICE)\n", 347 | "\n", 348 | " with torch.no_grad():\n", 349 | " next_state_values[non_final_mask] = target_net(non_final_next_state).max(1).values\n", 350 | "\n", 351 | " expected_state_action_values = (next_state_values * GAMMA) + reward_batch\n", 352 | " loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))\n", 353 | "\n", 354 | " loss.backward()\n", 355 | "\n", 356 | " nn.utils.clip_grad_value_(policy_net.parameters(), CLIP_VALUE)\n", 357 | " optimizer.step()" 358 | ], 359 | "metadata": { 360 | "id": "KtBs3WkB79yg" 361 | }, 362 | "execution_count": 14, 363 | "outputs": [] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "source": [ 368 | "episode_durations = []" 369 | ], 370 | "metadata": { 371 | "id": "WsKcEq4u6_5T" 372 | }, 373 | "execution_count": 15, 374 | "outputs": [] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "source": [ 379 | "for epoch in range(1, NUM_EPISODES + 1):\n", 380 | " state, _ = env.reset()\n", 381 | " state = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)\n", 382 | "\n", 383 | " for t in count():\n", 384 | " action = select_action(state)\n", 385 | "\n", 386 | " observation, reward, terminated, truncated, _ = env.step(action.item())\n", 387 | " reward = torch.tensor([reward], device=DEVICE)\n", 388 | "\n", 389 | " if terminated:\n", 390 | " next_state = None\n", 391 | " else:\n", 392 | " next_state = torch.tensor(observation, dtype=torch.float32, device=DEVICE).unsqueeze(0)\n", 393 | "\n", 394 | " memory.push(state, action, reward, next_state)\n", 395 | " state = next_state\n", 396 | "\n", 397 | " optimize_model()\n", 398 | "\n", 399 | " target_state_dict = target_net.state_dict()\n", 400 | " policy_state_dict = policy_net.state_dict()\n", 401 | "\n", 402 | " for key in policy_state_dict:\n", 403 | " target_state_dict[key] = policy_state_dict[key] * TAU + target_state_dict[key] * (1 - TAU)\n", 404 | "\n", 405 | " target_net.load_state_dict(target_state_dict)\n", 406 | "\n", 407 | " if terminated or truncated:\n", 408 | " episode_durations.append(t + 1)\n", 409 | "\n", 410 | " if epoch % 50 == 0:\n", 411 | " print(f'Episode {epoch} duration: {t + 1}')\n", 412 | "\n", 413 | " break" 414 | ], 415 | "metadata": { 416 | "colab": { 417 | "base_uri": "https://localhost:8080/" 418 | }, 419 | "id": "wO7NprE-_TQa", 420 | "outputId": "8ab15210-408f-452b-f081-926f7d4679e3" 421 | }, 422 | "execution_count": 16, 423 | "outputs": [ 424 | { 425 | "output_type": "stream", 426 | "name": "stdout", 427 | "text": [ 428 | "Episode 50 duration: 10\n", 429 | "Episode 100 duration: 17\n", 430 | "Episode 150 duration: 8\n", 431 | "Episode 200 duration: 74\n", 432 | "Episode 250 duration: 119\n", 433 | "Episode 300 duration: 120\n", 434 | "Episode 350 duration: 145\n", 435 | "Episode 400 duration: 181\n", 436 | "Episode 450 duration: 318\n", 437 | "Episode 500 duration: 500\n", 438 | "Episode 550 duration: 500\n", 439 | "Episode 600 duration: 500\n" 440 | ] 441 | } 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "source": [ 447 | "plt.plot(episode_durations)\n", 448 | "\n", 449 | "plt.show()" 450 | ], 451 | "metadata": { 452 | "colab": { 453 | "base_uri": "https://localhost:8080/", 454 | "height": 430 455 | }, 456 | "id": "s7t56VlbBr60", 457 | "outputId": "227e4320-e867-4da1-aedc-450b397db796" 458 | }, 459 | "execution_count": 17, 460 | "outputs": [ 461 | { 462 | "output_type": "display_data", 463 | "data": { 464 | "text/plain": [ 465 | "
" 466 | ], 467 | "image/png": "\n" 468 | }, 469 | "metadata": {} 470 | } 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "source": [], 476 | "metadata": { 477 | "id": "SkiYqlimB-lB" 478 | }, 479 | "execution_count": 17, 480 | "outputs": [] 481 | } 482 | ] 483 | } -------------------------------------------------------------------------------- /5_Pytorch_Introdution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyMa4UQgnjr1l61FgngsPT2O", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "id": "wqWzgaihwob-" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "import torch\n", 38 | "import torch.nn as nn\n", 39 | "import torch.nn.functional as F\n", 40 | "import torch.optim as optim" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "source": [ 46 | "scaler = torch.tensor([1.])" 47 | ], 48 | "metadata": { 49 | "id": "fSQwTeFAxRPu" 50 | }, 51 | "execution_count": 4, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "source": [ 57 | "scaler" 58 | ], 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "JbmosPEXxwBI", 64 | "outputId": "a4a5fcdc-a0b6-4c41-f084-7d18cdd0525b" 65 | }, 66 | "execution_count": 5, 67 | "outputs": [ 68 | { 69 | "output_type": "execute_result", 70 | "data": { 71 | "text/plain": [ 72 | "tensor([1.])" 73 | ] 74 | }, 75 | "metadata": {}, 76 | "execution_count": 5 77 | } 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "source": [ 83 | "scaler.ndim" 84 | ], 85 | "metadata": { 86 | "colab": { 87 | "base_uri": "https://localhost:8080/" 88 | }, 89 | "id": "HjeDdpDOxy11", 90 | "outputId": "94029959-2384-4e69-c806-971e23e97bf1" 91 | }, 92 | "execution_count": 6, 93 | "outputs": [ 94 | { 95 | "output_type": "execute_result", 96 | "data": { 97 | "text/plain": [ 98 | "1" 99 | ] 100 | }, 101 | "metadata": {}, 102 | "execution_count": 6 103 | } 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "source": [ 109 | "scaler.dtype" 110 | ], 111 | "metadata": { 112 | "colab": { 113 | "base_uri": "https://localhost:8080/" 114 | }, 115 | "id": "87I5lTcLx0D1", 116 | "outputId": "ae93d0d4-db54-4f30-99c4-edafb9d0d47e" 117 | }, 118 | "execution_count": 7, 119 | "outputs": [ 120 | { 121 | "output_type": "execute_result", 122 | "data": { 123 | "text/plain": [ 124 | "torch.float32" 125 | ] 126 | }, 127 | "metadata": {}, 128 | "execution_count": 7 129 | } 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "source": [ 135 | "scaler + 2" 136 | ], 137 | "metadata": { 138 | "colab": { 139 | "base_uri": "https://localhost:8080/" 140 | }, 141 | "id": "3HELsLg5x1Jl", 142 | "outputId": "dc28f56f-5cb3-422a-c930-d0c17fadd0f4" 143 | }, 144 | "execution_count": 8, 145 | "outputs": [ 146 | { 147 | "output_type": "execute_result", 148 | "data": { 149 | "text/plain": [ 150 | "tensor([3.])" 151 | ] 152 | }, 153 | "metadata": {}, 154 | "execution_count": 8 155 | } 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "source": [ 161 | "scaler.requires_grad" 162 | ], 163 | "metadata": { 164 | "colab": { 165 | "base_uri": "https://localhost:8080/" 166 | }, 167 | "id": "JQSUrajux7ZN", 168 | "outputId": "376f1b1c-e46b-49a8-b41f-9c1580feff32" 169 | }, 170 | "execution_count": 9, 171 | "outputs": [ 172 | { 173 | "output_type": "execute_result", 174 | "data": { 175 | "text/plain": [ 176 | "False" 177 | ] 178 | }, 179 | "metadata": {}, 180 | "execution_count": 9 181 | } 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "source": [ 187 | "scaler.device" 188 | ], 189 | "metadata": { 190 | "colab": { 191 | "base_uri": "https://localhost:8080/" 192 | }, 193 | "id": "wWIFoa5xyGCB", 194 | "outputId": "d702ab3d-9f86-423d-ce53-f58da116e6fd" 195 | }, 196 | "execution_count": 10, 197 | "outputs": [ 198 | { 199 | "output_type": "execute_result", 200 | "data": { 201 | "text/plain": [ 202 | "device(type='cpu')" 203 | ] 204 | }, 205 | "metadata": {}, 206 | "execution_count": 10 207 | } 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "source": [ 213 | "torch.cuda.is_available()" 214 | ], 215 | "metadata": { 216 | "colab": { 217 | "base_uri": "https://localhost:8080/" 218 | }, 219 | "id": "kTm0UlECzZCp", 220 | "outputId": "ebf469e0-95c2-4e89-ce60-1290abdd3b79" 221 | }, 222 | "execution_count": 15, 223 | "outputs": [ 224 | { 225 | "output_type": "execute_result", 226 | "data": { 227 | "text/plain": [ 228 | "False" 229 | ] 230 | }, 231 | "metadata": {}, 232 | "execution_count": 15 233 | } 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "source": [ 239 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'" 240 | ], 241 | "metadata": { 242 | "id": "l1lqFQnVyNXF" 243 | }, 244 | "execution_count": 12, 245 | "outputs": [] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "source": [ 250 | "device" 251 | ], 252 | "metadata": { 253 | "colab": { 254 | "base_uri": "https://localhost:8080/", 255 | "height": 36 256 | }, 257 | "id": "XZLSWKHmyW4I", 258 | "outputId": "d09aa7bf-a53c-44cc-9bd2-fd129e2fa0ec" 259 | }, 260 | "execution_count": 13, 261 | "outputs": [ 262 | { 263 | "output_type": "execute_result", 264 | "data": { 265 | "text/plain": [ 266 | "'cpu'" 267 | ], 268 | "application/vnd.google.colaboratory.intrinsic+json": { 269 | "type": "string" 270 | } 271 | }, 272 | "metadata": {}, 273 | "execution_count": 13 274 | } 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "source": [ 280 | "scaler = scaler.to(device)" 281 | ], 282 | "metadata": { 283 | "id": "F5Ekj_ITydS-" 284 | }, 285 | "execution_count": 14, 286 | "outputs": [] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "source": [ 291 | "class Network(nn.Module):\n", 292 | " def __init__(self):\n", 293 | " super().__init__()\n", 294 | "\n", 295 | " self.fc1 = nn.Linear(in_features=2, out_features=10)\n", 296 | " self.fc2 = nn.Linear(in_features=10, out_features=10)\n", 297 | " self.fc3 = nn.Linear(in_features=10, out_features=10)\n", 298 | " self.fc4 = nn.Linear(in_features=10, out_features=1)\n", 299 | "\n", 300 | " def forward(self, x):\n", 301 | " x = F.relu(self.fc1(x))\n", 302 | " x = F.relu(self.fc2(x))\n", 303 | " x = F.relu(self.fc3(x))\n", 304 | " x = self.fc4(x)\n", 305 | "\n", 306 | " return x" 307 | ], 308 | "metadata": { 309 | "id": "tC8YnzIwyf6n" 310 | }, 311 | "execution_count": 22, 312 | "outputs": [] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "source": [ 317 | "net = Network().to(device)" 318 | ], 319 | "metadata": { 320 | "id": "jmt6SYAZ0wPa" 321 | }, 322 | "execution_count": 23, 323 | "outputs": [] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "source": [ 328 | "test_input = torch.rand((3, 2))\n", 329 | "test_input" 330 | ], 331 | "metadata": { 332 | "colab": { 333 | "base_uri": "https://localhost:8080/" 334 | }, 335 | "id": "dT8GYeRl00ZX", 336 | "outputId": "9990a297-41e1-4512-c15f-9f29e2b0e1bc" 337 | }, 338 | "execution_count": 24, 339 | "outputs": [ 340 | { 341 | "output_type": "execute_result", 342 | "data": { 343 | "text/plain": [ 344 | "tensor([[0.7418, 0.8436],\n", 345 | " [0.8179, 0.3327],\n", 346 | " [0.5869, 0.7119]])" 347 | ] 348 | }, 349 | "metadata": {}, 350 | "execution_count": 24 351 | } 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "source": [ 357 | "test_output = net(test_input)\n", 358 | "test_output" 359 | ], 360 | "metadata": { 361 | "colab": { 362 | "base_uri": "https://localhost:8080/" 363 | }, 364 | "id": "D0BmAZwc0_rJ", 365 | "outputId": "6aa5025c-71ad-4de5-d711-190eb2f76ada" 366 | }, 367 | "execution_count": 25, 368 | "outputs": [ 369 | { 370 | "output_type": "execute_result", 371 | "data": { 372 | "text/plain": [ 373 | "tensor([[0.2814],\n", 374 | " [0.2967],\n", 375 | " [0.2830]], grad_fn=)" 376 | ] 377 | }, 378 | "metadata": {}, 379 | "execution_count": 25 380 | } 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "source": [ 386 | "net" 387 | ], 388 | "metadata": { 389 | "colab": { 390 | "base_uri": "https://localhost:8080/" 391 | }, 392 | "id": "x6_FwU7R1E8B", 393 | "outputId": "9b217aa5-2416-45d6-96de-88f4dfe9511f" 394 | }, 395 | "execution_count": 26, 396 | "outputs": [ 397 | { 398 | "output_type": "execute_result", 399 | "data": { 400 | "text/plain": [ 401 | "Network(\n", 402 | " (fc1): Linear(in_features=2, out_features=10, bias=True)\n", 403 | " (fc2): Linear(in_features=10, out_features=10, bias=True)\n", 404 | " (fc3): Linear(in_features=10, out_features=10, bias=True)\n", 405 | " (fc4): Linear(in_features=10, out_features=1, bias=True)\n", 406 | ")" 407 | ] 408 | }, 409 | "metadata": {}, 410 | "execution_count": 26 411 | } 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "source": [ 417 | "list(net.fc1.parameters())" 418 | ], 419 | "metadata": { 420 | "colab": { 421 | "base_uri": "https://localhost:8080/" 422 | }, 423 | "id": "0Zidy5vm1McJ", 424 | "outputId": "b5fbf595-8a8a-458f-e533-f6e1601ac526" 425 | }, 426 | "execution_count": 30, 427 | "outputs": [ 428 | { 429 | "output_type": "execute_result", 430 | "data": { 431 | "text/plain": [ 432 | "[Parameter containing:\n", 433 | " tensor([[-0.0678, 0.4129],\n", 434 | " [-0.2128, 0.5182],\n", 435 | " [ 0.3328, -0.5278],\n", 436 | " [-0.0189, 0.6389],\n", 437 | " [ 0.3837, 0.2405],\n", 438 | " [-0.2676, -0.2395],\n", 439 | " [ 0.2778, 0.0328],\n", 440 | " [-0.4963, 0.5281],\n", 441 | " [ 0.1273, -0.0240],\n", 442 | " [-0.4277, -0.5793]], requires_grad=True),\n", 443 | " Parameter containing:\n", 444 | " tensor([-0.5801, 0.0506, 0.2750, -0.6558, -0.2582, -0.3598, -0.5241, -0.0149,\n", 445 | " -0.2104, -0.5190], requires_grad=True)]" 446 | ] 447 | }, 448 | "metadata": {}, 449 | "execution_count": 30 450 | } 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "source": [], 456 | "metadata": { 457 | "id": "_Ai44HNl1PgN" 458 | }, 459 | "execution_count": null, 460 | "outputs": [] 461 | } 462 | ] 463 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | "# deep_learning_class_notebooks" 2 | --------------------------------------------------------------------------------