"
61 | ]
62 | },
63 | "execution_count": 5,
64 | "metadata": {
65 | "tags": []
66 | },
67 | "output_type": "execute_result"
68 | }
69 | ],
70 | "source": [
71 | "from pyvirtualdisplay import Display\n",
72 | "\n",
73 | "display = Display(visible=0, size=(1400, 900))\n",
74 | "display.start()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {
81 | "id": "RfCA1XwgF7R6"
82 | },
83 | "outputs": [],
84 | "source": [
85 | "# Ten kod tworzy wirtualny ekran, aby toczyć grę na nim\n",
86 | "# Jeśli uruchamiasz lokalnie, zignoruj to\n",
87 | "import os\n",
88 | "if type(os.environ.get('DISPLAY')) is not str or len(os.environ.get('DISPLAY')) == 0:\n",
89 | " !bash ../xvfb start\n",
90 | " %env DISPLAY=:1"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {
97 | "id": "dVCvYQcQGs1T",
98 | "colab": {
99 | "base_uri": "https://localhost:8080/",
100 | "height": 64.0
101 | },
102 | "outputId": "da6aea81-5d5e-4575-aae6-f9b40d883dda"
103 | },
104 | "outputs": [
105 | {
106 | "data": {
107 | "text/html": [
108 | "\n",
109 | "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x. \n",
110 | "We recommend you upgrade now \n",
111 | "or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic:\n",
112 | "more info .
\n"
113 | ],
114 | "text/plain": [
115 | ""
116 | ]
117 | },
118 | "metadata": {
119 | "tags": []
120 | },
121 | "output_type": "execute_result"
122 | }
123 | ],
124 | "source": [
125 | "import gym\n",
126 | "from gym import logger as gymlogger\n",
127 | "from gym.wrappers import Monitor\n",
128 | "gymlogger.set_level(40)\n",
129 | "import tensorflow as tf\n",
130 | "import numpy as np\n",
131 | "import random\n",
132 | "import matplotlib\n",
133 | "import matplotlib.pyplot as plt\n",
134 | "%matplotlib inline\n",
135 | "import math\n",
136 | "import glob\n",
137 | "import io\n",
138 | "import base64\n",
139 | "from IPython.display import HTML\n",
140 | "\n",
141 | "from IPython import display as ipythondisplay"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {
148 | "id": "nTIYTvAQIJtr"
149 | },
150 | "outputs": [],
151 | "source": [
152 | "\"\"\"\n",
153 | "Użyteczne funkcje umożliwiające nagranie środowiska gym i wyświetlenie go\n",
154 | "Aby zezwolić na wideo wystarczy \"env = wrap_env\"\n",
155 | "\"\"\"\n",
156 | "\n",
157 | "def show_video():\n",
158 | " mp4list = glob.glob('video/*.mp4')\n",
159 | " if len(mp4list) > 0:\n",
160 | " mp4 = mp4list[0]\n",
161 | " video = io.open(mp4, 'r+b').read()\n",
162 | " encoded = base64.b64encode(video)\n",
163 | " ipythondisplay.display(HTML(data='''\n",
164 | " \n",
165 | " '''.format(encoded.decode('ascii'))))\n",
166 | " else:\n",
167 | " print('Could not find video')\n",
168 | "\n",
169 | "def wrap_env(env):\n",
170 | " env = Monitor(env, './video', force=True)\n",
171 | " return env"
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "metadata": {
177 | "id": "azBUgbnCUgv6"
178 | },
179 | "source": [
180 | "## CartPole"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {
187 | "id": "cijHHKS0UmJn",
188 | "colab": {
189 | "base_uri": "https://localhost:8080/",
190 | "height": 139.0
191 | },
192 | "outputId": "6caa9eb5-457d-42d1-88bd-d87680c7fa20"
193 | },
194 | "outputs": [
195 | {
196 | "name": "stdout",
197 | "output_type": "stream",
198 | "text": [
199 | "Observation space: Box(4,)\n",
200 | "Action space: Discrete(2)\n",
201 | "Initial observation: [0.01463859 0.04037087 0.04288382 0.03510227]\n",
202 | "Next observation: [ 0.01544601 0.23485245 0.04358586 -0.24374792]\n",
203 | "Reward: 1.0\n",
204 | "Done: False\n",
205 | "Info: {}\n"
206 | ]
207 | }
208 | ],
209 | "source": [
210 | "import gym \n",
211 | "env = gym.make('CartPole-v0')\n",
212 | "env = wrap_env(env)\n",
213 | "\n",
214 | "print('Observation space: ', env.observation_space)\n",
215 | "print('Action space: ', env.action_space)\n",
216 | "\n",
217 | "obs = env.reset()\n",
218 | "\n",
219 | "print('Initial observation: ', obs)\n",
220 | "\n",
221 | "action = env.action_space.sample() # podujmuje losową akcję\n",
222 | "\n",
223 | "obs, r, done, info = env.step(action)\n",
224 | "print('Next observation: ', obs)\n",
225 | "print('Reward: ', r)\n",
226 | "print('Done: ', done)\n",
227 | "print('Info: ', info)"
228 | ]
229 | },
230 | {
231 | "cell_type": "markdown",
232 | "metadata": {
233 | "id": "4MEuGxqqWGzc"
234 | },
235 | "source": [
236 | "### Wyświetlenie wideo"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "metadata": {
243 | "id": "grah6EfMWJmU",
244 | "colab": {
245 | "base_uri": "https://localhost:8080/",
246 | "height": 976.0
247 | },
248 | "outputId": "7e22d1b6-c795-4c37-bb4e-508baac42e35"
249 | },
250 | "outputs": [
251 | {
252 | "name": "stdout",
253 | "output_type": "stream",
254 | "text": [
255 | "1.0\n",
256 | "1.0\n",
257 | "1.0\n",
258 | "1.0\n",
259 | "1.0\n",
260 | "1.0\n",
261 | "1.0\n",
262 | "1.0\n",
263 | "1.0\n",
264 | "1.0\n",
265 | "1.0\n",
266 | "1.0\n",
267 | "1.0\n",
268 | "1.0\n",
269 | "1.0\n",
270 | "1.0\n",
271 | "1.0\n",
272 | "1.0\n",
273 | "1.0\n",
274 | "1.0\n",
275 | "1.0\n",
276 | "1.0\n",
277 | "1.0\n",
278 | "1.0\n",
279 | "1.0\n",
280 | "1.0\n",
281 | "1.0\n",
282 | "1.0\n",
283 | "1.0\n",
284 | "1.0\n",
285 | "1.0\n",
286 | "1.0\n"
287 | ]
288 | },
289 | {
290 | "data": {
291 | "text/html": [
292 | "\n",
293 | " \n",
294 | " "
295 | ],
296 | "text/plain": [
297 | ""
298 | ]
299 | },
300 | "metadata": {
301 | "tags": []
302 | },
303 | "output_type": "execute_result"
304 | }
305 | ],
306 | "source": [
307 | "'''CartPole z użyciem losowej akcji'''\n",
308 | "import gym\n",
309 | "env = gym.make('CartPole-v0')\n",
310 | "env = wrap_env(env)\n",
311 | "\n",
312 | "observation = env.reset()\n",
313 | "\n",
314 | "while True:\n",
315 | " env.render()\n",
316 | "\n",
317 | " action = env.action_space.sample() # podujmuje losową akcję\n",
318 | " observation, reward, done, info = env.step(action)\n",
319 | " print(reward)\n",
320 | "\n",
321 | " if done:\n",
322 | " break;\n",
323 | "\n",
324 | "env.close()\n",
325 | "show_video()"
326 | ]
327 | }
328 | ],
329 | "metadata": {
330 | "colab": {
331 | "name": "cart_pole.ipynb",
332 | "provenance": [],
333 | "authorship_tag": "ABX9TyNNnZD5dGWsLcj8Q3zwW7oa",
334 | "include_colab_link": true
335 | },
336 | "kernelspec": {
337 | "name": "python3",
338 | "display_name": "Python 3"
339 | },
340 | "accelerator": "GPU"
341 | },
342 | "nbformat": 4,
343 | "nbformat_minor": 0
344 | }
345 |
--------------------------------------------------------------------------------
/machine_learning/reinforcement_learning/frozen_lake/frozen_lake.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | " "
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "id": "zp6ERkCfJaDf",
17 | "colab_type": "text"
18 | },
19 | "source": [
20 | "# Q-Learning\n",
21 | "### Podstawowe funkcje"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 0,
27 | "metadata": {
28 | "id": "BCFZaqgSI37U",
29 | "colab_type": "code",
30 | "colab": {}
31 | },
32 | "outputs": [],
33 | "source": [
34 | "import gym\n",
35 | "\n",
36 | "env = gym.make('FrozenLake-v0') # będziemy używać środowiska FrozenLake"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 3,
42 | "metadata": {
43 | "id": "5_iAxJjVJ5D8",
44 | "colab_type": "code",
45 | "colab": {
46 | "base_uri": "https://localhost:8080/",
47 | "height": 52.0
48 | },
49 | "outputId": "a3a76324-911f-4dac-c0d4-c73b9bcae934"
50 | },
51 | "outputs": [
52 | {
53 | "name": "stdout",
54 | "output_type": "stream",
55 | "text": [
56 | "16\n",
57 | "4\n"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "print(env.observation_space.n) # zwraca liczbę stanów\n",
63 | "print(env.action_space.n) # zwraca liczbę akcji"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "metadata": {
70 | "id": "rSI2CxbMKSWE",
71 | "colab_type": "code",
72 | "colab": {
73 | "base_uri": "https://localhost:8080/",
74 | "height": 35.0
75 | },
76 | "outputId": "23000dd3-2562-42bb-9c83-eeef874df311"
77 | },
78 | "outputs": [
79 | {
80 | "data": {
81 | "text/plain": [
82 | "0"
83 | ]
84 | },
85 | "execution_count": 4,
86 | "metadata": {
87 | "tags": []
88 | },
89 | "output_type": "execute_result"
90 | }
91 | ],
92 | "source": [
93 | "env.reset() # resetuje środowisko do stanu domyślnego"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 9,
99 | "metadata": {
100 | "id": "qtnJV9NJKc_L",
101 | "colab_type": "code",
102 | "colab": {
103 | "base_uri": "https://localhost:8080/",
104 | "height": 35.0
105 | },
106 | "outputId": "0b774ece-a43d-4209-8b3a-506d3c916fe1"
107 | },
108 | "outputs": [
109 | {
110 | "name": "stdout",
111 | "output_type": "stream",
112 | "text": [
113 | "2\n"
114 | ]
115 | }
116 | ],
117 | "source": [
118 | "action = env.action_space.sample() # zwraca losową akcję\n",
119 | "print(action)"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 0,
125 | "metadata": {
126 | "id": "pHe8tnxtLhb6",
127 | "colab_type": "code",
128 | "colab": {}
129 | },
130 | "outputs": [],
131 | "source": [
132 | "new_state, reward, done, info = env.step(action) # podejmuje akcję"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 11,
138 | "metadata": {
139 | "id": "iqjEeQASKwr0",
140 | "colab_type": "code",
141 | "colab": {
142 | "base_uri": "https://localhost:8080/",
143 | "height": 104.0
144 | },
145 | "outputId": "32035c87-4788-4756-daa8-4539dd6b656d"
146 | },
147 | "outputs": [
148 | {
149 | "name": "stdout",
150 | "output_type": "stream",
151 | "text": [
152 | " (Up)\n",
153 | "S\u001b[41mF\u001b[0mFF\n",
154 | "FHFH\n",
155 | "FFFH\n",
156 | "HFFG\n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "env.render() # renderuje GUI środowiska"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {
167 | "id": "X0E9UWKaMScI",
168 | "colab_type": "text"
169 | },
170 | "source": [
171 | "### Środowisko FrozenLake\n",
172 | "\n",
173 | "`Frozenlake-v0` to jedno z najprostszych środowisk w Open AI Gym. Celem jest nawigowanie agenta po zamarzniętym jeziorze bez wpadnięcia do wody. Jest tu:\n",
174 | "\n",
175 | "* 16 stanów (jeden dla każdego pola)\n",
176 | "* 4 możliwe akcje (LEFT, RIGHT, DOWN, UP)\n",
177 | "* 4 różne typy pól (F: frozen, H: hole, S: start, G: goal)\n",
178 | "\n",
179 | "### Budowa Q-Table\n",
180 | "\n",
181 | "Pierwszą rzeczą jakiej potrzebujemy jest budowa pustej Q-tabeli, której możemy użyć do przechowywania i uaktualniania naszych wartości."
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 0,
187 | "metadata": {
188 | "id": "HziaZ5EKPFUD",
189 | "colab_type": "code",
190 | "colab": {}
191 | },
192 | "outputs": [],
193 | "source": [
194 | "# import gym\n",
195 | "import numpy as np\n",
196 | "import time\n",
197 | "\n",
198 | "env = gym.make('FrozenLake-v0')\n",
199 | "STATES = env.observation_space.n\n",
200 | "ACTIONS = env.action_space.n"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 14,
206 | "metadata": {
207 | "id": "rLe4nK01QB2C",
208 | "colab_type": "code",
209 | "colab": {
210 | "base_uri": "https://localhost:8080/",
211 | "height": 295.0
212 | },
213 | "outputId": "22166874-9c8b-46e0-b6dd-1b5a53526e4a"
214 | },
215 | "outputs": [
216 | {
217 | "data": {
218 | "text/plain": [
219 | "array([[0., 0., 0., 0.],\n",
220 | " [0., 0., 0., 0.],\n",
221 | " [0., 0., 0., 0.],\n",
222 | " [0., 0., 0., 0.],\n",
223 | " [0., 0., 0., 0.],\n",
224 | " [0., 0., 0., 0.],\n",
225 | " [0., 0., 0., 0.],\n",
226 | " [0., 0., 0., 0.],\n",
227 | " [0., 0., 0., 0.],\n",
228 | " [0., 0., 0., 0.],\n",
229 | " [0., 0., 0., 0.],\n",
230 | " [0., 0., 0., 0.],\n",
231 | " [0., 0., 0., 0.],\n",
232 | " [0., 0., 0., 0.],\n",
233 | " [0., 0., 0., 0.],\n",
234 | " [0., 0., 0., 0.]])"
235 | ]
236 | },
237 | "execution_count": 14,
238 | "metadata": {
239 | "tags": []
240 | },
241 | "output_type": "execute_result"
242 | }
243 | ],
244 | "source": [
245 | "Q = np.zeros((STATES, ACTIONS)) # stworzenie macierzy zer\n",
246 | "Q"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {
252 | "id": "0iq-HT3gQcHh",
253 | "colab_type": "text"
254 | },
255 | "source": [
256 | "### Stałe\n",
257 | "Musimy zdefinować stałe, które będą użyte do aktualizowania Q-tabeli i powiedzą agentowi kiedy przerwać trening."
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 0,
263 | "metadata": {
264 | "id": "H2SDz6x6QWwZ",
265 | "colab_type": "code",
266 | "colab": {}
267 | },
268 | "outputs": [],
269 | "source": [
270 | "EPISODES = 2000 # ile razy odpalić środowisko od początku\n",
271 | "MAX_STEPS = 100 # maksymalna ilość kroków dozwolonych na każde uruchomienie środowiska\n",
272 | "\n",
273 | "LEARNING_RATE = 0.81 # współczynnik uczenia\n",
274 | "GAMMA = 0.96"
275 | ]
276 | },
277 | {
278 | "cell_type": "markdown",
279 | "metadata": {
280 | "id": "X-9dvyenSKsz",
281 | "colab_type": "text"
282 | },
283 | "source": [
284 | "### Podjęcie akcji\n",
285 | "Możemy podjąć akcję używając jednej z dwóch metod:\n",
286 | "\n",
287 | "1. Wybierając losowo dozwoloną akcję\n",
288 | "2. Używając obecnej Q-tabeli do znalezienia najlepszej akcji"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 0,
294 | "metadata": {
295 | "id": "82x4y9KSQUy4",
296 | "colab_type": "code",
297 | "colab": {}
298 | },
299 | "outputs": [],
300 | "source": [
301 | "epsilon = 0.9 # zaczynamy z 90% szans na podjęcie losowej akcji\n",
302 | "\n",
303 | "# kod do podjęcia akcji\n",
304 | "if np.random.uniform(0, 1) < epsilon: # sprawdza czy losowo wybrana wartość jest mniejsza niż epsilon\n",
305 | " action = env.action_space.sample() # podejmuje losową akcję\n",
306 | "else:\n",
307 | " action = np.argmax(Q[state, :]) # używa Q-tabeli do podjęcia najlepszej akcji bazując na obecnych wartościach"
308 | ]
309 | },
310 | {
311 | "cell_type": "markdown",
312 | "metadata": {
313 | "id": "O5KZgya-Ugny",
314 | "colab_type": "text"
315 | },
316 | "source": [
317 | "### Aktualizacja wartości Q"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 0,
323 | "metadata": {
324 | "id": "Xv6-B2sRUoMI",
325 | "colab_type": "code",
326 | "colab": {}
327 | },
328 | "outputs": [],
329 | "source": [
330 | "#Q[state, action] = Q[state, action] + LEARNING_RATE * (reward + GAMMA * np.max(Q[new_state, :]) - Q[state, action])"
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {
336 | "id": "gOQV0406WE7P",
337 | "colab_type": "text"
338 | },
339 | "source": [
340 | "### Gotowy program złożony w całość"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 0,
346 | "metadata": {
347 | "id": "YieDu-0UWl2e",
348 | "colab_type": "code",
349 | "colab": {}
350 | },
351 | "outputs": [],
352 | "source": [
353 | "import gym\n",
354 | "import numpy as np\n",
355 | "import time\n",
356 | "\n",
357 | "env = gym.make('FrozenLake-v0')\n",
358 | "STATES = env.observation_space.n\n",
359 | "ACTIONS = env.action_space.n\n",
360 | "\n",
361 | "Q = np.zeros((STATES, ACTIONS)) \n",
362 | "\n",
363 | "EPISODES = 1500 # ile razy odpalić środowisko od początku\n",
364 | "MAX_STEPS = 100 # maksymalna ilość kroków dozwolonych na każde uruchomienie środowiska\n",
365 | "\n",
366 | "LEARNING_RATE = 0.81 # współczynnik uczenia\n",
367 | "GAMMA = 0.96\n",
368 | "\n",
369 | "RENDER = False # jeśli chcesz zobaczyć trening ustaw na True\n",
370 | "\n",
371 | "epsilon = 0.9"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": 24,
377 | "metadata": {
378 | "id": "LjQ6BcGtXJb7",
379 | "colab_type": "code",
380 | "colab": {
381 | "base_uri": "https://localhost:8080/",
382 | "height": 312.0
383 | },
384 | "outputId": "4e8203a5-d064-4eb6-ca10-517dc865d395"
385 | },
386 | "outputs": [
387 | {
388 | "name": "stdout",
389 | "output_type": "stream",
390 | "text": [
391 | "[[2.39426868e-01 1.62526398e-02 1.54562519e-02 1.59202830e-02]\n",
392 | " [1.98547854e-03 6.78832722e-03 2.04410091e-03 2.05880712e-01]\n",
393 | " [1.17939011e-01 6.88565988e-03 5.78809926e-03 6.92566173e-03]\n",
394 | " [6.13191495e-03 3.05238706e-03 2.96899524e-03 6.79343079e-03]\n",
395 | " [2.33621805e-01 1.00620775e-02 7.89787111e-03 1.35511328e-02]\n",
396 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
397 | " [7.19855689e-02 1.73767651e-04 1.74687284e-04 1.28973108e-04]\n",
398 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
399 | " [3.44127917e-03 4.78581497e-03 3.20147771e-03 3.35419577e-01]\n",
400 | " [3.29275604e-03 8.13401945e-01 1.53087650e-02 4.71493137e-03]\n",
401 | " [1.85896330e-01 2.24225014e-03 1.32615571e-03 2.17757207e-03]\n",
402 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
403 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
404 | " [6.86168854e-02 6.01188267e-02 6.41986796e-01 4.52771138e-03]\n",
405 | " [2.39216419e-01 4.87311938e-01 1.45339891e-01 2.17005939e-01]\n",
406 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n",
407 | "Average reward: 0.2986666666666667:\n"
408 | ]
409 | }
410 | ],
411 | "source": [
412 | "rewards = [] \n",
413 | "for episode in range(EPISODES):\n",
414 | "\n",
415 | " state = env.reset()\n",
416 | " for _ in range(MAX_STEPS):\n",
417 | "\n",
418 | " if RENDER:\n",
419 | " env.render()\n",
420 | "\n",
421 | " if np.random.uniform(0, 1) < epsilon:\n",
422 | " action = env.action_space.sample()\n",
423 | " else:\n",
424 | " action = np.argmax(Q[state, :])\n",
425 | "\n",
426 | " next_state, reward, done, _ = env.step(action)\n",
427 | "\n",
428 | " Q[state, action] = Q[state, action] + LEARNING_RATE * (reward + GAMMA * np.max(Q[next_state, :]) - Q[state, action])\n",
429 | "\n",
430 | " state = next_state\n",
431 | "\n",
432 | " if done:\n",
433 | " rewards.append(reward)\n",
434 | " epsilon -= 0.001\n",
435 | " break # reached goal\n",
436 | "\n",
437 | "print(Q)\n",
438 | "print(f'Average reward: {sum(rewards)/len(rewards)}:')\n",
439 | "# teraz możemy zobaczyć nasze wartośći Q!"
440 | ]
441 | },
442 | {
443 | "cell_type": "code",
444 | "execution_count": 26,
445 | "metadata": {
446 | "id": "eVZ8HfGQZy1y",
447 | "colab_type": "code",
448 | "colab": {
449 | "base_uri": "https://localhost:8080/",
450 | "height": 279.0
451 | },
452 | "outputId": "1e767d37-5342-4f58-d391-8b0e98af65ed"
453 | },
454 | "outputs": [
455 | {
456 | "data": {
457 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXgV9dn/8fdNWMKasIR938WFxYh7\nVVCK1UpbtYLaRdtaqxSr/mqxtbbS9rFWq9Vqn7rbRaSKqNjiVrBafdSC7BBARJSwJSAJaxKS3L8/\nzoQeQ0JO4Ewmyfm8ritXzsyZM/NJrmTumfnOfL/m7oiISOpqEnUAERGJlgqBiEiKUyEQEUlxKgQi\nIilOhUBEJMU1jTpAbXXq1Mn79u0bdQwRkQbl/fff3+buWVW91+AKQd++fVmwYEHUMUREGhQz+7i6\n93RpSEQkxakQiIikOBUCEZEUp0IgIpLiVAhERFKcCoGISIpTIRARSXEqBCIiSeLuvLRsM0tzC6KO\nUiuhFgIzG29mq81srZlNreL93mb2upktMrOlZvaFMPOIiIRlS2ERVz4xn+89uZCvPfofNhbsizpS\nwkIrBGaWBjwAnAsMAyaZ2bBKi90CPO3uI4GJwB/CyiMiEgZ3Z9bCXMbd8wbvrNvOD84eRFm5M+Wp\nRewvK486XkLC7GJiNLDW3dcBmNkMYAKwMm4ZB9oFrzOATSHmERFJqrxdRfx41nL+mbOV7D7tuevi\n4fTt1Jr+WW2Y8tQi7nltDTeNHxp1zBqFWQh6ABvipnOBEyst83PgVTP7PtAaOLuqFZnZVcBVAL17\n9056UBGR2nB3Xly6mVtfWM7ekjJuOe8orji1H2lNDIALhnfnnQ+38Yd/fciJ/TtyxuAq+3qrN6Ju\nLJ4EPOHuPYEvAH8xs4MyuftD7p7t7tlZWfX7Fyoijdv23cVc8+RCpjy1iD4dWzNnyul8+/T+B4pA\nhVvPP5ohXdpyw98Wk7ezKKK0iQmzEGwEesVN9wzmxfsW8DSAu78DpAOdQswkInLYXlq2mXH3vMnc\nnDxuGj+EZ68+mYGd21S5bMvmadx/6Uj2lpRx3YzFlJV7HadNXJiFYD4wyMz6mVlzYo3Bsyst8wkw\nFsDMjiJWCPJDzCQiUms79pQw5alFfO/JhXTLTOfF75/GNWcOpGnaoXehg7q0ZdqEo3ln3Xbun7e2\njtLWXmhtBO5eamaTgVeANOAxd19hZtOABe4+G7gReNjMrifWcPxNd6+/ZVNEUs5rK7dy86xlFO4r\n4cZzBnP1mQNoVkMBiHfR8T1558Pt3Dt3DaP7deDkAR1DTHt4rKHtd7Ozs10D04hI2Ar37ue2v69g\n1sKNHNWtHb+9eDjDurer+YNV2FNcyhd//xa7i0uZc93pdGrTIslpa2Zm77t7dlXvRd1YLCJS77y+\nOo9xv3uDFxZvYsqYgbxw7amHXQQAWrdoyv2XjqJg335ueHoJ5fWsvUCFQEQksKtoPz+auZQrHp9P\nRstmPH/NqdwwbgjNmx75rnJY93bcev4w3lyTz4NvrktC2uRpcGMWi4iE4a0PtnHTzCVs2VnE984c\nwA/OHkSLpmlJ3cZlJ/bmnQ+3c9erqxndrz3H9+mQ1PUfLp0RiEhK21Ncyi3PL+PyR98jvXkaM793\nCj8aPzTpRQDAzLj9wmPpnpnO96cvomBvSdK3cThUCEQkZb27bjvj732TJ9/7hO+c3o85U05nVO/2\noW6zXXoz7p80ivzdxfy/Z5ZSH27YUSEQkZR0/7wPmPjQu6SZ8cx3T+Yn5w0jvVnyzwKqMrxXJlPP\nPYp/5mzl8bfX18k2D0VtBCKScu6b+wF3v7aGL4/swa++fAytmtf9rvDKU/vyzofbuf2lHLL7tue4\nnpl1nqGCzghEJKXcPy9WBC4c1ZPfXjw8kiIAsfaCuy4+jqw2LZg8fRE7i/ZHkgNUCEQkhfzhX2u5\n69U1fGVkD35z0XE0qdRRXF3LbNWc+yaNZGPBPm6etSyy9gIVAhFJCX9840N+8/JqvjSiO3dePPyg\n3kKjkt23AzeOG8w/lm5m+n8+iSSDCoGINHoPv7mOX7+0ii8O785d9agIVLj6cwM4fVAnbntxJTmb\nd9b59lUIRKRRe+Tf6/jVnBzOO64b93x1eI09hkahSRPjnktGkNGyGddOX8ie4tK63X6dbk1EpA49\n/vZH/PIfOXzh2K7ce8mIelkEKnRq04J7J47go217+OkLy+t02/X3tyIicgT+9H/rue3FlYw/uiv3\nThxZr4tAhVMGdGLKmEHMWriRme/n1tl26/9vRkSklv7yznp+NnsF44Z14feXjqzV+AFRmzJ2ECf1\n78BPn1/O2rxddbLNUH87ZjbezFab2Vozm1rF+/eY2eLga42ZFYSZR0Qavyff+5ifvrCCs4/qwv2X\njmpQRQAgrYlx78SRtGqexrVPLqJof1no2wztN2RmacADwLnAMGCSmQ2LX8bdr3f3Ee4+Avg9MCus\nPCLS+D31n0/4yXPLGTu0Mw9cNjIp3UdHoUu7dO6+ZASrt+7ithdXhr69MH9Lo4G17r7O3UuAGcCE\nQyw/CXgqxDwi0oj9bf4n3DxrGWcNyeIPl48KpffQunTG4Cy+d+YAnvrPJ8xesinUbYVZCHoAG+Km\nc4N5BzGzPkA/YF41719lZgvMbEF+vsa2F5HPembBBqbOWsYZg7P438uPb/BFoMIN5wzm+D7t+fGs\nZazftie07dSX86aJwEx3r/JimLs/5O7Z7p6dlZVVx9FEpD579v1cbnp2KacN7MSDXzu+znoQrQvN\n0ppw36SRpDUxrp2+kOLScNoLwiwEG4FecdM9g3lVmYguC4lILT23KJf/N3MJpw7oxMNfz25URaBC\nj8yW3HXxcFZs2smjb30UyjbC7HZvPjDIzPoRKwATgUsrL2RmQ4H2wDshZhGRRuaFxRu58eklnNy/\nY6MtAhXOGdaFP14+ijOHdA5l/aGdEbh7KTAZeAXIAZ529xVmNs3MLohbdCIww+vDMD0i0iDMXrKJ\n6/+2mNH9OvDIN7Jp2bzxFoEK44/pFlqxC7UjbnefA8ypNO/WStM/DzODiDQuf18aKwLZfTvw2DdP\niGw8gcakvjQWi4jUaM6yzVw3YzGjemfyuIpA0qgQiEiD8PLyLUx5ahEjemXy+BWjad1CRSBZVAhE\npN57dcUWJk9fyHE9M3jiihNooyKQVPptikjkikvL2FpYzObCfWwuLGJT4T62FBaxqaCIzYX7WL1l\nF8f0yOCJK0fTNr1Z1HEbHRUCEQnV/rJytu4siu3gC2I7+AOvd8Z29tt2Fx/0uXbpTeme2ZJuGemc\n1L8j1509iHYqAqFQIRCRpJm1MJflG3eyuXAfmwqL2Fywj/zdxVS+ObxNi6Z0y0inW2ZLhnVrR7eM\nlsF0+oHXagOoO/pNi0hSfJi/mxueXkLLZml0y0yne0ZLBg/OoltwVN8tI/3AEb4u79QvKgQikhTz\ncvIAeO2Gz9GzfauI00ht6K4hEUmKuau2MrRrWxWBBkiFQESOWOG+/cxfv4MxQ8PpC0fCpUIgIkfs\nzTX5lJU7Y49SIWiIVAhE5IjNW5VH+1bNGNGrfdRR5DCoEIjIESkrd15fncdZQzqT1sSijiOHQYVA\nRI7Iok92ULB3P2N0WajBUiEQkSMyd1UeTZsYpw/SMLINVaiFwMzGm9lqM1trZlOrWearZrbSzFaY\n2fQw84hI8s3LyeOEvh3IaKmHxBqq0AqBmaUBDwDnAsOASWY2rNIyg4CbgVPd/WjgB2HlEZHk2/Dp\nXlZv3aW7hRq4MM8IRgNr3X2du5cAM4AJlZb5DvCAu+8AcPe8EPOISJK9vjr2L6vnBxq2MAtBD2BD\n3HRuMC/eYGCwmb1tZu+a2fiqVmRmV5nZAjNbkJ+fH1JcEamtuTl59OvUmv5ZbaKOIkcg6sbipsAg\n4ExgEvCwmWVWXsjdH3L3bHfPzspSg5RIfbCnuJR3Ptyus4FGIMxCsBHoFTfdM5gXLxeY7e773f0j\nYA2xwiAi9dzba7dRUlbOWBWCBi/MQjAfGGRm/cysOTARmF1pmeeJnQ1gZp2IXSpaF2ImEUmSeavy\naNuiKdl9O0QdRY5QaIXA3UuBycArQA7wtLuvMLNpZnZBsNgrwHYzWwm8DvzQ3beHlUlEkqO83Jm3\nKo/PDc6iedOorzDLkQp1PAJ3nwPMqTTv1rjXDtwQfIlIA7Fi007ydhWrfaCRUCkXkVqbu2orZnDm\nEN280RioEIhIrc1blcfIXpl0bNMi6iiSBCoEIlIreTuLWJpbyNijukQdRZJEhUBEakVPEzc+1TYW\nm9kh7wlz90+TH0dE6ru5OXl0z0hnaNe2UUeRJDnUXUPvAw4Y0BvYEbzOBD4B+oWeTkTqlaL9Zby1\ndhtfGdUDMw1C01hUe2nI3fu5e3/gn8AX3b2Tu3cEzgderauAIlJ/vPfRp+wtKWPsULUPNCaJtBGc\nFDwPAIC7vwScEl4kEamv5uVsJb1ZE04e0DHqKJJEiTxQtsnMbgH+GkxfBmwKL5KI1EfuztxVeZw2\nsBPpzdKijiNJlMgZwSQgC3gOmBW8nhRmKBGpfz7I203ujn2M0WWhRueQZwTBKGM/dvfr6iiPiNRT\nc3N022hjdcgzAncvA06roywiUo/NW7WVo7u3o2tGetRRJMkSaSNYZGazgWeAPRUz3X1WaKlEpF7Z\nsaeE9z/eweSzBkYdRUKQSCFIB7YDY+LmObH2AhFJAW+syafcYYy6lWiUaiwE7n5FXQQRkfpr7qo8\nOrVpwXE9MqKOIiGosRCYWTrwLeBoYmcHALj7lSHmEpF6Yn9ZOW+szmP8MV1p0kRPEzdGidw++heg\nK/B54A1iYw/vSmTlZjbezFab2Vozm1rF+980s3wzWxx8fbs24UUkfO9/vIOdRaW6bbQRS6SNYKC7\nX2xmE9z9T2Y2Hfh3TR8Kbj19ADiH2CD1881struvrLTo39x9cq2Ti0idmLcqj+ZpTThtUKeoo0hI\nEjkj2B98LzCzY4AMIJEbiUcDa919nbuXADOACYcXU0SiMjdnKyf270CbFqGObCsRSqQQPGRm7YGf\nArOBlcAdCXyuB7Ahbjo3mFfZhWa21MxmmlmvqlZkZleZ2QIzW5Cfn5/ApkUkGdZv28OH+XsYq4fI\nGrUaC4G7P+LuO9z9DXfv7+6d3f3BJG3/RaCvux8HvAb8qZoMD7l7trtnZ2VpjFSRujJvVcXTxGof\naMxqLARm9qGZPWlmV5vZ0bVY90Yg/gi/ZzDvAHff7u7FweQjwPG1WL+IhGzeqjwGdW5D746too4i\nIUrk0tAw4EGgI3BnUBieS+Bz84FBZtbPzJoDE4ldWjrAzLrFTV4A5CQWW0TCtqtoP+99tJ0xR+my\nUGOXSOtPGbEG4zKgHMgLvg7J3UvNbDLwCpAGPObuK8xsGrDA3WcDU8zsAqAU+BT45mH9FCKSdG99\nsI39Za5BaFJAIoVgJ7AMuBt42N23J7ryYECbOZXm3Rr3+mbg5kTXJyJ1Z+6qPDJaNmNU78yoo0jI\nEh2P4E3gGmCGmd1mZmPDjSUiUSord15flceZQ7JompbIbkIaskT6GnoBeMHMhgLnAj8AbgJahpxN\nRCKyJLeA7XtKNPZAikjkrqFnzWwtcC/QCvg60D7sYCISnXk5eaQ1Mc4YrNu1U0EibQS3A4uCQWpE\nJAXMXZXH8X3ak9mqedRRpA4kcvFvJXCzmT0EYGaDzOz8cGOJSFQ2FewjZ/NOPU2cQhIpBI8DJcAp\nwfRG4JehJRKRSFU8TTxWzw+kjEQKwQB3/w1B53PuvhdQp+QijdS8VXn07tCKAVltoo4idSSRQlBi\nZi2JDU+JmQ0Aig/9ERFpiPaVlPH22m2MGdoZMx3vpYpEGot/BrwM9DKzJ4FT0RPAIo3S/324jeLS\ncl0WSjGHLAQWOyRYBXwFOInYJaHr3H1bHWQTkTo2d1UerZunMbpfh6ijSB06ZCFwdzezOe5+LPCP\nOsokIhFwd+bl5HH6oCxaNE2LOo7UoUTaCBaa2QmhJxGRSK3cvJMtO4vU22gKSqSN4ETgMjP7GNhD\n7PKQB4PJiEgjMS8ndtvoWUNUCFJNIoXg86GnEJHIzV2Vx/BemWS1bRF1FKljiXQ693FdBBGR6OTv\nKmZJbgHXnz046igSgVD7lzWz8Wa22szWmtnUQyx3oZm5mWWHmUdEqvav1Xm4o95GU1RohcDM0oAH\niHVdPQyYZGbDqliuLXAd8F5YWUTk0OatyqNLuxYc3b1d1FEkAgkVAjPrY2ZnB69bBjvvmowG1rr7\nOncvAWYAE6pY7hfAHUBRgplFJIlKSst5c00+Y4Z20dPEKSqR8Qi+A8wkNoA9QE/g+QTW3QPYEDed\nG8yLX/cooJe7H/IZBTO7yswWmNmC/Pz8BDYtIon6z0efsqekTL2NprBEzgiuJdatxE4Ad/8AOOK/\nGDNrQmwc5BtrWtbdH3L3bHfPzsrSQBkiyTR31VZaNG3CqQM7RR1FIpJIISgOLu0AYGZNCTqgq8FG\noFfcdM9gXoW2wDHAv8xsPbEuLGarwVik7rg7c3PyOGVAR1o219PEqSqRQvCGmf0YaGlm5wDPAC8m\n8Ln5wCAz62dmzYGJwOyKN9290N07uXtfd+8LvAtc4O4Lav1TiMhh+TB/D598upcxR3WJOopEKJFC\nMBXIB5YB3wXmALfU9CF3LwUmA68AOcDT7r7CzKaZ2QWHH1lEkmXeqq2AbhtNdYk8UFYOPBx81Yq7\nzyFWOOLn3VrNsmfWdv0icmTm5uQxtGtbemS2jDqKRKjGQmBmyzi4TaAQWAD80t23hxFMRMJVuHc/\nCz7ewdVn9I86ikQskb6GXgLKgOnB9ESgFbAFeAL4YijJRCRUb3yQT1m5M2ao2gdSXSKF4Gx3HxU3\nvczMFrr7KDO7PKxgIhKueTlb6dC6OSN6ZUYdRSKWSGNxmpmNrpgIxiaouM+sNJRUIhKq0rJy/rUm\nnzOHZJHWRE8Tp7pEzgi+DTxmZm2IjUWwE/i2mbUGbg8znIiEY9GGAgr27mesLgsJid01NB841swy\ngunCuLefDiuYiIRnbk4eTZsYpw/W08SS2BkBZnYecDSQXtEplbtPCzGXiISgvNz58zvreeL/PuLk\nAR1pl94s6khSDyRy++gfid0ldBbwCHAR8J+Qc4lIkm34dC8/nLmEd9d9yplDsrjjQo02KzGJnBGc\n4u7HmdlSd7/NzH5L7JZSEWkA3J0n3/uE/5mTQxMz7rjwWL6a3UtdTssBiRSCinEC9ppZd2A70C28\nSCKSLBsL9vGjmUt5a+02ThvYiTsuOk5PEctBEikEL5pZJnAnsJDYU8a17m5CROqOu/PMglx+8feV\nlLnzyy8dw2Un9tZZgFTpkIUgGDNgrrsXAM+a2d+B9Ep3DolIPbKlsIibZy3l9dX5nNivA3deNJze\nHVtFHUvqsUMWAncvN7MHgJHBdDFQXBfBRKR23J3nFm3k57NXUFJWzs+/OIyvn9yXJnpgTGqQyKWh\nuWZ2ITDL3RMZkEZE6ljeriJ+8txyXlu5lew+7bnr4uH07dQ66ljSQCRSCL4L3ACUmdk+Yk8Xu7u3\nCzWZiNTI3Xlx6WZufWE5e0vKuOW8o7ji1H7qNkJqJZEni9vWRRARqZ3tu4v56QvLmbNsC8N7ZfLb\ni4czsHObqGNJA1Rjp3MWc7mZ/TSY7hXfCV0Nnx1vZqvNbK2ZTa3i/avNbJmZLTazt8xsWO1/BJHU\n8/LyzYy7503+uTKPm8YP4dmrT1YRkMOWyKWhPwDlwBjgF8Bu4AHghEN9yMzSguXOAXKB+WY2291X\nxi023d3/GCx/AXA3ML62P4RIqtixp4SfzV7B7CWbOKZHO6ZfPIIhXXXSLkcmkUJwYjD2wCIAd98R\nDEZfk9HAWndfB2BmM4AJwIFC4O4745ZvzcEjoYlI4J8rt3Lzc8vYsaeEG84ZzPfOHECztER6khc5\ntEQKwf7g6N4BzCyL2BlCTXoAG+Kmc4ETKy9kZtcSa4xuTuys4yBmdhVwFUDv3r0T2LRI41G4bz/T\nXlzJswtzGdq1LU9ccQJHd8+IOpY0IokcTtwHPAd0NrNfAW8B/5OsAO7+gLsPAH4E3FLNMg+5e7a7\nZ2dlZSVr0yL13sJPdjDunjd4fvFGpowZyOzJp6kISNIlctfQk2b2PjCW2K2jX3L3nATWvRHoFTfd\nM5hXnRnA/yawXpGUsG13Md/9y/ukN2vCc9ecwnE9NaSkhCORbqjvA2a4+wO1XPd8YJCZ9SNWACYC\nl1Za9yB3/yCYPA/4ABGhvNy5/m+LKdy3nz9feSpHddNjOxKeRNoI3gduMbMhxC4RzXD3BTV9yN1L\nzWwy8AqxMY4fc/cVZjYNWODus4HJZnY2sB/YAXzjcH8Qkcbkj29+yL8/2MavvnyMioCEzhLtNcLM\nOgAXEjuy7+3ug8IMVp3s7GxfsKDGOiTSYC1Y/ymXPPQu5x7Tld9PGqkeQyUpzOx9d8+u6r3a3Hs2\nEBgK9AFWJSOYiHzWjj0lTHlqET3bt+T2rxyrIiB1IpEni39jZh8A04DlQLa7fzH0ZCIpxt354cwl\n5O8u5v5Jo2ir8YSljiTSRvAhcLK7bws7jEgqe+zt9fwzJ4+ffXEYx/bULaJSdxK5ffRBM2sf9C+U\nHjf/zVCTiaSQJRsK+PVLOZwzrAvfPKVv1HEkxSRy++i3geuIPQewGDgJeIdqngIWkdrZWbSfyU8t\npHPbdO686Di1C0idS6Sx+DpiHcx97O5nERutrCDUVCIpwt2Z+uxSNhUUcd+kkWS2SqQbL5HkSqQQ\nFLl7EYCZtXD3VcCQcGOJpIYn3/uEOcu28MPPD+H4Pu2jjiMpKpHG4lwzywSeB14zsx3Ax+HGEmn8\nVmwqZNrfV3LG4CyuOr1/1HEkhSXSWPzl4OXPzex1IAN4OdRUIo3c7uJSvj99Ee1bNePurw7XAPMS\nqUTOCA5w9zfCCiKSKtydW55bxvrte5j+nZPo2KZF1JEkxWlUC5E69sz7uTy/eBPXjR3MSf07Rh1H\nRIVApC59sHUXt76wnJP7d2TymIFRxxEBVAhE6sy+kjKunb6Q1s2bcu/EEaSpXUDqiVq1EYjI4bvt\nxRWs2bqbP185ms7t0mv+gEgd0RmBSB14YfFGZszfwDVnDuBzgzXcqtQvoRYCMxtvZqvNbK2ZTa3i\n/RvMbKWZLTWzuWbWJ8w8IlH4aNsefjxrGdl92nPDOYOjjiNykNAKgZmlAQ8A5wLDgElmNqzSYouI\ndWt9HDAT+E1YeUSiULS/jGufXEizpk24b9JImqbpJFzqnzD/KkcDa919nbuXEBucfkL8Au7+urvv\nDSbfJdaxnUijcfucHFZu3sldFw2ne2bLqOOIVCnMQtAD2BA3nRvMq863gJeqesPMrjKzBWa2ID8/\nP4kRRcLz8vLN/Omdj/n2af04e1iXqOOIVKtenKea2eVANnBnVe+7+0Punu3u2VlZamiT+m/Dp3v5\n4cylDO+ZwU3jh0YdR+SQwrx9dCPQK266ZzDvM8zsbOAnwBnuXhxiHpE6UVJazuSnFgFw/6WjaN60\nXhxviVQrzL/Q+cAgM+tnZs2BicDs+AXMbCTwIHCBu+eFmEWkztz5yiqWbCjgjguPo1eHVlHHEalR\naIXA3UuBycArQA7wtLuvMLNpZnZBsNidQBvgGTNbbGazq1mdSIMwN2crD//7I752Uh++cGy3qOOI\nJCTUJ4vdfQ4wp9K8W+Nenx3m9kXq0ubCfdz4zBKO6taOn5x3VNRxRBKmi5ciSVBaVs6UpxZRUlrO\nA5eOJL1ZWtSRRBKmvoZEkuDhf3/E/PU7uOeS4fTPahN1HJFa0RmByBFav20Pv/vnGsYN68KXR+qZ\nSGl4VAhEjoC785Pnl9E8rQnTJhwTdRyRw6JCIHIEnl24kbfXbuemc4fSNUNdS0vDpEIgcpi27S7m\nl/9YSXaf9lw2unfUcUQOmwqByGH6xd9Xsqe4lNu/cixNNNqYNGAqBCKH4fXVebyweBPXnDmQQV3a\nRh1H5IioEIjU0p7iUm55bjkDslpzzVkDoo4jcsT0HIFILd392ho2FuzjmatPpkVTPTgmDZ/OCERq\nYcmGAh5/+yMuPbE3J/TtEHUckaRQIRBJ0P6ycqbOWkanNi2Yeq7GGJDGQ5eGRBL06FsfkbN5J3+8\nfBTt0ptFHUckaXRGIJKAj7fv4Z7XYt1IjD9G3UtL46JCIFIDd+fHz6kbCWm8Qi0EZjbezFab2Voz\nm1rF+58zs4VmVmpmF4WZReRwqRsJaexCKwRmlgY8AJwLDAMmmdmwSot9AnwTmB5WDpEjoW4kJBWE\n2Vg8Gljr7usAzGwGMAFYWbGAu68P3isPMYfIYVM3EpIKwrw01APYEDedG8wTaRDUjYSkigbRWGxm\nV5nZAjNbkJ+fH3UcSQEV3UgM7NxG3UhIoxdmIdgI9Iqb7hnMqzV3f8jds909OysrKynhRA6lohuJ\n279yrLqRkEYvzEIwHxhkZv3MrDkwEZgd4vZEkmJprrqRkNQSWiFw91JgMvAKkAM87e4rzGyamV0A\nYGYnmFkucDHwoJmtCCuPSCL2l5Uz9Vl1IyGpJdQuJtx9DjCn0rxb417PJ3bJSKReePStj1ipbiQk\nxTSIxmKRuqBuJCRVqRCIoG4kJLWpEIigbiQktakQSMpTNxKS6lQIJOWpGwlJdSoEktLUjYSICoGk\nsL0l6kZCBDRUpaSwu1+NdSPxzNUnqxsJSWk6I5CUtDS3gMfUjYQIoEIgKUjdSIh8li4NSb1UVu7k\n7ypmd3Fp0tc9e8kmdSMhEkeFQOpcebmzbXcxmwuL2Fy4j00Fse+x6SI2F+xj665iyso9tAzqRkLk\nv1QIJKncne17SthcUMSmwn1sLtjH5p1FbI7b2W/dWcT+ss/u5Fs0bUL3zJZ0bZfOSQM60j2jJd0y\n02nToilmyb23v1kT46yhnZO6TpGGTIWgkdpTXHrw0XZBEfm7wznS3re/jC2FRWwpLKKk7LNDUDdP\na0LXjHS6ZaST3ac93TJb0j0jnW4ZLemakU73zJa0b9Us6Tt8EUmMCkEDtK+k7MDOfVNB3CWVwn0H\njsR3FR18bT2rbQs6t21B04V3axIAAAo7SURBVLTk3yPQIq0JI3pl0u2Y2A4/trOP7eg7tm6uJ3ZF\n6jEVgnqmKDiyPrBjD3b2WwqL2BTMK9i7/6DPdWzdnG6Z6fTu2IoT+3egW0ZLumfGjrq7ZaTTpV06\nzZvqJjEROViohcDMxgP3AmnAI+7+60rvtwD+DBwPbAcucff1YWaKUklpOVt3Bjv2nUUHLttsKihi\ny87Y0fz2PSUHfS6zVbPYjj0jnVG9M+meGdu5V+zsu7RLJ72ZHogSkcMTWiEwszTgAeAcIBeYb2az\n3X1l3GLfAna4+0AzmwjcAVwSVqYwlZaVs3VXMZsL9rGpsIgtVdwNs213MV7p8ny79Kaxo/bMdI7t\nkUn3jPQD1827Ba9bNdeJm4iEJ8w9zGhgrbuvAzCzGcAEIL4QTAB+HryeCdxvZuZeeXd55J6ev4GH\n/70u2avFgV1F+8nfVUzlNtg2LZoeuF4+rFu72A4+2OlXHNG3bqGdvIhEK8y9UA9gQ9x0LnBidcu4\ne6mZFQIdgW3xC5nZVcBVAL17H15/8ZmtmjGoS5vD+mxNWjdveuBOmPij+bZ6WElEGoAGcTjq7g8B\nDwFkZ2cf1tnCuKO7Mu7orknNJSLSGIR5G8lGoFfcdM9gXpXLmFlTIINYo7GIiNSRMAvBfGCQmfUz\ns+bARGB2pWVmA98IXl8EzAujfUBERKoX2qWh4Jr/ZOAVYrePPubuK8xsGrDA3WcDjwJ/MbO1wKfE\nioWIiNShUNsI3H0OMKfSvFvjXhcBF4eZQUREDk2PmoqIpDgVAhGRFKdCICKS4lQIRERSnDW0uzXN\nLB/4+DA/3olKTy3Xcw0pb0PKCg0rb0PKCg0rb0PKCkeWt4+7Z1X1RoMrBEfCzBa4e3bUORLVkPI2\npKzQsPI2pKzQsPI2pKwQXl5dGhIRSXEqBCIiKS7VCsFDUQeopYaUtyFlhYaVtyFlhYaVtyFlhZDy\nplQbgYiIHCzVzghERKQSFQIRkRSXMoXAzMab2WozW2tmU6POUx0z62Vmr5vZSjNbYWbXRZ0pEWaW\nZmaLzOzvUWc5FDPLNLOZZrbKzHLM7OSoMx2KmV0f/B0sN7OnzCw96kzxzOwxM8szs+Vx8zqY2Wtm\n9kHwvX2UGStUk/XO4G9hqZk9Z2aZUWasUFXWuPduNDM3s07J2l5KFAIzSwMeAM4FhgGTzGxYtKmq\nVQrc6O7DgJOAa+tx1njXATlRh0jAvcDL7j4UGE49zmxmPYApQLa7H0OsO/f61lX7E8D4SvOmAnPd\nfRAwN5iuD57g4KyvAce4+3HAGuDmug5VjSc4OCtm1gsYB3ySzI2lRCEARgNr3X2du5cAM4AJEWeq\nkrtvdveFwetdxHZUPaJNdWhm1hM4D3gk6iyHYmYZwOeIjYOBu5e4e0G0qWrUFGgZjODXCtgUcZ7P\ncPc3iY0lEm8C8Kfg9Z+AL9VpqGpUldXdX3X30mDyXWIjKUaumt8rwD3ATUBS7/JJlULQA9gQN51L\nPd+5AphZX2Ak8F60SWr0O2J/nOVRB6lBPyAfeDy4jPWImbWOOlR13H0jcBexo7/NQKG7vxptqoR0\ncffNwestQJcow9TClcBLUYeojplNADa6+5JkrztVCkGDY2ZtgGeBH7j7zqjzVMfMzgfy3P39qLMk\noCkwCvhfdx8J7KH+XLY4SHBtfQKxAtYdaG1ml0ebqnaCoWfr/T3qZvYTYpdln4w6S1XMrBXwY+DW\nmpY9HKlSCDYCveKmewbz6iUza0asCDzp7rOizlODU4ELzGw9sUtuY8zsr9FGqlYukOvuFWdYM4kV\nhvrqbOAjd8939/3ALOCUiDMlYquZdQMIvudFnOeQzOybwPnAZfV4zPQBxA4IlgT/az2BhWbWNRkr\nT5VCMB8YZGb9zKw5sQa32RFnqpKZGbFr2DnufnfUeWri7je7e09370vs9zrP3evlUau7bwE2mNmQ\nYNZYYGWEkWryCXCSmbUK/i7GUo8bt+PMBr4RvP4G8EKEWQ7JzMYTu6x5gbvvjTpPddx9mbt3dve+\nwf9aLjAq+Js+YilRCILGoMnAK8T+kZ529xXRpqrWqcDXiB1ZLw6+vhB1qEbk+8CTZrYUGAH8T8R5\nqhWcucwEFgLLiP2/1qsuEczsKeAdYIiZ5ZrZt4BfA+eY2QfEzmp+HWXGCtVkvR9oC7wW/K/9MdKQ\ngWqyhre9+nsmJCIidSElzghERKR6KgQiIilOhUBEJMWpEIiIpDgVAhGRFKdCII2WmU0zs7OTsJ7d\nScrzOzP7XPB6ctAT7md6kbSY+4L3lprZqLj3vhH06PmBmX0jbv76GrY7w8wGJeNnkMZJt4+K1MDM\ndrt7myNcR0fgH+5+UjA9EtgB/ItY76LbgvlfIPaswxeAE4F73f1EM+sALACyiXXZ8D5wvLvvMLP1\nwUNG1W37DOByd//OkfwM0njpjEAaDDO73Mz+Ezz482DQvThmttvM7gn67Z9rZlnB/CfM7KLg9a8t\nNsbDUjO7K5jX18zmBfPmmlnvYH4/M3vHzJaZ2S8rZfihmc0PPnNbMK+1mf3DzJZYbNyAS6qIfyHw\ncsWEuy9y9/VVLDcB+LPHvAtkBt00fB54zd0/dfcdxLpPruimOL+GHP8Gzg56MBU5iAqBNAhmdhRw\nCXCqu48AyoDLgrdbAwvc/WjgDeBnlT7bEfgycHTQ73zFzv33wJ+CeU8C9wXz7yXWMd2xxHr9rFjP\nOGAQsW7NRwDHB5d6xgOb3H14MG7AgR1+nFOJHcXXpLqecqvtQdfdTwjmVZnD3cuBtcTGXxA5iAqB\nNBRjgeOB+Wa2OJjuH7xXDvwteP1X4LRKny0EioBHzewrQEWfMicD04PXf4n73KnAU3HzK4wLvhYR\n6/ZhKLHCsIxYlwp3mNnp7l5YRf5uBEfuITpUjjxiPZiKHESFQBoKI3b0PiL4GuLuP69m2c80fAV9\nTY0m1m/P+VR9xH7IdcRluD0uw0B3f9Td1xDrxXQZ8Eszq6qr4H1AIsNMVtdTbo096NaQIz3IIHIQ\nFQJpKOYCF5lZZzgwLm6f4L0mwEXB60uBt+I/GIztkOHuc4Dr+e8lkv/jv0M/XkbsWjrA25XmV3gF\nuDJYH2bWw8w6m1l3YK+7/xW4k6q7ts4BBibwc84Gvh7cPXQSscFoNgfbHmdm7S02TsG4YF78z3mo\nHIOBg8a/FYHYQB0i9Z67rzSzW4BXzawJsB+4FviY2AAzo4P384i1JcRrC7xgsYHfDbghmP99YqOV\n/ZDYZZsrgvnXAdPN7EfEdaHs7q8GbRXvxHqFZjdwObEd/J1mVh7k+l4VP8I/gO8SDOdpZlOIdX/c\nFVhqZnPc/dvAHGJ3DK0ldgnrimDbn5rZL4h1qQ4wzd0rD2V4bFU5zKwLsC9ZXRZL46PbR6XBS8bt\nnXXBzN4Czq/rcZLN7Hpgp7s/WpfblYZDl4ZE6s6NQO8ItlvAfweTFzmIzghERFKczghERFKcCoGI\nSIpTIRARSXEqBCIiKU6FQEQkxf1/wWUIKH9sIAUAAAAASUVORK5CYII=\n",
458 | "text/plain": [
459 | ""
460 | ]
461 | },
462 | "metadata": {
463 | "tags": []
464 | },
465 | "output_type": "display_data"
466 | }
467 | ],
468 | "source": [
469 | "# możemy narysować postęp trenowania i zobaczyć jak agent się polepsza\n",
470 | "import matplotlib.pyplot as plt\n",
471 | "\n",
472 | "def get_average(values):\n",
473 | " return sum(values) / len(values)\n",
474 | "\n",
475 | "avg_rewards = []\n",
476 | "for i in range(0, len(rewards), 100):\n",
477 | " avg_rewards.append(get_average(rewards[i:i+100]))\n",
478 | "\n",
479 | "plt.plot(avg_rewards)\n",
480 | "plt.ylabel('average reward')\n",
481 | "plt.xlabel('episodes (100\\'s)')\n",
482 | "plt.show()"
483 | ]
484 | },
485 | {
486 | "cell_type": "markdown",
487 | "metadata": {
488 | "id": "sL5_fLdadMw6",
489 | "colab_type": "text"
490 | },
491 | "source": [
492 | "### Źródło:\n",
493 | "[https://www.youtube.com/watch?v=tPYj3fFJGjk](https://www.youtube.com/watch?v=tPYj3fFJGjk)"
494 | ]
495 | }
496 | ],
497 | "metadata": {
498 | "colab": {
499 | "name": "q_learning.ipynb",
500 | "provenance": [],
501 | "collapsed_sections": [],
502 | "authorship_tag": "ABX9TyMShg/BChmNRnh4JPpurMig",
503 | "include_colab_link": true
504 | },
505 | "kernelspec": {
506 | "name": "python3",
507 | "display_name": "Python 3"
508 | }
509 | },
510 | "nbformat": 4,
511 | "nbformat_minor": 0
512 | }
513 |
--------------------------------------------------------------------------------
/machine_learning/reinforcement_learning/gridworld/gridworld.py:
--------------------------------------------------------------------------------
1 | ACTION_SPACE = {'u', 'd', 'l', 'r'}
2 |
3 |
4 | class Grid:
5 | def __init__(self, rows, cols, start):
6 | self.rows = rows
7 | self.cols = cols
8 | self.i = start[0]
9 | self.j = start[1]
10 |
11 | def set(self, rewards, actions):
12 | self.rewards = rewards
13 | self.actions = actions
14 |
15 | def set_state(self, s):
16 | self.i = s[0]
17 | self.j = s[1]
18 |
19 | def current_state (self):
20 | return self.i, self.j
21 |
22 | def reset(self):
23 | self.i = 2
24 | self.j = 0
25 | return (self.i, self.j)
26 |
27 | def is_terminal(self, state):
28 | return state not in self.actions
29 |
30 | def get_next_state(self, state, action):
31 | i = state[0]
32 | j = state[1]
33 | if action in self.actions[(i, j)]:
34 | if action == 'u':
35 | i -= 1
36 | elif action == 'd':
37 | i += 1
38 | elif action == 'r':
39 | j += 1
40 | elif action == 'l':
41 | j -= 1
42 | return i, j
43 |
44 | def move(self, action):
45 | # Check if move is legal
46 | if action in self.actions[(self.i, self.j)]:
47 | if action == 'u':
48 | self.i -= 1
49 | elif action == 'd':
50 | self.i += 1
51 | elif action == 'r':
52 | self.j += 1
53 | elif action == 'l':
54 | self.j -= 1
55 | return self.rewards.get((self.i, self.j), 0)
56 |
57 | def undo_move(self, action):
58 | if action == 'u':
59 | self.i += 1
60 | elif action == 'd':
61 | self.i -= 1
62 | elif action == 'r':
63 | self.j -= 1
64 | elif action == 'l':
65 | self.j += 1
66 |
67 | def game_over(self):
68 | return (self.i, self.j) not in self.actions
69 |
70 | def all_states(self):
71 | return set(self.actions.keys()) | set(self.rewards.keys())
72 |
73 |
74 | def standard_grid():
75 | g = Grid(3, 4, (2, 0))
76 | rewards = {(0, 3): 1, (1, 3): -1}
77 | actions = {
78 | (0, 0): ('d', 'r'),
79 | (0, 1): ('l', 'r'),
80 | (0, 2): ('l', 'd', 'r'),
81 | (1, 0): ('u', 'd'),
82 | (1, 2): ('u', 'd', 'r'),
83 | (2, 0): ('u', 'r'),
84 | (2, 1): ('l', 'r'),
85 | (2, 2): ('l', 'r', 'u'),
86 | (2, 3): ('l', 'u')
87 | }
88 | g.set(rewards, actions)
89 | return g
90 |
91 |
92 | def negative_grid(step_cost=-0.1):
93 | g = standard_grid()
94 | g.rewards.update({
95 | (0, 0): step_cost,
96 | (0, 1): step_cost,
97 | (0, 2): step_cost,
98 | (1, 0): step_cost,
99 | (1, 2): step_cost,
100 | (2, 0): step_cost,
101 | (2, 1): step_cost,
102 | (2, 2): step_cost,
103 | (2, 3): step_cost
104 | })
105 | return g
106 |
--------------------------------------------------------------------------------
/machine_learning/reinforcement_learning/gridworld/iterative_policy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from machine_learning.reinforcement_learning.gridworld.gridworld import (
3 | standard_grid, ACTION_SPACE
4 | )
5 | converge_thresh = 0.001
6 |
7 |
8 | def print_values(V, grid):
9 | for i in range(grid.rows):
10 | print('------------------------')
11 | for j in range(grid.cols):
12 | v = V.get((i, j), 0)
13 | if v >= 0:
14 | print(' %.2f|' % v, end='')
15 | else:
16 | print('%.2f|' % v, end='')
17 | print()
18 |
19 |
20 | def print_policy(P, grid):
21 | for i in range(grid.rows):
22 | print('------------------------')
23 | for j in range(grid.cols):
24 | a = P.get((i, j), ' ')
25 | print(' %s |' % a, end='')
26 | print()
27 |
28 |
29 | if __name__ == '__main__':
30 | # Define transition probabilities and grid
31 | transition_probs = {}
32 | rewards = {}
33 |
34 | grid = standard_grid()
35 | for i in range(grid.rows):
36 | for j in range(grid.cols):
37 | s = (i, j)
38 | if not grid.is_terminal(s):
39 | for a in ACTION_SPACE:
40 | s2 = grid.get_next_state(s, a)
41 | transition_probs[(s, a, s2)] = 1
42 | if s2 in grid.rewards:
43 | rewards[(s, a, s2)] = grid.rewards[s2]
44 |
45 | policy = {
46 | (2, 0): 'u',
47 | (1, 0): 'u',
48 | (0, 0): 'r',
49 | (0, 1): 'r',
50 | (0, 2): 'r',
51 | (1, 2): 'u',
52 | (2, 1): 'r',
53 | (2, 2): 'u',
54 | (2, 3): 'l'
55 | }
56 | print_policy(policy, grid)
57 |
58 | V = {} # Initialize V(s) = 0
59 | for s in grid.all_states():
60 | V[s] = 0
61 |
62 | gamma = 0.9 # Discount factor
63 |
64 | # Repeat until convergence
65 | iter = 1
66 | while True:
67 | biggest_change = 0
68 | for s in grid.all_states():
69 | if not grid.is_terminal(s):
70 | old_v = V[s]
71 | new_v = 0
72 | for a in ACTION_SPACE:
73 | for s2 in grid.all_states():
74 | action_prob = 1 if policy.get(s) == a else 0
75 | r = rewards.get((s, a, s2), 0)
76 | new_v += action_prob * transition_probs.get((s, a, s2), 0) * (r + gamma * V[s2])
77 |
78 | V[s] = new_v
79 | biggest_change = max(biggest_change, np.abs(old_v - V[s]))
80 |
81 | print('Iter:', iter, 'biggest change:', biggest_change)
82 | print_values(V, grid)
83 | iter += 1
84 |
85 | if biggest_change < converge_thresh:
86 | break
87 | print()
88 |
--------------------------------------------------------------------------------
/machine_learning/reinforcement_learning/gridworld/monte_carlo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def max_dict(d):
5 | # Returns the argmax (key) and max (value) from a dict
6 | max_val = max(d.values())
7 | max_keys = [key for key, val in d.items() if val == max_val]
8 | return np.random.choice(max_keys), max_val
9 |
--------------------------------------------------------------------------------
/machine_learning/reinforcement_learning/gridworld/q_learning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from machine_learning.reinforcement_learning.gridworld.gridworld import (
4 | negative_grid
5 | )
6 | from machine_learning.reinforcement_learning.gridworld.iterative_policy import (
7 | print_values, print_policy
8 | )
9 | from machine_learning.reinforcement_learning.gridworld.monte_carlo import (
10 | max_dict
11 | )
12 | gamma = 0.9
13 | alpha = 0.1
14 | action_space = ['u', 'd', 'l', 'r']
15 |
16 |
17 | def epsilon_greedy(Q, s, eps=0.1):
18 | if np.random.random() < eps:
19 | return np.random.choice(action_space)
20 | else:
21 | a_optimal = max_dict(Q[s])[0]
22 | return a_optimal
23 |
24 |
25 | if __name__ == '__main__':
26 | grid = negative_grid(step_cost=-0.1)
27 |
28 | print('Rewards')
29 | print_values(grid.rewards, grid)
30 |
31 | # Initialize Q(s, a) = 0
32 | Q = {}
33 | states = grid.all_states()
34 | for s in states:
35 | Q[s] = {}
36 | for a in action_space:
37 | Q[s][a] = 0
38 |
39 | update_counts = {}
40 |
41 | reward_per_episode = []
42 | for i in range(10000):
43 | if i % 2000 == 0:
44 | print('Iter:', i)
45 |
46 | # Begin a new episode
47 | s = grid.reset()
48 | a = epsilon_greedy(Q, s, eps=0.1)
49 | episode_reward = 0
50 | while not grid.game_over():
51 | # Perform action and get next state + reward
52 | a = epsilon_greedy(Q, s, eps=0.1)
53 | r = grid.move(a)
54 | s2 = grid.current_state()
55 | # Update reward
56 | episode_reward += r
57 | # Get next action
58 | maxQ = max_dict(Q[s2])[1]
59 | a2 = epsilon_greedy(Q, s2, eps=0.1)
60 | # Update Q(s, a)
61 | Q[s][a] += alpha * (r + gamma * maxQ - Q [s][a])
62 | # Check how often Q(s) is updated
63 | update_counts[s] = update_counts.get(s, 0) + 1
64 | # Next state becomes current state
65 | s = s2
66 |
67 | # Log the reward for this episode
68 | reward_per_episode.append(episode_reward)
69 |
70 | plt.plot(reward_per_episode)
71 | plt.title('Reward per episode')
72 | plt.show()
73 |
74 | # Determine the policy from Q*
75 | # Find V* from Q*
76 | policy = {}
77 | V = {}
78 | for s in grid.actions.keys():
79 | a, max_q = max_dict(Q[s])
80 | policy[s] = a
81 | V[s] = max_q
82 |
83 | # The proportion of time we spend updating each part of Q
84 | print('Update counts:')
85 | total = np.sum(list(update_counts.values()))
86 | for k, v in update_counts.items():
87 | update_counts[k] = float(v) / total
88 |
89 | print_values(update_counts, grid)
90 | print('Values:')
91 | print_values(V, grid)
92 | print('Policy:')
93 | print_policy(policy, grid)
94 |
--------------------------------------------------------------------------------
/machine_learning/reinforcement_learning/gridworld/sarsa.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from machine_learning.reinforcement_learning.gridworld.gridworld import (
4 | negative_grid
5 | )
6 | from machine_learning.reinforcement_learning.gridworld.iterative_policy import (
7 | print_values, print_policy
8 | )
9 | from machine_learning.reinforcement_learning.gridworld.monte_carlo import (
10 | max_dict
11 | )
12 | gamma = 0.9
13 | alpha = 0.1
14 | action_space = ['u', 'd', 'l', 'r']
15 |
16 |
17 | def epsilon_greedy(Q, s, eps=0.1):
18 | if np.random.random() < eps:
19 | return np.random.choice(action_space)
20 | else:
21 | a_optimal = max_dict(Q[s])[0]
22 | return a_optimal
23 |
24 |
25 | if __name__ == '__main__':
26 | grid = negative_grid(step_cost=-0.1)
27 |
28 | print('Rewards')
29 | print_values(grid.rewards, grid)
30 |
31 | # Initialize Q(s, a) = 0
32 | Q = {}
33 | states = grid.all_states()
34 | for s in states:
35 | Q[s] = {}
36 | for a in action_space:
37 | Q[s][a] = 0
38 |
39 | update_counts = {}
40 |
41 | reward_per_episode = []
42 | for i in range(10000):
43 | if i % 2000 == 0:
44 | print('Iter:', i)
45 |
46 | # Begin a new episode
47 | s = grid.reset()
48 | a = epsilon_greedy(Q, s, eps=0.1)
49 | episode_reward = 0
50 | while not grid.game_over():
51 | # Perform action and get next state + reward
52 | r = grid.move(a)
53 | s2 = grid.current_state()
54 | # Update reward
55 | episode_reward += r
56 | # Get next action
57 | a2 = epsilon_greedy(Q, s2, eps=0.1)
58 | # Update Q(s, a)
59 | Q[s][a] += alpha * (r + gamma * Q[s2][a2] - Q [s][a])
60 | # Check how often Q(s) is updated
61 | update_counts[s] = update_counts.get(s, 0) + 1
62 | # Next state becomes current state
63 | s = s2
64 | a = a2
65 |
66 | # Log the reward for this episode
67 | reward_per_episode.append(episode_reward)
68 |
69 | plt.plot(reward_per_episode)
70 | plt.title('Reward per episode')
71 | plt.show()
72 |
73 | # Determine the policy from Q*
74 | # Find V* from Q*
75 | policy = {}
76 | V = {}
77 | for s in grid.actions.keys():
78 | a, max_q = max_dict(Q[s])
79 | policy[s] = a
80 | V[s] = max_q
81 |
82 | # The proportion of time we spend updating each part of Q
83 | print('Update counts:')
84 | total = np.sum(list(update_counts.values()))
85 | for k, v in update_counts.items():
86 | update_counts[k] = float(v) / total
87 |
88 | print_values(update_counts, grid)
89 | print('Values:')
90 | print_values(V, grid)
91 | print('Policy:')
92 | print_policy(policy, grid)
93 |
--------------------------------------------------------------------------------
/machine_learning/supervised_learning/linear_algebra/cramer_rule.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | " "
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "id": "MgP9hj2kWbfE"
17 | },
18 | "source": [
19 | "#Solving ML problems using Cramer's rule"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 19,
25 | "metadata": {
26 | "id": "fbqkBf7_SXwD"
27 | },
28 | "outputs": [],
29 | "source": [
30 | "import numpy as np\n",
31 | "import scipy\n",
32 | "from sklearn.datasets import load_iris"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 33,
38 | "metadata": {
39 | "colab": {
40 | "base_uri": "https://localhost:8080/"
41 | },
42 | "id": "eIQQgjrXSkrp",
43 | "outputId": "1abc3830-36b0-40df-b97e-3060354ea142"
44 | },
45 | "outputs": [
46 | {
47 | "name": "stdout",
48 | "output_type": "stream",
49 | "text": [
50 | "(4, 4)\n",
51 | "(4,)\n"
52 | ]
53 | }
54 | ],
55 | "source": [
56 | "iris = load_iris()\n",
57 | "X = iris.data[-4:]\n",
58 | "y = iris.target[-4:]\n",
59 | "\n",
60 | "print(X.shape)\n",
61 | "print(y.shape)"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 17,
67 | "metadata": {
68 | "colab": {
69 | "base_uri": "https://localhost:8080/"
70 | },
71 | "id": "xXAlqluKUe6D",
72 | "outputId": "1070973d-52ef-45b2-aa83-9257c4a1c421"
73 | },
74 | "outputs": [
75 | {
76 | "data": {
77 | "text/plain": [
78 | "array([-0.03868472, -0.1934236 , 0.61895551, -0.1934236 ])"
79 | ]
80 | },
81 | "execution_count": 17,
82 | "metadata": {},
83 | "output_type": "execute_result"
84 | }
85 | ],
86 | "source": [
87 | "w = np.linalg.solve(X, y)\n",
88 | "w"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 37,
94 | "metadata": {
95 | "colab": {
96 | "base_uri": "https://localhost:8080/"
97 | },
98 | "id": "UKb6KbjWjUCF",
99 | "outputId": "913fd481-5c51-4b9e-f414-642040421738"
100 | },
101 | "outputs": [
102 | {
103 | "name": "stderr",
104 | "output_type": "stream",
105 | "text": [
106 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:1: FutureWarning: `rcond` parameter will change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.\n",
107 | "To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.\n",
108 | " \"\"\"Entry point for launching an IPython kernel.\n"
109 | ]
110 | },
111 | {
112 | "data": {
113 | "text/plain": [
114 | "array([-0.03868472, -0.1934236 , 0.61895551, -0.1934236 ])"
115 | ]
116 | },
117 | "execution_count": 37,
118 | "metadata": {},
119 | "output_type": "execute_result"
120 | }
121 | ],
122 | "source": [
123 | "v = np.linalg.lstsq(X, y)\n",
124 | "v[0]"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 36,
130 | "metadata": {
131 | "colab": {
132 | "base_uri": "https://localhost:8080/"
133 | },
134 | "id": "HUMqhE72ew17",
135 | "outputId": "3649039a-a2ac-4320-e950-65b8e08e48e6"
136 | },
137 | "outputs": [
138 | {
139 | "data": {
140 | "text/plain": [
141 | "1.9999999999999996"
142 | ]
143 | },
144 | "execution_count": 36,
145 | "metadata": {},
146 | "output_type": "execute_result"
147 | }
148 | ],
149 | "source": [
150 | "y_pred = 0\n",
151 | "for i in range(4):\n",
152 | " y_pred += X[0][i] * w[i]\n",
153 | "\n",
154 | "y_pred"
155 | ]
156 | }
157 | ],
158 | "metadata": {
159 | "colab": {
160 | "name": "cramer_rule.ipynb",
161 | "provenance": [],
162 | "authorship_tag": "ABX9TyNNdU+ywNicUIyC5YaQYV2W",
163 | "include_colab_link": true
164 | },
165 | "kernelspec": {
166 | "name": "python3",
167 | "display_name": "Python 3"
168 | },
169 | "language_info": {
170 | "name": "python"
171 | }
172 | },
173 | "nbformat": 4,
174 | "nbformat_minor": 0
175 | }
176 |
--------------------------------------------------------------------------------
/machine_learning/unsupervised_learning/dimensionality_reduction/some_play_with_svd.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "some_play_with_svd.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyPQEYfVDZBXXdRFHs3wrTL+",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "language_info": {
17 | "name": "python"
18 | }
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "view-in-github",
25 | "colab_type": "text"
26 | },
27 | "source": [
28 | " "
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "XMJDsJWKB1xM"
35 | },
36 | "source": [
37 | "#Singular Value Decomposition"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "W-QTRSr29MTe"
44 | },
45 | "source": [
46 | "import numpy as np\n",
47 | "import matplotlib.pyplot as plt"
48 | ],
49 | "execution_count": null,
50 | "outputs": []
51 | },
52 | {
53 | "cell_type": "code",
54 | "metadata": {
55 | "colab": {
56 | "base_uri": "https://localhost:8080/"
57 | },
58 | "id": "Y1mPvwgg_dsa",
59 | "outputId": "53331e55-0c82-4f8f-9895-5e9f2355c504"
60 | },
61 | "source": [
62 | "from tensorflow.keras.datasets.mnist import load_data\n",
63 | "(X_train, _), (X_test, _) = load_data()"
64 | ],
65 | "execution_count": 26,
66 | "outputs": [
67 | {
68 | "output_type": "stream",
69 | "text": [
70 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
71 | "11493376/11490434 [==============================] - 0s 0us/step\n"
72 | ],
73 | "name": "stdout"
74 | }
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "metadata": {
80 | "colab": {
81 | "base_uri": "https://localhost:8080/",
82 | "height": 282
83 | },
84 | "id": "pss2K3y3_mpC",
85 | "outputId": "e744e489-6ff3-4526-e983-8e05360cdf66"
86 | },
87 | "source": [
88 | "plt.imshow(X_train[0], cmap='viridis')"
89 | ],
90 | "execution_count": 56,
91 | "outputs": [
92 | {
93 | "output_type": "execute_result",
94 | "data": {
95 | "text/plain": [
96 | ""
97 | ]
98 | },
99 | "metadata": {
100 | "tags": []
101 | },
102 | "execution_count": 56
103 | },
104 | {
105 | "output_type": "display_data",
106 | "data": {
107 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOZ0lEQVR4nO3dbYxc5XnG8euKbezamMQbB9chLjjgFAg0Jl0ZEBZQobgOqgSoCsSKIkJpnSY4Ca0rQWlV3IpWbpUQUUqRTHExFS+BBIQ/0CTUQpCowWWhBgwEDMY0NmaNWYENIX5Z3/2w42iBnWeXmTMv3vv/k1Yzc+45c24NXD5nznNmHkeEAIx/H+p0AwDag7ADSRB2IAnCDiRB2IEkJrZzY4d5ckzRtHZuEkjlV3pbe2OPR6o1FXbbiyVdJ2mCpH+LiJWl50/RNJ3qc5rZJICC9bGubq3hw3jbEyTdIOnzkk6UtMT2iY2+HoDWauYz+wJJL0TE5ojYK+lOSedV0xaAqjUT9qMk/WLY4621Ze9ie6ntPtt9+7Snic0BaEbLz8ZHxKqI6I2I3kma3OrNAaijmbBvkzRn2ONP1JYB6ELNhP1RSfNsz7V9mKQvSlpbTVsAqtbw0FtE7Le9TNKPNDT0tjoinq6sMwCVamqcPSLul3R/Rb0AaCEulwWSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJpmZxRffzxPJ/4gkfm9nS7T/3F8fUrQ1OPVBc9+hjdxTrU7/uYv3Vaw+rW3u893vFdXcOvl2sn3r38mL9uD9/pFjvhKbCbnuLpN2SBiXtj4jeKpoCUL0q9uy/FxE7K3gdAC3EZ3YgiWbDHpJ+bPsx20tHeoLtpbb7bPft054mNwegUc0exi+MiG22j5T0gO2fR8TDw58QEaskrZKkI9wTTW4PQIOa2rNHxLba7Q5J90paUEVTAKrXcNhtT7M9/eB9SYskbayqMQDVauYwfpake20ffJ3bI+KHlXQ1zkw4YV6xHpMnFeuvnPWRYv2d0+qPCfd8uDxe/JPPlMebO+k/fzm9WP/Hf1lcrK8/+fa6tZf2vVNcd2X/54r1j//k0PtE2nDYI2KzpM9U2AuAFmLoDUiCsANJEHYgCcIOJEHYgST4imsFBs/+bLF+7S03FOufmlT/q5jj2b4YLNb/5vqvFOsT3y4Pf51+97K6tenb9hfXnbyzPDQ3tW99sd6N2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs1dg8nOvFOuP/WpOsf6pSf1VtlOp5dtPK9Y3v1X+Kepbjv1+3dqbB8rj5LP++b+L9VY69L7AOjr27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQhCPaN6J4hHviVJ/Ttu11i4FLTi/Wdy0u/9zzhCcPL9af+Pr1H7ing67Z+TvF+qNnlcfRB994s1iP0+v/APGWbxZX1dwlT5SfgPdZH+u0KwZGnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMPOjxfrg6wPF+ku31x8rf/rM1cV1F/zDN4r1I2/o3HfK8cE1Nc5ue7XtHbY3DlvWY/sB25tqtzOqbBhA9cZyGH+LpPfOen+lpHURMU/SutpjAF1s1LBHxMOS3nsceZ6kNbX7aySdX3FfACrW6G/QzYqI7bX7r0qaVe+JtpdKWipJUzS1wc0BaFbTZ+Nj6Axf3bN8EbEqInojoneSJje7OQANajTs/bZnS1Ltdkd1LQFohUbDvlbSxbX7F0u6r5p2ALTKqJ/Zbd8h6WxJM21vlXS1pJWS7rJ9qaSXJV3YyibHu8Gdrze1/r5djc/v/ukvPVOsv3bjhPILHCjPsY7uMWrYI2JJnRJXxwCHEC6XBZIg7EAShB1IgrADSRB2IAmmbB4HTrji+bq1S04uD5r8+9HrivWzvnBZsT79e48U6+ge7NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2ceB0rTJr3/thOK6/7f2nWL9ymtuLdb/8sILivX43w/Xrc35+58V11Ubf+Y8A/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEUzYnN/BHpxfrt1397WJ97sQpDW/707cuK9bn3bS9WN+/eUvD2x6vmpqyGcD4QNiBJAg7kARhB5Ig7EAShB1IgrADSTDOjqI4Y36xfsTKrcX6HZ/8UcPbPv7BPy7Wf/tv63+PX5IGN21ueNuHqqbG2W2vtr3D9sZhy1bY3mZ7Q+3v3CobBlC9sRzG3yJp8QjLvxsR82t/91fbFoCqjRr2iHhY0kAbegHQQs2coFtm+8naYf6Mek+yvdR2n+2+fdrTxOYANKPRsN8o6VhJ8yVtl/Sdek+MiFUR0RsRvZM0ucHNAWhWQ2GPiP6IGIyIA5JukrSg2rYAVK2hsNuePezhBZI21nsugO4w6ji77TsknS1ppqR+SVfXHs+XFJK2SPpqRJS/fCzG2cejCbOOLNZfuei4urX1V1xXXPdDo+yLvvTSomL9zYWvF+vjUWmcfdRJIiJiyQiLb266KwBtxeWyQBKEHUiCsANJEHYgCcIOJMFXXNExd20tT9k81YcV67+MvcX6H3zj8vqvfe/64rqHKn5KGgBhB7Ig7EAShB1IgrADSRB2IAnCDiQx6rfekNuBheWfkn7xC+Upm0+av6VubbRx9NFcP3BKsT71vr6mXn+8Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzj7OufekYv35b5bHum86Y02xfuaU8nfKm7En9hXrjwzMLb/AgVF/3TwV9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7IeAiXOPLtZfvOTjdWsrLrqzuO4fHr6zoZ6qcFV/b7H+0HWnFesz1pR/dx7vNuqe3fYc2w/afsb207a/VVveY/sB25tqtzNa3y6ARo3lMH6/pOURcaKk0yRdZvtESVdKWhcR8yStqz0G0KVGDXtEbI+Ix2v3d0t6VtJRks6TdPBayjWSzm9VkwCa94E+s9s+RtIpktZLmhURBy8+flXSrDrrLJW0VJKmaGqjfQJo0pjPxts+XNIPJF0eEbuG12JodsgRZ4iMiFUR0RsRvZM0ualmATRuTGG3PUlDQb8tIu6pLe63PbtWny1pR2taBFCFUQ/jbVvSzZKejYhrh5XWSrpY0sra7X0t6XAcmHjMbxXrb/7u7GL9or/7YbH+px+5p1hvpeXby8NjP/vX+sNrPbf8T3HdGQcYWqvSWD6znyHpy5Kesr2htuwqDYX8LtuXSnpZ0oWtaRFAFUYNe0T8VNKIk7tLOqfadgC0CpfLAkkQdiAJwg4kQdiBJAg7kARfcR2jibN/s25tYPW04rpfm/tQsb5ken9DPVVh2baFxfrjN5anbJ75/Y3Fes9uxsq7BXt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUgizTj73t8v/2zx3j8bKNavOu7+urVFv/F2Qz1VpX/wnbq1M9cuL657/F//vFjveaM8Tn6gWEU3Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkGWffcn7537XnT767Zdu+4Y1ji/XrHlpUrHuw3o/7Djn+mpfq1ub1ry+uO1isYjxhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTgiyk+w50i6VdIsSSFpVURcZ3uFpD+R9FrtqVdFRP0vfUs6wj1xqpn4FWiV9bFOu2JgxAszxnJRzX5JyyPicdvTJT1m+4Fa7bsR8e2qGgXQOmOZn327pO21+7ttPyvpqFY3BqBaH+gzu+1jJJ0i6eA1mMtsP2l7te0ZddZZarvPdt8+7WmqWQCNG3PYbR8u6QeSLo+IXZJulHSspPka2vN/Z6T1ImJVRPRGRO8kTa6gZQCNGFPYbU/SUNBvi4h7JCki+iNiMCIOSLpJ0oLWtQmgWaOG3bYl3Szp2Yi4dtjy2cOedoGk8nSeADpqLGfjz5D0ZUlP2d5QW3aVpCW252toOG6LpK+2pEMAlRjL2fifShpp3K44pg6gu3AFHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlRf0q60o3Zr0l6ediimZJ2tq2BD6Zbe+vWviR6a1SVvR0dER8bqdDWsL9v43ZfRPR2rIGCbu2tW/uS6K1R7eqNw3ggCcIOJNHpsK/q8PZLurW3bu1LordGtaW3jn5mB9A+nd6zA2gTwg4k0ZGw215s+znbL9i+shM91GN7i+2nbG+w3dfhXlbb3mF747BlPbYfsL2pdjviHHsd6m2F7W21926D7XM71Nsc2w/afsb207a/VVve0feu0Fdb3re2f2a3PUHS85I+J2mrpEclLYmIZ9raSB22t0jqjYiOX4Bh+0xJb0m6NSJOqi37J0kDEbGy9g/ljIi4okt6WyHprU5P412brWj28GnGJZ0v6Svq4HtX6OtCteF968SefYGkFyJic0TslXSnpPM60EfXi4iHJQ28Z/F5ktbU7q/R0P8sbVent64QEdsj4vHa/d2SDk4z3tH3rtBXW3Qi7EdJ+sWwx1vVXfO9h6Qf237M9tJONzOCWRGxvXb/VUmzOtnMCEadxrud3jPNeNe8d41Mf94sTtC938KI+Kykz0u6rHa42pVi6DNYN42djmka73YZYZrxX+vke9fo9OfN6kTYt0maM+zxJ2rLukJEbKvd7pB0r7pvKur+gzPo1m53dLifX+umabxHmmZcXfDedXL6806E/VFJ82zPtX2YpC9KWtuBPt7H9rTaiRPZniZpkbpvKuq1ki6u3b9Y0n0d7OVdumUa73rTjKvD713Hpz+PiLb/STpXQ2fkX5T0V53ooU5fn5T0RO3v6U73JukODR3W7dPQuY1LJX1U0jpJmyT9l6SeLurtPyQ9JelJDQVrdod6W6ihQ/QnJW2o/Z3b6feu0Fdb3jculwWS4AQdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTx/65XcTNOWsh5AAAAAElFTkSuQmCC\n",
108 | "text/plain": [
109 | ""
110 | ]
111 | },
112 | "metadata": {
113 | "tags": [],
114 | "needs_background": "light"
115 | }
116 | }
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "metadata": {
122 | "colab": {
123 | "base_uri": "https://localhost:8080/"
124 | },
125 | "id": "XCtPgKUUALGk",
126 | "outputId": "f97e666d-6e8f-481c-9b09-13447338f29f"
127 | },
128 | "source": [
129 | "U, D, V = np.linalg.svd(X_train[0])\n",
130 | "\n",
131 | "print(U.shape)\n",
132 | "print(D.shape)\n",
133 | "print(V.shape)"
134 | ],
135 | "execution_count": 35,
136 | "outputs": [
137 | {
138 | "output_type": "stream",
139 | "text": [
140 | "(28, 28)\n",
141 | "(28,)\n",
142 | "(28, 28)\n"
143 | ],
144 | "name": "stdout"
145 | }
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "metadata": {
151 | "colab": {
152 | "base_uri": "https://localhost:8080/",
153 | "height": 282
154 | },
155 | "id": "X7DWdYrRAQQ5",
156 | "outputId": "1460239c-cc61-4439-9d84-98d851213d3d"
157 | },
158 | "source": [
159 | "plt.imshow(U, cmap='viridis')"
160 | ],
161 | "execution_count": 53,
162 | "outputs": [
163 | {
164 | "output_type": "execute_result",
165 | "data": {
166 | "text/plain": [
167 | ""
168 | ]
169 | },
170 | "metadata": {
171 | "tags": []
172 | },
173 | "execution_count": 53
174 | },
175 | {
176 | "output_type": "display_data",
177 | "data": {
178 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAUNklEQVR4nO3dfWxV93kH8O/ji238im1ejLHNa1kITTeSuqwbqGVFS1ImjbTTsrIpomoWqjbRWrXThpI/mn+qoWpNl0lLK7qg0KyjadVkQQtaw1Ba1lZLcDLCe2JeDMb4BWyD37F977M/fKgM+Pcc576cc+nv+5Es2/fx8Xnu8X3uub7P+f1+oqogot9+BXEnQETRYLETeYLFTuQJFjuRJ1jsRJ6YFenOSsq0qLImyl0STWv1wstm/ETn/Igyya6x/l5MjAzJdLGMil1EHgTwLIAEgH9V1R3WzxdV1uBDW76WyS6JsuKt7c+Z8TU7vhxRJtl1es8zzljaL+NFJAHgXwB8GsBqAFtEZHW6v4+IciuT/9nXAjitqmdVdQzAjwBszk5aRJRtmRR7PYC2Kd9fDG67iYhsE5FmEWmeGBnKYHdElImcvxuvqjtVtUlVm2aVlOV6d0TkkEmxtwNonPJ9Q3AbEeWhTIr9EICVIrJMRIoAfA7A3uykRUTZlnbrTVUnROQJAD/DZOttl6oez1pmRBk6bLTX7tTWWiYy6rOr6j4A+7KUCxHlEC+XJfIEi53IEyx2Ik+w2Ik8wWIn8gSLncgTkY5nJ8omq48O+NlLt/DMTuQJFjuRJ1jsRJ5gsRN5gsVO5AkWO5En2HqjvMXWWnbxzE7kCRY7kSdY7ESeYLETeYLFTuQJFjuRJ1jsRJ6ItM+uBcB4uTt+vSZlbl/c435uKuq3993/8REzXnGoxN7+rqQztuzlCXPb9k8Um/GJMvt+l16yn5Pnnhh3xvp+p9DctuSKve/BRfa+F/+kzYxf+mdjFaDX7eW72Uefnn6qzx37T/fjlGd2Ik+w2Ik8wWIn8gSLncgTLHYiT7DYiTzBYifyRLR99gQwXq7u+MLr5vapAXcvXJLu3wsAxSXuXvQku89e9wtxxs79qd3Lrjli5zb/0FUzfvapIjM+eNXdyy65bPfRB5aEPN/bqePis8aFEwAm3qx2xsYa7V9edsl9zHNtZIGdW0l3fLkNv1fljKVGE85YRsUuIq0ABgAkAUyoalMmv4+IcicbZ/Y/UtUrWfg9RJRD/J+dyBOZFrsCeF1E3haRbdP9gIhsE5FmEWlODg1luDsiSlemL+PXq2q7iCwAsF9ETqnqwak/oKo7AewEgNkNIe/IEFHOZHRmV9X24HM3gFcArM1GUkSUfWkXu4iUiUjFja8B3A/gWLYSI6LsyuRlfC2AV0Tkxu/5d1X9L2sDFSBV5H4ln0jYPeGCMSPmHsYLABg/U2HG54SM665oGXDGBhrnmNvOvmon17HBHte9/AvvmvH3nrvLGas4PNvcdnj1qBlfvMfdtwWA1Cn7uA40umM1p0KuAWiw951Ls4bj66OHsXr8BcbUCmkXu6qeBfB76W5PRNFi643IEyx2Ik+w2Ik8wWIn8gSLncgT0S7ZXKBIlrvbLcl+eyhnmTFddOU5oy8HoKLNfl4bq7DjqVJ7GKulY53dxknU2ZcRn1n0u2a87jV3O3POMXuM0siJSjPe+fv2/S4IGTlc5O5YovyMPf/3QIN7eGyu1Zyy26VxtgXLLrlryPp78MxO5AkWO5EnWOxEnmCxE3mCxU7kCRY7kSdY7ESeiLbPngCk1D0GTwftdMovuXufOst+3tKQp7WBRrtv2r+k1Bmbe9xuNpd02/ereo891PPC/XafvqzdvRx12zftfRe/ZvfRG35hL3U90GgvR118zX3fWh+y++jF7pWJc67sgn3tw0CDfX1CLo1Wux/MajyMeWYn8gSLncgTLHYiT7DYiTzBYifyBIudyBMsdiJPRNtnTwE67n5+Keq1e91F/e4x68nZ9vPWeGnIePYqe7GasSXu5aTr37CXmr640Z5u+XqVHa9ttsfqt3zB/Wese8meSvraMruH37s+5PqFYfsagfJz7txGF9nXJxT3pT+HQKZa/speirriXESJTEM0vYWVeGYn8gSLncgTLHYiT7DYiTzBYifyBIudyBMsdiJPRNtnV4GMunvpNSfs/mFi1D2ePVVoP2/1L7V7+GPVIWs+X3P3fAe/edXctOon9tjn4UX2rlsftuNFl9y5jbuH4QMARu62l2wuO2r36Rv39Zjxq/e4x6yXdtoPv+vxTRuPspB1BuI0Xua+NsKatyH0HonILhHpFpFjU26rEZH9ItISfI7xz0JEMzGTp68XADx4y23bARxQ1ZUADgTfE1EeCy12VT0IoPeWmzcD2B18vRvAQ1nOi4iyLN1/TGpVtSP4uhNAresHRWSbiDSLSHNycDDN3RFRpjJ+F0JVFYDznTVV3amqTaralCi3BxcQUe6kW+xdIlIHAMHn7uylRES5kG6x7wWwNfh6K4BXs5MOEeVKaJ9dRPYA2ABgnohcBPANADsA/FhEHgVwHkBIJzigACbcPcKyDnvc9liVu58sIW3yoUb7Bwr77ee94l533j1XFprbLn3z1vc3b3ZpY40ZX/iSGcZojfv6hP6l9v26++/azfi19UvNeMsjdu6pIndui193ryEAANer4xvPPnC3PdZ+ztH4cqs6634sJ4ypFUKLXVW3OEIbw7YlovyRv5cJEVFWsdiJPMFiJ/IEi53IEyx2Ik9EPMQVKDC6LQXX7fZY733u4ZZVZ+w2jlbY8UX7QqaiLnPH++6yh8+ee9geFFjVZF+TNNA/34wnjVWTx+bYw4bb/nKFGS+5Ym9v/T0BoPaQe/uCv+2yN97bYMdzaPGr9hTb15ZHlMg0+la6H2/JX7m345mdyBMsdiJPsNiJPMFiJ/IEi53IEyx2Ik+w2Ik8EW2fHYAkjWlwQ6aDHl7o7tlWttp90aL2IjNefLnfjA8tdE8HXfuWvWRzx3qjEQ4g8cI8M971qZDhucZS14kP2VOBjSTs2YMGPmwP9ax5yx7qefk+99902XZ7im38oR3Ope4muzSK7VHLObXg/9xDwduGjesacpEMEeUfFjuRJ1jsRJ5gsRN5gsVO5AkWO5EnWOxEnoi8z65GO1wL7F55stI9eDpZbI8pL7tohjG6wF7beKjOnVtJj73vBc32oO+iAbuXPbujxIwnS4zpmv/8qLlt3+f/wIxXtqbMOMS+xmCs0j0HwfuP2D3+yjP2rnOp/II9jn+83H6s5tLIPHfZpma58+KZncgTLHYiT7DYiTzBYifyBIudyBMsdiJPsNiJPBF5n92SChnPjkJ37zNZaPc9S3rsfvFQnX0orPnXe1fZ29acsvvsV5e7e9EAUPcru5c9Mt89prznMbuPPjLPPm49a+x5AMIs+49RZ0z+1+5lX77Xvr4gl4oG87fPvuTx952xlrfdxzv0zC4iu0SkW0SOTbntaRFpF5HDwcemD5owEUVrJi/jXwDw4DS3f0dV1wQf+7KbFhFlW2ixq+pBADFOwkNE2ZDJG3RPiMiR4GW+czEzEdkmIs0i0pwcGspgd0SUiXSL/bsAVgBYA6ADwLddP6iqO1W1SVWbEmVlae6OiDKVVrGrapeqJlU1BeD7ANZmNy0iyra0il1E6qZ8+xkAx1w/S0T5IbTPLiJ7AGwAME9ELgL4BoANIrIGgAJoBfDFGe1NATGmQB+tCUlH3OO+UyGbllyxx4x3fNKe/7z6qPt5sX+F3ZO9cLfd45/7a3s8/ES5Ha8+5F7nvOOBOmcMAK5/ZNiM179k99l7VtsHvu1L7msMZr1rj2cvtKe8z6mee+w++uwrESUyjbZ/WumMjXW5r9kILXZV3TLNzc/PKCsiyhu8XJbIEyx2Ik+w2Ik8wWIn8gSLncgTkQ5xleDDZbTKbncUFLpbWBOl9rYj8+0W0rxlPWa8+t/cwy2vV9stJKjd1kvZYVz5iP1n6vwL99LHyQ677bfqqT4z3vbZejM+96Q9fHek033V5FCDuWmsJBXfENYwXWvd5+jxX7u345mdyBMsdiJPsNiJPMFiJ/IEi53IEyx2Ik+w2Ik8Ee1U0goUjLnDI7V2bzM16h7qOdRgDzMdarB/d+FIsRkf+5i7l76g2Z7queSMPR7y5NcWmvHy1pAloV90X0MwWG/f79N/vciMj9UZfzAARf32RQKll91jmkdq82om85tUtNqPp7HK+PrwK190XxvR0+M+3jyzE3mCxU7kCRY7kSdY7ESeYLETeYLFTuQJFjuRJyJtdKoASWNYuaTs3iYm3M9NyVJ73HZi2H5eK/65e0w4ANT9Wasz1nJoibltstiezrn0op3brKGQqao3ubef3Wn3g+cesY9b6X473vonZhiphPsagXlH7em9r64IGeifQ1UtI2a8+6OlEWVyu7H57jkCUmfdjwWe2Yk8wWIn8gSLncgTLHYiT7DYiTzBYifyBIudyBPRDihOKCaq3ONtZ3fa6RSed8ev32P3RWt/bs8bX3mi14yfXeDupUvIUawM6aOPu9umAIBksd0rX/yauxd+eY09Fl5Dnu67P+peAhgA6g/a88YX97rHw194wD0XPxDvssg137pgxrtfWhVRJrfrW+l+LCffdT9WQs/sItIoIm+IyAkROS4iXwlurxGR/SLSEnyuTidxIorGTF7GTwD4uqquBvBxAI+LyGoA2wEcUNWVAA4E3xNRngotdlXtUNV3gq8HAJwEUA9gM4DdwY/tBvBQrpIkosx9oDfoRGQpgHsBvAmgVlU7glAngFrHNttEpFlEmpODQxmkSkSZmHGxi0g5gJ8C+Kqq9k+NqaoCmHa0hqruVNUmVW1KlIe8E0VEOTOjYheRQkwW+g9V9eXg5i4RqQvidQC6c5MiEWVDaOtNRATA8wBOquozU0J7AWwFsCP4/Gro3goAFLtbb+Xn7TaRNS3x8Dr7X4SCpD1csm/NXDM+95h7mGnHOnNTVJ2x21NDIVMqa8isxV0fc9+3WaP2tuMhS10PrLKHoSZG7OMqj11zxvSI3XqL07UvLbB/YEMkaUzLasVaj5WZ9NnXAXgEwFERORzc9iQmi/zHIvIogPMAHp5hrkQUg9BiV9VfAnA9X2zMbjpElCu8XJbIEyx2Ik+w2Ik8wWIn8gSLncgT0Q5xTQEYcffSrT46AEjS3esevm73eycW2Xd1qNGernnuEXd81fd6zG3bH5hvxocX2fueqLCPS+kF930r7rV/d9+H7fjqf7hsxk/9jb3cdOXP3NcvLLxgX39wbVl8Szq3frbGjBfbI6JzanCZe0hzylh5nGd2Ik+w2Ik8wWIn8gSLncgTLHYiT7DYiTzBYifyRKSNTEkKivrcfXZNhPXZ3bHR8xXmtmPrh814osVegjcx5u5tdn5ynrntrI32nMgrt9tjykfrys14qtjoldsrLkNS9kOg5TF7uem6/wlZ8rnDPc9A//L8Hc8+57R9v0Zr4jtPzj3sfrx0Gw9zntmJPMFiJ/IEi53IEyx2Ik+w2Ik8wWIn8gSLncgTkfbZC8aBki53j7B/SfrpVLSG/ECr3UcPM1hvz2lv2m/34S/l8Ry9pR12POy4DNbfmasAhfXRD29/zoyv2fHlbKZzk7FKY95448/BMzuRJ1jsRJ5gsRN5gsVO5AkWO5EnWOxEnmCxE3liJuuzNwL4AYBaAApgp6o+KyJPA3gMwI2JxZ9U1X25SpQon4T10ePsw7vM5CqWCQBfV9V3RKQCwNsisj+IfUdV/zF36RFRtsxkffYOAB3B1wMichJAfa4TI6Ls+kD/s4vIUgD3AngzuOkJETkiIrtEpNqxzTYRaRaR5okR9xRFRJRbMy52ESkH8FMAX1XVfgDfBbACwBpMnvm/Pd12qrpTVZtUtWlWyZ15nTTRb4MZFbuIFGKy0H+oqi8DgKp2qWpSVVMAvg9gbe7SJKJMhRa7iAiA5wGcVNVnptw+ddrRzwA4lv30iChbZvJu/DoAjwA4KiKHg9ueBLBFRNZgsh3XCuCLOcmQ6A6USWsuV225mbwb/0sA0w2gZU+d6A7CK+iIPMFiJ/IEi53IEyx2Ik+w2Ik8wWIn8kSkU0kT0SSrl56r4bE8sxN5gsVO5AkWO5EnWOxEnmCxE3mCxU7kCRY7kSdEVaPbmchlAOen3DQPwJXIEvhg8jW3fM0LYG7pymZuS1R1/nSBSIv9tp2LNKtqU2wJGPI1t3zNC2Bu6YoqN76MJ/IEi53IE3EX+86Y92/J19zyNS+AuaUrktxi/Z+diKIT95mdiCLCYifyRCzFLiIPish7InJaRLbHkYOLiLSKyFEROSwizTHnsktEukXk2JTbakRkv4i0BJ+nXWMvptyeFpH24NgdFpFNMeXWKCJviMgJETkuIl8Jbo/12Bl5RXLcIv+fXUQSAN4H8McALgI4BGCLqp6INBEHEWkF0KSqsV+AISKfADAI4Aeqek9w27cA9KrqjuCJslpV/z5PcnsawGDcy3gHqxXVTV1mHMBDAD6PGI+dkdfDiOC4xXFmXwvgtKqeVdUxAD8CsDmGPPKeqh4E0HvLzZsB7A6+3o3JB0vkHLnlBVXtUNV3gq8HANxYZjzWY2fkFYk4ir0eQNuU7y8iv9Z7VwCvi8jbIrIt7mSmUauqHcHXnQBq40xmGqHLeEfplmXG8+bYpbP8eab4Bt3t1qvqfQA+DeDx4OVqXtLJ/8HyqXc6o2W8ozLNMuO/EeexS3f580zFUeztABqnfN8Q3JYXVLU9+NwN4BXk31LUXTdW0A0+d8ecz2/k0zLe0y0zjjw4dnEufx5HsR8CsFJElolIEYDPAdgbQx63EZGy4I0TiEgZgPuRf0tR7wWwNfh6K4BXY8zlJvmyjLdrmXHEfOxiX/5cVSP/ALAJk+/InwHwVBw5OPJaDuDd4ON43LkB2IPJl3XjmHxv41EAcwEcANAC4L8B1ORRbi8COArgCCYLqy6m3NZj8iX6EQCHg49NcR87I69IjhsvlyXyBN+gI/IEi53IEyx2Ik+w2Ik8wWIn8gSLncgTLHYiT/w/xSX3gpUMTO4AAAAASUVORK5CYII=\n",
179 | "text/plain": [
180 | ""
181 | ]
182 | },
183 | "metadata": {
184 | "tags": [],
185 | "needs_background": "light"
186 | }
187 | }
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "metadata": {
193 | "colab": {
194 | "base_uri": "https://localhost:8080/",
195 | "height": 82
196 | },
197 | "id": "YkR97H0DAc8D",
198 | "outputId": "1aac799d-f96a-423f-81c4-70b70a34c170"
199 | },
200 | "source": [
201 | "plt.imshow(D.reshape(1, 28), cmap='viridis')"
202 | ],
203 | "execution_count": 54,
204 | "outputs": [
205 | {
206 | "output_type": "execute_result",
207 | "data": {
208 | "text/plain": [
209 | ""
210 | ]
211 | },
212 | "metadata": {
213 | "tags": []
214 | },
215 | "execution_count": 54
216 | },
217 | {
218 | "output_type": "display_data",
219 | "data": {
220 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAAvCAYAAADginEnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAGjklEQVR4nO3dXYxcZR3H8e+PvoTYEtwC1gYRRQyQGHzbkJgQ0/iKXtAaItqrcqHlwka9w6rBBmNsiBrvTFAbNFHQ1LdqiAhRo4kvsCVNX4CWgjWy2bZSFFlE6u7+vDhnYVhntnt2zu7MnPl9ks3MnPnPOc8/T/Z/Zp55zjOyTURENN85vW5AREQsjxT8iIghkYIfETEkUvAjIoZECn5ExJBIwY+IGBJdFXxJ6yTdJ+mx8nakQ9y0pP3l395ujhkREYujbubhS7odeNr2LkmfAUZs39ImbtL22i7aGRERXeq24B8BNtqekLQB+K3tK9rEpeBHRPRYt2P4621PlPdPAOs7xJ0raUzSnyRt7vKYERGxCCvPFiDpfuDVbZ76XOsD25bU6ePCpbbHJV0G/FrSQduPtznWNmAbwJpX6O1XXr76rAnMOvzcugXHAsy8sKJSvKYrhbPy+WrxqvhJSzMV939mqtoLqn7yqxo/s8T7r6rC/qu3pGqulQ8Q8aJn+cdTti9q99yyDOnMec2dwC9s75kvbvTN5/qBey9ZcFuufmDLgmMB/v34+ZXiVz1T7cPQhYeqnSHOOVOtH1ZNVivgq8f/WSle/zlTKZ6pau3x89XOiJ6ueIabrniGnln4/iv/z1Rsi6ueDGcq5hqNdr/37LM92u65bod09gJbJV0HHABeVX55+yJJI5LOk/QDSU8AHwWqVZ+IiOhatwV/F/BeisJ/FLgS2CLpw5K+VcZcBTwKvBt4DrgT+HiXx42IiIrOOoY/H9unJX0e2Gn7/QCS7gYut/2xMuYPkg6VMX+UtBI4IUnO2swREcumjittLwb+1vL4yXJb2xjbU8AzwAU1HDsiIhaor5ZWkLStnL459vfT+SIqIqJOdRT8caB1Os1rym1tY8ohnfOB03N3ZPsO26O2Ry+6oNq0yYiImF9XY/ilB4Gryxk4M8Aa4D1zYk5SzL8/ArwS+EvG7yMillcd7/BnC7fKPwBLuk3S9eXj31O8y18LnAI+UsNxIyKigjre4V8DHGiZpbMD2GT71paY/wK/tL29huNFRMQiLNcsHYAbJB2QtEfSwi+hjYiIWtTxDn8hfg7cZfsFSTcD3wHeNTeodS0dYHLFhmNH2uzrQuCp/9/8xdoaW4ej9e2qQ76NlXybbZjy7VWul3Z6oqu1dAAkvYOXX3i1A8D2lzvEr6BYQ7/aYjYvvX6s0zoRTZR8my35Nlc/5lrHkM6DwBslvV7Saoq1cl72q1blwmqzrgceqeG4ERFRQddDOranJG0H7gVWALttH5Z0GzBmey/wyXLGzhTwNHBTt8eNiIhqahnDt30PcM+cbbe23N8B7KjjWMAdNe1nUCTfZku+zdV3uXY9hh8REYOhr9bSiYiIpTNQBV/SdZKOSDo294dWmkjScUkHJe2XNNbr9tRN0m5Jp8rls2e3rZN0n6THytuRXraxTh3y3SlpvOzj/ZI+2Ms21kXSJZJ+I+lhSYclfarc3sj+nSffvurfgRnSKadzHqX4wZUnKWYHbbH9cE8btoQkHQdGbTdy3rKkdwKTwHdtv6ncdjvFtN1d5Ul9xPYtvWxnXTrkuxOYtP2VXratbuXMvA22H5J0HrAP2EwxYaNx/TtPvjfSR/07SO/wrwGO2X7C9hngbmBTj9sUXbD9O4pZW602UVyYR3m7eVkbtYQ65NtItidsP1Tef5ZiKvbFNLR/58m3rwxSwV/oEg5NYuBXkvaVVyEPg/W2J8r7J4D1vWzMMtleLjuyuylDHK0kvQ54K/BnhqB/5+QLfdS/g1Twh9G1tt8GfAD4RDkkMDTKJbQHY8xx8b4BvAF4CzABfLW3zamXpLXAj4BP2/5X63NN7N82+fZV/w5SwV/ID600iu3x8vYU8BOKYa2mOzl7ZXZ5e6rH7VlStk/anrY9A3yTBvWxpFUUxe97tn9cbm5s/7bLt9/6d5AK/lmXcGgSSWvKL3+QtAZ4H3Bo/lc1wl5ga3l/K/CzHrZlyc1ZduRDNKSPJQn4NvCI7a+1PNXI/u2Ub7/178DM0gEopzR9nZeWcPhSj5u0ZCRdRvGuHooror/ftHwl3QVspFhV8CTwBeCnwA+B1wJ/BW603YgvOjvku5Hi476B48DNLWPcA0vStRQ/fHSQ4pfwAD5LMa7duP6dJ98t9FH/DlTBj4iIxRukIZ2IiOhCCn5ExJBIwY+IGBIp+BERQyIFPyJiSKTgR0QMiRT8iIghkYIfETEk/getTpto5Zz1CwAAAABJRU5ErkJggg==\n",
221 | "text/plain": [
222 | ""
223 | ]
224 | },
225 | "metadata": {
226 | "tags": [],
227 | "needs_background": "light"
228 | }
229 | }
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "metadata": {
235 | "colab": {
236 | "base_uri": "https://localhost:8080/",
237 | "height": 282
238 | },
239 | "id": "NNSBp4GSAjUj",
240 | "outputId": "67f6af3b-33d8-411a-b0a4-09096771e1a0"
241 | },
242 | "source": [
243 | "plt.imshow(V, cmap='viridis')"
244 | ],
245 | "execution_count": 55,
246 | "outputs": [
247 | {
248 | "output_type": "execute_result",
249 | "data": {
250 | "text/plain": [
251 | ""
252 | ]
253 | },
254 | "metadata": {
255 | "tags": []
256 | },
257 | "execution_count": 55
258 | },
259 | {
260 | "output_type": "display_data",
261 | "data": {
262 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAARe0lEQVR4nO3dfWxV530H8O8Xv2BjzIvBGMJbIKXRKFuAemwKUZSsL6KpJuhURaFTRNdoVEqjNW00NcomNfsvm1aibGoz0YJCozRVpyYKW1EaSqPRaGsWh1FeAimEQMEF28QQG+N3//aHTyo38fM7N/f1wPP9SMjX9/G558exvz7X93ef89DMICLXvymVLkBEykNhF4mEwi4SCYVdJBIKu0gkqsu6s2kNVjOjqZy7LBpzfi1abUpHI+1X6mja3pky7uw/bdNS8/Y/lrLtWErxacM14R3YqL8xR9LG/X1XynBPN0au9k1afEFhJ7kBwBMAqgB8z8we876+ZkYTlm/5eiG7rJiRaeGxgYXD7ras89NsV/1vQ9oPnlU7Ya9OSVTaL4OUfaeqcWob8n8LVvdUueNjKb9kq+ZfDY4N99W629ZeqPHHL1f6t+jkTu3aFhzL+2k8ySoA3wbwGQArAWwmuTLfxxOR0irkb/Z1AE6a2SkzGwLwQwAbi1OWiBRbIWFfCODshM/PJff9HpJbSbaRbBvt7ytgdyJSiJK/Gm9m282s1cxaq+obSr07EQkoJOztABZP+HxRcp+IZFAhYX8NwAqSy0jWArgHwO7ilCUixZZ3683MRkg+AOCnGG+97TSzo0WrLGOqnZcb5rzqH8bpF/zHrur3W3PDjX4Lqn9OeHx4ul/b4Gx3GIPNfm2cOeSO19aFG9LV1f5jX6lx+p2A+/YCAJj3Qnj7mqt+S3LM6dEDwLvL/O9JFhXUZzezPQD2FKkWESkhvV1WJBIKu0gkFHaRSCjsIpFQ2EUiobCLRKKs89mvZYNzw03dmtu73W07h1KmS+6f4Y7PftPvZc85FJ7KCfpTMYdm+lM9r873ax9oqnfHx5yfsMHZKY3yOSk9/mH/XHVxTfj/PmXY75NbygzWqZf88SzSmV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQq23HBWy/OXwkH+YBz7it5j6bvDbX9UDTvss7XLNKS2mtCu4elN/AaBqMDxW946/86oB/7g1nvVrm/Mfx8ODzf4lzS+vaXbHe5dk8+qyHp3ZRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIqM+eoxlvh8caU6aojt7s98mvrO13x+uXDLjj/cdnBcdmnnQ3xew3/ceG+b3srrX+5Z77W8Lbj6T8v2bOdKbuAugbneOOV3/yo8GxnqX+FNe+Zf57HxpOu8OZpDO7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJ9dlz1HtjeKzbuWQxADQd8HvVH/m239N9e1O4jw4AXBruR78zx79U9NAM/1LQc476l7Ge97rfC7+yqC441tfl7/vSSv/Hs2Hdu+5414zw+x8a2t1NseQ//QsBvLMqsiWbSZ4G0AtgFMCImbUWoygRKb5inNnvNLOLRXgcESkh/c0uEolCw24AXiL5Osmtk30Bya0k20i2jfanXLBMREqm0Kfxt5lZO8l5APaSPG5m+yd+gZltB7AdAOrnLy7kuo0iUoCCzuxm1p587ATwPIB1xShKRIov77CTbCDZ+N5tAJ8GcKRYhYlIcRXyNL4FwPMcXxK4GsAPzOzFolSVQVWD4V76yFDKssiz0sanuuNNh/2/frrHwnPKp9T52475U+3BlD+8Lq/we+UX7wxfOL6uwe/hz/q5f52Aed856o5f+NqtwbGem/w++tAM/8BMGXGHMynvsJvZKQC3FLEWESkhtd5EIqGwi0RCYReJhMIuEgmFXSQSmuKao8Hm8DTUqj7/d+bUS37/ylJmSw7f0+2ON/x0bnDMWzIZAGp7/RbU6FT//3bDX51yx7t+vSQ4tuRf3U3RtdY/bh1/E26tAcBIeHYtqvv8dmhdt7/voRlasllEMkphF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpFQnz1HjSfCzfDZJ4bdbX97r9/s7j3R4I7zQLiPDgDmLIvc0uZfprpnif8jUJfyHoD2p5b7X3BreC7oiS9Mdzdd8UyPO356oz8FdnBR+PuycI//H5v5y3Pu+G82L3XHs0hndpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEuqz5+jq/HAve6Tev+xw48v+YZ562Z87Pe9+f854/zdagmMj0/zaajdccscvHp/jji/bPeCOD80KX2q6b6E/lz6tjz79jDsMjoWXq+4JT7MHAFxa4ffRr8VLSevMLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQn32HC38RbixeuUG/zBeWun30Wvf9X/nDv7AnzPeVBPudZ/6vD9ve03jZXd87j90uOMnH/mYO17jTEmfdcy/9vqVlCnjPSv88RlvhY/7QJO/7yn+JQqAa++y8elndpI7SXaSPDLhviaSe0meSD7OLm2ZIlKoXJ7GPwVgw/vuexjAPjNbAWBf8rmIZFhq2M1sP4D3rz+0EcCu5PYuAJuKXJeIFFm+L9C1mNn55PYFAME3Z5PcSrKNZNtof1+euxORQhX8aryZGYDgKyFmtt3MWs2starev7CiiJROvmHvILkAAJKPncUrSURKId+w7wawJbm9BcALxSlHREoltc9O8lkAdwCYS/IcgG8CeAzAj0jeB+AMgLtLWWQW2BSnsepPy8ZYg3/t9ub/8h/g3aX+t+nMZ8MLkTe+5TeET77lN6v7/95/j8D8X/q1D3wpPF+++W/9x+6+vdEdb9o/1R2/7Ly/YemeIXfbU1/wj9v04+G58lmVGnYz2xwY+kSRaxGREtLbZUUiobCLREJhF4mEwi4SCYVdJBKa4pqjofvfCY7Ne9Bv0wx81m8RjdT5LaaRlDcezv+fcGvv0hevuNtOe26mOz6YMp+xc61/vmj53qzg2Om/8H/8eMlv603xO5qo6Q3XdnGV/z2ZedB/7NFwtzOzdGYXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhPnuunm4ODnWt9/vsV3/r94trZ6Vc1jhleeCLq8Lfxmk/8fvoV1v8fQ/P8JvZcw/454uOj4cvZT37uH9cOue5w+if69e+9Ce9wbFbth92t/3Vl/xLZLd/0j+uWaQzu0gkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4SCfXZc/QnD7UFx175tz92t134c/+xu//SWdcYQPOOae74+fXhb2P1QMrlmpf5ve6m//PPB0y5jPZQc7hPXxs+pAAAS7lac9WgX9vZT4WvE3D23291t63/uH/crkU6s4tEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVCfPUe7j/xRcKx5wN+2/U5/fNHT/oXhu//A/zaNNISb3UON/pzvmsv++GidPw7z+9FVveH57F23+I998w7//Qe/edgdxpJ/DD/+25umu9vWfb7DHR/cN9/feQalntlJ7iTZSfLIhPseJdlO8mDy767SlikihcrlafxTADZMcv/jZrY6+benuGWJSLGlht3M9gPoLkMtIlJChbxA9wDJQ8nT/OCKYCS3kmwj2Tba31fA7kSkEPmG/UkANwFYDeA8gG+FvtDMtptZq5m1VtWnrFAoIiWTV9jNrMPMRs1sDMB3AawrblkiUmx5hZ3kggmffg7AkdDXikg2pPbZST4L4A4Ac0meA/BNAHeQXA3AAJwG8OUS1pgJ046FF+S+tNLvNbf8t99PPvdn/r5revzHv/nJi8Gxnif8CeczH29yx8+vr3HHF7/U744P/flgcKzqZ/7i773L/V744Dl3GB3rwueyqqv+tr0vpvTR/cOSSalhN7PNk9y9owS1iEgJ6e2yIpFQ2EUiobCLREJhF4mEwi4SCVrKFMViqp+/2JZv+XrZ9jdR2iWPDz30HXf8Dx+/v4jViJTGqV3b0H/h7KS9Xp3ZRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIXDeXkj78tcL65Oqjy/VOZ3aRSCjsIpFQ2EUiobCLREJhF4mEwi4SCYVdJBLXVJ/d66WrTy7i05ldJBIKu0gkFHaRSCjsIpFQ2EUiobCLREJhF4lEpvrshc5JF5Gw1DM7ycUkXyb5BsmjJL+a3N9Eci/JE8lHf7FtEamoXJ7GjwB4yMxWAvhTAF8huRLAwwD2mdkKAPuSz0Uko1LDbmbnzexAcrsXwDEACwFsBLAr+bJdADaVqkgRKdyHeoGO5I0A1gB4FUCLmZ1Phi4AaAlss5VkG8m20f6+AkoVkULkHHaS0wH8GMCDZtYzcczGV4ecdIVIM9tuZq1m1lpV31BQsSKSv5zCTrIG40F/xsyeS+7uILkgGV8AoLM0JYpIMaS23kgSwA4Ax8xs24Sh3QC2AHgs+fhC2mN9rKUL/6tpqiIVkUuffT2AewEcJnkwue8RjIf8RyTvA3AGwN2lKVFEiiE17Gb2CoBJF3cH8IniliMipaK3y4pEQmEXiYTCLhIJhV0kEgq7SCTKOsX1aEezeulyXbgWp2PrzC4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUiobCLRCJTl5IWuVak9dGz2IfXmV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYT67CIlUEgfvlQ9eJ3ZRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFI5LI++2IA3wfQAsAAbDezJ0g+CuCvAXQlX/qIme0pVaEi1xOvl16qufC5vKlmBMBDZnaAZCOA10nuTcYeN7N/zmvPIlJWuazPfh7A+eR2L8ljABaWujARKa4P9Tc7yRsBrAHwanLXAyQPkdxJcnZgm60k20i2jfb3FVSsiOQv57CTnA7gxwAeNLMeAE8CuAnAaoyf+b812XZmtt3MWs2staq+oQgli0g+cgo7yRqMB/0ZM3sOAMysw8xGzWwMwHcBrCtdmSJSqNSwkySAHQCOmdm2CfcvmPBlnwNwpPjliUix5PJq/HoA9wI4TPJgct8jADaTXI3xdtxpAF8uSYUikSlkeuy6F7uCY7m8Gv8KAE4ypJ66yDVE76ATiYTCLhIJhV0kEgq7SCQUdpFIKOwikYjmUtJp0wZX/Yvf2+RoMasRyZ/Xhz/VsS04pjO7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJmln5dkZ2ATgz4a65AC6WrYAPJ6u1ZbUuQLXlq5i1LTWz5skGyhr2D+ycbDOz1ooV4MhqbVmtC1Bt+SpXbXoaLxIJhV0kEpUO+/YK79+T1dqyWheg2vJVltoq+je7iJRPpc/sIlImCrtIJCoSdpIbSL5J8iTJhytRQwjJ0yQPkzxIsq3Ctewk2UnyyIT7mkjuJXki+TjpGnsVqu1Rku3JsTtI8q4K1baY5Msk3yB5lORXk/sreuycuspy3Mr+NzvJKgC/BvApAOcAvAZgs5m9UdZCAkieBtBqZhV/AwbJ2wFcAfB9M1uV3PdPALrN7LHkF+VsM/tGRmp7FMCVSi/jnaxWtGDiMuMANgH4Iip47Jy67kYZjlslzuzrAJw0s1NmNgTghwA2VqCOzDOz/QC633f3RgC7ktu7MP7DUnaB2jLBzM6b2YHkdi+A95YZr+ixc+oqi0qEfSGAsxM+P4dsrfduAF4i+TrJrZUuZhItZnY+uX0BQEsli5lE6jLe5fS+ZcYzc+zyWf68UHqB7oNuM7O1AD4D4CvJ09VMsvG/wbLUO81pGe9ymWSZ8d+p5LHLd/nzQlUi7O0AFk/4fFFyXyaYWXvysRPA88jeUtQd762gm3zsrHA9v5OlZbwnW2YcGTh2lVz+vBJhfw3ACpLLSNYCuAfA7grU8QEkG5IXTkCyAcCnkb2lqHcD2JLc3gLghQrW8nuysox3aJlxVPjYVXz5czMr+z8Ad2H8Ffm3APxdJWoI1LUcwK+Sf0crXRuAZzH+tG4Y469t3AdgDoB9AE4A+BmApgzV9jSAwwAOYTxYCypU220Yf4p+CMDB5N9dlT52Tl1lOW56u6xIJPQCnUgkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4Sif8Hwv4m0oErNXcAAAAASUVORK5CYII=\n",
263 | "text/plain": [
264 | ""
265 | ]
266 | },
267 | "metadata": {
268 | "tags": [],
269 | "needs_background": "light"
270 | }
271 | }
272 | ]
273 | }
274 | ]
275 | }
276 |
--------------------------------------------------------------------------------
/modern_approach/forward_forward/forward_forward_pytorch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyOGnxSridVLcFl/Z7lVc4J1",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | },
17 | "accelerator": "GPU",
18 | "gpuClass": "standard"
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "view-in-github",
25 | "colab_type": "text"
26 | },
27 | "source": [
28 | " "
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "source": [
34 | "#Forward-Forward propagation"
35 | ],
36 | "metadata": {
37 | "id": "ZAovmTs66ICD"
38 | }
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 1,
43 | "metadata": {
44 | "id": "gYFTLW4K5qay"
45 | },
46 | "outputs": [],
47 | "source": [
48 | "import torch\n",
49 | "import torch.nn as nn\n",
50 | "from tqdm import tqdm\n",
51 | "from torch.optim import Adam\n",
52 | "from torchvision.datasets import MNIST\n",
53 | "from torchvision.transforms import Compose, ToTensor, Normalize, Lambda\n",
54 | "from torch.utils.data import DataLoader"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "source": [
60 | "def MNIST_loaders(train_batch_size=50000, test_batch_size=10000):\n",
61 | " transform = Compose([\n",
62 | " ToTensor(),\n",
63 | " Normalize((0.1307,), (0.3081,)),\n",
64 | " Lambda(lambda x: torch.flatten(x))\n",
65 | " ])\n",
66 | " train_loader = DataLoader(\n",
67 | " MNIST('./data/', train=True, download=True, transform=transform),\n",
68 | " batch_size=train_batch_size, shuffle=True\n",
69 | " )\n",
70 | " test_loader = DataLoader(\n",
71 | " MNIST('./data/', train=False, download=True, transform=transform),\n",
72 | " batch_size=test_batch_size, shuffle=False\n",
73 | " )\n",
74 | " return train_loader, test_loader\n",
75 | "\n",
76 | "def overlay_y_on_x(x, y):\n",
77 | " x_ = x.clone()\n",
78 | " x_[:, :10] *= 0.0\n",
79 | " x_[range(x.shape[0]), y] = x.max()\n",
80 | " return x_"
81 | ],
82 | "metadata": {
83 | "id": "BeQ6BjzS9JLZ"
84 | },
85 | "execution_count": 3,
86 | "outputs": []
87 | },
88 | {
89 | "cell_type": "code",
90 | "source": [
91 | "class Net(torch.nn.Module):\n",
92 | " def __init__(self, dims):\n",
93 | " super().__init__()\n",
94 | " self.layers = []\n",
95 | " for d in range(len(dims) - 1):\n",
96 | " self.layers += [Layer(dims[d], dims[d + 1]).cuda()]\n",
97 | "\n",
98 | " def predict(self, x):\n",
99 | " goodness_per_label = []\n",
100 | " for label in range(10):\n",
101 | " h = overlay_y_on_x(x, label)\n",
102 | " goodness = []\n",
103 | " for layer in self.layers:\n",
104 | " h = layer(h)\n",
105 | " goodness += [h.pow(2).mean(1)]\n",
106 | " goodness_per_label += [sum(goodness).unsqueeze(1)]\n",
107 | " goodness_per_label = torch.cat(goodness_per_label, 1)\n",
108 | " return goodness_per_label.argmax(1)\n",
109 | "\n",
110 | " def train(self, x_pos, x_neg):\n",
111 | " h_pos, h_neg = x_pos, x_neg\n",
112 | " for i, layer in enumerate(self.layers):\n",
113 | " print('Training layer', i, '...')\n",
114 | " h_pos, h_neg = layer.train(h_pos, h_neg)"
115 | ],
116 | "metadata": {
117 | "id": "qRvPhl89--ap"
118 | },
119 | "execution_count": 4,
120 | "outputs": []
121 | },
122 | {
123 | "cell_type": "code",
124 | "source": [
125 | "class Layer(nn.Linear):\n",
126 | " def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):\n",
127 | " super().__init__(in_features, out_features, bias, device, dtype)\n",
128 | " self.relu = torch.nn.ReLU()\n",
129 | " self.opt = Adam(self.parameters(), lr=0.03)\n",
130 | " self.threshold = 2.0\n",
131 | " self.epochs = 1000\n",
132 | "\n",
133 | " def forward(self, x):\n",
134 | " x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)\n",
135 | " return self.relu(\n",
136 | " torch.mm(x_direction, self.weight.T) +\n",
137 | " self.bias.unsqueeze(0)\n",
138 | " )\n",
139 | "\n",
140 | " def train(self, x_pos, x_neg):\n",
141 | " for i in tqdm(range(self.epochs)):\n",
142 | " g_pos = self.forward(x_pos).pow(2).mean(1)\n",
143 | " g_neg = self.forward(x_neg).pow(2).mean(1)\n",
144 | " loss = torch.log(1 + torch.exp(torch.cat([\n",
145 | " -g_pos + self.threshold,\n",
146 | " g_neg - self.threshold\n",
147 | " ]))).mean()\n",
148 | " self.opt.zero_grad()\n",
149 | " loss.backward()\n",
150 | " self.opt.step()\n",
151 | " return self.forward(x_pos).detach(), self.forward(x_neg).detach()"
152 | ],
153 | "metadata": {
154 | "id": "MC3qw48jCR4P"
155 | },
156 | "execution_count": 5,
157 | "outputs": []
158 | },
159 | {
160 | "cell_type": "code",
161 | "source": [
162 | "torch.manual_seed(42)\n",
163 | "train_loader, test_loader = MNIST_loaders()\n",
164 | "\n",
165 | "net = Net([784, 500, 500])\n",
166 | "x, y = next(iter(train_loader))\n",
167 | "x, y = x.cuda(), y.cuda()\n",
168 | "x_pos = overlay_y_on_x(x, y)\n",
169 | "rnd = torch.randperm(x.size(0))\n",
170 | "x_neg = overlay_y_on_x(x, y[rnd])\n",
171 | "net.train(x_pos, x_neg)\n",
172 | "\n",
173 | "print('Train error:', 1.0 - net.predict(x).eq(y).float().mean().item())\n",
174 | "\n",
175 | "x_test, y_test = next(iter(test_loader))\n",
176 | "x_test, y_test = x_test.cuda(), y_test.cuda()\n",
177 | "\n",
178 | "print('Test error:', 1.0 - net.predict(x_test).eq(y_test).float().mean().item())"
179 | ],
180 | "metadata": {
181 | "colab": {
182 | "base_uri": "https://localhost:8080/"
183 | },
184 | "id": "YPfnhB-CE1fR",
185 | "outputId": "4a97be71-9d65-43ac-b568-760ee6abe41f"
186 | },
187 | "execution_count": 7,
188 | "outputs": [
189 | {
190 | "output_type": "stream",
191 | "name": "stdout",
192 | "text": [
193 | "Training layer 0 ...\n"
194 | ]
195 | },
196 | {
197 | "output_type": "stream",
198 | "name": "stderr",
199 | "text": [
200 | "100%|██████████| 1000/1000 [00:59<00:00, 16.67it/s]\n"
201 | ]
202 | },
203 | {
204 | "output_type": "stream",
205 | "name": "stdout",
206 | "text": [
207 | "Training layer 1 ...\n"
208 | ]
209 | },
210 | {
211 | "output_type": "stream",
212 | "name": "stderr",
213 | "text": [
214 | "100%|██████████| 1000/1000 [00:39<00:00, 25.17it/s]\n"
215 | ]
216 | },
217 | {
218 | "output_type": "stream",
219 | "name": "stdout",
220 | "text": [
221 | "Train error: 0.07084000110626221\n",
222 | "Test error: 0.06929999589920044\n"
223 | ]
224 | }
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "source": [
230 | "Credits: Mohammad Pezeshki (https://github.com/mohammadpz)"
231 | ],
232 | "metadata": {
233 | "id": "iR-dwsMj8zzC"
234 | }
235 | }
236 | ]
237 | }
--------------------------------------------------------------------------------
/modern_approach/transformer/self_attention.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "self_attention.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyPmS2YULIuLDzTcuwBBAsob",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "language_info": {
17 | "name": "python"
18 | }
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "view-in-github",
25 | "colab_type": "text"
26 | },
27 | "source": [
28 | " "
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "yOFn0WjjXr2B"
35 | },
36 | "source": [
37 | "#Self-attention"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "tyk3ybvYXnHu"
44 | },
45 | "source": [
46 | "import numpy as n\n",
47 | "import tensorflow as tf"
48 | ],
49 | "execution_count": 18,
50 | "outputs": []
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {
55 | "id": "XtRRq7wmX7ys"
56 | },
57 | "source": [
58 | "##Krok 1. Przygotowanie wejść"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "metadata": {
64 | "id": "P0c54R30YEZk"
65 | },
66 | "source": [
67 | "x = [[1, 0, 1, 0], # wejście 1\n",
68 | " [0, 2, 0, 2], # wejście 2\n",
69 | " [1, 1, 1, 1]] # wejście 3\n",
70 | "\n",
71 | "x = np.array(x, dtype=np.float32)"
72 | ],
73 | "execution_count": 3,
74 | "outputs": []
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {
79 | "id": "cire_KEMYkIt"
80 | },
81 | "source": [
82 | "##Krok 2. Inicjalizacja wag"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "metadata": {
88 | "id": "MzQnEQyHYmpN"
89 | },
90 | "source": [
91 | "w_key = [[0, 0, 1],\n",
92 | " [1, 1, 0],\n",
93 | " [0, 1, 0],\n",
94 | " [1, 1, 0]]\n",
95 | "\n",
96 | "w_query = [[1, 0, 1],\n",
97 | " [1, 0, 0],\n",
98 | " [0, 0, 1],\n",
99 | " [0, 1, 1]]\n",
100 | "\n",
101 | "w_value = [[0, 2, 1],\n",
102 | " [0, 3, 0],\n",
103 | " [1, 0, 3],\n",
104 | " [1, 1, 0]]\n",
105 | "\n",
106 | "w_key = np.array(w_key, dtype=np.float32)\n",
107 | "w_query = np.array(w_query, dtype=np.float32)\n",
108 | "w_value = np.array(w_value, dtype=np.float32)"
109 | ],
110 | "execution_count": 4,
111 | "outputs": []
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {
116 | "id": "ecF6Vrq1b3kk"
117 | },
118 | "source": [
119 | "##Krok 3. Wyznaczenie key, query i value"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "metadata": {
125 | "colab": {
126 | "base_uri": "https://localhost:8080/"
127 | },
128 | "id": "fmspQnWfcFHH",
129 | "outputId": "4f9d0877-c4da-4eae-d373-fcaa129f72c7"
130 | },
131 | "source": [
132 | "keys = x @ w_key\n",
133 | "querys = x @ w_query\n",
134 | "values = x @ w_value\n",
135 | "\n",
136 | "print('Keys: \\n', keys)\n",
137 | "print('Querys: \\n', querys)\n",
138 | "print('Values: \\n', values)"
139 | ],
140 | "execution_count": 5,
141 | "outputs": [
142 | {
143 | "output_type": "stream",
144 | "text": [
145 | "Keys: \n",
146 | " [[0. 1. 1.]\n",
147 | " [4. 4. 0.]\n",
148 | " [2. 3. 1.]]\n",
149 | "Querys: \n",
150 | " [[1. 0. 2.]\n",
151 | " [2. 2. 2.]\n",
152 | " [2. 1. 3.]]\n",
153 | "Values: \n",
154 | " [[1. 2. 4.]\n",
155 | " [2. 8. 0.]\n",
156 | " [2. 6. 4.]]\n"
157 | ],
158 | "name": "stdout"
159 | }
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {
165 | "id": "hTu3f5PSfLQ-"
166 | },
167 | "source": [
168 | "##Krok 4. Obliczenie attention scores"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "metadata": {
174 | "colab": {
175 | "base_uri": "https://localhost:8080/"
176 | },
177 | "id": "wHW48xpYfcZF",
178 | "outputId": "0c24d183-43c5-46e3-a46c-45c3c4b7ef83"
179 | },
180 | "source": [
181 | "attn_scores = querys @ keys.T\n",
182 | "print(attn_scores)"
183 | ],
184 | "execution_count": 14,
185 | "outputs": [
186 | {
187 | "output_type": "stream",
188 | "text": [
189 | "[[ 2. 4. 4.]\n",
190 | " [ 4. 16. 12.]\n",
191 | " [ 4. 12. 10.]]\n"
192 | ],
193 | "name": "stdout"
194 | }
195 | ]
196 | },
197 | {
198 | "cell_type": "markdown",
199 | "metadata": {
200 | "id": "RPvnjAmmgHUO"
201 | },
202 | "source": [
203 | "##Krok 5. Obliczenie softmax"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "metadata": {
209 | "colab": {
210 | "base_uri": "https://localhost:8080/"
211 | },
212 | "id": "qAW5QWTIgT_T",
213 | "outputId": "4932d038-e787-4530-d5ec-75fe794493fc"
214 | },
215 | "source": [
216 | "attn_scores_softmax = np.round_(tf.nn.softmax(attn_scores, axis=-1), decimals=1)\n",
217 | "print(attn_scores_softmax)"
218 | ],
219 | "execution_count": 30,
220 | "outputs": [
221 | {
222 | "output_type": "stream",
223 | "text": [
224 | "[[0.1 0.5 0.5]\n",
225 | " [0. 1. 0. ]\n",
226 | " [0. 0.9 0.1]]\n"
227 | ],
228 | "name": "stdout"
229 | }
230 | ]
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "metadata": {
235 | "id": "Jd2bE_wzkz-6"
236 | },
237 | "source": [
238 | "##Krok 6. Mnożenie scores i values"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "metadata": {
244 | "colab": {
245 | "base_uri": "https://localhost:8080/"
246 | },
247 | "id": "OsRCJnFFlKjb",
248 | "outputId": "66df39f5-7183-48ca-8d75-ab03b3b5159a"
249 | },
250 | "source": [
251 | "weighted_values = values[:, None] * attn_scores_softmax.T[:, :, None]\n",
252 | "print(weighted_values)"
253 | ],
254 | "execution_count": 31,
255 | "outputs": [
256 | {
257 | "output_type": "stream",
258 | "text": [
259 | "[[[0.1 0.2 0.4]\n",
260 | " [0. 0. 0. ]\n",
261 | " [0. 0. 0. ]]\n",
262 | "\n",
263 | " [[1. 4. 0. ]\n",
264 | " [2. 8. 0. ]\n",
265 | " [1.8 7.2 0. ]]\n",
266 | "\n",
267 | " [[1. 3. 2. ]\n",
268 | " [0. 0. 0. ]\n",
269 | " [0.2 0.6 0.4]]]\n"
270 | ],
271 | "name": "stdout"
272 | }
273 | ]
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "metadata": {
278 | "id": "toP8zcCLl7Yo"
279 | },
280 | "source": [
281 | "##Krok 7. Suma zważonych wartości"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "metadata": {
287 | "colab": {
288 | "base_uri": "https://localhost:8080/"
289 | },
290 | "id": "1G0hGuCMmDTs",
291 | "outputId": "16140ec9-885f-4f4d-8247-c056b6d0a503"
292 | },
293 | "source": [
294 | "outputs = np.sum(weighted_values, axis=0)\n",
295 | "print(outputs)"
296 | ],
297 | "execution_count": 32,
298 | "outputs": [
299 | {
300 | "output_type": "stream",
301 | "text": [
302 | "[[2.1 7.2 2.4 ]\n",
303 | " [2. 8. 0. ]\n",
304 | " [2. 7.7999997 0.4 ]]\n"
305 | ],
306 | "name": "stdout"
307 | }
308 | ]
309 | }
310 | ]
311 | }
--------------------------------------------------------------------------------
/modern_approach/zero_shot_learning/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PsorTheDoctor/artificial-intelligence/82516ac57eb13f14e8214633a0960bea0cd9e0fb/modern_approach/zero_shot_learning/example.png
--------------------------------------------------------------------------------
/neural_networks/CNN/pixelcnn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "pixelcnn.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyPfd36wg6CHF4lkcJIwlq3z",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "language_info": {
17 | "name": "python"
18 | },
19 | "accelerator": "GPU"
20 | },
21 | "cells": [
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {
25 | "id": "view-in-github",
26 | "colab_type": "text"
27 | },
28 | "source": [
29 | " "
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "WWCYgQAABKid"
36 | },
37 | "source": [
38 | "#PixelCNN\n",
39 | "##Import bibliotek"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "metadata": {
45 | "id": "DTLjE01TAf4J"
46 | },
47 | "source": [
48 | "import numpy as np\n",
49 | "import tensorflow as tf\n",
50 | "from tensorflow import keras\n",
51 | "from tensorflow.keras import layers\n",
52 | "from tqdm import tqdm"
53 | ],
54 | "execution_count": 1,
55 | "outputs": []
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {
60 | "id": "Lug7PJ_xBQI9"
61 | },
62 | "source": [
63 | "##Załadowanie danych"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "metadata": {
69 | "colab": {
70 | "base_uri": "https://localhost:8080/"
71 | },
72 | "id": "YI2AIzVEBIrW",
73 | "outputId": "4881db92-4ec5-4374-f97c-7ac8e33f7a0d"
74 | },
75 | "source": [
76 | "num_classes = 10\n",
77 | "input_shape = (28, 28, 1)\n",
78 | "n_residual_blocks = 5\n",
79 | "# Podział danych na traningowe i testowe\n",
80 | "(x, _), (y, _) = keras.datasets.mnist.load_data()\n",
81 | "# Zaokrąglenie wszystkich pikseli mniejszych od 33% z 256 do 0\n",
82 | "# Wszystko powyżej tej wartości zostanie zaokrąglone do 1, więc wartości\n",
83 | "# będą równe 0 lub 1\n",
84 | "data = np.concatenate((x, y), axis=0)\n",
85 | "data = np.where(data < (0.33 * 256), 0, 1)\n",
86 | "data = data.astype(np.float32)"
87 | ],
88 | "execution_count": 2,
89 | "outputs": [
90 | {
91 | "output_type": "stream",
92 | "text": [
93 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
94 | "11493376/11490434 [==============================] - 0s 0us/step\n"
95 | ],
96 | "name": "stdout"
97 | }
98 | ]
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {
103 | "id": "NdhDl5CNKdNm"
104 | },
105 | "source": [
106 | "##Stworzenie warstw modelu"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "metadata": {
112 | "id": "R8F0wRltBLGc"
113 | },
114 | "source": [
115 | "class PixelConvLayer(layers.Layer):\n",
116 | " def __init__(self, mask_type, **kwargs):\n",
117 | " super(PixelConvLayer, self).__init__()\n",
118 | " self.mask_type = mask_type\n",
119 | " self.conv = layers.Conv2D(**kwargs)\n",
120 | "\n",
121 | " def build(self, input_shape):\n",
122 | " self.conv.build(input_shape)\n",
123 | " kernel_shape = self.conv.kernel.get_shape()\n",
124 | " self.mask = np.zeros(shape=kernel_shape)\n",
125 | " self.mask[: kernel_shape[0] // 2, ...] = 1.0\n",
126 | " self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0\n",
127 | " if self.mask_type == 'B':\n",
128 | " self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0\n",
129 | "\n",
130 | " def call(self, inputs):\n",
131 | " self.conv.kernel.assign(self.conv.kernel * self.mask)\n",
132 | " return self.conv(inputs)\n",
133 | "\n",
134 | "\n",
135 | "class ResidualBlock(layers.Layer):\n",
136 | " def __init__(self, filters, **kwargs):\n",
137 | " super(ResidualBlock, self).__init__(kwargs);\n",
138 | " self.conv1 = layers.Conv2D(\n",
139 | " filters=filters, kernel_size=1, activation='relu'\n",
140 | " )\n",
141 | " self.pixel_conv = PixelConvLayer(\n",
142 | " mask_type='B',\n",
143 | " filters = filters // 2,\n",
144 | " kernel_size=3,\n",
145 | " activation='relu',\n",
146 | " padding='same'\n",
147 | " )\n",
148 | " self.conv2 = layers.Conv2D(\n",
149 | " filters=filters, kernel_size=1, activation='relu'\n",
150 | " )\n",
151 | "\n",
152 | " def call(self, inputs):\n",
153 | " x = self.conv1(inputs)\n",
154 | " x = self.pixel_conv(x)\n",
155 | " x = self.conv2(x)\n",
156 | " return layers.add([inputs, x])"
157 | ],
158 | "execution_count": 13,
159 | "outputs": []
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "metadata": {
164 | "id": "-6jAS8QuMhVt"
165 | },
166 | "source": [
167 | "##Budowa modelu"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "metadata": {
173 | "colab": {
174 | "base_uri": "https://localhost:8080/"
175 | },
176 | "id": "98GvFMF4MkTz",
177 | "outputId": "fc7a0fd7-99c7-40e3-c416-ed6dd7a9eadf"
178 | },
179 | "source": [
180 | "inputs = keras.Input(shape=input_shape)\n",
181 | "x = PixelConvLayer(\n",
182 | " mask_type='A', filters=128, kernel_size=7, activation='relu', padding='same'\n",
183 | ")(inputs)\n",
184 | "\n",
185 | "for _ in range(n_residual_blocks):\n",
186 | " x = ResidualBlock(filters=128)(x)\n",
187 | "\n",
188 | "for _ in range(2):\n",
189 | " x = PixelConvLayer(\n",
190 | " mask_type='B',\n",
191 | " filters=128,\n",
192 | " kernel_size=1,\n",
193 | " strides=1,\n",
194 | " activation='relu',\n",
195 | " padding='valid'\n",
196 | " )(x)\n",
197 | "\n",
198 | "out = layers.Conv2D(\n",
199 | " filters=1, kernel_size=1, strides=1, activation='sigmoid', padding='valid'\n",
200 | ")(x)\n",
201 | "\n",
202 | "pixel_cnn = keras.Model(inputs, out)\n",
203 | "adam = keras.optimizers.Adam(learning_rate=0.0005)\n",
204 | "pixel_cnn.compile(optimizer=adam, loss='binary_crossentropy')\n",
205 | "\n",
206 | "pixel_cnn.summary()\n",
207 | "pixel_cnn.fit(\n",
208 | " x=data, y=data, batch_size=128, epochs=50, validation_split=0.1, verbose=2\n",
209 | ")"
210 | ],
211 | "execution_count": 14,
212 | "outputs": [
213 | {
214 | "output_type": "stream",
215 | "text": [
216 | "Model: \"model\"\n",
217 | "_________________________________________________________________\n",
218 | "Layer (type) Output Shape Param # \n",
219 | "=================================================================\n",
220 | "input_6 (InputLayer) [(None, 28, 28, 1)] 0 \n",
221 | "_________________________________________________________________\n",
222 | "pixel_conv_layer_5 (PixelCon (None, 28, 28, 128) 6400 \n",
223 | "_________________________________________________________________\n",
224 | "residual_block_1 (ResidualBl (None, 28, 28, 128) 98624 \n",
225 | "_________________________________________________________________\n",
226 | "residual_block_2 (ResidualBl (None, 28, 28, 128) 98624 \n",
227 | "_________________________________________________________________\n",
228 | "residual_block_3 (ResidualBl (None, 28, 28, 128) 98624 \n",
229 | "_________________________________________________________________\n",
230 | "residual_block_4 (ResidualBl (None, 28, 28, 128) 98624 \n",
231 | "_________________________________________________________________\n",
232 | "residual_block_5 (ResidualBl (None, 28, 28, 128) 98624 \n",
233 | "_________________________________________________________________\n",
234 | "pixel_conv_layer_11 (PixelCo (None, 28, 28, 128) 16512 \n",
235 | "_________________________________________________________________\n",
236 | "pixel_conv_layer_12 (PixelCo (None, 28, 28, 128) 16512 \n",
237 | "_________________________________________________________________\n",
238 | "conv2d_24 (Conv2D) (None, 28, 28, 1) 129 \n",
239 | "=================================================================\n",
240 | "Total params: 532,673\n",
241 | "Trainable params: 39,553\n",
242 | "Non-trainable params: 493,120\n",
243 | "_________________________________________________________________\n",
244 | "Epoch 1/50\n",
245 | "493/493 - 77s - loss: 0.1387 - val_loss: 0.0975\n",
246 | "Epoch 2/50\n",
247 | "493/493 - 46s - loss: 0.0965 - val_loss: 0.0954\n",
248 | "Epoch 3/50\n",
249 | "493/493 - 46s - loss: 0.0949 - val_loss: 0.0946\n",
250 | "Epoch 4/50\n",
251 | "493/493 - 47s - loss: 0.0942 - val_loss: 0.0939\n",
252 | "Epoch 5/50\n",
253 | "493/493 - 47s - loss: 0.0938 - val_loss: 0.0941\n",
254 | "Epoch 6/50\n",
255 | "493/493 - 47s - loss: 0.0935 - val_loss: 0.0936\n",
256 | "Epoch 7/50\n",
257 | "493/493 - 48s - loss: 0.0932 - val_loss: 0.0932\n",
258 | "Epoch 8/50\n",
259 | "493/493 - 48s - loss: 0.0930 - val_loss: 0.0931\n",
260 | "Epoch 9/50\n",
261 | "493/493 - 48s - loss: 0.0929 - val_loss: 0.0934\n",
262 | "Epoch 10/50\n",
263 | "493/493 - 48s - loss: 0.0928 - val_loss: 0.0930\n",
264 | "Epoch 11/50\n",
265 | "493/493 - 48s - loss: 0.0927 - val_loss: 0.0928\n",
266 | "Epoch 12/50\n",
267 | "493/493 - 48s - loss: 0.0926 - val_loss: 0.0927\n",
268 | "Epoch 13/50\n",
269 | "493/493 - 48s - loss: 0.0925 - val_loss: 0.0928\n",
270 | "Epoch 14/50\n",
271 | "493/493 - 48s - loss: 0.0925 - val_loss: 0.0926\n",
272 | "Epoch 15/50\n",
273 | "493/493 - 48s - loss: 0.0924 - val_loss: 0.0925\n",
274 | "Epoch 16/50\n",
275 | "493/493 - 48s - loss: 0.0924 - val_loss: 0.0927\n",
276 | "Epoch 17/50\n",
277 | "493/493 - 48s - loss: 0.0922 - val_loss: 0.0927\n",
278 | "Epoch 18/50\n",
279 | "493/493 - 48s - loss: 0.0922 - val_loss: 0.0932\n",
280 | "Epoch 19/50\n",
281 | "493/493 - 48s - loss: 0.0922 - val_loss: 0.0923\n",
282 | "Epoch 20/50\n",
283 | "493/493 - 48s - loss: 0.0921 - val_loss: 0.0923\n",
284 | "Epoch 21/50\n",
285 | "493/493 - 48s - loss: 0.0921 - val_loss: 0.0922\n",
286 | "Epoch 22/50\n",
287 | "493/493 - 48s - loss: 0.0921 - val_loss: 0.0924\n",
288 | "Epoch 23/50\n",
289 | "493/493 - 48s - loss: 0.0920 - val_loss: 0.0922\n",
290 | "Epoch 24/50\n",
291 | "493/493 - 48s - loss: 0.0920 - val_loss: 0.0925\n",
292 | "Epoch 25/50\n",
293 | "493/493 - 48s - loss: 0.0920 - val_loss: 0.0921\n",
294 | "Epoch 26/50\n",
295 | "493/493 - 48s - loss: 0.0919 - val_loss: 0.0933\n",
296 | "Epoch 27/50\n",
297 | "493/493 - 48s - loss: 0.0919 - val_loss: 0.0920\n",
298 | "Epoch 28/50\n",
299 | "493/493 - 48s - loss: 0.0918 - val_loss: 0.0919\n",
300 | "Epoch 29/50\n",
301 | "493/493 - 48s - loss: 0.0918 - val_loss: 0.0920\n",
302 | "Epoch 30/50\n",
303 | "493/493 - 48s - loss: 0.0918 - val_loss: 0.0920\n",
304 | "Epoch 31/50\n",
305 | "493/493 - 48s - loss: 0.0918 - val_loss: 0.0923\n",
306 | "Epoch 32/50\n",
307 | "493/493 - 48s - loss: 0.0917 - val_loss: 0.0919\n",
308 | "Epoch 33/50\n",
309 | "493/493 - 48s - loss: 0.0917 - val_loss: 0.0920\n",
310 | "Epoch 34/50\n",
311 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0919\n",
312 | "Epoch 35/50\n",
313 | "493/493 - 48s - loss: 0.0917 - val_loss: 0.0918\n",
314 | "Epoch 36/50\n",
315 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0917\n",
316 | "Epoch 37/50\n",
317 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0919\n",
318 | "Epoch 38/50\n",
319 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0917\n",
320 | "Epoch 39/50\n",
321 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0922\n",
322 | "Epoch 40/50\n",
323 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0919\n",
324 | "Epoch 41/50\n",
325 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0917\n",
326 | "Epoch 42/50\n",
327 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0916\n",
328 | "Epoch 43/50\n",
329 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0917\n",
330 | "Epoch 44/50\n",
331 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0916\n",
332 | "Epoch 45/50\n",
333 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0918\n",
334 | "Epoch 46/50\n",
335 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0916\n",
336 | "Epoch 47/50\n",
337 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0916\n",
338 | "Epoch 48/50\n",
339 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0917\n",
340 | "Epoch 49/50\n",
341 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0919\n",
342 | "Epoch 50/50\n",
343 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0922\n"
344 | ],
345 | "name": "stdout"
346 | },
347 | {
348 | "output_type": "execute_result",
349 | "data": {
350 | "text/plain": [
351 | ""
352 | ]
353 | },
354 | "metadata": {
355 | "tags": []
356 | },
357 | "execution_count": 14
358 | }
359 | ]
360 | },
361 | {
362 | "cell_type": "markdown",
363 | "metadata": {
364 | "id": "N24P8Z9-T7XX"
365 | },
366 | "source": [
367 | "##Demonstracja"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "metadata": {
373 | "colab": {
374 | "base_uri": "https://localhost:8080/",
375 | "height": 146
376 | },
377 | "id": "rw3u7OYUTxWp",
378 | "outputId": "d3fb76c3-c49f-44a3-e801-2a527a98343e"
379 | },
380 | "source": [
381 | "from IPython.display import Image, display\n",
382 | "\n",
383 | "batch = 4\n",
384 | "pixels = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])\n",
385 | "batch, rows, cols, channels = pixels.shape\n",
386 | "\n",
387 | "for row in tqdm(range(rows)):\n",
388 | " for col in range(cols):\n",
389 | " for channel in range(channels):\n",
390 | " \n",
391 | " probs = pixel_cnn.predict(pixels)[:, row, col, channel]\n",
392 | " \n",
393 | " pixels[:, row, col, channel] = tf.math.ceil(\n",
394 | " probs - tf.random.uniform(probs.shape)\n",
395 | " )\n",
396 | "\n",
397 | "def deprocess_image(x):\n",
398 | " x = np.stack((x, x, x), 2)\n",
399 | "\n",
400 | " x *= 255.0\n",
401 | "\n",
402 | " x = np.clip(x, 0, 255).astype('uint8')\n",
403 | " return x\n",
404 | "\n",
405 | "for i, pic in enumerate(pixels):\n",
406 | " keras.preprocessing.image.save_img(\n",
407 | " 'generated_image_{}.png'.format(i), deprocess_image(np.squeeze(pic, -1))\n",
408 | " )\n",
409 | "\n",
410 | "display(Image('generated_image_0.png'))\n",
411 | "display(Image('generated_image_1.png'))\n",
412 | "display(Image('generated_image_2.png'))\n",
413 | "display(Image('generated_image_3.png'))"
414 | ],
415 | "execution_count": 18,
416 | "outputs": [
417 | {
418 | "output_type": "stream",
419 | "text": [
420 | "100%|██████████| 28/28 [00:32<00:00, 1.18s/it]\n"
421 | ],
422 | "name": "stderr"
423 | },
424 | {
425 | "output_type": "display_data",
426 | "data": {
427 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAAh0lEQVR4nO2UwQ7AIAhDwez/f5kdXKYSdKHgTvZmnG+1gERHR6litRaRtsd61w3tccMXfnTBvKx1qTUzV8vw3al3WimzEFwy7MRrZZ9RfiNR2HQgkC3V/x0Kd8Lm5p9VPHNMI1PwMUVYh5aUt05DF0T4H0OmWa4bNDEH4z2N4B4Iduy9lmniBgARMCy+4/TbAAAAAElFTkSuQmCC\n",
428 | "text/plain": [
429 | ""
430 | ]
431 | },
432 | "metadata": {
433 | "tags": []
434 | }
435 | },
436 | {
437 | "output_type": "display_data",
438 | "data": {
439 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAAZklEQVR4nO2USxYAIARF1f73rFlO6UM9M3fY5x5CREmS4Ch6iZlluywOXKkH4zNDIN34FmBHIkUZSaf/bxQp5CkHKTBxkaJcSSC+Ek2dt6vw3PwQHFL7gCykxsuH7g75+qzpu0a5AQqvIRRWxCD2AAAAAElFTkSuQmCC\n",
440 | "text/plain": [
441 | ""
442 | ]
443 | },
444 | "metadata": {
445 | "tags": []
446 | }
447 | },
448 | {
449 | "output_type": "display_data",
450 | "data": {
451 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAAbUlEQVR4nO2QMQ7AMAgD7Sj//zIdWlVtAqQKDB24ESUn20BRFEUSTLGIyKUjAfQs3ZOWaDxj4nt9NZFqVKTDOr53fvOSqkGsP0uaZXTua6nzM5Q0nTbEIXlftut3BGpaePV/tul82p7Sk8YnPgA8jCQivImO7AAAAABJRU5ErkJggg==\n",
452 | "text/plain": [
453 | ""
454 | ]
455 | },
456 | "metadata": {
457 | "tags": []
458 | }
459 | },
460 | {
461 | "output_type": "display_data",
462 | "data": {
463 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAAYUlEQVR4nO2SMQ7AMAgDTdX/f9kZKnUCh9BE6sCtiAMLgKZpmhVIknRLV9koqhXpazSzPdKpsSLN8CepyA7g1s36yuHIgkuv6Ug/6h7C+Ml+l/PXzzz2mnSXEWfj1/4xYgBOgCEY9BznlAAAAABJRU5ErkJggg==\n",
464 | "text/plain": [
465 | ""
466 | ]
467 | },
468 | "metadata": {
469 | "tags": []
470 | }
471 | }
472 | ]
473 | }
474 | ]
475 | }
--------------------------------------------------------------------------------
/neural_networks/GAN/draggan.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4",
8 | "authorship_tag": "ABX9TyNxNhZKkvQ0+dfoe14tVQd8",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | },
18 | "accelerator": "GPU"
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "view-in-github",
25 | "colab_type": "text"
26 | },
27 | "source": [
28 | " "
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "source": [
34 | "#DragGAN"
35 | ],
36 | "metadata": {
37 | "id": "FiZr6NwNrF2H"
38 | }
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {
44 | "id": "9I5eZYJKq0no"
45 | },
46 | "outputs": [],
47 | "source": [
48 | "!git clone https://github.com/Zeqiang-Lai/DragGAN.git\n",
49 | "\n",
50 | "import sys\n",
51 | "sys.path.append(\".\")\n",
52 | "sys.path.append('./DragGAN')\n",
53 | "\n",
54 | "!pip install -r DragGAN/requirements.txt\n",
55 | "\n",
56 | "from gradio_app import main"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "source": [
62 | "demo = main()\n",
63 | "demo.queue(concurrency_count=1, max_size=20).launch()"
64 | ],
65 | "metadata": {
66 | "id": "VPgSsKPsrKKl"
67 | },
68 | "execution_count": null,
69 | "outputs": []
70 | }
71 | ]
72 | }
--------------------------------------------------------------------------------
/neural_networks/GAN/ersgan_for_vdm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyOlGtAwGTMODsdOeXhygfwK",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | " "
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "source": [
32 | "#ESRGAN for video diffusion"
33 | ],
34 | "metadata": {
35 | "id": "754mykfUJ1U7"
36 | }
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 14,
41 | "metadata": {
42 | "id": "4B0zUQQCJza4"
43 | },
44 | "outputs": [],
45 | "source": [
46 | "import numpy as np\n",
47 | "import tensorflow_hub as hub\n",
48 | "import tensorflow as tf\n",
49 | "import cv2\n",
50 | "from google.colab.patches import cv2_imshow"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "source": [
56 | "model = hub.load('https://tfhub.dev/captain-pool/esrgan-tf2/1')"
57 | ],
58 | "metadata": {
59 | "id": "yMVVA5dcLPVt"
60 | },
61 | "execution_count": 2,
62 | "outputs": []
63 | },
64 | {
65 | "cell_type": "code",
66 | "source": [
67 | "video_file = 'example.mp4'\n",
68 | "cap = cv2.VideoCapture(video_file)\n",
69 | "n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
70 | "frames = []\n",
71 | "\n",
72 | "for i in range(n_frames):\n",
73 | " _, frame = cap.read()\n",
74 | " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
75 | "\n",
76 | " frame = tf.expand_dims(frame, 0)\n",
77 | " low_resolution = tf.cast(frame, tf.float32)\n",
78 | " super_resolution = model(low_resolution)\n",
79 | "\n",
80 | " frames.append(super_resolution)\n",
81 | "\n",
82 | " print(f'{i + 1}/{n_frames} frames processed.')\n",
83 | "\n",
84 | "cap.release()\n",
85 | "frames = np.array(frames)\n",
86 | "frames.shape"
87 | ],
88 | "metadata": {
89 | "colab": {
90 | "base_uri": "https://localhost:8080/"
91 | },
92 | "id": "5WRnIyHXKOAv",
93 | "outputId": "ff12e830-1c1f-4158-e065-1f1f8a9eabbd"
94 | },
95 | "execution_count": 6,
96 | "outputs": [
97 | {
98 | "output_type": "execute_result",
99 | "data": {
100 | "text/plain": [
101 | "(16, 1, 1024, 1024, 3)"
102 | ]
103 | },
104 | "metadata": {},
105 | "execution_count": 6
106 | }
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "source": [
112 | "output_frames = frames.reshape(16, 1024, 1024, 3)\n",
113 | "cv2_imshow(output_frames[0])"
114 | ],
115 | "metadata": {
116 | "id": "nR8Z7xMMPTMC"
117 | },
118 | "execution_count": null,
119 | "outputs": []
120 | },
121 | {
122 | "cell_type": "code",
123 | "source": [
124 | "output_filename = 'result.mp4'\n",
125 | "fps = 16\n",
126 | "resolution = (1024, 1024)\n",
127 | "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
128 | "writer = cv2.VideoWriter(output_filename, fourcc, fps, resolution)\n",
129 | "\n",
130 | "for frame in output_frames:\n",
131 | " frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n",
132 | " frame = (frame * 255).astype(np.uint8)\n",
133 | " writer.write(frame)\n",
134 | "\n",
135 | "writer.release()"
136 | ],
137 | "metadata": {
138 | "id": "yY5aHcMyNa4T"
139 | },
140 | "execution_count": 16,
141 | "outputs": []
142 | }
143 | ]
144 | }
--------------------------------------------------------------------------------
/neural_networks/MLP/experimental/zip_learning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "zip_learning.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyMopxhbnZCQHibT6Wyxr/mi",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | " "
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "c1ACAYy4Iuev",
32 | "colab_type": "text"
33 | },
34 | "source": [
35 | "# Zip Learning - Uczenie skompresowane\n",
36 | "Eksperymentalna metoda uczenia na skompresowanym zbiorze Mnist."
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {
42 | "id": "6ljDzo5Lupqs",
43 | "colab_type": "text"
44 | },
45 | "source": [
46 | "## Import bibliotek"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "metadata": {
52 | "id": "P5lGUmImIt48",
53 | "colab_type": "code",
54 | "colab": {}
55 | },
56 | "source": [
57 | "%tensorflow_version 2.x\n",
58 | "from datetime import datetime\n",
59 | "import numpy as np\n",
60 | "\n",
61 | "import tensorflow as tf\n",
62 | "from tensorflow.keras.datasets.mnist import load_data\n",
63 | "from tensorflow.keras.models import Sequential\n",
64 | "from tensorflow.keras.layers import InputLayer, Dense, Dropout"
65 | ],
66 | "execution_count": 0,
67 | "outputs": []
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {
72 | "id": "0RfpmBAhvBMm",
73 | "colab_type": "text"
74 | },
75 | "source": [
76 | "## Funkcje kompresujące\n",
77 | "Zastosowany algorytm kompresji jest podobny algorytmu kompresji bezstratnej RLE. Funkcje kompresji dla tablic jednowymiarowych (wektorów) i dwuwymiarowych (macierzy) to odpowiednio `zip1D` i `zip2D`. Obie operują na obiektach `numpy.ndarray`. Funkcje zwracają po 2 tablice, z których pierwsza `unique_vals` zawiera serię wartości wystepujących w argumencie (jeśli wiele takich samych wartości stoi po sobie to są zapisywanie jako jedna), a druga `vals_ctr` - ilość wystąpień każdej z nich. \n",
78 | "\n",
79 | "\n"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "metadata": {
85 | "id": "VkIMCZ07K935",
86 | "colab_type": "code",
87 | "colab": {}
88 | },
89 | "source": [
90 | "# Kompresja 1-wymiarowa\n",
91 | "def zip1D(array):\n",
92 | " unique_vals = []\n",
93 | " vals_ctr = []\n",
94 | " current_val = None\n",
95 | " idx = -1\n",
96 | "\n",
97 | " for i in range(len(array)):\n",
98 | " if array[i] != current_val: # \"is not\" doesn't work with numpy arrays!\n",
99 | " current_val = array[i]\n",
100 | " unique_vals.append(current_val)\n",
101 | " vals_ctr.append(0)\n",
102 | " idx += 1\n",
103 | " vals_ctr[idx] += 1\n",
104 | "\n",
105 | " return unique_vals, vals_ctr\n",
106 | "\n",
107 | "# Kompresja 2-wymiarowa\n",
108 | "def zip2D(mat):\n",
109 | " if type(mat) is np.ndarray:\n",
110 | " array = mat.flatten(order='C')\n",
111 | " return zip1D(array)"
112 | ],
113 | "execution_count": 0,
114 | "outputs": []
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {
119 | "id": "ea_iVUta02EB",
120 | "colab_type": "text"
121 | },
122 | "source": [
123 | "Przykład działania funkcji `zip1D`:"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "metadata": {
129 | "id": "tuW_cOd80TFD",
130 | "colab_type": "code",
131 | "colab": {
132 | "base_uri": "https://localhost:8080/",
133 | "height": 35
134 | },
135 | "outputId": "c580479b-3c55-4ecd-855d-286dbc2367bb"
136 | },
137 | "source": [
138 | "array = np.array(['A','A','B','A','A','A','A'])\n",
139 | "\n",
140 | "unique_vals, vals_ctr = zip1D(array)\n",
141 | "unique_vals, vals_ctr"
142 | ],
143 | "execution_count": 114,
144 | "outputs": [
145 | {
146 | "output_type": "execute_result",
147 | "data": {
148 | "text/plain": [
149 | "(['A', 'B', 'A'], [2, 1, 4])"
150 | ]
151 | },
152 | "metadata": {
153 | "tags": []
154 | },
155 | "execution_count": 114
156 | }
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {
162 | "id": "C37RSqzi1COH",
163 | "colab_type": "text"
164 | },
165 | "source": [
166 | "Przykład działania funkcji `zip2D`:"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "metadata": {
172 | "id": "FgU4HUN60gYv",
173 | "colab_type": "code",
174 | "colab": {
175 | "base_uri": "https://localhost:8080/",
176 | "height": 35
177 | },
178 | "outputId": "3f58125f-2a5b-42f8-d560-606c4e8e2005"
179 | },
180 | "source": [
181 | "mat = np.array([['A','B','A'], \n",
182 | " ['B','B','B'], \n",
183 | " ['A','B','A']])\n",
184 | "\n",
185 | "unique_vals, vals_ctr = zip2D(mat)\n",
186 | "unique_vals, vals_ctr"
187 | ],
188 | "execution_count": 115,
189 | "outputs": [
190 | {
191 | "output_type": "execute_result",
192 | "data": {
193 | "text/plain": [
194 | "(['A', 'B', 'A', 'B', 'A', 'B', 'A'], [1, 1, 1, 3, 1, 1, 1])"
195 | ]
196 | },
197 | "metadata": {
198 | "tags": []
199 | },
200 | "execution_count": 115
201 | }
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {
207 | "id": "N-GGUJ2WvYB6",
208 | "colab_type": "text"
209 | },
210 | "source": [
211 | "## Załadowanie danych"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "metadata": {
217 | "id": "bGk5m10QJ8CT",
218 | "colab_type": "code",
219 | "colab": {}
220 | },
221 | "source": [
222 | "(X_train, y_train), (X_test, y_test) = load_data()"
223 | ],
224 | "execution_count": 0,
225 | "outputs": []
226 | },
227 | {
228 | "cell_type": "code",
229 | "metadata": {
230 | "id": "859C0hRsqXE4",
231 | "colab_type": "code",
232 | "colab": {
233 | "base_uri": "https://localhost:8080/",
234 | "height": 52
235 | },
236 | "outputId": "79cd53cc-6de0-47bf-ada5-c9f52880e8cb"
237 | },
238 | "source": [
239 | "print(X_train.shape)\n",
240 | "print(y_train.shape)"
241 | ],
242 | "execution_count": 93,
243 | "outputs": [
244 | {
245 | "output_type": "stream",
246 | "text": [
247 | "(60000, 28, 28)\n",
248 | "(60000,)\n"
249 | ],
250 | "name": "stdout"
251 | }
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {
257 | "id": "PNKvvLvAvczo",
258 | "colab_type": "text"
259 | },
260 | "source": [
261 | "## Normalizacja"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "metadata": {
267 | "id": "KsN4OFKKRPqU",
268 | "colab_type": "code",
269 | "colab": {}
270 | },
271 | "source": [
272 | "# X_train = X_train / 255.0\n",
273 | "n_samples = len(X_train)\n",
274 | "width = len(X_train[0])\n",
275 | "height = len(X_train[0][0])\n",
276 | "threshold = 128\n",
277 | "\n",
278 | "for n in range(n_samples):\n",
279 | " for x in range(width):\n",
280 | " for y in range(height):\n",
281 | " if X_train[n][x][y] < threshold:\n",
282 | " X_train[n][x][y] = 0\n",
283 | " else:\n",
284 | " X_train[n][x][y] = 1"
285 | ],
286 | "execution_count": 0,
287 | "outputs": []
288 | },
289 | {
290 | "cell_type": "markdown",
291 | "metadata": {
292 | "id": "zrIYdbeUvqgg",
293 | "colab_type": "text"
294 | },
295 | "source": [
296 | "## Kompresja"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "metadata": {
302 | "id": "eOX2awg3Krpe",
303 | "colab_type": "code",
304 | "colab": {
305 | "base_uri": "https://localhost:8080/",
306 | "height": 35
307 | },
308 | "outputId": "c3e6571c-fbdb-402d-94aa-786be5427796"
309 | },
310 | "source": [
311 | "X_train_unique_vals = []\n",
312 | "X_train_vals_ctr = []\n",
313 | "\n",
314 | "start = datetime.now()\n",
315 | "for i in range(len(X_train)):\n",
316 | " current_val, val_ctr = zip2D(X_train[i])\n",
317 | " X_train_unique_vals.append(current_val)\n",
318 | " X_train_vals_ctr.append(val_ctr)\n",
319 | "\n",
320 | "zipping_time = datetime.now() - start\n",
321 | "print(zipping_time)"
322 | ],
323 | "execution_count": 97,
324 | "outputs": [
325 | {
326 | "output_type": "stream",
327 | "text": [
328 | "0:00:13.333470\n"
329 | ],
330 | "name": "stdout"
331 | }
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "metadata": {
337 | "id": "lZ3eWuCZMdOg",
338 | "colab_type": "code",
339 | "colab": {
340 | "base_uri": "https://localhost:8080/",
341 | "height": 87
342 | },
343 | "outputId": "d85136e6-11e1-4582-c1c8-9e15d8252ec0"
344 | },
345 | "source": [
346 | "print(len(X_train_vals_ctr))\n",
347 | "print(len(X_train_unique_vals))\n",
348 | "\n",
349 | "print(len(X_train_unique_vals[0]))\n",
350 | "print(len(X_train_vals_ctr[0]))"
351 | ],
352 | "execution_count": 98,
353 | "outputs": [
354 | {
355 | "output_type": "stream",
356 | "text": [
357 | "60000\n",
358 | "60000\n",
359 | "47\n",
360 | "47\n"
361 | ],
362 | "name": "stdout"
363 | }
364 | ]
365 | },
366 | {
367 | "cell_type": "markdown",
368 | "metadata": {
369 | "id": "71lTm-MW1XCy",
370 | "colab_type": "text"
371 | },
372 | "source": [
373 | "## Znalezienie wektora o największej długości"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "metadata": {
379 | "id": "M3qinvVjVABc",
380 | "colab_type": "code",
381 | "colab": {
382 | "base_uri": "https://localhost:8080/",
383 | "height": 35
384 | },
385 | "outputId": "769a8b07-dc9f-41df-8a69-06fe67608608"
386 | },
387 | "source": [
388 | "best_length = 0\n",
389 | "for i in range(len(X_train_vals_ctr)):\n",
390 | " if len(X_train_vals_ctr[i]) > best_length:\n",
391 | " best_length = len(X_train_vals_ctr[i])\n",
392 | "\n",
393 | "print(best_length)"
394 | ],
395 | "execution_count": 99,
396 | "outputs": [
397 | {
398 | "output_type": "stream",
399 | "text": [
400 | "101\n"
401 | ],
402 | "name": "stdout"
403 | }
404 | ]
405 | },
406 | {
407 | "cell_type": "markdown",
408 | "metadata": {
409 | "id": "NHn-aUW9v7QB",
410 | "colab_type": "text"
411 | },
412 | "source": [
413 | "## Padding"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "metadata": {
419 | "id": "9LPL27F6Vspe",
420 | "colab_type": "code",
421 | "colab": {
422 | "base_uri": "https://localhost:8080/",
423 | "height": 35
424 | },
425 | "outputId": "7172c958-a973-4c8f-b27b-b53a979d50eb"
426 | },
427 | "source": [
428 | "padded_X_train_vals_ctr = np.zeros([len(X_train_vals_ctr), best_length])\n",
429 | "\n",
430 | "for i in range(len(X_train_vals_ctr)):\n",
431 | " padded = np.pad(X_train_vals_ctr[i], \n",
432 | " pad_width=(best_length - len(X_train_vals_ctr[i]), 0), \n",
433 | " mode='constant')\n",
434 | " padded_X_train_vals_ctr[i] = padded\n",
435 | "\n",
436 | "print(padded_X_train_vals_ctr[0].shape)"
437 | ],
438 | "execution_count": 100,
439 | "outputs": [
440 | {
441 | "output_type": "stream",
442 | "text": [
443 | "(101,)\n"
444 | ],
445 | "name": "stdout"
446 | }
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "metadata": {
452 | "id": "WtvStaNPegMa",
453 | "colab_type": "code",
454 | "colab": {
455 | "base_uri": "https://localhost:8080/",
456 | "height": 52
457 | },
458 | "outputId": "05f76082-fb36-456f-f1b7-a5a8cdee709d"
459 | },
460 | "source": [
461 | "print(padded_X_train_vals_ctr.shape)\n",
462 | "print(y_train.shape)"
463 | ],
464 | "execution_count": 101,
465 | "outputs": [
466 | {
467 | "output_type": "stream",
468 | "text": [
469 | "(60000, 101)\n",
470 | "(60000,)\n"
471 | ],
472 | "name": "stdout"
473 | }
474 | ]
475 | },
476 | {
477 | "cell_type": "markdown",
478 | "metadata": {
479 | "id": "yEyr9xC-v0SK",
480 | "colab_type": "text"
481 | },
482 | "source": [
483 | "## Budowa sieci MLP"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "metadata": {
489 | "id": "jiUJ8hknSwqD",
490 | "colab_type": "code",
491 | "colab": {
492 | "base_uri": "https://localhost:8080/",
493 | "height": 260
494 | },
495 | "outputId": "6e8bc998-713c-478d-ec05-26878556088d"
496 | },
497 | "source": [
498 | "model = Sequential()\n",
499 | "model.add(InputLayer(input_shape=(padded_X_train_vals_ctr.shape)))\n",
500 | "model.add(Dense(units=128, activation='relu'))\n",
501 | "model.add(Dropout(0.2))\n",
502 | "model.add(Dense(units=10, activation='softmax'))\n",
503 | "\n",
504 | "model.compile(optimizer='adam',\n",
505 | " loss='sparse_categorical_crossentropy',\n",
506 | " metrics=['accuracy'])\n",
507 | "\n",
508 | "model.summary()"
509 | ],
510 | "execution_count": 102,
511 | "outputs": [
512 | {
513 | "output_type": "stream",
514 | "text": [
515 | "Model: \"sequential_8\"\n",
516 | "_________________________________________________________________\n",
517 | "Layer (type) Output Shape Param # \n",
518 | "=================================================================\n",
519 | "dense_12 (Dense) (None, 60000, 128) 13056 \n",
520 | "_________________________________________________________________\n",
521 | "dropout_6 (Dropout) (None, 60000, 128) 0 \n",
522 | "_________________________________________________________________\n",
523 | "dense_13 (Dense) (None, 60000, 10) 1290 \n",
524 | "=================================================================\n",
525 | "Total params: 14,346\n",
526 | "Trainable params: 14,346\n",
527 | "Non-trainable params: 0\n",
528 | "_________________________________________________________________\n"
529 | ],
530 | "name": "stdout"
531 | }
532 | ]
533 | },
534 | {
535 | "cell_type": "markdown",
536 | "metadata": {
537 | "id": "d5Lh2G7DwCPx",
538 | "colab_type": "text"
539 | },
540 | "source": [
541 | "## Trening modelu"
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "metadata": {
547 | "id": "I-1P3W2ociwS",
548 | "colab_type": "code",
549 | "colab": {
550 | "base_uri": "https://localhost:8080/",
551 | "height": 191
552 | },
553 | "outputId": "a978bab7-2e0d-4074-d654-45e45f20c9ee"
554 | },
555 | "source": [
556 | "history = model.fit(padded_X_train_vals_ctr, y_train, epochs=5)"
557 | ],
558 | "execution_count": 106,
559 | "outputs": [
560 | {
561 | "output_type": "stream",
562 | "text": [
563 | "Epoch 1/5\n",
564 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3783 - accuracy: 0.8725\n",
565 | "Epoch 2/5\n",
566 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3756 - accuracy: 0.8718\n",
567 | "Epoch 3/5\n",
568 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3752 - accuracy: 0.8728\n",
569 | "Epoch 4/5\n",
570 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3760 - accuracy: 0.8723\n",
571 | "Epoch 5/5\n",
572 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3738 - accuracy: 0.8734\n"
573 | ],
574 | "name": "stdout"
575 | }
576 | ]
577 | }
578 | ]
579 | }
--------------------------------------------------------------------------------
/neural_networks/NSL/adversarial_regularization/adversarial_regularization_mnist.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "adversarial_regularization_mnist.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyNcJ8TeO0Jlrt9MiDyM2qgm",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | " "
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "JhX2byHcpx-w",
33 | "colab_type": "text"
34 | },
35 | "source": [
36 | "# Adversarial regularization for image classification"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "metadata": {
42 | "id": "PItAtRzGpfRV",
43 | "colab_type": "code",
44 | "colab": {
45 | "base_uri": "https://localhost:8080/",
46 | "height": 35
47 | },
48 | "outputId": "23e280c2-25e0-4f63-da62-feeb2be97e83"
49 | },
50 | "source": [
51 | "!pip install -q neural-structured-learning"
52 | ],
53 | "execution_count": 1,
54 | "outputs": [
55 | {
56 | "output_type": "stream",
57 | "text": [
58 | "\u001b[?25l\r\u001b[K |███▏ | 10kB 23.5MB/s eta 0:00:01\r\u001b[K |██████▎ | 20kB 30.4MB/s eta 0:00:01\r\u001b[K |█████████▍ | 30kB 22.0MB/s eta 0:00:01\r\u001b[K |████████████▌ | 40kB 22.3MB/s eta 0:00:01\r\u001b[K |███████████████▋ | 51kB 20.5MB/s eta 0:00:01\r\u001b[K |██████████████████▉ | 61kB 23.2MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 71kB 18.5MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 81kB 18.0MB/s eta 0:00:01\r\u001b[K |████████████████████████████▏ | 92kB 18.0MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▎| 102kB 18.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 112kB 18.3MB/s \n",
59 | "\u001b[?25h"
60 | ],
61 | "name": "stdout"
62 | }
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "metadata": {
68 | "id": "69h_NWQMqJEd",
69 | "colab_type": "code",
70 | "colab": {
71 | "base_uri": "https://localhost:8080/",
72 | "height": 64
73 | },
74 | "outputId": "2afed11e-8a08-4ccc-ca43-5e7dd37ede86"
75 | },
76 | "source": [
77 | "from __future__ import absolute_import, division, print_function, unicode_literals\n",
78 | "\n",
79 | "import matplotlib.pyplot as plt\n",
80 | "import neural_structured_learning as nsl\n",
81 | "import numpy as np\n",
82 | "import tensorflow as tf\n",
83 | "import tensorflow_datasets as tfds"
84 | ],
85 | "execution_count": 2,
86 | "outputs": [
87 | {
88 | "output_type": "display_data",
89 | "data": {
90 | "text/html": [
91 | "\n",
92 | "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x. \n",
93 | "We recommend you upgrade now \n",
94 | "or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic:\n",
95 | "more info .
\n"
96 | ],
97 | "text/plain": [
98 | ""
99 | ]
100 | },
101 | "metadata": {
102 | "tags": []
103 | }
104 | }
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {
110 | "id": "0K0XGMSPqs22",
111 | "colab_type": "text"
112 | },
113 | "source": [
114 | "### Hiperparametry"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "metadata": {
120 | "id": "hZIyxe8Hqmbp",
121 | "colab_type": "code",
122 | "colab": {}
123 | },
124 | "source": [
125 | "class HParams(object):\n",
126 | " def __init__(self):\n",
127 | " self.input_shape = [28, 28, 1]\n",
128 | " self.num_classes = 10\n",
129 | " self.conv_filters = [32, 64, 64]\n",
130 | " self.kernel_size = (3, 3)\n",
131 | " self.pool_size = (2, 2)\n",
132 | " self.num_fc_units = [64]\n",
133 | " self.batch_size = 32\n",
134 | " self.epochs = 5\n",
135 | " self.adv_multiplier = 0.2\n",
136 | " self.adv_step_size = 0.2\n",
137 | " self.adv_grad_norm = 'infinity'\n",
138 | "\n",
139 | "HPARAMS = HParams()"
140 | ],
141 | "execution_count": 0,
142 | "outputs": []
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {
147 | "id": "PHE9rXs4rsLh",
148 | "colab_type": "text"
149 | },
150 | "source": [
151 | "### Zbiór danych MNIST"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "metadata": {
157 | "id": "TmJQNSqgrlaE",
158 | "colab_type": "code",
159 | "colab": {}
160 | },
161 | "source": [
162 | "datasets = tfds.load('mnist')\n",
163 | "\n",
164 | "train_dataset = datasets['train']\n",
165 | "test_dataset = datasets['test']\n",
166 | "\n",
167 | "IMAGE_INPUT_NAME = 'image'\n",
168 | "LABEL_INPUT_NAME = 'label'"
169 | ],
170 | "execution_count": 0,
171 | "outputs": []
172 | },
173 | {
174 | "cell_type": "code",
175 | "metadata": {
176 | "id": "jnm1ycvysQNX",
177 | "colab_type": "code",
178 | "colab": {}
179 | },
180 | "source": [
181 | "def normalize(features):\n",
182 | " features[IMAGE_INPUT_NAME] = tf.cast(\n",
183 | " features[IMAGE_INPUT_NAME], dtype=tf.float32) / 255.0\n",
184 | " return features\n",
185 | "\n",
186 | "def convert_to_tuples(features):\n",
187 | " return features[IMAGE_INPUT_NAME], features[LABEL_INPUT_NAME]\n",
188 | "\n",
189 | "def convert_to_dictionaries(image, label):\n",
190 | " return {IMAGE_INPUT_NAME: image, LABEL_INPUT_NAME: label}\n",
191 | "\n",
192 | "train_dataset = train_dataset.map(normalize).shuffle(10000).batch(HPARAMS.batch_size).map(convert_to_tuples)\n",
193 | "test_dataset = test_dataset.map(normalize).batch(HPARAMS.batch_size).map(convert_to_tuples)"
194 | ],
195 | "execution_count": 0,
196 | "outputs": []
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {
201 | "id": "wzBCUYYSw2-V",
202 | "colab_type": "text"
203 | },
204 | "source": [
205 | "### Model bazowy"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "metadata": {
211 | "id": "ihr_UqDWtgHf",
212 | "colab_type": "code",
213 | "colab": {}
214 | },
215 | "source": [
216 | "def build_base_model(hparams):\n",
217 | " inputs = tf.keras.Input(\n",
218 | " shape=hparams.input_shape, dtype=tf.float32, name=IMAGE_INPUT_NAME)\n",
219 | " \n",
220 | " x = inputs\n",
221 | " for i, num_filters in enumerate(hparams.conv_filters):\n",
222 | " x = tf.keras.layers.Conv2D(\n",
223 | " num_filters, hparams.kernel_size, activation='relu')(x)\n",
224 | " if i < len(hparams.conv_filters) - 1:\n",
225 | " # max pooling między warstwami splotu\n",
226 | " x = tf.keras.layers.MaxPooling2D(hparams.pool_size)(x)\n",
227 | " x = tf.keras.layers.Flatten()(x)\n",
228 | " for num_units in hparams.num_fc_units:\n",
229 | " x = tf.keras.layers.Dense(num_units, activation='relu')(x)\n",
230 | " pred = tf.keras.layers.Dense(hparams.num_classes, activation='softmax')(x)\n",
231 | " model = tf.keras.Model(inputs=inputs, outputs=pred)\n",
232 | " return model"
233 | ],
234 | "execution_count": 0,
235 | "outputs": []
236 | },
237 | {
238 | "cell_type": "code",
239 | "metadata": {
240 | "id": "E0-Ix8KUvUnY",
241 | "colab_type": "code",
242 | "colab": {
243 | "base_uri": "https://localhost:8080/",
244 | "height": 592
245 | },
246 | "outputId": "69884dd5-f06a-485a-823a-e89c6c31cf33"
247 | },
248 | "source": [
249 | "base_model = build_base_model(HPARAMS)\n",
250 | "base_model.summary()"
251 | ],
252 | "execution_count": 9,
253 | "outputs": [
254 | {
255 | "output_type": "stream",
256 | "text": [
257 | "WARNING:tensorflow:From /tensorflow-1.15.0/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
258 | "Instructions for updating:\n",
259 | "If using Keras pass *_constraint arguments to layers.\n"
260 | ],
261 | "name": "stdout"
262 | },
263 | {
264 | "output_type": "stream",
265 | "text": [
266 | "WARNING:tensorflow:From /tensorflow-1.15.0/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
267 | "Instructions for updating:\n",
268 | "If using Keras pass *_constraint arguments to layers.\n"
269 | ],
270 | "name": "stderr"
271 | },
272 | {
273 | "output_type": "stream",
274 | "text": [
275 | "Model: \"model\"\n",
276 | "_________________________________________________________________\n",
277 | "Layer (type) Output Shape Param # \n",
278 | "=================================================================\n",
279 | "image (InputLayer) [(None, 28, 28, 1)] 0 \n",
280 | "_________________________________________________________________\n",
281 | "conv2d (Conv2D) (None, 26, 26, 32) 320 \n",
282 | "_________________________________________________________________\n",
283 | "max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0 \n",
284 | "_________________________________________________________________\n",
285 | "conv2d_1 (Conv2D) (None, 11, 11, 64) 18496 \n",
286 | "_________________________________________________________________\n",
287 | "max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0 \n",
288 | "_________________________________________________________________\n",
289 | "conv2d_2 (Conv2D) (None, 3, 3, 64) 36928 \n",
290 | "_________________________________________________________________\n",
291 | "flatten (Flatten) (None, 576) 0 \n",
292 | "_________________________________________________________________\n",
293 | "dense (Dense) (None, 64) 36928 \n",
294 | "_________________________________________________________________\n",
295 | "dense_1 (Dense) (None, 10) 650 \n",
296 | "=================================================================\n",
297 | "Total params: 93,322\n",
298 | "Trainable params: 93,322\n",
299 | "Non-trainable params: 0\n",
300 | "_________________________________________________________________\n"
301 | ],
302 | "name": "stdout"
303 | }
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "metadata": {
309 | "id": "i8SJt6H7vmSt",
310 | "colab_type": "code",
311 | "colab": {
312 | "base_uri": "https://localhost:8080/",
313 | "height": 225
314 | },
315 | "outputId": "73be9d68-23ed-45cd-974a-3fe8bccb435b"
316 | },
317 | "source": [
318 | "base_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',\n",
319 | " metrics=['acc'])\n",
320 | "base_model.fit(train_dataset, epochs=HPARAMS.epochs)"
321 | ],
322 | "execution_count": 11,
323 | "outputs": [
324 | {
325 | "output_type": "stream",
326 | "text": [
327 | "Train on None steps\n",
328 | "Epoch 1/5\n",
329 | "1875/1875 [==============================] - 25s 13ms/step - loss: 0.1426 - acc: 0.9550\n",
330 | "Epoch 2/5\n",
331 | "1875/1875 [==============================] - 16s 9ms/step - loss: 0.0463 - acc: 0.9855\n",
332 | "Epoch 3/5\n",
333 | "1875/1875 [==============================] - 16s 9ms/step - loss: 0.0334 - acc: 0.9896\n",
334 | "Epoch 4/5\n",
335 | "1875/1875 [==============================] - 16s 9ms/step - loss: 0.0239 - acc: 0.9923\n",
336 | "Epoch 5/5\n",
337 | "1875/1875 [==============================] - 17s 9ms/step - loss: 0.0197 - acc: 0.9940\n"
338 | ],
339 | "name": "stdout"
340 | },
341 | {
342 | "output_type": "execute_result",
343 | "data": {
344 | "text/plain": [
345 | ""
346 | ]
347 | },
348 | "metadata": {
349 | "tags": []
350 | },
351 | "execution_count": 11
352 | }
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "metadata": {
358 | "id": "n3x2EpLawHDD",
359 | "colab_type": "code",
360 | "colab": {
361 | "base_uri": "https://localhost:8080/",
362 | "height": 52
363 | },
364 | "outputId": "8a725a8c-28cd-47ea-819c-3c9a0e32814e"
365 | },
366 | "source": [
367 | "results = base_model.evaluate(test_dataset)\n",
368 | "named_results = dict(zip(base_model.metrics_names, results))\n",
369 | "print('\\naccuracy:', named_results['acc'])"
370 | ],
371 | "execution_count": 12,
372 | "outputs": [
373 | {
374 | "output_type": "stream",
375 | "text": [
376 | " 313/Unknown - 3s 10ms/step - loss: 0.0325 - acc: 0.9906\n",
377 | "accuracy: 0.9906\n"
378 | ],
379 | "name": "stdout"
380 | }
381 | ]
382 | },
383 | {
384 | "cell_type": "markdown",
385 | "metadata": {
386 | "id": "dTY2e7VYw_tp",
387 | "colab_type": "text"
388 | },
389 | "source": [
390 | "### Adversarial-regularized model"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "metadata": {
396 | "id": "NfyH08YDw_Ui",
397 | "colab_type": "code",
398 | "colab": {}
399 | },
400 | "source": [
401 | "adv_config = nsl.configs.make_adv_reg_config(\n",
402 | " multiplier = HPARAMS.adv_multiplier,\n",
403 | " adv_step_size = HPARAMS.adv_step_size,\n",
404 | " adv_grad_norm = HPARAMS.adv_grad_norm\n",
405 | ")"
406 | ],
407 | "execution_count": 0,
408 | "outputs": []
409 | },
410 | {
411 | "cell_type": "code",
412 | "metadata": {
413 | "id": "67W6vp_Rypry",
414 | "colab_type": "code",
415 | "colab": {}
416 | },
417 | "source": [
418 | "base_adv_model = build_base_model(HPARAMS)\n",
419 | "adv_model = nsl.keras.AdversarialRegularization(\n",
420 | " base_adv_model,\n",
421 | " label_keys = [LABEL_INPUT_NAME],\n",
422 | " adv_config = adv_config\n",
423 | ")\n",
424 | "\n",
425 | "train_set_for_adv_model = train_dataset.map(convert_to_dictionaries)\n",
426 | "test_set_for_adv_model = test_dataset.map(convert_to_dictionaries)"
427 | ],
428 | "execution_count": 0,
429 | "outputs": []
430 | },
431 | {
432 | "cell_type": "code",
433 | "metadata": {
434 | "id": "4asz8sS_0_70",
435 | "colab_type": "code",
436 | "colab": {
437 | "base_uri": "https://localhost:8080/",
438 | "height": 453
439 | },
440 | "outputId": "cdb08b25-8ed1-418f-c86b-360399bb22ff"
441 | },
442 | "source": [
443 | "adv_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',\n",
444 | " metrics=['acc'])\n",
445 | "adv_model.fit(train_set_for_adv_model, epochs=HPARAMS.epochs)"
446 | ],
447 | "execution_count": 18,
448 | "outputs": [
449 | {
450 | "output_type": "stream",
451 | "text": [
452 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/neural_structured_learning/keras/adversarial_regularization.py:167: The name tf.losses.Reduction is deprecated. Please use tf.compat.v1.losses.Reduction instead.\n",
453 | "\n"
454 | ],
455 | "name": "stdout"
456 | },
457 | {
458 | "output_type": "stream",
459 | "text": [
460 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/neural_structured_learning/keras/adversarial_regularization.py:167: The name tf.losses.Reduction is deprecated. Please use tf.compat.v1.losses.Reduction instead.\n",
461 | "\n"
462 | ],
463 | "name": "stderr"
464 | },
465 | {
466 | "output_type": "stream",
467 | "text": [
468 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/neural_structured_learning/lib/adversarial_neighbor.py:97: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
469 | "Instructions for updating:\n",
470 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
471 | ],
472 | "name": "stdout"
473 | },
474 | {
475 | "output_type": "stream",
476 | "text": [
477 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/neural_structured_learning/lib/adversarial_neighbor.py:97: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
478 | "Instructions for updating:\n",
479 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
480 | ],
481 | "name": "stderr"
482 | },
483 | {
484 | "output_type": "stream",
485 | "text": [
486 | "WARNING:tensorflow:Output output_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to output_1.\n"
487 | ],
488 | "name": "stdout"
489 | },
490 | {
491 | "output_type": "stream",
492 | "text": [
493 | "WARNING:tensorflow:Output output_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to output_1.\n"
494 | ],
495 | "name": "stderr"
496 | },
497 | {
498 | "output_type": "stream",
499 | "text": [
500 | "Train on None steps\n",
501 | "Epoch 1/5\n",
502 | "1875/1875 [==============================] - 24s 13ms/step - loss: 0.2993 - sparse_categorical_crossentropy: 0.1377 - sparse_categorical_accuracy: 0.9599 - adversarial_loss: 0.8081\n",
503 | "Epoch 2/5\n",
504 | "1875/1875 [==============================] - 20s 11ms/step - loss: 0.1227 - sparse_categorical_crossentropy: 0.0431 - sparse_categorical_accuracy: 0.9870 - adversarial_loss: 0.3981\n",
505 | "Epoch 3/5\n",
506 | "1875/1875 [==============================] - 20s 11ms/step - loss: 0.0796 - sparse_categorical_crossentropy: 0.0325 - sparse_categorical_accuracy: 0.9897 - adversarial_loss: 0.2356\n",
507 | "Epoch 4/5\n",
508 | "1875/1875 [==============================] - 20s 11ms/step - loss: 0.0547 - sparse_categorical_crossentropy: 0.0246 - sparse_categorical_accuracy: 0.9920 - adversarial_loss: 0.1504\n",
509 | "Epoch 5/5\n",
510 | "1875/1875 [==============================] - 20s 11ms/step - loss: 0.0463 - sparse_categorical_crossentropy: 0.0184 - sparse_categorical_accuracy: 0.9942 - adversarial_loss: 0.1393\n"
511 | ],
512 | "name": "stdout"
513 | },
514 | {
515 | "output_type": "execute_result",
516 | "data": {
517 | "text/plain": [
518 | ""
519 | ]
520 | },
521 | "metadata": {
522 | "tags": []
523 | },
524 | "execution_count": 18
525 | }
526 | ]
527 | },
528 | {
529 | "cell_type": "code",
530 | "metadata": {
531 | "id": "LcOB61b-1fiZ",
532 | "colab_type": "code",
533 | "colab": {
534 | "base_uri": "https://localhost:8080/",
535 | "height": 72
536 | },
537 | "outputId": "2f6b85a9-8d56-4a9b-c35b-aef5aef78ee8"
538 | },
539 | "source": [
540 | "results = adv_model.evaluate(test_set_for_adv_model)\n",
541 | "named_results = dict(zip(adv_model.metrics_names, results))\n",
542 | "print('\\naccuracy:', named_results['sparse_categorical_accuracy'])"
543 | ],
544 | "execution_count": 20,
545 | "outputs": [
546 | {
547 | "output_type": "stream",
548 | "text": [
549 | " 313/Unknown - 3s 11ms/step - loss: 0.0593 - sparse_categorical_crossentropy: 0.0343 - sparse_categorical_accuracy: 0.9895 - adversarial_loss: 0.1249\n",
550 | "accuracy: 0.9895\n"
551 | ],
552 | "name": "stdout"
553 | }
554 | ]
555 | }
556 | ]
557 | }
558 |
--------------------------------------------------------------------------------
/neural_networks/RBN/rbn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "rbn.ipynb",
7 | "provenance": [],
8 | "authorship_tag": "ABX9TyNMkvjA2+wmCQMUd75RE5uR",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | }
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "view-in-github",
24 | "colab_type": "text"
25 | },
26 | "source": [
27 | " "
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "source": [
33 | "#Radial Basis Network"
34 | ],
35 | "metadata": {
36 | "id": "q5daW45oZGBn"
37 | }
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 2,
42 | "metadata": {
43 | "id": "55FRqfZ3O7Wy"
44 | },
45 | "outputs": [],
46 | "source": [
47 | "import numpy as np\n",
48 | "import pandas as pd\n",
49 | "from keras.layers import Layer, Dense, Flatten\n",
50 | "from keras import backend as K\n",
51 | "from keras.models import Sequential\n",
52 | "from keras.losses import binary_crossentropy"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "source": [
58 | "X = np.load('./k49-train-imgs.npz')['arr_0']\n",
59 | "y = np.load('./k49-train-labels.npz')['arr_0']\n",
60 | "y = (y <= 25).astype(int)\n",
61 | "\n",
62 | "print(X.shape)\n",
63 | "print(y.shape)"
64 | ],
65 | "metadata": {
66 | "colab": {
67 | "base_uri": "https://localhost:8080/"
68 | },
69 | "id": "QT-n9_P4cmWn",
70 | "outputId": "f2c966ea-1137-4579-e0ab-2a19cf910422"
71 | },
72 | "execution_count": 11,
73 | "outputs": [
74 | {
75 | "output_type": "stream",
76 | "name": "stdout",
77 | "text": [
78 | "(232365, 28, 28)\n",
79 | "(232365,)\n"
80 | ]
81 | }
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "source": [
87 | "class RBFLayer(Layer):\n",
88 | " def __init__(self, units, gamma, **kwargs):\n",
89 | " super(RBFLayer, self).__init__(**kwargs)\n",
90 | " self.units = units\n",
91 | " self.gamma = K.cast_to_floatx(gamma)\n",
92 | "\n",
93 | " def build(self, input_shape):\n",
94 | " self.mu = self.add_weight(name='mu',\n",
95 | " shape=(int(input_shape[1]), self.units),\n",
96 | " initializer='uniform',\n",
97 | " trainable=True)\n",
98 | " super(RBFLayer, self).build(input_shape)\n",
99 | " \n",
100 | " def call(self, inputs):\n",
101 | " diff = K.expand_dims(inputs) - self.mu\n",
102 | " l2 = K.sum(K.pow(diff, 2), axis=1)\n",
103 | " res = K.exp(-1 * self.gamma * l2)\n",
104 | " return res"
105 | ],
106 | "metadata": {
107 | "id": "TfuUdz2bbxIc"
108 | },
109 | "execution_count": null,
110 | "outputs": []
111 | },
112 | {
113 | "cell_type": "code",
114 | "source": [
115 | "model = Sequential()\n",
116 | "model.add(Flatten(input_shape=(28, 28)))\n",
117 | "model.add(RBFLayer(units=10, gamma=0.5))\n",
118 | "model.add(Dense(1, activation='sigmoid'))\n",
119 | "\n",
120 | "model.compile(optimizer='rmsprop', loss=binary_crossentropy)"
121 | ],
122 | "metadata": {
123 | "id": "ye0X9YpfcYJZ"
124 | },
125 | "execution_count": 3,
126 | "outputs": []
127 | },
128 | {
129 | "cell_type": "code",
130 | "source": [
131 | "model.fit(X, y, batch_size=256, epochs=3)"
132 | ],
133 | "metadata": {
134 | "colab": {
135 | "base_uri": "https://localhost:8080/"
136 | },
137 | "id": "_p7Q3CgJeK_A",
138 | "outputId": "d38d9d79-49f7-481d-d587-67d3ec846ed2"
139 | },
140 | "execution_count": 12,
141 | "outputs": [
142 | {
143 | "output_type": "stream",
144 | "name": "stdout",
145 | "text": [
146 | "Epoch 1/3\n",
147 | "908/908 [==============================] - 20s 21ms/step - loss: 0.6823\n",
148 | "Epoch 2/3\n",
149 | "908/908 [==============================] - 19s 21ms/step - loss: 0.6806\n",
150 | "Epoch 3/3\n",
151 | "908/908 [==============================] - 19s 21ms/step - loss: 0.6806\n"
152 | ]
153 | },
154 | {
155 | "output_type": "execute_result",
156 | "data": {
157 | "text/plain": [
158 | ""
159 | ]
160 | },
161 | "metadata": {},
162 | "execution_count": 12
163 | }
164 | ]
165 | }
166 | ]
167 | }
--------------------------------------------------------------------------------
/neural_networks/RNN/seq2seq_sorting.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "seq2seq_sorting.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyPQj847hUoM1ycujsbogaLF",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "language_info": {
17 | "name": "python"
18 | },
19 | "accelerator": "GPU"
20 | },
21 | "cells": [
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {
25 | "id": "view-in-github",
26 | "colab_type": "text"
27 | },
28 | "source": [
29 | " "
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "VDYwQHxP923y"
36 | },
37 | "source": [
38 | "#Seq2Seq: Sortowanie\n",
39 | "##Import bibliotek"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "metadata": {
45 | "id": "0pb-Wrij92WN"
46 | },
47 | "source": [
48 | "from tensorflow import keras\n",
49 | "from tensorflow.keras import layers\n",
50 | "import numpy as np\n",
51 | "\n",
52 | "TRAINING_SIZE = 50000\n",
53 | "NUMBERS_TO_SORT = 10"
54 | ],
55 | "execution_count": 3,
56 | "outputs": []
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {
61 | "id": "8SY7nCYXHF9w"
62 | },
63 | "source": [
64 | "##Generowanie danych"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "metadata": {
70 | "colab": {
71 | "base_uri": "https://localhost:8080/"
72 | },
73 | "id": "vktABSDC-Jgg",
74 | "outputId": "8da5a9c9-66f9-4871-b827-43f5badd9b37"
75 | },
76 | "source": [
77 | "class CharacterTable:\n",
78 | " \n",
79 | " def __init__(self, chars):\n",
80 | " self.chars = sorted(set(chars))\n",
81 | " self.char_indices = dict((c, i) for i, c in enumerate(self.chars))\n",
82 | " self.indices_char = dict((i, c) for i, c in enumerate(self.chars))\n",
83 | "\n",
84 | " def encode(self, C, num_rows):\n",
85 | " x = np.zeros((num_rows, len(self.chars)))\n",
86 | " for i, c in enumerate(C):\n",
87 | " x[i, self.char_indices[c]] = 1\n",
88 | " return x\n",
89 | "\n",
90 | " def decode(self, x, calc_argmax=True):\n",
91 | " if calc_argmax:\n",
92 | " x = x.argmax(axis=-1)\n",
93 | " return ''.join(self.indices_char[x] for x in x)\n",
94 | "\n",
95 | "# Wszystkie liczby, znaki i spacja\n",
96 | "chars = '0123456789'\n",
97 | "ctable = CharacterTable(chars)\n",
98 | "\n",
99 | "questions = []\n",
100 | "expected = []\n",
101 | "\n",
102 | "while len(questions) < TRAINING_SIZE:\n",
103 | " randomize_string = lambda: str(\n",
104 | " ''.join(\n",
105 | " np.random.choice(list('0123456789'))\n",
106 | " for i in range(NUMBERS_TO_SORT)\n",
107 | " )\n",
108 | " )\n",
109 | " query = randomize_string()\n",
110 | "\n",
111 | " # String to list\n",
112 | " ans = [int(q) for q in query]\n",
113 | " ans = sorted(ans)\n",
114 | "\n",
115 | " # Sorted list to string\n",
116 | " answer = ''\n",
117 | " for num in ans:\n",
118 | " answer += str(num)\n",
119 | " \n",
120 | " questions.append(query)\n",
121 | " expected.append(answer)\n",
122 | "\n",
123 | "print('Liczba przykładów:', len(questions))\n",
124 | "print('Questions: ', questions[:5])\n",
125 | "print('Answers: ', expected[:5])"
126 | ],
127 | "execution_count": 26,
128 | "outputs": [
129 | {
130 | "output_type": "stream",
131 | "name": "stdout",
132 | "text": [
133 | "Liczba przykładów: 50000\n",
134 | "Questions: ['8847644026', '3025793945', '1664364979', '2849673969', '3700480495']\n",
135 | "Answers: ['0244466788', '0233455799', '1344666799', '2346678999', '0003445789']\n"
136 | ]
137 | }
138 | ]
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {
143 | "id": "B_665xMOHtSA"
144 | },
145 | "source": [
146 | "##Wektoryzacja danych"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "metadata": {
152 | "colab": {
153 | "base_uri": "https://localhost:8080/"
154 | },
155 | "id": "1R4OKuJVG68v",
156 | "outputId": "549c6fab-6dfc-4b2a-a841-cad771634470"
157 | },
158 | "source": [
159 | "x = np.zeros((len(questions), NUMBERS_TO_SORT, len(chars)), dtype=np.bool)\n",
160 | "y = np.zeros((len(questions), NUMBERS_TO_SORT, len(chars)), dtype=np.bool)\n",
161 | "\n",
162 | "for i, sentence in enumerate(questions):\n",
163 | " x[i] = ctable.encode(sentence, NUMBERS_TO_SORT)\n",
164 | "for i, sentence in enumerate(expected):\n",
165 | " y[i] = ctable.encode(sentence, NUMBERS_TO_SORT)\n",
166 | "\n",
167 | "indices = np.arange(len(y))\n",
168 | "np.random.shuffle(indices)\n",
169 | "x = x[indices]\n",
170 | "y = y[indices]\n",
171 | "\n",
172 | "split_at = len(x) - len(x) // 10\n",
173 | "(x_train, x_val) = x[:split_at], x[split_at:]\n",
174 | "(y_train, y_val) = y[:split_at], y[split_at:]\n",
175 | "\n",
176 | "print('Dane treningowe:')\n",
177 | "print(x_train.shape)\n",
178 | "print(y_train.shape)\n",
179 | "\n",
180 | "print('Dane walidacyjne:')\n",
181 | "print(x_val.shape)\n",
182 | "print(y_val.shape)"
183 | ],
184 | "execution_count": 27,
185 | "outputs": [
186 | {
187 | "output_type": "stream",
188 | "name": "stdout",
189 | "text": [
190 | "Dane treningowe:\n",
191 | "(45000, 10, 10)\n",
192 | "(45000, 10, 10)\n",
193 | "Dane walidacyjne:\n",
194 | "(5000, 10, 10)\n",
195 | "(5000, 10, 10)\n"
196 | ]
197 | }
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {
203 | "id": "s4Z6pNxiH2gx"
204 | },
205 | "source": [
206 | "##Budowa modelu"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "metadata": {
212 | "colab": {
213 | "base_uri": "https://localhost:8080/"
214 | },
215 | "id": "2uRxKHkk_9bL",
216 | "outputId": "a41b8a76-0c2e-4030-e40f-ccb75e67ff6e"
217 | },
218 | "source": [
219 | "num_layers = 1\n",
220 | "\n",
221 | "model = keras.Sequential()\n",
222 | "model.add(layers.LSTM(16, input_shape=(NUMBERS_TO_SORT, len(chars))))\n",
223 | "model.add(layers.RepeatVector(NUMBERS_TO_SORT))\n",
224 | "\n",
225 | "for _ in range(num_layers):\n",
226 | " model.add(layers.LSTM(16, return_sequences=True))\n",
227 | "\n",
228 | "model.add(layers.Dense(len(chars), activation='softmax'))\n",
229 | "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
230 | "model.summary()"
231 | ],
232 | "execution_count": 38,
233 | "outputs": [
234 | {
235 | "output_type": "stream",
236 | "name": "stdout",
237 | "text": [
238 | "Model: \"sequential_7\"\n",
239 | "_________________________________________________________________\n",
240 | "Layer (type) Output Shape Param # \n",
241 | "=================================================================\n",
242 | "lstm_20 (LSTM) (None, 16) 1728 \n",
243 | "_________________________________________________________________\n",
244 | "repeat_vector_7 (RepeatVecto (None, 10, 16) 0 \n",
245 | "_________________________________________________________________\n",
246 | "lstm_21 (LSTM) (None, 10, 16) 2112 \n",
247 | "_________________________________________________________________\n",
248 | "dense_7 (Dense) (None, 10, 10) 170 \n",
249 | "=================================================================\n",
250 | "Total params: 4,010\n",
251 | "Trainable params: 4,010\n",
252 | "Non-trainable params: 0\n",
253 | "_________________________________________________________________\n"
254 | ]
255 | }
256 | ]
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "metadata": {
261 | "id": "b8Y3x8I3H5SG"
262 | },
263 | "source": [
264 | "##Trening modelu"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "metadata": {
270 | "colab": {
271 | "base_uri": "https://localhost:8080/"
272 | },
273 | "id": "J7j6Pq8xIB_0",
274 | "outputId": "a2dc2968-f20b-4043-afc1-9270126f973d"
275 | },
276 | "source": [
277 | "epochs = 5\n",
278 | "batch_size = 32\n",
279 | "\n",
280 | "for epoch in range(1, epochs + 1):\n",
281 | " print()\n",
282 | " print('Iteracja', epoch)\n",
283 | " model.fit(\n",
284 | " x_train,\n",
285 | " y_train,\n",
286 | " batch_size=batch_size,\n",
287 | " epochs=1,\n",
288 | " validation_data=(x_val, y_val)\n",
289 | " )\n",
290 | " # Wybór 10 losowych próbek ze zbioru walidacyjnego, \n",
291 | " # abyśmy mogli zobaczyć błędy\n",
292 | " for i in range(10):\n",
293 | " ind = np.random.randint(0, len(x_val))\n",
294 | " rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]\n",
295 | " preds = np.argmax(model.predict(rowx), axis=-1)\n",
296 | " q = ctable.decode(rowx[0])\n",
297 | " correct = ctable.decode(rowy[0])\n",
298 | " guess = ctable.decode(preds[0], calc_argmax=False)\n",
299 | " print('Q', q, end=' ')\n",
300 | " print('T', correct, end=' ')\n",
301 | " if correct == guess:\n",
302 | " print('☑ ' + guess)\n",
303 | " else:\n",
304 | " print('☒ ' + guess)"
305 | ],
306 | "execution_count": 39,
307 | "outputs": [
308 | {
309 | "output_type": "stream",
310 | "name": "stdout",
311 | "text": [
312 | "\n",
313 | "Iteracja 1\n",
314 | "1407/1407 [==============================] - 17s 10ms/step - loss: 0.8683 - accuracy: 0.7232 - val_loss: 0.3988 - val_accuracy: 0.8993\n",
315 | "Q 7543839439 T 3334457899 ☒ 2334557899\n",
316 | "Q 2589669405 T 0245566899 ☒ 0145566899\n",
317 | "Q 3739211020 T 0011223379 ☑ 0011223379\n",
318 | "Q 4230435049 T 0023344459 ☒ 0123344459\n",
319 | "Q 8195226339 T 1223356899 ☒ 0223356899\n",
320 | "Q 5887675366 T 3556667788 ☒ 1556667788\n",
321 | "Q 5740260425 T 0022445567 ☒ 0122445567\n",
322 | "Q 9658427089 T 0245678899 ☑ 0245678899\n",
323 | "Q 2872732756 T 2223567778 ☒ 2233567777\n",
324 | "Q 6575054576 T 0455556677 ☒ 0455566677\n",
325 | "\n",
326 | "Iteracja 2\n",
327 | "1407/1407 [==============================] - 13s 9ms/step - loss: 0.2551 - accuracy: 0.9442 - val_loss: 0.1640 - val_accuracy: 0.9680\n",
328 | "Q 8464997804 T 0444678899 ☑ 0444678899\n",
329 | "Q 2124903773 T 0122334779 ☑ 0122334779\n",
330 | "Q 6137035694 T 0133456679 ☑ 0133456679\n",
331 | "Q 7719639560 T 0135667799 ☑ 0135667799\n",
332 | "Q 1797418294 T 1124477899 ☒ 0124477899\n",
333 | "Q 2732399463 T 2233346799 ☑ 2233346799\n",
334 | "Q 6638838210 T 0123366888 ☑ 0123366888\n",
335 | "Q 3799983347 T 3334778999 ☑ 3334778999\n",
336 | "Q 0519453235 T 0123345559 ☑ 0123345559\n",
337 | "Q 3277231476 T 1223346777 ☒ 0223346777\n",
338 | "\n",
339 | "Iteracja 3\n",
340 | "1407/1407 [==============================] - 13s 9ms/step - loss: 0.0940 - accuracy: 0.9908 - val_loss: 0.0543 - val_accuracy: 0.9963\n",
341 | "Q 1764144907 T 0114446779 ☑ 0114446779\n",
342 | "Q 1481556210 T 0111245568 ☑ 0111245568\n",
343 | "Q 8801835874 T 0134578888 ☑ 0134578888\n",
344 | "Q 8141561751 T 1111455678 ☑ 1111455678\n",
345 | "Q 3746504061 T 0013445667 ☑ 0013445667\n",
346 | "Q 7179021402 T 0011224779 ☑ 0011224779\n",
347 | "Q 6905020396 T 0002356699 ☑ 0002356699\n",
348 | "Q 6346478826 T 2344666788 ☑ 2344666788\n",
349 | "Q 7245078965 T 0245567789 ☑ 0245567789\n",
350 | "Q 2627281325 T 1222235678 ☑ 1222235678\n",
351 | "\n",
352 | "Iteracja 4\n",
353 | "1407/1407 [==============================] - 13s 10ms/step - loss: 0.0370 - accuracy: 0.9977 - val_loss: 0.0254 - val_accuracy: 0.9984\n",
354 | "Q 3998260233 T 0223336899 ☑ 0223336899\n",
355 | "Q 7663668555 T 3555666678 ☑ 3555666678\n",
356 | "Q 8607591008 T 0001567889 ☑ 0001567889\n",
357 | "Q 2011430331 T 0011123334 ☑ 0011123334\n",
358 | "Q 2654766745 T 2445566677 ☑ 2445566677\n",
359 | "Q 8297424699 T 2244678999 ☑ 2244678999\n",
360 | "Q 6398993191 T 1133689999 ☑ 1133689999\n",
361 | "Q 9180138065 T 0011356889 ☑ 0011356889\n",
362 | "Q 2649767969 T 2466677999 ☑ 2466677999\n",
363 | "Q 0694320242 T 0022234469 ☑ 0022234469\n",
364 | "\n",
365 | "Iteracja 5\n",
366 | "1407/1407 [==============================] - 13s 10ms/step - loss: 0.0188 - accuracy: 0.9988 - val_loss: 0.0137 - val_accuracy: 0.9990\n",
367 | "Q 5852344260 T 0223445568 ☑ 0223445568\n",
368 | "Q 4426186161 T 1112446668 ☑ 1112446668\n",
369 | "Q 5216904308 T 0012345689 ☑ 0012345689\n",
370 | "Q 8159999274 T 1245789999 ☑ 1245789999\n",
371 | "Q 8701108145 T 0011145788 ☑ 0011145788\n",
372 | "Q 4473650311 T 0113344567 ☑ 0113344567\n",
373 | "Q 1234457849 T 1234445789 ☑ 1234445789\n",
374 | "Q 5719021444 T 0112444579 ☑ 0112444579\n",
375 | "Q 4772336309 T 0233346779 ☑ 0233346779\n",
376 | "Q 7848340530 T 0033445788 ☑ 0033445788\n"
377 | ]
378 | }
379 | ]
380 | }
381 | ]
382 | }
--------------------------------------------------------------------------------