├── 0_Deep_Learning's_Hello_World.ipynb
├── 10_RNN_Network.ipynb
├── 11_Text_Classification.ipynb
├── 12_Machine_Translation_From_Scratch_Part_1.ipynb
├── 13_Machine_Translation_From_Scratch_Part_2.ipynb
├── 14_Vanilla_GAN.ipynb
├── 15_DCGAN.ipynb
├── 16_Conditional_DCGAN.ipynb
├── 17_Pix2Pix_GAN.ipynb
├── 18_Cycle_GAN.ipynb
├── 19_Arbitrary_Style_Transfer_(AdaIN).ipynb
├── 1_Fashion_MNIST_and_CIFAR10.ipynb
├── 20_VAE.ipynb
├── 21_Diffusion_Model.ipynb
├── 22_Open_Source_NLU_models.ipynb
├── 23_Value_functions_and_policy_iteration.ipynb
├── 24_Double_Deep_Q_Learning_1_gym_intro.ipynb
├── 25_Double_Deep_Q_Learning_2.ipynb
├── 2_horse_or_human_workshop.ipynb
├── 3_VGG16_keras_applicarions.ipynb
├── 4_Residual_Networks.ipynb
├── 5_Pytorch_Introdution.ipynb
├── 6_CIFAR10_pytorch.ipynb
├── 7_Neural_Style_Transfer_.ipynb
├── 8_Siamese_Networks.ipynb
├── 9_Unet_and_Segmentation.ipynb
└── README.md
/13_Machine_Translation_From_Scratch_Part_2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyMMzEZjTDTJxfI3u0MIefDO",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 35,
32 | "metadata": {
33 | "id": "cezeGfesoylg"
34 | },
35 | "outputs": [],
36 | "source": [
37 | "import torch\n",
38 | "import torch.nn as nn\n",
39 | "import torch.nn.functional as F\n",
40 | "import torch.optim as optim"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "source": [
46 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'"
47 | ],
48 | "metadata": {
49 | "id": "RCpeVTNlvgxg"
50 | },
51 | "execution_count": 36,
52 | "outputs": []
53 | },
54 | {
55 | "cell_type": "code",
56 | "source": [
57 | "class Encoder(nn.Module):\n",
58 | " def __init__(self, num_tokens, embedding_dim, latent_dim):\n",
59 | " super().__init__()\n",
60 | "\n",
61 | " self.embedding = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_dim)\n",
62 | " self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=latent_dim, num_layers=1, batch_first=True, bidirectional=True)\n",
63 | "\n",
64 | " self.latent_dim = latent_dim\n",
65 | "\n",
66 | " def forward(self, x):\n",
67 | " x = self.embedding(x)\n",
68 | " batch_size, _, _ = x.size()\n",
69 | " h_0 = torch.zeros(2, batch_size, self.latent_dim).to(DEVICE)\n",
70 | " outputs, context_vector = self.rnn(x, h_0)\n",
71 | "\n",
72 | " return context_vector, outputs"
73 | ],
74 | "metadata": {
75 | "id": "l6drzN3TvYEP"
76 | },
77 | "execution_count": 37,
78 | "outputs": []
79 | },
80 | {
81 | "cell_type": "code",
82 | "source": [
83 | "encoder = Encoder(num_tokens=100, embedding_dim=16, latent_dim=64).to(DEVICE)"
84 | ],
85 | "metadata": {
86 | "id": "uEfD71WEvcs0"
87 | },
88 | "execution_count": 38,
89 | "outputs": []
90 | },
91 | {
92 | "cell_type": "code",
93 | "source": [
94 | "batch_size = 10\n",
95 | "seq_length = 30\n",
96 | "\n",
97 | "test_input = torch.zeros(batch_size, seq_length, dtype=torch.int64)\n",
98 | "test_output, _ = encoder(test_input)\n",
99 | "\n",
100 | "test_output.shape"
101 | ],
102 | "metadata": {
103 | "colab": {
104 | "base_uri": "https://localhost:8080/"
105 | },
106 | "id": "nKKSzpd_voel",
107 | "outputId": "ed6adfc9-9c52-4b02-a511-8fff8f3988de"
108 | },
109 | "execution_count": 39,
110 | "outputs": [
111 | {
112 | "output_type": "execute_result",
113 | "data": {
114 | "text/plain": [
115 | "torch.Size([2, 10, 64])"
116 | ]
117 | },
118 | "metadata": {},
119 | "execution_count": 39
120 | }
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "source": [
126 | "class AttentionBlock(nn.Module):\n",
127 | " def __init__(self, hidden_dim):\n",
128 | " super().__init__()\n",
129 | " self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
130 | " self.U = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
131 | " self.V = nn.Linear(hidden_dim, 1, bias=False)\n",
132 | "\n",
133 | " def forward(self, query, keys):\n",
134 | " QK = self.W(query).unsqueeze(1) + self.U(keys)\n",
135 | " QK = torch.tanh(QK)\n",
136 | " scores = self.V(QK)\n",
137 | "\n",
138 | " scores = scores.squeeze(2).unsqueeze(1)\n",
139 | " weigths = F.softmax(scores, dim=-1)\n",
140 | "\n",
141 | " context = torch.bmm(weigths, keys)\n",
142 | "\n",
143 | " return context, weigths"
144 | ],
145 | "metadata": {
146 | "id": "QoAL81lzwBtT"
147 | },
148 | "execution_count": 40,
149 | "outputs": []
150 | },
151 | {
152 | "cell_type": "code",
153 | "source": [
154 | "class AttentionGRU(nn.Module):\n",
155 | " def __init__(self, input_size, latent_dim):\n",
156 | " super().__init__()\n",
157 | " self.attention = AttentionBlock(2 * latent_dim)\n",
158 | " self.rnn = nn.GRU(input_size=input_size, hidden_size=2 * latent_dim)\n",
159 | " self.latent_dim = latent_dim\n",
160 | "\n",
161 | " def forward(self, predicted_label, encoder_outputs):\n",
162 | " batch_size, _, _ = predicted_label.size()\n",
163 | " h = torch.zeros(batch_size, 2 * self.latent_dim)\n",
164 | " predicted_label = predicted_label.permute(1, 0, 2)\n",
165 | " for token in predicted_label:\n",
166 | " context, weights = self.attention(h, encoder_outputs)\n",
167 | " context = context.permute(1, 0, 2)\n",
168 | " token = token.unsqueeze(1).permute(1, 0, 2)\n",
169 | " output, h = self.rnn(token, context)\n",
170 | " h = h.squeeze()\n",
171 | "\n",
172 | " return output, h"
173 | ],
174 | "metadata": {
175 | "id": "FHt-pZmz40N1"
176 | },
177 | "execution_count": 41,
178 | "outputs": []
179 | },
180 | {
181 | "cell_type": "code",
182 | "source": [
183 | "class Decoder(nn.Module):\n",
184 | " def __init__(self, num_tokens, embedding_dim, latent_dim):\n",
185 | " super().__init__()\n",
186 | "\n",
187 | " self.embedding = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_dim)\n",
188 | " self.rnn = AttentionGRU(embedding_dim, latent_dim)\n",
189 | " self.fc = nn.Linear(in_features=2 * latent_dim, out_features=num_tokens)\n",
190 | " self.softmax = nn.LogSoftmax(dim=1)\n",
191 | "\n",
192 | " def forward(self, encoder_outputs, predicted_label):\n",
193 | " x = self.embedding(predicted_label)\n",
194 | " x, _ = self.rnn(x, encoder_outputs)\n",
195 | " x = self.fc(x)\n",
196 | " x = self.softmax(x)\n",
197 | "\n",
198 | " return x"
199 | ],
200 | "metadata": {
201 | "id": "nfrRgkJF62bE"
202 | },
203 | "execution_count": 42,
204 | "outputs": []
205 | },
206 | {
207 | "cell_type": "code",
208 | "source": [
209 | "encoder = Encoder(num_tokens=100, embedding_dim=8, latent_dim=16)\n",
210 | "decoder = Decoder(num_tokens=100, embedding_dim=8, latent_dim=16)"
211 | ],
212 | "metadata": {
213 | "id": "lrVicVYv7LBr"
214 | },
215 | "execution_count": 43,
216 | "outputs": []
217 | },
218 | {
219 | "cell_type": "code",
220 | "source": [
221 | "batch_size = 50\n",
222 | "seq_length = 20\n",
223 | "predicted_labels_count = 10\n",
224 | "test_input = torch.zeros(batch_size, seq_length, dtype=torch.int64)\n",
225 | "predicted_labels = torch.zeros(batch_size, predicted_labels_count, dtype=torch.int64)\n",
226 | "\n",
227 | "_, encoder_output = encoder(test_input)\n",
228 | "new_token = decoder(encoder_output, predicted_labels)\n",
229 | "new_token.size()"
230 | ],
231 | "metadata": {
232 | "colab": {
233 | "base_uri": "https://localhost:8080/"
234 | },
235 | "id": "L-glyCsm7T29",
236 | "outputId": "4e784ab3-7cef-4924-ecb7-3d0ee3b7b8e5"
237 | },
238 | "execution_count": 44,
239 | "outputs": [
240 | {
241 | "output_type": "execute_result",
242 | "data": {
243 | "text/plain": [
244 | "torch.Size([1, 50, 100])"
245 | ]
246 | },
247 | "metadata": {},
248 | "execution_count": 44
249 | }
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "source": [],
255 | "metadata": {
256 | "id": "ncvEzVsD79Zh"
257 | },
258 | "execution_count": 44,
259 | "outputs": []
260 | }
261 | ]
262 | }
--------------------------------------------------------------------------------
/21_Diffusion_Model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4",
8 | "authorship_tag": "ABX9TyPp+28MPcewUc/8K+ah+M4u",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | },
18 | "accelerator": "GPU",
19 | "widgets": {
20 | "application/vnd.jupyter.widget-state+json": {
21 | "b0fa66fc56e34465bb2a59f9ce916b71": {
22 | "model_module": "@jupyter-widgets/controls",
23 | "model_name": "HBoxModel",
24 | "model_module_version": "1.5.0",
25 | "state": {
26 | "_dom_classes": [],
27 | "_model_module": "@jupyter-widgets/controls",
28 | "_model_module_version": "1.5.0",
29 | "_model_name": "HBoxModel",
30 | "_view_count": null,
31 | "_view_module": "@jupyter-widgets/controls",
32 | "_view_module_version": "1.5.0",
33 | "_view_name": "HBoxView",
34 | "box_style": "",
35 | "children": [
36 | "IPY_MODEL_14019560a279458cb775634b162b2407",
37 | "IPY_MODEL_4db4b93e47bd4c859b1d6685ed8cf76e",
38 | "IPY_MODEL_2ef34af8de2c4c7eab85fd3c83a5e1c5"
39 | ],
40 | "layout": "IPY_MODEL_1638fe599f2c409c8d75396bbad174b7"
41 | }
42 | },
43 | "14019560a279458cb775634b162b2407": {
44 | "model_module": "@jupyter-widgets/controls",
45 | "model_name": "HTMLModel",
46 | "model_module_version": "1.5.0",
47 | "state": {
48 | "_dom_classes": [],
49 | "_model_module": "@jupyter-widgets/controls",
50 | "_model_module_version": "1.5.0",
51 | "_model_name": "HTMLModel",
52 | "_view_count": null,
53 | "_view_module": "@jupyter-widgets/controls",
54 | "_view_module_version": "1.5.0",
55 | "_view_name": "HTMLView",
56 | "description": "",
57 | "description_tooltip": null,
58 | "layout": "IPY_MODEL_0a2010f11ba345a385997b328168127a",
59 | "placeholder": "",
60 | "style": "IPY_MODEL_5a39152a73cb4a87bd5bdf6f5a6888cc",
61 | "value": "100%"
62 | }
63 | },
64 | "4db4b93e47bd4c859b1d6685ed8cf76e": {
65 | "model_module": "@jupyter-widgets/controls",
66 | "model_name": "FloatProgressModel",
67 | "model_module_version": "1.5.0",
68 | "state": {
69 | "_dom_classes": [],
70 | "_model_module": "@jupyter-widgets/controls",
71 | "_model_module_version": "1.5.0",
72 | "_model_name": "FloatProgressModel",
73 | "_view_count": null,
74 | "_view_module": "@jupyter-widgets/controls",
75 | "_view_module_version": "1.5.0",
76 | "_view_name": "ProgressView",
77 | "bar_style": "success",
78 | "description": "",
79 | "description_tooltip": null,
80 | "layout": "IPY_MODEL_b72fab00156243478d0619f3f8a56a1a",
81 | "max": 938,
82 | "min": 0,
83 | "orientation": "horizontal",
84 | "style": "IPY_MODEL_38548a20e7744f098a9118b82ff13f95",
85 | "value": 938
86 | }
87 | },
88 | "2ef34af8de2c4c7eab85fd3c83a5e1c5": {
89 | "model_module": "@jupyter-widgets/controls",
90 | "model_name": "HTMLModel",
91 | "model_module_version": "1.5.0",
92 | "state": {
93 | "_dom_classes": [],
94 | "_model_module": "@jupyter-widgets/controls",
95 | "_model_module_version": "1.5.0",
96 | "_model_name": "HTMLModel",
97 | "_view_count": null,
98 | "_view_module": "@jupyter-widgets/controls",
99 | "_view_module_version": "1.5.0",
100 | "_view_name": "HTMLView",
101 | "description": "",
102 | "description_tooltip": null,
103 | "layout": "IPY_MODEL_47d16d89d79543ab8b893ccae490d2f0",
104 | "placeholder": "",
105 | "style": "IPY_MODEL_05f22484b66641818d3288792ead2efa",
106 | "value": " 938/938 [03:40<00:00, 4.89it/s]"
107 | }
108 | },
109 | "1638fe599f2c409c8d75396bbad174b7": {
110 | "model_module": "@jupyter-widgets/base",
111 | "model_name": "LayoutModel",
112 | "model_module_version": "1.2.0",
113 | "state": {
114 | "_model_module": "@jupyter-widgets/base",
115 | "_model_module_version": "1.2.0",
116 | "_model_name": "LayoutModel",
117 | "_view_count": null,
118 | "_view_module": "@jupyter-widgets/base",
119 | "_view_module_version": "1.2.0",
120 | "_view_name": "LayoutView",
121 | "align_content": null,
122 | "align_items": null,
123 | "align_self": null,
124 | "border": null,
125 | "bottom": null,
126 | "display": null,
127 | "flex": null,
128 | "flex_flow": null,
129 | "grid_area": null,
130 | "grid_auto_columns": null,
131 | "grid_auto_flow": null,
132 | "grid_auto_rows": null,
133 | "grid_column": null,
134 | "grid_gap": null,
135 | "grid_row": null,
136 | "grid_template_areas": null,
137 | "grid_template_columns": null,
138 | "grid_template_rows": null,
139 | "height": null,
140 | "justify_content": null,
141 | "justify_items": null,
142 | "left": null,
143 | "margin": null,
144 | "max_height": null,
145 | "max_width": null,
146 | "min_height": null,
147 | "min_width": null,
148 | "object_fit": null,
149 | "object_position": null,
150 | "order": null,
151 | "overflow": null,
152 | "overflow_x": null,
153 | "overflow_y": null,
154 | "padding": null,
155 | "right": null,
156 | "top": null,
157 | "visibility": null,
158 | "width": null
159 | }
160 | },
161 | "0a2010f11ba345a385997b328168127a": {
162 | "model_module": "@jupyter-widgets/base",
163 | "model_name": "LayoutModel",
164 | "model_module_version": "1.2.0",
165 | "state": {
166 | "_model_module": "@jupyter-widgets/base",
167 | "_model_module_version": "1.2.0",
168 | "_model_name": "LayoutModel",
169 | "_view_count": null,
170 | "_view_module": "@jupyter-widgets/base",
171 | "_view_module_version": "1.2.0",
172 | "_view_name": "LayoutView",
173 | "align_content": null,
174 | "align_items": null,
175 | "align_self": null,
176 | "border": null,
177 | "bottom": null,
178 | "display": null,
179 | "flex": null,
180 | "flex_flow": null,
181 | "grid_area": null,
182 | "grid_auto_columns": null,
183 | "grid_auto_flow": null,
184 | "grid_auto_rows": null,
185 | "grid_column": null,
186 | "grid_gap": null,
187 | "grid_row": null,
188 | "grid_template_areas": null,
189 | "grid_template_columns": null,
190 | "grid_template_rows": null,
191 | "height": null,
192 | "justify_content": null,
193 | "justify_items": null,
194 | "left": null,
195 | "margin": null,
196 | "max_height": null,
197 | "max_width": null,
198 | "min_height": null,
199 | "min_width": null,
200 | "object_fit": null,
201 | "object_position": null,
202 | "order": null,
203 | "overflow": null,
204 | "overflow_x": null,
205 | "overflow_y": null,
206 | "padding": null,
207 | "right": null,
208 | "top": null,
209 | "visibility": null,
210 | "width": null
211 | }
212 | },
213 | "5a39152a73cb4a87bd5bdf6f5a6888cc": {
214 | "model_module": "@jupyter-widgets/controls",
215 | "model_name": "DescriptionStyleModel",
216 | "model_module_version": "1.5.0",
217 | "state": {
218 | "_model_module": "@jupyter-widgets/controls",
219 | "_model_module_version": "1.5.0",
220 | "_model_name": "DescriptionStyleModel",
221 | "_view_count": null,
222 | "_view_module": "@jupyter-widgets/base",
223 | "_view_module_version": "1.2.0",
224 | "_view_name": "StyleView",
225 | "description_width": ""
226 | }
227 | },
228 | "b72fab00156243478d0619f3f8a56a1a": {
229 | "model_module": "@jupyter-widgets/base",
230 | "model_name": "LayoutModel",
231 | "model_module_version": "1.2.0",
232 | "state": {
233 | "_model_module": "@jupyter-widgets/base",
234 | "_model_module_version": "1.2.0",
235 | "_model_name": "LayoutModel",
236 | "_view_count": null,
237 | "_view_module": "@jupyter-widgets/base",
238 | "_view_module_version": "1.2.0",
239 | "_view_name": "LayoutView",
240 | "align_content": null,
241 | "align_items": null,
242 | "align_self": null,
243 | "border": null,
244 | "bottom": null,
245 | "display": null,
246 | "flex": null,
247 | "flex_flow": null,
248 | "grid_area": null,
249 | "grid_auto_columns": null,
250 | "grid_auto_flow": null,
251 | "grid_auto_rows": null,
252 | "grid_column": null,
253 | "grid_gap": null,
254 | "grid_row": null,
255 | "grid_template_areas": null,
256 | "grid_template_columns": null,
257 | "grid_template_rows": null,
258 | "height": null,
259 | "justify_content": null,
260 | "justify_items": null,
261 | "left": null,
262 | "margin": null,
263 | "max_height": null,
264 | "max_width": null,
265 | "min_height": null,
266 | "min_width": null,
267 | "object_fit": null,
268 | "object_position": null,
269 | "order": null,
270 | "overflow": null,
271 | "overflow_x": null,
272 | "overflow_y": null,
273 | "padding": null,
274 | "right": null,
275 | "top": null,
276 | "visibility": null,
277 | "width": null
278 | }
279 | },
280 | "38548a20e7744f098a9118b82ff13f95": {
281 | "model_module": "@jupyter-widgets/controls",
282 | "model_name": "ProgressStyleModel",
283 | "model_module_version": "1.5.0",
284 | "state": {
285 | "_model_module": "@jupyter-widgets/controls",
286 | "_model_module_version": "1.5.0",
287 | "_model_name": "ProgressStyleModel",
288 | "_view_count": null,
289 | "_view_module": "@jupyter-widgets/base",
290 | "_view_module_version": "1.2.0",
291 | "_view_name": "StyleView",
292 | "bar_color": null,
293 | "description_width": ""
294 | }
295 | },
296 | "47d16d89d79543ab8b893ccae490d2f0": {
297 | "model_module": "@jupyter-widgets/base",
298 | "model_name": "LayoutModel",
299 | "model_module_version": "1.2.0",
300 | "state": {
301 | "_model_module": "@jupyter-widgets/base",
302 | "_model_module_version": "1.2.0",
303 | "_model_name": "LayoutModel",
304 | "_view_count": null,
305 | "_view_module": "@jupyter-widgets/base",
306 | "_view_module_version": "1.2.0",
307 | "_view_name": "LayoutView",
308 | "align_content": null,
309 | "align_items": null,
310 | "align_self": null,
311 | "border": null,
312 | "bottom": null,
313 | "display": null,
314 | "flex": null,
315 | "flex_flow": null,
316 | "grid_area": null,
317 | "grid_auto_columns": null,
318 | "grid_auto_flow": null,
319 | "grid_auto_rows": null,
320 | "grid_column": null,
321 | "grid_gap": null,
322 | "grid_row": null,
323 | "grid_template_areas": null,
324 | "grid_template_columns": null,
325 | "grid_template_rows": null,
326 | "height": null,
327 | "justify_content": null,
328 | "justify_items": null,
329 | "left": null,
330 | "margin": null,
331 | "max_height": null,
332 | "max_width": null,
333 | "min_height": null,
334 | "min_width": null,
335 | "object_fit": null,
336 | "object_position": null,
337 | "order": null,
338 | "overflow": null,
339 | "overflow_x": null,
340 | "overflow_y": null,
341 | "padding": null,
342 | "right": null,
343 | "top": null,
344 | "visibility": null,
345 | "width": null
346 | }
347 | },
348 | "05f22484b66641818d3288792ead2efa": {
349 | "model_module": "@jupyter-widgets/controls",
350 | "model_name": "DescriptionStyleModel",
351 | "model_module_version": "1.5.0",
352 | "state": {
353 | "_model_module": "@jupyter-widgets/controls",
354 | "_model_module_version": "1.5.0",
355 | "_model_name": "DescriptionStyleModel",
356 | "_view_count": null,
357 | "_view_module": "@jupyter-widgets/base",
358 | "_view_module_version": "1.2.0",
359 | "_view_name": "StyleView",
360 | "description_width": ""
361 | }
362 | },
363 | "7c76076cccf04326a82ec69c5b33c9af": {
364 | "model_module": "@jupyter-widgets/controls",
365 | "model_name": "HBoxModel",
366 | "model_module_version": "1.5.0",
367 | "state": {
368 | "_dom_classes": [],
369 | "_model_module": "@jupyter-widgets/controls",
370 | "_model_module_version": "1.5.0",
371 | "_model_name": "HBoxModel",
372 | "_view_count": null,
373 | "_view_module": "@jupyter-widgets/controls",
374 | "_view_module_version": "1.5.0",
375 | "_view_name": "HBoxView",
376 | "box_style": "",
377 | "children": [
378 | "IPY_MODEL_326f80a6a9364bc4a37fd727649c5599",
379 | "IPY_MODEL_cd209532893744078f13583bc2765ef5",
380 | "IPY_MODEL_9933cef2fd0c44c2a308f24e5de0b1ea"
381 | ],
382 | "layout": "IPY_MODEL_9d07802c82da4382966ee2bc58bff4e2"
383 | }
384 | },
385 | "326f80a6a9364bc4a37fd727649c5599": {
386 | "model_module": "@jupyter-widgets/controls",
387 | "model_name": "HTMLModel",
388 | "model_module_version": "1.5.0",
389 | "state": {
390 | "_dom_classes": [],
391 | "_model_module": "@jupyter-widgets/controls",
392 | "_model_module_version": "1.5.0",
393 | "_model_name": "HTMLModel",
394 | "_view_count": null,
395 | "_view_module": "@jupyter-widgets/controls",
396 | "_view_module_version": "1.5.0",
397 | "_view_name": "HTMLView",
398 | "description": "",
399 | "description_tooltip": null,
400 | "layout": "IPY_MODEL_7c62a32ae5944b7584b0026533d3bea9",
401 | "placeholder": "",
402 | "style": "IPY_MODEL_7d6defe1998a4083bbf9f4fdb4e5d599",
403 | "value": "100%"
404 | }
405 | },
406 | "cd209532893744078f13583bc2765ef5": {
407 | "model_module": "@jupyter-widgets/controls",
408 | "model_name": "FloatProgressModel",
409 | "model_module_version": "1.5.0",
410 | "state": {
411 | "_dom_classes": [],
412 | "_model_module": "@jupyter-widgets/controls",
413 | "_model_module_version": "1.5.0",
414 | "_model_name": "FloatProgressModel",
415 | "_view_count": null,
416 | "_view_module": "@jupyter-widgets/controls",
417 | "_view_module_version": "1.5.0",
418 | "_view_name": "ProgressView",
419 | "bar_style": "success",
420 | "description": "",
421 | "description_tooltip": null,
422 | "layout": "IPY_MODEL_5a2ad31325dc45d0b795eb88daf147df",
423 | "max": 938,
424 | "min": 0,
425 | "orientation": "horizontal",
426 | "style": "IPY_MODEL_c981017d7c39426dbaf566505c817bf4",
427 | "value": 938
428 | }
429 | },
430 | "9933cef2fd0c44c2a308f24e5de0b1ea": {
431 | "model_module": "@jupyter-widgets/controls",
432 | "model_name": "HTMLModel",
433 | "model_module_version": "1.5.0",
434 | "state": {
435 | "_dom_classes": [],
436 | "_model_module": "@jupyter-widgets/controls",
437 | "_model_module_version": "1.5.0",
438 | "_model_name": "HTMLModel",
439 | "_view_count": null,
440 | "_view_module": "@jupyter-widgets/controls",
441 | "_view_module_version": "1.5.0",
442 | "_view_name": "HTMLView",
443 | "description": "",
444 | "description_tooltip": null,
445 | "layout": "IPY_MODEL_2ca3d70b283a4eaba7c683d03a53da07",
446 | "placeholder": "",
447 | "style": "IPY_MODEL_51922432c661481da1ebe8192367c4dd",
448 | "value": " 938/938 [03:39<00:00, 4.95it/s]"
449 | }
450 | },
451 | "9d07802c82da4382966ee2bc58bff4e2": {
452 | "model_module": "@jupyter-widgets/base",
453 | "model_name": "LayoutModel",
454 | "model_module_version": "1.2.0",
455 | "state": {
456 | "_model_module": "@jupyter-widgets/base",
457 | "_model_module_version": "1.2.0",
458 | "_model_name": "LayoutModel",
459 | "_view_count": null,
460 | "_view_module": "@jupyter-widgets/base",
461 | "_view_module_version": "1.2.0",
462 | "_view_name": "LayoutView",
463 | "align_content": null,
464 | "align_items": null,
465 | "align_self": null,
466 | "border": null,
467 | "bottom": null,
468 | "display": null,
469 | "flex": null,
470 | "flex_flow": null,
471 | "grid_area": null,
472 | "grid_auto_columns": null,
473 | "grid_auto_flow": null,
474 | "grid_auto_rows": null,
475 | "grid_column": null,
476 | "grid_gap": null,
477 | "grid_row": null,
478 | "grid_template_areas": null,
479 | "grid_template_columns": null,
480 | "grid_template_rows": null,
481 | "height": null,
482 | "justify_content": null,
483 | "justify_items": null,
484 | "left": null,
485 | "margin": null,
486 | "max_height": null,
487 | "max_width": null,
488 | "min_height": null,
489 | "min_width": null,
490 | "object_fit": null,
491 | "object_position": null,
492 | "order": null,
493 | "overflow": null,
494 | "overflow_x": null,
495 | "overflow_y": null,
496 | "padding": null,
497 | "right": null,
498 | "top": null,
499 | "visibility": null,
500 | "width": null
501 | }
502 | },
503 | "7c62a32ae5944b7584b0026533d3bea9": {
504 | "model_module": "@jupyter-widgets/base",
505 | "model_name": "LayoutModel",
506 | "model_module_version": "1.2.0",
507 | "state": {
508 | "_model_module": "@jupyter-widgets/base",
509 | "_model_module_version": "1.2.0",
510 | "_model_name": "LayoutModel",
511 | "_view_count": null,
512 | "_view_module": "@jupyter-widgets/base",
513 | "_view_module_version": "1.2.0",
514 | "_view_name": "LayoutView",
515 | "align_content": null,
516 | "align_items": null,
517 | "align_self": null,
518 | "border": null,
519 | "bottom": null,
520 | "display": null,
521 | "flex": null,
522 | "flex_flow": null,
523 | "grid_area": null,
524 | "grid_auto_columns": null,
525 | "grid_auto_flow": null,
526 | "grid_auto_rows": null,
527 | "grid_column": null,
528 | "grid_gap": null,
529 | "grid_row": null,
530 | "grid_template_areas": null,
531 | "grid_template_columns": null,
532 | "grid_template_rows": null,
533 | "height": null,
534 | "justify_content": null,
535 | "justify_items": null,
536 | "left": null,
537 | "margin": null,
538 | "max_height": null,
539 | "max_width": null,
540 | "min_height": null,
541 | "min_width": null,
542 | "object_fit": null,
543 | "object_position": null,
544 | "order": null,
545 | "overflow": null,
546 | "overflow_x": null,
547 | "overflow_y": null,
548 | "padding": null,
549 | "right": null,
550 | "top": null,
551 | "visibility": null,
552 | "width": null
553 | }
554 | },
555 | "7d6defe1998a4083bbf9f4fdb4e5d599": {
556 | "model_module": "@jupyter-widgets/controls",
557 | "model_name": "DescriptionStyleModel",
558 | "model_module_version": "1.5.0",
559 | "state": {
560 | "_model_module": "@jupyter-widgets/controls",
561 | "_model_module_version": "1.5.0",
562 | "_model_name": "DescriptionStyleModel",
563 | "_view_count": null,
564 | "_view_module": "@jupyter-widgets/base",
565 | "_view_module_version": "1.2.0",
566 | "_view_name": "StyleView",
567 | "description_width": ""
568 | }
569 | },
570 | "5a2ad31325dc45d0b795eb88daf147df": {
571 | "model_module": "@jupyter-widgets/base",
572 | "model_name": "LayoutModel",
573 | "model_module_version": "1.2.0",
574 | "state": {
575 | "_model_module": "@jupyter-widgets/base",
576 | "_model_module_version": "1.2.0",
577 | "_model_name": "LayoutModel",
578 | "_view_count": null,
579 | "_view_module": "@jupyter-widgets/base",
580 | "_view_module_version": "1.2.0",
581 | "_view_name": "LayoutView",
582 | "align_content": null,
583 | "align_items": null,
584 | "align_self": null,
585 | "border": null,
586 | "bottom": null,
587 | "display": null,
588 | "flex": null,
589 | "flex_flow": null,
590 | "grid_area": null,
591 | "grid_auto_columns": null,
592 | "grid_auto_flow": null,
593 | "grid_auto_rows": null,
594 | "grid_column": null,
595 | "grid_gap": null,
596 | "grid_row": null,
597 | "grid_template_areas": null,
598 | "grid_template_columns": null,
599 | "grid_template_rows": null,
600 | "height": null,
601 | "justify_content": null,
602 | "justify_items": null,
603 | "left": null,
604 | "margin": null,
605 | "max_height": null,
606 | "max_width": null,
607 | "min_height": null,
608 | "min_width": null,
609 | "object_fit": null,
610 | "object_position": null,
611 | "order": null,
612 | "overflow": null,
613 | "overflow_x": null,
614 | "overflow_y": null,
615 | "padding": null,
616 | "right": null,
617 | "top": null,
618 | "visibility": null,
619 | "width": null
620 | }
621 | },
622 | "c981017d7c39426dbaf566505c817bf4": {
623 | "model_module": "@jupyter-widgets/controls",
624 | "model_name": "ProgressStyleModel",
625 | "model_module_version": "1.5.0",
626 | "state": {
627 | "_model_module": "@jupyter-widgets/controls",
628 | "_model_module_version": "1.5.0",
629 | "_model_name": "ProgressStyleModel",
630 | "_view_count": null,
631 | "_view_module": "@jupyter-widgets/base",
632 | "_view_module_version": "1.2.0",
633 | "_view_name": "StyleView",
634 | "bar_color": null,
635 | "description_width": ""
636 | }
637 | },
638 | "2ca3d70b283a4eaba7c683d03a53da07": {
639 | "model_module": "@jupyter-widgets/base",
640 | "model_name": "LayoutModel",
641 | "model_module_version": "1.2.0",
642 | "state": {
643 | "_model_module": "@jupyter-widgets/base",
644 | "_model_module_version": "1.2.0",
645 | "_model_name": "LayoutModel",
646 | "_view_count": null,
647 | "_view_module": "@jupyter-widgets/base",
648 | "_view_module_version": "1.2.0",
649 | "_view_name": "LayoutView",
650 | "align_content": null,
651 | "align_items": null,
652 | "align_self": null,
653 | "border": null,
654 | "bottom": null,
655 | "display": null,
656 | "flex": null,
657 | "flex_flow": null,
658 | "grid_area": null,
659 | "grid_auto_columns": null,
660 | "grid_auto_flow": null,
661 | "grid_auto_rows": null,
662 | "grid_column": null,
663 | "grid_gap": null,
664 | "grid_row": null,
665 | "grid_template_areas": null,
666 | "grid_template_columns": null,
667 | "grid_template_rows": null,
668 | "height": null,
669 | "justify_content": null,
670 | "justify_items": null,
671 | "left": null,
672 | "margin": null,
673 | "max_height": null,
674 | "max_width": null,
675 | "min_height": null,
676 | "min_width": null,
677 | "object_fit": null,
678 | "object_position": null,
679 | "order": null,
680 | "overflow": null,
681 | "overflow_x": null,
682 | "overflow_y": null,
683 | "padding": null,
684 | "right": null,
685 | "top": null,
686 | "visibility": null,
687 | "width": null
688 | }
689 | },
690 | "51922432c661481da1ebe8192367c4dd": {
691 | "model_module": "@jupyter-widgets/controls",
692 | "model_name": "DescriptionStyleModel",
693 | "model_module_version": "1.5.0",
694 | "state": {
695 | "_model_module": "@jupyter-widgets/controls",
696 | "_model_module_version": "1.5.0",
697 | "_model_name": "DescriptionStyleModel",
698 | "_view_count": null,
699 | "_view_module": "@jupyter-widgets/base",
700 | "_view_module_version": "1.2.0",
701 | "_view_name": "StyleView",
702 | "description_width": ""
703 | }
704 | },
705 | "61d30891744b4dcda25e8eeb71b54aad": {
706 | "model_module": "@jupyter-widgets/controls",
707 | "model_name": "HBoxModel",
708 | "model_module_version": "1.5.0",
709 | "state": {
710 | "_dom_classes": [],
711 | "_model_module": "@jupyter-widgets/controls",
712 | "_model_module_version": "1.5.0",
713 | "_model_name": "HBoxModel",
714 | "_view_count": null,
715 | "_view_module": "@jupyter-widgets/controls",
716 | "_view_module_version": "1.5.0",
717 | "_view_name": "HBoxView",
718 | "box_style": "",
719 | "children": [
720 | "IPY_MODEL_718e38085b3b4f5eab8c481913bad4fb",
721 | "IPY_MODEL_3346c18e5aae424a9705689c62854830",
722 | "IPY_MODEL_1a782ba0602d4f7ebb0659295f7497f0"
723 | ],
724 | "layout": "IPY_MODEL_80ad93a8096a41f6ba7035bb483ee461"
725 | }
726 | },
727 | "718e38085b3b4f5eab8c481913bad4fb": {
728 | "model_module": "@jupyter-widgets/controls",
729 | "model_name": "HTMLModel",
730 | "model_module_version": "1.5.0",
731 | "state": {
732 | "_dom_classes": [],
733 | "_model_module": "@jupyter-widgets/controls",
734 | "_model_module_version": "1.5.0",
735 | "_model_name": "HTMLModel",
736 | "_view_count": null,
737 | "_view_module": "@jupyter-widgets/controls",
738 | "_view_module_version": "1.5.0",
739 | "_view_name": "HTMLView",
740 | "description": "",
741 | "description_tooltip": null,
742 | "layout": "IPY_MODEL_bf1b70fc39db4494aee42122b14a0c8a",
743 | "placeholder": "",
744 | "style": "IPY_MODEL_9cfbaf78195a4409aa3c751fef6377e1",
745 | "value": "100%"
746 | }
747 | },
748 | "3346c18e5aae424a9705689c62854830": {
749 | "model_module": "@jupyter-widgets/controls",
750 | "model_name": "FloatProgressModel",
751 | "model_module_version": "1.5.0",
752 | "state": {
753 | "_dom_classes": [],
754 | "_model_module": "@jupyter-widgets/controls",
755 | "_model_module_version": "1.5.0",
756 | "_model_name": "FloatProgressModel",
757 | "_view_count": null,
758 | "_view_module": "@jupyter-widgets/controls",
759 | "_view_module_version": "1.5.0",
760 | "_view_name": "ProgressView",
761 | "bar_style": "success",
762 | "description": "",
763 | "description_tooltip": null,
764 | "layout": "IPY_MODEL_c8379e3f85c74bb9889bf49229de82b3",
765 | "max": 938,
766 | "min": 0,
767 | "orientation": "horizontal",
768 | "style": "IPY_MODEL_75125bd2901045ffa9c9b7d3583ee94c",
769 | "value": 938
770 | }
771 | },
772 | "1a782ba0602d4f7ebb0659295f7497f0": {
773 | "model_module": "@jupyter-widgets/controls",
774 | "model_name": "HTMLModel",
775 | "model_module_version": "1.5.0",
776 | "state": {
777 | "_dom_classes": [],
778 | "_model_module": "@jupyter-widgets/controls",
779 | "_model_module_version": "1.5.0",
780 | "_model_name": "HTMLModel",
781 | "_view_count": null,
782 | "_view_module": "@jupyter-widgets/controls",
783 | "_view_module_version": "1.5.0",
784 | "_view_name": "HTMLView",
785 | "description": "",
786 | "description_tooltip": null,
787 | "layout": "IPY_MODEL_78a5fa8234ac4ff182f882908461559a",
788 | "placeholder": "",
789 | "style": "IPY_MODEL_28abc7c5f1914aa0a97fd7d280d75ee0",
790 | "value": " 938/938 [03:40<00:00, 4.95it/s]"
791 | }
792 | },
793 | "80ad93a8096a41f6ba7035bb483ee461": {
794 | "model_module": "@jupyter-widgets/base",
795 | "model_name": "LayoutModel",
796 | "model_module_version": "1.2.0",
797 | "state": {
798 | "_model_module": "@jupyter-widgets/base",
799 | "_model_module_version": "1.2.0",
800 | "_model_name": "LayoutModel",
801 | "_view_count": null,
802 | "_view_module": "@jupyter-widgets/base",
803 | "_view_module_version": "1.2.0",
804 | "_view_name": "LayoutView",
805 | "align_content": null,
806 | "align_items": null,
807 | "align_self": null,
808 | "border": null,
809 | "bottom": null,
810 | "display": null,
811 | "flex": null,
812 | "flex_flow": null,
813 | "grid_area": null,
814 | "grid_auto_columns": null,
815 | "grid_auto_flow": null,
816 | "grid_auto_rows": null,
817 | "grid_column": null,
818 | "grid_gap": null,
819 | "grid_row": null,
820 | "grid_template_areas": null,
821 | "grid_template_columns": null,
822 | "grid_template_rows": null,
823 | "height": null,
824 | "justify_content": null,
825 | "justify_items": null,
826 | "left": null,
827 | "margin": null,
828 | "max_height": null,
829 | "max_width": null,
830 | "min_height": null,
831 | "min_width": null,
832 | "object_fit": null,
833 | "object_position": null,
834 | "order": null,
835 | "overflow": null,
836 | "overflow_x": null,
837 | "overflow_y": null,
838 | "padding": null,
839 | "right": null,
840 | "top": null,
841 | "visibility": null,
842 | "width": null
843 | }
844 | },
845 | "bf1b70fc39db4494aee42122b14a0c8a": {
846 | "model_module": "@jupyter-widgets/base",
847 | "model_name": "LayoutModel",
848 | "model_module_version": "1.2.0",
849 | "state": {
850 | "_model_module": "@jupyter-widgets/base",
851 | "_model_module_version": "1.2.0",
852 | "_model_name": "LayoutModel",
853 | "_view_count": null,
854 | "_view_module": "@jupyter-widgets/base",
855 | "_view_module_version": "1.2.0",
856 | "_view_name": "LayoutView",
857 | "align_content": null,
858 | "align_items": null,
859 | "align_self": null,
860 | "border": null,
861 | "bottom": null,
862 | "display": null,
863 | "flex": null,
864 | "flex_flow": null,
865 | "grid_area": null,
866 | "grid_auto_columns": null,
867 | "grid_auto_flow": null,
868 | "grid_auto_rows": null,
869 | "grid_column": null,
870 | "grid_gap": null,
871 | "grid_row": null,
872 | "grid_template_areas": null,
873 | "grid_template_columns": null,
874 | "grid_template_rows": null,
875 | "height": null,
876 | "justify_content": null,
877 | "justify_items": null,
878 | "left": null,
879 | "margin": null,
880 | "max_height": null,
881 | "max_width": null,
882 | "min_height": null,
883 | "min_width": null,
884 | "object_fit": null,
885 | "object_position": null,
886 | "order": null,
887 | "overflow": null,
888 | "overflow_x": null,
889 | "overflow_y": null,
890 | "padding": null,
891 | "right": null,
892 | "top": null,
893 | "visibility": null,
894 | "width": null
895 | }
896 | },
897 | "9cfbaf78195a4409aa3c751fef6377e1": {
898 | "model_module": "@jupyter-widgets/controls",
899 | "model_name": "DescriptionStyleModel",
900 | "model_module_version": "1.5.0",
901 | "state": {
902 | "_model_module": "@jupyter-widgets/controls",
903 | "_model_module_version": "1.5.0",
904 | "_model_name": "DescriptionStyleModel",
905 | "_view_count": null,
906 | "_view_module": "@jupyter-widgets/base",
907 | "_view_module_version": "1.2.0",
908 | "_view_name": "StyleView",
909 | "description_width": ""
910 | }
911 | },
912 | "c8379e3f85c74bb9889bf49229de82b3": {
913 | "model_module": "@jupyter-widgets/base",
914 | "model_name": "LayoutModel",
915 | "model_module_version": "1.2.0",
916 | "state": {
917 | "_model_module": "@jupyter-widgets/base",
918 | "_model_module_version": "1.2.0",
919 | "_model_name": "LayoutModel",
920 | "_view_count": null,
921 | "_view_module": "@jupyter-widgets/base",
922 | "_view_module_version": "1.2.0",
923 | "_view_name": "LayoutView",
924 | "align_content": null,
925 | "align_items": null,
926 | "align_self": null,
927 | "border": null,
928 | "bottom": null,
929 | "display": null,
930 | "flex": null,
931 | "flex_flow": null,
932 | "grid_area": null,
933 | "grid_auto_columns": null,
934 | "grid_auto_flow": null,
935 | "grid_auto_rows": null,
936 | "grid_column": null,
937 | "grid_gap": null,
938 | "grid_row": null,
939 | "grid_template_areas": null,
940 | "grid_template_columns": null,
941 | "grid_template_rows": null,
942 | "height": null,
943 | "justify_content": null,
944 | "justify_items": null,
945 | "left": null,
946 | "margin": null,
947 | "max_height": null,
948 | "max_width": null,
949 | "min_height": null,
950 | "min_width": null,
951 | "object_fit": null,
952 | "object_position": null,
953 | "order": null,
954 | "overflow": null,
955 | "overflow_x": null,
956 | "overflow_y": null,
957 | "padding": null,
958 | "right": null,
959 | "top": null,
960 | "visibility": null,
961 | "width": null
962 | }
963 | },
964 | "75125bd2901045ffa9c9b7d3583ee94c": {
965 | "model_module": "@jupyter-widgets/controls",
966 | "model_name": "ProgressStyleModel",
967 | "model_module_version": "1.5.0",
968 | "state": {
969 | "_model_module": "@jupyter-widgets/controls",
970 | "_model_module_version": "1.5.0",
971 | "_model_name": "ProgressStyleModel",
972 | "_view_count": null,
973 | "_view_module": "@jupyter-widgets/base",
974 | "_view_module_version": "1.2.0",
975 | "_view_name": "StyleView",
976 | "bar_color": null,
977 | "description_width": ""
978 | }
979 | },
980 | "78a5fa8234ac4ff182f882908461559a": {
981 | "model_module": "@jupyter-widgets/base",
982 | "model_name": "LayoutModel",
983 | "model_module_version": "1.2.0",
984 | "state": {
985 | "_model_module": "@jupyter-widgets/base",
986 | "_model_module_version": "1.2.0",
987 | "_model_name": "LayoutModel",
988 | "_view_count": null,
989 | "_view_module": "@jupyter-widgets/base",
990 | "_view_module_version": "1.2.0",
991 | "_view_name": "LayoutView",
992 | "align_content": null,
993 | "align_items": null,
994 | "align_self": null,
995 | "border": null,
996 | "bottom": null,
997 | "display": null,
998 | "flex": null,
999 | "flex_flow": null,
1000 | "grid_area": null,
1001 | "grid_auto_columns": null,
1002 | "grid_auto_flow": null,
1003 | "grid_auto_rows": null,
1004 | "grid_column": null,
1005 | "grid_gap": null,
1006 | "grid_row": null,
1007 | "grid_template_areas": null,
1008 | "grid_template_columns": null,
1009 | "grid_template_rows": null,
1010 | "height": null,
1011 | "justify_content": null,
1012 | "justify_items": null,
1013 | "left": null,
1014 | "margin": null,
1015 | "max_height": null,
1016 | "max_width": null,
1017 | "min_height": null,
1018 | "min_width": null,
1019 | "object_fit": null,
1020 | "object_position": null,
1021 | "order": null,
1022 | "overflow": null,
1023 | "overflow_x": null,
1024 | "overflow_y": null,
1025 | "padding": null,
1026 | "right": null,
1027 | "top": null,
1028 | "visibility": null,
1029 | "width": null
1030 | }
1031 | },
1032 | "28abc7c5f1914aa0a97fd7d280d75ee0": {
1033 | "model_module": "@jupyter-widgets/controls",
1034 | "model_name": "DescriptionStyleModel",
1035 | "model_module_version": "1.5.0",
1036 | "state": {
1037 | "_model_module": "@jupyter-widgets/controls",
1038 | "_model_module_version": "1.5.0",
1039 | "_model_name": "DescriptionStyleModel",
1040 | "_view_count": null,
1041 | "_view_module": "@jupyter-widgets/base",
1042 | "_view_module_version": "1.2.0",
1043 | "_view_name": "StyleView",
1044 | "description_width": ""
1045 | }
1046 | },
1047 | "c9305060d9e144a095a0d93b63355e0b": {
1048 | "model_module": "@jupyter-widgets/controls",
1049 | "model_name": "HBoxModel",
1050 | "model_module_version": "1.5.0",
1051 | "state": {
1052 | "_dom_classes": [],
1053 | "_model_module": "@jupyter-widgets/controls",
1054 | "_model_module_version": "1.5.0",
1055 | "_model_name": "HBoxModel",
1056 | "_view_count": null,
1057 | "_view_module": "@jupyter-widgets/controls",
1058 | "_view_module_version": "1.5.0",
1059 | "_view_name": "HBoxView",
1060 | "box_style": "",
1061 | "children": [
1062 | "IPY_MODEL_2afc450d38f948ad89af08983065e9e2",
1063 | "IPY_MODEL_6dfbbde3bbe84ab9803ae729c6697aae",
1064 | "IPY_MODEL_58a1dace4c7b4ad98419fb961b920cc0"
1065 | ],
1066 | "layout": "IPY_MODEL_68db21ef3c204407ad9fa181fea152ae"
1067 | }
1068 | },
1069 | "2afc450d38f948ad89af08983065e9e2": {
1070 | "model_module": "@jupyter-widgets/controls",
1071 | "model_name": "HTMLModel",
1072 | "model_module_version": "1.5.0",
1073 | "state": {
1074 | "_dom_classes": [],
1075 | "_model_module": "@jupyter-widgets/controls",
1076 | "_model_module_version": "1.5.0",
1077 | "_model_name": "HTMLModel",
1078 | "_view_count": null,
1079 | "_view_module": "@jupyter-widgets/controls",
1080 | "_view_module_version": "1.5.0",
1081 | "_view_name": "HTMLView",
1082 | "description": "",
1083 | "description_tooltip": null,
1084 | "layout": "IPY_MODEL_b01677a02efb420cb32c0957d77a2345",
1085 | "placeholder": "",
1086 | "style": "IPY_MODEL_f8cfcf541c684ad889fda00b140e3ab1",
1087 | "value": " 2%"
1088 | }
1089 | },
1090 | "6dfbbde3bbe84ab9803ae729c6697aae": {
1091 | "model_module": "@jupyter-widgets/controls",
1092 | "model_name": "FloatProgressModel",
1093 | "model_module_version": "1.5.0",
1094 | "state": {
1095 | "_dom_classes": [],
1096 | "_model_module": "@jupyter-widgets/controls",
1097 | "_model_module_version": "1.5.0",
1098 | "_model_name": "FloatProgressModel",
1099 | "_view_count": null,
1100 | "_view_module": "@jupyter-widgets/controls",
1101 | "_view_module_version": "1.5.0",
1102 | "_view_name": "ProgressView",
1103 | "bar_style": "danger",
1104 | "description": "",
1105 | "description_tooltip": null,
1106 | "layout": "IPY_MODEL_de9dcf17fb4a4858a5b68922376272e2",
1107 | "max": 938,
1108 | "min": 0,
1109 | "orientation": "horizontal",
1110 | "style": "IPY_MODEL_a32ba491756d40f49fe5b7e795bd3f21",
1111 | "value": 22
1112 | }
1113 | },
1114 | "58a1dace4c7b4ad98419fb961b920cc0": {
1115 | "model_module": "@jupyter-widgets/controls",
1116 | "model_name": "HTMLModel",
1117 | "model_module_version": "1.5.0",
1118 | "state": {
1119 | "_dom_classes": [],
1120 | "_model_module": "@jupyter-widgets/controls",
1121 | "_model_module_version": "1.5.0",
1122 | "_model_name": "HTMLModel",
1123 | "_view_count": null,
1124 | "_view_module": "@jupyter-widgets/controls",
1125 | "_view_module_version": "1.5.0",
1126 | "_view_name": "HTMLView",
1127 | "description": "",
1128 | "description_tooltip": null,
1129 | "layout": "IPY_MODEL_1354164d0f164f7dac7a26e7fa86ca61",
1130 | "placeholder": "",
1131 | "style": "IPY_MODEL_02c67b9bb3904e71a68168c23418fa1c",
1132 | "value": " 22/938 [00:05<03:39, 4.17it/s]"
1133 | }
1134 | },
1135 | "68db21ef3c204407ad9fa181fea152ae": {
1136 | "model_module": "@jupyter-widgets/base",
1137 | "model_name": "LayoutModel",
1138 | "model_module_version": "1.2.0",
1139 | "state": {
1140 | "_model_module": "@jupyter-widgets/base",
1141 | "_model_module_version": "1.2.0",
1142 | "_model_name": "LayoutModel",
1143 | "_view_count": null,
1144 | "_view_module": "@jupyter-widgets/base",
1145 | "_view_module_version": "1.2.0",
1146 | "_view_name": "LayoutView",
1147 | "align_content": null,
1148 | "align_items": null,
1149 | "align_self": null,
1150 | "border": null,
1151 | "bottom": null,
1152 | "display": null,
1153 | "flex": null,
1154 | "flex_flow": null,
1155 | "grid_area": null,
1156 | "grid_auto_columns": null,
1157 | "grid_auto_flow": null,
1158 | "grid_auto_rows": null,
1159 | "grid_column": null,
1160 | "grid_gap": null,
1161 | "grid_row": null,
1162 | "grid_template_areas": null,
1163 | "grid_template_columns": null,
1164 | "grid_template_rows": null,
1165 | "height": null,
1166 | "justify_content": null,
1167 | "justify_items": null,
1168 | "left": null,
1169 | "margin": null,
1170 | "max_height": null,
1171 | "max_width": null,
1172 | "min_height": null,
1173 | "min_width": null,
1174 | "object_fit": null,
1175 | "object_position": null,
1176 | "order": null,
1177 | "overflow": null,
1178 | "overflow_x": null,
1179 | "overflow_y": null,
1180 | "padding": null,
1181 | "right": null,
1182 | "top": null,
1183 | "visibility": null,
1184 | "width": null
1185 | }
1186 | },
1187 | "b01677a02efb420cb32c0957d77a2345": {
1188 | "model_module": "@jupyter-widgets/base",
1189 | "model_name": "LayoutModel",
1190 | "model_module_version": "1.2.0",
1191 | "state": {
1192 | "_model_module": "@jupyter-widgets/base",
1193 | "_model_module_version": "1.2.0",
1194 | "_model_name": "LayoutModel",
1195 | "_view_count": null,
1196 | "_view_module": "@jupyter-widgets/base",
1197 | "_view_module_version": "1.2.0",
1198 | "_view_name": "LayoutView",
1199 | "align_content": null,
1200 | "align_items": null,
1201 | "align_self": null,
1202 | "border": null,
1203 | "bottom": null,
1204 | "display": null,
1205 | "flex": null,
1206 | "flex_flow": null,
1207 | "grid_area": null,
1208 | "grid_auto_columns": null,
1209 | "grid_auto_flow": null,
1210 | "grid_auto_rows": null,
1211 | "grid_column": null,
1212 | "grid_gap": null,
1213 | "grid_row": null,
1214 | "grid_template_areas": null,
1215 | "grid_template_columns": null,
1216 | "grid_template_rows": null,
1217 | "height": null,
1218 | "justify_content": null,
1219 | "justify_items": null,
1220 | "left": null,
1221 | "margin": null,
1222 | "max_height": null,
1223 | "max_width": null,
1224 | "min_height": null,
1225 | "min_width": null,
1226 | "object_fit": null,
1227 | "object_position": null,
1228 | "order": null,
1229 | "overflow": null,
1230 | "overflow_x": null,
1231 | "overflow_y": null,
1232 | "padding": null,
1233 | "right": null,
1234 | "top": null,
1235 | "visibility": null,
1236 | "width": null
1237 | }
1238 | },
1239 | "f8cfcf541c684ad889fda00b140e3ab1": {
1240 | "model_module": "@jupyter-widgets/controls",
1241 | "model_name": "DescriptionStyleModel",
1242 | "model_module_version": "1.5.0",
1243 | "state": {
1244 | "_model_module": "@jupyter-widgets/controls",
1245 | "_model_module_version": "1.5.0",
1246 | "_model_name": "DescriptionStyleModel",
1247 | "_view_count": null,
1248 | "_view_module": "@jupyter-widgets/base",
1249 | "_view_module_version": "1.2.0",
1250 | "_view_name": "StyleView",
1251 | "description_width": ""
1252 | }
1253 | },
1254 | "de9dcf17fb4a4858a5b68922376272e2": {
1255 | "model_module": "@jupyter-widgets/base",
1256 | "model_name": "LayoutModel",
1257 | "model_module_version": "1.2.0",
1258 | "state": {
1259 | "_model_module": "@jupyter-widgets/base",
1260 | "_model_module_version": "1.2.0",
1261 | "_model_name": "LayoutModel",
1262 | "_view_count": null,
1263 | "_view_module": "@jupyter-widgets/base",
1264 | "_view_module_version": "1.2.0",
1265 | "_view_name": "LayoutView",
1266 | "align_content": null,
1267 | "align_items": null,
1268 | "align_self": null,
1269 | "border": null,
1270 | "bottom": null,
1271 | "display": null,
1272 | "flex": null,
1273 | "flex_flow": null,
1274 | "grid_area": null,
1275 | "grid_auto_columns": null,
1276 | "grid_auto_flow": null,
1277 | "grid_auto_rows": null,
1278 | "grid_column": null,
1279 | "grid_gap": null,
1280 | "grid_row": null,
1281 | "grid_template_areas": null,
1282 | "grid_template_columns": null,
1283 | "grid_template_rows": null,
1284 | "height": null,
1285 | "justify_content": null,
1286 | "justify_items": null,
1287 | "left": null,
1288 | "margin": null,
1289 | "max_height": null,
1290 | "max_width": null,
1291 | "min_height": null,
1292 | "min_width": null,
1293 | "object_fit": null,
1294 | "object_position": null,
1295 | "order": null,
1296 | "overflow": null,
1297 | "overflow_x": null,
1298 | "overflow_y": null,
1299 | "padding": null,
1300 | "right": null,
1301 | "top": null,
1302 | "visibility": null,
1303 | "width": null
1304 | }
1305 | },
1306 | "a32ba491756d40f49fe5b7e795bd3f21": {
1307 | "model_module": "@jupyter-widgets/controls",
1308 | "model_name": "ProgressStyleModel",
1309 | "model_module_version": "1.5.0",
1310 | "state": {
1311 | "_model_module": "@jupyter-widgets/controls",
1312 | "_model_module_version": "1.5.0",
1313 | "_model_name": "ProgressStyleModel",
1314 | "_view_count": null,
1315 | "_view_module": "@jupyter-widgets/base",
1316 | "_view_module_version": "1.2.0",
1317 | "_view_name": "StyleView",
1318 | "bar_color": null,
1319 | "description_width": ""
1320 | }
1321 | },
1322 | "1354164d0f164f7dac7a26e7fa86ca61": {
1323 | "model_module": "@jupyter-widgets/base",
1324 | "model_name": "LayoutModel",
1325 | "model_module_version": "1.2.0",
1326 | "state": {
1327 | "_model_module": "@jupyter-widgets/base",
1328 | "_model_module_version": "1.2.0",
1329 | "_model_name": "LayoutModel",
1330 | "_view_count": null,
1331 | "_view_module": "@jupyter-widgets/base",
1332 | "_view_module_version": "1.2.0",
1333 | "_view_name": "LayoutView",
1334 | "align_content": null,
1335 | "align_items": null,
1336 | "align_self": null,
1337 | "border": null,
1338 | "bottom": null,
1339 | "display": null,
1340 | "flex": null,
1341 | "flex_flow": null,
1342 | "grid_area": null,
1343 | "grid_auto_columns": null,
1344 | "grid_auto_flow": null,
1345 | "grid_auto_rows": null,
1346 | "grid_column": null,
1347 | "grid_gap": null,
1348 | "grid_row": null,
1349 | "grid_template_areas": null,
1350 | "grid_template_columns": null,
1351 | "grid_template_rows": null,
1352 | "height": null,
1353 | "justify_content": null,
1354 | "justify_items": null,
1355 | "left": null,
1356 | "margin": null,
1357 | "max_height": null,
1358 | "max_width": null,
1359 | "min_height": null,
1360 | "min_width": null,
1361 | "object_fit": null,
1362 | "object_position": null,
1363 | "order": null,
1364 | "overflow": null,
1365 | "overflow_x": null,
1366 | "overflow_y": null,
1367 | "padding": null,
1368 | "right": null,
1369 | "top": null,
1370 | "visibility": null,
1371 | "width": null
1372 | }
1373 | },
1374 | "02c67b9bb3904e71a68168c23418fa1c": {
1375 | "model_module": "@jupyter-widgets/controls",
1376 | "model_name": "DescriptionStyleModel",
1377 | "model_module_version": "1.5.0",
1378 | "state": {
1379 | "_model_module": "@jupyter-widgets/controls",
1380 | "_model_module_version": "1.5.0",
1381 | "_model_name": "DescriptionStyleModel",
1382 | "_view_count": null,
1383 | "_view_module": "@jupyter-widgets/base",
1384 | "_view_module_version": "1.2.0",
1385 | "_view_name": "StyleView",
1386 | "description_width": ""
1387 | }
1388 | }
1389 | }
1390 | }
1391 | },
1392 | "cells": [
1393 | {
1394 | "cell_type": "markdown",
1395 | "metadata": {
1396 | "id": "view-in-github",
1397 | "colab_type": "text"
1398 | },
1399 | "source": [
1400 | "
"
1401 | ]
1402 | },
1403 | {
1404 | "cell_type": "code",
1405 | "execution_count": 1,
1406 | "metadata": {
1407 | "id": "Ee1NBwatIJlR"
1408 | },
1409 | "outputs": [],
1410 | "source": [
1411 | "import torch\n",
1412 | "import torch.nn as nn\n",
1413 | "import torch.nn.functional as F\n",
1414 | "import torch.optim as optim\n",
1415 | "\n",
1416 | "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n",
1417 | "\n",
1418 | "import torchvision as tv\n",
1419 | "import torchvision.transforms as T\n",
1420 | "\n",
1421 | "from PIL import Image\n",
1422 | "\n",
1423 | "import numpy as np\n",
1424 | "import matplotlib.pyplot as plt\n",
1425 | "\n",
1426 | "from tqdm.notebook import tqdm"
1427 | ]
1428 | },
1429 | {
1430 | "cell_type": "code",
1431 | "source": [
1432 | "START = 1e-4\n",
1433 | "END = .02\n",
1434 | "TIMESTEPS = 300\n",
1435 | "\n",
1436 | "IMAGE_SIZE = 64\n",
1437 | "BATCH_SIZE = 64\n",
1438 | "EPOCHS = 5\n",
1439 | "\n",
1440 | "LR = 1e-3\n",
1441 | "\n",
1442 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'"
1443 | ],
1444 | "metadata": {
1445 | "id": "nzHOnslOMzvU"
1446 | },
1447 | "execution_count": 2,
1448 | "outputs": []
1449 | },
1450 | {
1451 | "cell_type": "code",
1452 | "source": [
1453 | "betas = torch.linspace(start=START, end=END, steps=TIMESTEPS)\n",
1454 | "alphas = 1 - betas\n",
1455 | "\n",
1456 | "alpha_bars = torch.cumprod(alphas, dim=0)\n",
1457 | "alpha_bars_prev = F.pad(alpha_bars[:-1], (1, 0), value=1.0)\n",
1458 | "sqrt_one_over_alpha_bars = torch.sqrt(1. / alpha_bars)\n",
1459 | "sqrt_alpha_bars = torch.sqrt(alpha_bars)\n",
1460 | "sqrt_one_minus_alpha_bars = torch.sqrt(1 - alpha_bars)\n",
1461 | "\n",
1462 | "posterior_variance = betas * (1. - alpha_bars_prev) / (1. - alpha_bars)"
1463 | ],
1464 | "metadata": {
1465 | "id": "_zBev601NSRS"
1466 | },
1467 | "execution_count": 3,
1468 | "outputs": []
1469 | },
1470 | {
1471 | "cell_type": "code",
1472 | "source": [
1473 | "class ConvBlock(nn.Module):\n",
1474 | " def __init__(self, in_channels, out_channels):\n",
1475 | " super().__init__()\n",
1476 | "\n",
1477 | " self.conv_blocks = nn.Sequential(\n",
1478 | " nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),\n",
1479 | " nn.ReLU(inplace=True),\n",
1480 | " nn.BatchNorm2d(out_channels),\n",
1481 | " nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),\n",
1482 | " nn.ReLU(inplace=True),\n",
1483 | " nn.BatchNorm2d(out_channels),\n",
1484 | " )\n",
1485 | "\n",
1486 | " def forward(self, x):\n",
1487 | " return self.conv_blocks(x)"
1488 | ],
1489 | "metadata": {
1490 | "id": "GMK460xiUdg7"
1491 | },
1492 | "execution_count": 4,
1493 | "outputs": []
1494 | },
1495 | {
1496 | "cell_type": "code",
1497 | "source": [
1498 | "class DownBlock(nn.Module):\n",
1499 | " def __init__(self, in_channels, out_channels):\n",
1500 | " super().__init__()\n",
1501 | "\n",
1502 | " self.conv_blocks = nn.Sequential(\n",
1503 | " nn.MaxPool2d(2),\n",
1504 | " ConvBlock(in_channels, out_channels)\n",
1505 | " )\n",
1506 | "\n",
1507 | " def forward(self, x):\n",
1508 | " return self.conv_blocks(x)"
1509 | ],
1510 | "metadata": {
1511 | "id": "DYvOCCVRUtNS"
1512 | },
1513 | "execution_count": 5,
1514 | "outputs": []
1515 | },
1516 | {
1517 | "cell_type": "code",
1518 | "source": [
1519 | "class UpBlock(nn.Module):\n",
1520 | " def __init__(self, in_channels, out_channels):\n",
1521 | " super().__init__()\n",
1522 | "\n",
1523 | " self.up = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=2, stride=2)\n",
1524 | " self.conv_blocks = ConvBlock(in_channels, out_channels)\n",
1525 | "\n",
1526 | " def forward(self, x, residual_inputs):\n",
1527 | " x = self.up(x)\n",
1528 | "\n",
1529 | " diff_y = residual_inputs.size()[2] - x.size()[2]\n",
1530 | " diff_x = residual_inputs.size()[3] - x.size()[3]\n",
1531 | "\n",
1532 | " x = F.pad(x, [diff_x // 2, diff_x - diff_x // 2,\n",
1533 | " diff_y // 2, diff_y - diff_y // 2])\n",
1534 | "\n",
1535 | " x = torch.cat([residual_inputs, x], dim=1)\n",
1536 | " x = self.conv_blocks(x)\n",
1537 | "\n",
1538 | " return x"
1539 | ],
1540 | "metadata": {
1541 | "id": "3QFPrpawUtP6"
1542 | },
1543 | "execution_count": 6,
1544 | "outputs": []
1545 | },
1546 | {
1547 | "cell_type": "code",
1548 | "source": [
1549 | "class OutBlock(nn.Module):\n",
1550 | " def __init__(self, in_channels, num_classes):\n",
1551 | " super().__init__()\n",
1552 | "\n",
1553 | " self.conv = nn.Conv2d(in_channels=in_channels, out_channels=num_classes, kernel_size=1)\n",
1554 | "\n",
1555 | " def forward(self, x):\n",
1556 | " return self.conv(x)"
1557 | ],
1558 | "metadata": {
1559 | "id": "s-V4VnnUUtTM"
1560 | },
1561 | "execution_count": 7,
1562 | "outputs": []
1563 | },
1564 | {
1565 | "cell_type": "code",
1566 | "source": [
1567 | "class UNet(nn.Module):\n",
1568 | " def __init__(self, in_channels, num_classes):\n",
1569 | " super().__init__()\n",
1570 | "\n",
1571 | " self.input_block = ConvBlock(in_channels, 64)\n",
1572 | "\n",
1573 | " self.down_1 = DownBlock(64, 128)\n",
1574 | " self.down_2 = DownBlock(128, 256)\n",
1575 | " self.down_3 = DownBlock(256, 512)\n",
1576 | " self.down_4 = DownBlock(512, 1024)\n",
1577 | "\n",
1578 | " self.up_4 = UpBlock(1024, 512)\n",
1579 | " self.up_3 = UpBlock(512, 256)\n",
1580 | " self.up_2 = UpBlock(256, 128)\n",
1581 | " self.up_1 = UpBlock(128, 64)\n",
1582 | "\n",
1583 | " self.output_block = OutBlock(64, num_classes)\n",
1584 | "\n",
1585 | " self.embedding_up_1 = nn.Linear(1, 128)\n",
1586 | " self.embedding_up_2 = nn.Linear(1, 256)\n",
1587 | " self.embedding_up_3 = nn.Linear(1, 512)\n",
1588 | " self.embedding_up_4 = nn.Linear(1, 1024)\n",
1589 | "\n",
1590 | " def forward(self, x, t):\n",
1591 | " batch_size = x.size(0)\n",
1592 | " down_cache_1 = self.input_block(x)\n",
1593 | "\n",
1594 | " down_cache_2 = self.down_1(down_cache_1)\n",
1595 | " down_cache_3 = self.down_2(down_cache_2)\n",
1596 | " down_cache_4 = self.down_3(down_cache_3)\n",
1597 | " down_cache_5 = self.down_4(down_cache_4)\n",
1598 | "\n",
1599 | " t_embed = self.embedding_up_4(t).view(batch_size, -1, 1, 1)\n",
1600 | " x = self.up_4(down_cache_5 + t_embed, down_cache_4)\n",
1601 | "\n",
1602 | " t_embed = self.embedding_up_3(t).view(batch_size, -1, 1, 1)\n",
1603 | " x = self.up_3(x + t_embed, down_cache_3)\n",
1604 | "\n",
1605 | " t_embed = self.embedding_up_2(t).view(batch_size, -1, 1, 1)\n",
1606 | " x = self.up_2(x + t_embed, down_cache_2)\n",
1607 | "\n",
1608 | " t_embed = self.embedding_up_1(t).view(batch_size, -1, 1, 1)\n",
1609 | " x = self.up_1(x + t_embed, down_cache_1)\n",
1610 | "\n",
1611 | " x = self.output_block(x)\n",
1612 | "\n",
1613 | " return x"
1614 | ],
1615 | "metadata": {
1616 | "id": "nQd2F863UtWj"
1617 | },
1618 | "execution_count": 8,
1619 | "outputs": []
1620 | },
1621 | {
1622 | "cell_type": "code",
1623 | "source": [
1624 | "def get_index_for_batch(values, t, x_shape):\n",
1625 | " batch_size = t.size(0)\n",
1626 | " out = values.gather(-1, t.cpu())\n",
1627 | "\n",
1628 | " return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(DEVICE)"
1629 | ],
1630 | "metadata": {
1631 | "id": "0d_mFHm6N0mp"
1632 | },
1633 | "execution_count": 9,
1634 | "outputs": []
1635 | },
1636 | {
1637 | "cell_type": "code",
1638 | "source": [
1639 | "def forward_process(x_0, t):\n",
1640 | " noise = torch.randn_like(x_0)\n",
1641 | "\n",
1642 | " sqrt_alpha_bars_for_batch = get_index_for_batch(sqrt_alpha_bars, t, x_0.shape).to(DEVICE)\n",
1643 | " sqrt_one_minus_alpha_bars_for_batch = get_index_for_batch(sqrt_one_minus_alpha_bars, t, x_0.shape).to(DEVICE)\n",
1644 | "\n",
1645 | " z = sqrt_one_minus_alpha_bars_for_batch * noise + sqrt_alpha_bars_for_batch * x_0\n",
1646 | "\n",
1647 | " return z, noise"
1648 | ],
1649 | "metadata": {
1650 | "id": "uhOmFQeZN5XY"
1651 | },
1652 | "execution_count": 10,
1653 | "outputs": []
1654 | },
1655 | {
1656 | "cell_type": "code",
1657 | "source": [
1658 | "reverse_transforms = T.Compose([\n",
1659 | " T.Lambda(lambda x: (x + 1) / 2),\n",
1660 | " T.Lambda(lambda x: x.permute(1, 2, 0)),\n",
1661 | " T.Lambda(lambda x: x * 255),\n",
1662 | " T.Lambda(lambda x: x.cpu().numpy().astype(np.uint8)),\n",
1663 | " T.ToPILImage()\n",
1664 | "])\n",
1665 | "\n",
1666 | "def convert_tensor_image(image):\n",
1667 | " if len(image.shape) == 4:\n",
1668 | " image = image[0, :, :, :]\n",
1669 | "\n",
1670 | " image = reverse_transforms(image)\n",
1671 | " return image"
1672 | ],
1673 | "metadata": {
1674 | "id": "-oWTY3w9SxDK"
1675 | },
1676 | "execution_count": 11,
1677 | "outputs": []
1678 | },
1679 | {
1680 | "cell_type": "code",
1681 | "source": [
1682 | "transforms = T.Compose([\n",
1683 | " T.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n",
1684 | " T.RandomHorizontalFlip(),\n",
1685 | " T.ToTensor(),\n",
1686 | " T.Lambda(lambda x: (x * 2) - 1)\n",
1687 | "])\n",
1688 | "\n",
1689 | "train = tv.datasets.CIFAR10(root='./dataset', download=True, transform=transforms, train=True)\n",
1690 | "val = tv.datasets.CIFAR10(root='./dataset', download=True, transform=transforms, train=False)\n",
1691 | "\n",
1692 | "dataset = ConcatDataset([train, val])\n",
1693 | "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)"
1694 | ],
1695 | "metadata": {
1696 | "colab": {
1697 | "base_uri": "https://localhost:8080/"
1698 | },
1699 | "id": "aRMbMaGMRrgS",
1700 | "outputId": "763ce410-8c10-42dd-9468-f922c4055924"
1701 | },
1702 | "execution_count": 12,
1703 | "outputs": [
1704 | {
1705 | "output_type": "stream",
1706 | "name": "stdout",
1707 | "text": [
1708 | "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz\n"
1709 | ]
1710 | },
1711 | {
1712 | "output_type": "stream",
1713 | "name": "stderr",
1714 | "text": [
1715 | "100%|██████████| 170498071/170498071 [00:13<00:00, 12645277.96it/s]\n"
1716 | ]
1717 | },
1718 | {
1719 | "output_type": "stream",
1720 | "name": "stdout",
1721 | "text": [
1722 | "Extracting ./dataset/cifar-10-python.tar.gz to ./dataset\n",
1723 | "Files already downloaded and verified\n"
1724 | ]
1725 | }
1726 | ]
1727 | },
1728 | {
1729 | "cell_type": "code",
1730 | "source": [
1731 | "network = UNet(3, 3).to(DEVICE)"
1732 | ],
1733 | "metadata": {
1734 | "id": "5WW4S17DSgsI"
1735 | },
1736 | "execution_count": 13,
1737 | "outputs": []
1738 | },
1739 | {
1740 | "cell_type": "code",
1741 | "source": [
1742 | "optimizer = optim.Adam(network.parameters(), lr=LR)"
1743 | ],
1744 | "metadata": {
1745 | "id": "DNCJALJwToDx"
1746 | },
1747 | "execution_count": 14,
1748 | "outputs": []
1749 | },
1750 | {
1751 | "cell_type": "code",
1752 | "source": [
1753 | "criterion = nn.MSELoss()"
1754 | ],
1755 | "metadata": {
1756 | "id": "nnYKe7ZwXPth"
1757 | },
1758 | "execution_count": 15,
1759 | "outputs": []
1760 | },
1761 | {
1762 | "cell_type": "code",
1763 | "source": [
1764 | "for epoch in range(1, EPOCHS + 1):\n",
1765 | " print(f'Epoch {epoch} / {EPOCHS}')\n",
1766 | " total_loss = .0\n",
1767 | " for images, _ in tqdm(dataloader):\n",
1768 | " optimizer.zero_grad()\n",
1769 | "\n",
1770 | " images = images.to(DEVICE)\n",
1771 | "\n",
1772 | " batch_size = images.size(0)\n",
1773 | " t = torch.randint(0, TIMESTEPS, (batch_size,)).to(DEVICE).long()\n",
1774 | " noisy_image, noise = forward_process(images, t)\n",
1775 | " noise_preds = network(noisy_image, t.unsqueeze(-1).float())\n",
1776 | "\n",
1777 | " loss = criterion(noise_preds, noise)\n",
1778 | " loss.backward()\n",
1779 | "\n",
1780 | " optimizer.step()\n",
1781 | "\n",
1782 | " total_loss += loss.detach().cpu().item()\n",
1783 | "\n",
1784 | " print(f'Loss: {total_loss:.2f}')"
1785 | ],
1786 | "metadata": {
1787 | "colab": {
1788 | "base_uri": "https://localhost:8080/",
1789 | "height": 474,
1790 | "referenced_widgets": [
1791 | "b0fa66fc56e34465bb2a59f9ce916b71",
1792 | "14019560a279458cb775634b162b2407",
1793 | "4db4b93e47bd4c859b1d6685ed8cf76e",
1794 | "2ef34af8de2c4c7eab85fd3c83a5e1c5",
1795 | "1638fe599f2c409c8d75396bbad174b7",
1796 | "0a2010f11ba345a385997b328168127a",
1797 | "5a39152a73cb4a87bd5bdf6f5a6888cc",
1798 | "b72fab00156243478d0619f3f8a56a1a",
1799 | "38548a20e7744f098a9118b82ff13f95",
1800 | "47d16d89d79543ab8b893ccae490d2f0",
1801 | "05f22484b66641818d3288792ead2efa",
1802 | "7c76076cccf04326a82ec69c5b33c9af",
1803 | "326f80a6a9364bc4a37fd727649c5599",
1804 | "cd209532893744078f13583bc2765ef5",
1805 | "9933cef2fd0c44c2a308f24e5de0b1ea",
1806 | "9d07802c82da4382966ee2bc58bff4e2",
1807 | "7c62a32ae5944b7584b0026533d3bea9",
1808 | "7d6defe1998a4083bbf9f4fdb4e5d599",
1809 | "5a2ad31325dc45d0b795eb88daf147df",
1810 | "c981017d7c39426dbaf566505c817bf4",
1811 | "2ca3d70b283a4eaba7c683d03a53da07",
1812 | "51922432c661481da1ebe8192367c4dd",
1813 | "61d30891744b4dcda25e8eeb71b54aad",
1814 | "718e38085b3b4f5eab8c481913bad4fb",
1815 | "3346c18e5aae424a9705689c62854830",
1816 | "1a782ba0602d4f7ebb0659295f7497f0",
1817 | "80ad93a8096a41f6ba7035bb483ee461",
1818 | "bf1b70fc39db4494aee42122b14a0c8a",
1819 | "9cfbaf78195a4409aa3c751fef6377e1",
1820 | "c8379e3f85c74bb9889bf49229de82b3",
1821 | "75125bd2901045ffa9c9b7d3583ee94c",
1822 | "78a5fa8234ac4ff182f882908461559a",
1823 | "28abc7c5f1914aa0a97fd7d280d75ee0",
1824 | "c9305060d9e144a095a0d93b63355e0b",
1825 | "2afc450d38f948ad89af08983065e9e2",
1826 | "6dfbbde3bbe84ab9803ae729c6697aae",
1827 | "58a1dace4c7b4ad98419fb961b920cc0",
1828 | "68db21ef3c204407ad9fa181fea152ae",
1829 | "b01677a02efb420cb32c0957d77a2345",
1830 | "f8cfcf541c684ad889fda00b140e3ab1",
1831 | "de9dcf17fb4a4858a5b68922376272e2",
1832 | "a32ba491756d40f49fe5b7e795bd3f21",
1833 | "1354164d0f164f7dac7a26e7fa86ca61",
1834 | "02c67b9bb3904e71a68168c23418fa1c"
1835 | ]
1836 | },
1837 | "id": "re16Q5X9Xjij",
1838 | "outputId": "177d142f-c50b-4ced-9ef8-7b0119d95643"
1839 | },
1840 | "execution_count": 20,
1841 | "outputs": [
1842 | {
1843 | "output_type": "stream",
1844 | "name": "stdout",
1845 | "text": [
1846 | "Epoch 1 / 5\n"
1847 | ]
1848 | },
1849 | {
1850 | "output_type": "display_data",
1851 | "data": {
1852 | "text/plain": [
1853 | " 0%| | 0/938 [00:00, ?it/s]"
1854 | ],
1855 | "application/vnd.jupyter.widget-view+json": {
1856 | "version_major": 2,
1857 | "version_minor": 0,
1858 | "model_id": "b0fa66fc56e34465bb2a59f9ce916b71"
1859 | }
1860 | },
1861 | "metadata": {}
1862 | },
1863 | {
1864 | "output_type": "stream",
1865 | "name": "stdout",
1866 | "text": [
1867 | "Loss: 135.68\n",
1868 | "Epoch 2 / 5\n"
1869 | ]
1870 | },
1871 | {
1872 | "output_type": "display_data",
1873 | "data": {
1874 | "text/plain": [
1875 | " 0%| | 0/938 [00:00, ?it/s]"
1876 | ],
1877 | "application/vnd.jupyter.widget-view+json": {
1878 | "version_major": 2,
1879 | "version_minor": 0,
1880 | "model_id": "7c76076cccf04326a82ec69c5b33c9af"
1881 | }
1882 | },
1883 | "metadata": {}
1884 | },
1885 | {
1886 | "output_type": "stream",
1887 | "name": "stdout",
1888 | "text": [
1889 | "Loss: 44.75\n",
1890 | "Epoch 3 / 5\n"
1891 | ]
1892 | },
1893 | {
1894 | "output_type": "display_data",
1895 | "data": {
1896 | "text/plain": [
1897 | " 0%| | 0/938 [00:00, ?it/s]"
1898 | ],
1899 | "application/vnd.jupyter.widget-view+json": {
1900 | "version_major": 2,
1901 | "version_minor": 0,
1902 | "model_id": "61d30891744b4dcda25e8eeb71b54aad"
1903 | }
1904 | },
1905 | "metadata": {}
1906 | },
1907 | {
1908 | "output_type": "stream",
1909 | "name": "stdout",
1910 | "text": [
1911 | "Loss: 39.01\n",
1912 | "Epoch 4 / 5\n"
1913 | ]
1914 | },
1915 | {
1916 | "output_type": "display_data",
1917 | "data": {
1918 | "text/plain": [
1919 | " 0%| | 0/938 [00:00, ?it/s]"
1920 | ],
1921 | "application/vnd.jupyter.widget-view+json": {
1922 | "version_major": 2,
1923 | "version_minor": 0,
1924 | "model_id": "c9305060d9e144a095a0d93b63355e0b"
1925 | }
1926 | },
1927 | "metadata": {}
1928 | },
1929 | {
1930 | "output_type": "error",
1931 | "ename": "KeyboardInterrupt",
1932 | "evalue": "",
1933 | "traceback": [
1934 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1935 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1936 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mtotal_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Loss: {total_loss:.2f}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
1937 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1938 | ]
1939 | }
1940 | ]
1941 | },
1942 | {
1943 | "cell_type": "code",
1944 | "source": [
1945 | "@torch.no_grad()\n",
1946 | "def sample_timestep(x, t):\n",
1947 | " betas_for_batch = get_index_for_batch(betas, t, x.shape)\n",
1948 | " sqrt_one_minus_alpha_for_batch = get_index_for_batch(\n",
1949 | " sqrt_one_minus_alpha_bars, t, x.shape\n",
1950 | " )\n",
1951 | " sqrt_one_over_alpha_bars_for_batch = get_index_for_batch(\n",
1952 | " sqrt_one_over_alpha_bars, t, x.shape\n",
1953 | " )\n",
1954 | "\n",
1955 | " model_mean = sqrt_one_over_alpha_bars_for_batch * (\n",
1956 | " x - betas_for_batch * network(x, t.float()) / sqrt_one_minus_alpha_for_batch\n",
1957 | " )\n",
1958 | "\n",
1959 | " posterior_variance_for_batch = get_index_for_batch(\n",
1960 | " posterior_variance, t, x.shape\n",
1961 | " )\n",
1962 | "\n",
1963 | " if t == 0:\n",
1964 | " return model_mean\n",
1965 | " else:\n",
1966 | " noise = torch.randn_like(x)\n",
1967 | " return model_mean + torch.sqrt(posterior_variance_for_batch) * noise"
1968 | ],
1969 | "metadata": {
1970 | "id": "WO7VBqT3YYei"
1971 | },
1972 | "execution_count": 64,
1973 | "outputs": []
1974 | },
1975 | {
1976 | "cell_type": "code",
1977 | "source": [
1978 | "@torch.no_grad()\n",
1979 | "def sample_plot_image():\n",
1980 | " img_size = IMAGE_SIZE\n",
1981 | " img = torch.randn((1, 3, img_size, img_size), device=DEVICE)\n",
1982 | " num_images = 10\n",
1983 | " stepsize = int(TIMESTEPS / num_images)\n",
1984 | "\n",
1985 | " for i in range(0, TIMESTEPS)[::-1]:\n",
1986 | " t = torch.full((1,), i, device=DEVICE, dtype=torch.long)\n",
1987 | " img = sample_timestep(img, t)\n",
1988 | "\n",
1989 | " return img"
1990 | ],
1991 | "metadata": {
1992 | "id": "WhA9Wk8Mcerx"
1993 | },
1994 | "execution_count": 73,
1995 | "outputs": []
1996 | },
1997 | {
1998 | "cell_type": "code",
1999 | "source": [
2000 | "img = sample_plot_image()"
2001 | ],
2002 | "metadata": {
2003 | "id": "8DGzIGuXeJf-"
2004 | },
2005 | "execution_count": 79,
2006 | "outputs": []
2007 | },
2008 | {
2009 | "cell_type": "code",
2010 | "source": [
2011 | "img = convert_tensor_image(img)\n",
2012 | "plt.imshow(img)\n",
2013 | "plt.show()"
2014 | ],
2015 | "metadata": {
2016 | "id": "Kbc0j49gnQ0I"
2017 | },
2018 | "execution_count": null,
2019 | "outputs": []
2020 | },
2021 | {
2022 | "cell_type": "code",
2023 | "source": [
2024 | "i = 0\n",
2025 | "img = torch.randn((1, 3, IMAGE_SIZE, IMAGE_SIZE), device=DEVICE)\n",
2026 | "t = torch.full((1,), i, device=DEVICE, dtype=torch.long)\n",
2027 | "img = sample_timestep(img, t)"
2028 | ],
2029 | "metadata": {
2030 | "id": "jK--uaBboI5a"
2031 | },
2032 | "execution_count": 81,
2033 | "outputs": []
2034 | },
2035 | {
2036 | "cell_type": "code",
2037 | "source": [
2038 | "img"
2039 | ],
2040 | "metadata": {
2041 | "id": "iFnbbmNXocA2"
2042 | },
2043 | "execution_count": null,
2044 | "outputs": []
2045 | },
2046 | {
2047 | "cell_type": "code",
2048 | "source": [],
2049 | "metadata": {
2050 | "id": "3NgzKu8mov3g"
2051 | },
2052 | "execution_count": 45,
2053 | "outputs": []
2054 | },
2055 | {
2056 | "cell_type": "code",
2057 | "source": [],
2058 | "metadata": {
2059 | "id": "qJxom48CpbRj"
2060 | },
2061 | "execution_count": null,
2062 | "outputs": []
2063 | }
2064 | ]
2065 | }
--------------------------------------------------------------------------------
/23_Value_functions_and_policy_iteration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyN0jZMs8LETkT8dHP5fS1rg",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | " "
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "metadata": {
33 | "id": "Fz5lLqGQzR5M"
34 | },
35 | "outputs": [],
36 | "source": [
37 | "import numpy as np"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "source": [
43 | "class Environment:\n",
44 | " def __init__(self):\n",
45 | " self.x = 2\n",
46 | " self.y = 2\n",
47 | " self.A = (0, 1)\n",
48 | " self.B = (0, 3)\n",
49 | " self.A_next = (1, 4)\n",
50 | " self.B_next = (3, 2)\n",
51 | " self.edge_size = 5\n",
52 | "\n",
53 | " def calculate_next(self, y, x, move):\n",
54 | " if isinstance(move, str):\n",
55 | " move = move.lower()\n",
56 | " new_x, new_y = x, y\n",
57 | "\n",
58 | " if move in ('u', 'up', 0):\n",
59 | " new_y -= 1\n",
60 | " elif move in ('d', 'down', 1):\n",
61 | " new_y += 1\n",
62 | " elif move in ('r', 'right', 2):\n",
63 | " new_x += 1\n",
64 | " elif move in ('l', 'left', 3):\n",
65 | " new_x -= 1\n",
66 | "\n",
67 | " if (y, x) == self.A:\n",
68 | " new_x, new_y = self.A_next\n",
69 | " reward = 10\n",
70 | " elif (y, x) == self.B:\n",
71 | " new_x, new_y = self.B_next\n",
72 | " reward = 5\n",
73 | " elif new_x < 0 or new_x >= self.edge_size:\n",
74 | " new_x, new_y = x, y\n",
75 | " reward = -1\n",
76 | " elif new_y < 0 or new_y >= self.edge_size:\n",
77 | " new_x, new_y = x, y\n",
78 | " reward = -1\n",
79 | " else:\n",
80 | " reward = 0\n",
81 | "\n",
82 | " return new_y, new_x, reward\n",
83 | "\n",
84 | " def step(self, move):\n",
85 | " new_y, new_x, reward = self.calculate_next(self.y, self.x, move)\n",
86 | "\n",
87 | " self.y = new_y\n",
88 | " self.x = new_x\n",
89 | "\n",
90 | " return reward\n",
91 | "\n",
92 | " def predict_reward(self, y, x, move):\n",
93 | " new_y, new_x, reward = self.calculate_next(y, x, move)\n",
94 | " return new_y, new_x, reward\n",
95 | "\n",
96 | " def reset(self):\n",
97 | " self.x = 2\n",
98 | " self.y = 2\n",
99 | "\n",
100 | " @property\n",
101 | " def moves(self):\n",
102 | " return range(4)\n",
103 | "\n",
104 | " def __repr__(self):\n",
105 | " val = ''\n",
106 | " for i in range(self.edge_size):\n",
107 | " for j in range(self.edge_size):\n",
108 | " if i == self.y and j == self.x:\n",
109 | " val += '*'\n",
110 | " else:\n",
111 | " val += '_'\n",
112 | " val += '\\n'\n",
113 | " return val"
114 | ],
115 | "metadata": {
116 | "id": "7fxGRyVVzf1U"
117 | },
118 | "execution_count": 126,
119 | "outputs": []
120 | },
121 | {
122 | "cell_type": "code",
123 | "source": [
124 | "def argmax(values):\n",
125 | " maximum = float('-inf')\n",
126 | " moves = []\n",
127 | "\n",
128 | " for i, value in enumerate(values):\n",
129 | " if value > maximum:\n",
130 | " moves = [i]\n",
131 | " maximum = value\n",
132 | " elif value == maximum:\n",
133 | " moves.append(i)\n",
134 | "\n",
135 | " return moves"
136 | ],
137 | "metadata": {
138 | "id": "SBIJaIUV-ELw"
139 | },
140 | "execution_count": 161,
141 | "outputs": []
142 | },
143 | {
144 | "cell_type": "code",
145 | "source": [
146 | "def calculate_value_function(env, policy, gamma=.9):\n",
147 | " value_function = np.zeros((5, 5))\n",
148 | "\n",
149 | " for _ in range(50):\n",
150 | " for i in range(5):\n",
151 | " for j in range(5):\n",
152 | " temp = 0\n",
153 | " for a in env.moves:\n",
154 | " next_y, next_x, reward = env.predict_reward(i, j, a)\n",
155 | " temp += policy[i, j, a] * (reward + gamma * value_function[next_y, next_x])\n",
156 | " value_function[i, j] = temp\n",
157 | "\n",
158 | " return value_function"
159 | ],
160 | "metadata": {
161 | "id": "v4175Xl_1h84"
162 | },
163 | "execution_count": 134,
164 | "outputs": []
165 | },
166 | {
167 | "cell_type": "code",
168 | "source": [
169 | "def update_policy(value_function):\n",
170 | " new_policy = np.zeros((5, 5, 4))\n",
171 | "\n",
172 | " for i in range(5):\n",
173 | " for j in range(5):\n",
174 | " l = []\n",
175 | " for a in env.moves:\n",
176 | " new_y, new_x, _ = env.predict_reward(i, j, a)\n",
177 | " l.append(value_function[new_y, new_x])\n",
178 | " maximums = argmax(l)\n",
179 | "\n",
180 | " new_policy[i, j, maximums] = 1 / len(maximums)\n",
181 | "\n",
182 | " return new_policy"
183 | ],
184 | "metadata": {
185 | "id": "Vd9Wp36NAAfM"
186 | },
187 | "execution_count": 171,
188 | "outputs": []
189 | },
190 | {
191 | "cell_type": "code",
192 | "source": [
193 | "env = Environment()"
194 | ],
195 | "metadata": {
196 | "id": "AdXjN_rG1fVy"
197 | },
198 | "execution_count": 172,
199 | "outputs": []
200 | },
201 | {
202 | "cell_type": "code",
203 | "source": [
204 | "policy = np.ones((5, 5, 4)) * .25"
205 | ],
206 | "metadata": {
207 | "id": "XQdAZ2KQ8nn5"
208 | },
209 | "execution_count": 173,
210 | "outputs": []
211 | },
212 | {
213 | "cell_type": "code",
214 | "source": [
215 | "for _ in range(10):\n",
216 | " value_function = calculate_value_function(env, policy)\n",
217 | " policy = update_policy(value_function)"
218 | ],
219 | "metadata": {
220 | "id": "ZEUZexM15bAb"
221 | },
222 | "execution_count": 174,
223 | "outputs": []
224 | },
225 | {
226 | "cell_type": "code",
227 | "source": [
228 | "value_function # Q table"
229 | ],
230 | "metadata": {
231 | "colab": {
232 | "base_uri": "https://localhost:8080/"
233 | },
234 | "id": "Q1ey6vjm9Jgw",
235 | "outputId": "4072c658-1386-4fce-82e4-cb718f90e9a2"
236 | },
237 | "execution_count": 175,
238 | "outputs": [
239 | {
240 | "output_type": "execute_result",
241 | "data": {
242 | "text/plain": [
243 | "array([[21.97748529, 24.4194281 , 21.97748529, 19.4194281 , 17.47748529],\n",
244 | " [19.77973676, 21.97748529, 19.77973676, 17.80176308, 16.02158677],\n",
245 | " [17.80176308, 19.77973676, 17.80176308, 16.02158677, 14.4194281 ],\n",
246 | " [16.02158677, 17.80176308, 16.02158677, 14.4194281 , 12.97748529],\n",
247 | " [14.4194281 , 16.02158677, 14.4194281 , 12.97748529, 11.67973676]])"
248 | ]
249 | },
250 | "metadata": {},
251 | "execution_count": 175
252 | }
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "source": [
258 | "moves = ['U', 'D', 'R', 'L']\n",
259 | "\n",
260 | "for i in range(5):\n",
261 | " print('|', end='')\n",
262 | " for j in range(5):\n",
263 | " for a in env.moves:\n",
264 | " if policy[i, j, a] != 0:\n",
265 | " print(moves[a], end='')\n",
266 | " else:\n",
267 | " print(' ', end='')\n",
268 | " print('|', end='')\n",
269 | " print('')"
270 | ],
271 | "metadata": {
272 | "colab": {
273 | "base_uri": "https://localhost:8080/"
274 | },
275 | "id": "N9X8m-LH9UfU",
276 | "outputId": "b619cb70-3ca9-4aef-d4c5-e60917e2ba90"
277 | },
278 | "execution_count": 181,
279 | "outputs": [
280 | {
281 | "output_type": "stream",
282 | "name": "stdout",
283 | "text": [
284 | "| R |UDRL| L|UDRL| L|\n",
285 | "| R |U |U L| L| L|\n",
286 | "| R |U |U L|U L|U L|\n",
287 | "| R |U |U L|U L|U L|\n",
288 | "| R |U |U L|U L|U L|\n"
289 | ]
290 | }
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "source": [],
296 | "metadata": {
297 | "id": "OXFIubz3_IyX"
298 | },
299 | "execution_count": null,
300 | "outputs": []
301 | }
302 | ]
303 | }
--------------------------------------------------------------------------------
/24_Double_Deep_Q_Learning_1_gym_intro.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyM4+5u2E899RPT5eXzO/42D",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | " "
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "source": [
32 | "!pip install -q pyvirtualdisplay\n",
33 | "!pip install -q swig\n",
34 | "!pip install -q gymnasium[all]"
35 | ],
36 | "metadata": {
37 | "colab": {
38 | "base_uri": "https://localhost:8080/"
39 | },
40 | "id": "OqttDVvk0LTA",
41 | "outputId": "5247d6e3-b297-47fd-efde-1fbf7f30090a"
42 | },
43 | "execution_count": 12,
44 | "outputs": [
45 | {
46 | "output_type": "stream",
47 | "name": "stdout",
48 | "text": [
49 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
50 | " Building wheel for box2d-py (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
51 | ]
52 | }
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 117,
58 | "metadata": {
59 | "id": "ufYQ8d50z_6t"
60 | },
61 | "outputs": [],
62 | "source": [
63 | "import gymnasium as gym\n",
64 | "import numpy as np\n",
65 | "import matplotlib.pyplot as plt\n",
66 | "\n",
67 | "import random\n",
68 | "from collections import namedtuple, deque\n",
69 | "\n",
70 | "from IPython.display import clear_output"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "source": [
76 | "env = gym.make(\"LunarLander-v2\", render_mode=\"rgb_array\")"
77 | ],
78 | "metadata": {
79 | "id": "jNSDYwlI0HL4"
80 | },
81 | "execution_count": 28,
82 | "outputs": []
83 | },
84 | {
85 | "cell_type": "code",
86 | "source": [
87 | "observation, info = env.reset()"
88 | ],
89 | "metadata": {
90 | "id": "G8V5e4Lb0Zgj"
91 | },
92 | "execution_count": 29,
93 | "outputs": []
94 | },
95 | {
96 | "cell_type": "code",
97 | "source": [
98 | "observation"
99 | ],
100 | "metadata": {
101 | "colab": {
102 | "base_uri": "https://localhost:8080/"
103 | },
104 | "id": "Bg3ih6gk27ll",
105 | "outputId": "cf92c132-f634-4f83-9c72-c5b3df1e52cd"
106 | },
107 | "execution_count": 32,
108 | "outputs": [
109 | {
110 | "output_type": "execute_result",
111 | "data": {
112 | "text/plain": [
113 | "array([-0.00484533, 1.408985 , -0.4908019 , -0.08600599, 0.00562137,\n",
114 | " 0.11117391, 0. , 0. ], dtype=float32)"
115 | ]
116 | },
117 | "metadata": {},
118 | "execution_count": 32
119 | }
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "source": [
125 | "env.action_space.sample()"
126 | ],
127 | "metadata": {
128 | "colab": {
129 | "base_uri": "https://localhost:8080/"
130 | },
131 | "id": "z5qq_0PW5YXC",
132 | "outputId": "ee96bf80-3dce-4141-8f79-4ac61bd0ce92"
133 | },
134 | "execution_count": 57,
135 | "outputs": [
136 | {
137 | "output_type": "execute_result",
138 | "data": {
139 | "text/plain": [
140 | "1"
141 | ]
142 | },
143 | "metadata": {},
144 | "execution_count": 57
145 | }
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "source": [
151 | "fig, axs = plt.subplots(1, 1, figsize=(5, 5))\n",
152 | "\n",
153 | "for _ in range(100):\n",
154 | " action = env.action_space.sample() # agent policy that uses the observation and info\n",
155 | " observation, reward, terminated, truncated, info = env.step(action)\n",
156 | "\n",
157 | " if terminated or truncated:\n",
158 | " observation, info = env.reset()\n",
159 | "\n",
160 | " axs.imshow(env.render())\n",
161 | " axs.axis('off')\n",
162 | " plt.pause(.01)\n",
163 | "\n",
164 | "env.close()"
165 | ],
166 | "metadata": {
167 | "colab": {
168 | "base_uri": "https://localhost:8080/",
169 | "height": 295
170 | },
171 | "id": "xKmBOW9_29UQ",
172 | "outputId": "e6851b6b-8cac-4260-ccef-88b47bc5a763"
173 | },
174 | "execution_count": 52,
175 | "outputs": [
176 | {
177 | "output_type": "display_data",
178 | "data": {
179 | "text/plain": [
180 | ""
181 | ],
182 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEWCAYAAACqitpwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWYUlEQVR4nO3de3CU5aHH8d+7l2w2uxty3QByC6AogiXgUREvnVPl1EqtY0vtOFo7xzmdTosd/UsdZ+w4c+aIh3qmc04RRS5tgZSbB4TjJYiUWkC5KBCu4WJISMIl93uyt/f8sSYEDBjgSTaB72fmnXd3k+w+2ZnsN8/7vvuuZdu2LQAADHIkegAAgGsPcQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYJyrp99oWVZvjgMAMED05MQuzFwAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wA4AIOy9LoG27Qd266SUlud6KHMyARFwC4QJLbrWE5ORrk9ys7PT3RwxmQLNu27R59o2X19lgAoN8YnJkpf0qKviorU6xnL5PXjZ5kg7gAAC5LT7LBZjEAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAQAYR1wAAMYRFwCAccQFAGAccQEAGEdcAADGERcAgHHEBQBgHHEBABhHXAAAxhEXAIBxxAUAYBxxAYABwuVwyO0YGC/bA2OUAHCdC3q9WvPQQ1o/Y4aGpKQkejjfypXoAQAAvt2EzExNHzFCkpSXna1TJSUJHtGlWbZt2z36Rsvq7bEAAC7CYVn6t/Hj5XY69ea+fYr17KW7V/QkG8QFAHBZepIN9rkAAIwjLgBwDXC5kuRw9J/d6MQFAAYwy7KUnTVGk/Me14gRUxI9nE79J3MAcB1zu1OU7AnI6w0oGouqurr4kt/v8fiVkT5Kt9w4XSODUxVMHa9jpz9RdXWxGhvP9tGoL464AEBCWEpKSlFO9jilpuYoK32MhufcoRRPunYdXqLa2lLFYtFv/JQnya9hwybpppH/rNwh9yrNm6skZ/x9LzcOma6Km/do+xd/USwW6etf6DzEBQD6gNPpVoo3XUken4JZ4zRiyBSl+nOU5stVmneEkpx+OSyXbMU0JGuCDiZ9qLa2RkmSw+FUauoQjRg6ReNGP6ihaXkKJA+V00o670hejzOg2296WqfO7NeJ0h2J+lUlERcA6BNDB09U3oSZGpIxSSnuDCW70pTk9EmyZCumSKxN9e0n1dR+Wk6XU05nkiRLmZmjNHrk3bp5xL8oOGi8fEk5cljObh/DsixlpIzVlFufVHVNsRqbKvv0d+yKuABAH/A4UxVICWqwf6KidlihaLPq20+qLVyv6sajqq4vVmtzvcpP79Opyn1yOJyaOP6HmjTuJ7oh7XYluQIXjcqFbhz8gConFmnL9rcTtnmMuABAH6iuO6GGljM6Wf+5wtEWtbU1qrRip85WH1Fzc7XqGsoUCjdLspSamqO8CTM1eeyTCniG9jgqUnz2kuxO022jf6rik5/pZNmXvfdLXWocvEMfAPpG3q0zFY2EVXZqj+qbKxSLRWXb53bau93JGjn8Dt096Vcann6HXA7vFb/2xuyoik69r4/+8arq60+Z+hUkcfoXADAm4HbrN+PHa2N5uXZVVRm/f6czSXff8YxuG/UTZfpvuqzZysW0Req09cB/a+vOBUY3j3H6FwAw5KmxYzU1GNTzEycqqRc+U8VhORUcdItcLo8sQy/NHucg5d34pEaOuF1S304Q2OeCfud3v5NGjpTa26V166TCwvjtTU1SfX1ix9ZfjRolvfKKZNtSTY00b178+YvFpKoqKRxO9AgHvvdKSjQuLU0FZWUKx2LG7z8cadPu/auUPCVV3q+PJrtalmUpzTtS/3Tr06qpLTG+eeySj81mMfQ3S5dK48adu27b8RfJffukbdvi18+elT74IHFj7G9uuUX6y1/Ov822pba2+PN05owUiUg7dkhFRYkZI3pm2qRf69ZbHlKOb6KcjqSrvj/bthWKNmnb4f/RPz6f1+0bM6/kPr8NMxf0S13/l7EsyeGQ8vKkSZPit7W1SU8/HX8BbWiQ5s+Pz2piMen0aam5OSHDTqgL//+zLCklRfrxj+PXbVuqrj73PO3ZI61ZE/9ae7tUWtqnw8VF7ClaqaysXCW70pSePPqq/7G3LEseV0C35f5UpRVf6ETJ5z2Kw9UiLhhQOv7OvF5pzJhzt8+bF1+HQtLf/ia9+iqbgjp0PGeWJWVnxxdJuukmaebMeHRqa6U5c6SNGxM3TsQ1t1Zpz/7VCviz5XXFN4+Z2HKUnjJKd932jKpritXQcNrASC+NuGDA6vjnKxyWjhyJh6WtTdqwIb4JCN/U9R/W6ur4bMW245vNvkzM2yHQjdIzu1R0bLO8EzI02H+bnNbVbx5zWC7lZt2nyRMe19ad8xUOtxoY6cURF/RrXV8MYzEpGo3fVlEh/fWv5/YrbN4stfbu38qA0vV5i0Ti18Ph+PO0d2/8uSwtJSj9lW1Htf/oexocvFkp7gylJecamb24nT6NueF+HSv9u8rK9lz9QC+BuKBfsu14SI4di++8j0bjO6M3bIh/PRqNHz2GczqC0toq7d4df46am+MRLi8/97VQKHFjRM+1ttfqi33LlRrIUZIrIJ87+6ruz7ZjOtt4SNv2vqWKiv2GRnlxxAX90G/08stzFYlIBw/Gd9Dj0jyeXG3derM++OBDtbRIn30WjwsGtorKPSo6tkm+iVlKdqbJ6XBf0f3YdkyVTYe1addsHf3q731yvjHign5oqjZsmJvoQQwoTmeGSkpu1oYNHyZ6KDDswLH/U1bmaLmH+5Ttu/myf962Y6qo261Nu/5TX53Y0idHikm8Qx8A+rXW9lrt3r9K9c1laglXXVYcYnZUpxsKtXHHf6i4ZFufhUUiLgDQ752pOaT9h9erqqlItnq2vdO2YzpZs10fbfudiks+N/LmyctBXADAALfDoRt8Pjl65Wwmtg5/tUHlpwtV21os27706WdidlRltTv18Wf/rtKyLyT13YylA3EBAAN+deut2vDDH2pm13f3GtQWqtcX+/6q6objaglXX/T7bNvWiapP9eHWV1RWsfdbQ9RbiAsAGDApM1OOr9e9paruuAoPvaealuOKxs4/pty2bbWG63T49Hpt3P6aTp3er0TMWDpwtBgAGPDC55/r0dxc5R892ouPYutoySYNDU6Qxx1Qdsp4WZalmB1VbWuxjpVv0vbdi1RTl/gTxREXADCgqq1NCw4d6vXHaQ81avvexcrOulE+d7acDo/ONOxTYdH/6vDxjZJly7IcCdsc1oG4AMAA09B8WoUH18qeEFOovVl7DqzWsZK/KzV1sL4z8REdL96miooDYrMYAKDHbDumA8fflyfZrxNl21RdW6y0wDA9eM+Lyg3ep1uGPaKdhxdr34H3FYm0JWSMxAUABqBYLKxd+5ZKkvwpQT1470saGZwqrztTreE6tbc3KxptT9j4OFoMAAa4ITnjFUy/RT53UG2ROm07MFeHj2zs03fkX6jHcXnnnXc0bdo0+Xw+PvIYAPqJ0cPv1T1Tfq1072hFYyEdO71Rh4983Ccnp7yUHsflmWee0ccff6wVK1boqaeeUjAY7M1xAQC+hcvl0YjheRqcepuclltltTv1j11/VHNLjdHHCbhcujMrS67LmFj0eJ+LZVnyer16+OGH9d3vfldFRUVasWKFVq1apYqKCrW3J27bHgBcfyzdlfevyhv7hJKcfjW0l2nXoSWqrDpu/JEeHzVK41JTFUxO1vqysh79zBXtc/H5fMrLy9Ps2bO1ZcsWvf7667r//vvldDqv5O4AAJcpI2O4Rg+7V6me4QpFW7Tnq+UqOvZJr7y/paq9XfbX65664h36lmXJsiwNHTpUs2bN0sqVK7V27Vo98sgjCgaDcjg4VgAAeoPb7dV9d85SzqAJitkRHT+zUV/szVc43Duf9f1+ebn+69AhHWxr05tvvtmjnzFyKLLT6VQwGNTDDz+shx56SFu3btXatWu1cuVKlXd8vioAwBBLDTWVqvIXybIsfXlwuRoaz/Tao4VjMVWGQnpn7lz94he/6NHPGH2fi2VZcjqduu+++3TnnXdq1qxZWrNmjZYtW6ajR4+qiQ89B4CrFg636NNd/61Rp6YqNT1HJWU7evXxAoGA3njjDf385z/v+dHCdh84e/asvWTJEnvGjBl2cnKyrfg5CVhYul2WLl2a8DEMtGXKlCn2888/n/BxsFx7yw9+8AN71apVl/263yfv0M/OztYTTzyhRx99VIWFhVq0aJE2bdqk0tJSRaN9++loAICeefzxx7VgwQL5fL7L/tk+O/2Lw+GQ3+/X3Xffrbvuukt79uzRRx99pD//+c8qLi5WOBzuq6EAAC7Bsiw9+uijmjdvnvx+/xXdR0IO6XI4HJo8ebJeeOEFbdmyRQsXLtSDDz6oQYMGJWI4AICvDR8+XL/85S+Vn5+v9PT0K76fhJ640ul0Kjs7W08++aRmzJihHTt2aMmSJSooKFBtbS2bzACgDw0dOlQrV67U7bffLpfr6vLQL86KbFmW0tPTNX36dH3ve99TUVGR8vPztX79eh04cECxWGI/9AYArnVjxozRypUrlZeXZ+T8kZZtJ/C0mZcQjUZ14sQJbd68We+8844OHjyoxsbGRA8LV8nhcMjpdF50cblcGj9+vIqKihI91AHF7/fL4/Fo7969/DOGy5KUlKRhw4Zp9erVysvLM3a//TYuHWzbVnNzszZv3qxly5Z1bjJD/+JyuRQIBOTz+RQIBOT3+7tdAoGAUlNTNWjQIAUCgW7Xfr+fM29fgaNHj2r+/PlavXq1Tp48mejhYABwOBx65ZVX9Nxzzyk1NdXo312/j0sH27bV1NSk48eP609/+pPef/99lZaWKhQKJXpoA1bHTMHtdp+37nrZ4/EoLS1N6enpl1x8Pp+SkpLkcrk61263+xsLpwXqXaFQSMXFxVq0aJHeffddlZSUKBJJ7KnX0T95vV69+uqrevbZZ5WcnGz8/gdMXDp0DLe4uFhr167VihUrtGNH7747daDyer3KyMjoNg5paWlKTU1VIBDonGl0XO56vafHtzPT6F9s29bx48e1ZMkS5efn69ixY4keEvoRy7L0+9//Xs8//3yv/e0OuLh0FY1GVVNTo8rKStXW1qqurk41NTXnrWtrazuXlpYWhcNhhUKhzqW76/2FZVlKSkqSx+ORx+NRcnJy52WPxyO/36/s7OzzlmAwqKysLAWDQfl8vs4ZQ3czFGYR175wOKzy8nLl5+dr+fLlKioqYrZ/nZs6dapmz56tqVOnyu1299rjDOi4XA7bttXe3q7m5ubOpampqdvrDQ0Nqq+vP29dV1d33vXm5ubOWdSF6wsvS+f+s+84m3TH5UGDBikjI0OZmZnKyMg473J6evo39k+kpqZ2Ll6vlxkDeqy4uFhr1qzR4sWLdfDgQXb8X4ceeOABvf3228rNze31147rJi49Zdu2YrGYYrGYotHoRdehUKhzZtR1hnThbW63u3NWkZWV1bkOBoPKzMyUx+Pp9ogpp9Mph8NBPGBUNBrV2bNntW7dOi1atEj79u1Ta2vvnKYd/cv3v/99LV++3PiO+4shLsB1yLZtlZeXq6CgQPPmzVNhYWG/2iQMc3w+n5599lk999xzysnJ6bPHJS7AdSwWi6m+vl4FBQV6++23tX37dmYy15ChQ4dqzpw5mjlzZq/uX+kOcQEg27ZVWVmpTz/9VH/4wx/05ZdfEpkBLiMjQ++++67uvffehHwEPXEB0Mm2bYVCIRUUFGjhwoX65JNP1NzcnOhh4TJNmzZNr732mu65556E7bclLgC6VVdXp88++0xz587Vli1b1NDQ8I2jING/uN1uPfbYY5o9e7ZGjRqV0LEQFwAX1XH05MaNG7V06VKtW7dODQ0NiR4WLuLFF1/USy+9pEAgkPAjTYkLgB5pbGzU7t27tWDBAhUUFKiqqor3yvQT6enpmjVrll566SV5vd5ED0cScQFwGTpmMtu2bdOyZcu0atUq1dTUJHpY17Xc3FzNmTNHjz32WMJnK10RFwBXpKWlRYcPH9bChQu1fv16lZeXM5PpY3l5eXrrrbc0efLkq/5wL9OIC4CrEolEVFhYqGXLlik/P19nzpxhx38fmDZtmhYvXqyxY8f2qxlLB+IC4Kp1HMJcVlam+fPn67333tORI0eITC9ISUnRE088odmzZyszMzPRw7ko4gLAqGg0qqKiIq1YsUKLFi1SRUUFm8sMSU5O1ssvv6xZs2YpLS0t0cO5JOICwDjbthWNRlVdXa358+dr586dam5uVktLy3lnIu+4jfh8u+zsbM2ZM0c/+9nP5PF4Ej2cb0VcAPS6UCh0Xly6RqalpUU1NTWqqqpSZWWlqqurVVlZqaqqKlVVVammpkaRSKTzbOVdz1Aei8Wu+U1vlmVp9OjR+uMf/6jp06cPmM9hIi4AEqrr5yF1t0QiEdXW1naGpuu6urpatbW1amxsVFNTkxobGzuXpqYmNTQ0DPizPT/wwAN64403NHHixH654/5iiAuAAS0ajaqtrU2tra1qbW3tvNyxbmhoUGVlZbdLVVWVQqHQNwJ3sct9+X1Op1M/+tGP9Prrr2vs2LG9/jyaRlwAXNO6+4TYrrdFo1FFo1FFIpHOddfLPfna1dzW3dcikYjS0tL029/+Vn6/f0DNWDoQFwCAcQNjzxAAYEAhLgAA44gLAMA44gIAMI64AACMIy4AAOOICwDAOOICADCOuAAAjCMuAADjiAsAwDjiAgAwjrgAAIwjLgAA44gLAMA44gIAMI64AACMIy4AAOOICwDAOOICADCOuAAAjCMuAADjiAsAwDjiAgAwjrgAAIwjLgAA44gLAMA44gIAMI64AACMIy4AAOOICwDAOOICADCOuAAAjCMuAADjiAsAwDjiAgAwjrgAAIwjLgAA44gLAMA44gIAMI64AACMIy4AAOOICwDAOOICADCOuAAAjCMuAADjiAsAwDjiAgAwjrgAAIwjLgAA4/4fpN7J9XcZ78wAAAAASUVORK5CYII=\n"
183 | },
184 | "metadata": {}
185 | }
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "source": [
191 | "MemoryBlock = namedtuple('MemoryBlock', ('current', 'action', 'reward', 'next'))"
192 | ],
193 | "metadata": {
194 | "id": "CatEHV-C5Ard"
195 | },
196 | "execution_count": 62,
197 | "outputs": []
198 | },
199 | {
200 | "cell_type": "code",
201 | "source": [
202 | "block = MemoryBlock(1, 1, 2, 1)"
203 | ],
204 | "metadata": {
205 | "id": "awpIYsZg6Phm"
206 | },
207 | "execution_count": 64,
208 | "outputs": []
209 | },
210 | {
211 | "cell_type": "code",
212 | "source": [
213 | "block.action"
214 | ],
215 | "metadata": {
216 | "colab": {
217 | "base_uri": "https://localhost:8080/"
218 | },
219 | "id": "j0Xnvz_C6d9B",
220 | "outputId": "f2a5c582-9655-4fbb-95ee-39a561cecc76"
221 | },
222 | "execution_count": 68,
223 | "outputs": [
224 | {
225 | "output_type": "execute_result",
226 | "data": {
227 | "text/plain": [
228 | "1"
229 | ]
230 | },
231 | "metadata": {},
232 | "execution_count": 68
233 | }
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "source": [
239 | "memory = deque(maxlen=10)"
240 | ],
241 | "metadata": {
242 | "id": "LlcxJRes6gH4"
243 | },
244 | "execution_count": 80,
245 | "outputs": []
246 | },
247 | {
248 | "cell_type": "code",
249 | "source": [
250 | "memory"
251 | ],
252 | "metadata": {
253 | "colab": {
254 | "base_uri": "https://localhost:8080/"
255 | },
256 | "id": "SaUTuWou6r81",
257 | "outputId": "703aa9af-8471-4647-a8eb-4270d08266ee"
258 | },
259 | "execution_count": 81,
260 | "outputs": [
261 | {
262 | "output_type": "execute_result",
263 | "data": {
264 | "text/plain": [
265 | "deque([])"
266 | ]
267 | },
268 | "metadata": {},
269 | "execution_count": 81
270 | }
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "source": [
276 | "memory.append(100)"
277 | ],
278 | "metadata": {
279 | "id": "fSD6S3pK61Dh"
280 | },
281 | "execution_count": 82,
282 | "outputs": []
283 | },
284 | {
285 | "cell_type": "code",
286 | "source": [
287 | "memory"
288 | ],
289 | "metadata": {
290 | "colab": {
291 | "base_uri": "https://localhost:8080/"
292 | },
293 | "id": "AX6uGERC62dL",
294 | "outputId": "f26a33f1-4838-45a3-a1b7-06cf638f1ab2"
295 | },
296 | "execution_count": 83,
297 | "outputs": [
298 | {
299 | "output_type": "execute_result",
300 | "data": {
301 | "text/plain": [
302 | "deque([100])"
303 | ]
304 | },
305 | "metadata": {},
306 | "execution_count": 83
307 | }
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "source": [
313 | "for i in range(10):\n",
314 | " memory.append(i)\n",
315 | " print(memory)"
316 | ],
317 | "metadata": {
318 | "colab": {
319 | "base_uri": "https://localhost:8080/"
320 | },
321 | "id": "W87WmNCE65fe",
322 | "outputId": "7f177f5b-a448-4fb4-b04d-aae531d306b2"
323 | },
324 | "execution_count": 84,
325 | "outputs": [
326 | {
327 | "output_type": "stream",
328 | "name": "stdout",
329 | "text": [
330 | "deque([100, 0], maxlen=10)\n",
331 | "deque([100, 0, 1], maxlen=10)\n",
332 | "deque([100, 0, 1, 2], maxlen=10)\n",
333 | "deque([100, 0, 1, 2, 3], maxlen=10)\n",
334 | "deque([100, 0, 1, 2, 3, 4], maxlen=10)\n",
335 | "deque([100, 0, 1, 2, 3, 4, 5], maxlen=10)\n",
336 | "deque([100, 0, 1, 2, 3, 4, 5, 6], maxlen=10)\n",
337 | "deque([100, 0, 1, 2, 3, 4, 5, 6, 7], maxlen=10)\n",
338 | "deque([100, 0, 1, 2, 3, 4, 5, 6, 7, 8], maxlen=10)\n",
339 | "deque([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], maxlen=10)\n"
340 | ]
341 | }
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "source": [
347 | "env = gym.make(\"LunarLander-v2\", render_mode=\"rgb_array\")\n",
348 | "observation, info = env.reset()\n",
349 | "\n",
350 | "memory = deque(maxlen=150)\n",
351 | "\n",
352 | "for _ in range(100):\n",
353 | " action = env.action_space.sample() # agent policy that uses the observation and info\n",
354 | "\n",
355 | " current = observation.copy()\n",
356 | " observation, reward, terminated, truncated, info = env.step(action)\n",
357 | " block = MemoryBlock(current, action, reward, observation)\n",
358 | " memory.append(block)\n",
359 | "\n",
360 | " if terminated or truncated:\n",
361 | " observation, info = env.reset()\n",
362 | "\n",
363 | "\n",
364 | "env.close()"
365 | ],
366 | "metadata": {
367 | "id": "V0EB3J8F6-GP"
368 | },
369 | "execution_count": 87,
370 | "outputs": []
371 | },
372 | {
373 | "cell_type": "code",
374 | "source": [
375 | "b1 = memory[0]\n",
376 | "b2 = memory[1]"
377 | ],
378 | "metadata": {
379 | "id": "Hd4Jcpkn7gS1"
380 | },
381 | "execution_count": 96,
382 | "outputs": []
383 | },
384 | {
385 | "cell_type": "code",
386 | "source": [
387 | "b1"
388 | ],
389 | "metadata": {
390 | "colab": {
391 | "base_uri": "https://localhost:8080/"
392 | },
393 | "id": "MKW7UpO87hYY",
394 | "outputId": "5338c33f-7b14-4bf1-94f7-3e09ae6dac7f"
395 | },
396 | "execution_count": 97,
397 | "outputs": [
398 | {
399 | "output_type": "execute_result",
400 | "data": {
401 | "text/plain": [
402 | "MemoryBlock(current=array([ 0.0032485 , 1.4054555 , 0.32901463, -0.24287517, -0.00375733,\n",
403 | " -0.07452664, 0. , 0. ], dtype=float32), action=0, reward=-1.3044605623392442, next=array([ 0.006497 , 1.3994141 , 0.32857022, -0.26852667, -0.00744009,\n",
404 | " -0.07366162, 0. , 0. ], dtype=float32))"
405 | ]
406 | },
407 | "metadata": {},
408 | "execution_count": 97
409 | }
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "source": [
415 | "b2"
416 | ],
417 | "metadata": {
418 | "colab": {
419 | "base_uri": "https://localhost:8080/"
420 | },
421 | "id": "FuteA5op788U",
422 | "outputId": "fdca541a-a4a7-47dd-e211-93400c3ce372"
423 | },
424 | "execution_count": 98,
425 | "outputs": [
426 | {
427 | "output_type": "execute_result",
428 | "data": {
429 | "text/plain": [
430 | "MemoryBlock(current=array([ 0.006497 , 1.3994141 , 0.32857022, -0.26852667, -0.00744009,\n",
431 | " -0.07366162, 0. , 0. ], dtype=float32), action=2, reward=-1.3291067603536988, next=array([ 0.00991163, 1.3933325 , 0.34438816, -0.2703012 , -0.01033648,\n",
432 | " -0.05793293, 0. , 0. ], dtype=float32))"
433 | ]
434 | },
435 | "metadata": {},
436 | "execution_count": 98
437 | }
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "source": [
443 | "batch = MemoryBlock(*zip(*memory))"
444 | ],
445 | "metadata": {
446 | "id": "b0Z0mrcu79K8"
447 | },
448 | "execution_count": 115,
449 | "outputs": []
450 | },
451 | {
452 | "cell_type": "code",
453 | "source": [
454 | "batch.current"
455 | ],
456 | "metadata": {
457 | "colab": {
458 | "base_uri": "https://localhost:8080/"
459 | },
460 | "id": "t4YaNqJK8C4S",
461 | "outputId": "706880f1-a7d6-453e-f28e-3aa8917b9f99"
462 | },
463 | "execution_count": 116,
464 | "outputs": [
465 | {
466 | "output_type": "execute_result",
467 | "data": {
468 | "text/plain": [
469 | "(array([ 0.0032485 , 1.4054555 , 0.32901463, -0.24287517, -0.00375733,\n",
470 | " -0.07452664, 0. , 0. ], dtype=float32),\n",
471 | " array([ 0.006497 , 1.3994141 , 0.32857022, -0.26852667, -0.00744009,\n",
472 | " -0.07366162, 0. , 0. ], dtype=float32),\n",
473 | " array([ 0.00991163, 1.3933325 , 0.34438816, -0.2703012 , -0.01033648,\n",
474 | " -0.05793293, 0. , 0. ], dtype=float32),\n",
475 | " array([ 0.01323576, 1.3866549 , 0.33303937, -0.29678914, -0.01095505,\n",
476 | " -0.0123724 , 0. , 0. ], dtype=float32),\n",
477 | " array([ 0.01642313, 1.3809104 , 0.32008106, -0.25532517, -0.01228218,\n",
478 | " -0.02654489, 0. , 0. ], dtype=float32),\n",
479 | " array([ 0.0196105 , 1.3745655 , 0.32008517, -0.28200608, -0.01360798,\n",
480 | " -0.02651851, 0. , 0. ], dtype=float32),\n",
481 | " array([ 0.02286119, 1.3676276 , 0.32800844, -0.30837443, -0.01652008,\n",
482 | " -0.05824757, 0. , 0. ], dtype=float32),\n",
483 | " array([ 0.0260045 , 1.3614005 , 0.31786153, -0.27680793, -0.0200112 ,\n",
484 | " -0.06982894, 0. , 0. ], dtype=float32),\n",
485 | " array([ 0.02913609, 1.3556799 , 0.31680828, -0.25429958, -0.02362819,\n",
486 | " -0.07234631, 0. , 0. ], dtype=float32),\n",
487 | " array([ 0.03228216, 1.350129 , 0.3182506 , -0.24676459, -0.02723356,\n",
488 | " -0.07211356, 0. , 0. ], dtype=float32),\n",
489 | " array([ 0.03551693, 1.3439851 , 0.32936293, -0.2731835 , -0.03306038,\n",
490 | " -0.11654727, 0. , 0. ], dtype=float32),\n",
491 | " array([ 0.03875179, 1.3372415 , 0.32938066, -0.29985175, -0.03888537,\n",
492 | " -0.11651033, 0. , 0. ], dtype=float32),\n",
493 | " array([ 0.04192553, 1.3298929 , 0.3217017 , -0.32672888, -0.04317166,\n",
494 | " -0.08573384, 0. , 0. ], dtype=float32),\n",
495 | " array([ 0.04518175, 1.3219506 , 0.33203125, -0.353187 , -0.04952039,\n",
496 | " -0.12698598, 0. , 0. ], dtype=float32),\n",
497 | " array([ 0.04841232, 1.314907 , 0.32982442, -0.31328657, -0.05622081,\n",
498 | " -0.13402088, 0. , 0. ], dtype=float32),\n",
499 | " array([ 0.05172186, 1.3072615 , 0.33972174, -0.34014696, -0.06490061,\n",
500 | " -0.17361203, 0. , 0. ], dtype=float32),\n",
501 | " array([ 0.05503168, 1.2990172 , 0.33974722, -0.36681646, -0.07357845,\n",
502 | " -0.17357238, 0. , 0. ], dtype=float32),\n",
503 | " array([ 0.05834246, 1.2912889 , 0.3400916 , -0.34394732, -0.08250807,\n",
504 | " -0.17860876, 0. , 0. ], dtype=float32),\n",
505 | " array([ 0.06172209, 1.2829638 , 0.34867343, -0.37062198, -0.09314454,\n",
506 | " -0.21274868, 0. , 0. ], dtype=float32),\n",
507 | " array([ 0.06502628, 1.2740395 , 0.33922154, -0.39720652, -0.10188421,\n",
508 | " -0.17480874, 0. , 0. ], dtype=float32),\n",
509 | " array([ 0.06852102, 1.2659017 , 0.35781997, -0.36226854, -0.11017855,\n",
510 | " -0.1659019 , 0. , 0. ], dtype=float32),\n",
511 | " array([ 0.07213764, 1.2581213 , 0.3697724 , -0.34640688, -0.11823454,\n",
512 | " -0.16113424, 0. , 0. ], dtype=float32),\n",
513 | " array([ 0.07599831, 1.250621 , 0.39334607, -0.33393574, -0.1254857 ,\n",
514 | " -0.14503631, 0. , 0. ], dtype=float32),\n",
515 | " array([ 0.07978897, 1.2425411 , 0.3845337 , -0.35957322, -0.13093206,\n",
516 | " -0.10893674, 0. , 0. ], dtype=float32),\n",
517 | " array([ 0.08357992, 1.2338617 , 0.38454834, -0.38624236, -0.136378 ,\n",
518 | " -0.10892855, 0. , 0. ], dtype=float32),\n",
519 | " array([ 0.08736897, 1.2245612 , 0.3843667 , -0.41385975, -0.14182337,\n",
520 | " -0.10890688, 0. , 0. ], dtype=float32),\n",
521 | " array([ 0.09117527, 1.2154578 , 0.3863613 , -0.4051519 , -0.14753653,\n",
522 | " -0.11426322, 0. , 0. ], dtype=float32),\n",
523 | " array([ 0.09506798, 1.2057257 , 0.39723632, -0.43334863, -0.15548898,\n",
524 | " -0.15904924, 0. , 0. ], dtype=float32),\n",
525 | " array([ 0.09896078, 1.1953943 , 0.39723513, -0.4600205 , -0.16344142,\n",
526 | " -0.15904859, 0. , 0. ], dtype=float32),\n",
527 | " array([ 0.10305023, 1.1857693 , 0.41663918, -0.42865118, -0.1711515 ,\n",
528 | " -0.15420182, 0. , 0. ], dtype=float32),\n",
529 | " array([ 0.10713968, 1.1755449 , 0.41663796, -0.45532274, -0.17886156,\n",
530 | " -0.15420106, 0. , 0. ], dtype=float32),\n",
531 | " array([ 0.11116524, 1.1647334 , 0.40861574, -0.4812524 , -0.18494107,\n",
532 | " -0.12159048, 0. , 0. ], dtype=float32),\n",
533 | " array([ 0.1151638 , 1.1540335 , 0.40642828, -0.47639313, -0.19154665,\n",
534 | " -0.13211167, 0. , 0. ], dtype=float32),\n",
535 | " array([ 0.11916237, 1.1427339 , 0.4064273 , -0.50306344, -0.19815221,\n",
536 | " -0.13211125, 0. , 0. ], dtype=float32),\n",
537 | " array([ 0.12323084, 1.1308116 , 0.41521496, -0.53102475, -0.20657948,\n",
538 | " -0.16854541, 0. , 0. ], dtype=float32),\n",
539 | " array([ 0.12759057, 1.1193523 , 0.44367886, -0.5104016 , -0.21436031,\n",
540 | " -0.15561649, 0. , 0. ], dtype=float32),\n",
541 | " array([ 0.1321722 , 1.1080383 , 0.46535143, -0.50390726, -0.22162652,\n",
542 | " -0.1453242 , 0. , 0. ], dtype=float32),\n",
543 | " array([ 0.13710518, 1.0973196 , 0.49971023, -0.4773649 , -0.2281169 ,\n",
544 | " -0.1298075 , 0. , 0. ], dtype=float32),\n",
545 | " array([ 0.14234333, 1.086952 , 0.52950966, -0.46167386, -0.23389536,\n",
546 | " -0.11556929, 0. , 0. ], dtype=float32),\n",
547 | " array([ 0.14751807, 1.0759965 , 0.52158654, -0.4875675 , -0.23806275,\n",
548 | " -0.08334794, 0. , 0. ], dtype=float32),\n",
549 | " array([ 0.15269288, 1.0644413 , 0.52158606, -0.51423556, -0.24223015,\n",
550 | " -0.08334783, 0. , 0. ], dtype=float32),\n",
551 | " array([ 0.15786782, 1.0522863 , 0.5215855 , -0.5409037 , -0.24639754,\n",
552 | " -0.08334772, 0. , 0. ], dtype=float32),\n",
553 | " array([ 0.16310024, 1.0408137 , 0.52803177, -0.5107082 , -0.25127062,\n",
554 | " -0.09746158, 0. , 0. ], dtype=float32),\n",
555 | " array([ 0.16864958, 1.029617 , 0.5589772 , -0.49832523, -0.25539538,\n",
556 | " -0.08249549, 0. , 0. ], dtype=float32),\n",
557 | " array([ 0.17437668, 1.019299 , 0.57709444, -0.45934814, -0.259876 ,\n",
558 | " -0.08961239, 0. , 0. ], dtype=float32),\n",
559 | " array([ 0.18010378, 1.0083812 , 0.5770937 , -0.48601642, -0.2643566 ,\n",
560 | " -0.08961224, 0. , 0. ], dtype=float32),\n",
561 | " array([ 0.1859231 , 0.99682903, 0.588689 , -0.5146646 , -0.2712742 ,\n",
562 | " -0.1383522 , 0. , 0. ], dtype=float32),\n",
563 | " array([ 0.19174251, 0.98467755, 0.5886875 , -0.5413351 , -0.2781918 ,\n",
564 | " -0.13835177, 0. , 0. ], dtype=float32),\n",
565 | " array([ 0.19747925, 0.9719698 , 0.5782013 , -0.5656596 , -0.2828468 ,\n",
566 | " -0.09309985, 0. , 0. ], dtype=float32),\n",
567 | " array([ 0.20321599, 0.9586623 , 0.5782005 , -0.592328 , -0.28750178,\n",
568 | " -0.09309975, 0. , 0. ], dtype=float32),\n",
569 | " array([ 0.20902376, 0.9447341 , 0.5870816 , -0.62028587, -0.29400042,\n",
570 | " -0.12997194, 0. , 0. ], dtype=float32),\n",
571 | " array([ 0.21475688, 0.9302393 , 0.57767045, -0.6450995 , -0.2984922 ,\n",
572 | " -0.08983554, 0. , 0. ], dtype=float32),\n",
573 | " array([ 0.22040614, 0.9151806 , 0.56708866, -0.6697237 , -0.3007316 ,\n",
574 | " -0.04478817, 0. , 0. ], dtype=float32),\n",
575 | " array([ 0.22598 , 0.89956385, 0.5575346 , -0.69411033, -0.30088818,\n",
576 | " -0.00313152, 0. , 0. ], dtype=float32),\n",
577 | " array([ 0.23146506, 0.8833899 , 0.54632086, -0.7183911 , -0.2986327 ,\n",
578 | " 0.0451091 , 0. , 0. ], dtype=float32),\n",
579 | " array([ 0.23701553, 0.866591 , 0.5545404 , -0.7465191 , -0.2981128 ,\n",
580 | " 0.01039831, 0. , 0. ], dtype=float32),\n",
581 | " array([ 0.24256602, 0.84919196, 0.5545404 , -0.77318585, -0.29759288,\n",
582 | " 0.01039837, 0. , 0. ], dtype=float32),\n",
583 | " array([ 0.24834327, 0.8317434 , 0.5767339 , -0.77528656, -0.29656732,\n",
584 | " 0.02051135, 0. , 0. ], dtype=float32),\n",
585 | " array([ 0.25418806, 0.8136553 , 0.5852955 , -0.80408293, -0.29741687,\n",
586 | " -0.01699063, 0. , 0. ], dtype=float32),\n",
587 | " array([ 0.26003274, 0.7949673 , 0.5852954 , -0.83074963, -0.2982664 ,\n",
588 | " -0.01699061, 0. , 0. ], dtype=float32),\n",
589 | " array([ 0.26631337, 0.77691966, 0.62809825, -0.80212414, -0.2983126 ,\n",
590 | " -0.00092363, 0. , 0. ], dtype=float32),\n",
591 | " array([ 0.2725939 , 0.7582721 , 0.62809837, -0.82879084, -0.2983588 ,\n",
592 | " -0.00092367, 0. , 0. ], dtype=float32),\n",
593 | " array([ 2.7907389e-01, 7.3992133e-01, 6.4797944e-01, -8.1558305e-01,\n",
594 | " -2.9833561e-01, 4.6390909e-04, 0.0000000e+00, 0.0000000e+00],\n",
595 | " dtype=float32),\n",
596 | " array([ 0.28546923, 0.7210037 , 0.63733035, -0.8403326 , -0.29606122,\n",
597 | " 0.04548761, 0. , 0. ], dtype=float32),\n",
598 | " array([ 0.29180604, 0.70152026, 0.6298959 , -0.8651683 , -0.29216075,\n",
599 | " 0.07800949, 0. , 0. ], dtype=float32),\n",
600 | " array([ 0.29806557, 0.6814782 , 0.6201067 , -0.8896006 , -0.2861396 ,\n",
601 | " 0.12042297, 0. , 0. ], dtype=float32),\n",
602 | " array([ 0.304393 , 0.66080976, 0.62864447, -0.917799 , -0.28192195,\n",
603 | " 0.08435254, 0. , 0. ], dtype=float32),\n",
604 | " array([ 0.31065854, 0.63955736, 0.62088794, -0.9434684 , -0.27610767,\n",
605 | " 0.11628503, 0. , 0. ], dtype=float32),\n",
606 | " array([ 0.31683215, 0.6177396 , 0.60932267, -0.96817976, -0.26786372,\n",
607 | " 0.16487893, 0. , 0. ], dtype=float32),\n",
608 | " array([ 0.32300606, 0.5953227 , 0.60932034, -0.9948518 , -0.2596198 ,\n",
609 | " 0.16487816, 0. , 0. ], dtype=float32),\n",
610 | " array([ 0.32926998, 0.5733222 , 0.61870646, -0.9764615 , -0.25177723,\n",
611 | " 0.1568514 , 0. , 0. ], dtype=float32),\n",
612 | " array([ 0.335534 , 0.5507224 , 0.6187045 , -1.0031332 , -0.24393468,\n",
613 | " 0.15685079, 0. , 0. ], dtype=float32),\n",
614 | " array([ 0.34193307, 0.5280501 , 0.63197047, -1.0063579 , -0.23587935,\n",
615 | " 0.16110703, 0. , 0. ], dtype=float32),\n",
616 | " array([ 0.3483321 , 0.5047788 , 0.6319685 , -1.0330298 , -0.22782403,\n",
617 | " 0.16110618, 0. , 0. ], dtype=float32),\n",
618 | " array([ 0.35482207, 0.4808718 , 0.6434083 , -1.061682 , -0.22217 ,\n",
619 | " 0.11308068, 0. , 0. ], dtype=float32),\n",
620 | " array([ 0.36123332, 0.45639145, 0.63350135, -1.0868869 , -0.21445887,\n",
621 | " 0.1542222 , 0. , 0. ], dtype=float32),\n",
622 | " array([ 0.36755657, 0.4313467 , 0.6223889 , -1.1116874 , -0.2044231 ,\n",
623 | " 0.20071515, 0. , 0. ], dtype=float32),\n",
624 | " array([ 0.3738801 , 0.4057033 , 0.62238616, -1.1383624 , -0.1943874 ,\n",
625 | " 0.20071383, 0. , 0. ], dtype=float32),\n",
626 | " array([ 0.3802759 , 0.37943146, 0.63148636, -1.1665958 , -0.18625432,\n",
627 | " 0.16266184, 0. , 0. ], dtype=float32),\n",
628 | " array([ 0.3866108 , 0.35258853, 0.6237731 , -1.1918262 , -0.17650253,\n",
629 | " 0.19503552, 0. , 0. ], dtype=float32),\n",
630 | " array([ 0.3928669 , 0.32515943, 0.6138825 , -1.2177227 , -0.1647498 ,\n",
631 | " 0.23505464, 0. , 0. ], dtype=float32),\n",
632 | " array([ 0.39912328, 0.29713205, 0.6138795 , -1.2444009 , -0.15299718,\n",
633 | " 0.23505235, 0. , 0. ], dtype=float32),\n",
634 | " array([ 0.40555716, 0.26928914, 0.63117206, -1.23625 , -0.14080116,\n",
635 | " 0.24392083, 0. , 0. ], dtype=float32),\n",
636 | " array([ 0.41192016, 0.24086352, 0.62222606, -1.262094 , -0.12678766,\n",
637 | " 0.28026995, 0. , 0. ], dtype=float32),\n",
638 | " array([ 0.4183545 , 0.2118345 , 0.63115126, -1.2891777 , -0.11456715,\n",
639 | " 0.24441049, 0. , 0. ], dtype=float32),\n",
640 | " array([ 0.42478913, 0.18220748, 0.63114893, -1.3158567 , -0.10234676,\n",
641 | " 0.24440798, 0. , 0. ], dtype=float32),\n",
642 | " array([ 0.43113318, 0.15199342, 0.6197608 , -1.3419081 , -0.08783399,\n",
643 | " 0.29025573, 0. , 0. ], dtype=float32),\n",
644 | " array([ 0.43754464, 0.12116834, 0.6282196 , -1.369294 , -0.07503272,\n",
645 | " 0.25602564, 0. , 0. ], dtype=float32),\n",
646 | " array([ 0.44387072, 0.08976682, 0.6174868 , -1.3949317 , -0.06005894,\n",
647 | " 0.2994754 , 0. , 0. ], dtype=float32),\n",
648 | " array([ 0.45019692, 0.05776836, 0.6174848 , -1.4216172 , -0.04508539,\n",
649 | " 0.29947075, 0. , 0. ], dtype=float32),\n",
650 | " array([ 0.45643815, 0.0251823 , 0.60678774, -1.4478359 , -0.02796434,\n",
651 | " 0.34242067, 0. , 0. ], dtype=float32),\n",
652 | " array([ 0.46254796, -0.00653446, 0.5943854 , -1.4094068 , -0.01157715,\n",
653 | " 0.3277441 , 0. , 0. ], dtype=float32),\n",
654 | " array([ 0.46856374, -0.03884045, 0.5825814 , -1.4357822 , 0.00717474,\n",
655 | " 0.37503785, 1. , 0. ], dtype=float32),\n",
656 | " array([ 0.47394055, -0.07029414, 0.5007308 , -1.398636 , 0.03890656,\n",
657 | " 0.62863874, 1. , 0. ], dtype=float32),\n",
658 | " array([ 6.3772203e-04, 1.4035777e+00, 6.4582005e-02, -3.2633755e-01,\n",
659 | " -7.3219155e-04, -1.4628743e-02, 0.0000000e+00, 0.0000000e+00],\n",
660 | " dtype=float32),\n",
661 | " array([ 1.2557984e-03, 1.3968295e+00, 6.2634937e-02, -2.9991859e-01,\n",
662 | " -1.5520940e-03, -1.6398780e-02, 0.0000000e+00, 0.0000000e+00],\n",
663 | " dtype=float32),\n",
664 | " array([ 2.0483017e-03, 1.3907650e+00, 7.9237178e-02, -2.6953679e-01,\n",
665 | " -1.5398865e-03, 2.4391804e-04, 0.0000000e+00, 0.0000000e+00],\n",
666 | " dtype=float32),\n",
667 | " array([ 2.8409003e-03, 1.3841002e+00, 7.9236843e-02, -2.9621407e-01,\n",
668 | " -1.5283929e-03, 2.3040108e-04, 0.0000000e+00, 0.0000000e+00],\n",
669 | " dtype=float32),\n",
670 | " array([ 0.00369711, 1.3768382 , 0.08724354, -0.3227533 , -0.00312197,\n",
671 | " -0.03187469, 0. , 0. ], dtype=float32),\n",
672 | " array([ 0.00447617, 1.3689772 , 0.0775428 , -0.3493837 , -0.00276848,\n",
673 | " 0.00707043, 0. , 0. ], dtype=float32))"
674 | ]
675 | },
676 | "metadata": {},
677 | "execution_count": 116
678 | }
679 | ]
680 | },
681 | {
682 | "cell_type": "code",
683 | "source": [
684 | "s = random.sample(memory, 10)\n",
685 | "batch = MemoryBlock(*zip(*s))"
686 | ],
687 | "metadata": {
688 | "id": "dkcpe-Qw9E2X"
689 | },
690 | "execution_count": 119,
691 | "outputs": []
692 | },
693 | {
694 | "cell_type": "code",
695 | "source": [
696 | "batch.reward"
697 | ],
698 | "metadata": {
699 | "colab": {
700 | "base_uri": "https://localhost:8080/"
701 | },
702 | "id": "Z2dSnIvO9P62",
703 | "outputId": "93da11b9-a8a7-4090-8862-dbf46a493ace"
704 | },
705 | "execution_count": 120,
706 | "outputs": [
707 | {
708 | "output_type": "execute_result",
709 | "data": {
710 | "text/plain": [
711 | "(0.7480714012465353,\n",
712 | " -0.6803533446248753,\n",
713 | " 3.16454145519175,\n",
714 | " -0.4614322955537659,\n",
715 | " -1.2174461058552595,\n",
716 | " 0.24074994760946994,\n",
717 | " -0.9185988699621135,\n",
718 | " -0.9376932553099369,\n",
719 | " -0.575740884578579,\n",
720 | " -0.37997100287804986)"
721 | ]
722 | },
723 | "metadata": {},
724 | "execution_count": 120
725 | }
726 | ]
727 | },
728 | {
729 | "cell_type": "code",
730 | "source": [],
731 | "metadata": {
732 | "id": "Ej2l8P2R9U3E"
733 | },
734 | "execution_count": null,
735 | "outputs": []
736 | }
737 | ]
738 | }
--------------------------------------------------------------------------------
/25_Double_Deep_Q_Learning_2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4",
8 | "authorship_tag": "ABX9TyO5oExkA5B4ZMegA5eJNX6v",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | },
18 | "accelerator": "GPU"
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "view-in-github",
25 | "colab_type": "text"
26 | },
27 | "source": [
28 | " "
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "source": [
34 | "Main source: https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html"
35 | ],
36 | "metadata": {
37 | "id": "e2c9wneh2jNf"
38 | }
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 1,
43 | "metadata": {
44 | "colab": {
45 | "base_uri": "https://localhost:8080/"
46 | },
47 | "id": "kMrwaCQ804lM",
48 | "outputId": "51b22a64-3639-4925-b883-7affe72d375f"
49 | },
50 | "outputs": [
51 | {
52 | "output_type": "stream",
53 | "name": "stdout",
54 | "text": [
55 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m953.9/953.9 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
56 | "\u001b[?25h"
57 | ]
58 | }
59 | ],
60 | "source": [
61 | "!pip install -q gymnasium[classic_control]"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "source": [
67 | "import math\n",
68 | "import random\n",
69 | "from collections import namedtuple, deque\n",
70 | "from itertools import count\n",
71 | "\n",
72 | "import gymnasium as gym\n",
73 | "\n",
74 | "import torch\n",
75 | "import torch.nn as nn\n",
76 | "import torch.nn.functional as F\n",
77 | "import torch.optim as optim\n",
78 | "\n",
79 | "from IPython import display\n",
80 | "\n",
81 | "import numpy as np\n",
82 | "import matplotlib.pyplot as plt"
83 | ],
84 | "metadata": {
85 | "id": "S-VcJSlJ2hM3"
86 | },
87 | "execution_count": 2,
88 | "outputs": []
89 | },
90 | {
91 | "cell_type": "code",
92 | "source": [
93 | "env = gym.make('CartPole-v1')"
94 | ],
95 | "metadata": {
96 | "id": "d4JAWIO23I6V"
97 | },
98 | "execution_count": 3,
99 | "outputs": []
100 | },
101 | {
102 | "cell_type": "code",
103 | "source": [
104 | "BATCH_SIZE = 128\n",
105 | "NUM_EPISODES = 600\n",
106 | "\n",
107 | "GAMMA = .99 # Discount factor\n",
108 | "\n",
109 | "# epsilon-greedy parameters:\n",
110 | "EPS_START = .9\n",
111 | "EPS_END = .05\n",
112 | "EPS_DECAY = 1000\n",
113 | "\n",
114 | "TAU = 5e-3\n",
115 | "\n",
116 | "LR = 1e-4\n",
117 | "CLIP_VALUE = 100\n",
118 | "\n",
119 | "MEMORY_SIZE = 10000\n",
120 | "\n",
121 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'"
122 | ],
123 | "metadata": {
124 | "id": "-tf9lVbc3CsC"
125 | },
126 | "execution_count": 4,
127 | "outputs": []
128 | },
129 | {
130 | "cell_type": "code",
131 | "source": [
132 | "state, info = env.reset()\n",
133 | "N_OBSERVATIONS = len(state)\n",
134 | "N_ACTIONS = env.action_space.n"
135 | ],
136 | "metadata": {
137 | "id": "7LuXu5sU5qGW"
138 | },
139 | "execution_count": 5,
140 | "outputs": []
141 | },
142 | {
143 | "cell_type": "code",
144 | "source": [
145 | "plt.ion()"
146 | ],
147 | "metadata": {
148 | "colab": {
149 | "base_uri": "https://localhost:8080/"
150 | },
151 | "id": "9Py_Q-6j3M9n",
152 | "outputId": "f3bfeeed-f12b-44b3-fa5b-f7dc89c5af2c"
153 | },
154 | "execution_count": 6,
155 | "outputs": [
156 | {
157 | "output_type": "execute_result",
158 | "data": {
159 | "text/plain": [
160 | ""
161 | ]
162 | },
163 | "metadata": {},
164 | "execution_count": 6
165 | }
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "source": [
171 | "MemoryBlock = namedtuple('MemoryBlock', ('state', 'action', 'reward', 'next_state'))\n",
172 | "\n",
173 | "class ReplayMemory:\n",
174 | " def __init__(self, capacity):\n",
175 | " self.memory = deque(maxlen=capacity)\n",
176 | "\n",
177 | " def __len__(self):\n",
178 | " return len(self.memory)\n",
179 | "\n",
180 | " def push(self, state, action, reward, next_state):\n",
181 | " block = MemoryBlock(state, action, reward, next_state)\n",
182 | " self.memory.append(block)\n",
183 | "\n",
184 | " def sample(self, batch_size):\n",
185 | " return random.sample(self.memory, batch_size)"
186 | ],
187 | "metadata": {
188 | "id": "A-qtTzAV3ZY8"
189 | },
190 | "execution_count": 7,
191 | "outputs": []
192 | },
193 | {
194 | "cell_type": "code",
195 | "source": [
196 | "class PolicyNetwork(nn.Module):\n",
197 | " def __init__(self, n_observations, n_actions, latent_dim=128):\n",
198 | " super().__init__()\n",
199 | "\n",
200 | " self.layers = nn.Sequential(\n",
201 | " nn.Linear(n_observations, latent_dim),\n",
202 | " nn.ReLU(),\n",
203 | " nn.Linear(latent_dim, latent_dim),\n",
204 | " nn.ReLU(),\n",
205 | " nn.Linear(latent_dim, n_actions)\n",
206 | " )\n",
207 | "\n",
208 | " def forward(self, x):\n",
209 | " return self.layers(x)"
210 | ],
211 | "metadata": {
212 | "id": "3EQxRKlX4gPh"
213 | },
214 | "execution_count": 8,
215 | "outputs": []
216 | },
217 | {
218 | "cell_type": "code",
219 | "source": [
220 | "policy_net = PolicyNetwork(N_OBSERVATIONS, N_ACTIONS).to(DEVICE)\n",
221 | "target_net = PolicyNetwork(N_OBSERVATIONS, N_ACTIONS).to(DEVICE)\n",
222 | "\n",
223 | "target_net.load_state_dict(policy_net.state_dict())"
224 | ],
225 | "metadata": {
226 | "colab": {
227 | "base_uri": "https://localhost:8080/"
228 | },
229 | "id": "y3MogvRV5MmR",
230 | "outputId": "e3742ebb-08ad-4d5d-cb64-c9a0333e7f68"
231 | },
232 | "execution_count": 9,
233 | "outputs": [
234 | {
235 | "output_type": "execute_result",
236 | "data": {
237 | "text/plain": [
238 | ""
239 | ]
240 | },
241 | "metadata": {},
242 | "execution_count": 9
243 | }
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "source": [
249 | "optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)\n",
250 | "criterion = nn.SmoothL1Loss()"
251 | ],
252 | "metadata": {
253 | "id": "vZTxKNpG6IAT"
254 | },
255 | "execution_count": 10,
256 | "outputs": []
257 | },
258 | {
259 | "cell_type": "code",
260 | "source": [
261 | "memory = ReplayMemory(MEMORY_SIZE)"
262 | ],
263 | "metadata": {
264 | "id": "qK5Uinxk6Uxp"
265 | },
266 | "execution_count": 11,
267 | "outputs": []
268 | },
269 | {
270 | "cell_type": "code",
271 | "source": [
272 | "x = np.array(range(6000))\n",
273 | "\n",
274 | "eps = EPS_END + (EPS_START - EPS_END) * np.exp(-1 * x / EPS_DECAY)\n",
275 | "\n",
276 | "plt.plot(x, eps)\n",
277 | "plt.show()"
278 | ],
279 | "metadata": {
280 | "colab": {
281 | "base_uri": "https://localhost:8080/",
282 | "height": 430
283 | },
284 | "id": "ln4qZVoT7Dhi",
285 | "outputId": "66003bd4-5de0-453c-dff1-5f7174cb12f3"
286 | },
287 | "execution_count": 12,
288 | "outputs": [
289 | {
290 | "output_type": "display_data",
291 | "data": {
292 | "text/plain": [
293 | ""
294 | ],
295 | "image/png": "\n"
296 | },
297 | "metadata": {}
298 | }
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "source": [
304 | "steps_done = 0\n",
305 | "\n",
306 | "def select_action(state):\n",
307 | " global steps_done\n",
308 | "\n",
309 | " eps = EPS_END + (EPS_START - EPS_END) * math.exp(-1 * steps_done / EPS_DECAY)\n",
310 | " steps_done += 1\n",
311 | "\n",
312 | " if random.random() > eps:\n",
313 | " with torch.no_grad():\n",
314 | " action = policy_net(state).max(1).indices.view(1, 1)\n",
315 | " else:\n",
316 | " action = torch.tensor([[env.action_space.sample()]], device=DEVICE, dtype=torch.long)\n",
317 | "\n",
318 | " return action"
319 | ],
320 | "metadata": {
321 | "id": "YijGFqvN6b6E"
322 | },
323 | "execution_count": 13,
324 | "outputs": []
325 | },
326 | {
327 | "cell_type": "code",
328 | "source": [
329 | "def optimize_model():\n",
330 | " if len(memory) < BATCH_SIZE:\n",
331 | " return\n",
332 | "\n",
333 | " optimizer.zero_grad()\n",
334 | "\n",
335 | " history = memory.sample(BATCH_SIZE)\n",
336 | " batch = MemoryBlock(*zip(*history))\n",
337 | "\n",
338 | " state_batch = torch.cat(batch.state)\n",
339 | " action_batch = torch.cat(batch.action)\n",
340 | " reward_batch = torch.cat(batch.reward)\n",
341 | "\n",
342 | " state_action_values = policy_net(state_batch).gather(1, action_batch)\n",
343 | "\n",
344 | " non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=DEVICE, dtype=torch.bool)\n",
345 | " non_final_next_state = torch.cat([s for s in batch.next_state if s is not None])\n",
346 | " next_state_values = torch.zeros(BATCH_SIZE, device=DEVICE)\n",
347 | "\n",
348 | " with torch.no_grad():\n",
349 | " next_state_values[non_final_mask] = target_net(non_final_next_state).max(1).values\n",
350 | "\n",
351 | " expected_state_action_values = (next_state_values * GAMMA) + reward_batch\n",
352 | " loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))\n",
353 | "\n",
354 | " loss.backward()\n",
355 | "\n",
356 | " nn.utils.clip_grad_value_(policy_net.parameters(), CLIP_VALUE)\n",
357 | " optimizer.step()"
358 | ],
359 | "metadata": {
360 | "id": "KtBs3WkB79yg"
361 | },
362 | "execution_count": 14,
363 | "outputs": []
364 | },
365 | {
366 | "cell_type": "code",
367 | "source": [
368 | "episode_durations = []"
369 | ],
370 | "metadata": {
371 | "id": "WsKcEq4u6_5T"
372 | },
373 | "execution_count": 15,
374 | "outputs": []
375 | },
376 | {
377 | "cell_type": "code",
378 | "source": [
379 | "for epoch in range(1, NUM_EPISODES + 1):\n",
380 | " state, _ = env.reset()\n",
381 | " state = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)\n",
382 | "\n",
383 | " for t in count():\n",
384 | " action = select_action(state)\n",
385 | "\n",
386 | " observation, reward, terminated, truncated, _ = env.step(action.item())\n",
387 | " reward = torch.tensor([reward], device=DEVICE)\n",
388 | "\n",
389 | " if terminated:\n",
390 | " next_state = None\n",
391 | " else:\n",
392 | " next_state = torch.tensor(observation, dtype=torch.float32, device=DEVICE).unsqueeze(0)\n",
393 | "\n",
394 | " memory.push(state, action, reward, next_state)\n",
395 | " state = next_state\n",
396 | "\n",
397 | " optimize_model()\n",
398 | "\n",
399 | " target_state_dict = target_net.state_dict()\n",
400 | " policy_state_dict = policy_net.state_dict()\n",
401 | "\n",
402 | " for key in policy_state_dict:\n",
403 | " target_state_dict[key] = policy_state_dict[key] * TAU + target_state_dict[key] * (1 - TAU)\n",
404 | "\n",
405 | " target_net.load_state_dict(target_state_dict)\n",
406 | "\n",
407 | " if terminated or truncated:\n",
408 | " episode_durations.append(t + 1)\n",
409 | "\n",
410 | " if epoch % 50 == 0:\n",
411 | " print(f'Episode {epoch} duration: {t + 1}')\n",
412 | "\n",
413 | " break"
414 | ],
415 | "metadata": {
416 | "colab": {
417 | "base_uri": "https://localhost:8080/"
418 | },
419 | "id": "wO7NprE-_TQa",
420 | "outputId": "8ab15210-408f-452b-f081-926f7d4679e3"
421 | },
422 | "execution_count": 16,
423 | "outputs": [
424 | {
425 | "output_type": "stream",
426 | "name": "stdout",
427 | "text": [
428 | "Episode 50 duration: 10\n",
429 | "Episode 100 duration: 17\n",
430 | "Episode 150 duration: 8\n",
431 | "Episode 200 duration: 74\n",
432 | "Episode 250 duration: 119\n",
433 | "Episode 300 duration: 120\n",
434 | "Episode 350 duration: 145\n",
435 | "Episode 400 duration: 181\n",
436 | "Episode 450 duration: 318\n",
437 | "Episode 500 duration: 500\n",
438 | "Episode 550 duration: 500\n",
439 | "Episode 600 duration: 500\n"
440 | ]
441 | }
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "source": [
447 | "plt.plot(episode_durations)\n",
448 | "\n",
449 | "plt.show()"
450 | ],
451 | "metadata": {
452 | "colab": {
453 | "base_uri": "https://localhost:8080/",
454 | "height": 430
455 | },
456 | "id": "s7t56VlbBr60",
457 | "outputId": "227e4320-e867-4da1-aedc-450b397db796"
458 | },
459 | "execution_count": 17,
460 | "outputs": [
461 | {
462 | "output_type": "display_data",
463 | "data": {
464 | "text/plain": [
465 | ""
466 | ],
467 | "image/png": "\n"
468 | },
469 | "metadata": {}
470 | }
471 | ]
472 | },
473 | {
474 | "cell_type": "code",
475 | "source": [],
476 | "metadata": {
477 | "id": "SkiYqlimB-lB"
478 | },
479 | "execution_count": 17,
480 | "outputs": []
481 | }
482 | ]
483 | }
--------------------------------------------------------------------------------
/5_Pytorch_Introdution.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyMa4UQgnjr1l61FgngsPT2O",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | " "
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "metadata": {
33 | "id": "wqWzgaihwob-"
34 | },
35 | "outputs": [],
36 | "source": [
37 | "import torch\n",
38 | "import torch.nn as nn\n",
39 | "import torch.nn.functional as F\n",
40 | "import torch.optim as optim"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "source": [
46 | "scaler = torch.tensor([1.])"
47 | ],
48 | "metadata": {
49 | "id": "fSQwTeFAxRPu"
50 | },
51 | "execution_count": 4,
52 | "outputs": []
53 | },
54 | {
55 | "cell_type": "code",
56 | "source": [
57 | "scaler"
58 | ],
59 | "metadata": {
60 | "colab": {
61 | "base_uri": "https://localhost:8080/"
62 | },
63 | "id": "JbmosPEXxwBI",
64 | "outputId": "a4a5fcdc-a0b6-4c41-f084-7d18cdd0525b"
65 | },
66 | "execution_count": 5,
67 | "outputs": [
68 | {
69 | "output_type": "execute_result",
70 | "data": {
71 | "text/plain": [
72 | "tensor([1.])"
73 | ]
74 | },
75 | "metadata": {},
76 | "execution_count": 5
77 | }
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "source": [
83 | "scaler.ndim"
84 | ],
85 | "metadata": {
86 | "colab": {
87 | "base_uri": "https://localhost:8080/"
88 | },
89 | "id": "HjeDdpDOxy11",
90 | "outputId": "94029959-2384-4e69-c806-971e23e97bf1"
91 | },
92 | "execution_count": 6,
93 | "outputs": [
94 | {
95 | "output_type": "execute_result",
96 | "data": {
97 | "text/plain": [
98 | "1"
99 | ]
100 | },
101 | "metadata": {},
102 | "execution_count": 6
103 | }
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "source": [
109 | "scaler.dtype"
110 | ],
111 | "metadata": {
112 | "colab": {
113 | "base_uri": "https://localhost:8080/"
114 | },
115 | "id": "87I5lTcLx0D1",
116 | "outputId": "ae93d0d4-db54-4f30-99c4-edafb9d0d47e"
117 | },
118 | "execution_count": 7,
119 | "outputs": [
120 | {
121 | "output_type": "execute_result",
122 | "data": {
123 | "text/plain": [
124 | "torch.float32"
125 | ]
126 | },
127 | "metadata": {},
128 | "execution_count": 7
129 | }
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "source": [
135 | "scaler + 2"
136 | ],
137 | "metadata": {
138 | "colab": {
139 | "base_uri": "https://localhost:8080/"
140 | },
141 | "id": "3HELsLg5x1Jl",
142 | "outputId": "dc28f56f-5cb3-422a-c930-d0c17fadd0f4"
143 | },
144 | "execution_count": 8,
145 | "outputs": [
146 | {
147 | "output_type": "execute_result",
148 | "data": {
149 | "text/plain": [
150 | "tensor([3.])"
151 | ]
152 | },
153 | "metadata": {},
154 | "execution_count": 8
155 | }
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "source": [
161 | "scaler.requires_grad"
162 | ],
163 | "metadata": {
164 | "colab": {
165 | "base_uri": "https://localhost:8080/"
166 | },
167 | "id": "JQSUrajux7ZN",
168 | "outputId": "376f1b1c-e46b-49a8-b41f-9c1580feff32"
169 | },
170 | "execution_count": 9,
171 | "outputs": [
172 | {
173 | "output_type": "execute_result",
174 | "data": {
175 | "text/plain": [
176 | "False"
177 | ]
178 | },
179 | "metadata": {},
180 | "execution_count": 9
181 | }
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "source": [
187 | "scaler.device"
188 | ],
189 | "metadata": {
190 | "colab": {
191 | "base_uri": "https://localhost:8080/"
192 | },
193 | "id": "wWIFoa5xyGCB",
194 | "outputId": "d702ab3d-9f86-423d-ce53-f58da116e6fd"
195 | },
196 | "execution_count": 10,
197 | "outputs": [
198 | {
199 | "output_type": "execute_result",
200 | "data": {
201 | "text/plain": [
202 | "device(type='cpu')"
203 | ]
204 | },
205 | "metadata": {},
206 | "execution_count": 10
207 | }
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "source": [
213 | "torch.cuda.is_available()"
214 | ],
215 | "metadata": {
216 | "colab": {
217 | "base_uri": "https://localhost:8080/"
218 | },
219 | "id": "kTm0UlECzZCp",
220 | "outputId": "ebf469e0-95c2-4e89-ce60-1290abdd3b79"
221 | },
222 | "execution_count": 15,
223 | "outputs": [
224 | {
225 | "output_type": "execute_result",
226 | "data": {
227 | "text/plain": [
228 | "False"
229 | ]
230 | },
231 | "metadata": {},
232 | "execution_count": 15
233 | }
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "source": [
239 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
240 | ],
241 | "metadata": {
242 | "id": "l1lqFQnVyNXF"
243 | },
244 | "execution_count": 12,
245 | "outputs": []
246 | },
247 | {
248 | "cell_type": "code",
249 | "source": [
250 | "device"
251 | ],
252 | "metadata": {
253 | "colab": {
254 | "base_uri": "https://localhost:8080/",
255 | "height": 36
256 | },
257 | "id": "XZLSWKHmyW4I",
258 | "outputId": "d09aa7bf-a53c-44cc-9bd2-fd129e2fa0ec"
259 | },
260 | "execution_count": 13,
261 | "outputs": [
262 | {
263 | "output_type": "execute_result",
264 | "data": {
265 | "text/plain": [
266 | "'cpu'"
267 | ],
268 | "application/vnd.google.colaboratory.intrinsic+json": {
269 | "type": "string"
270 | }
271 | },
272 | "metadata": {},
273 | "execution_count": 13
274 | }
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "source": [
280 | "scaler = scaler.to(device)"
281 | ],
282 | "metadata": {
283 | "id": "F5Ekj_ITydS-"
284 | },
285 | "execution_count": 14,
286 | "outputs": []
287 | },
288 | {
289 | "cell_type": "code",
290 | "source": [
291 | "class Network(nn.Module):\n",
292 | " def __init__(self):\n",
293 | " super().__init__()\n",
294 | "\n",
295 | " self.fc1 = nn.Linear(in_features=2, out_features=10)\n",
296 | " self.fc2 = nn.Linear(in_features=10, out_features=10)\n",
297 | " self.fc3 = nn.Linear(in_features=10, out_features=10)\n",
298 | " self.fc4 = nn.Linear(in_features=10, out_features=1)\n",
299 | "\n",
300 | " def forward(self, x):\n",
301 | " x = F.relu(self.fc1(x))\n",
302 | " x = F.relu(self.fc2(x))\n",
303 | " x = F.relu(self.fc3(x))\n",
304 | " x = self.fc4(x)\n",
305 | "\n",
306 | " return x"
307 | ],
308 | "metadata": {
309 | "id": "tC8YnzIwyf6n"
310 | },
311 | "execution_count": 22,
312 | "outputs": []
313 | },
314 | {
315 | "cell_type": "code",
316 | "source": [
317 | "net = Network().to(device)"
318 | ],
319 | "metadata": {
320 | "id": "jmt6SYAZ0wPa"
321 | },
322 | "execution_count": 23,
323 | "outputs": []
324 | },
325 | {
326 | "cell_type": "code",
327 | "source": [
328 | "test_input = torch.rand((3, 2))\n",
329 | "test_input"
330 | ],
331 | "metadata": {
332 | "colab": {
333 | "base_uri": "https://localhost:8080/"
334 | },
335 | "id": "dT8GYeRl00ZX",
336 | "outputId": "9990a297-41e1-4512-c15f-9f29e2b0e1bc"
337 | },
338 | "execution_count": 24,
339 | "outputs": [
340 | {
341 | "output_type": "execute_result",
342 | "data": {
343 | "text/plain": [
344 | "tensor([[0.7418, 0.8436],\n",
345 | " [0.8179, 0.3327],\n",
346 | " [0.5869, 0.7119]])"
347 | ]
348 | },
349 | "metadata": {},
350 | "execution_count": 24
351 | }
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "source": [
357 | "test_output = net(test_input)\n",
358 | "test_output"
359 | ],
360 | "metadata": {
361 | "colab": {
362 | "base_uri": "https://localhost:8080/"
363 | },
364 | "id": "D0BmAZwc0_rJ",
365 | "outputId": "6aa5025c-71ad-4de5-d711-190eb2f76ada"
366 | },
367 | "execution_count": 25,
368 | "outputs": [
369 | {
370 | "output_type": "execute_result",
371 | "data": {
372 | "text/plain": [
373 | "tensor([[0.2814],\n",
374 | " [0.2967],\n",
375 | " [0.2830]], grad_fn=)"
376 | ]
377 | },
378 | "metadata": {},
379 | "execution_count": 25
380 | }
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "source": [
386 | "net"
387 | ],
388 | "metadata": {
389 | "colab": {
390 | "base_uri": "https://localhost:8080/"
391 | },
392 | "id": "x6_FwU7R1E8B",
393 | "outputId": "9b217aa5-2416-45d6-96de-88f4dfe9511f"
394 | },
395 | "execution_count": 26,
396 | "outputs": [
397 | {
398 | "output_type": "execute_result",
399 | "data": {
400 | "text/plain": [
401 | "Network(\n",
402 | " (fc1): Linear(in_features=2, out_features=10, bias=True)\n",
403 | " (fc2): Linear(in_features=10, out_features=10, bias=True)\n",
404 | " (fc3): Linear(in_features=10, out_features=10, bias=True)\n",
405 | " (fc4): Linear(in_features=10, out_features=1, bias=True)\n",
406 | ")"
407 | ]
408 | },
409 | "metadata": {},
410 | "execution_count": 26
411 | }
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "source": [
417 | "list(net.fc1.parameters())"
418 | ],
419 | "metadata": {
420 | "colab": {
421 | "base_uri": "https://localhost:8080/"
422 | },
423 | "id": "0Zidy5vm1McJ",
424 | "outputId": "b5fbf595-8a8a-458f-e533-f6e1601ac526"
425 | },
426 | "execution_count": 30,
427 | "outputs": [
428 | {
429 | "output_type": "execute_result",
430 | "data": {
431 | "text/plain": [
432 | "[Parameter containing:\n",
433 | " tensor([[-0.0678, 0.4129],\n",
434 | " [-0.2128, 0.5182],\n",
435 | " [ 0.3328, -0.5278],\n",
436 | " [-0.0189, 0.6389],\n",
437 | " [ 0.3837, 0.2405],\n",
438 | " [-0.2676, -0.2395],\n",
439 | " [ 0.2778, 0.0328],\n",
440 | " [-0.4963, 0.5281],\n",
441 | " [ 0.1273, -0.0240],\n",
442 | " [-0.4277, -0.5793]], requires_grad=True),\n",
443 | " Parameter containing:\n",
444 | " tensor([-0.5801, 0.0506, 0.2750, -0.6558, -0.2582, -0.3598, -0.5241, -0.0149,\n",
445 | " -0.2104, -0.5190], requires_grad=True)]"
446 | ]
447 | },
448 | "metadata": {},
449 | "execution_count": 30
450 | }
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "source": [],
456 | "metadata": {
457 | "id": "_Ai44HNl1PgN"
458 | },
459 | "execution_count": null,
460 | "outputs": []
461 | }
462 | ]
463 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | "# deep_learning_class_notebooks"
2 |
--------------------------------------------------------------------------------
|