├── .idea └── vcs.xml ├── README.md ├── augmentation ├── cutmix.ipynb ├── cutout.ipynb ├── cutout_cutmix_mixup.ipynb └── mixup.ipynb ├── machine_learning ├── reinforcement_learning │ ├── alpha_tensor │ │ ├── alpha_tensor.ipynb │ │ ├── factorizations_f2.npz │ │ └── factorizations_r.npz │ ├── card_game │ │ └── card_game_env.ipynb │ ├── cartpole │ │ ├── cart_pole.ipynb │ │ ├── dqn_cartpole.ipynb │ │ └── reinforce_cartpole.ipynb │ ├── frozen_lake │ │ └── frozen_lake.ipynb │ └── gridworld │ │ ├── gridworld.py │ │ ├── iterative_policy.py │ │ ├── monte_carlo.py │ │ ├── q_learning.py │ │ └── sarsa.py ├── supervised_learning │ ├── decision_trees │ │ ├── heart_disease.ipynb │ │ └── iris_decision_tree.ipynb │ ├── ensemble_learning │ │ ├── ensemble_learning.ipynb │ │ └── iris_random_forest.ipynb │ ├── linear_algebra │ │ └── cramer_rule.ipynb │ ├── naive_bayes │ │ └── naive_bayes.ipynb │ └── time_series_prediction │ │ ├── prophet_coronavirus.ipynb │ │ └── time_series_forecasting.ipynb └── unsupervised_learning │ ├── anomaly_detection │ ├── anomaly_detection_time_series.ipynb │ ├── isolation_forest.ipynb │ └── local_outlier_factor.ipynb │ ├── association_rules │ └── apriori.ipynb │ ├── clustering │ ├── dbscan.ipynb │ ├── hierarchical_clustering.ipynb │ ├── k_means.ipynb │ ├── k_means_elbow_method.ipynb │ ├── k_means_image_segmentation.ipynb │ └── k_means_implementation.ipynb │ └── dimensionality_reduction │ ├── pca.ipynb │ ├── some_play_with_svd.ipynb │ └── t_sne.ipynb ├── modern_approach ├── KAN │ └── kan_mnist.ipynb ├── adversarial_attack │ ├── adversarial_fgsm.ipynb │ └── adversarial_local_fgsm.ipynb ├── diffusion_models │ └── ddpm_pytorch.ipynb ├── few_shot_learning │ └── few_shot_learning.ipynb ├── forward_forward │ ├── forward_forward.ipynb │ └── forward_forward_pytorch.ipynb ├── one_shot_learning │ ├── siamese_network_with_contrastive_loss.ipynb │ └── siamese_network_with_triplet_loss.ipynb ├── transformer │ ├── neural_machine_translation.ipynb │ ├── self_attention.ipynb │ └── solving_math_word_problems.ipynb ├── video_interpolation │ ├── film.ipynb │ ├── film_for_vdm.ipynb │ └── video_inbetweening.ipynb └── zero_shot_learning │ ├── example.png │ ├── owl_vit.ipynb │ └── owl_vit_image_conditioning.ipynb ├── neural_networks ├── CNN │ ├── digit_recognition_cnn.ipynb │ ├── dogs_vs_cats.ipynb │ ├── drone_vs_plane.ipynb │ ├── pixelcnn.ipynb │ └── sigmoid_vs_tanh_lenet.ipynb ├── ConvRNN │ └── video_prediction_conv_lstm.ipynb ├── GAN │ ├── StarGAN │ │ ├── stargan_eval.ipynb │ │ └── stargan_transfer_learning.ipynb │ ├── cycle_gan.ipynb │ ├── dcgan_digits.ipynb │ ├── draggan.ipynb │ ├── ersgan_for_vdm.ipynb │ ├── esrgan.ipynb │ ├── pix2pix.ipynb │ ├── stylegan3.ipynb │ └── vanilla_gan.ipynb ├── MLP │ ├── digit_recognition.ipynb │ ├── experimental │ │ ├── trainable_sigmoid.ipynb │ │ └── zip_learning.ipynb │ └── fashion_mnist.ipynb ├── NSL │ ├── adversarial_regularization │ │ └── adversarial_regularization_mnist.ipynb │ └── graph_regularization │ │ └── document_classification.ipynb ├── RBM │ └── digit_recognition_rbm.ipynb ├── RBN │ └── rbn.ipynb ├── RNN │ ├── drama_generator.ipynb │ ├── performance_rnn.ipynb │ ├── review_classifier.ipynb │ ├── seq2seq_math.ipynb │ └── seq2seq_sorting.ipynb └── SOM │ └── digit_recognition_som.ipynb └── quantum_neural_networks └── quantum_cnn.ipynb /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # artificial-intelligence 2 | AI projects in python, mostly Jupyter notebooks. 3 | -------------------------------------------------------------------------------- /augmentation/cutout.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "cutout.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPmpmYZEHOzemSOBObcLtlQ", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "bakLL6CELver" 35 | }, 36 | "source": [ 37 | "#CutOut: Augmentacja danych" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "iYsgVB44GCdI" 44 | }, 45 | "source": [ 46 | "import numpy as np\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "from tensorflow.keras.datasets.fashion_mnist import load_data" 49 | ], 50 | "execution_count": 38, 51 | "outputs": [] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "metadata": { 56 | "colab": { 57 | "base_uri": "https://localhost:8080/" 58 | }, 59 | "id": "OrUklpSiJfxY", 60 | "outputId": "9f66a97c-f192-442a-8807-4e25bd02b760" 61 | }, 62 | "source": [ 63 | "(X_train, y_train), (X_test, y_test) = load_data()" 64 | ], 65 | "execution_count": 39, 66 | "outputs": [ 67 | { 68 | "output_type": "stream", 69 | "text": [ 70 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n", 71 | "32768/29515 [=================================] - 0s 0us/step\n", 72 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n", 73 | "26427392/26421880 [==============================] - 0s 0us/step\n", 74 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n", 75 | "8192/5148 [===============================================] - 0s 0us/step\n", 76 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n", 77 | "4423680/4422102 [==============================] - 0s 0us/step\n" 78 | ], 79 | "name": "stdout" 80 | } 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "colab": { 87 | "base_uri": "https://localhost:8080/" 88 | }, 89 | "id": "v0ZXubNyKCus", 90 | "outputId": "e6fd5f07-c6e4-48e5-8338-2ddf1c4ed67a" 91 | }, 92 | "source": [ 93 | "X_train[0].shape" 94 | ], 95 | "execution_count": 40, 96 | "outputs": [ 97 | { 98 | "output_type": "execute_result", 99 | "data": { 100 | "text/plain": [ 101 | "(28, 28)" 102 | ] 103 | }, 104 | "metadata": { 105 | "tags": [] 106 | }, 107 | "execution_count": 40 108 | } 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "metadata": { 114 | "id": "nt3RM-sZG-1c" 115 | }, 116 | "source": [ 117 | "def cutout(img, n_holes=1, length=16):\n", 118 | " h = img.shape[0]\n", 119 | " w = img.shape[1]\n", 120 | "\n", 121 | " mask = np.ones((h, w), np.float32)\n", 122 | "\n", 123 | " for _ in range(n_holes):\n", 124 | " y = np.random.randint(h)\n", 125 | " x = np.random.randint(w)\n", 126 | "\n", 127 | " y1 = np.clip(y - length // 2, 0, h)\n", 128 | " y2 = np.clip(y + length // 2, 0, h)\n", 129 | " x1 = np.clip(x - length // 2, 0, w)\n", 130 | " x2 = np.clip(x + length // 2, 0, w)\n", 131 | "\n", 132 | " mask[y1: y2, x1: x2] = 0\n", 133 | "\n", 134 | " img = img * mask\n", 135 | " return img" 136 | ], 137 | "execution_count": 41, 138 | "outputs": [] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "metadata": { 143 | "colab": { 144 | "base_uri": "https://localhost:8080/", 145 | "height": 282 146 | }, 147 | "id": "cqFegHuYJxET", 148 | "outputId": "919338b4-846a-4fff-d51d-5096472861eb" 149 | }, 150 | "source": [ 151 | "cut = cutout(X_train[0])\n", 152 | "plt.imshow(cut, cmap='gray')" 153 | ], 154 | "execution_count": 43, 155 | "outputs": [ 156 | { 157 | "output_type": "execute_result", 158 | "data": { 159 | "text/plain": [ 160 | "" 161 | ] 162 | }, 163 | "metadata": { 164 | "tags": [] 165 | }, 166 | "execution_count": 43 167 | }, 168 | { 169 | "output_type": "display_data", 170 | "data": { 171 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOg0lEQVR4nO3db4hd9Z3H8c83k8Q/SYyJo+OYjJtYhBAW18oQFle0i7S4IiQVCcmDJQt1pw/i0kLBFfdBfSRFti37qDBFabp0LYXUNWBZmw0Fd31QHSWbxJHWGBKTYTJjiKKJfzIz+e6DOZEx3vP7Xe899547ft8vGObO+c6595s788m5c3/n/H7m7gLw1bek7gYAdAdhB4Ig7EAQhB0IgrADQSzt5oOZGW/9Ax3m7tZoe1tHdjO7z8z+ZGZHzeyxdu4LQGdZq+PsZtYn6c+SvinplKRXJe109/HEPhzZgQ7rxJF9i6Sj7n7M3S9I+rWkrW3cH4AOaifs6ySdXPD1qWLb55jZiJmNmdlYG48FoE0df4PO3UcljUq8jAfq1M6RfULS0IKv1xfbAPSgdsL+qqRbzWyjmS2XtEPSvmraAlC1ll/Gu/usmT0i6UVJfZKecfc3KusMQKVaHnpr6cH4mx3ouI6cVANg8SDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiCIrk4lje4za3gB1Gfavepx1apVyfrdd99dWnvhhRfaeuzcv62vr6+0Njs729ZjtyvXe0qrPzOO7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBLPLAotM6vyBubk5ZpcFoiPsQBCEHQiCsANBEHYgCMIOBEHYgSC4nh1YZObm5lrar62wm9lxSR9KmpM06+7D7dwfgM6p4sj+t+5+poL7AdBB/M0OBNFu2F3S783sNTMbafQNZjZiZmNmNtbmYwFoQ1sXwpjZOnefMLMbJO2X9E/u/lLi+7kQBuiwjlwI4+4TxedpSc9J2tLO/QHonJbDbmYrzGzVpduSviXpSFWNAahWO+/GD0h6rpj/eqmk/3D3/6qkKwClVq9eXVo7d+5caa3lsLv7MUl/1er+ALqLoTcgCMIOBEHYgSAIOxAEYQeC4BJXYJEZHi6/uPSVV14prXFkB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgWLK5BxSXCbesmz9D9D6WbAaCI+xAEIQdCIKwA0EQdiAIwg4EQdiBILievQcwTo5u4MgOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiCIbNjN7BkzmzazIwu2rTWz/Wb2VvF5TWfbBNCuZo7sv5B032XbHpN0wN1vlXSg+BpAD8uG3d1fknT2ss1bJe0pbu+RtK3ivgBUrNVz4wfcfbK4fVrSQNk3mtmIpJEWHwdARdq+EMbdPTWRpLuPShqVmHASqFOr78ZPmdmgJBWfp6trCUAntBr2fZJ2Fbd3SXq+mnYAdEp23ngze1bSNyT1S5qS9ENJ/ynpN5JulnRC0nZ3v/xNvEb3xcv4BnLzxi9Zkv4/eW5ursp20ONS67OPj4/r/PnzDX+hsn+zu/vOktK9zbUGoBdwBh0QBGEHgiDsQBCEHQiCsANBMJV0D8gNf/b19SXrDL3FsnRpeWxTw7gc2YEgCDsQBGEHgiDsQBCEHQiCsANBEHYgiOwlrpU+GJe4NpQaN5Wk2dnZLnWCrwJ3bzjYzpEdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4JYVOPsqWt1c9d856Zjzk3nPDMzU1q7ePFicl+gmxhnB4Ij7EAQhB0IgrADQRB2IAjCDgRB2IEgemqcPer86Pfcc0+yvn379mR99+7dVbaDRa7lcXYze8bMps3syIJtT5jZhJkdLD7ur7JZANVr5mX8LyTd12D7T9399uLjd9W2BaBq2bC7+0uSznahFwAd1M4bdI+Y2aHiZf6asm8ysxEzGzOzsTYeC0CbWg37zyR9TdLtkiYl/bjsG9191N2H3X24xccCUIGWwu7uU+4+5+4XJf1c0pZq2wJQtZbCbmaDC778tqQjZd8LoDdkx9nN7FlJ35DUL2lK0g+Lr2+X5JKOS/quu09mH6zGeePXrl2brN90003J+qZNm1re96GHHmr5viXpk08+SdZvvvnmZB2xlI2zp1cnmN9xZ4PNT7fdEYCu4nRZIAjCDgRB2IEgCDsQBGEHgsi+G99Nd955Z7L+5JNPltauv/765L5r1pSe0Sspf/ls6vLb9957L7lvbsnlDz74IFn/9NNPk3WgGRzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiCIrk8lnRqvHhtLz1y1bt260lpuLDs3jn7+/PlkPWXp0vTpCh999FHL992M2267raP3j8WFJZuB4Ag7EARhB4Ig7EAQhB0IgrADQRB2IIiuXs/e39+vbdu2ldY3bNiQ3P/o0aOltZUrVyb3zdWvu+66ZD1l2bJlyfrq1auT9ZMnTybrExMTX7onfHUNDQ2V1k6fPl1a48gOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0F0dZx9ZmZGU1NTpfV33nknuf8111xTWssta5wby86Nwy9fvry0lhtHP3v2bLJ+/PjxZD3X2+HDh0truecl1/v69euT9UOHDiXrGzduLK3lltG+cOFCsv7+++8n6zMzM6W13Fz8ufkPVq1alaw3sRR6ae2qq65K7rtjx47SWuo5yR7ZzWzIzP5gZuNm9oaZfa/YvtbM9pvZW8Xn9CoMAGrVzMv4WUk/cPfNkv5a0m4z2yzpMUkH3P1WSQeKrwH0qGzY3X3S3V8vbn8o6U1J6yRtlbSn+LY9ksrPgwVQuy/1Bp2ZbZD0dUl/lDTg7pNF6bSkgZJ9RsxszMzGcn+DAeicpsNuZisl7ZX0fXf/3EqEPv9uRMN3JNx91N2H3X049SYXgM5qKuxmtkzzQf+Vu/+22DxlZoNFfVDSdGdaBFCF7NCbzY8RPC3pTXf/yYLSPkm7JP2o+Px87r4uXLiQHALLDVek9l2xYkVy3/7+/mQ9N4xz5syZ0tq7776b3Dc31fQVV1yRrOcuoU0N1aSGKyVpyZL0//epf7ckbd68OVlPTdGdGw7NDVnmnrdU76lhuSrqV199dbJ+4403ltZyv4t33HFHaS01DNvMOPvfSPp7SYfN7GCx7XHNh/w3ZvYdSSckbW/ivgDUJBt2d/9fSWVnANxbbTsAOoXTZYEgCDsQBGEHgiDsQBCEHQiiq5e4fvzxxzp48GBpfe/evcn9H3744dJabrrlY8eOZXtLSV3SmDsz8Morr0zWc/unlrmW0r3nlrLOnduQW256cnIyWb948WJpLddb7vyC3M8sdWlw7tTt3Bh/O5fX5uq33HJLct/UdNGp55QjOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EYblx1kofzKytB3vggQdKa48++mhy34GBhrNmfWZ6Oj33RmpcNTftcG6cPHdddu56+NT9p6YslvLj7Lmx7lw9dQ5Bbt9c7zmp/VNTmjcjd25E6vwCSRocHCytpc5FkaQHH3wwWXf3hv9wjuxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EETXx9lT85Tnxibbce+96Ylwn3rqqWT9hhtuKK1de+21yX1zc7PnxuFz4+ypcf7cWHVuvDn3+3Hq1KlkPdVbak55Kf+85KR6zy3ZnLtWPvczffHFF5P18fHx0trLL7+c3DeHcXYgOMIOBEHYgSAIOxAEYQeCIOxAEIQdCCI7zm5mQ5J+KWlAkksadfd/M7MnJP2jpEuLkz/u7r/L3Ff3BvW7aNOmTcl6aoxeys9RPjQ0lKyfOHGitJYbT3777beTdSw+ZePszSwSMSvpB+7+upmtkvSame0vaj9193+tqkkAndPM+uyTkiaL2x+a2ZuS1nW6MQDV+lJ/s5vZBklfl/THYtMjZnbIzJ4xszUl+4yY2ZiZjbXVKYC2NB12M1spaa+k77v7B5J+Julrkm7X/JH/x432c/dRdx929+EK+gXQoqbCbmbLNB/0X7n7byXJ3afcfc7dL0r6uaQtnWsTQLuyYbf5y6aelvSmu/9kwfaF02N+W9KR6tsDUJVmht7ukvQ/kg5LunQN6uOSdmr+JbxLOi7pu8Wbean7+koOvQG9pGzobVHNGw8gj+vZgeAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQTQzu2yVzkhaOO9xf7GtF/Vqb73al0Rvraqyt78oK3T1evYvPLjZWK/OTdervfVqXxK9tapbvfEyHgiCsANB1B320ZofP6VXe+vVviR6a1VXeqv1b3YA3VP3kR1AlxB2IIhawm5m95nZn8zsqJk9VkcPZczsuJkdNrODda9PV6yhN21mRxZsW2tm+83sreJzwzX2aurtCTObKJ67g2Z2f029DZnZH8xs3MzeMLPvFdtrfe4SfXXleev63+xm1ifpz5K+KemUpFcl7XT38a42UsLMjksadvfaT8Aws7slnZP0S3f/y2LbU5LOuvuPiv8o17j7P/dIb09IOlf3Mt7FakWDC5cZl7RN0j+oxucu0dd2deF5q+PIvkXSUXc/5u4XJP1a0tYa+uh57v6SpLOXbd4qaU9xe4/mf1m6rqS3nuDuk+7+enH7Q0mXlhmv9blL9NUVdYR9naSTC74+pd5a790l/d7MXjOzkbqbaWBgwTJbpyUN1NlMA9llvLvpsmXGe+a5a2X583bxBt0X3eXud0j6O0m7i5erPcnn/wbrpbHTppbx7pYGy4x/ps7nrtXlz9tVR9gnJA0t+Hp9sa0nuPtE8Xla0nPqvaWopy6toFt8nq65n8/00jLejZYZVw88d3Uuf15H2F+VdKuZbTSz5ZJ2SNpXQx9fYGYrijdOZGYrJH1LvbcU9T5Ju4rbuyQ9X2Mvn9Mry3iXLTOump+72pc/d/euf0i6X/PvyL8t6V/q6KGkr1sk/V/x8UbdvUl6VvMv62Y0/97GdyRdJ+mApLck/bektT3U279rfmnvQ5oP1mBNvd2l+ZfohyQdLD7ur/u5S/TVleeN02WBIHiDDgiCsANBEHYgCMIOBEHYgSAIOxAEYQeC+H/u3vLKCnctgQAAAABJRU5ErkJggg==\n", 172 | "text/plain": [ 173 | "
" 174 | ] 175 | }, 176 | "metadata": { 177 | "tags": [], 178 | "needs_background": "light" 179 | } 180 | } 181 | ] 182 | } 183 | ] 184 | } -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/alpha_tensor/alpha_tensor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyNqvRd4SEoe9Ndz/CSnUc9m", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "#Alpha Tensor: Exploring factorizations" 33 | ], 34 | "metadata": { 35 | "id": "SKf6xGKpHovo" 36 | } 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 5, 41 | "metadata": { 42 | "colab": { 43 | "base_uri": "https://localhost:8080/", 44 | "height": 108 45 | }, 46 | "id": "Amd-ax5D8h0b", 47 | "outputId": "bf784504-6171-4f80-c1f4-22a7ad9604aa" 48 | }, 49 | "outputs": [ 50 | { 51 | "output_type": "display_data", 52 | "data": { 53 | "text/plain": [ 54 | "" 55 | ], 56 | "text/html": [ 57 | "\n", 58 | " \n", 60 | " \n", 61 | " Upload widget is only available when the cell has been executed in the\n", 62 | " current browser session. Please rerun this cell to enable.\n", 63 | " \n", 64 | " " 240 | ] 241 | }, 242 | "metadata": {} 243 | }, 244 | { 245 | "output_type": "stream", 246 | "name": "stdout", 247 | "text": [ 248 | "Saving factorizations_r.npz to factorizations_r (1).npz\n", 249 | "Saving factorizations_f2.npz to factorizations_f2 (1).npz\n" 250 | ] 251 | } 252 | ], 253 | "source": [ 254 | "import numpy as np\n", 255 | "from google.colab import files\n", 256 | "\n", 257 | "uploaded = files.upload()" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "source": [ 263 | "filename = list(uploaded.keys())[0]\n", 264 | "with open(filename, 'rb') as f:\n", 265 | " factorizations = dict(np.load(f, allow_pickle=True))" 266 | ], 267 | "metadata": { 268 | "id": "BDGFR-ijOUma" 269 | }, 270 | "execution_count": 6, 271 | "outputs": [] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "source": [ 276 | "for key in factorizations:\n", 277 | " u, v, w = factorizations[key]\n", 278 | " rank = u.shape[-1]\n", 279 | " assert rank == v.shape[-1] and rank == w.shape[-1]" 280 | ], 281 | "metadata": { 282 | "id": "dhFXUDXZH2Mx" 283 | }, 284 | "execution_count": 9, 285 | "outputs": [] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "source": [ 290 | "def get_mamu_tensor_rectangular(a: int, b: int, c: int) -> np.ndarray:\n", 291 | " result = np.full((a*b, b*c, c*a), 0, dtype=np.int32)\n", 292 | " for i in range(a):\n", 293 | " for j in range(b):\n", 294 | " for k in range(c):\n", 295 | " result[i * b + j][j * c + k][k * a + i] = 1\n", 296 | " return result\n", 297 | "\n", 298 | "tensor = get_mamu_tensor_rectangular(3, 4, 5)\n", 299 | "u, v, w = factorizations['3,4,5']\n", 300 | "reconstruction = np.einsum('ir,jr,kr->ijk', u, v, w)\n", 301 | "if np.array_equal(tensor, reconstruction):\n", 302 | " print('Factorization is correct in R (standard arithmetic).')\n", 303 | "elif np.array_equal(tensor, np.mod(reconstruction, 2)):\n", 304 | " print('Factorization is correct in F2 (modular arithmetic).')\n", 305 | "else:\n", 306 | " print('Factorization is incorrect.')" 307 | ], 308 | "metadata": { 309 | "colab": { 310 | "base_uri": "https://localhost:8080/" 311 | }, 312 | "id": "s_bxqSfgISCF", 313 | "outputId": "24988cb5-d8f9-4689-87d3-f9f1af1faf72" 314 | }, 315 | "execution_count": 10, 316 | "outputs": [ 317 | { 318 | "output_type": "stream", 319 | "name": "stdout", 320 | "text": [ 321 | "Factorization is correct in R (standard arithmetic).\n" 322 | ] 323 | } 324 | ] 325 | } 326 | ] 327 | } -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/alpha_tensor/factorizations_f2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PsorTheDoctor/artificial-intelligence/82516ac57eb13f14e8214633a0960bea0cd9e0fb/machine_learning/reinforcement_learning/alpha_tensor/factorizations_f2.npz -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/alpha_tensor/factorizations_r.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PsorTheDoctor/artificial-intelligence/82516ac57eb13f14e8214633a0960bea0cd9e0fb/machine_learning/reinforcement_learning/alpha_tensor/factorizations_r.npz -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/card_game/card_game_env.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "En02FxRwzPCD" 17 | }, 18 | "source": [ 19 | "# Card Game Env: Własne środowisko Pythona" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "id": "hjV1uDG7q9_C", 27 | "colab": { 28 | "base_uri": "https://localhost:8080/", 29 | "height": 35.0 30 | }, 31 | "outputId": "cb653b22-0ac7-4151-d196-6f8335f6897c" 32 | }, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "TensorFlow 2.x selected.\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "try:\n", 44 | " %tensorflow_version 2.x\n", 45 | "except:\n", 46 | " pass" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": { 53 | "id": "ScyzGCBqq-zF" 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "!pip install tf-agents\n", 58 | "!pip install 'gym==0.10.11'" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "id": "d2_UaUwrqnev" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "from __future__ import absolute_import, division, print_function\n", 70 | "\n", 71 | "import abc\n", 72 | "import tensorflow as tf\n", 73 | "import numpy as np\n", 74 | "\n", 75 | "from tf_agents.environments import py_environment\n", 76 | "from tf_agents.environments import tf_environment\n", 77 | "from tf_agents.environments import tf_py_environment\n", 78 | "from tf_agents.environments import utils\n", 79 | "from tf_agents.specs import array_spec\n", 80 | "from tf_agents.environments import wrappers\n", 81 | "from tf_agents.environments import suite_gym\n", 82 | "from tf_agents.trajectories import time_step as ts\n", 83 | "\n", 84 | "tf.compat.v1.enable_v2_behavior()" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": { 90 | "id": "vGE3CN6FbTwB" 91 | }, 92 | "source": [ 93 | "### Tworzenie własnego środowiska Pythona\n", 94 | "Załóżmy, że chcemy przetrenować agenta do poniżeszej gry karcianej (inspirowanej Black Jackiem):\n", 95 | "\n", 96 | "1. W grze używa się skończonej talii kart ponumerowanych 1...10.\n", 97 | "2. W każdej kolejce agent musi zrobić jedną z 2 rzeczy: wziąć losową kartę albo opuścić kolejkę.\n", 98 | "3. Celem gry jest uzyskanie sumy kart tak bardzo jak top możliwe zbliżonej do 21 na końcu każdej rundy, bez przekraczania.\n", 99 | "\n", 100 | "Środowisko przedstawiające grę mogłoby wyglądać tak:\n", 101 | "\n", 102 | "1. Akcje: Mamy 2 akcje. Akcja 0: weź nową kartę i Akcja 1: opuść kolejkę.\n", 103 | "2. Obserwacje: Suma kart w bieżącej rundzie.\n", 104 | "3. Nagroda: `sum_of_cards - 21 if sum_of_cards <= 21, else -21`\n", 105 | "\n", 106 | "\n", 107 | "\n", 108 | "\n", 109 | "\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "id": "B7xMqltja9Q-" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "class CardGameEnv(py_environment.PyEnvironment):\n", 121 | "\n", 122 | " def __init__(self):\n", 123 | " self._action_spec = array_spec.BoundedArraySpec(\n", 124 | " shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')\n", 125 | " self._observation_spec = array_spec.BoundedArraySpec(\n", 126 | " shape=(1,), dtype=np.int32, minimum=0, name='observation')\n", 127 | " self._state = 0\n", 128 | " self._episode_ended = False\n", 129 | "\n", 130 | " def action_spec(self):\n", 131 | " return self._action_spec\n", 132 | "\n", 133 | " def observation_spec(self):\n", 134 | " return self._observation_spec\n", 135 | "\n", 136 | " def _reset(self):\n", 137 | " self._state = 0\n", 138 | " self._episode_ended = False\n", 139 | " return ts.restart(np.array([self._state], dtype=np.int32))\n", 140 | "\n", 141 | " def _step(self, action):\n", 142 | "\n", 143 | " if self._episode_ended:\n", 144 | " return self.reset()\n", 145 | "\n", 146 | " if action == 1:\n", 147 | " self._episode_ended = True\n", 148 | " elif action == 0:\n", 149 | " new_card = np.random.randint(1, 11)\n", 150 | " self._state += new_card\n", 151 | " else:\n", 152 | " raise ValueError('`akcja` powinna być 0 lub 1.')\n", 153 | "\n", 154 | " if self._episode_ended or self._state >= 21:\n", 155 | " reward = self._state - 21 if self._state <= 21 else -21\n", 156 | " return ts.termination(np.array([self._state], dtype=np.int32), reward)\n", 157 | " else:\n", 158 | " return ts.transition(\n", 159 | " np.array([self._state], dtype=np.int32), reward=0.0, discount=1.0)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "id": "i6dCEYl_r1Oa" 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "environment = CardGameEnv()\n", 171 | "utils.validate_py_environment(environment, episodes=5)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": { 178 | "id": "SVeaTtCvuhqc", 179 | "colab": { 180 | "base_uri": "https://localhost:8080/", 181 | "height": 107.0 182 | }, 183 | "outputId": "5e61a3e4-4613-438f-ef57-88a0ce049477" 184 | }, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([0], dtype=int32))\n", 191 | "TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([4], dtype=int32))\n", 192 | "TimeStep(step_type=array(2, dtype=int32), reward=array(-17., dtype=float32), discount=array(0., dtype=float32), observation=array([4], dtype=int32))\n", 193 | "Final Reward = -17.0\n" 194 | ] 195 | } 196 | ], 197 | "source": [ 198 | "get_new_card_action = np.array(0, dtype=np.int32)\n", 199 | "end_round_action = np.array(1, dtype=np.int32)\n", 200 | "\n", 201 | "environment = CardGameEnv()\n", 202 | "time_step = environment.reset()\n", 203 | "print(time_step)\n", 204 | "cumulative_reward = time_step.reward\n", 205 | "\n", 206 | "for _ in range(1):\n", 207 | " time_step = environment.step(get_new_card_action)\n", 208 | " print(time_step)\n", 209 | " cumulative_reward += time_step.reward\n", 210 | "\n", 211 | "time_step = environment.step(end_round_action)\n", 212 | "print(time_step)\n", 213 | "cumulative_reward += time_step.reward\n", 214 | "print('Final Reward = ', cumulative_reward)" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "colab": { 220 | "name": "card_game_env.ipynb", 221 | "provenance": [], 222 | "collapsed_sections": [], 223 | "authorship_tag": "ABX9TyO0WRwBcg80OLisWET/RExB", 224 | "include_colab_link": true 225 | }, 226 | "kernelspec": { 227 | "name": "python3", 228 | "display_name": "Python 3" 229 | } 230 | }, 231 | "nbformat": 4, 232 | "nbformat_minor": 0 233 | } 234 | -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/cartpole/cart_pole.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "id": "wpSBFRJ1DBgM" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "!pip install gym\n", 22 | "!apt-get install python-opengl -y\n", 23 | "!apt install xvfb -y" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "id": "f_o3geGdFBNN" 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "!pip install pyvirtualdisplay\n", 35 | "!pip install piglet" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "id": "M4rQl_zNFM6e", 43 | "colab": { 44 | "base_uri": "https://localhost:8080/", 45 | "height": 72.0 46 | }, 47 | "outputId": "1d3e047e-5573-4e80-f388-847b54386803" 48 | }, 49 | "outputs": [ 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | "xdpyinfo was not found, X start can not be checked! Please install xdpyinfo!\n" 55 | ] 56 | }, 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "" 61 | ] 62 | }, 63 | "execution_count": 5, 64 | "metadata": { 65 | "tags": [] 66 | }, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "from pyvirtualdisplay import Display\n", 72 | "\n", 73 | "display = Display(visible=0, size=(1400, 900))\n", 74 | "display.start()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": { 81 | "id": "RfCA1XwgF7R6" 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "# Ten kod tworzy wirtualny ekran, aby toczyć grę na nim\n", 86 | "# Jeśli uruchamiasz lokalnie, zignoruj to\n", 87 | "import os\n", 88 | "if type(os.environ.get('DISPLAY')) is not str or len(os.environ.get('DISPLAY')) == 0:\n", 89 | " !bash ../xvfb start\n", 90 | " %env DISPLAY=:1" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "id": "dVCvYQcQGs1T", 98 | "colab": { 99 | "base_uri": "https://localhost:8080/", 100 | "height": 64.0 101 | }, 102 | "outputId": "da6aea81-5d5e-4575-aae6-f9b40d883dda" 103 | }, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/html": [ 108 | "

\n", 109 | "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
\n", 110 | "We recommend you upgrade now \n", 111 | "or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic:\n", 112 | "more info.

\n" 113 | ], 114 | "text/plain": [ 115 | "" 116 | ] 117 | }, 118 | "metadata": { 119 | "tags": [] 120 | }, 121 | "output_type": "execute_result" 122 | } 123 | ], 124 | "source": [ 125 | "import gym\n", 126 | "from gym import logger as gymlogger\n", 127 | "from gym.wrappers import Monitor\n", 128 | "gymlogger.set_level(40)\n", 129 | "import tensorflow as tf\n", 130 | "import numpy as np\n", 131 | "import random\n", 132 | "import matplotlib\n", 133 | "import matplotlib.pyplot as plt\n", 134 | "%matplotlib inline\n", 135 | "import math\n", 136 | "import glob\n", 137 | "import io\n", 138 | "import base64\n", 139 | "from IPython.display import HTML\n", 140 | "\n", 141 | "from IPython import display as ipythondisplay" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": { 148 | "id": "nTIYTvAQIJtr" 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "\"\"\"\n", 153 | "Użyteczne funkcje umożliwiające nagranie środowiska gym i wyświetlenie go\n", 154 | "Aby zezwolić na wideo wystarczy \"env = wrap_env\"\n", 155 | "\"\"\"\n", 156 | "\n", 157 | "def show_video():\n", 158 | " mp4list = glob.glob('video/*.mp4')\n", 159 | " if len(mp4list) > 0:\n", 160 | " mp4 = mp4list[0]\n", 161 | " video = io.open(mp4, 'r+b').read()\n", 162 | " encoded = base64.b64encode(video)\n", 163 | " ipythondisplay.display(HTML(data=''''''.format(encoded.decode('ascii'))))\n", 166 | " else:\n", 167 | " print('Could not find video')\n", 168 | "\n", 169 | "def wrap_env(env):\n", 170 | " env = Monitor(env, './video', force=True)\n", 171 | " return env" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": { 177 | "id": "azBUgbnCUgv6" 178 | }, 179 | "source": [ 180 | "## CartPole" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": { 187 | "id": "cijHHKS0UmJn", 188 | "colab": { 189 | "base_uri": "https://localhost:8080/", 190 | "height": 139.0 191 | }, 192 | "outputId": "6caa9eb5-457d-42d1-88bd-d87680c7fa20" 193 | }, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "Observation space: Box(4,)\n", 200 | "Action space: Discrete(2)\n", 201 | "Initial observation: [0.01463859 0.04037087 0.04288382 0.03510227]\n", 202 | "Next observation: [ 0.01544601 0.23485245 0.04358586 -0.24374792]\n", 203 | "Reward: 1.0\n", 204 | "Done: False\n", 205 | "Info: {}\n" 206 | ] 207 | } 208 | ], 209 | "source": [ 210 | "import gym \n", 211 | "env = gym.make('CartPole-v0')\n", 212 | "env = wrap_env(env)\n", 213 | "\n", 214 | "print('Observation space: ', env.observation_space)\n", 215 | "print('Action space: ', env.action_space)\n", 216 | "\n", 217 | "obs = env.reset()\n", 218 | "\n", 219 | "print('Initial observation: ', obs)\n", 220 | "\n", 221 | "action = env.action_space.sample() # podujmuje losową akcję\n", 222 | "\n", 223 | "obs, r, done, info = env.step(action)\n", 224 | "print('Next observation: ', obs)\n", 225 | "print('Reward: ', r)\n", 226 | "print('Done: ', done)\n", 227 | "print('Info: ', info)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": { 233 | "id": "4MEuGxqqWGzc" 234 | }, 235 | "source": [ 236 | "### Wyświetlenie wideo" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": { 243 | "id": "grah6EfMWJmU", 244 | "colab": { 245 | "base_uri": "https://localhost:8080/", 246 | "height": 976.0 247 | }, 248 | "outputId": "7e22d1b6-c795-4c37-bb4e-508baac42e35" 249 | }, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "1.0\n", 256 | "1.0\n", 257 | "1.0\n", 258 | "1.0\n", 259 | "1.0\n", 260 | "1.0\n", 261 | "1.0\n", 262 | "1.0\n", 263 | "1.0\n", 264 | "1.0\n", 265 | "1.0\n", 266 | "1.0\n", 267 | "1.0\n", 268 | "1.0\n", 269 | "1.0\n", 270 | "1.0\n", 271 | "1.0\n", 272 | "1.0\n", 273 | "1.0\n", 274 | "1.0\n", 275 | "1.0\n", 276 | "1.0\n", 277 | "1.0\n", 278 | "1.0\n", 279 | "1.0\n", 280 | "1.0\n", 281 | "1.0\n", 282 | "1.0\n", 283 | "1.0\n", 284 | "1.0\n", 285 | "1.0\n", 286 | "1.0\n" 287 | ] 288 | }, 289 | { 290 | "data": { 291 | "text/html": [ 292 | "" 295 | ], 296 | "text/plain": [ 297 | "" 298 | ] 299 | }, 300 | "metadata": { 301 | "tags": [] 302 | }, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "'''CartPole z użyciem losowej akcji'''\n", 308 | "import gym\n", 309 | "env = gym.make('CartPole-v0')\n", 310 | "env = wrap_env(env)\n", 311 | "\n", 312 | "observation = env.reset()\n", 313 | "\n", 314 | "while True:\n", 315 | " env.render()\n", 316 | "\n", 317 | " action = env.action_space.sample() # podujmuje losową akcję\n", 318 | " observation, reward, done, info = env.step(action)\n", 319 | " print(reward)\n", 320 | "\n", 321 | " if done:\n", 322 | " break;\n", 323 | "\n", 324 | "env.close()\n", 325 | "show_video()" 326 | ] 327 | } 328 | ], 329 | "metadata": { 330 | "colab": { 331 | "name": "cart_pole.ipynb", 332 | "provenance": [], 333 | "authorship_tag": "ABX9TyNNnZD5dGWsLcj8Q3zwW7oa", 334 | "include_colab_link": true 335 | }, 336 | "kernelspec": { 337 | "name": "python3", 338 | "display_name": "Python 3" 339 | }, 340 | "accelerator": "GPU" 341 | }, 342 | "nbformat": 4, 343 | "nbformat_minor": 0 344 | } 345 | -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/frozen_lake/frozen_lake.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "zp6ERkCfJaDf", 17 | "colab_type": "text" 18 | }, 19 | "source": [ 20 | "# Q-Learning\n", 21 | "### Podstawowe funkcje" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 0, 27 | "metadata": { 28 | "id": "BCFZaqgSI37U", 29 | "colab_type": "code", 30 | "colab": {} 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "import gym\n", 35 | "\n", 36 | "env = gym.make('FrozenLake-v0') # będziemy używać środowiska FrozenLake" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": { 43 | "id": "5_iAxJjVJ5D8", 44 | "colab_type": "code", 45 | "colab": { 46 | "base_uri": "https://localhost:8080/", 47 | "height": 52.0 48 | }, 49 | "outputId": "a3a76324-911f-4dac-c0d4-c73b9bcae934" 50 | }, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "16\n", 57 | "4\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "print(env.observation_space.n) # zwraca liczbę stanów\n", 63 | "print(env.action_space.n) # zwraca liczbę akcji" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": { 70 | "id": "rSI2CxbMKSWE", 71 | "colab_type": "code", 72 | "colab": { 73 | "base_uri": "https://localhost:8080/", 74 | "height": 35.0 75 | }, 76 | "outputId": "23000dd3-2562-42bb-9c83-eeef874df311" 77 | }, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "0" 83 | ] 84 | }, 85 | "execution_count": 4, 86 | "metadata": { 87 | "tags": [] 88 | }, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "env.reset() # resetuje środowisko do stanu domyślnego" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 9, 99 | "metadata": { 100 | "id": "qtnJV9NJKc_L", 101 | "colab_type": "code", 102 | "colab": { 103 | "base_uri": "https://localhost:8080/", 104 | "height": 35.0 105 | }, 106 | "outputId": "0b774ece-a43d-4209-8b3a-506d3c916fe1" 107 | }, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "2\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "action = env.action_space.sample() # zwraca losową akcję\n", 119 | "print(action)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 0, 125 | "metadata": { 126 | "id": "pHe8tnxtLhb6", 127 | "colab_type": "code", 128 | "colab": {} 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "new_state, reward, done, info = env.step(action) # podejmuje akcję" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 11, 138 | "metadata": { 139 | "id": "iqjEeQASKwr0", 140 | "colab_type": "code", 141 | "colab": { 142 | "base_uri": "https://localhost:8080/", 143 | "height": 104.0 144 | }, 145 | "outputId": "32035c87-4788-4756-daa8-4539dd6b656d" 146 | }, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | " (Up)\n", 153 | "S\u001b[41mF\u001b[0mFF\n", 154 | "FHFH\n", 155 | "FFFH\n", 156 | "HFFG\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "env.render() # renderuje GUI środowiska" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": { 167 | "id": "X0E9UWKaMScI", 168 | "colab_type": "text" 169 | }, 170 | "source": [ 171 | "### Środowisko FrozenLake\n", 172 | "\n", 173 | "`Frozenlake-v0` to jedno z najprostszych środowisk w Open AI Gym. Celem jest nawigowanie agenta po zamarzniętym jeziorze bez wpadnięcia do wody. Jest tu:\n", 174 | "\n", 175 | "* 16 stanów (jeden dla każdego pola)\n", 176 | "* 4 możliwe akcje (LEFT, RIGHT, DOWN, UP)\n", 177 | "* 4 różne typy pól (F: frozen, H: hole, S: start, G: goal)\n", 178 | "\n", 179 | "### Budowa Q-Table\n", 180 | "\n", 181 | "Pierwszą rzeczą jakiej potrzebujemy jest budowa pustej Q-tabeli, której możemy użyć do przechowywania i uaktualniania naszych wartości." 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 0, 187 | "metadata": { 188 | "id": "HziaZ5EKPFUD", 189 | "colab_type": "code", 190 | "colab": {} 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "# import gym\n", 195 | "import numpy as np\n", 196 | "import time\n", 197 | "\n", 198 | "env = gym.make('FrozenLake-v0')\n", 199 | "STATES = env.observation_space.n\n", 200 | "ACTIONS = env.action_space.n" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 14, 206 | "metadata": { 207 | "id": "rLe4nK01QB2C", 208 | "colab_type": "code", 209 | "colab": { 210 | "base_uri": "https://localhost:8080/", 211 | "height": 295.0 212 | }, 213 | "outputId": "22166874-9c8b-46e0-b6dd-1b5a53526e4a" 214 | }, 215 | "outputs": [ 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "array([[0., 0., 0., 0.],\n", 220 | " [0., 0., 0., 0.],\n", 221 | " [0., 0., 0., 0.],\n", 222 | " [0., 0., 0., 0.],\n", 223 | " [0., 0., 0., 0.],\n", 224 | " [0., 0., 0., 0.],\n", 225 | " [0., 0., 0., 0.],\n", 226 | " [0., 0., 0., 0.],\n", 227 | " [0., 0., 0., 0.],\n", 228 | " [0., 0., 0., 0.],\n", 229 | " [0., 0., 0., 0.],\n", 230 | " [0., 0., 0., 0.],\n", 231 | " [0., 0., 0., 0.],\n", 232 | " [0., 0., 0., 0.],\n", 233 | " [0., 0., 0., 0.],\n", 234 | " [0., 0., 0., 0.]])" 235 | ] 236 | }, 237 | "execution_count": 14, 238 | "metadata": { 239 | "tags": [] 240 | }, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "Q = np.zeros((STATES, ACTIONS)) # stworzenie macierzy zer\n", 246 | "Q" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": { 252 | "id": "0iq-HT3gQcHh", 253 | "colab_type": "text" 254 | }, 255 | "source": [ 256 | "### Stałe\n", 257 | "Musimy zdefinować stałe, które będą użyte do aktualizowania Q-tabeli i powiedzą agentowi kiedy przerwać trening." 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 0, 263 | "metadata": { 264 | "id": "H2SDz6x6QWwZ", 265 | "colab_type": "code", 266 | "colab": {} 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "EPISODES = 2000 # ile razy odpalić środowisko od początku\n", 271 | "MAX_STEPS = 100 # maksymalna ilość kroków dozwolonych na każde uruchomienie środowiska\n", 272 | "\n", 273 | "LEARNING_RATE = 0.81 # współczynnik uczenia\n", 274 | "GAMMA = 0.96" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": { 280 | "id": "X-9dvyenSKsz", 281 | "colab_type": "text" 282 | }, 283 | "source": [ 284 | "### Podjęcie akcji\n", 285 | "Możemy podjąć akcję używając jednej z dwóch metod:\n", 286 | "\n", 287 | "1. Wybierając losowo dozwoloną akcję\n", 288 | "2. Używając obecnej Q-tabeli do znalezienia najlepszej akcji" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 0, 294 | "metadata": { 295 | "id": "82x4y9KSQUy4", 296 | "colab_type": "code", 297 | "colab": {} 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "epsilon = 0.9 # zaczynamy z 90% szans na podjęcie losowej akcji\n", 302 | "\n", 303 | "# kod do podjęcia akcji\n", 304 | "if np.random.uniform(0, 1) < epsilon: # sprawdza czy losowo wybrana wartość jest mniejsza niż epsilon\n", 305 | " action = env.action_space.sample() # podejmuje losową akcję\n", 306 | "else:\n", 307 | " action = np.argmax(Q[state, :]) # używa Q-tabeli do podjęcia najlepszej akcji bazując na obecnych wartościach" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": { 313 | "id": "O5KZgya-Ugny", 314 | "colab_type": "text" 315 | }, 316 | "source": [ 317 | "### Aktualizacja wartości Q" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 0, 323 | "metadata": { 324 | "id": "Xv6-B2sRUoMI", 325 | "colab_type": "code", 326 | "colab": {} 327 | }, 328 | "outputs": [], 329 | "source": [ 330 | "#Q[state, action] = Q[state, action] + LEARNING_RATE * (reward + GAMMA * np.max(Q[new_state, :]) - Q[state, action])" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": { 336 | "id": "gOQV0406WE7P", 337 | "colab_type": "text" 338 | }, 339 | "source": [ 340 | "### Gotowy program złożony w całość" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 0, 346 | "metadata": { 347 | "id": "YieDu-0UWl2e", 348 | "colab_type": "code", 349 | "colab": {} 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "import gym\n", 354 | "import numpy as np\n", 355 | "import time\n", 356 | "\n", 357 | "env = gym.make('FrozenLake-v0')\n", 358 | "STATES = env.observation_space.n\n", 359 | "ACTIONS = env.action_space.n\n", 360 | "\n", 361 | "Q = np.zeros((STATES, ACTIONS)) \n", 362 | "\n", 363 | "EPISODES = 1500 # ile razy odpalić środowisko od początku\n", 364 | "MAX_STEPS = 100 # maksymalna ilość kroków dozwolonych na każde uruchomienie środowiska\n", 365 | "\n", 366 | "LEARNING_RATE = 0.81 # współczynnik uczenia\n", 367 | "GAMMA = 0.96\n", 368 | "\n", 369 | "RENDER = False # jeśli chcesz zobaczyć trening ustaw na True\n", 370 | "\n", 371 | "epsilon = 0.9" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 24, 377 | "metadata": { 378 | "id": "LjQ6BcGtXJb7", 379 | "colab_type": "code", 380 | "colab": { 381 | "base_uri": "https://localhost:8080/", 382 | "height": 312.0 383 | }, 384 | "outputId": "4e8203a5-d064-4eb6-ca10-517dc865d395" 385 | }, 386 | "outputs": [ 387 | { 388 | "name": "stdout", 389 | "output_type": "stream", 390 | "text": [ 391 | "[[2.39426868e-01 1.62526398e-02 1.54562519e-02 1.59202830e-02]\n", 392 | " [1.98547854e-03 6.78832722e-03 2.04410091e-03 2.05880712e-01]\n", 393 | " [1.17939011e-01 6.88565988e-03 5.78809926e-03 6.92566173e-03]\n", 394 | " [6.13191495e-03 3.05238706e-03 2.96899524e-03 6.79343079e-03]\n", 395 | " [2.33621805e-01 1.00620775e-02 7.89787111e-03 1.35511328e-02]\n", 396 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", 397 | " [7.19855689e-02 1.73767651e-04 1.74687284e-04 1.28973108e-04]\n", 398 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", 399 | " [3.44127917e-03 4.78581497e-03 3.20147771e-03 3.35419577e-01]\n", 400 | " [3.29275604e-03 8.13401945e-01 1.53087650e-02 4.71493137e-03]\n", 401 | " [1.85896330e-01 2.24225014e-03 1.32615571e-03 2.17757207e-03]\n", 402 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", 403 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", 404 | " [6.86168854e-02 6.01188267e-02 6.41986796e-01 4.52771138e-03]\n", 405 | " [2.39216419e-01 4.87311938e-01 1.45339891e-01 2.17005939e-01]\n", 406 | " [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", 407 | "Average reward: 0.2986666666666667:\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "rewards = [] \n", 413 | "for episode in range(EPISODES):\n", 414 | "\n", 415 | " state = env.reset()\n", 416 | " for _ in range(MAX_STEPS):\n", 417 | "\n", 418 | " if RENDER:\n", 419 | " env.render()\n", 420 | "\n", 421 | " if np.random.uniform(0, 1) < epsilon:\n", 422 | " action = env.action_space.sample()\n", 423 | " else:\n", 424 | " action = np.argmax(Q[state, :])\n", 425 | "\n", 426 | " next_state, reward, done, _ = env.step(action)\n", 427 | "\n", 428 | " Q[state, action] = Q[state, action] + LEARNING_RATE * (reward + GAMMA * np.max(Q[next_state, :]) - Q[state, action])\n", 429 | "\n", 430 | " state = next_state\n", 431 | "\n", 432 | " if done:\n", 433 | " rewards.append(reward)\n", 434 | " epsilon -= 0.001\n", 435 | " break # reached goal\n", 436 | "\n", 437 | "print(Q)\n", 438 | "print(f'Average reward: {sum(rewards)/len(rewards)}:')\n", 439 | "# teraz możemy zobaczyć nasze wartośći Q!" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 26, 445 | "metadata": { 446 | "id": "eVZ8HfGQZy1y", 447 | "colab_type": "code", 448 | "colab": { 449 | "base_uri": "https://localhost:8080/", 450 | "height": 279.0 451 | }, 452 | "outputId": "1e767d37-5342-4f58-d391-8b0e98af65ed" 453 | }, 454 | "outputs": [ 455 | { 456 | "data": { 457 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXgV9dn/8fdNWMKasIR938WFxYh7\nVVCK1UpbtYLaRdtaqxSr/mqxtbbS9rFWq9Vqn7rbRaSKqNjiVrBafdSC7BBARJSwJSAJaxKS3L8/\nzoQeQ0JO4Ewmyfm8ritXzsyZM/NJrmTumfnOfL/m7oiISOpqEnUAERGJlgqBiEiKUyEQEUlxKgQi\nIilOhUBEJMU1jTpAbXXq1Mn79u0bdQwRkQbl/fff3+buWVW91+AKQd++fVmwYEHUMUREGhQz+7i6\n93RpSEQkxakQiIikOBUCEZEUp0IgIpLiVAhERFKcCoGISIpTIRARSXEqBCIiSeLuvLRsM0tzC6KO\nUiuhFgIzG29mq81srZlNreL93mb2upktMrOlZvaFMPOIiIRlS2ERVz4xn+89uZCvPfofNhbsizpS\nwkIrBGaWBjwAnAsMAyaZ2bBKi90CPO3uI4GJwB/CyiMiEgZ3Z9bCXMbd8wbvrNvOD84eRFm5M+Wp\nRewvK486XkLC7GJiNLDW3dcBmNkMYAKwMm4ZB9oFrzOATSHmERFJqrxdRfx41nL+mbOV7D7tuevi\n4fTt1Jr+WW2Y8tQi7nltDTeNHxp1zBqFWQh6ABvipnOBEyst83PgVTP7PtAaOLuqFZnZVcBVAL17\n9056UBGR2nB3Xly6mVtfWM7ekjJuOe8orji1H2lNDIALhnfnnQ+38Yd/fciJ/TtyxuAq+3qrN6Ju\nLJ4EPOHuPYEvAH8xs4MyuftD7p7t7tlZWfX7Fyoijdv23cVc8+RCpjy1iD4dWzNnyul8+/T+B4pA\nhVvPP5ohXdpyw98Wk7ezKKK0iQmzEGwEesVN9wzmxfsW8DSAu78DpAOdQswkInLYXlq2mXH3vMnc\nnDxuGj+EZ68+mYGd21S5bMvmadx/6Uj2lpRx3YzFlJV7HadNXJiFYD4wyMz6mVlzYo3Bsyst8wkw\nFsDMjiJWCPJDzCQiUms79pQw5alFfO/JhXTLTOfF75/GNWcOpGnaoXehg7q0ZdqEo3ln3Xbun7e2\njtLWXmhtBO5eamaTgVeANOAxd19hZtOABe4+G7gReNjMrifWcPxNd6+/ZVNEUs5rK7dy86xlFO4r\n4cZzBnP1mQNoVkMBiHfR8T1558Pt3Dt3DaP7deDkAR1DTHt4rKHtd7Ozs10D04hI2Ar37ue2v69g\n1sKNHNWtHb+9eDjDurer+YNV2FNcyhd//xa7i0uZc93pdGrTIslpa2Zm77t7dlXvRd1YLCJS77y+\nOo9xv3uDFxZvYsqYgbxw7amHXQQAWrdoyv2XjqJg335ueHoJ5fWsvUCFQEQksKtoPz+auZQrHp9P\nRstmPH/NqdwwbgjNmx75rnJY93bcev4w3lyTz4NvrktC2uRpcGMWi4iE4a0PtnHTzCVs2VnE984c\nwA/OHkSLpmlJ3cZlJ/bmnQ+3c9erqxndrz3H9+mQ1PUfLp0RiEhK21Ncyi3PL+PyR98jvXkaM793\nCj8aPzTpRQDAzLj9wmPpnpnO96cvomBvSdK3cThUCEQkZb27bjvj732TJ9/7hO+c3o85U05nVO/2\noW6zXXoz7p80ivzdxfy/Z5ZSH27YUSEQkZR0/7wPmPjQu6SZ8cx3T+Yn5w0jvVnyzwKqMrxXJlPP\nPYp/5mzl8bfX18k2D0VtBCKScu6b+wF3v7aGL4/swa++fAytmtf9rvDKU/vyzofbuf2lHLL7tue4\nnpl1nqGCzghEJKXcPy9WBC4c1ZPfXjw8kiIAsfaCuy4+jqw2LZg8fRE7i/ZHkgNUCEQkhfzhX2u5\n69U1fGVkD35z0XE0qdRRXF3LbNWc+yaNZGPBPm6etSyy9gIVAhFJCX9840N+8/JqvjSiO3dePPyg\n3kKjkt23AzeOG8w/lm5m+n8+iSSDCoGINHoPv7mOX7+0ii8O785d9agIVLj6cwM4fVAnbntxJTmb\nd9b59lUIRKRRe+Tf6/jVnBzOO64b93x1eI09hkahSRPjnktGkNGyGddOX8ie4tK63X6dbk1EpA49\n/vZH/PIfOXzh2K7ce8mIelkEKnRq04J7J47go217+OkLy+t02/X3tyIicgT+9H/rue3FlYw/uiv3\nThxZr4tAhVMGdGLKmEHMWriRme/n1tl26/9vRkSklv7yznp+NnsF44Z14feXjqzV+AFRmzJ2ECf1\n78BPn1/O2rxddbLNUH87ZjbezFab2Vozm1rF+/eY2eLga42ZFYSZR0Qavyff+5ifvrCCs4/qwv2X\njmpQRQAgrYlx78SRtGqexrVPLqJof1no2wztN2RmacADwLnAMGCSmQ2LX8bdr3f3Ee4+Avg9MCus\nPCLS+D31n0/4yXPLGTu0Mw9cNjIp3UdHoUu7dO6+ZASrt+7ithdXhr69MH9Lo4G17r7O3UuAGcCE\nQyw/CXgqxDwi0oj9bf4n3DxrGWcNyeIPl48KpffQunTG4Cy+d+YAnvrPJ8xesinUbYVZCHoAG+Km\nc4N5BzGzPkA/YF41719lZgvMbEF+vsa2F5HPembBBqbOWsYZg7P438uPb/BFoMIN5wzm+D7t+fGs\nZazftie07dSX86aJwEx3r/JimLs/5O7Z7p6dlZVVx9FEpD579v1cbnp2KacN7MSDXzu+znoQrQvN\n0ppw36SRpDUxrp2+kOLScNoLwiwEG4FecdM9g3lVmYguC4lILT23KJf/N3MJpw7oxMNfz25URaBC\nj8yW3HXxcFZs2smjb30UyjbC7HZvPjDIzPoRKwATgUsrL2RmQ4H2wDshZhGRRuaFxRu58eklnNy/\nY6MtAhXOGdaFP14+ijOHdA5l/aGdEbh7KTAZeAXIAZ529xVmNs3MLohbdCIww+vDMD0i0iDMXrKJ\n6/+2mNH9OvDIN7Jp2bzxFoEK44/pFlqxC7UjbnefA8ypNO/WStM/DzODiDQuf18aKwLZfTvw2DdP\niGw8gcakvjQWi4jUaM6yzVw3YzGjemfyuIpA0qgQiEiD8PLyLUx5ahEjemXy+BWjad1CRSBZVAhE\npN57dcUWJk9fyHE9M3jiihNooyKQVPptikjkikvL2FpYzObCfWwuLGJT4T62FBaxqaCIzYX7WL1l\nF8f0yOCJK0fTNr1Z1HEbHRUCEQnV/rJytu4siu3gC2I7+AOvd8Z29tt2Fx/0uXbpTeme2ZJuGemc\n1L8j1509iHYqAqFQIRCRpJm1MJflG3eyuXAfmwqL2Fywj/zdxVS+ObxNi6Z0y0inW2ZLhnVrR7eM\nlsF0+oHXagOoO/pNi0hSfJi/mxueXkLLZml0y0yne0ZLBg/OoltwVN8tI/3AEb4u79QvKgQikhTz\ncvIAeO2Gz9GzfauI00ht6K4hEUmKuau2MrRrWxWBBkiFQESOWOG+/cxfv4MxQ8PpC0fCpUIgIkfs\nzTX5lJU7Y49SIWiIVAhE5IjNW5VH+1bNGNGrfdRR5DCoEIjIESkrd15fncdZQzqT1sSijiOHQYVA\nRI7Iok92ULB3P2N0WajBUiEQkSMyd1UeTZsYpw/SMLINVaiFwMzGm9lqM1trZlOrWearZrbSzFaY\n2fQw84hI8s3LyeOEvh3IaKmHxBqq0AqBmaUBDwDnAsOASWY2rNIyg4CbgVPd/WjgB2HlEZHk2/Dp\nXlZv3aW7hRq4MM8IRgNr3X2du5cAM4AJlZb5DvCAu+8AcPe8EPOISJK9vjr2L6vnBxq2MAtBD2BD\n3HRuMC/eYGCwmb1tZu+a2fiqVmRmV5nZAjNbkJ+fH1JcEamtuTl59OvUmv5ZbaKOIkcg6sbipsAg\n4ExgEvCwmWVWXsjdH3L3bHfPzspSg5RIfbCnuJR3Ptyus4FGIMxCsBHoFTfdM5gXLxeY7e773f0j\nYA2xwiAi9dzba7dRUlbOWBWCBi/MQjAfGGRm/cysOTARmF1pmeeJnQ1gZp2IXSpaF2ImEUmSeavy\naNuiKdl9O0QdRY5QaIXA3UuBycArQA7wtLuvMLNpZnZBsNgrwHYzWwm8DvzQ3beHlUlEkqO83Jm3\nKo/PDc6iedOorzDLkQp1PAJ3nwPMqTTv1rjXDtwQfIlIA7Fi007ydhWrfaCRUCkXkVqbu2orZnDm\nEN280RioEIhIrc1blcfIXpl0bNMi6iiSBCoEIlIreTuLWJpbyNijukQdRZJEhUBEakVPEzc+1TYW\nm9kh7wlz90+TH0dE6ru5OXl0z0hnaNe2UUeRJDnUXUPvAw4Y0BvYEbzOBD4B+oWeTkTqlaL9Zby1\ndhtfGdUDMw1C01hUe2nI3fu5e3/gn8AX3b2Tu3cEzgderauAIlJ/vPfRp+wtKWPsULUPNCaJtBGc\nFDwPAIC7vwScEl4kEamv5uVsJb1ZE04e0DHqKJJEiTxQtsnMbgH+GkxfBmwKL5KI1EfuztxVeZw2\nsBPpzdKijiNJlMgZwSQgC3gOmBW8nhRmKBGpfz7I203ujn2M0WWhRueQZwTBKGM/dvfr6iiPiNRT\nc3N022hjdcgzAncvA06roywiUo/NW7WVo7u3o2tGetRRJMkSaSNYZGazgWeAPRUz3X1WaKlEpF7Z\nsaeE9z/eweSzBkYdRUKQSCFIB7YDY+LmObH2AhFJAW+syafcYYy6lWiUaiwE7n5FXQQRkfpr7qo8\nOrVpwXE9MqKOIiGosRCYWTrwLeBoYmcHALj7lSHmEpF6Yn9ZOW+szmP8MV1p0kRPEzdGidw++heg\nK/B54A1iYw/vSmTlZjbezFab2Vozm1rF+980s3wzWxx8fbs24UUkfO9/vIOdRaW6bbQRS6SNYKC7\nX2xmE9z9T2Y2Hfh3TR8Kbj19ADiH2CD1881struvrLTo39x9cq2Ti0idmLcqj+ZpTThtUKeoo0hI\nEjkj2B98LzCzY4AMIJEbiUcDa919nbuXADOACYcXU0SiMjdnKyf270CbFqGObCsRSqQQPGRm7YGf\nArOBlcAdCXyuB7Ahbjo3mFfZhWa21MxmmlmvqlZkZleZ2QIzW5Cfn5/ApkUkGdZv28OH+XsYq4fI\nGrUaC4G7P+LuO9z9DXfv7+6d3f3BJG3/RaCvux8HvAb8qZoMD7l7trtnZ2VpjFSRujJvVcXTxGof\naMxqLARm9qGZPWlmV5vZ0bVY90Yg/gi/ZzDvAHff7u7FweQjwPG1WL+IhGzeqjwGdW5D746too4i\nIUrk0tAw4EGgI3BnUBieS+Bz84FBZtbPzJoDE4ldWjrAzLrFTV4A5CQWW0TCtqtoP+99tJ0xR+my\nUGOXSOtPGbEG4zKgHMgLvg7J3UvNbDLwCpAGPObuK8xsGrDA3WcDU8zsAqAU+BT45mH9FCKSdG99\nsI39Za5BaFJAIoVgJ7AMuBt42N23J7ryYECbOZXm3Rr3+mbg5kTXJyJ1Z+6qPDJaNmNU78yoo0jI\nEh2P4E3gGmCGmd1mZmPDjSUiUSord15flceZQ7JompbIbkIaskT6GnoBeMHMhgLnAj8AbgJahpxN\nRCKyJLeA7XtKNPZAikjkrqFnzWwtcC/QCvg60D7sYCISnXk5eaQ1Mc4YrNu1U0EibQS3A4uCQWpE\nJAXMXZXH8X3ak9mqedRRpA4kcvFvJXCzmT0EYGaDzOz8cGOJSFQ2FewjZ/NOPU2cQhIpBI8DJcAp\nwfRG4JehJRKRSFU8TTxWzw+kjEQKwQB3/w1B53PuvhdQp+QijdS8VXn07tCKAVltoo4idSSRQlBi\nZi2JDU+JmQ0Aig/9ERFpiPaVlPH22m2MGdoZMx3vpYpEGot/BrwM9DKzJ4FT0RPAIo3S/324jeLS\ncl0WSjGHLAQWOyRYBXwFOInYJaHr3H1bHWQTkTo2d1UerZunMbpfh6ijSB06ZCFwdzezOe5+LPCP\nOsokIhFwd+bl5HH6oCxaNE2LOo7UoUTaCBaa2QmhJxGRSK3cvJMtO4vU22gKSqSN4ETgMjP7GNhD\n7PKQB4PJiEgjMS8ndtvoWUNUCFJNIoXg86GnEJHIzV2Vx/BemWS1bRF1FKljiXQ693FdBBGR6OTv\nKmZJbgHXnz046igSgVD7lzWz8Wa22szWmtnUQyx3oZm5mWWHmUdEqvav1Xm4o95GU1RohcDM0oAH\niHVdPQyYZGbDqliuLXAd8F5YWUTk0OatyqNLuxYc3b1d1FEkAgkVAjPrY2ZnB69bBjvvmowG1rr7\nOncvAWYAE6pY7hfAHUBRgplFJIlKSst5c00+Y4Z20dPEKSqR8Qi+A8wkNoA9QE/g+QTW3QPYEDed\nG8yLX/cooJe7H/IZBTO7yswWmNmC/Pz8BDYtIon6z0efsqekTL2NprBEzgiuJdatxE4Ad/8AOOK/\nGDNrQmwc5BtrWtbdH3L3bHfPzsrSQBkiyTR31VZaNG3CqQM7RR1FIpJIISgOLu0AYGZNCTqgq8FG\noFfcdM9gXoW2wDHAv8xsPbEuLGarwVik7rg7c3PyOGVAR1o219PEqSqRQvCGmf0YaGlm5wDPAC8m\n8Ln5wCAz62dmzYGJwOyKN9290N07uXtfd+8LvAtc4O4Lav1TiMhh+TB/D598upcxR3WJOopEKJFC\nMBXIB5YB3wXmALfU9CF3LwUmA68AOcDT7r7CzKaZ2QWHH1lEkmXeqq2AbhtNdYk8UFYOPBx81Yq7\nzyFWOOLn3VrNsmfWdv0icmTm5uQxtGtbemS2jDqKRKjGQmBmyzi4TaAQWAD80t23hxFMRMJVuHc/\nCz7ewdVn9I86ikQskb6GXgLKgOnB9ESgFbAFeAL4YijJRCRUb3yQT1m5M2ao2gdSXSKF4Gx3HxU3\nvczMFrr7KDO7PKxgIhKueTlb6dC6OSN6ZUYdRSKWSGNxmpmNrpgIxiaouM+sNJRUIhKq0rJy/rUm\nnzOHZJHWRE8Tp7pEzgi+DTxmZm2IjUWwE/i2mbUGbg8znIiEY9GGAgr27mesLgsJid01NB841swy\ngunCuLefDiuYiIRnbk4eTZsYpw/W08SS2BkBZnYecDSQXtEplbtPCzGXiISgvNz58zvreeL/PuLk\nAR1pl94s6khSDyRy++gfid0ldBbwCHAR8J+Qc4lIkm34dC8/nLmEd9d9yplDsrjjQo02KzGJnBGc\n4u7HmdlSd7/NzH5L7JZSEWkA3J0n3/uE/5mTQxMz7rjwWL6a3UtdTssBiRSCinEC9ppZd2A70C28\nSCKSLBsL9vGjmUt5a+02ThvYiTsuOk5PEctBEikEL5pZJnAnsJDYU8a17m5CROqOu/PMglx+8feV\nlLnzyy8dw2Un9tZZgFTpkIUgGDNgrrsXAM+a2d+B9Ep3DolIPbKlsIibZy3l9dX5nNivA3deNJze\nHVtFHUvqsUMWAncvN7MHgJHBdDFQXBfBRKR23J3nFm3k57NXUFJWzs+/OIyvn9yXJnpgTGqQyKWh\nuWZ2ITDL3RMZkEZE6ljeriJ+8txyXlu5lew+7bnr4uH07dQ66ljSQCRSCL4L3ACUmdk+Yk8Xu7u3\nCzWZiNTI3Xlx6WZufWE5e0vKuOW8o7ji1H7qNkJqJZEni9vWRRARqZ3tu4v56QvLmbNsC8N7ZfLb\ni4czsHObqGNJA1Rjp3MWc7mZ/TSY7hXfCV0Nnx1vZqvNbK2ZTa3i/avNbJmZLTazt8xsWO1/BJHU\n8/LyzYy7503+uTKPm8YP4dmrT1YRkMOWyKWhPwDlwBjgF8Bu4AHghEN9yMzSguXOAXKB+WY2291X\nxi023d3/GCx/AXA3ML62P4RIqtixp4SfzV7B7CWbOKZHO6ZfPIIhXXXSLkcmkUJwYjD2wCIAd98R\nDEZfk9HAWndfB2BmM4AJwIFC4O4745ZvzcEjoYlI4J8rt3Lzc8vYsaeEG84ZzPfOHECztER6khc5\ntEQKwf7g6N4BzCyL2BlCTXoAG+Kmc4ETKy9kZtcSa4xuTuys4yBmdhVwFUDv3r0T2LRI41G4bz/T\nXlzJswtzGdq1LU9ccQJHd8+IOpY0IokcTtwHPAd0NrNfAW8B/5OsAO7+gLsPAH4E3FLNMg+5e7a7\nZ2dlZSVr0yL13sJPdjDunjd4fvFGpowZyOzJp6kISNIlctfQk2b2PjCW2K2jX3L3nATWvRHoFTfd\nM5hXnRnA/yawXpGUsG13Md/9y/ukN2vCc9ecwnE9NaSkhCORbqjvA2a4+wO1XPd8YJCZ9SNWACYC\nl1Za9yB3/yCYPA/4ABGhvNy5/m+LKdy3nz9feSpHddNjOxKeRNoI3gduMbMhxC4RzXD3BTV9yN1L\nzWwy8AqxMY4fc/cVZjYNWODus4HJZnY2sB/YAXzjcH8Qkcbkj29+yL8/2MavvnyMioCEzhLtNcLM\nOgAXEjuy7+3ug8IMVp3s7GxfsKDGOiTSYC1Y/ymXPPQu5x7Tld9PGqkeQyUpzOx9d8+u6r3a3Hs2\nEBgK9AFWJSOYiHzWjj0lTHlqET3bt+T2rxyrIiB1IpEni39jZh8A04DlQLa7fzH0ZCIpxt354cwl\n5O8u5v5Jo2ir8YSljiTSRvAhcLK7bws7jEgqe+zt9fwzJ4+ffXEYx/bULaJSdxK5ffRBM2sf9C+U\nHjf/zVCTiaSQJRsK+PVLOZwzrAvfPKVv1HEkxSRy++i3geuIPQewGDgJeIdqngIWkdrZWbSfyU8t\npHPbdO686Di1C0idS6Sx+DpiHcx97O5nERutrCDUVCIpwt2Z+uxSNhUUcd+kkWS2SqQbL5HkSqQQ\nFLl7EYCZtXD3VcCQcGOJpIYn3/uEOcu28MPPD+H4Pu2jjiMpKpHG4lwzywSeB14zsx3Ax+HGEmn8\nVmwqZNrfV3LG4CyuOr1/1HEkhSXSWPzl4OXPzex1IAN4OdRUIo3c7uJSvj99Ee1bNePurw7XAPMS\nqUTOCA5w9zfCCiKSKtydW55bxvrte5j+nZPo2KZF1JEkxWlUC5E69sz7uTy/eBPXjR3MSf07Rh1H\nRIVApC59sHUXt76wnJP7d2TymIFRxxEBVAhE6sy+kjKunb6Q1s2bcu/EEaSpXUDqiVq1EYjI4bvt\nxRWs2bqbP185ms7t0mv+gEgd0RmBSB14YfFGZszfwDVnDuBzgzXcqtQvoRYCMxtvZqvNbK2ZTa3i\n/RvMbKWZLTWzuWbWJ8w8IlH4aNsefjxrGdl92nPDOYOjjiNykNAKgZmlAQ8A5wLDgElmNqzSYouI\ndWt9HDAT+E1YeUSiULS/jGufXEizpk24b9JImqbpJFzqnzD/KkcDa919nbuXEBucfkL8Au7+urvv\nDSbfJdaxnUijcfucHFZu3sldFw2ne2bLqOOIVCnMQtAD2BA3nRvMq863gJeqesPMrjKzBWa2ID8/\nP4kRRcLz8vLN/Omdj/n2af04e1iXqOOIVKtenKea2eVANnBnVe+7+0Punu3u2VlZamiT+m/Dp3v5\n4cylDO+ZwU3jh0YdR+SQwrx9dCPQK266ZzDvM8zsbOAnwBnuXhxiHpE6UVJazuSnFgFw/6WjaN60\nXhxviVQrzL/Q+cAgM+tnZs2BicDs+AXMbCTwIHCBu+eFmEWkztz5yiqWbCjgjguPo1eHVlHHEalR\naIXA3UuBycArQA7wtLuvMLNpZnZBsNidQBvgGTNbbGazq1mdSIMwN2crD//7I752Uh++cGy3qOOI\nJCTUJ4vdfQ4wp9K8W+Nenx3m9kXq0ubCfdz4zBKO6taOn5x3VNRxRBKmi5ciSVBaVs6UpxZRUlrO\nA5eOJL1ZWtSRRBKmvoZEkuDhf3/E/PU7uOeS4fTPahN1HJFa0RmByBFav20Pv/vnGsYN68KXR+qZ\nSGl4VAhEjoC785Pnl9E8rQnTJhwTdRyRw6JCIHIEnl24kbfXbuemc4fSNUNdS0vDpEIgcpi27S7m\nl/9YSXaf9lw2unfUcUQOmwqByGH6xd9Xsqe4lNu/cixNNNqYNGAqBCKH4fXVebyweBPXnDmQQV3a\nRh1H5IioEIjU0p7iUm55bjkDslpzzVkDoo4jcsT0HIFILd392ho2FuzjmatPpkVTPTgmDZ/OCERq\nYcmGAh5/+yMuPbE3J/TtEHUckaRQIRBJ0P6ycqbOWkanNi2Yeq7GGJDGQ5eGRBL06FsfkbN5J3+8\nfBTt0ptFHUckaXRGIJKAj7fv4Z7XYt1IjD9G3UtL46JCIFIDd+fHz6kbCWm8Qi0EZjbezFab2Voz\nm1rF+58zs4VmVmpmF4WZReRwqRsJaexCKwRmlgY8AJwLDAMmmdmwSot9AnwTmB5WDpEjoW4kJBWE\n2Vg8Gljr7usAzGwGMAFYWbGAu68P3isPMYfIYVM3EpIKwrw01APYEDedG8wTaRDUjYSkigbRWGxm\nV5nZAjNbkJ+fH3UcSQEV3UgM7NxG3UhIoxdmIdgI9Iqb7hnMqzV3f8jds909OysrKynhRA6lohuJ\n279yrLqRkEYvzEIwHxhkZv3MrDkwEZgd4vZEkmJprrqRkNQSWiFw91JgMvAKkAM87e4rzGyamV0A\nYGYnmFkucDHwoJmtCCuPSCL2l5Uz9Vl1IyGpJdQuJtx9DjCn0rxb417PJ3bJSKReePStj1ipbiQk\nxTSIxmKRuqBuJCRVqRCIoG4kJLWpEIigbiQktakQSMpTNxKS6lQIJOWpGwlJdSoEktLUjYSICoGk\nsL0l6kZCBDRUpaSwu1+NdSPxzNUnqxsJSWk6I5CUtDS3gMfUjYQIoEIgKUjdSIh8li4NSb1UVu7k\n7ypmd3Fp0tc9e8kmdSMhEkeFQOpcebmzbXcxmwuL2Fy4j00Fse+x6SI2F+xj665iyso9tAzqRkLk\nv1QIJKncne17SthcUMSmwn1sLtjH5p1FbI7b2W/dWcT+ss/u5Fs0bUL3zJZ0bZfOSQM60j2jJd0y\n02nToilmyb23v1kT46yhnZO6TpGGTIWgkdpTXHrw0XZBEfm7wznS3re/jC2FRWwpLKKk7LNDUDdP\na0LXjHS6ZaST3ac93TJb0j0jnW4ZLemakU73zJa0b9Us6Tt8EUmMCkEDtK+k7MDOfVNB3CWVwn0H\njsR3FR18bT2rbQs6t21B04V3axIAAAo7SURBVLTk3yPQIq0JI3pl0u2Y2A4/trOP7eg7tm6uJ3ZF\n6jEVgnqmKDiyPrBjD3b2WwqL2BTMK9i7/6DPdWzdnG6Z6fTu2IoT+3egW0ZLumfGjrq7ZaTTpV06\nzZvqJjEROViohcDMxgP3AmnAI+7+60rvtwD+DBwPbAcucff1YWaKUklpOVt3Bjv2nUUHLttsKihi\ny87Y0fz2PSUHfS6zVbPYjj0jnVG9M+meGdu5V+zsu7RLJ72ZHogSkcMTWiEwszTgAeAcIBeYb2az\n3X1l3GLfAna4+0AzmwjcAVwSVqYwlZaVs3VXMZsL9rGpsIgtVdwNs213MV7p8ny79Kaxo/bMdI7t\nkUn3jPQD1827Ba9bNdeJm4iEJ8w9zGhgrbuvAzCzGcAEIL4QTAB+HryeCdxvZuZeeXd55J6ev4GH\n/70u2avFgV1F+8nfVUzlNtg2LZoeuF4+rFu72A4+2OlXHNG3bqGdvIhEK8y9UA9gQ9x0LnBidcu4\ne6mZFQIdgW3xC5nZVcBVAL17H15/8ZmtmjGoS5vD+mxNWjdveuBOmPij+bZ6WElEGoAGcTjq7g8B\nDwFkZ2cf1tnCuKO7Mu7orknNJSLSGIR5G8lGoFfcdM9gXpXLmFlTIINYo7GIiNSRMAvBfGCQmfUz\ns+bARGB2pWVmA98IXl8EzAujfUBERKoX2qWh4Jr/ZOAVYrePPubuK8xsGrDA3WcDjwJ/MbO1wKfE\nioWIiNShUNsI3H0OMKfSvFvjXhcBF4eZQUREDk2PmoqIpDgVAhGRFKdCICKS4lQIRERSnDW0uzXN\nLB/4+DA/3olKTy3Xcw0pb0PKCg0rb0PKCg0rb0PKCkeWt4+7Z1X1RoMrBEfCzBa4e3bUORLVkPI2\npKzQsPI2pKzQsPI2pKwQXl5dGhIRSXEqBCIiKS7VCsFDUQeopYaUtyFlhYaVtyFlhYaVtyFlhZDy\nplQbgYiIHCzVzghERKQSFQIRkRSXMoXAzMab2WozW2tmU6POUx0z62Vmr5vZSjNbYWbXRZ0pEWaW\nZmaLzOzvUWc5FDPLNLOZZrbKzHLM7OSoMx2KmV0f/B0sN7OnzCw96kzxzOwxM8szs+Vx8zqY2Wtm\n9kHwvX2UGStUk/XO4G9hqZk9Z2aZUWasUFXWuPduNDM3s07J2l5KFAIzSwMeAM4FhgGTzGxYtKmq\nVQrc6O7DgJOAa+tx1njXATlRh0jAvcDL7j4UGE49zmxmPYApQLa7H0OsO/f61lX7E8D4SvOmAnPd\nfRAwN5iuD57g4KyvAce4+3HAGuDmug5VjSc4OCtm1gsYB3ySzI2lRCEARgNr3X2du5cAM4AJEWeq\nkrtvdveFwetdxHZUPaJNdWhm1hM4D3gk6iyHYmYZwOeIjYOBu5e4e0G0qWrUFGgZjODXCtgUcZ7P\ncPc3iY0lEm8C8Kfg9Z+AL9VpqGpUldXdX3X30mDyXWIjKUaumt8rwD3ATUBS7/JJlULQA9gQN51L\nPd+5AphZX2Ak8F60SWr0O2J/nOVRB6lBPyAfeDy4jPWImbWOOlR13H0jcBexo7/NQKG7vxptqoR0\ncffNwestQJcow9TClcBLUYeojplNADa6+5JkrztVCkGDY2ZtgGeBH7j7zqjzVMfMzgfy3P39qLMk\noCkwCvhfdx8J7KH+XLY4SHBtfQKxAtYdaG1ml0ebqnaCoWfr/T3qZvYTYpdln4w6S1XMrBXwY+DW\nmpY9HKlSCDYCveKmewbz6iUza0asCDzp7rOizlODU4ELzGw9sUtuY8zsr9FGqlYukOvuFWdYM4kV\nhvrqbOAjd8939/3ALOCUiDMlYquZdQMIvudFnOeQzOybwPnAZfV4zPQBxA4IlgT/az2BhWbWNRkr\nT5VCMB8YZGb9zKw5sQa32RFnqpKZGbFr2DnufnfUeWri7je7e09370vs9zrP3evlUau7bwE2mNmQ\nYNZYYGWEkWryCXCSmbUK/i7GUo8bt+PMBr4RvP4G8EKEWQ7JzMYTu6x5gbvvjTpPddx9mbt3dve+\nwf9aLjAq+Js+YilRCILGoMnAK8T+kZ529xXRpqrWqcDXiB1ZLw6+vhB1qEbk+8CTZrYUGAH8T8R5\nqhWcucwEFgLLiP2/1qsuEczsKeAdYIiZ5ZrZt4BfA+eY2QfEzmp+HWXGCtVkvR9oC7wW/K/9MdKQ\ngWqyhre9+nsmJCIidSElzghERKR6KgQiIilOhUBEJMWpEIiIpDgVAhGRFKdCII2WmU0zs7OTsJ7d\nScrzOzP7XPB6ctAT7md6kbSY+4L3lprZqLj3vhH06PmBmX0jbv76GrY7w8wGJeNnkMZJt4+K1MDM\ndrt7myNcR0fgH+5+UjA9EtgB/ItY76LbgvlfIPaswxeAE4F73f1EM+sALACyiXXZ8D5wvLvvMLP1\nwUNG1W37DOByd//OkfwM0njpjEAaDDO73Mz+Ezz482DQvThmttvM7gn67Z9rZlnB/CfM7KLg9a8t\nNsbDUjO7K5jX18zmBfPmmlnvYH4/M3vHzJaZ2S8rZfihmc0PPnNbMK+1mf3DzJZYbNyAS6qIfyHw\ncsWEuy9y9/VVLDcB+LPHvAtkBt00fB54zd0/dfcdxLpPruimOL+GHP8Gzg56MBU5iAqBNAhmdhRw\nCXCqu48AyoDLgrdbAwvc/WjgDeBnlT7bEfgycHTQ73zFzv33wJ+CeU8C9wXz7yXWMd2xxHr9rFjP\nOGAQsW7NRwDHB5d6xgOb3H14MG7AgR1+nFOJHcXXpLqecqvtQdfdTwjmVZnD3cuBtcTGXxA5iAqB\nNBRjgeOB+Wa2OJjuH7xXDvwteP1X4LRKny0EioBHzewrQEWfMicD04PXf4n73KnAU3HzK4wLvhYR\n6/ZhKLHCsIxYlwp3mNnp7l5YRf5uBEfuITpUjjxiPZiKHESFQBoKI3b0PiL4GuLuP69m2c80fAV9\nTY0m1m/P+VR9xH7IdcRluD0uw0B3f9Td1xDrxXQZ8Eszq6qr4H1AIsNMVtdTbo096NaQIz3IIHIQ\nFQJpKOYCF5lZZzgwLm6f4L0mwEXB60uBt+I/GIztkOHuc4Dr+e8lkv/jv0M/XkbsWjrA25XmV3gF\nuDJYH2bWw8w6m1l3YK+7/xW4k6q7ts4BBibwc84Gvh7cPXQSscFoNgfbHmdm7S02TsG4YF78z3mo\nHIOBg8a/FYHYQB0i9Z67rzSzW4BXzawJsB+4FviY2AAzo4P384i1JcRrC7xgsYHfDbghmP99YqOV\n/ZDYZZsrgvnXAdPN7EfEdaHs7q8GbRXvxHqFZjdwObEd/J1mVh7k+l4VP8I/gO8SDOdpZlOIdX/c\nFVhqZnPc/dvAHGJ3DK0ldgnrimDbn5rZL4h1qQ4wzd0rD2V4bFU5zKwLsC9ZXRZL46PbR6XBS8bt\nnXXBzN4Czq/rcZLN7Hpgp7s/WpfblYZDl4ZE6s6NQO8ItlvAfweTFzmIzghERFKczghERFKcCoGI\nSIpTIRARSXEqBCIiKU6FQEQkxf1/wWUIKH9sIAUAAAAASUVORK5CYII=\n", 458 | "text/plain": [ 459 | "
" 460 | ] 461 | }, 462 | "metadata": { 463 | "tags": [] 464 | }, 465 | "output_type": "display_data" 466 | } 467 | ], 468 | "source": [ 469 | "# możemy narysować postęp trenowania i zobaczyć jak agent się polepsza\n", 470 | "import matplotlib.pyplot as plt\n", 471 | "\n", 472 | "def get_average(values):\n", 473 | " return sum(values) / len(values)\n", 474 | "\n", 475 | "avg_rewards = []\n", 476 | "for i in range(0, len(rewards), 100):\n", 477 | " avg_rewards.append(get_average(rewards[i:i+100]))\n", 478 | "\n", 479 | "plt.plot(avg_rewards)\n", 480 | "plt.ylabel('average reward')\n", 481 | "plt.xlabel('episodes (100\\'s)')\n", 482 | "plt.show()" 483 | ] 484 | }, 485 | { 486 | "cell_type": "markdown", 487 | "metadata": { 488 | "id": "sL5_fLdadMw6", 489 | "colab_type": "text" 490 | }, 491 | "source": [ 492 | "### Źródło:\n", 493 | "[https://www.youtube.com/watch?v=tPYj3fFJGjk](https://www.youtube.com/watch?v=tPYj3fFJGjk)" 494 | ] 495 | } 496 | ], 497 | "metadata": { 498 | "colab": { 499 | "name": "q_learning.ipynb", 500 | "provenance": [], 501 | "collapsed_sections": [], 502 | "authorship_tag": "ABX9TyMShg/BChmNRnh4JPpurMig", 503 | "include_colab_link": true 504 | }, 505 | "kernelspec": { 506 | "name": "python3", 507 | "display_name": "Python 3" 508 | } 509 | }, 510 | "nbformat": 4, 511 | "nbformat_minor": 0 512 | } 513 | -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/gridworld/gridworld.py: -------------------------------------------------------------------------------- 1 | ACTION_SPACE = {'u', 'd', 'l', 'r'} 2 | 3 | 4 | class Grid: 5 | def __init__(self, rows, cols, start): 6 | self.rows = rows 7 | self.cols = cols 8 | self.i = start[0] 9 | self.j = start[1] 10 | 11 | def set(self, rewards, actions): 12 | self.rewards = rewards 13 | self.actions = actions 14 | 15 | def set_state(self, s): 16 | self.i = s[0] 17 | self.j = s[1] 18 | 19 | def current_state (self): 20 | return self.i, self.j 21 | 22 | def reset(self): 23 | self.i = 2 24 | self.j = 0 25 | return (self.i, self.j) 26 | 27 | def is_terminal(self, state): 28 | return state not in self.actions 29 | 30 | def get_next_state(self, state, action): 31 | i = state[0] 32 | j = state[1] 33 | if action in self.actions[(i, j)]: 34 | if action == 'u': 35 | i -= 1 36 | elif action == 'd': 37 | i += 1 38 | elif action == 'r': 39 | j += 1 40 | elif action == 'l': 41 | j -= 1 42 | return i, j 43 | 44 | def move(self, action): 45 | # Check if move is legal 46 | if action in self.actions[(self.i, self.j)]: 47 | if action == 'u': 48 | self.i -= 1 49 | elif action == 'd': 50 | self.i += 1 51 | elif action == 'r': 52 | self.j += 1 53 | elif action == 'l': 54 | self.j -= 1 55 | return self.rewards.get((self.i, self.j), 0) 56 | 57 | def undo_move(self, action): 58 | if action == 'u': 59 | self.i += 1 60 | elif action == 'd': 61 | self.i -= 1 62 | elif action == 'r': 63 | self.j -= 1 64 | elif action == 'l': 65 | self.j += 1 66 | 67 | def game_over(self): 68 | return (self.i, self.j) not in self.actions 69 | 70 | def all_states(self): 71 | return set(self.actions.keys()) | set(self.rewards.keys()) 72 | 73 | 74 | def standard_grid(): 75 | g = Grid(3, 4, (2, 0)) 76 | rewards = {(0, 3): 1, (1, 3): -1} 77 | actions = { 78 | (0, 0): ('d', 'r'), 79 | (0, 1): ('l', 'r'), 80 | (0, 2): ('l', 'd', 'r'), 81 | (1, 0): ('u', 'd'), 82 | (1, 2): ('u', 'd', 'r'), 83 | (2, 0): ('u', 'r'), 84 | (2, 1): ('l', 'r'), 85 | (2, 2): ('l', 'r', 'u'), 86 | (2, 3): ('l', 'u') 87 | } 88 | g.set(rewards, actions) 89 | return g 90 | 91 | 92 | def negative_grid(step_cost=-0.1): 93 | g = standard_grid() 94 | g.rewards.update({ 95 | (0, 0): step_cost, 96 | (0, 1): step_cost, 97 | (0, 2): step_cost, 98 | (1, 0): step_cost, 99 | (1, 2): step_cost, 100 | (2, 0): step_cost, 101 | (2, 1): step_cost, 102 | (2, 2): step_cost, 103 | (2, 3): step_cost 104 | }) 105 | return g 106 | -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/gridworld/iterative_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from machine_learning.reinforcement_learning.gridworld.gridworld import ( 3 | standard_grid, ACTION_SPACE 4 | ) 5 | converge_thresh = 0.001 6 | 7 | 8 | def print_values(V, grid): 9 | for i in range(grid.rows): 10 | print('------------------------') 11 | for j in range(grid.cols): 12 | v = V.get((i, j), 0) 13 | if v >= 0: 14 | print(' %.2f|' % v, end='') 15 | else: 16 | print('%.2f|' % v, end='') 17 | print() 18 | 19 | 20 | def print_policy(P, grid): 21 | for i in range(grid.rows): 22 | print('------------------------') 23 | for j in range(grid.cols): 24 | a = P.get((i, j), ' ') 25 | print(' %s |' % a, end='') 26 | print() 27 | 28 | 29 | if __name__ == '__main__': 30 | # Define transition probabilities and grid 31 | transition_probs = {} 32 | rewards = {} 33 | 34 | grid = standard_grid() 35 | for i in range(grid.rows): 36 | for j in range(grid.cols): 37 | s = (i, j) 38 | if not grid.is_terminal(s): 39 | for a in ACTION_SPACE: 40 | s2 = grid.get_next_state(s, a) 41 | transition_probs[(s, a, s2)] = 1 42 | if s2 in grid.rewards: 43 | rewards[(s, a, s2)] = grid.rewards[s2] 44 | 45 | policy = { 46 | (2, 0): 'u', 47 | (1, 0): 'u', 48 | (0, 0): 'r', 49 | (0, 1): 'r', 50 | (0, 2): 'r', 51 | (1, 2): 'u', 52 | (2, 1): 'r', 53 | (2, 2): 'u', 54 | (2, 3): 'l' 55 | } 56 | print_policy(policy, grid) 57 | 58 | V = {} # Initialize V(s) = 0 59 | for s in grid.all_states(): 60 | V[s] = 0 61 | 62 | gamma = 0.9 # Discount factor 63 | 64 | # Repeat until convergence 65 | iter = 1 66 | while True: 67 | biggest_change = 0 68 | for s in grid.all_states(): 69 | if not grid.is_terminal(s): 70 | old_v = V[s] 71 | new_v = 0 72 | for a in ACTION_SPACE: 73 | for s2 in grid.all_states(): 74 | action_prob = 1 if policy.get(s) == a else 0 75 | r = rewards.get((s, a, s2), 0) 76 | new_v += action_prob * transition_probs.get((s, a, s2), 0) * (r + gamma * V[s2]) 77 | 78 | V[s] = new_v 79 | biggest_change = max(biggest_change, np.abs(old_v - V[s])) 80 | 81 | print('Iter:', iter, 'biggest change:', biggest_change) 82 | print_values(V, grid) 83 | iter += 1 84 | 85 | if biggest_change < converge_thresh: 86 | break 87 | print() 88 | -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/gridworld/monte_carlo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def max_dict(d): 5 | # Returns the argmax (key) and max (value) from a dict 6 | max_val = max(d.values()) 7 | max_keys = [key for key, val in d.items() if val == max_val] 8 | return np.random.choice(max_keys), max_val 9 | -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/gridworld/q_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from machine_learning.reinforcement_learning.gridworld.gridworld import ( 4 | negative_grid 5 | ) 6 | from machine_learning.reinforcement_learning.gridworld.iterative_policy import ( 7 | print_values, print_policy 8 | ) 9 | from machine_learning.reinforcement_learning.gridworld.monte_carlo import ( 10 | max_dict 11 | ) 12 | gamma = 0.9 13 | alpha = 0.1 14 | action_space = ['u', 'd', 'l', 'r'] 15 | 16 | 17 | def epsilon_greedy(Q, s, eps=0.1): 18 | if np.random.random() < eps: 19 | return np.random.choice(action_space) 20 | else: 21 | a_optimal = max_dict(Q[s])[0] 22 | return a_optimal 23 | 24 | 25 | if __name__ == '__main__': 26 | grid = negative_grid(step_cost=-0.1) 27 | 28 | print('Rewards') 29 | print_values(grid.rewards, grid) 30 | 31 | # Initialize Q(s, a) = 0 32 | Q = {} 33 | states = grid.all_states() 34 | for s in states: 35 | Q[s] = {} 36 | for a in action_space: 37 | Q[s][a] = 0 38 | 39 | update_counts = {} 40 | 41 | reward_per_episode = [] 42 | for i in range(10000): 43 | if i % 2000 == 0: 44 | print('Iter:', i) 45 | 46 | # Begin a new episode 47 | s = grid.reset() 48 | a = epsilon_greedy(Q, s, eps=0.1) 49 | episode_reward = 0 50 | while not grid.game_over(): 51 | # Perform action and get next state + reward 52 | a = epsilon_greedy(Q, s, eps=0.1) 53 | r = grid.move(a) 54 | s2 = grid.current_state() 55 | # Update reward 56 | episode_reward += r 57 | # Get next action 58 | maxQ = max_dict(Q[s2])[1] 59 | a2 = epsilon_greedy(Q, s2, eps=0.1) 60 | # Update Q(s, a) 61 | Q[s][a] += alpha * (r + gamma * maxQ - Q [s][a]) 62 | # Check how often Q(s) is updated 63 | update_counts[s] = update_counts.get(s, 0) + 1 64 | # Next state becomes current state 65 | s = s2 66 | 67 | # Log the reward for this episode 68 | reward_per_episode.append(episode_reward) 69 | 70 | plt.plot(reward_per_episode) 71 | plt.title('Reward per episode') 72 | plt.show() 73 | 74 | # Determine the policy from Q* 75 | # Find V* from Q* 76 | policy = {} 77 | V = {} 78 | for s in grid.actions.keys(): 79 | a, max_q = max_dict(Q[s]) 80 | policy[s] = a 81 | V[s] = max_q 82 | 83 | # The proportion of time we spend updating each part of Q 84 | print('Update counts:') 85 | total = np.sum(list(update_counts.values())) 86 | for k, v in update_counts.items(): 87 | update_counts[k] = float(v) / total 88 | 89 | print_values(update_counts, grid) 90 | print('Values:') 91 | print_values(V, grid) 92 | print('Policy:') 93 | print_policy(policy, grid) 94 | -------------------------------------------------------------------------------- /machine_learning/reinforcement_learning/gridworld/sarsa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from machine_learning.reinforcement_learning.gridworld.gridworld import ( 4 | negative_grid 5 | ) 6 | from machine_learning.reinforcement_learning.gridworld.iterative_policy import ( 7 | print_values, print_policy 8 | ) 9 | from machine_learning.reinforcement_learning.gridworld.monte_carlo import ( 10 | max_dict 11 | ) 12 | gamma = 0.9 13 | alpha = 0.1 14 | action_space = ['u', 'd', 'l', 'r'] 15 | 16 | 17 | def epsilon_greedy(Q, s, eps=0.1): 18 | if np.random.random() < eps: 19 | return np.random.choice(action_space) 20 | else: 21 | a_optimal = max_dict(Q[s])[0] 22 | return a_optimal 23 | 24 | 25 | if __name__ == '__main__': 26 | grid = negative_grid(step_cost=-0.1) 27 | 28 | print('Rewards') 29 | print_values(grid.rewards, grid) 30 | 31 | # Initialize Q(s, a) = 0 32 | Q = {} 33 | states = grid.all_states() 34 | for s in states: 35 | Q[s] = {} 36 | for a in action_space: 37 | Q[s][a] = 0 38 | 39 | update_counts = {} 40 | 41 | reward_per_episode = [] 42 | for i in range(10000): 43 | if i % 2000 == 0: 44 | print('Iter:', i) 45 | 46 | # Begin a new episode 47 | s = grid.reset() 48 | a = epsilon_greedy(Q, s, eps=0.1) 49 | episode_reward = 0 50 | while not grid.game_over(): 51 | # Perform action and get next state + reward 52 | r = grid.move(a) 53 | s2 = grid.current_state() 54 | # Update reward 55 | episode_reward += r 56 | # Get next action 57 | a2 = epsilon_greedy(Q, s2, eps=0.1) 58 | # Update Q(s, a) 59 | Q[s][a] += alpha * (r + gamma * Q[s2][a2] - Q [s][a]) 60 | # Check how often Q(s) is updated 61 | update_counts[s] = update_counts.get(s, 0) + 1 62 | # Next state becomes current state 63 | s = s2 64 | a = a2 65 | 66 | # Log the reward for this episode 67 | reward_per_episode.append(episode_reward) 68 | 69 | plt.plot(reward_per_episode) 70 | plt.title('Reward per episode') 71 | plt.show() 72 | 73 | # Determine the policy from Q* 74 | # Find V* from Q* 75 | policy = {} 76 | V = {} 77 | for s in grid.actions.keys(): 78 | a, max_q = max_dict(Q[s]) 79 | policy[s] = a 80 | V[s] = max_q 81 | 82 | # The proportion of time we spend updating each part of Q 83 | print('Update counts:') 84 | total = np.sum(list(update_counts.values())) 85 | for k, v in update_counts.items(): 86 | update_counts[k] = float(v) / total 87 | 88 | print_values(update_counts, grid) 89 | print('Values:') 90 | print_values(V, grid) 91 | print('Policy:') 92 | print_policy(policy, grid) 93 | -------------------------------------------------------------------------------- /machine_learning/supervised_learning/linear_algebra/cramer_rule.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "MgP9hj2kWbfE" 17 | }, 18 | "source": [ 19 | "#Solving ML problems using Cramer's rule" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 19, 25 | "metadata": { 26 | "id": "fbqkBf7_SXwD" 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "import numpy as np\n", 31 | "import scipy\n", 32 | "from sklearn.datasets import load_iris" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 33, 38 | "metadata": { 39 | "colab": { 40 | "base_uri": "https://localhost:8080/" 41 | }, 42 | "id": "eIQQgjrXSkrp", 43 | "outputId": "1abc3830-36b0-40df-b97e-3060354ea142" 44 | }, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "(4, 4)\n", 51 | "(4,)\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "iris = load_iris()\n", 57 | "X = iris.data[-4:]\n", 58 | "y = iris.target[-4:]\n", 59 | "\n", 60 | "print(X.shape)\n", 61 | "print(y.shape)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 17, 67 | "metadata": { 68 | "colab": { 69 | "base_uri": "https://localhost:8080/" 70 | }, 71 | "id": "xXAlqluKUe6D", 72 | "outputId": "1070973d-52ef-45b2-aa83-9257c4a1c421" 73 | }, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "array([-0.03868472, -0.1934236 , 0.61895551, -0.1934236 ])" 79 | ] 80 | }, 81 | "execution_count": 17, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "w = np.linalg.solve(X, y)\n", 88 | "w" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 37, 94 | "metadata": { 95 | "colab": { 96 | "base_uri": "https://localhost:8080/" 97 | }, 98 | "id": "UKb6KbjWjUCF", 99 | "outputId": "913fd481-5c51-4b9e-f414-642040421738" 100 | }, 101 | "outputs": [ 102 | { 103 | "name": "stderr", 104 | "output_type": "stream", 105 | "text": [ 106 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:1: FutureWarning: `rcond` parameter will change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.\n", 107 | "To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.\n", 108 | " \"\"\"Entry point for launching an IPython kernel.\n" 109 | ] 110 | }, 111 | { 112 | "data": { 113 | "text/plain": [ 114 | "array([-0.03868472, -0.1934236 , 0.61895551, -0.1934236 ])" 115 | ] 116 | }, 117 | "execution_count": 37, 118 | "metadata": {}, 119 | "output_type": "execute_result" 120 | } 121 | ], 122 | "source": [ 123 | "v = np.linalg.lstsq(X, y)\n", 124 | "v[0]" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 36, 130 | "metadata": { 131 | "colab": { 132 | "base_uri": "https://localhost:8080/" 133 | }, 134 | "id": "HUMqhE72ew17", 135 | "outputId": "3649039a-a2ac-4320-e950-65b8e08e48e6" 136 | }, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/plain": [ 141 | "1.9999999999999996" 142 | ] 143 | }, 144 | "execution_count": 36, 145 | "metadata": {}, 146 | "output_type": "execute_result" 147 | } 148 | ], 149 | "source": [ 150 | "y_pred = 0\n", 151 | "for i in range(4):\n", 152 | " y_pred += X[0][i] * w[i]\n", 153 | "\n", 154 | "y_pred" 155 | ] 156 | } 157 | ], 158 | "metadata": { 159 | "colab": { 160 | "name": "cramer_rule.ipynb", 161 | "provenance": [], 162 | "authorship_tag": "ABX9TyNNdU+ywNicUIyC5YaQYV2W", 163 | "include_colab_link": true 164 | }, 165 | "kernelspec": { 166 | "name": "python3", 167 | "display_name": "Python 3" 168 | }, 169 | "language_info": { 170 | "name": "python" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 0 175 | } 176 | -------------------------------------------------------------------------------- /machine_learning/unsupervised_learning/dimensionality_reduction/some_play_with_svd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "some_play_with_svd.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPQEYfVDZBXXdRFHs3wrTL+", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "XMJDsJWKB1xM" 35 | }, 36 | "source": [ 37 | "#Singular Value Decomposition" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "W-QTRSr29MTe" 44 | }, 45 | "source": [ 46 | "import numpy as np\n", 47 | "import matplotlib.pyplot as plt" 48 | ], 49 | "execution_count": null, 50 | "outputs": [] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "metadata": { 55 | "colab": { 56 | "base_uri": "https://localhost:8080/" 57 | }, 58 | "id": "Y1mPvwgg_dsa", 59 | "outputId": "53331e55-0c82-4f8f-9895-5e9f2355c504" 60 | }, 61 | "source": [ 62 | "from tensorflow.keras.datasets.mnist import load_data\n", 63 | "(X_train, _), (X_test, _) = load_data()" 64 | ], 65 | "execution_count": 26, 66 | "outputs": [ 67 | { 68 | "output_type": "stream", 69 | "text": [ 70 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", 71 | "11493376/11490434 [==============================] - 0s 0us/step\n" 72 | ], 73 | "name": "stdout" 74 | } 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "metadata": { 80 | "colab": { 81 | "base_uri": "https://localhost:8080/", 82 | "height": 282 83 | }, 84 | "id": "pss2K3y3_mpC", 85 | "outputId": "e744e489-6ff3-4526-e983-8e05360cdf66" 86 | }, 87 | "source": [ 88 | "plt.imshow(X_train[0], cmap='viridis')" 89 | ], 90 | "execution_count": 56, 91 | "outputs": [ 92 | { 93 | "output_type": "execute_result", 94 | "data": { 95 | "text/plain": [ 96 | "" 97 | ] 98 | }, 99 | "metadata": { 100 | "tags": [] 101 | }, 102 | "execution_count": 56 103 | }, 104 | { 105 | "output_type": "display_data", 106 | "data": { 107 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOZ0lEQVR4nO3dbYxc5XnG8euKbezamMQbB9chLjjgFAg0Jl0ZEBZQobgOqgSoCsSKIkJpnSY4Ca0rQWlV3IpWbpUQUUqRTHExFS+BBIQ/0CTUQpCowWWhBgwEDMY0NmaNWYENIX5Z3/2w42iBnWeXmTMv3vv/k1Yzc+45c24NXD5nznNmHkeEAIx/H+p0AwDag7ADSRB2IAnCDiRB2IEkJrZzY4d5ckzRtHZuEkjlV3pbe2OPR6o1FXbbiyVdJ2mCpH+LiJWl50/RNJ3qc5rZJICC9bGubq3hw3jbEyTdIOnzkk6UtMT2iY2+HoDWauYz+wJJL0TE5ojYK+lOSedV0xaAqjUT9qMk/WLY4621Ze9ie6ntPtt9+7Snic0BaEbLz8ZHxKqI6I2I3kma3OrNAaijmbBvkzRn2ONP1JYB6ELNhP1RSfNsz7V9mKQvSlpbTVsAqtbw0FtE7Le9TNKPNDT0tjoinq6sMwCVamqcPSLul3R/Rb0AaCEulwWSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJpmZxRffzxPJ/4gkfm9nS7T/3F8fUrQ1OPVBc9+hjdxTrU7/uYv3Vaw+rW3u893vFdXcOvl2sn3r38mL9uD9/pFjvhKbCbnuLpN2SBiXtj4jeKpoCUL0q9uy/FxE7K3gdAC3EZ3YgiWbDHpJ+bPsx20tHeoLtpbb7bPft054mNwegUc0exi+MiG22j5T0gO2fR8TDw58QEaskrZKkI9wTTW4PQIOa2rNHxLba7Q5J90paUEVTAKrXcNhtT7M9/eB9SYskbayqMQDVauYwfpake20ffJ3bI+KHlXQ1zkw4YV6xHpMnFeuvnPWRYv2d0+qPCfd8uDxe/JPPlMebO+k/fzm9WP/Hf1lcrK8/+fa6tZf2vVNcd2X/54r1j//k0PtE2nDYI2KzpM9U2AuAFmLoDUiCsANJEHYgCcIOJEHYgST4imsFBs/+bLF+7S03FOufmlT/q5jj2b4YLNb/5vqvFOsT3y4Pf51+97K6tenb9hfXnbyzPDQ3tW99sd6N2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs1dg8nOvFOuP/WpOsf6pSf1VtlOp5dtPK9Y3v1X+Kepbjv1+3dqbB8rj5LP++b+L9VY69L7AOjr27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQhCPaN6J4hHviVJ/Ttu11i4FLTi/Wdy0u/9zzhCcPL9af+Pr1H7ing67Z+TvF+qNnlcfRB994s1iP0+v/APGWbxZX1dwlT5SfgPdZH+u0KwZGnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMPOjxfrg6wPF+ku31x8rf/rM1cV1F/zDN4r1I2/o3HfK8cE1Nc5ue7XtHbY3DlvWY/sB25tqtzOqbBhA9cZyGH+LpPfOen+lpHURMU/SutpjAF1s1LBHxMOS3nsceZ6kNbX7aySdX3FfACrW6G/QzYqI7bX7r0qaVe+JtpdKWipJUzS1wc0BaFbTZ+Nj6Axf3bN8EbEqInojoneSJje7OQANajTs/bZnS1Ltdkd1LQFohUbDvlbSxbX7F0u6r5p2ALTKqJ/Zbd8h6WxJM21vlXS1pJWS7rJ9qaSXJV3YyibHu8Gdrze1/r5djc/v/ukvPVOsv3bjhPILHCjPsY7uMWrYI2JJnRJXxwCHEC6XBZIg7EAShB1IgrADSRB2IAmmbB4HTrji+bq1S04uD5r8+9HrivWzvnBZsT79e48U6+ge7NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2ceB0rTJr3/thOK6/7f2nWL9ymtuLdb/8sILivX43w/Xrc35+58V11Ubf+Y8A/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEUzYnN/BHpxfrt1397WJ97sQpDW/707cuK9bn3bS9WN+/eUvD2x6vmpqyGcD4QNiBJAg7kARhB5Ig7EAShB1IgrADSTDOjqI4Y36xfsTKrcX6HZ/8UcPbPv7BPy7Wf/tv63+PX5IGN21ueNuHqqbG2W2vtr3D9sZhy1bY3mZ7Q+3v3CobBlC9sRzG3yJp8QjLvxsR82t/91fbFoCqjRr2iHhY0kAbegHQQs2coFtm+8naYf6Mek+yvdR2n+2+fdrTxOYANKPRsN8o6VhJ8yVtl/Sdek+MiFUR0RsRvZM0ucHNAWhWQ2GPiP6IGIyIA5JukrSg2rYAVK2hsNuePezhBZI21nsugO4w6ji77TsknS1ppqR+SVfXHs+XFJK2SPpqRJS/fCzG2cejCbOOLNZfuei4urX1V1xXXPdDo+yLvvTSomL9zYWvF+vjUWmcfdRJIiJiyQiLb266KwBtxeWyQBKEHUiCsANJEHYgCcIOJMFXXNExd20tT9k81YcV67+MvcX6H3zj8vqvfe/64rqHKn5KGgBhB7Ig7EAShB1IgrADSRB2IAnCDiQx6rfekNuBheWfkn7xC+Upm0+av6VubbRx9NFcP3BKsT71vr6mXn+8Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzj7OufekYv35b5bHum86Y02xfuaU8nfKm7En9hXrjwzMLb/AgVF/3TwV9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7IeAiXOPLtZfvOTjdWsrLrqzuO4fHr6zoZ6qcFV/b7H+0HWnFesz1pR/dx7vNuqe3fYc2w/afsb207a/VVveY/sB25tqtzNa3y6ARo3lMH6/pOURcaKk0yRdZvtESVdKWhcR8yStqz0G0KVGDXtEbI+Ix2v3d0t6VtJRks6TdPBayjWSzm9VkwCa94E+s9s+RtIpktZLmhURBy8+flXSrDrrLJW0VJKmaGqjfQJo0pjPxts+XNIPJF0eEbuG12JodsgRZ4iMiFUR0RsRvZM0ualmATRuTGG3PUlDQb8tIu6pLe63PbtWny1pR2taBFCFUQ/jbVvSzZKejYhrh5XWSrpY0sra7X0t6XAcmHjMbxXrb/7u7GL9or/7YbH+px+5p1hvpeXby8NjP/vX+sNrPbf8T3HdGQcYWqvSWD6znyHpy5Kesr2htuwqDYX8LtuXSnpZ0oWtaRFAFUYNe0T8VNKIk7tLOqfadgC0CpfLAkkQdiAJwg4kQdiBJAg7kARfcR2jibN/s25tYPW04rpfm/tQsb5ken9DPVVh2baFxfrjN5anbJ75/Y3Fes9uxsq7BXt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUgizTj73t8v/2zx3j8bKNavOu7+urVFv/F2Qz1VpX/wnbq1M9cuL657/F//vFjveaM8Tn6gWEU3Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkGWffcn7537XnT767Zdu+4Y1ji/XrHlpUrHuw3o/7Djn+mpfq1ub1ry+uO1isYjxhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTgiyk+w50i6VdIsSSFpVURcZ3uFpD+R9FrtqVdFRP0vfUs6wj1xqpn4FWiV9bFOu2JgxAszxnJRzX5JyyPicdvTJT1m+4Fa7bsR8e2qGgXQOmOZn327pO21+7ttPyvpqFY3BqBaH+gzu+1jJJ0i6eA1mMtsP2l7te0ZddZZarvPdt8+7WmqWQCNG3PYbR8u6QeSLo+IXZJulHSspPka2vN/Z6T1ImJVRPRGRO8kTa6gZQCNGFPYbU/SUNBvi4h7JCki+iNiMCIOSLpJ0oLWtQmgWaOG3bYl3Szp2Yi4dtjy2cOedoGk8nSeADpqLGfjz5D0ZUlP2d5QW3aVpCW252toOG6LpK+2pEMAlRjL2fifShpp3K44pg6gu3AFHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlRf0q60o3Zr0l6ediimZJ2tq2BD6Zbe+vWviR6a1SVvR0dER8bqdDWsL9v43ZfRPR2rIGCbu2tW/uS6K1R7eqNw3ggCcIOJNHpsK/q8PZLurW3bu1LordGtaW3jn5mB9A+nd6zA2gTwg4k0ZGw215s+znbL9i+shM91GN7i+2nbG+w3dfhXlbb3mF747BlPbYfsL2pdjviHHsd6m2F7W21926D7XM71Nsc2w/afsb207a/VVve0feu0Fdb3re2f2a3PUHS85I+J2mrpEclLYmIZ9raSB22t0jqjYiOX4Bh+0xJb0m6NSJOqi37J0kDEbGy9g/ljIi4okt6WyHprU5P412brWj28GnGJZ0v6Svq4HtX6OtCteF968SefYGkFyJic0TslXSnpPM60EfXi4iHJQ28Z/F5ktbU7q/R0P8sbVent64QEdsj4vHa/d2SDk4z3tH3rtBXW3Qi7EdJ+sWwx1vVXfO9h6Qf237M9tJONzOCWRGxvXb/VUmzOtnMCEadxrud3jPNeNe8d41Mf94sTtC938KI+Kykz0u6rHa42pVi6DNYN42djmka73YZYZrxX+vke9fo9OfN6kTYt0maM+zxJ2rLukJEbKvd7pB0r7pvKur+gzPo1m53dLifX+umabxHmmZcXfDedXL6806E/VFJ82zPtX2YpC9KWtuBPt7H9rTaiRPZniZpkbpvKuq1ki6u3b9Y0n0d7OVdumUa73rTjKvD713Hpz+PiLb/STpXQ2fkX5T0V53ooU5fn5T0RO3v6U73JukODR3W7dPQuY1LJX1U0jpJmyT9l6SeLurtPyQ9JelJDQVrdod6W6ihQ/QnJW2o/Z3b6feu0Fdb3jculwWS4AQdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTx/65XcTNOWsh5AAAAAElFTkSuQmCC\n", 108 | "text/plain": [ 109 | "
" 110 | ] 111 | }, 112 | "metadata": { 113 | "tags": [], 114 | "needs_background": "light" 115 | } 116 | } 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "colab": { 123 | "base_uri": "https://localhost:8080/" 124 | }, 125 | "id": "XCtPgKUUALGk", 126 | "outputId": "f97e666d-6e8f-481c-9b09-13447338f29f" 127 | }, 128 | "source": [ 129 | "U, D, V = np.linalg.svd(X_train[0])\n", 130 | "\n", 131 | "print(U.shape)\n", 132 | "print(D.shape)\n", 133 | "print(V.shape)" 134 | ], 135 | "execution_count": 35, 136 | "outputs": [ 137 | { 138 | "output_type": "stream", 139 | "text": [ 140 | "(28, 28)\n", 141 | "(28,)\n", 142 | "(28, 28)\n" 143 | ], 144 | "name": "stdout" 145 | } 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "metadata": { 151 | "colab": { 152 | "base_uri": "https://localhost:8080/", 153 | "height": 282 154 | }, 155 | "id": "X7DWdYrRAQQ5", 156 | "outputId": "1460239c-cc61-4439-9d84-98d851213d3d" 157 | }, 158 | "source": [ 159 | "plt.imshow(U, cmap='viridis')" 160 | ], 161 | "execution_count": 53, 162 | "outputs": [ 163 | { 164 | "output_type": "execute_result", 165 | "data": { 166 | "text/plain": [ 167 | "" 168 | ] 169 | }, 170 | "metadata": { 171 | "tags": [] 172 | }, 173 | "execution_count": 53 174 | }, 175 | { 176 | "output_type": "display_data", 177 | "data": { 178 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAUNklEQVR4nO3dfWxV93kH8O/ji238im1ejLHNa1kITTeSuqwbqGVFS1ImjbTTsrIpomoWqjbRWrXThpI/mn+qoWpNl0lLK7qg0KyjadVkQQtaw1Ba1lZLcDLCe2JeDMb4BWyD37F977M/fKgM+Pcc576cc+nv+5Es2/fx8Xnu8X3uub7P+f1+oqogot9+BXEnQETRYLETeYLFTuQJFjuRJ1jsRJ6YFenOSsq0qLImyl0STWv1wstm/ETn/Igyya6x/l5MjAzJdLGMil1EHgTwLIAEgH9V1R3WzxdV1uBDW76WyS6JsuKt7c+Z8TU7vhxRJtl1es8zzljaL+NFJAHgXwB8GsBqAFtEZHW6v4+IciuT/9nXAjitqmdVdQzAjwBszk5aRJRtmRR7PYC2Kd9fDG67iYhsE5FmEWmeGBnKYHdElImcvxuvqjtVtUlVm2aVlOV6d0TkkEmxtwNonPJ9Q3AbEeWhTIr9EICVIrJMRIoAfA7A3uykRUTZlnbrTVUnROQJAD/DZOttl6oez1pmRBk6bLTX7tTWWiYy6rOr6j4A+7KUCxHlEC+XJfIEi53IEyx2Ik+w2Ik8wWIn8gSLncgTkY5nJ8omq48O+NlLt/DMTuQJFjuRJ1jsRJ5gsRN5gsVO5AkWO5En2HqjvMXWWnbxzE7kCRY7kSdY7ESeYLETeYLFTuQJFjuRJ1jsRJ6ItM+uBcB4uTt+vSZlbl/c435uKuq3993/8REzXnGoxN7+rqQztuzlCXPb9k8Um/GJMvt+l16yn5Pnnhh3xvp+p9DctuSKve/BRfa+F/+kzYxf+mdjFaDX7eW72Uefnn6qzx37T/fjlGd2Ik+w2Ik8wWIn8gSLncgTLHYiT7DYiTzBYifyRLR99gQwXq7u+MLr5vapAXcvXJLu3wsAxSXuXvQku89e9wtxxs79qd3Lrjli5zb/0FUzfvapIjM+eNXdyy65bPfRB5aEPN/bqePis8aFEwAm3qx2xsYa7V9edsl9zHNtZIGdW0l3fLkNv1fljKVGE85YRsUuIq0ABgAkAUyoalMmv4+IcicbZ/Y/UtUrWfg9RJRD/J+dyBOZFrsCeF1E3haRbdP9gIhsE5FmEWlODg1luDsiSlemL+PXq2q7iCwAsF9ETqnqwak/oKo7AewEgNkNIe/IEFHOZHRmV9X24HM3gFcArM1GUkSUfWkXu4iUiUjFja8B3A/gWLYSI6LsyuRlfC2AV0Tkxu/5d1X9L2sDFSBV5H4ln0jYPeGCMSPmHsYLABg/U2HG54SM665oGXDGBhrnmNvOvmon17HBHte9/AvvmvH3nrvLGas4PNvcdnj1qBlfvMfdtwWA1Cn7uA40umM1p0KuAWiw951Ls4bj66OHsXr8BcbUCmkXu6qeBfB76W5PRNFi643IEyx2Ik+w2Ik8wWIn8gSLncgT0S7ZXKBIlrvbLcl+eyhnmTFddOU5oy8HoKLNfl4bq7DjqVJ7GKulY53dxknU2ZcRn1n0u2a87jV3O3POMXuM0siJSjPe+fv2/S4IGTlc5O5YovyMPf/3QIN7eGyu1Zyy26VxtgXLLrlryPp78MxO5AkWO5EnWOxEnmCxE3mCxU7kCRY7kSdY7ESeiLbPngCk1D0GTwftdMovuXufOst+3tKQp7WBRrtv2r+k1Bmbe9xuNpd02/ereo891PPC/XafvqzdvRx12zftfRe/ZvfRG35hL3U90GgvR118zX3fWh+y++jF7pWJc67sgn3tw0CDfX1CLo1Wux/MajyMeWYn8gSLncgTLHYiT7DYiTzBYifyBIudyBMsdiJPRNtnTwE67n5+Keq1e91F/e4x68nZ9vPWeGnIePYqe7GasSXu5aTr37CXmr640Z5u+XqVHa9ttsfqt3zB/Wese8meSvraMruH37s+5PqFYfsagfJz7txGF9nXJxT3pT+HQKZa/speirriXESJTEM0vYWVeGYn8gSLncgTLHYiT7DYiTzBYifyBIudyBMsdiJPRNtnV4GMunvpNSfs/mFi1D2ePVVoP2/1L7V7+GPVIWs+X3P3fAe/edXctOon9tjn4UX2rlsftuNFl9y5jbuH4QMARu62l2wuO2r36Rv39Zjxq/e4x6yXdtoPv+vxTRuPspB1BuI0Xua+NsKatyH0HonILhHpFpFjU26rEZH9ItISfI7xz0JEMzGTp68XADx4y23bARxQ1ZUADgTfE1EeCy12VT0IoPeWmzcD2B18vRvAQ1nOi4iyLN1/TGpVtSP4uhNAresHRWSbiDSLSHNycDDN3RFRpjJ+F0JVFYDznTVV3amqTaralCi3BxcQUe6kW+xdIlIHAMHn7uylRES5kG6x7wWwNfh6K4BXs5MOEeVKaJ9dRPYA2ABgnohcBPANADsA/FhEHgVwHkBIJzigACbcPcKyDnvc9liVu58sIW3yoUb7Bwr77ee94l533j1XFprbLn3z1vc3b3ZpY40ZX/iSGcZojfv6hP6l9v26++/azfi19UvNeMsjdu6pIndui193ryEAANer4xvPPnC3PdZ+ztH4cqs6634sJ4ypFUKLXVW3OEIbw7YlovyRv5cJEVFWsdiJPMFiJ/IEi53IEyx2Ik9EPMQVKDC6LQXX7fZY733u4ZZVZ+w2jlbY8UX7QqaiLnPH++6yh8+ee9geFFjVZF+TNNA/34wnjVWTx+bYw4bb/nKFGS+5Ym9v/T0BoPaQe/uCv+2yN97bYMdzaPGr9hTb15ZHlMg0+la6H2/JX7m345mdyBMsdiJPsNiJPMFiJ/IEi53IEyx2Ik+w2Ik8EW2fHYAkjWlwQ6aDHl7o7tlWttp90aL2IjNefLnfjA8tdE8HXfuWvWRzx3qjEQ4g8cI8M971qZDhucZS14kP2VOBjSTs2YMGPmwP9ax5yx7qefk+99902XZ7im38oR3Ope4muzSK7VHLObXg/9xDwduGjesacpEMEeUfFjuRJ1jsRJ5gsRN5gsVO5AkWO5EnWOxEnoi8z65GO1wL7F55stI9eDpZbI8pL7tohjG6wF7beKjOnVtJj73vBc32oO+iAbuXPbujxIwnS4zpmv/8qLlt3+f/wIxXtqbMOMS+xmCs0j0HwfuP2D3+yjP2rnOp/II9jn+83H6s5tLIPHfZpma58+KZncgTLHYiT7DYiTzBYifyBIudyBMsdiJPsNiJPBF5n92SChnPjkJ37zNZaPc9S3rsfvFQnX0orPnXe1fZ29acsvvsV5e7e9EAUPcru5c9Mt89prznMbuPPjLPPm49a+x5AMIs+49RZ0z+1+5lX77Xvr4gl4oG87fPvuTx952xlrfdxzv0zC4iu0SkW0SOTbntaRFpF5HDwcemD5owEUVrJi/jXwDw4DS3f0dV1wQf+7KbFhFlW2ixq+pBADFOwkNE2ZDJG3RPiMiR4GW+czEzEdkmIs0i0pwcGspgd0SUiXSL/bsAVgBYA6ADwLddP6iqO1W1SVWbEmVlae6OiDKVVrGrapeqJlU1BeD7ANZmNy0iyra0il1E6qZ8+xkAx1w/S0T5IbTPLiJ7AGwAME9ELgL4BoANIrIGgAJoBfDFGe1NATGmQB+tCUlH3OO+UyGbllyxx4x3fNKe/7z6qPt5sX+F3ZO9cLfd45/7a3s8/ES5Ha8+5F7nvOOBOmcMAK5/ZNiM179k99l7VtsHvu1L7msMZr1rj2cvtKe8z6mee+w++uwrESUyjbZ/WumMjXW5r9kILXZV3TLNzc/PKCsiyhu8XJbIEyx2Ik+w2Ik8wWIn8gSLncgTkQ5xleDDZbTKbncUFLpbWBOl9rYj8+0W0rxlPWa8+t/cwy2vV9stJKjd1kvZYVz5iP1n6vwL99LHyQ677bfqqT4z3vbZejM+96Q9fHek033V5FCDuWmsJBXfENYwXWvd5+jxX7u345mdyBMsdiJPsNiJPMFiJ/IEi53IEyx2Ik+w2Ik8Ee1U0goUjLnDI7V2bzM16h7qOdRgDzMdarB/d+FIsRkf+5i7l76g2Z7queSMPR7y5NcWmvHy1pAloV90X0MwWG/f79N/vciMj9UZfzAARf32RQKll91jmkdq82om85tUtNqPp7HK+PrwK190XxvR0+M+3jyzE3mCxU7kCRY7kSdY7ESeYLETeYLFTuQJFjuRJyJtdKoASWNYuaTs3iYm3M9NyVJ73HZi2H5eK/65e0w4ANT9Wasz1nJoibltstiezrn0op3brKGQqao3ubef3Wn3g+cesY9b6X473vonZhiphPsagXlH7em9r64IGeifQ1UtI2a8+6OlEWVyu7H57jkCUmfdjwWe2Yk8wWIn8gSLncgTLHYiT7DYiTzBYifyBIudyBPRDihOKCaq3ONtZ3fa6RSed8ev32P3RWt/bs8bX3mi14yfXeDupUvIUawM6aOPu9umAIBksd0rX/yauxd+eY09Fl5Dnu67P+peAhgA6g/a88YX97rHw194wD0XPxDvssg137pgxrtfWhVRJrfrW+l+LCffdT9WQs/sItIoIm+IyAkROS4iXwlurxGR/SLSEnyuTidxIorGTF7GTwD4uqquBvBxAI+LyGoA2wEcUNWVAA4E3xNRngotdlXtUNV3gq8HAJwEUA9gM4DdwY/tBvBQrpIkosx9oDfoRGQpgHsBvAmgVlU7glAngFrHNttEpFlEmpODQxmkSkSZmHGxi0g5gJ8C+Kqq9k+NqaoCmHa0hqruVNUmVW1KlIe8E0VEOTOjYheRQkwW+g9V9eXg5i4RqQvidQC6c5MiEWVDaOtNRATA8wBOquozU0J7AWwFsCP4/Gro3goAFLtbb+Xn7TaRNS3x8Dr7X4SCpD1csm/NXDM+95h7mGnHOnNTVJ2x21NDIVMqa8isxV0fc9+3WaP2tuMhS10PrLKHoSZG7OMqj11zxvSI3XqL07UvLbB/YEMkaUzLasVaj5WZ9NnXAXgEwFERORzc9iQmi/zHIvIogPMAHp5hrkQUg9BiV9VfAnA9X2zMbjpElCu8XJbIEyx2Ik+w2Ik8wWIn8gSLncgT0Q5xTQEYcffSrT46AEjS3esevm73eycW2Xd1qNGernnuEXd81fd6zG3bH5hvxocX2fueqLCPS+kF930r7rV/d9+H7fjqf7hsxk/9jb3cdOXP3NcvLLxgX39wbVl8Szq3frbGjBfbI6JzanCZe0hzylh5nGd2Ik+w2Ik8wWIn8gSLncgTLHYiT7DYiTzBYifyRKSNTEkKivrcfXZNhPXZ3bHR8xXmtmPrh814osVegjcx5u5tdn5ynrntrI32nMgrt9tjykfrys14qtjoldsrLkNS9kOg5TF7uem6/wlZ8rnDPc9A//L8Hc8+57R9v0Zr4jtPzj3sfrx0Gw9zntmJPMFiJ/IEi53IEyx2Ik+w2Ik8wWIn8gSLncgTkfbZC8aBki53j7B/SfrpVLSG/ECr3UcPM1hvz2lv2m/34S/l8Ry9pR12POy4DNbfmasAhfXRD29/zoyv2fHlbKZzk7FKY95448/BMzuRJ1jsRJ5gsRN5gsVO5AkWO5EnWOxEnmCxE3liJuuzNwL4AYBaAApgp6o+KyJPA3gMwI2JxZ9U1X25SpQon4T10ePsw7vM5CqWCQBfV9V3RKQCwNsisj+IfUdV/zF36RFRtsxkffYOAB3B1wMichJAfa4TI6Ls+kD/s4vIUgD3AngzuOkJETkiIrtEpNqxzTYRaRaR5okR9xRFRJRbMy52ESkH8FMAX1XVfgDfBbACwBpMnvm/Pd12qrpTVZtUtWlWyZ15nTTRb4MZFbuIFGKy0H+oqi8DgKp2qWpSVVMAvg9gbe7SJKJMhRa7iAiA5wGcVNVnptw+ddrRzwA4lv30iChbZvJu/DoAjwA4KiKHg9ueBLBFRNZgsh3XCuCLOcmQ6A6USWsuV225mbwb/0sA0w2gZU+d6A7CK+iIPMFiJ/IEi53IEyx2Ik+w2Ik8wWIn8kSkU0kT0SSrl56r4bE8sxN5gsVO5AkWO5EnWOxEnmCxE3mCxU7kCRY7kSdEVaPbmchlAOen3DQPwJXIEvhg8jW3fM0LYG7pymZuS1R1/nSBSIv9tp2LNKtqU2wJGPI1t3zNC2Bu6YoqN76MJ/IEi53IE3EX+86Y92/J19zyNS+AuaUrktxi/Z+diKIT95mdiCLCYifyRCzFLiIPish7InJaRLbHkYOLiLSKyFEROSwizTHnsktEukXk2JTbakRkv4i0BJ+nXWMvptyeFpH24NgdFpFNMeXWKCJviMgJETkuIl8Jbo/12Bl5RXLcIv+fXUQSAN4H8McALgI4BGCLqp6INBEHEWkF0KSqsV+AISKfADAI4Aeqek9w27cA9KrqjuCJslpV/z5PcnsawGDcy3gHqxXVTV1mHMBDAD6PGI+dkdfDiOC4xXFmXwvgtKqeVdUxAD8CsDmGPPKeqh4E0HvLzZsB7A6+3o3JB0vkHLnlBVXtUNV3gq8HANxYZjzWY2fkFYk4ir0eQNuU7y8iv9Z7VwCvi8jbIrIt7mSmUauqHcHXnQBq40xmGqHLeEfplmXG8+bYpbP8eab4Bt3t1qvqfQA+DeDx4OVqXtLJ/8HyqXc6o2W8ozLNMuO/EeexS3f580zFUeztABqnfN8Q3JYXVLU9+NwN4BXk31LUXTdW0A0+d8ecz2/k0zLe0y0zjjw4dnEufx5HsR8CsFJElolIEYDPAdgbQx63EZGy4I0TiEgZgPuRf0tR7wWwNfh6K4BXY8zlJvmyjLdrmXHEfOxiX/5cVSP/ALAJk+/InwHwVBw5OPJaDuDd4ON43LkB2IPJl3XjmHxv41EAcwEcANAC4L8B1ORRbi8COArgCCYLqy6m3NZj8iX6EQCHg49NcR87I69IjhsvlyXyBN+gI/IEi53IEyx2Ik+w2Ik8wWIn8gSLncgTLHYiT/w/xSX3gpUMTO4AAAAASUVORK5CYII=\n", 179 | "text/plain": [ 180 | "
" 181 | ] 182 | }, 183 | "metadata": { 184 | "tags": [], 185 | "needs_background": "light" 186 | } 187 | } 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "metadata": { 193 | "colab": { 194 | "base_uri": "https://localhost:8080/", 195 | "height": 82 196 | }, 197 | "id": "YkR97H0DAc8D", 198 | "outputId": "1aac799d-f96a-423f-81c4-70b70a34c170" 199 | }, 200 | "source": [ 201 | "plt.imshow(D.reshape(1, 28), cmap='viridis')" 202 | ], 203 | "execution_count": 54, 204 | "outputs": [ 205 | { 206 | "output_type": "execute_result", 207 | "data": { 208 | "text/plain": [ 209 | "" 210 | ] 211 | }, 212 | "metadata": { 213 | "tags": [] 214 | }, 215 | "execution_count": 54 216 | }, 217 | { 218 | "output_type": "display_data", 219 | "data": { 220 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAAvCAYAAADginEnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAGjklEQVR4nO3dXYxcZR3H8e+PvoTYEtwC1gYRRQyQGHzbkJgQ0/iKXtAaItqrcqHlwka9w6rBBmNsiBrvTFAbNFHQ1LdqiAhRo4kvsCVNX4CWgjWy2bZSFFlE6u7+vDhnYVhntnt2zu7MnPl9ks3MnPnPOc8/T/Z/Zp55zjOyTURENN85vW5AREQsjxT8iIghkYIfETEkUvAjIoZECn5ExJBIwY+IGBJdFXxJ6yTdJ+mx8nakQ9y0pP3l395ujhkREYujbubhS7odeNr2LkmfAUZs39ImbtL22i7aGRERXeq24B8BNtqekLQB+K3tK9rEpeBHRPRYt2P4621PlPdPAOs7xJ0raUzSnyRt7vKYERGxCCvPFiDpfuDVbZ76XOsD25bU6ePCpbbHJV0G/FrSQduPtznWNmAbwJpX6O1XXr76rAnMOvzcugXHAsy8sKJSvKYrhbPy+WrxqvhJSzMV939mqtoLqn7yqxo/s8T7r6rC/qu3pGqulQ8Q8aJn+cdTti9q99yyDOnMec2dwC9s75kvbvTN5/qBey9ZcFuufmDLgmMB/v34+ZXiVz1T7cPQhYeqnSHOOVOtH1ZNVivgq8f/WSle/zlTKZ6pau3x89XOiJ6ueIabrniGnln4/iv/z1Rsi6ueDGcq5hqNdr/37LM92u65bod09gJbJV0HHABeVX55+yJJI5LOk/QDSU8AHwWqVZ+IiOhatwV/F/BeisJ/FLgS2CLpw5K+VcZcBTwKvBt4DrgT+HiXx42IiIrOOoY/H9unJX0e2Gn7/QCS7gYut/2xMuYPkg6VMX+UtBI4IUnO2swREcumjittLwb+1vL4yXJb2xjbU8AzwAU1HDsiIhaor5ZWkLStnL459vfT+SIqIqJOdRT8caB1Os1rym1tY8ohnfOB03N3ZPsO26O2Ry+6oNq0yYiImF9XY/ilB4Gryxk4M8Aa4D1zYk5SzL8/ArwS+EvG7yMillcd7/BnC7fKPwBLuk3S9eXj31O8y18LnAI+UsNxIyKigjre4V8DHGiZpbMD2GT71paY/wK/tL29huNFRMQiLNcsHYAbJB2QtEfSwi+hjYiIWtTxDn8hfg7cZfsFSTcD3wHeNTeodS0dYHLFhmNH2uzrQuCp/9/8xdoaW4ej9e2qQ76NlXybbZjy7VWul3Z6oqu1dAAkvYOXX3i1A8D2lzvEr6BYQ7/aYjYvvX6s0zoRTZR8my35Nlc/5lrHkM6DwBslvV7Saoq1cl72q1blwmqzrgceqeG4ERFRQddDOranJG0H7gVWALttH5Z0GzBmey/wyXLGzhTwNHBTt8eNiIhqahnDt30PcM+cbbe23N8B7KjjWMAdNe1nUCTfZku+zdV3uXY9hh8REYOhr9bSiYiIpTNQBV/SdZKOSDo294dWmkjScUkHJe2XNNbr9tRN0m5Jp8rls2e3rZN0n6THytuRXraxTh3y3SlpvOzj/ZI+2Ms21kXSJZJ+I+lhSYclfarc3sj+nSffvurfgRnSKadzHqX4wZUnKWYHbbH9cE8btoQkHQdGbTdy3rKkdwKTwHdtv6ncdjvFtN1d5Ul9xPYtvWxnXTrkuxOYtP2VXratbuXMvA22H5J0HrAP2EwxYaNx/TtPvjfSR/07SO/wrwGO2X7C9hngbmBTj9sUXbD9O4pZW602UVyYR3m7eVkbtYQ65NtItidsP1Tef5ZiKvbFNLR/58m3rwxSwV/oEg5NYuBXkvaVVyEPg/W2J8r7J4D1vWzMMtleLjuyuylDHK0kvQ54K/BnhqB/5+QLfdS/g1Twh9G1tt8GfAD4RDkkMDTKJbQHY8xx8b4BvAF4CzABfLW3zamXpLXAj4BP2/5X63NN7N82+fZV/w5SwV/ID600iu3x8vYU8BOKYa2mOzl7ZXZ5e6rH7VlStk/anrY9A3yTBvWxpFUUxe97tn9cbm5s/7bLt9/6d5AK/lmXcGgSSWvKL3+QtAZ4H3Bo/lc1wl5ga3l/K/CzHrZlyc1ZduRDNKSPJQn4NvCI7a+1PNXI/u2Ub7/178DM0gEopzR9nZeWcPhSj5u0ZCRdRvGuHooror/ftHwl3QVspFhV8CTwBeCnwA+B1wJ/BW603YgvOjvku5Hi476B48DNLWPcA0vStRQ/fHSQ4pfwAD5LMa7duP6dJ98t9FH/DlTBj4iIxRukIZ2IiOhCCn5ExJBIwY+IGBIp+BERQyIFPyJiSKTgR0QMiRT8iIghkYIfETEk/getTpto5Zz1CwAAAABJRU5ErkJggg==\n", 221 | "text/plain": [ 222 | "
" 223 | ] 224 | }, 225 | "metadata": { 226 | "tags": [], 227 | "needs_background": "light" 228 | } 229 | } 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "metadata": { 235 | "colab": { 236 | "base_uri": "https://localhost:8080/", 237 | "height": 282 238 | }, 239 | "id": "NNSBp4GSAjUj", 240 | "outputId": "67f6af3b-33d8-411a-b0a4-09096771e1a0" 241 | }, 242 | "source": [ 243 | "plt.imshow(V, cmap='viridis')" 244 | ], 245 | "execution_count": 55, 246 | "outputs": [ 247 | { 248 | "output_type": "execute_result", 249 | "data": { 250 | "text/plain": [ 251 | "" 252 | ] 253 | }, 254 | "metadata": { 255 | "tags": [] 256 | }, 257 | "execution_count": 55 258 | }, 259 | { 260 | "output_type": "display_data", 261 | "data": { 262 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAARe0lEQVR4nO3dfWxV530H8O8Xv2BjzIvBGMJbIKXRKFuAemwKUZSsL6KpJuhURaFTRNdoVEqjNW00NcomNfsvm1aibGoz0YJCozRVpyYKW1EaSqPRaGsWh1FeAimEQMEF28QQG+N3//aHTyo38fM7N/f1wPP9SMjX9/G558exvz7X93ef89DMICLXvymVLkBEykNhF4mEwi4SCYVdJBIKu0gkqsu6s2kNVjOjqZy7LBpzfi1abUpHI+1X6mja3pky7uw/bdNS8/Y/lrLtWErxacM14R3YqL8xR9LG/X1XynBPN0au9k1afEFhJ7kBwBMAqgB8z8we876+ZkYTlm/5eiG7rJiRaeGxgYXD7ras89NsV/1vQ9oPnlU7Ya9OSVTaL4OUfaeqcWob8n8LVvdUueNjKb9kq+ZfDY4N99W629ZeqPHHL1f6t+jkTu3aFhzL+2k8ySoA3wbwGQArAWwmuTLfxxOR0irkb/Z1AE6a2SkzGwLwQwAbi1OWiBRbIWFfCODshM/PJff9HpJbSbaRbBvt7ytgdyJSiJK/Gm9m282s1cxaq+obSr07EQkoJOztABZP+HxRcp+IZFAhYX8NwAqSy0jWArgHwO7ilCUixZZ3683MRkg+AOCnGG+97TSzo0WrLGOqnZcb5rzqH8bpF/zHrur3W3PDjX4Lqn9OeHx4ul/b4Gx3GIPNfm2cOeSO19aFG9LV1f5jX6lx+p2A+/YCAJj3Qnj7mqt+S3LM6dEDwLvL/O9JFhXUZzezPQD2FKkWESkhvV1WJBIKu0gkFHaRSCjsIpFQ2EUiobCLRKKs89mvZYNzw03dmtu73W07h1KmS+6f4Y7PftPvZc85FJ7KCfpTMYdm+lM9r873ax9oqnfHx5yfsMHZKY3yOSk9/mH/XHVxTfj/PmXY75NbygzWqZf88SzSmV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQq23HBWy/OXwkH+YBz7it5j6bvDbX9UDTvss7XLNKS2mtCu4elN/AaBqMDxW946/86oB/7g1nvVrm/Mfx8ODzf4lzS+vaXbHe5dk8+qyHp3ZRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIqM+eoxlvh8caU6aojt7s98mvrO13x+uXDLjj/cdnBcdmnnQ3xew3/ceG+b3srrX+5Z77W8Lbj6T8v2bOdKbuAugbneOOV3/yo8GxnqX+FNe+Zf57HxpOu8OZpDO7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJ9dlz1HtjeKzbuWQxADQd8HvVH/m239N9e1O4jw4AXBruR78zx79U9NAM/1LQc476l7Ge97rfC7+yqC441tfl7/vSSv/Hs2Hdu+5414zw+x8a2t1NseQ//QsBvLMqsiWbSZ4G0AtgFMCImbUWoygRKb5inNnvNLOLRXgcESkh/c0uEolCw24AXiL5Osmtk30Bya0k20i2jfanXLBMREqm0Kfxt5lZO8l5APaSPG5m+yd+gZltB7AdAOrnLy7kuo0iUoCCzuxm1p587ATwPIB1xShKRIov77CTbCDZ+N5tAJ8GcKRYhYlIcRXyNL4FwPMcXxK4GsAPzOzFolSVQVWD4V76yFDKssiz0sanuuNNh/2/frrHwnPKp9T52475U+3BlD+8Lq/we+UX7wxfOL6uwe/hz/q5f52Aed856o5f+NqtwbGem/w++tAM/8BMGXGHMynvsJvZKQC3FLEWESkhtd5EIqGwi0RCYReJhMIuEgmFXSQSmuKao8Hm8DTUqj7/d+bUS37/ylJmSw7f0+2ON/x0bnDMWzIZAGp7/RbU6FT//3bDX51yx7t+vSQ4tuRf3U3RtdY/bh1/E26tAcBIeHYtqvv8dmhdt7/voRlasllEMkphF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpFQnz1HjSfCzfDZJ4bdbX97r9/s7j3R4I7zQLiPDgDmLIvc0uZfprpnif8jUJfyHoD2p5b7X3BreC7oiS9Mdzdd8UyPO356oz8FdnBR+PuycI//H5v5y3Pu+G82L3XHs0hndpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEuqz5+jq/HAve6Tev+xw48v+YZ562Z87Pe9+f854/zdagmMj0/zaajdccscvHp/jji/bPeCOD80KX2q6b6E/lz6tjz79jDsMjoWXq+4JT7MHAFxa4ffRr8VLSevMLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQn32HC38RbixeuUG/zBeWun30Wvf9X/nDv7AnzPeVBPudZ/6vD9ve03jZXd87j90uOMnH/mYO17jTEmfdcy/9vqVlCnjPSv88RlvhY/7QJO/7yn+JQqAa++y8elndpI7SXaSPDLhviaSe0meSD7OLm2ZIlKoXJ7GPwVgw/vuexjAPjNbAWBf8rmIZFhq2M1sP4D3rz+0EcCu5PYuAJuKXJeIFFm+L9C1mNn55PYFAME3Z5PcSrKNZNtof1+euxORQhX8aryZGYDgKyFmtt3MWs2starev7CiiJROvmHvILkAAJKPncUrSURKId+w7wawJbm9BcALxSlHREoltc9O8lkAdwCYS/IcgG8CeAzAj0jeB+AMgLtLWWQW2BSnsepPy8ZYg3/t9ub/8h/g3aX+t+nMZ8MLkTe+5TeET77lN6v7/95/j8D8X/q1D3wpPF+++W/9x+6+vdEdb9o/1R2/7Ly/YemeIXfbU1/wj9v04+G58lmVGnYz2xwY+kSRaxGREtLbZUUiobCLREJhF4mEwi4SCYVdJBKa4pqjofvfCY7Ne9Bv0wx81m8RjdT5LaaRlDcezv+fcGvv0hevuNtOe26mOz6YMp+xc61/vmj53qzg2Om/8H/8eMlv603xO5qo6Q3XdnGV/z2ZedB/7NFwtzOzdGYXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhPnuunm4ODnWt9/vsV3/r94trZ6Vc1jhleeCLq8Lfxmk/8fvoV1v8fQ/P8JvZcw/454uOj4cvZT37uH9cOue5w+if69e+9Ce9wbFbth92t/3Vl/xLZLd/0j+uWaQzu0gkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4SCfXZc/QnD7UFx175tz92t134c/+xu//SWdcYQPOOae74+fXhb2P1QMrlmpf5ve6m//PPB0y5jPZQc7hPXxs+pAAAS7lac9WgX9vZT4WvE3D23291t63/uH/crkU6s4tEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVCfPUe7j/xRcKx5wN+2/U5/fNHT/oXhu//A/zaNNISb3UON/pzvmsv++GidPw7z+9FVveH57F23+I998w7//Qe/edgdxpJ/DD/+25umu9vWfb7DHR/cN9/feQalntlJ7iTZSfLIhPseJdlO8mDy767SlikihcrlafxTADZMcv/jZrY6+benuGWJSLGlht3M9gPoLkMtIlJChbxA9wDJQ8nT/OCKYCS3kmwj2Tba31fA7kSkEPmG/UkANwFYDeA8gG+FvtDMtptZq5m1VtWnrFAoIiWTV9jNrMPMRs1sDMB3AawrblkiUmx5hZ3kggmffg7AkdDXikg2pPbZST4L4A4Ac0meA/BNAHeQXA3AAJwG8OUS1pgJ046FF+S+tNLvNbf8t99PPvdn/r5revzHv/nJi8Gxnif8CeczH29yx8+vr3HHF7/U744P/flgcKzqZ/7i773L/V744Dl3GB3rwueyqqv+tr0vpvTR/cOSSalhN7PNk9y9owS1iEgJ6e2yIpFQ2EUiobCLREJhF4mEwi4SCVrKFMViqp+/2JZv+XrZ9jdR2iWPDz30HXf8Dx+/v4jViJTGqV3b0H/h7KS9Xp3ZRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIXDeXkj78tcL65Oqjy/VOZ3aRSCjsIpFQ2EUiobCLREJhF4mEwi4SCYVdJBLXVJ/d66WrTy7i05ldJBIKu0gkFHaRSCjsIpFQ2EUiobCLREJhF4lEpvrshc5JF5Gw1DM7ycUkXyb5BsmjJL+a3N9Eci/JE8lHf7FtEamoXJ7GjwB4yMxWAvhTAF8huRLAwwD2mdkKAPuSz0Uko1LDbmbnzexAcrsXwDEACwFsBLAr+bJdADaVqkgRKdyHeoGO5I0A1gB4FUCLmZ1Phi4AaAlss5VkG8m20f6+AkoVkULkHHaS0wH8GMCDZtYzcczGV4ecdIVIM9tuZq1m1lpV31BQsSKSv5zCTrIG40F/xsyeS+7uILkgGV8AoLM0JYpIMaS23kgSwA4Ax8xs24Sh3QC2AHgs+fhC2mN9rKUL/6tpqiIVkUuffT2AewEcJnkwue8RjIf8RyTvA3AGwN2lKVFEiiE17Gb2CoBJF3cH8IniliMipaK3y4pEQmEXiYTCLhIJhV0kEgq7SCTKOsX1aEezeulyXbgWp2PrzC4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUiobCLRCJTl5IWuVak9dGz2IfXmV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYT67CIlUEgfvlQ9eJ3ZRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFI5LI++2IA3wfQAsAAbDezJ0g+CuCvAXQlX/qIme0pVaEi1xOvl16qufC5vKlmBMBDZnaAZCOA10nuTcYeN7N/zmvPIlJWuazPfh7A+eR2L8ljABaWujARKa4P9Tc7yRsBrAHwanLXAyQPkdxJcnZgm60k20i2jfb3FVSsiOQv57CTnA7gxwAeNLMeAE8CuAnAaoyf+b812XZmtt3MWs2staq+oQgli0g+cgo7yRqMB/0ZM3sOAMysw8xGzWwMwHcBrCtdmSJSqNSwkySAHQCOmdm2CfcvmPBlnwNwpPjliUix5PJq/HoA9wI4TPJgct8jADaTXI3xdtxpAF8uSYUikSlkeuy6F7uCY7m8Gv8KAE4ypJ66yDVE76ATiYTCLhIJhV0kEgq7SCQUdpFIKOwikYjmUtJp0wZX/Yvf2+RoMasRyZ/Xhz/VsS04pjO7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJmln5dkZ2ATgz4a65AC6WrYAPJ6u1ZbUuQLXlq5i1LTWz5skGyhr2D+ycbDOz1ooV4MhqbVmtC1Bt+SpXbXoaLxIJhV0kEpUO+/YK79+T1dqyWheg2vJVltoq+je7iJRPpc/sIlImCrtIJCoSdpIbSL5J8iTJhytRQwjJ0yQPkzxIsq3Ctewk2UnyyIT7mkjuJXki+TjpGnsVqu1Rku3JsTtI8q4K1baY5Msk3yB5lORXk/sreuycuspy3Mr+NzvJKgC/BvApAOcAvAZgs5m9UdZCAkieBtBqZhV/AwbJ2wFcAfB9M1uV3PdPALrN7LHkF+VsM/tGRmp7FMCVSi/jnaxWtGDiMuMANgH4Iip47Jy67kYZjlslzuzrAJw0s1NmNgTghwA2VqCOzDOz/QC633f3RgC7ktu7MP7DUnaB2jLBzM6b2YHkdi+A95YZr+ixc+oqi0qEfSGAsxM+P4dsrfduAF4i+TrJrZUuZhItZnY+uX0BQEsli5lE6jLe5fS+ZcYzc+zyWf68UHqB7oNuM7O1AD4D4CvJ09VMsvG/wbLUO81pGe9ymWSZ8d+p5LHLd/nzQlUi7O0AFk/4fFFyXyaYWXvysRPA88jeUtQd762gm3zsrHA9v5OlZbwnW2YcGTh2lVz+vBJhfw3ACpLLSNYCuAfA7grU8QEkG5IXTkCyAcCnkb2lqHcD2JLc3gLghQrW8nuysox3aJlxVPjYVXz5czMr+z8Ad2H8Ffm3APxdJWoI1LUcwK+Sf0crXRuAZzH+tG4Y469t3AdgDoB9AE4A+BmApgzV9jSAwwAOYTxYCypU220Yf4p+CMDB5N9dlT52Tl1lOW56u6xIJPQCnUgkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4Sif8Hwv4m0oErNXcAAAAASUVORK5CYII=\n", 263 | "text/plain": [ 264 | "
" 265 | ] 266 | }, 267 | "metadata": { 268 | "tags": [], 269 | "needs_background": "light" 270 | } 271 | } 272 | ] 273 | } 274 | ] 275 | } 276 | -------------------------------------------------------------------------------- /modern_approach/forward_forward/forward_forward_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyOGnxSridVLcFl/Z7lVc4J1", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU", 18 | "gpuClass": "standard" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "#Forward-Forward propagation" 35 | ], 36 | "metadata": { 37 | "id": "ZAovmTs66ICD" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": { 44 | "id": "gYFTLW4K5qay" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "import torch\n", 49 | "import torch.nn as nn\n", 50 | "from tqdm import tqdm\n", 51 | "from torch.optim import Adam\n", 52 | "from torchvision.datasets import MNIST\n", 53 | "from torchvision.transforms import Compose, ToTensor, Normalize, Lambda\n", 54 | "from torch.utils.data import DataLoader" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "source": [ 60 | "def MNIST_loaders(train_batch_size=50000, test_batch_size=10000):\n", 61 | " transform = Compose([\n", 62 | " ToTensor(),\n", 63 | " Normalize((0.1307,), (0.3081,)),\n", 64 | " Lambda(lambda x: torch.flatten(x))\n", 65 | " ])\n", 66 | " train_loader = DataLoader(\n", 67 | " MNIST('./data/', train=True, download=True, transform=transform),\n", 68 | " batch_size=train_batch_size, shuffle=True\n", 69 | " )\n", 70 | " test_loader = DataLoader(\n", 71 | " MNIST('./data/', train=False, download=True, transform=transform),\n", 72 | " batch_size=test_batch_size, shuffle=False\n", 73 | " )\n", 74 | " return train_loader, test_loader\n", 75 | "\n", 76 | "def overlay_y_on_x(x, y):\n", 77 | " x_ = x.clone()\n", 78 | " x_[:, :10] *= 0.0\n", 79 | " x_[range(x.shape[0]), y] = x.max()\n", 80 | " return x_" 81 | ], 82 | "metadata": { 83 | "id": "BeQ6BjzS9JLZ" 84 | }, 85 | "execution_count": 3, 86 | "outputs": [] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "source": [ 91 | "class Net(torch.nn.Module):\n", 92 | " def __init__(self, dims):\n", 93 | " super().__init__()\n", 94 | " self.layers = []\n", 95 | " for d in range(len(dims) - 1):\n", 96 | " self.layers += [Layer(dims[d], dims[d + 1]).cuda()]\n", 97 | "\n", 98 | " def predict(self, x):\n", 99 | " goodness_per_label = []\n", 100 | " for label in range(10):\n", 101 | " h = overlay_y_on_x(x, label)\n", 102 | " goodness = []\n", 103 | " for layer in self.layers:\n", 104 | " h = layer(h)\n", 105 | " goodness += [h.pow(2).mean(1)]\n", 106 | " goodness_per_label += [sum(goodness).unsqueeze(1)]\n", 107 | " goodness_per_label = torch.cat(goodness_per_label, 1)\n", 108 | " return goodness_per_label.argmax(1)\n", 109 | "\n", 110 | " def train(self, x_pos, x_neg):\n", 111 | " h_pos, h_neg = x_pos, x_neg\n", 112 | " for i, layer in enumerate(self.layers):\n", 113 | " print('Training layer', i, '...')\n", 114 | " h_pos, h_neg = layer.train(h_pos, h_neg)" 115 | ], 116 | "metadata": { 117 | "id": "qRvPhl89--ap" 118 | }, 119 | "execution_count": 4, 120 | "outputs": [] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "source": [ 125 | "class Layer(nn.Linear):\n", 126 | " def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):\n", 127 | " super().__init__(in_features, out_features, bias, device, dtype)\n", 128 | " self.relu = torch.nn.ReLU()\n", 129 | " self.opt = Adam(self.parameters(), lr=0.03)\n", 130 | " self.threshold = 2.0\n", 131 | " self.epochs = 1000\n", 132 | "\n", 133 | " def forward(self, x):\n", 134 | " x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)\n", 135 | " return self.relu(\n", 136 | " torch.mm(x_direction, self.weight.T) +\n", 137 | " self.bias.unsqueeze(0)\n", 138 | " )\n", 139 | "\n", 140 | " def train(self, x_pos, x_neg):\n", 141 | " for i in tqdm(range(self.epochs)):\n", 142 | " g_pos = self.forward(x_pos).pow(2).mean(1)\n", 143 | " g_neg = self.forward(x_neg).pow(2).mean(1)\n", 144 | " loss = torch.log(1 + torch.exp(torch.cat([\n", 145 | " -g_pos + self.threshold,\n", 146 | " g_neg - self.threshold\n", 147 | " ]))).mean()\n", 148 | " self.opt.zero_grad()\n", 149 | " loss.backward()\n", 150 | " self.opt.step()\n", 151 | " return self.forward(x_pos).detach(), self.forward(x_neg).detach()" 152 | ], 153 | "metadata": { 154 | "id": "MC3qw48jCR4P" 155 | }, 156 | "execution_count": 5, 157 | "outputs": [] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "source": [ 162 | "torch.manual_seed(42)\n", 163 | "train_loader, test_loader = MNIST_loaders()\n", 164 | "\n", 165 | "net = Net([784, 500, 500])\n", 166 | "x, y = next(iter(train_loader))\n", 167 | "x, y = x.cuda(), y.cuda()\n", 168 | "x_pos = overlay_y_on_x(x, y)\n", 169 | "rnd = torch.randperm(x.size(0))\n", 170 | "x_neg = overlay_y_on_x(x, y[rnd])\n", 171 | "net.train(x_pos, x_neg)\n", 172 | "\n", 173 | "print('Train error:', 1.0 - net.predict(x).eq(y).float().mean().item())\n", 174 | "\n", 175 | "x_test, y_test = next(iter(test_loader))\n", 176 | "x_test, y_test = x_test.cuda(), y_test.cuda()\n", 177 | "\n", 178 | "print('Test error:', 1.0 - net.predict(x_test).eq(y_test).float().mean().item())" 179 | ], 180 | "metadata": { 181 | "colab": { 182 | "base_uri": "https://localhost:8080/" 183 | }, 184 | "id": "YPfnhB-CE1fR", 185 | "outputId": "4a97be71-9d65-43ac-b568-760ee6abe41f" 186 | }, 187 | "execution_count": 7, 188 | "outputs": [ 189 | { 190 | "output_type": "stream", 191 | "name": "stdout", 192 | "text": [ 193 | "Training layer 0 ...\n" 194 | ] 195 | }, 196 | { 197 | "output_type": "stream", 198 | "name": "stderr", 199 | "text": [ 200 | "100%|██████████| 1000/1000 [00:59<00:00, 16.67it/s]\n" 201 | ] 202 | }, 203 | { 204 | "output_type": "stream", 205 | "name": "stdout", 206 | "text": [ 207 | "Training layer 1 ...\n" 208 | ] 209 | }, 210 | { 211 | "output_type": "stream", 212 | "name": "stderr", 213 | "text": [ 214 | "100%|██████████| 1000/1000 [00:39<00:00, 25.17it/s]\n" 215 | ] 216 | }, 217 | { 218 | "output_type": "stream", 219 | "name": "stdout", 220 | "text": [ 221 | "Train error: 0.07084000110626221\n", 222 | "Test error: 0.06929999589920044\n" 223 | ] 224 | } 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "source": [ 230 | "Credits: Mohammad Pezeshki (https://github.com/mohammadpz)" 231 | ], 232 | "metadata": { 233 | "id": "iR-dwsMj8zzC" 234 | } 235 | } 236 | ] 237 | } -------------------------------------------------------------------------------- /modern_approach/transformer/self_attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "self_attention.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPmS2YULIuLDzTcuwBBAsob", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "yOFn0WjjXr2B" 35 | }, 36 | "source": [ 37 | "#Self-attention" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "tyk3ybvYXnHu" 44 | }, 45 | "source": [ 46 | "import numpy as n\n", 47 | "import tensorflow as tf" 48 | ], 49 | "execution_count": 18, 50 | "outputs": [] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "XtRRq7wmX7ys" 56 | }, 57 | "source": [ 58 | "##Krok 1. Przygotowanie wejść" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "id": "P0c54R30YEZk" 65 | }, 66 | "source": [ 67 | "x = [[1, 0, 1, 0], # wejście 1\n", 68 | " [0, 2, 0, 2], # wejście 2\n", 69 | " [1, 1, 1, 1]] # wejście 3\n", 70 | "\n", 71 | "x = np.array(x, dtype=np.float32)" 72 | ], 73 | "execution_count": 3, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": { 79 | "id": "cire_KEMYkIt" 80 | }, 81 | "source": [ 82 | "##Krok 2. Inicjalizacja wag" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "metadata": { 88 | "id": "MzQnEQyHYmpN" 89 | }, 90 | "source": [ 91 | "w_key = [[0, 0, 1],\n", 92 | " [1, 1, 0],\n", 93 | " [0, 1, 0],\n", 94 | " [1, 1, 0]]\n", 95 | "\n", 96 | "w_query = [[1, 0, 1],\n", 97 | " [1, 0, 0],\n", 98 | " [0, 0, 1],\n", 99 | " [0, 1, 1]]\n", 100 | "\n", 101 | "w_value = [[0, 2, 1],\n", 102 | " [0, 3, 0],\n", 103 | " [1, 0, 3],\n", 104 | " [1, 1, 0]]\n", 105 | "\n", 106 | "w_key = np.array(w_key, dtype=np.float32)\n", 107 | "w_query = np.array(w_query, dtype=np.float32)\n", 108 | "w_value = np.array(w_value, dtype=np.float32)" 109 | ], 110 | "execution_count": 4, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "ecF6Vrq1b3kk" 117 | }, 118 | "source": [ 119 | "##Krok 3. Wyznaczenie key, query i value" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "metadata": { 125 | "colab": { 126 | "base_uri": "https://localhost:8080/" 127 | }, 128 | "id": "fmspQnWfcFHH", 129 | "outputId": "4f9d0877-c4da-4eae-d373-fcaa129f72c7" 130 | }, 131 | "source": [ 132 | "keys = x @ w_key\n", 133 | "querys = x @ w_query\n", 134 | "values = x @ w_value\n", 135 | "\n", 136 | "print('Keys: \\n', keys)\n", 137 | "print('Querys: \\n', querys)\n", 138 | "print('Values: \\n', values)" 139 | ], 140 | "execution_count": 5, 141 | "outputs": [ 142 | { 143 | "output_type": "stream", 144 | "text": [ 145 | "Keys: \n", 146 | " [[0. 1. 1.]\n", 147 | " [4. 4. 0.]\n", 148 | " [2. 3. 1.]]\n", 149 | "Querys: \n", 150 | " [[1. 0. 2.]\n", 151 | " [2. 2. 2.]\n", 152 | " [2. 1. 3.]]\n", 153 | "Values: \n", 154 | " [[1. 2. 4.]\n", 155 | " [2. 8. 0.]\n", 156 | " [2. 6. 4.]]\n" 157 | ], 158 | "name": "stdout" 159 | } 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": { 165 | "id": "hTu3f5PSfLQ-" 166 | }, 167 | "source": [ 168 | "##Krok 4. Obliczenie attention scores" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "metadata": { 174 | "colab": { 175 | "base_uri": "https://localhost:8080/" 176 | }, 177 | "id": "wHW48xpYfcZF", 178 | "outputId": "0c24d183-43c5-46e3-a46c-45c3c4b7ef83" 179 | }, 180 | "source": [ 181 | "attn_scores = querys @ keys.T\n", 182 | "print(attn_scores)" 183 | ], 184 | "execution_count": 14, 185 | "outputs": [ 186 | { 187 | "output_type": "stream", 188 | "text": [ 189 | "[[ 2. 4. 4.]\n", 190 | " [ 4. 16. 12.]\n", 191 | " [ 4. 12. 10.]]\n" 192 | ], 193 | "name": "stdout" 194 | } 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": { 200 | "id": "RPvnjAmmgHUO" 201 | }, 202 | "source": [ 203 | "##Krok 5. Obliczenie softmax" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "metadata": { 209 | "colab": { 210 | "base_uri": "https://localhost:8080/" 211 | }, 212 | "id": "qAW5QWTIgT_T", 213 | "outputId": "4932d038-e787-4530-d5ec-75fe794493fc" 214 | }, 215 | "source": [ 216 | "attn_scores_softmax = np.round_(tf.nn.softmax(attn_scores, axis=-1), decimals=1)\n", 217 | "print(attn_scores_softmax)" 218 | ], 219 | "execution_count": 30, 220 | "outputs": [ 221 | { 222 | "output_type": "stream", 223 | "text": [ 224 | "[[0.1 0.5 0.5]\n", 225 | " [0. 1. 0. ]\n", 226 | " [0. 0.9 0.1]]\n" 227 | ], 228 | "name": "stdout" 229 | } 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "id": "Jd2bE_wzkz-6" 236 | }, 237 | "source": [ 238 | "##Krok 6. Mnożenie scores i values" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "metadata": { 244 | "colab": { 245 | "base_uri": "https://localhost:8080/" 246 | }, 247 | "id": "OsRCJnFFlKjb", 248 | "outputId": "66df39f5-7183-48ca-8d75-ab03b3b5159a" 249 | }, 250 | "source": [ 251 | "weighted_values = values[:, None] * attn_scores_softmax.T[:, :, None]\n", 252 | "print(weighted_values)" 253 | ], 254 | "execution_count": 31, 255 | "outputs": [ 256 | { 257 | "output_type": "stream", 258 | "text": [ 259 | "[[[0.1 0.2 0.4]\n", 260 | " [0. 0. 0. ]\n", 261 | " [0. 0. 0. ]]\n", 262 | "\n", 263 | " [[1. 4. 0. ]\n", 264 | " [2. 8. 0. ]\n", 265 | " [1.8 7.2 0. ]]\n", 266 | "\n", 267 | " [[1. 3. 2. ]\n", 268 | " [0. 0. 0. ]\n", 269 | " [0.2 0.6 0.4]]]\n" 270 | ], 271 | "name": "stdout" 272 | } 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": { 278 | "id": "toP8zcCLl7Yo" 279 | }, 280 | "source": [ 281 | "##Krok 7. Suma zważonych wartości" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "metadata": { 287 | "colab": { 288 | "base_uri": "https://localhost:8080/" 289 | }, 290 | "id": "1G0hGuCMmDTs", 291 | "outputId": "16140ec9-885f-4f4d-8247-c056b6d0a503" 292 | }, 293 | "source": [ 294 | "outputs = np.sum(weighted_values, axis=0)\n", 295 | "print(outputs)" 296 | ], 297 | "execution_count": 32, 298 | "outputs": [ 299 | { 300 | "output_type": "stream", 301 | "text": [ 302 | "[[2.1 7.2 2.4 ]\n", 303 | " [2. 8. 0. ]\n", 304 | " [2. 7.7999997 0.4 ]]\n" 305 | ], 306 | "name": "stdout" 307 | } 308 | ] 309 | } 310 | ] 311 | } -------------------------------------------------------------------------------- /modern_approach/zero_shot_learning/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PsorTheDoctor/artificial-intelligence/82516ac57eb13f14e8214633a0960bea0cd9e0fb/modern_approach/zero_shot_learning/example.png -------------------------------------------------------------------------------- /neural_networks/CNN/pixelcnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pixelcnn.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPfd36wg6CHF4lkcJIwlq3z", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "WWCYgQAABKid" 36 | }, 37 | "source": [ 38 | "#PixelCNN\n", 39 | "##Import bibliotek" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "metadata": { 45 | "id": "DTLjE01TAf4J" 46 | }, 47 | "source": [ 48 | "import numpy as np\n", 49 | "import tensorflow as tf\n", 50 | "from tensorflow import keras\n", 51 | "from tensorflow.keras import layers\n", 52 | "from tqdm import tqdm" 53 | ], 54 | "execution_count": 1, 55 | "outputs": [] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "Lug7PJ_xBQI9" 61 | }, 62 | "source": [ 63 | "##Załadowanie danych" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/" 71 | }, 72 | "id": "YI2AIzVEBIrW", 73 | "outputId": "4881db92-4ec5-4374-f97c-7ac8e33f7a0d" 74 | }, 75 | "source": [ 76 | "num_classes = 10\n", 77 | "input_shape = (28, 28, 1)\n", 78 | "n_residual_blocks = 5\n", 79 | "# Podział danych na traningowe i testowe\n", 80 | "(x, _), (y, _) = keras.datasets.mnist.load_data()\n", 81 | "# Zaokrąglenie wszystkich pikseli mniejszych od 33% z 256 do 0\n", 82 | "# Wszystko powyżej tej wartości zostanie zaokrąglone do 1, więc wartości\n", 83 | "# będą równe 0 lub 1\n", 84 | "data = np.concatenate((x, y), axis=0)\n", 85 | "data = np.where(data < (0.33 * 256), 0, 1)\n", 86 | "data = data.astype(np.float32)" 87 | ], 88 | "execution_count": 2, 89 | "outputs": [ 90 | { 91 | "output_type": "stream", 92 | "text": [ 93 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", 94 | "11493376/11490434 [==============================] - 0s 0us/step\n" 95 | ], 96 | "name": "stdout" 97 | } 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": { 103 | "id": "NdhDl5CNKdNm" 104 | }, 105 | "source": [ 106 | "##Stworzenie warstw modelu" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "metadata": { 112 | "id": "R8F0wRltBLGc" 113 | }, 114 | "source": [ 115 | "class PixelConvLayer(layers.Layer):\n", 116 | " def __init__(self, mask_type, **kwargs):\n", 117 | " super(PixelConvLayer, self).__init__()\n", 118 | " self.mask_type = mask_type\n", 119 | " self.conv = layers.Conv2D(**kwargs)\n", 120 | "\n", 121 | " def build(self, input_shape):\n", 122 | " self.conv.build(input_shape)\n", 123 | " kernel_shape = self.conv.kernel.get_shape()\n", 124 | " self.mask = np.zeros(shape=kernel_shape)\n", 125 | " self.mask[: kernel_shape[0] // 2, ...] = 1.0\n", 126 | " self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0\n", 127 | " if self.mask_type == 'B':\n", 128 | " self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0\n", 129 | "\n", 130 | " def call(self, inputs):\n", 131 | " self.conv.kernel.assign(self.conv.kernel * self.mask)\n", 132 | " return self.conv(inputs)\n", 133 | "\n", 134 | "\n", 135 | "class ResidualBlock(layers.Layer):\n", 136 | " def __init__(self, filters, **kwargs):\n", 137 | " super(ResidualBlock, self).__init__(kwargs);\n", 138 | " self.conv1 = layers.Conv2D(\n", 139 | " filters=filters, kernel_size=1, activation='relu'\n", 140 | " )\n", 141 | " self.pixel_conv = PixelConvLayer(\n", 142 | " mask_type='B',\n", 143 | " filters = filters // 2,\n", 144 | " kernel_size=3,\n", 145 | " activation='relu',\n", 146 | " padding='same'\n", 147 | " )\n", 148 | " self.conv2 = layers.Conv2D(\n", 149 | " filters=filters, kernel_size=1, activation='relu'\n", 150 | " )\n", 151 | "\n", 152 | " def call(self, inputs):\n", 153 | " x = self.conv1(inputs)\n", 154 | " x = self.pixel_conv(x)\n", 155 | " x = self.conv2(x)\n", 156 | " return layers.add([inputs, x])" 157 | ], 158 | "execution_count": 13, 159 | "outputs": [] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": { 164 | "id": "-6jAS8QuMhVt" 165 | }, 166 | "source": [ 167 | "##Budowa modelu" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "metadata": { 173 | "colab": { 174 | "base_uri": "https://localhost:8080/" 175 | }, 176 | "id": "98GvFMF4MkTz", 177 | "outputId": "fc7a0fd7-99c7-40e3-c416-ed6dd7a9eadf" 178 | }, 179 | "source": [ 180 | "inputs = keras.Input(shape=input_shape)\n", 181 | "x = PixelConvLayer(\n", 182 | " mask_type='A', filters=128, kernel_size=7, activation='relu', padding='same'\n", 183 | ")(inputs)\n", 184 | "\n", 185 | "for _ in range(n_residual_blocks):\n", 186 | " x = ResidualBlock(filters=128)(x)\n", 187 | "\n", 188 | "for _ in range(2):\n", 189 | " x = PixelConvLayer(\n", 190 | " mask_type='B',\n", 191 | " filters=128,\n", 192 | " kernel_size=1,\n", 193 | " strides=1,\n", 194 | " activation='relu',\n", 195 | " padding='valid'\n", 196 | " )(x)\n", 197 | "\n", 198 | "out = layers.Conv2D(\n", 199 | " filters=1, kernel_size=1, strides=1, activation='sigmoid', padding='valid'\n", 200 | ")(x)\n", 201 | "\n", 202 | "pixel_cnn = keras.Model(inputs, out)\n", 203 | "adam = keras.optimizers.Adam(learning_rate=0.0005)\n", 204 | "pixel_cnn.compile(optimizer=adam, loss='binary_crossentropy')\n", 205 | "\n", 206 | "pixel_cnn.summary()\n", 207 | "pixel_cnn.fit(\n", 208 | " x=data, y=data, batch_size=128, epochs=50, validation_split=0.1, verbose=2\n", 209 | ")" 210 | ], 211 | "execution_count": 14, 212 | "outputs": [ 213 | { 214 | "output_type": "stream", 215 | "text": [ 216 | "Model: \"model\"\n", 217 | "_________________________________________________________________\n", 218 | "Layer (type) Output Shape Param # \n", 219 | "=================================================================\n", 220 | "input_6 (InputLayer) [(None, 28, 28, 1)] 0 \n", 221 | "_________________________________________________________________\n", 222 | "pixel_conv_layer_5 (PixelCon (None, 28, 28, 128) 6400 \n", 223 | "_________________________________________________________________\n", 224 | "residual_block_1 (ResidualBl (None, 28, 28, 128) 98624 \n", 225 | "_________________________________________________________________\n", 226 | "residual_block_2 (ResidualBl (None, 28, 28, 128) 98624 \n", 227 | "_________________________________________________________________\n", 228 | "residual_block_3 (ResidualBl (None, 28, 28, 128) 98624 \n", 229 | "_________________________________________________________________\n", 230 | "residual_block_4 (ResidualBl (None, 28, 28, 128) 98624 \n", 231 | "_________________________________________________________________\n", 232 | "residual_block_5 (ResidualBl (None, 28, 28, 128) 98624 \n", 233 | "_________________________________________________________________\n", 234 | "pixel_conv_layer_11 (PixelCo (None, 28, 28, 128) 16512 \n", 235 | "_________________________________________________________________\n", 236 | "pixel_conv_layer_12 (PixelCo (None, 28, 28, 128) 16512 \n", 237 | "_________________________________________________________________\n", 238 | "conv2d_24 (Conv2D) (None, 28, 28, 1) 129 \n", 239 | "=================================================================\n", 240 | "Total params: 532,673\n", 241 | "Trainable params: 39,553\n", 242 | "Non-trainable params: 493,120\n", 243 | "_________________________________________________________________\n", 244 | "Epoch 1/50\n", 245 | "493/493 - 77s - loss: 0.1387 - val_loss: 0.0975\n", 246 | "Epoch 2/50\n", 247 | "493/493 - 46s - loss: 0.0965 - val_loss: 0.0954\n", 248 | "Epoch 3/50\n", 249 | "493/493 - 46s - loss: 0.0949 - val_loss: 0.0946\n", 250 | "Epoch 4/50\n", 251 | "493/493 - 47s - loss: 0.0942 - val_loss: 0.0939\n", 252 | "Epoch 5/50\n", 253 | "493/493 - 47s - loss: 0.0938 - val_loss: 0.0941\n", 254 | "Epoch 6/50\n", 255 | "493/493 - 47s - loss: 0.0935 - val_loss: 0.0936\n", 256 | "Epoch 7/50\n", 257 | "493/493 - 48s - loss: 0.0932 - val_loss: 0.0932\n", 258 | "Epoch 8/50\n", 259 | "493/493 - 48s - loss: 0.0930 - val_loss: 0.0931\n", 260 | "Epoch 9/50\n", 261 | "493/493 - 48s - loss: 0.0929 - val_loss: 0.0934\n", 262 | "Epoch 10/50\n", 263 | "493/493 - 48s - loss: 0.0928 - val_loss: 0.0930\n", 264 | "Epoch 11/50\n", 265 | "493/493 - 48s - loss: 0.0927 - val_loss: 0.0928\n", 266 | "Epoch 12/50\n", 267 | "493/493 - 48s - loss: 0.0926 - val_loss: 0.0927\n", 268 | "Epoch 13/50\n", 269 | "493/493 - 48s - loss: 0.0925 - val_loss: 0.0928\n", 270 | "Epoch 14/50\n", 271 | "493/493 - 48s - loss: 0.0925 - val_loss: 0.0926\n", 272 | "Epoch 15/50\n", 273 | "493/493 - 48s - loss: 0.0924 - val_loss: 0.0925\n", 274 | "Epoch 16/50\n", 275 | "493/493 - 48s - loss: 0.0924 - val_loss: 0.0927\n", 276 | "Epoch 17/50\n", 277 | "493/493 - 48s - loss: 0.0922 - val_loss: 0.0927\n", 278 | "Epoch 18/50\n", 279 | "493/493 - 48s - loss: 0.0922 - val_loss: 0.0932\n", 280 | "Epoch 19/50\n", 281 | "493/493 - 48s - loss: 0.0922 - val_loss: 0.0923\n", 282 | "Epoch 20/50\n", 283 | "493/493 - 48s - loss: 0.0921 - val_loss: 0.0923\n", 284 | "Epoch 21/50\n", 285 | "493/493 - 48s - loss: 0.0921 - val_loss: 0.0922\n", 286 | "Epoch 22/50\n", 287 | "493/493 - 48s - loss: 0.0921 - val_loss: 0.0924\n", 288 | "Epoch 23/50\n", 289 | "493/493 - 48s - loss: 0.0920 - val_loss: 0.0922\n", 290 | "Epoch 24/50\n", 291 | "493/493 - 48s - loss: 0.0920 - val_loss: 0.0925\n", 292 | "Epoch 25/50\n", 293 | "493/493 - 48s - loss: 0.0920 - val_loss: 0.0921\n", 294 | "Epoch 26/50\n", 295 | "493/493 - 48s - loss: 0.0919 - val_loss: 0.0933\n", 296 | "Epoch 27/50\n", 297 | "493/493 - 48s - loss: 0.0919 - val_loss: 0.0920\n", 298 | "Epoch 28/50\n", 299 | "493/493 - 48s - loss: 0.0918 - val_loss: 0.0919\n", 300 | "Epoch 29/50\n", 301 | "493/493 - 48s - loss: 0.0918 - val_loss: 0.0920\n", 302 | "Epoch 30/50\n", 303 | "493/493 - 48s - loss: 0.0918 - val_loss: 0.0920\n", 304 | "Epoch 31/50\n", 305 | "493/493 - 48s - loss: 0.0918 - val_loss: 0.0923\n", 306 | "Epoch 32/50\n", 307 | "493/493 - 48s - loss: 0.0917 - val_loss: 0.0919\n", 308 | "Epoch 33/50\n", 309 | "493/493 - 48s - loss: 0.0917 - val_loss: 0.0920\n", 310 | "Epoch 34/50\n", 311 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0919\n", 312 | "Epoch 35/50\n", 313 | "493/493 - 48s - loss: 0.0917 - val_loss: 0.0918\n", 314 | "Epoch 36/50\n", 315 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0917\n", 316 | "Epoch 37/50\n", 317 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0919\n", 318 | "Epoch 38/50\n", 319 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0917\n", 320 | "Epoch 39/50\n", 321 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0922\n", 322 | "Epoch 40/50\n", 323 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0919\n", 324 | "Epoch 41/50\n", 325 | "493/493 - 48s - loss: 0.0916 - val_loss: 0.0917\n", 326 | "Epoch 42/50\n", 327 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0916\n", 328 | "Epoch 43/50\n", 329 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0917\n", 330 | "Epoch 44/50\n", 331 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0916\n", 332 | "Epoch 45/50\n", 333 | "493/493 - 48s - loss: 0.0915 - val_loss: 0.0918\n", 334 | "Epoch 46/50\n", 335 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0916\n", 336 | "Epoch 47/50\n", 337 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0916\n", 338 | "Epoch 48/50\n", 339 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0917\n", 340 | "Epoch 49/50\n", 341 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0919\n", 342 | "Epoch 50/50\n", 343 | "493/493 - 48s - loss: 0.0914 - val_loss: 0.0922\n" 344 | ], 345 | "name": "stdout" 346 | }, 347 | { 348 | "output_type": "execute_result", 349 | "data": { 350 | "text/plain": [ 351 | "" 352 | ] 353 | }, 354 | "metadata": { 355 | "tags": [] 356 | }, 357 | "execution_count": 14 358 | } 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": { 364 | "id": "N24P8Z9-T7XX" 365 | }, 366 | "source": [ 367 | "##Demonstracja" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "metadata": { 373 | "colab": { 374 | "base_uri": "https://localhost:8080/", 375 | "height": 146 376 | }, 377 | "id": "rw3u7OYUTxWp", 378 | "outputId": "d3fb76c3-c49f-44a3-e801-2a527a98343e" 379 | }, 380 | "source": [ 381 | "from IPython.display import Image, display\n", 382 | "\n", 383 | "batch = 4\n", 384 | "pixels = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])\n", 385 | "batch, rows, cols, channels = pixels.shape\n", 386 | "\n", 387 | "for row in tqdm(range(rows)):\n", 388 | " for col in range(cols):\n", 389 | " for channel in range(channels):\n", 390 | " \n", 391 | " probs = pixel_cnn.predict(pixels)[:, row, col, channel]\n", 392 | " \n", 393 | " pixels[:, row, col, channel] = tf.math.ceil(\n", 394 | " probs - tf.random.uniform(probs.shape)\n", 395 | " )\n", 396 | "\n", 397 | "def deprocess_image(x):\n", 398 | " x = np.stack((x, x, x), 2)\n", 399 | "\n", 400 | " x *= 255.0\n", 401 | "\n", 402 | " x = np.clip(x, 0, 255).astype('uint8')\n", 403 | " return x\n", 404 | "\n", 405 | "for i, pic in enumerate(pixels):\n", 406 | " keras.preprocessing.image.save_img(\n", 407 | " 'generated_image_{}.png'.format(i), deprocess_image(np.squeeze(pic, -1))\n", 408 | " )\n", 409 | "\n", 410 | "display(Image('generated_image_0.png'))\n", 411 | "display(Image('generated_image_1.png'))\n", 412 | "display(Image('generated_image_2.png'))\n", 413 | "display(Image('generated_image_3.png'))" 414 | ], 415 | "execution_count": 18, 416 | "outputs": [ 417 | { 418 | "output_type": "stream", 419 | "text": [ 420 | "100%|██████████| 28/28 [00:32<00:00, 1.18s/it]\n" 421 | ], 422 | "name": "stderr" 423 | }, 424 | { 425 | "output_type": "display_data", 426 | "data": { 427 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAAh0lEQVR4nO2UwQ7AIAhDwez/f5kdXKYSdKHgTvZmnG+1gERHR6litRaRtsd61w3tccMXfnTBvKx1qTUzV8vw3al3WimzEFwy7MRrZZ9RfiNR2HQgkC3V/x0Kd8Lm5p9VPHNMI1PwMUVYh5aUt05DF0T4H0OmWa4bNDEH4z2N4B4Iduy9lmniBgARMCy+4/TbAAAAAElFTkSuQmCC\n", 428 | "text/plain": [ 429 | "" 430 | ] 431 | }, 432 | "metadata": { 433 | "tags": [] 434 | } 435 | }, 436 | { 437 | "output_type": "display_data", 438 | "data": { 439 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAAZklEQVR4nO2USxYAIARF1f73rFlO6UM9M3fY5x5CREmS4Ch6iZlluywOXKkH4zNDIN34FmBHIkUZSaf/bxQp5CkHKTBxkaJcSSC+Ek2dt6vw3PwQHFL7gCykxsuH7g75+qzpu0a5AQqvIRRWxCD2AAAAAElFTkSuQmCC\n", 440 | "text/plain": [ 441 | "" 442 | ] 443 | }, 444 | "metadata": { 445 | "tags": [] 446 | } 447 | }, 448 | { 449 | "output_type": "display_data", 450 | "data": { 451 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAAbUlEQVR4nO2QMQ7AMAgD7Sj//zIdWlVtAqQKDB24ESUn20BRFEUSTLGIyKUjAfQs3ZOWaDxj4nt9NZFqVKTDOr53fvOSqkGsP0uaZXTua6nzM5Q0nTbEIXlftut3BGpaePV/tul82p7Sk8YnPgA8jCQivImO7AAAAABJRU5ErkJggg==\n", 452 | "text/plain": [ 453 | "" 454 | ] 455 | }, 456 | "metadata": { 457 | "tags": [] 458 | } 459 | }, 460 | { 461 | "output_type": "display_data", 462 | "data": { 463 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAAYUlEQVR4nO2SMQ7AMAgDTdX/f9kZKnUCh9BE6sCtiAMLgKZpmhVIknRLV9koqhXpazSzPdKpsSLN8CepyA7g1s36yuHIgkuv6Ug/6h7C+Ml+l/PXzzz2mnSXEWfj1/4xYgBOgCEY9BznlAAAAABJRU5ErkJggg==\n", 464 | "text/plain": [ 465 | "" 466 | ] 467 | }, 468 | "metadata": { 469 | "tags": [] 470 | } 471 | } 472 | ] 473 | } 474 | ] 475 | } -------------------------------------------------------------------------------- /neural_networks/GAN/draggan.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4", 8 | "authorship_tag": "ABX9TyNxNhZKkvQ0+dfoe14tVQd8", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "#DragGAN" 35 | ], 36 | "metadata": { 37 | "id": "FiZr6NwNrF2H" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "id": "9I5eZYJKq0no" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "!git clone https://github.com/Zeqiang-Lai/DragGAN.git\n", 49 | "\n", 50 | "import sys\n", 51 | "sys.path.append(\".\")\n", 52 | "sys.path.append('./DragGAN')\n", 53 | "\n", 54 | "!pip install -r DragGAN/requirements.txt\n", 55 | "\n", 56 | "from gradio_app import main" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "source": [ 62 | "demo = main()\n", 63 | "demo.queue(concurrency_count=1, max_size=20).launch()" 64 | ], 65 | "metadata": { 66 | "id": "VPgSsKPsrKKl" 67 | }, 68 | "execution_count": null, 69 | "outputs": [] 70 | } 71 | ] 72 | } -------------------------------------------------------------------------------- /neural_networks/GAN/ersgan_for_vdm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyOlGtAwGTMODsdOeXhygfwK", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "#ESRGAN for video diffusion" 33 | ], 34 | "metadata": { 35 | "id": "754mykfUJ1U7" 36 | } 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 14, 41 | "metadata": { 42 | "id": "4B0zUQQCJza4" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import numpy as np\n", 47 | "import tensorflow_hub as hub\n", 48 | "import tensorflow as tf\n", 49 | "import cv2\n", 50 | "from google.colab.patches import cv2_imshow" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "source": [ 56 | "model = hub.load('https://tfhub.dev/captain-pool/esrgan-tf2/1')" 57 | ], 58 | "metadata": { 59 | "id": "yMVVA5dcLPVt" 60 | }, 61 | "execution_count": 2, 62 | "outputs": [] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "source": [ 67 | "video_file = 'example.mp4'\n", 68 | "cap = cv2.VideoCapture(video_file)\n", 69 | "n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", 70 | "frames = []\n", 71 | "\n", 72 | "for i in range(n_frames):\n", 73 | " _, frame = cap.read()\n", 74 | " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", 75 | "\n", 76 | " frame = tf.expand_dims(frame, 0)\n", 77 | " low_resolution = tf.cast(frame, tf.float32)\n", 78 | " super_resolution = model(low_resolution)\n", 79 | "\n", 80 | " frames.append(super_resolution)\n", 81 | "\n", 82 | " print(f'{i + 1}/{n_frames} frames processed.')\n", 83 | "\n", 84 | "cap.release()\n", 85 | "frames = np.array(frames)\n", 86 | "frames.shape" 87 | ], 88 | "metadata": { 89 | "colab": { 90 | "base_uri": "https://localhost:8080/" 91 | }, 92 | "id": "5WRnIyHXKOAv", 93 | "outputId": "ff12e830-1c1f-4158-e065-1f1f8a9eabbd" 94 | }, 95 | "execution_count": 6, 96 | "outputs": [ 97 | { 98 | "output_type": "execute_result", 99 | "data": { 100 | "text/plain": [ 101 | "(16, 1, 1024, 1024, 3)" 102 | ] 103 | }, 104 | "metadata": {}, 105 | "execution_count": 6 106 | } 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "source": [ 112 | "output_frames = frames.reshape(16, 1024, 1024, 3)\n", 113 | "cv2_imshow(output_frames[0])" 114 | ], 115 | "metadata": { 116 | "id": "nR8Z7xMMPTMC" 117 | }, 118 | "execution_count": null, 119 | "outputs": [] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "source": [ 124 | "output_filename = 'result.mp4'\n", 125 | "fps = 16\n", 126 | "resolution = (1024, 1024)\n", 127 | "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n", 128 | "writer = cv2.VideoWriter(output_filename, fourcc, fps, resolution)\n", 129 | "\n", 130 | "for frame in output_frames:\n", 131 | " frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n", 132 | " frame = (frame * 255).astype(np.uint8)\n", 133 | " writer.write(frame)\n", 134 | "\n", 135 | "writer.release()" 136 | ], 137 | "metadata": { 138 | "id": "yY5aHcMyNa4T" 139 | }, 140 | "execution_count": 16, 141 | "outputs": [] 142 | } 143 | ] 144 | } -------------------------------------------------------------------------------- /neural_networks/MLP/experimental/zip_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "zip_learning.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyMopxhbnZCQHibT6Wyxr/mi", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "c1ACAYy4Iuev", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "# Zip Learning - Uczenie skompresowane\n", 36 | "Eksperymentalna metoda uczenia na skompresowanym zbiorze Mnist." 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "6ljDzo5Lupqs", 43 | "colab_type": "text" 44 | }, 45 | "source": [ 46 | "## Import bibliotek" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "id": "P5lGUmImIt48", 53 | "colab_type": "code", 54 | "colab": {} 55 | }, 56 | "source": [ 57 | "%tensorflow_version 2.x\n", 58 | "from datetime import datetime\n", 59 | "import numpy as np\n", 60 | "\n", 61 | "import tensorflow as tf\n", 62 | "from tensorflow.keras.datasets.mnist import load_data\n", 63 | "from tensorflow.keras.models import Sequential\n", 64 | "from tensorflow.keras.layers import InputLayer, Dense, Dropout" 65 | ], 66 | "execution_count": 0, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "0RfpmBAhvBMm", 73 | "colab_type": "text" 74 | }, 75 | "source": [ 76 | "## Funkcje kompresujące\n", 77 | "Zastosowany algorytm kompresji jest podobny algorytmu kompresji bezstratnej RLE. Funkcje kompresji dla tablic jednowymiarowych (wektorów) i dwuwymiarowych (macierzy) to odpowiednio `zip1D` i `zip2D`. Obie operują na obiektach `numpy.ndarray`. Funkcje zwracają po 2 tablice, z których pierwsza `unique_vals` zawiera serię wartości wystepujących w argumencie (jeśli wiele takich samych wartości stoi po sobie to są zapisywanie jako jedna), a druga `vals_ctr` - ilość wystąpień każdej z nich. \n", 78 | "\n", 79 | "\n" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "id": "VkIMCZ07K935", 86 | "colab_type": "code", 87 | "colab": {} 88 | }, 89 | "source": [ 90 | "# Kompresja 1-wymiarowa\n", 91 | "def zip1D(array):\n", 92 | " unique_vals = []\n", 93 | " vals_ctr = []\n", 94 | " current_val = None\n", 95 | " idx = -1\n", 96 | "\n", 97 | " for i in range(len(array)):\n", 98 | " if array[i] != current_val: # \"is not\" doesn't work with numpy arrays!\n", 99 | " current_val = array[i]\n", 100 | " unique_vals.append(current_val)\n", 101 | " vals_ctr.append(0)\n", 102 | " idx += 1\n", 103 | " vals_ctr[idx] += 1\n", 104 | "\n", 105 | " return unique_vals, vals_ctr\n", 106 | "\n", 107 | "# Kompresja 2-wymiarowa\n", 108 | "def zip2D(mat):\n", 109 | " if type(mat) is np.ndarray:\n", 110 | " array = mat.flatten(order='C')\n", 111 | " return zip1D(array)" 112 | ], 113 | "execution_count": 0, 114 | "outputs": [] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": { 119 | "id": "ea_iVUta02EB", 120 | "colab_type": "text" 121 | }, 122 | "source": [ 123 | "Przykład działania funkcji `zip1D`:" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "metadata": { 129 | "id": "tuW_cOd80TFD", 130 | "colab_type": "code", 131 | "colab": { 132 | "base_uri": "https://localhost:8080/", 133 | "height": 35 134 | }, 135 | "outputId": "c580479b-3c55-4ecd-855d-286dbc2367bb" 136 | }, 137 | "source": [ 138 | "array = np.array(['A','A','B','A','A','A','A'])\n", 139 | "\n", 140 | "unique_vals, vals_ctr = zip1D(array)\n", 141 | "unique_vals, vals_ctr" 142 | ], 143 | "execution_count": 114, 144 | "outputs": [ 145 | { 146 | "output_type": "execute_result", 147 | "data": { 148 | "text/plain": [ 149 | "(['A', 'B', 'A'], [2, 1, 4])" 150 | ] 151 | }, 152 | "metadata": { 153 | "tags": [] 154 | }, 155 | "execution_count": 114 156 | } 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": { 162 | "id": "C37RSqzi1COH", 163 | "colab_type": "text" 164 | }, 165 | "source": [ 166 | "Przykład działania funkcji `zip2D`:" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "metadata": { 172 | "id": "FgU4HUN60gYv", 173 | "colab_type": "code", 174 | "colab": { 175 | "base_uri": "https://localhost:8080/", 176 | "height": 35 177 | }, 178 | "outputId": "3f58125f-2a5b-42f8-d560-606c4e8e2005" 179 | }, 180 | "source": [ 181 | "mat = np.array([['A','B','A'], \n", 182 | " ['B','B','B'], \n", 183 | " ['A','B','A']])\n", 184 | "\n", 185 | "unique_vals, vals_ctr = zip2D(mat)\n", 186 | "unique_vals, vals_ctr" 187 | ], 188 | "execution_count": 115, 189 | "outputs": [ 190 | { 191 | "output_type": "execute_result", 192 | "data": { 193 | "text/plain": [ 194 | "(['A', 'B', 'A', 'B', 'A', 'B', 'A'], [1, 1, 1, 3, 1, 1, 1])" 195 | ] 196 | }, 197 | "metadata": { 198 | "tags": [] 199 | }, 200 | "execution_count": 115 201 | } 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": { 207 | "id": "N-GGUJ2WvYB6", 208 | "colab_type": "text" 209 | }, 210 | "source": [ 211 | "## Załadowanie danych" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "metadata": { 217 | "id": "bGk5m10QJ8CT", 218 | "colab_type": "code", 219 | "colab": {} 220 | }, 221 | "source": [ 222 | "(X_train, y_train), (X_test, y_test) = load_data()" 223 | ], 224 | "execution_count": 0, 225 | "outputs": [] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "metadata": { 230 | "id": "859C0hRsqXE4", 231 | "colab_type": "code", 232 | "colab": { 233 | "base_uri": "https://localhost:8080/", 234 | "height": 52 235 | }, 236 | "outputId": "79cd53cc-6de0-47bf-ada5-c9f52880e8cb" 237 | }, 238 | "source": [ 239 | "print(X_train.shape)\n", 240 | "print(y_train.shape)" 241 | ], 242 | "execution_count": 93, 243 | "outputs": [ 244 | { 245 | "output_type": "stream", 246 | "text": [ 247 | "(60000, 28, 28)\n", 248 | "(60000,)\n" 249 | ], 250 | "name": "stdout" 251 | } 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": { 257 | "id": "PNKvvLvAvczo", 258 | "colab_type": "text" 259 | }, 260 | "source": [ 261 | "## Normalizacja" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "metadata": { 267 | "id": "KsN4OFKKRPqU", 268 | "colab_type": "code", 269 | "colab": {} 270 | }, 271 | "source": [ 272 | "# X_train = X_train / 255.0\n", 273 | "n_samples = len(X_train)\n", 274 | "width = len(X_train[0])\n", 275 | "height = len(X_train[0][0])\n", 276 | "threshold = 128\n", 277 | "\n", 278 | "for n in range(n_samples):\n", 279 | " for x in range(width):\n", 280 | " for y in range(height):\n", 281 | " if X_train[n][x][y] < threshold:\n", 282 | " X_train[n][x][y] = 0\n", 283 | " else:\n", 284 | " X_train[n][x][y] = 1" 285 | ], 286 | "execution_count": 0, 287 | "outputs": [] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": { 292 | "id": "zrIYdbeUvqgg", 293 | "colab_type": "text" 294 | }, 295 | "source": [ 296 | "## Kompresja" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "metadata": { 302 | "id": "eOX2awg3Krpe", 303 | "colab_type": "code", 304 | "colab": { 305 | "base_uri": "https://localhost:8080/", 306 | "height": 35 307 | }, 308 | "outputId": "c3e6571c-fbdb-402d-94aa-786be5427796" 309 | }, 310 | "source": [ 311 | "X_train_unique_vals = []\n", 312 | "X_train_vals_ctr = []\n", 313 | "\n", 314 | "start = datetime.now()\n", 315 | "for i in range(len(X_train)):\n", 316 | " current_val, val_ctr = zip2D(X_train[i])\n", 317 | " X_train_unique_vals.append(current_val)\n", 318 | " X_train_vals_ctr.append(val_ctr)\n", 319 | "\n", 320 | "zipping_time = datetime.now() - start\n", 321 | "print(zipping_time)" 322 | ], 323 | "execution_count": 97, 324 | "outputs": [ 325 | { 326 | "output_type": "stream", 327 | "text": [ 328 | "0:00:13.333470\n" 329 | ], 330 | "name": "stdout" 331 | } 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "metadata": { 337 | "id": "lZ3eWuCZMdOg", 338 | "colab_type": "code", 339 | "colab": { 340 | "base_uri": "https://localhost:8080/", 341 | "height": 87 342 | }, 343 | "outputId": "d85136e6-11e1-4582-c1c8-9e15d8252ec0" 344 | }, 345 | "source": [ 346 | "print(len(X_train_vals_ctr))\n", 347 | "print(len(X_train_unique_vals))\n", 348 | "\n", 349 | "print(len(X_train_unique_vals[0]))\n", 350 | "print(len(X_train_vals_ctr[0]))" 351 | ], 352 | "execution_count": 98, 353 | "outputs": [ 354 | { 355 | "output_type": "stream", 356 | "text": [ 357 | "60000\n", 358 | "60000\n", 359 | "47\n", 360 | "47\n" 361 | ], 362 | "name": "stdout" 363 | } 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": { 369 | "id": "71lTm-MW1XCy", 370 | "colab_type": "text" 371 | }, 372 | "source": [ 373 | "## Znalezienie wektora o największej długości" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "metadata": { 379 | "id": "M3qinvVjVABc", 380 | "colab_type": "code", 381 | "colab": { 382 | "base_uri": "https://localhost:8080/", 383 | "height": 35 384 | }, 385 | "outputId": "769a8b07-dc9f-41df-8a69-06fe67608608" 386 | }, 387 | "source": [ 388 | "best_length = 0\n", 389 | "for i in range(len(X_train_vals_ctr)):\n", 390 | " if len(X_train_vals_ctr[i]) > best_length:\n", 391 | " best_length = len(X_train_vals_ctr[i])\n", 392 | "\n", 393 | "print(best_length)" 394 | ], 395 | "execution_count": 99, 396 | "outputs": [ 397 | { 398 | "output_type": "stream", 399 | "text": [ 400 | "101\n" 401 | ], 402 | "name": "stdout" 403 | } 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": { 409 | "id": "NHn-aUW9v7QB", 410 | "colab_type": "text" 411 | }, 412 | "source": [ 413 | "## Padding" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "metadata": { 419 | "id": "9LPL27F6Vspe", 420 | "colab_type": "code", 421 | "colab": { 422 | "base_uri": "https://localhost:8080/", 423 | "height": 35 424 | }, 425 | "outputId": "7172c958-a973-4c8f-b27b-b53a979d50eb" 426 | }, 427 | "source": [ 428 | "padded_X_train_vals_ctr = np.zeros([len(X_train_vals_ctr), best_length])\n", 429 | "\n", 430 | "for i in range(len(X_train_vals_ctr)):\n", 431 | " padded = np.pad(X_train_vals_ctr[i], \n", 432 | " pad_width=(best_length - len(X_train_vals_ctr[i]), 0), \n", 433 | " mode='constant')\n", 434 | " padded_X_train_vals_ctr[i] = padded\n", 435 | "\n", 436 | "print(padded_X_train_vals_ctr[0].shape)" 437 | ], 438 | "execution_count": 100, 439 | "outputs": [ 440 | { 441 | "output_type": "stream", 442 | "text": [ 443 | "(101,)\n" 444 | ], 445 | "name": "stdout" 446 | } 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "metadata": { 452 | "id": "WtvStaNPegMa", 453 | "colab_type": "code", 454 | "colab": { 455 | "base_uri": "https://localhost:8080/", 456 | "height": 52 457 | }, 458 | "outputId": "05f76082-fb36-456f-f1b7-a5a8cdee709d" 459 | }, 460 | "source": [ 461 | "print(padded_X_train_vals_ctr.shape)\n", 462 | "print(y_train.shape)" 463 | ], 464 | "execution_count": 101, 465 | "outputs": [ 466 | { 467 | "output_type": "stream", 468 | "text": [ 469 | "(60000, 101)\n", 470 | "(60000,)\n" 471 | ], 472 | "name": "stdout" 473 | } 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": { 479 | "id": "yEyr9xC-v0SK", 480 | "colab_type": "text" 481 | }, 482 | "source": [ 483 | "## Budowa sieci MLP" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "metadata": { 489 | "id": "jiUJ8hknSwqD", 490 | "colab_type": "code", 491 | "colab": { 492 | "base_uri": "https://localhost:8080/", 493 | "height": 260 494 | }, 495 | "outputId": "6e8bc998-713c-478d-ec05-26878556088d" 496 | }, 497 | "source": [ 498 | "model = Sequential()\n", 499 | "model.add(InputLayer(input_shape=(padded_X_train_vals_ctr.shape)))\n", 500 | "model.add(Dense(units=128, activation='relu'))\n", 501 | "model.add(Dropout(0.2))\n", 502 | "model.add(Dense(units=10, activation='softmax'))\n", 503 | "\n", 504 | "model.compile(optimizer='adam',\n", 505 | " loss='sparse_categorical_crossentropy',\n", 506 | " metrics=['accuracy'])\n", 507 | "\n", 508 | "model.summary()" 509 | ], 510 | "execution_count": 102, 511 | "outputs": [ 512 | { 513 | "output_type": "stream", 514 | "text": [ 515 | "Model: \"sequential_8\"\n", 516 | "_________________________________________________________________\n", 517 | "Layer (type) Output Shape Param # \n", 518 | "=================================================================\n", 519 | "dense_12 (Dense) (None, 60000, 128) 13056 \n", 520 | "_________________________________________________________________\n", 521 | "dropout_6 (Dropout) (None, 60000, 128) 0 \n", 522 | "_________________________________________________________________\n", 523 | "dense_13 (Dense) (None, 60000, 10) 1290 \n", 524 | "=================================================================\n", 525 | "Total params: 14,346\n", 526 | "Trainable params: 14,346\n", 527 | "Non-trainable params: 0\n", 528 | "_________________________________________________________________\n" 529 | ], 530 | "name": "stdout" 531 | } 532 | ] 533 | }, 534 | { 535 | "cell_type": "markdown", 536 | "metadata": { 537 | "id": "d5Lh2G7DwCPx", 538 | "colab_type": "text" 539 | }, 540 | "source": [ 541 | "## Trening modelu" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "metadata": { 547 | "id": "I-1P3W2ociwS", 548 | "colab_type": "code", 549 | "colab": { 550 | "base_uri": "https://localhost:8080/", 551 | "height": 191 552 | }, 553 | "outputId": "a978bab7-2e0d-4074-d654-45e45f20c9ee" 554 | }, 555 | "source": [ 556 | "history = model.fit(padded_X_train_vals_ctr, y_train, epochs=5)" 557 | ], 558 | "execution_count": 106, 559 | "outputs": [ 560 | { 561 | "output_type": "stream", 562 | "text": [ 563 | "Epoch 1/5\n", 564 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3783 - accuracy: 0.8725\n", 565 | "Epoch 2/5\n", 566 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3756 - accuracy: 0.8718\n", 567 | "Epoch 3/5\n", 568 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3752 - accuracy: 0.8728\n", 569 | "Epoch 4/5\n", 570 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3760 - accuracy: 0.8723\n", 571 | "Epoch 5/5\n", 572 | "1875/1875 [==============================] - 3s 1ms/step - loss: 0.3738 - accuracy: 0.8734\n" 573 | ], 574 | "name": "stdout" 575 | } 576 | ] 577 | } 578 | ] 579 | } -------------------------------------------------------------------------------- /neural_networks/NSL/adversarial_regularization/adversarial_regularization_mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "adversarial_regularization_mnist.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyNcJ8TeO0Jlrt9MiDyM2qgm", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "JhX2byHcpx-w", 33 | "colab_type": "text" 34 | }, 35 | "source": [ 36 | "# Adversarial regularization for image classification" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "metadata": { 42 | "id": "PItAtRzGpfRV", 43 | "colab_type": "code", 44 | "colab": { 45 | "base_uri": "https://localhost:8080/", 46 | "height": 35 47 | }, 48 | "outputId": "23e280c2-25e0-4f63-da62-feeb2be97e83" 49 | }, 50 | "source": [ 51 | "!pip install -q neural-structured-learning" 52 | ], 53 | "execution_count": 1, 54 | "outputs": [ 55 | { 56 | "output_type": "stream", 57 | "text": [ 58 | "\u001b[?25l\r\u001b[K |███▏ | 10kB 23.5MB/s eta 0:00:01\r\u001b[K |██████▎ | 20kB 30.4MB/s eta 0:00:01\r\u001b[K |█████████▍ | 30kB 22.0MB/s eta 0:00:01\r\u001b[K |████████████▌ | 40kB 22.3MB/s eta 0:00:01\r\u001b[K |███████████████▋ | 51kB 20.5MB/s eta 0:00:01\r\u001b[K |██████████████████▉ | 61kB 23.2MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 71kB 18.5MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 81kB 18.0MB/s eta 0:00:01\r\u001b[K |████████████████████████████▏ | 92kB 18.0MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▎| 102kB 18.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 112kB 18.3MB/s \n", 59 | "\u001b[?25h" 60 | ], 61 | "name": "stdout" 62 | } 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "metadata": { 68 | "id": "69h_NWQMqJEd", 69 | "colab_type": "code", 70 | "colab": { 71 | "base_uri": "https://localhost:8080/", 72 | "height": 64 73 | }, 74 | "outputId": "2afed11e-8a08-4ccc-ca43-5e7dd37ede86" 75 | }, 76 | "source": [ 77 | "from __future__ import absolute_import, division, print_function, unicode_literals\n", 78 | "\n", 79 | "import matplotlib.pyplot as plt\n", 80 | "import neural_structured_learning as nsl\n", 81 | "import numpy as np\n", 82 | "import tensorflow as tf\n", 83 | "import tensorflow_datasets as tfds" 84 | ], 85 | "execution_count": 2, 86 | "outputs": [ 87 | { 88 | "output_type": "display_data", 89 | "data": { 90 | "text/html": [ 91 | "

\n", 92 | "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
\n", 93 | "We recommend you upgrade now \n", 94 | "or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic:\n", 95 | "more info.

\n" 96 | ], 97 | "text/plain": [ 98 | "" 99 | ] 100 | }, 101 | "metadata": { 102 | "tags": [] 103 | } 104 | } 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": { 110 | "id": "0K0XGMSPqs22", 111 | "colab_type": "text" 112 | }, 113 | "source": [ 114 | "### Hiperparametry" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "metadata": { 120 | "id": "hZIyxe8Hqmbp", 121 | "colab_type": "code", 122 | "colab": {} 123 | }, 124 | "source": [ 125 | "class HParams(object):\n", 126 | " def __init__(self):\n", 127 | " self.input_shape = [28, 28, 1]\n", 128 | " self.num_classes = 10\n", 129 | " self.conv_filters = [32, 64, 64]\n", 130 | " self.kernel_size = (3, 3)\n", 131 | " self.pool_size = (2, 2)\n", 132 | " self.num_fc_units = [64]\n", 133 | " self.batch_size = 32\n", 134 | " self.epochs = 5\n", 135 | " self.adv_multiplier = 0.2\n", 136 | " self.adv_step_size = 0.2\n", 137 | " self.adv_grad_norm = 'infinity'\n", 138 | "\n", 139 | "HPARAMS = HParams()" 140 | ], 141 | "execution_count": 0, 142 | "outputs": [] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "id": "PHE9rXs4rsLh", 148 | "colab_type": "text" 149 | }, 150 | "source": [ 151 | "### Zbiór danych MNIST" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "metadata": { 157 | "id": "TmJQNSqgrlaE", 158 | "colab_type": "code", 159 | "colab": {} 160 | }, 161 | "source": [ 162 | "datasets = tfds.load('mnist')\n", 163 | "\n", 164 | "train_dataset = datasets['train']\n", 165 | "test_dataset = datasets['test']\n", 166 | "\n", 167 | "IMAGE_INPUT_NAME = 'image'\n", 168 | "LABEL_INPUT_NAME = 'label'" 169 | ], 170 | "execution_count": 0, 171 | "outputs": [] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "metadata": { 176 | "id": "jnm1ycvysQNX", 177 | "colab_type": "code", 178 | "colab": {} 179 | }, 180 | "source": [ 181 | "def normalize(features):\n", 182 | " features[IMAGE_INPUT_NAME] = tf.cast(\n", 183 | " features[IMAGE_INPUT_NAME], dtype=tf.float32) / 255.0\n", 184 | " return features\n", 185 | "\n", 186 | "def convert_to_tuples(features):\n", 187 | " return features[IMAGE_INPUT_NAME], features[LABEL_INPUT_NAME]\n", 188 | "\n", 189 | "def convert_to_dictionaries(image, label):\n", 190 | " return {IMAGE_INPUT_NAME: image, LABEL_INPUT_NAME: label}\n", 191 | "\n", 192 | "train_dataset = train_dataset.map(normalize).shuffle(10000).batch(HPARAMS.batch_size).map(convert_to_tuples)\n", 193 | "test_dataset = test_dataset.map(normalize).batch(HPARAMS.batch_size).map(convert_to_tuples)" 194 | ], 195 | "execution_count": 0, 196 | "outputs": [] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "wzBCUYYSw2-V", 202 | "colab_type": "text" 203 | }, 204 | "source": [ 205 | "### Model bazowy" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "metadata": { 211 | "id": "ihr_UqDWtgHf", 212 | "colab_type": "code", 213 | "colab": {} 214 | }, 215 | "source": [ 216 | "def build_base_model(hparams):\n", 217 | " inputs = tf.keras.Input(\n", 218 | " shape=hparams.input_shape, dtype=tf.float32, name=IMAGE_INPUT_NAME)\n", 219 | " \n", 220 | " x = inputs\n", 221 | " for i, num_filters in enumerate(hparams.conv_filters):\n", 222 | " x = tf.keras.layers.Conv2D(\n", 223 | " num_filters, hparams.kernel_size, activation='relu')(x)\n", 224 | " if i < len(hparams.conv_filters) - 1:\n", 225 | " # max pooling między warstwami splotu\n", 226 | " x = tf.keras.layers.MaxPooling2D(hparams.pool_size)(x)\n", 227 | " x = tf.keras.layers.Flatten()(x)\n", 228 | " for num_units in hparams.num_fc_units:\n", 229 | " x = tf.keras.layers.Dense(num_units, activation='relu')(x)\n", 230 | " pred = tf.keras.layers.Dense(hparams.num_classes, activation='softmax')(x)\n", 231 | " model = tf.keras.Model(inputs=inputs, outputs=pred)\n", 232 | " return model" 233 | ], 234 | "execution_count": 0, 235 | "outputs": [] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "metadata": { 240 | "id": "E0-Ix8KUvUnY", 241 | "colab_type": "code", 242 | "colab": { 243 | "base_uri": "https://localhost:8080/", 244 | "height": 592 245 | }, 246 | "outputId": "69884dd5-f06a-485a-823a-e89c6c31cf33" 247 | }, 248 | "source": [ 249 | "base_model = build_base_model(HPARAMS)\n", 250 | "base_model.summary()" 251 | ], 252 | "execution_count": 9, 253 | "outputs": [ 254 | { 255 | "output_type": "stream", 256 | "text": [ 257 | "WARNING:tensorflow:From /tensorflow-1.15.0/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", 258 | "Instructions for updating:\n", 259 | "If using Keras pass *_constraint arguments to layers.\n" 260 | ], 261 | "name": "stdout" 262 | }, 263 | { 264 | "output_type": "stream", 265 | "text": [ 266 | "WARNING:tensorflow:From /tensorflow-1.15.0/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", 267 | "Instructions for updating:\n", 268 | "If using Keras pass *_constraint arguments to layers.\n" 269 | ], 270 | "name": "stderr" 271 | }, 272 | { 273 | "output_type": "stream", 274 | "text": [ 275 | "Model: \"model\"\n", 276 | "_________________________________________________________________\n", 277 | "Layer (type) Output Shape Param # \n", 278 | "=================================================================\n", 279 | "image (InputLayer) [(None, 28, 28, 1)] 0 \n", 280 | "_________________________________________________________________\n", 281 | "conv2d (Conv2D) (None, 26, 26, 32) 320 \n", 282 | "_________________________________________________________________\n", 283 | "max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0 \n", 284 | "_________________________________________________________________\n", 285 | "conv2d_1 (Conv2D) (None, 11, 11, 64) 18496 \n", 286 | "_________________________________________________________________\n", 287 | "max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0 \n", 288 | "_________________________________________________________________\n", 289 | "conv2d_2 (Conv2D) (None, 3, 3, 64) 36928 \n", 290 | "_________________________________________________________________\n", 291 | "flatten (Flatten) (None, 576) 0 \n", 292 | "_________________________________________________________________\n", 293 | "dense (Dense) (None, 64) 36928 \n", 294 | "_________________________________________________________________\n", 295 | "dense_1 (Dense) (None, 10) 650 \n", 296 | "=================================================================\n", 297 | "Total params: 93,322\n", 298 | "Trainable params: 93,322\n", 299 | "Non-trainable params: 0\n", 300 | "_________________________________________________________________\n" 301 | ], 302 | "name": "stdout" 303 | } 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "metadata": { 309 | "id": "i8SJt6H7vmSt", 310 | "colab_type": "code", 311 | "colab": { 312 | "base_uri": "https://localhost:8080/", 313 | "height": 225 314 | }, 315 | "outputId": "73be9d68-23ed-45cd-974a-3fe8bccb435b" 316 | }, 317 | "source": [ 318 | "base_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',\n", 319 | " metrics=['acc'])\n", 320 | "base_model.fit(train_dataset, epochs=HPARAMS.epochs)" 321 | ], 322 | "execution_count": 11, 323 | "outputs": [ 324 | { 325 | "output_type": "stream", 326 | "text": [ 327 | "Train on None steps\n", 328 | "Epoch 1/5\n", 329 | "1875/1875 [==============================] - 25s 13ms/step - loss: 0.1426 - acc: 0.9550\n", 330 | "Epoch 2/5\n", 331 | "1875/1875 [==============================] - 16s 9ms/step - loss: 0.0463 - acc: 0.9855\n", 332 | "Epoch 3/5\n", 333 | "1875/1875 [==============================] - 16s 9ms/step - loss: 0.0334 - acc: 0.9896\n", 334 | "Epoch 4/5\n", 335 | "1875/1875 [==============================] - 16s 9ms/step - loss: 0.0239 - acc: 0.9923\n", 336 | "Epoch 5/5\n", 337 | "1875/1875 [==============================] - 17s 9ms/step - loss: 0.0197 - acc: 0.9940\n" 338 | ], 339 | "name": "stdout" 340 | }, 341 | { 342 | "output_type": "execute_result", 343 | "data": { 344 | "text/plain": [ 345 | "" 346 | ] 347 | }, 348 | "metadata": { 349 | "tags": [] 350 | }, 351 | "execution_count": 11 352 | } 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "metadata": { 358 | "id": "n3x2EpLawHDD", 359 | "colab_type": "code", 360 | "colab": { 361 | "base_uri": "https://localhost:8080/", 362 | "height": 52 363 | }, 364 | "outputId": "8a725a8c-28cd-47ea-819c-3c9a0e32814e" 365 | }, 366 | "source": [ 367 | "results = base_model.evaluate(test_dataset)\n", 368 | "named_results = dict(zip(base_model.metrics_names, results))\n", 369 | "print('\\naccuracy:', named_results['acc'])" 370 | ], 371 | "execution_count": 12, 372 | "outputs": [ 373 | { 374 | "output_type": "stream", 375 | "text": [ 376 | " 313/Unknown - 3s 10ms/step - loss: 0.0325 - acc: 0.9906\n", 377 | "accuracy: 0.9906\n" 378 | ], 379 | "name": "stdout" 380 | } 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": { 386 | "id": "dTY2e7VYw_tp", 387 | "colab_type": "text" 388 | }, 389 | "source": [ 390 | "### Adversarial-regularized model" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "metadata": { 396 | "id": "NfyH08YDw_Ui", 397 | "colab_type": "code", 398 | "colab": {} 399 | }, 400 | "source": [ 401 | "adv_config = nsl.configs.make_adv_reg_config(\n", 402 | " multiplier = HPARAMS.adv_multiplier,\n", 403 | " adv_step_size = HPARAMS.adv_step_size,\n", 404 | " adv_grad_norm = HPARAMS.adv_grad_norm\n", 405 | ")" 406 | ], 407 | "execution_count": 0, 408 | "outputs": [] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "metadata": { 413 | "id": "67W6vp_Rypry", 414 | "colab_type": "code", 415 | "colab": {} 416 | }, 417 | "source": [ 418 | "base_adv_model = build_base_model(HPARAMS)\n", 419 | "adv_model = nsl.keras.AdversarialRegularization(\n", 420 | " base_adv_model,\n", 421 | " label_keys = [LABEL_INPUT_NAME],\n", 422 | " adv_config = adv_config\n", 423 | ")\n", 424 | "\n", 425 | "train_set_for_adv_model = train_dataset.map(convert_to_dictionaries)\n", 426 | "test_set_for_adv_model = test_dataset.map(convert_to_dictionaries)" 427 | ], 428 | "execution_count": 0, 429 | "outputs": [] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "metadata": { 434 | "id": "4asz8sS_0_70", 435 | "colab_type": "code", 436 | "colab": { 437 | "base_uri": "https://localhost:8080/", 438 | "height": 453 439 | }, 440 | "outputId": "cdb08b25-8ed1-418f-c86b-360399bb22ff" 441 | }, 442 | "source": [ 443 | "adv_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',\n", 444 | " metrics=['acc'])\n", 445 | "adv_model.fit(train_set_for_adv_model, epochs=HPARAMS.epochs)" 446 | ], 447 | "execution_count": 18, 448 | "outputs": [ 449 | { 450 | "output_type": "stream", 451 | "text": [ 452 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/neural_structured_learning/keras/adversarial_regularization.py:167: The name tf.losses.Reduction is deprecated. Please use tf.compat.v1.losses.Reduction instead.\n", 453 | "\n" 454 | ], 455 | "name": "stdout" 456 | }, 457 | { 458 | "output_type": "stream", 459 | "text": [ 460 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/neural_structured_learning/keras/adversarial_regularization.py:167: The name tf.losses.Reduction is deprecated. Please use tf.compat.v1.losses.Reduction instead.\n", 461 | "\n" 462 | ], 463 | "name": "stderr" 464 | }, 465 | { 466 | "output_type": "stream", 467 | "text": [ 468 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/neural_structured_learning/lib/adversarial_neighbor.py:97: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 469 | "Instructions for updating:\n", 470 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n" 471 | ], 472 | "name": "stdout" 473 | }, 474 | { 475 | "output_type": "stream", 476 | "text": [ 477 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/neural_structured_learning/lib/adversarial_neighbor.py:97: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 478 | "Instructions for updating:\n", 479 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n" 480 | ], 481 | "name": "stderr" 482 | }, 483 | { 484 | "output_type": "stream", 485 | "text": [ 486 | "WARNING:tensorflow:Output output_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to output_1.\n" 487 | ], 488 | "name": "stdout" 489 | }, 490 | { 491 | "output_type": "stream", 492 | "text": [ 493 | "WARNING:tensorflow:Output output_1 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to output_1.\n" 494 | ], 495 | "name": "stderr" 496 | }, 497 | { 498 | "output_type": "stream", 499 | "text": [ 500 | "Train on None steps\n", 501 | "Epoch 1/5\n", 502 | "1875/1875 [==============================] - 24s 13ms/step - loss: 0.2993 - sparse_categorical_crossentropy: 0.1377 - sparse_categorical_accuracy: 0.9599 - adversarial_loss: 0.8081\n", 503 | "Epoch 2/5\n", 504 | "1875/1875 [==============================] - 20s 11ms/step - loss: 0.1227 - sparse_categorical_crossentropy: 0.0431 - sparse_categorical_accuracy: 0.9870 - adversarial_loss: 0.3981\n", 505 | "Epoch 3/5\n", 506 | "1875/1875 [==============================] - 20s 11ms/step - loss: 0.0796 - sparse_categorical_crossentropy: 0.0325 - sparse_categorical_accuracy: 0.9897 - adversarial_loss: 0.2356\n", 507 | "Epoch 4/5\n", 508 | "1875/1875 [==============================] - 20s 11ms/step - loss: 0.0547 - sparse_categorical_crossentropy: 0.0246 - sparse_categorical_accuracy: 0.9920 - adversarial_loss: 0.1504\n", 509 | "Epoch 5/5\n", 510 | "1875/1875 [==============================] - 20s 11ms/step - loss: 0.0463 - sparse_categorical_crossentropy: 0.0184 - sparse_categorical_accuracy: 0.9942 - adversarial_loss: 0.1393\n" 511 | ], 512 | "name": "stdout" 513 | }, 514 | { 515 | "output_type": "execute_result", 516 | "data": { 517 | "text/plain": [ 518 | "" 519 | ] 520 | }, 521 | "metadata": { 522 | "tags": [] 523 | }, 524 | "execution_count": 18 525 | } 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "metadata": { 531 | "id": "LcOB61b-1fiZ", 532 | "colab_type": "code", 533 | "colab": { 534 | "base_uri": "https://localhost:8080/", 535 | "height": 72 536 | }, 537 | "outputId": "2f6b85a9-8d56-4a9b-c35b-aef5aef78ee8" 538 | }, 539 | "source": [ 540 | "results = adv_model.evaluate(test_set_for_adv_model)\n", 541 | "named_results = dict(zip(adv_model.metrics_names, results))\n", 542 | "print('\\naccuracy:', named_results['sparse_categorical_accuracy'])" 543 | ], 544 | "execution_count": 20, 545 | "outputs": [ 546 | { 547 | "output_type": "stream", 548 | "text": [ 549 | " 313/Unknown - 3s 11ms/step - loss: 0.0593 - sparse_categorical_crossentropy: 0.0343 - sparse_categorical_accuracy: 0.9895 - adversarial_loss: 0.1249\n", 550 | "accuracy: 0.9895\n" 551 | ], 552 | "name": "stdout" 553 | } 554 | ] 555 | } 556 | ] 557 | } 558 | -------------------------------------------------------------------------------- /neural_networks/RBN/rbn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "rbn.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyNMkvjA2+wmCQMUd75RE5uR", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "source": [ 33 | "#Radial Basis Network" 34 | ], 35 | "metadata": { 36 | "id": "q5daW45oZGBn" 37 | } 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": { 43 | "id": "55FRqfZ3O7Wy" 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "import numpy as np\n", 48 | "import pandas as pd\n", 49 | "from keras.layers import Layer, Dense, Flatten\n", 50 | "from keras import backend as K\n", 51 | "from keras.models import Sequential\n", 52 | "from keras.losses import binary_crossentropy" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "source": [ 58 | "X = np.load('./k49-train-imgs.npz')['arr_0']\n", 59 | "y = np.load('./k49-train-labels.npz')['arr_0']\n", 60 | "y = (y <= 25).astype(int)\n", 61 | "\n", 62 | "print(X.shape)\n", 63 | "print(y.shape)" 64 | ], 65 | "metadata": { 66 | "colab": { 67 | "base_uri": "https://localhost:8080/" 68 | }, 69 | "id": "QT-n9_P4cmWn", 70 | "outputId": "f2c966ea-1137-4579-e0ab-2a19cf910422" 71 | }, 72 | "execution_count": 11, 73 | "outputs": [ 74 | { 75 | "output_type": "stream", 76 | "name": "stdout", 77 | "text": [ 78 | "(232365, 28, 28)\n", 79 | "(232365,)\n" 80 | ] 81 | } 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "source": [ 87 | "class RBFLayer(Layer):\n", 88 | " def __init__(self, units, gamma, **kwargs):\n", 89 | " super(RBFLayer, self).__init__(**kwargs)\n", 90 | " self.units = units\n", 91 | " self.gamma = K.cast_to_floatx(gamma)\n", 92 | "\n", 93 | " def build(self, input_shape):\n", 94 | " self.mu = self.add_weight(name='mu',\n", 95 | " shape=(int(input_shape[1]), self.units),\n", 96 | " initializer='uniform',\n", 97 | " trainable=True)\n", 98 | " super(RBFLayer, self).build(input_shape)\n", 99 | " \n", 100 | " def call(self, inputs):\n", 101 | " diff = K.expand_dims(inputs) - self.mu\n", 102 | " l2 = K.sum(K.pow(diff, 2), axis=1)\n", 103 | " res = K.exp(-1 * self.gamma * l2)\n", 104 | " return res" 105 | ], 106 | "metadata": { 107 | "id": "TfuUdz2bbxIc" 108 | }, 109 | "execution_count": null, 110 | "outputs": [] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "source": [ 115 | "model = Sequential()\n", 116 | "model.add(Flatten(input_shape=(28, 28)))\n", 117 | "model.add(RBFLayer(units=10, gamma=0.5))\n", 118 | "model.add(Dense(1, activation='sigmoid'))\n", 119 | "\n", 120 | "model.compile(optimizer='rmsprop', loss=binary_crossentropy)" 121 | ], 122 | "metadata": { 123 | "id": "ye0X9YpfcYJZ" 124 | }, 125 | "execution_count": 3, 126 | "outputs": [] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "source": [ 131 | "model.fit(X, y, batch_size=256, epochs=3)" 132 | ], 133 | "metadata": { 134 | "colab": { 135 | "base_uri": "https://localhost:8080/" 136 | }, 137 | "id": "_p7Q3CgJeK_A", 138 | "outputId": "d38d9d79-49f7-481d-d587-67d3ec846ed2" 139 | }, 140 | "execution_count": 12, 141 | "outputs": [ 142 | { 143 | "output_type": "stream", 144 | "name": "stdout", 145 | "text": [ 146 | "Epoch 1/3\n", 147 | "908/908 [==============================] - 20s 21ms/step - loss: 0.6823\n", 148 | "Epoch 2/3\n", 149 | "908/908 [==============================] - 19s 21ms/step - loss: 0.6806\n", 150 | "Epoch 3/3\n", 151 | "908/908 [==============================] - 19s 21ms/step - loss: 0.6806\n" 152 | ] 153 | }, 154 | { 155 | "output_type": "execute_result", 156 | "data": { 157 | "text/plain": [ 158 | "" 159 | ] 160 | }, 161 | "metadata": {}, 162 | "execution_count": 12 163 | } 164 | ] 165 | } 166 | ] 167 | } -------------------------------------------------------------------------------- /neural_networks/RNN/seq2seq_sorting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "seq2seq_sorting.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPQj847hUoM1ycujsbogaLF", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "VDYwQHxP923y" 36 | }, 37 | "source": [ 38 | "#Seq2Seq: Sortowanie\n", 39 | "##Import bibliotek" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "metadata": { 45 | "id": "0pb-Wrij92WN" 46 | }, 47 | "source": [ 48 | "from tensorflow import keras\n", 49 | "from tensorflow.keras import layers\n", 50 | "import numpy as np\n", 51 | "\n", 52 | "TRAINING_SIZE = 50000\n", 53 | "NUMBERS_TO_SORT = 10" 54 | ], 55 | "execution_count": 3, 56 | "outputs": [] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": { 61 | "id": "8SY7nCYXHF9w" 62 | }, 63 | "source": [ 64 | "##Generowanie danych" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "colab": { 71 | "base_uri": "https://localhost:8080/" 72 | }, 73 | "id": "vktABSDC-Jgg", 74 | "outputId": "8da5a9c9-66f9-4871-b827-43f5badd9b37" 75 | }, 76 | "source": [ 77 | "class CharacterTable:\n", 78 | " \n", 79 | " def __init__(self, chars):\n", 80 | " self.chars = sorted(set(chars))\n", 81 | " self.char_indices = dict((c, i) for i, c in enumerate(self.chars))\n", 82 | " self.indices_char = dict((i, c) for i, c in enumerate(self.chars))\n", 83 | "\n", 84 | " def encode(self, C, num_rows):\n", 85 | " x = np.zeros((num_rows, len(self.chars)))\n", 86 | " for i, c in enumerate(C):\n", 87 | " x[i, self.char_indices[c]] = 1\n", 88 | " return x\n", 89 | "\n", 90 | " def decode(self, x, calc_argmax=True):\n", 91 | " if calc_argmax:\n", 92 | " x = x.argmax(axis=-1)\n", 93 | " return ''.join(self.indices_char[x] for x in x)\n", 94 | "\n", 95 | "# Wszystkie liczby, znaki i spacja\n", 96 | "chars = '0123456789'\n", 97 | "ctable = CharacterTable(chars)\n", 98 | "\n", 99 | "questions = []\n", 100 | "expected = []\n", 101 | "\n", 102 | "while len(questions) < TRAINING_SIZE:\n", 103 | " randomize_string = lambda: str(\n", 104 | " ''.join(\n", 105 | " np.random.choice(list('0123456789'))\n", 106 | " for i in range(NUMBERS_TO_SORT)\n", 107 | " )\n", 108 | " )\n", 109 | " query = randomize_string()\n", 110 | "\n", 111 | " # String to list\n", 112 | " ans = [int(q) for q in query]\n", 113 | " ans = sorted(ans)\n", 114 | "\n", 115 | " # Sorted list to string\n", 116 | " answer = ''\n", 117 | " for num in ans:\n", 118 | " answer += str(num)\n", 119 | " \n", 120 | " questions.append(query)\n", 121 | " expected.append(answer)\n", 122 | "\n", 123 | "print('Liczba przykładów:', len(questions))\n", 124 | "print('Questions: ', questions[:5])\n", 125 | "print('Answers: ', expected[:5])" 126 | ], 127 | "execution_count": 26, 128 | "outputs": [ 129 | { 130 | "output_type": "stream", 131 | "name": "stdout", 132 | "text": [ 133 | "Liczba przykładów: 50000\n", 134 | "Questions: ['8847644026', '3025793945', '1664364979', '2849673969', '3700480495']\n", 135 | "Answers: ['0244466788', '0233455799', '1344666799', '2346678999', '0003445789']\n" 136 | ] 137 | } 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": { 143 | "id": "B_665xMOHtSA" 144 | }, 145 | "source": [ 146 | "##Wektoryzacja danych" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "metadata": { 152 | "colab": { 153 | "base_uri": "https://localhost:8080/" 154 | }, 155 | "id": "1R4OKuJVG68v", 156 | "outputId": "549c6fab-6dfc-4b2a-a841-cad771634470" 157 | }, 158 | "source": [ 159 | "x = np.zeros((len(questions), NUMBERS_TO_SORT, len(chars)), dtype=np.bool)\n", 160 | "y = np.zeros((len(questions), NUMBERS_TO_SORT, len(chars)), dtype=np.bool)\n", 161 | "\n", 162 | "for i, sentence in enumerate(questions):\n", 163 | " x[i] = ctable.encode(sentence, NUMBERS_TO_SORT)\n", 164 | "for i, sentence in enumerate(expected):\n", 165 | " y[i] = ctable.encode(sentence, NUMBERS_TO_SORT)\n", 166 | "\n", 167 | "indices = np.arange(len(y))\n", 168 | "np.random.shuffle(indices)\n", 169 | "x = x[indices]\n", 170 | "y = y[indices]\n", 171 | "\n", 172 | "split_at = len(x) - len(x) // 10\n", 173 | "(x_train, x_val) = x[:split_at], x[split_at:]\n", 174 | "(y_train, y_val) = y[:split_at], y[split_at:]\n", 175 | "\n", 176 | "print('Dane treningowe:')\n", 177 | "print(x_train.shape)\n", 178 | "print(y_train.shape)\n", 179 | "\n", 180 | "print('Dane walidacyjne:')\n", 181 | "print(x_val.shape)\n", 182 | "print(y_val.shape)" 183 | ], 184 | "execution_count": 27, 185 | "outputs": [ 186 | { 187 | "output_type": "stream", 188 | "name": "stdout", 189 | "text": [ 190 | "Dane treningowe:\n", 191 | "(45000, 10, 10)\n", 192 | "(45000, 10, 10)\n", 193 | "Dane walidacyjne:\n", 194 | "(5000, 10, 10)\n", 195 | "(5000, 10, 10)\n" 196 | ] 197 | } 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": { 203 | "id": "s4Z6pNxiH2gx" 204 | }, 205 | "source": [ 206 | "##Budowa modelu" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "metadata": { 212 | "colab": { 213 | "base_uri": "https://localhost:8080/" 214 | }, 215 | "id": "2uRxKHkk_9bL", 216 | "outputId": "a41b8a76-0c2e-4030-e40f-ccb75e67ff6e" 217 | }, 218 | "source": [ 219 | "num_layers = 1\n", 220 | "\n", 221 | "model = keras.Sequential()\n", 222 | "model.add(layers.LSTM(16, input_shape=(NUMBERS_TO_SORT, len(chars))))\n", 223 | "model.add(layers.RepeatVector(NUMBERS_TO_SORT))\n", 224 | "\n", 225 | "for _ in range(num_layers):\n", 226 | " model.add(layers.LSTM(16, return_sequences=True))\n", 227 | "\n", 228 | "model.add(layers.Dense(len(chars), activation='softmax'))\n", 229 | "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", 230 | "model.summary()" 231 | ], 232 | "execution_count": 38, 233 | "outputs": [ 234 | { 235 | "output_type": "stream", 236 | "name": "stdout", 237 | "text": [ 238 | "Model: \"sequential_7\"\n", 239 | "_________________________________________________________________\n", 240 | "Layer (type) Output Shape Param # \n", 241 | "=================================================================\n", 242 | "lstm_20 (LSTM) (None, 16) 1728 \n", 243 | "_________________________________________________________________\n", 244 | "repeat_vector_7 (RepeatVecto (None, 10, 16) 0 \n", 245 | "_________________________________________________________________\n", 246 | "lstm_21 (LSTM) (None, 10, 16) 2112 \n", 247 | "_________________________________________________________________\n", 248 | "dense_7 (Dense) (None, 10, 10) 170 \n", 249 | "=================================================================\n", 250 | "Total params: 4,010\n", 251 | "Trainable params: 4,010\n", 252 | "Non-trainable params: 0\n", 253 | "_________________________________________________________________\n" 254 | ] 255 | } 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": { 261 | "id": "b8Y3x8I3H5SG" 262 | }, 263 | "source": [ 264 | "##Trening modelu" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "metadata": { 270 | "colab": { 271 | "base_uri": "https://localhost:8080/" 272 | }, 273 | "id": "J7j6Pq8xIB_0", 274 | "outputId": "a2dc2968-f20b-4043-afc1-9270126f973d" 275 | }, 276 | "source": [ 277 | "epochs = 5\n", 278 | "batch_size = 32\n", 279 | "\n", 280 | "for epoch in range(1, epochs + 1):\n", 281 | " print()\n", 282 | " print('Iteracja', epoch)\n", 283 | " model.fit(\n", 284 | " x_train,\n", 285 | " y_train,\n", 286 | " batch_size=batch_size,\n", 287 | " epochs=1,\n", 288 | " validation_data=(x_val, y_val)\n", 289 | " )\n", 290 | " # Wybór 10 losowych próbek ze zbioru walidacyjnego, \n", 291 | " # abyśmy mogli zobaczyć błędy\n", 292 | " for i in range(10):\n", 293 | " ind = np.random.randint(0, len(x_val))\n", 294 | " rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]\n", 295 | " preds = np.argmax(model.predict(rowx), axis=-1)\n", 296 | " q = ctable.decode(rowx[0])\n", 297 | " correct = ctable.decode(rowy[0])\n", 298 | " guess = ctable.decode(preds[0], calc_argmax=False)\n", 299 | " print('Q', q, end=' ')\n", 300 | " print('T', correct, end=' ')\n", 301 | " if correct == guess:\n", 302 | " print('☑ ' + guess)\n", 303 | " else:\n", 304 | " print('☒ ' + guess)" 305 | ], 306 | "execution_count": 39, 307 | "outputs": [ 308 | { 309 | "output_type": "stream", 310 | "name": "stdout", 311 | "text": [ 312 | "\n", 313 | "Iteracja 1\n", 314 | "1407/1407 [==============================] - 17s 10ms/step - loss: 0.8683 - accuracy: 0.7232 - val_loss: 0.3988 - val_accuracy: 0.8993\n", 315 | "Q 7543839439 T 3334457899 ☒ 2334557899\n", 316 | "Q 2589669405 T 0245566899 ☒ 0145566899\n", 317 | "Q 3739211020 T 0011223379 ☑ 0011223379\n", 318 | "Q 4230435049 T 0023344459 ☒ 0123344459\n", 319 | "Q 8195226339 T 1223356899 ☒ 0223356899\n", 320 | "Q 5887675366 T 3556667788 ☒ 1556667788\n", 321 | "Q 5740260425 T 0022445567 ☒ 0122445567\n", 322 | "Q 9658427089 T 0245678899 ☑ 0245678899\n", 323 | "Q 2872732756 T 2223567778 ☒ 2233567777\n", 324 | "Q 6575054576 T 0455556677 ☒ 0455566677\n", 325 | "\n", 326 | "Iteracja 2\n", 327 | "1407/1407 [==============================] - 13s 9ms/step - loss: 0.2551 - accuracy: 0.9442 - val_loss: 0.1640 - val_accuracy: 0.9680\n", 328 | "Q 8464997804 T 0444678899 ☑ 0444678899\n", 329 | "Q 2124903773 T 0122334779 ☑ 0122334779\n", 330 | "Q 6137035694 T 0133456679 ☑ 0133456679\n", 331 | "Q 7719639560 T 0135667799 ☑ 0135667799\n", 332 | "Q 1797418294 T 1124477899 ☒ 0124477899\n", 333 | "Q 2732399463 T 2233346799 ☑ 2233346799\n", 334 | "Q 6638838210 T 0123366888 ☑ 0123366888\n", 335 | "Q 3799983347 T 3334778999 ☑ 3334778999\n", 336 | "Q 0519453235 T 0123345559 ☑ 0123345559\n", 337 | "Q 3277231476 T 1223346777 ☒ 0223346777\n", 338 | "\n", 339 | "Iteracja 3\n", 340 | "1407/1407 [==============================] - 13s 9ms/step - loss: 0.0940 - accuracy: 0.9908 - val_loss: 0.0543 - val_accuracy: 0.9963\n", 341 | "Q 1764144907 T 0114446779 ☑ 0114446779\n", 342 | "Q 1481556210 T 0111245568 ☑ 0111245568\n", 343 | "Q 8801835874 T 0134578888 ☑ 0134578888\n", 344 | "Q 8141561751 T 1111455678 ☑ 1111455678\n", 345 | "Q 3746504061 T 0013445667 ☑ 0013445667\n", 346 | "Q 7179021402 T 0011224779 ☑ 0011224779\n", 347 | "Q 6905020396 T 0002356699 ☑ 0002356699\n", 348 | "Q 6346478826 T 2344666788 ☑ 2344666788\n", 349 | "Q 7245078965 T 0245567789 ☑ 0245567789\n", 350 | "Q 2627281325 T 1222235678 ☑ 1222235678\n", 351 | "\n", 352 | "Iteracja 4\n", 353 | "1407/1407 [==============================] - 13s 10ms/step - loss: 0.0370 - accuracy: 0.9977 - val_loss: 0.0254 - val_accuracy: 0.9984\n", 354 | "Q 3998260233 T 0223336899 ☑ 0223336899\n", 355 | "Q 7663668555 T 3555666678 ☑ 3555666678\n", 356 | "Q 8607591008 T 0001567889 ☑ 0001567889\n", 357 | "Q 2011430331 T 0011123334 ☑ 0011123334\n", 358 | "Q 2654766745 T 2445566677 ☑ 2445566677\n", 359 | "Q 8297424699 T 2244678999 ☑ 2244678999\n", 360 | "Q 6398993191 T 1133689999 ☑ 1133689999\n", 361 | "Q 9180138065 T 0011356889 ☑ 0011356889\n", 362 | "Q 2649767969 T 2466677999 ☑ 2466677999\n", 363 | "Q 0694320242 T 0022234469 ☑ 0022234469\n", 364 | "\n", 365 | "Iteracja 5\n", 366 | "1407/1407 [==============================] - 13s 10ms/step - loss: 0.0188 - accuracy: 0.9988 - val_loss: 0.0137 - val_accuracy: 0.9990\n", 367 | "Q 5852344260 T 0223445568 ☑ 0223445568\n", 368 | "Q 4426186161 T 1112446668 ☑ 1112446668\n", 369 | "Q 5216904308 T 0012345689 ☑ 0012345689\n", 370 | "Q 8159999274 T 1245789999 ☑ 1245789999\n", 371 | "Q 8701108145 T 0011145788 ☑ 0011145788\n", 372 | "Q 4473650311 T 0113344567 ☑ 0113344567\n", 373 | "Q 1234457849 T 1234445789 ☑ 1234445789\n", 374 | "Q 5719021444 T 0112444579 ☑ 0112444579\n", 375 | "Q 4772336309 T 0233346779 ☑ 0233346779\n", 376 | "Q 7848340530 T 0033445788 ☑ 0033445788\n" 377 | ] 378 | } 379 | ] 380 | } 381 | ] 382 | } --------------------------------------------------------------------------------