├── .gitignore ├── notebooks ├── 0-requirements.ipynb ├── 1-DCNN.ipynb ├── 2-tensor_manipulation.ipynb ├── 3-tensor-decomposition.ipynb ├── 4-tensor_regression.ipynb ├── 5-decomposition_with_pytorch_and_backprop.ipynb ├── 6-tensor_regression_layer_pytorch.ipynb └── images │ ├── FC.png │ ├── TRL.png │ ├── example-unfolding-fibers.png │ └── example_tensor.png └── slides ├── 1-deep-nets.pdf ├── 2-tensor-basics.pdf ├── 3-tensor-decomposition.pdf ├── 4-tensor-regression.pdf └── 5-tensor+deep.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.DS_Store 6 | data/ 7 | 8 | # C extensions 9 | *.so 10 | *.py~ 11 | 12 | # Pycharm 13 | .idea 14 | 15 | # vim temp files 16 | *.swp 17 | 18 | # Sphinx doc 19 | doc/_build/ 20 | doc/auto_examples/ 21 | doc/modules/generated/ 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | .pytest_cache/ 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *,cover 61 | .hypothesis/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | #Ipython Notebook 77 | .ipynb_checkpoints 78 | -------------------------------------------------------------------------------- /notebooks/0-requirements.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Installation instructions\n", 8 | "\n", 9 | "If you are new to Python or simply want a pain-free experience, I recommend you install the [Anaconda](https://www.anaconda.com/download/)) distribution. It comes with all you need shipped-in and ready to use!\n", 10 | "\n", 11 | "Once you have anaconda installed, you want to get either the [Jupyter Lab](http://jupyterlab.readthedocs.io/en/stable/), or the [Jupyter Notebook](http://jupyter.org/install.html). Typically, you can simply run:\n", 12 | "```bash\n", 13 | "conda install jupyterlab\n", 14 | "```\n", 15 | "\n", 16 | "For PyTorch, simply follow the [instructions](https://pytorch.org/). With conda, it should be something like:\n", 17 | "```bash\n", 18 | "conda install pytorch torchvision -c pytorch\n", 19 | "```\n", 20 | "\n", 21 | "Finally, to install TensorLy to get the latest version, you can clone the Github repository. In the command line, run:\n", 22 | "```bash\n", 23 | "git clone https://github.com/tensorly/tensorly\n", 24 | "cd tensorly\n", 25 | "pip install -e .\n", 26 | "```\n", 27 | "\n", 28 | "Or, if you have an issue during installation, you can also use conda:\n", 29 | "```bash\n", 30 | "conda install -c tensorly tensorly\n", 31 | "```\n", 32 | "\n", 33 | "Finally, you can install Scikit-Learn using pip. In the command line, just type:\n", 34 | "```bash\n", 35 | "pip install scikit-learn\n", 36 | "```\n", 37 | "\n", 38 | "# Checking whether you have the correct versions\n", 39 | "First, open a jupyter lab (just type jupyter lab in the command line and open the link that appears) . You can also run the code directly in Python or IPython but it's not as nice. Generally, if you are not already familiar with jupyter lab or jupyter notebook, I recommend you spend a little time exploring as it's a very useful tool, especially for data science.\n", 40 | "\n", 41 | "To check you have all you need, you can run the following:" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 1, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stderr", 51 | "output_type": "stream", 52 | "text": [ 53 | "Using numpy backend.\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "import torch\n", 59 | "import numpy as np\n", 60 | "import tensorly as tl\n", 61 | "import sklearn" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 2, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "PyTorch version: 0.4.1\n", 74 | "Tensorly version: 0.4.3\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "print('PyTorch version: {}'.format(torch.__version__))\n", 80 | "print('Tensorly version: {}'.format(tl.__version__))" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "For pytorch, you should have version of at least 0.4.0" 88 | ] 89 | } 90 | ], 91 | "metadata": { 92 | "kernelspec": { 93 | "display_name": "Python 3", 94 | "language": "python", 95 | "name": "python3" 96 | }, 97 | "language_info": { 98 | "codemirror_mode": { 99 | "name": "ipython", 100 | "version": 3 101 | }, 102 | "file_extension": ".py", 103 | "mimetype": "text/x-python", 104 | "name": "python", 105 | "nbconvert_exporter": "python", 106 | "pygments_lexer": "ipython3", 107 | "version": "3.6.6" 108 | } 109 | }, 110 | "nbformat": 4, 111 | "nbformat_minor": 2 112 | } 113 | -------------------------------------------------------------------------------- /notebooks/1-DCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.optim as optim\n", 12 | "from torchvision import datasets, transforms\n", 13 | "import torch.nn.functional as F\n", 14 | "\n", 15 | "import matplotlib.pyplot as plt" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "# choose the size of your minibatch\n", 25 | "batch_size = 32\n", 26 | "\n", 27 | "device = 'cpu'\n", 28 | "# to run on GPU, uncomment the following line:\n", 29 | "#device = 'cuda:0'" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "We will train a deep convolutional neural network on the MNIST dataset. It consists of 70,000 images (60,000 for training and 10,000 for testing) of hand written digits. Our task is to predict the digit represented by each image." 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Load the data \n", 44 | "\n", 45 | "Note the normalisation (remove the mean, divide by the std)." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "train_loader = torch.utils.data.DataLoader(\n", 55 | " datasets.MNIST('./data/', train=True, download=True,\n", 56 | " transform=transforms.Compose([\n", 57 | " transforms.ToTensor(),\n", 58 | " transforms.Normalize((0.1307,), (0.3081,))\n", 59 | " ])),\n", 60 | " batch_size=batch_size, shuffle=True)\n", 61 | "\n", 62 | "test_loader = torch.utils.data.DataLoader(\n", 63 | " datasets.MNIST('./data/', train=False, transform=transforms.Compose([\n", 64 | " transforms.ToTensor(),\n", 65 | " transforms.Normalize((0.1307,), (0.3081,))\n", 66 | " ])),\n", 67 | " batch_size=batch_size, shuffle=True)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "# Let's visualise the data" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "batch, labels = next(iter(train_loader))" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "This gives us a batch of images and the corresponding labels (the class each sample belongs to).\n", 91 | "batch has (batch_size, n_channels, height, width) and labels is simply of shape (batch_size, ).\n", 92 | "Since the samples of MNIST are black and white images, n_channels is 1, let's remove it." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "batch = batch.squeeze() # By default, removes the dimensions of size 1" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "ToTensor() converts the images from uint8 (values from 0 to 255) to float32 (ranging from 0 to 1). We typically would need to convert them back to uint8 images to properly visualise them. However, in this case, since they are just grayscale images, we don't have to." 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 6, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfIAAAJYCAYAAACZ/mTFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAIABJREFUeJzt3Xnc1lP+x/HPUWmVikpSUirbjDWDpLKVvSEmkmyDTPaIsWadmMkvLWSQJTvJvpOJJJkYS1qItCvt2uT7++POx+d83dfluu/7Wu7zvV7Px2Me8/7e3/O9ruPq7jp9z/mec1wURQIAAMK0SaErAAAAyo+GHACAgNGQAwAQMBpyAAACRkMOAEDAaMgBAAgYDTkAAAGjIU/DOdfCOfeSc26Jc26+c26oc65qoeuVVM65HZ1zbznnljnnZjjn/lzoOiWZc66vc26Sc26tc+7+Qtcn6ZxzK2P/2+CcG1LoeiWZc26sc26N+cynFrpOuUBDnt5wEVkoIk1EZDcR6Sgi5xa0Rgm18R9Iz4rICyLSQETOEpFRzrk2Ba1Yss0VkRtF5L5CV6QYRFFU55f/iUhjEVktIk8WuFrFoK/57NsWujK5QEOe3nYi8kQURWuiKJovIq+IyM4FrlNS7SAiW4vI7VEUbYii6C0ReU9EehW2WskVRdHoKIrGiMjiQtelCHWXkpuEcYWuCMJHQ57eYBHp4Zyr5ZxrKiKHSUljjuxzKX62S74rAuRBbxF5MGKN7Hy4xTm3yDn3nnOuU6Erkws05Om9IyV34MtFZLaITBKRMQWtUXJ9KSV3KJc656o55w6VkqGMWoWtFpBdzrnmUvK7/UCh61IE+otISxFpKiJ3i8jzzrlWha1S9tGQp+Cc20REXhWR0SJSW0S2FJH6IjKwkPVKqiiK1otINxE5QkTmi8glIvKElPwDCkiSU0Tk3SiKZha6IkkXRdEHURStiKJobRRFD0jJcN3hha5XttGQp9ZARJqJyNCNvwSLRWSkJPCXoLKIouh/URR1jKJoiyiKukjJv6QnFrpeQJadItyNF0okpQ/jBY2GPIUoihaJyEwR6eOcq+qcqycl41qfFLZmyeWc+6NzrsbGZxL6SclsgfsLXK3E2vh7XUNEqohIlY2fPdMrc8g5t5+UdPPytHqOOefqOee6/PJ77ZzrKSIHSElPa6LQkKd3rIh0FZHvRWSGiPwkIhcVtEbJ1ktE5knJWPlBInJIFEVrC1ulRLtKSqZAXS4iJ2/MVxW0RsnXW0RGR1G0otAVKQLVpGR65fciskhEzhORblEUJW4uueOhSQAAwsUdOQAAAaMhBwAgYDTkAAAEjIYcAICA0ZADABCwvM4Zdc7xiDwAABmIoiijxWu4IwcAIGA05AAABIyGHACAgNGQAwAQMBpyAAACRkMOAEDAaMgBAAgYDTkAAAGjIQcAIGA05AAABIyGHACAgOV1rXUA+dWqVSvNn3zyiebPPvtM80EHHeRds2rVqtxXDEDWcEcOAEDAaMgBAAgYDTkAAAFjjDwD/fr184733ntvza1bt9a8YcMGr9zbb7+t+corr9S8bt26bFcRKBP7O3zGGWd45+644458VwfwbL755pofffRRzY0bN/bKTZ06tdTr69Wrp7lr167euUWLFmneYYcdNP/www/lq2wlwB05AAABoyEHACBgLoqi/L2Zc/l7s3J4+eWXNR9wwAGaa9So4ZVzzpX5tU844QTNTz31VDlqlwy77rqr5jFjxmiuX7++5k8//dS75sknn9RMt2/5ffvtt5qbNWumeeLEiV65ffbZJ291qqxsl6uIyLbbbqu5S5cumnfaaaeMXu/pp5/WPGvWLO/chAkTNC9btqxM9UyqL774QrP9s8i0vbLf0emuueaaazTfdNNNZaliXkRRlFFjwx05AAABoyEHACBgRde1fvDBB3vHDz/8sOaGDRvm7H3Xrl2rOd4dN3PmzJy9b6GddNJJ3vFdd92luU6dOhm9xooVKzTbbt8pU6ZUsHbFJVXXevw7YK+99tI8efLk3FesElq9erV3HB9e+8VPP/3kHS9ZskSz/f2uWvXXCULVqlXzrlm6dKnmAQMGaB4+fLhXLsmzXY4//njv2H4v288u213r9on4k08+OaPXzie61gEAKAI05AAABIyGHACAgBXdym4XXHCBd5xqXHz27Nma4+NgCxcu1HzzzTdrrl69ulfu8ccf1zx37tyU5d555x3Nt9xyi+ZXXnml1LpVdrVq1dJ82223eefsuKGd9nTEEUdojq/eZHft6t69u+Ybbrih4pUtIuPHj9f8l7/8RXN8OuWpp56quVjHyC+77LKMyt1///3esX2eI5VzzjnHO+7Zs6fmQYMGabarQYqIHHfccZr/85//ZFS/UMRXX6tSpUrWXjv+O/zll19qvuqqq7L2PoXEHTkAAAGjIQcAIGBFN/3s+eef945tl6613XbbabbTdtJp3ry5d/zNN9+UWm7+/Pnese3atFN/5syZk9H7Vjb//ve/Ncc35LCbHOy+++6a16xZo/n000/3rrnnnns021Xf7Cpx+H1HH320ZruqXvw7wA4r2RXNkHt246X4Jkw9evTQPHr06LzVKVfslL5p06Z557bZZhvN9vsxPsxpr3v22Wc1v/baa5rHjh1b4boWCtPPAAAoAjTkAAAErOieWo+v2JRKnz59NMefHrVdXpts8uu/hYYOHVquOtinr0PtTt966601n3jiiSnLjRo1SrPtTrfiqzxZO+64o+b40IV98tU+mYoSzz33nGb7O1izZs1CVAcbXXLJJZptN7L9DhJJRne61alTJ81Nmzb1zqUa8rUzhkT8z+6rr77S3KpVK83xJ+JTsRu1iPx2c5vKjDtyAAACRkMOAEDAaMgBAAhY0U0/q1evnnf8ww8//O418RXW7A5evXr10mxXXoobMWKE5vhqQosXL/7dOlR2e++9t+YJEyZoju/Y9Ic//EHz9OnTS30tu7qdiMjf/vY3zel2TBs3bpzmjh07/k6Ni5udhnn44Yd755h+lltt2rTxjj/88EPNixYt0rzDDjt45davX5/biuXZxRdfrDm+AqSV6U5mq1at0my/JzJt45YvX+4d26my1113neZ8rrjJ9DMAAIoADTkAAAEruulny5Yt847t6kJ2c462bdtqjk9f6NKli+b4hhOWXRHOdiNlOgUuJLYL/eeff9Yc3/xgt91205yqaz2+GYrdcGbw4MEp67DZZptlVll44r/DzZo103zAAQdoTtpGHeV15plnarZTqET8qaqpVoS0Q0AiInXr1tX8z3/+U3PSutLj0g1Flkft2rUrdP3mm2/uHbdr106zXTVu++2398p99913FXrfbOCOHACAgNGQAwAQsKLrWo8/wWi7hJ966inN8dXcrFTd6fGVxvbYYw/NSexOtz7++GPNH330kWbbPSUi8uijj2q+6aabSn0t29UoItKoUaNSy9kufBGRZ555JrPKwhP/O2F/v08++WTNdK2XsE/y273ERUQOPfRQzddcc41mu1HSFlts4V1jVxCzXetJZ2e37Lvvvhldk24oM5X494R9ut2yw6wiIptuuqnmatWqaY4PmRx77LGa7WZE+cQdOQAAAaMhBwAgYEXXtZ7OLrvsUuZr7AYqN954o3du6dKlFa5TiHr37q05vriL3RPbPv1p9xnOtPvs4Ycf9o7jT7sjM+k+76222iqPNQnD7bffrtl2uYqI9O/fX7PdRMk+gR5/Gv28887TnPQhOMvOjki3aEumC8LYp8fffPNNzS+99JJX7umnny71+iOPPNI7fvzxxzXHu92tW2+9VfP48eM1xzd4ySXuyAEACBgNOQAAAaMhBwAgYEW3aUp8040PPvhA84477ljm11uxYoXm+MpA+C07HmhX2XvwwQc1n3766d4199xzj+Yff/xR84EHHuiVmzhxYtbqmXQnnnii5lGjRnnn7JjkCy+8oNk+34AS8ecL7Hhpv379Sr3m008/9Y7tCmepVjtMIvt9Gf87bz+TlStXarbPJ8S9/fbbmuObNZVHt27dNNvnceLj5fZ3wLYnmU6pS4dNUwAAKAI05AAABKwopp/dfPPNmjt37uydy6Q7Pd59ls/hiKQZMmRIha63G6jQlV5+8+bN05xu+pndZx6/Ff8uiHebl1buD3/4g3fObtY0c+ZMzXb6k4j/u1+zZk3N8RX3pkyZ8nvVFhF/xTM7/TNf7NBavMs8XRd6vthV2k444QTNdgVQEZHq1atrbtiwYe4rVgruyAEACBgNOQAAAUts1/pOO+2k+eyzz9Zcv379lNfY7i+7MP6WW27plbNPvtsnGO1TjiKFW0A/yRjWyI4PP/xQ84IFC7xzjRs31lyorsJQ2U1TbHf1JZdcojn+O2w3Xtl55501DxgwIBdVVHblycmTJ2u2Xf0iIhdddFFO6xGCF198UfPy5cu9c/bvSMuWLTV36dLFK/fqq6/mqHbckQMAEDQacgAAAkZDDgBAwBI7Rm7Hp9ONi1t2daEHHnhA8+zZs71ydozc7n60xx57pKwDMnfyyScXugqJt2rVKs3pVsHKdCe6YhV/Lub444/XbMdS77jjjpSvYadkNm3aVHN857mLL744ozrZ1cXs+O3LL7/slWvVqpVmO2WtEFPRQnLvvfd6x5dffrnmtWvXal68eHHe6sQdOQAAAaMhBwAgYInpWrdTOET8biPLrmYkInLWWWdptht3ZGr9+vWa49M2UD77779/ynN2ygyy4+uvv/aOmzVrppnpfukdcMAB3vGmm26q+ZFHHinz682ZM6fULPLb77iKGj9+fFZfryK2224779hO3brrrrvyXZ1ys39mkyZNytv7ckcOAEDAaMgBAAhYYrrW7YYCIiJr1qzRbDcYiD+Red9995X6enaFntq1a6d8X/s+Tz/9dGaVxW/07t1bc9WqqX8t43tno+LeeOMN77hTp06lljvmmGO842effTZXVarU2rRpo/mMM87wztnvl2HDhuWtTqF77rnnvGO7MueXX36peezYsfmqUlC4IwcAIGA05AAABIyGHACAgCVmjPztt9/2jidMmKC5c+fOmu1KbCIiM2fO1HzppZdqHjRokObNN9/cu2bRokWajzvuuHLWGJZdLS/OTvFjjDz74lPM7LHNPXr08MoV6xj5lVdeqblu3breuY8//lizHdtFevHnCYYPH655xIgRmu0uciIiL7zwQm4rtpF9Tqpr164pyxVqJUTuyAEACBgNOQAAAUtM13rc1Vdfrfmtt97SbFdeEhHZdtttNT/xxBMZvbZdGH/cuHHlrSKMWrVqpTxnuyiXLFmSj+oUlXh3YKruwfbt2+ejOpWeXYUsPizx2muv5bs6iWQ/19atW2u23ewi/vS/r776SvP06dPL/J41atTwju3mTVdccYXmFi1apHyNQg03cUcOAEDAaMgBAAiYy+emCM65guzAYFeuim9ykG4VsV/YrnQRkcGDB2u2e9Gi/CZOnKi5Xbt23rn//e9/mnfddde81alYNG7c2Du2m//Y/aztBkMiv92XOcns98TUqVM1f/PNN165gw46KF9VSpRGjRp5x7ab3D4xnq69WrVqlebvv/++zHWIz2iye8OnY+tqN3xauHBhmesQF0VRRo/Bc0cOAEDAaMgBAAgYDTkAAAFL7PQz6+CDD9Z86qmneufs7mczZszQPGbMGM233367d828efOyXENY+XxuAyILFizwjrfaaqsC1aTyslPyNmzYoLmYnhPIpfh48pQpUzTvtddeGb2GHUu3K0Vm+n0Sn3aZ6XVPPfWU5myMi5cHd+QAAASMhhwAgIAVxfQzVH4DBw7UfM4553jnrrrqKs1DhgzJW52A0jRp0kQzw2zIJaafAQBQBGjIAQAIGF3rAABUQnStAwBQBGjIAQAIGA05AAABoyEHACBgNOQAAASMhhwAgIDldfoZAADILu7IAQAIGA05AAABoyEHACBgNOQAAASMhhwAgIDRkAMAEDAacgAAAkZDnoZzroVz7iXn3BLn3Hzn3FDnXNVC1yupnHOjnHPznHPLnXPTnHNnFrpOSeac29E595ZzbplzboZz7s+FrlOSOedWxv63wTk3pND1SjLn3Fjn3BrzmU8tdJ1ygYY8veEislBEmojIbiLSUUTOLWiNku0WEWkRRVFdETlaRG50zu1Z4Dol0sZ/kD4rIi+ISAMROUtERjnn2hS0YgkWRVGdX/4nIo1FZLWIPFngahWDvuazb1voyuQCDXl624nIE1EUrYmiaL6IvCIiOxe4TokVRdHnURSt/eVw4/9aFbBKSbaDiGwtIrdHUbQhiqK3ROQ9EelV2GoVje5ScpMwrtAVQfhoyNMbLCI9nHO1nHNNReQwKWnMkSPOueHOuR9F5EsRmSciLxW4SknlUvxsl3xXpEj1FpEHI9bIzodbnHOLnHPvOec6FboyuUBDnt47UnIHvlxEZovIJBEZU9AaJVwUReeKyGYi0kFERovI2vRXoJy+lJI7wkudc9Wcc4dKydBRrcJWK/mcc82l5LN+oNB1KQL9RaSliDQVkbtF5HnnXOJ6+WjIU3DObSIir0pJY1JbRLYUkfoiMrCQ9SoGG7t63xWRbUSkT6Hrk0RRFK0XkW4icoSIzBeRS0TkCSn5Byty6xQReTeKopmFrkjSRVH0QRRFK6IoWhtF0QNSMnx0eKHrlW005Kk1EJFmIjJ04y/BYhEZKQn8JajEqgpj5DkTRdH/oijqGEXRFlEUdZGSO5eJha5XEThFuBsvlEhKH1YKGg15ClEULRKRmSLSxzlX1TlXT0rGtT4pbM2SyTnXyDnXwzlXxzlXxTnXRUROFJG3Cl23pHLO/dE5V2PjMyD9pGR2xv0FrlaiOef2k5JuXp5WzzHnXD3nXJeNv+NVnXM9ReQAKelpTRQa8vSOFZGuIvK9iMwQkZ9E5KKC1ii5IinpRp8tIktE5J8icmEURc8WtFbJ1ktKHihcKCIHicghZtYAcqO3iIyOomhFoStSBKqJyI1S8v29SETOE5FuURQlbi6546FJAADCxR05AAABoyEHACBgNOQAAASMhhwAgIDRkAMAELC8bsnpnOMReQAAMhBFUUaL13BHDgBAwGjIAQAIGA05AAABoyEHACBgNOQAAASMhhwAgIDRkAMAEDAacgAAAkZDDgBAwGjIAQAIGA05AAABy+ta60nx+eefa27atKnmCy64wCv3wAMP5K1OAIASW2yxhebVq1d751q1alWh127btq13vP3222tu165dRq9x3HHHVagOcdyRAwAQMBpyAAACRkMOAEDAXBTlb4vwyrgfebNmzTTbcRWrSZMm3vGYMWM0V6tWTfP333/vlWvcuHE2qlh0GjVq5B3bz3XHHXfUXL9+/ZSv8c0332ieM2dO9ioHoNKw3wFPPfWU5j/84Q+aV61a5V2z7bbb5r5iv2OTTTK7h2Y/cgAAigANOQAAASuK6Wf77LOP5gsvvNA7d8wxx2iuXr16hd6nTp063vFWW22lef78+RV67VBddNFF3nHfvn1/95p69ep5x8uWLdO89dZba950001TvsaKFSs0L1q0yDtnpw+OHTtWc3y64OLFi3+3rgDyxw5livhTy+w0MDtMmmrItLx+/vnnlOemT5+u+b///a93btKkSVmth8UdOQAAAaMhBwAgYIl9av3BBx/UfNJJJ2nO9GnBdevWpTyXrkvXOvbYYzXbJ92Trk+fPpqHDBnincv08y+E0aNHe8fdu3cvUE1y7+KLL9bcs2dP79zuu++uecmSJZqff/55r9w555yjec2aNdmuIiAiIh06dND88ssve+fscOZpp52muWHDhpq//vpr75o2bdqUuQ72yffBgweX+fry4ql1AACKAA05AAABoyEHACBgiZl+NmDAAO+4V69emu1zAMuXL/fK2fGOe+65R/N3332X8r369eun+bbbbiv1fURE/v73v2t+9dVXNcd340maH3/8UfOHH36YspydnmGnge28885eub333lvzyJEjNe+///4py+25556aq1bN7Nf8j3/8Y0blQnLIIYdots8A1K5dO+U1drqene5zyimneOV22203zffff7/m//u//ytXXQvNTo200xxF/Odd7O5X5XnGyDl/2LOir5Huevt9F5/WGYqjjjpKc82aNb1zL730kubDDz88b3WqbLgjBwAgYDTkAAAELDHTzxYuXOgdb7nllpqnTJmiOd798u2331bofdOt8mPdd999ms8888wKvSdK16lTJ832827RokXKa+y0kj/96U/euS+++CJrdcuX/fbbzzu2QxZ2iMFuRPPKK69411x11VWa7TDJv/71L6+c7Wq3fw+OPPJIr1z89QvJbqYh4g952e+MKlWqpHyNTLu1M7k+/hr2u8pOoRIRmTZtmmbbvW9XPhQRadmyZamvfeWVV3rlBg4cWJZq51T8M7G/x/Z3OP7nsnLlSs12Bc9Zs2Zpjn9H29/pyo7pZwAAFAEacgAAApaYp9bTPYVrn5yuaFd63Pvvv6953333TVnu5JNP1my780REnnzyyazWqVjEV42zKzvZ34d496c9Hj58uOa5c+dmu4p50bx5c83PPPOMd852p9tV2vbaay/N6WZoWDNmzEh5zq7Yt/nmm2f0eoVw9dVXe8eNGzcu82vYlcLiv1uDBg3SnOrziq9uaLt+J0yYoLlRo0Zeua+++kpzq1atNNsNgkREJk+erNlu3LTZZpuVWp/KYNSoUd7xiSeemNF1dmW3zz77rNQy8VlC48eP1/z2229rHjFihFcupE2TuCMHACBgNOQAAASMhhwAgIAlZoy8UC666CLNdrxFxF+FyO6YZsfLRRgj/z12tyI7lSndtDI7drl27VrvnN3V7MUXX8xCDQvLTtWJT1my/+12JbZ04+I1atTQfNddd2nu3bu3Vy7VZ2zHICub+I52qaaPxX9unz04/vjjs1+xUsTHvi07Xt61a1fvnF3BLd2qlpVJLldli68Gd9BBB5WaL7zwQq+c3SEwPoZf2XBHDgBAwGjIAQAIWGK61u2qRyJ+N2KPHj0033zzzWmvK6uJEydqjk+7mT9/vuYGDRqUeg1K2M/uvPPO885df/31Gb2G/bztZgrFvJKeXTHr1FNP1WxXBouvfGW7au1qZ+lWMbPTpjKdzlaZxafx5as7vTwuv/xy79gOjdju9FtvvTVvdSqru+++2zvu2LGj5ieeeELzI488ktHrHXzwwZrjU9kOPPBAzdWrV9dsf9dF/LZizZo1mp966qmM6pBP3JEDABAwGnIAAAKWmE1T7B7hIqm7kT7++GPv2D4tabtmy6Nz587e8Ztvvql53bp1muP7aE+aNKlC7xuqAw44QPMdd9yhOd2+4Hav7GeffdY7Z58yTffEb9J069ZN82OPPeads7MlUok/1R9fCesX8f2s7RCR7a6szJtSTJ061Ttu3bq1ZrsBR/xJ8Mr2JL5drXLPPfdMWe6dd97RHP9+gsijjz6q+ZhjjvHO2SEK+32Sz5UL2TQFAIAiQEMOAEDAEtO1Ht8Q4KOPPtK8/fbbp7xu9uzZmu1GEvH9zVPZbrvtNNsnd0X8xTls9/nee++d0Wsnkd0T2e57bZ8ejbNdnnZjlKeffjrLtQtffL/tE044QbNdOOatt97SHO+Onzdvnubp06drbtKkiVfODl/Fn5yurOKzVvr376/ZboZiu9wLJb4RlF28yA5Lxb/D7aZMdgGcyjzkURnceeed3vHZZ5+t2c7suPbaa71yN910U87qRNc6AABFgIYcAICA0ZADABCwxIyRx9mpMnZsad999015jZ12Y68544wzvHJ21bjXX39dc3yqj309O51l3LhxaeueJHYKh4g//adZs2YZvcZPP/2kecSIEZovueQSr5yd4ofys6tn2VUR7epWIv6f3+LFi3NfsSyI/8598cUXmu3008owRn7kkUd6x3a6pV2xLz723alTJ83FOrW1PKpW9Rc6fe655zTb7+94m2mn/8WnN1cUY+QAABQBGnIAAAKWmE1T4pYuXaq5ffv2mocNG+aV++tf/6q5Vq1amu20jQ4dOnjX2MX1bXfMqlWrvHJnnXWW5lC60+N79zZq1Ejzt99+W+HXf+211zQfddRRmu2wRHwFMfsZ/+1vf9Ns/7xEfjsEgsw0b97cOz766KNLLRffvCaU7nQrvqHLPffcozndsFu+2KGoAQMGpCxnV+Oz32EidKeXlx3CExEZPHiw5kMPPVTzJpv49792pc5sd61nijtyAAACRkMOAEDAEvvUeqbOOecczf/3f/+nOZPNJkT87piDDjrIOxdKd7plhwNERAYNGqTZbjDzn//8J6vvazcisE/dioiMHj1as31aN/4U9S677KLZrtKF9OzGGiK/HUr6hd0YRURk7NixuapS0Ro5cqTmU045JWU5uxpfz549c1qnbLIrCrZp00az3fBIxN+DvFAuvfRSzQMHDkxZzq7Ume1hDZ5aBwCgCNCQAwAQMBpyAAACltjpZ5my4692nCbdjmmp2GluIv4ObKHsPNSvXz/v2NY7vrtbNi1btkxz/LO34+Lpfh7fAQ+Z2WOPPVKe+/zzzzUzJp4bdse6bt26aY7/fttVEUMaF7c6d+6s2T6f9fDDD3vllixZotmunplL8VUot9hii7y8bzZwRw4AQMBoyAEACFjRd60///zzmjPtTl+5cqXmOnXqaL755pu9cnZ1uP32209zqJt7tG3bVvOnn36a1de2m3PEN0NJJT69b8GCBVmtU5INHz5cc+3atVOWu/HGG/NRnaJiu9JF/CGLunXral6xYoVXzk5zSpoqVap4x/fdd5/m4447TvPEiRMr/F5NmzbVbKdUXn311V65VO2BXVVPxB8WLBTuyAEACBgNOQAAASu6rvV//etf3nG7du0026dEv/zyS822a0fE7/KaPn265vhqcPZp4Keeekpzqk0pKoP4f4PdIOakk07SfMUVV5Tr9e2GKPZJ1UMOOURzfF/g9evXa549e3ap9RERWbRoUbnqVIzOPPPMlOfskMUzzzyTj+oUlXgXbnyToF/Y4Q+R33a1J5l9YvzPf/6zZrt5STp2I5n4d5rdbCkOE7fyAAAgAElEQVTTmS529s7uu+/unbNtQKFwRw4AQMBoyAEACBgNOQAAASu6MfJevXqlPGfHz+20m6VLl6a8Zscdd9QcX/msYcOGmu1YvN3pS6RyTF/4xfvvv+8dN2/eXLPdjSk+Hv3dd9+V+nr2v1vE310t1fiUnd4nInLhhRdqttNSUDZ2pz/7HMKGDRu8cnaFw1CnSlY29jO101JF/BXO7A6MI0aMyH3F8uzss8/WPHToUM3x52LsKmv9+/fPWX3sZ798+XLv3DfffKP5sssu01wZxsTjuCMHACBgNOQAAATM2a6FnL+Zc/l7sxQWLlzoHdvpVaNGjdJsu5EzFb/m/vvvL7Wc7T4T8bvdCu3II4/0jp999lnNqTYvyYZPPvlE89133+2du/POO3P2vklnp/HMmjVLc82aNTU/8sgj3jUnn3xy7itWBBo0aKB58uTJmps1a+aVs9/BXbt21ZyvzUIK5dxzz9V83XXXeefs93J52E1X7HeLiMjIkSM126Gjxx9/vELvmQtRFGX0pcsdOQAAAaMhBwAgYEX31PqgQYO842OPPVaz3RTEPuF71113pXy90047TXO8yyyVnXbaKaNyhfDCCy94x3YjmN69e2veZpttMnq9+NCN3fTgscce0zx48OAy1ROZsb9rtjt9zZo1mi+44IK81qlY2Nkp9rshPkRlZ4B8/fXXua9YJWFXrot3azdp0qTUa8477zzveMiQIZrtxlRjxozRHB9OTSLuyAEACBgNOQAAAaMhBwAgYEU3/QzlZ1db2m233TK65ueff/aO7Rg5ss/u7CTir0Jlxx3tKoaXXnpp7itWhOx0VDvlafHixV65/fffX/O0adNyXzEEg+lnAAAUARpyAAACRtc6kCB2Co6IyLvvvqt5/fr1mqtXr563OhWTl156SXPHjh0122EpO+1SRKRnz565rxiCRNc6AABFgIYcAICAFd3KbkCSff75597xF198oXnTTTfNd3WKjl0d0nanr127VvOwYcPyWickH3fkAAAEjIYcAICA0ZADABAwxsiBBFm2bJl3vMsuuxSoJrDj4l27dtU8fvz4QlQHCcYdOQAAAaMhBwAgYHld2Q0AAGQXd+QAAASMhhwAgIDRkAMAEDAacgAAAkZDDgBAwGjIAQAIGA05AAABoyEHACBgNOQpOOf6OucmOefWOufuL3R9ioFzbpRzbp5zbrlzbppz7sxC1ynJnHMNnHPPOOdWOee+dc6dVOg6JRnfKfnlnFsZ+98G59yQQtcrF9g0JbW5InKjiHQRkZoFrkuxuEVEzoiiaK1zbgcRGeucmxxF0UeFrlhCDRORdSLSWER2E5EXnXOfRFH0eWGrlVh8p+RRFEV1fsnOudoiskBEnixcjXKHO/IUoigaHUXRGBFZXOi6FIsoij6PouiXLaOijf9rVcAqJdbGL7bjROTqKIpWRlH0rog8JyK9Cluz5OI7paC6i8hCERlX6IrkAg05KhXn3HDn3I8i8qWIzBORlwpcpaRqIyIboiiaZn72iYjsXKD6ALnUW0QejBK6uQgNOSqVKIrOFZHNRKSDiIwWkbXpr0A51RGRZbGfLZOSzx5IDOdccxHpKCIPFLouuUJDjkoniqING7t6txGRPoWuT0KtFJG6sZ/VFZEVBagLkEuniMi7URTNLHRFcoWGHJVZVWGMPFemiUhV51xr87NdRYQH3ZA0p0iC78ZFaMhTcs5Vdc7VEJEqIlLFOVfDOcdT/jninGvknOvhnKvjnKvinOsiIieKyFuFrlsSRVG0SkqGLq53ztV2zrUXkWNE5KHC1iy5+E7JP+fcfiLSVBL6tPovaMhTu0pEVovI5SJy8sZ8VUFrlGyRlHSjzxaRJSLyTxG5MIqiZwtaq2Q7V0qmQS0UkUdFpA9Tz3KK75T86y0io6MoSvSQkUvoQ3wAABQF7sgBAAgYDTkAAAGjIQcAIGA05AAABCyvUx+cczxZBwBABqIocpmU444cAICA0ZADABAwGnIAAAJGQw4AQMBoyAEACBgNOQAAAaMhBwAgYDTkAAAEjIYcAICA0ZADABAwGnIAAAKW17XWAVQOJ5xwgubHHnvMO+fcr8s7f/vtt5pfffVVr9yFF16oefXq1dmuYmKddtpp3nHPnj019+3bV/Ps2bO9citXrsxtxRAs7sgBAAgYDTkAAAGjIQcAIGAuivK3RXjS9yO/4YYbvOOrrrqq1HKXXXaZd3zbbbflrE4VZcdL7Vhe/L+1RYsWpV4fHzt96KGHSi03dOhQzZ9++mlZq4mNtthiC+948eLFpZabNGmS5j322KNc7/XVV19p7tOnj+Y33nijXK+XZJtttpnmm2++2Tv3t7/9rdRr4mPpDzzwQPYrhkqN/cgBACgCNOQAAASM6WcV1KxZM81nnnmmdy7VsEU+hzMqqnr16poffPDBlOVS/TfVqFHDO/7rX/9aark//OEPmuPTnG688UbNGzZsSF3ZIrL33ntrvummmzTvvPPOXjk7fezhhx/O6LXHjRunuVGjRppr1qzplWvVqpXmf//735rjQ0qZvm+S7bTTTppTdaXjV+vXr9c8YMAAzcOHD/fK/fDDD3mrU2XGHTkAAAGjIQcAIGB0rVeQ7Spu3LhxAWuSGz///LPmn376SXPVqtn91dlnn31KzSIiTZs21XzWWWdl9X1DEf/devnllzXXr18/5XVbbbWVZvt0+n333ac5/mdpn5b++uuvNdeqVcsr99FHH2lu27at5iuvvNIrV6xd65tuuqnm6667rnAVCVCVKlU0X3/99Zrjv1v/+te/NKeaJVQMuCMHACBgNOQAAASMhhwAgICxsls5HHLIIZpfeuklzXZcJ27ZsmWamzRp4p1bs2ZNFmuXO926ddMcn4r23nvvaR4/frzm7t27e+W22247zXZqWrox9+XLl2u2066mTZuWSbUTIb4a2OWXX15qOTsVTcT/szj99NM122c7li5dWq46tWzZUvOMGTM0L1iwwCu36667al64cGG53itEdjU3+/c/nbfeekvzGWec4Z2zUwmTzj6bk866des029/1WbNmaZ47d653jX0Gx640OXLkyDLXM9dY2Q0AgCJAQw4AQMCYflYOHTp00JyuO92y0yRC6UqPGzNmjOa6detmdE18cxXrhBNO0Hz77bdrjg892PeyU6iS3rXevn17zf369UtZznbHXnvttd4520Vpp6xlg+2+tJuw7LXXXl65XXbZRbOta9JNmDChzNfY7uFi6kqPW7VqlWb7fRlfPdNuQNWuXTvNnTp1yuh9OnbsqDk+zHz//fdn9BqVAXfkAAAEjIYcAICA0bWeAbuymEjqjT/ilixZonnw4MFZrVOo7JPqa9eu1VynTp2Mrt92222zXqfKym5SEn+q33YD2iGPTJ/2zQa70l+64SLb3V9MXestWrQodBWC9eWXX2q2Gyp9/PHHXjk7/GTZ4Rzb5S4i8pe//EVzly5dNA8cONArR9c6AADICxpyAAACRkMOAEDAGCPPgB2DFEm9y5kd8xURueiiizSvWLEi+xWrpOyuT7fddpt3zk73+OMf/1jm1953330125XFRPyduuy42tSpU71ydjWoyuywww5Lee6KK67QPHTo0HxUp9zs8w92rN+OsQOWfZZizz331FyvXj2vXKopep999lmpOe7QQw8tbxUrFe7IAQAIGA05AAABo2s9hSOPPFJzpl3AAwYM8I7jG4skSXw61BFHHKH5+uuv12y7uLPh6KOP1nzggQd656ZMmaLZbtQR71q/5pprNL/44ouaK0NX79Zbb63ZbnIS9+677+ajOlmx++67a27durVm++cFWNtss02pP4+vGvjJJ5/kozqVHnfkAAAEjIYcAICA0bVu2G5N+7R1tWrVUl5jn1SfPHlybipWCTVo0MA7fuaZZ/Jeh/hqcPEVnH4R7963df33v/+t+dxzz/XKbdiwoaJVLDP7VO7mm2+e9/cHKgO7l7vVtm3bCr/27NmzK/walQ135AAABIyGHACAgNG1btgF9NN14dju9PPPP1/zq6++mpuKVULnnXdeoauQFXYDnLvvvts799FHH+W7Ot7CQStXrtSc6aYyQBLYBWHsDKI///nPXjm7H3mmXn/9dc3OOc0NGzb0ytmZL5X96XjuyAEACBgNOQAAAaMhBwAgYIyRG/HpR6l89913mu30pWJiV+jKtdWrV2u2G6Nkasstt/SOU21607dvX+/4tNNOK/N7VZT93fr88881/+lPf8p7XX5Ps2bNNKea+ici8umnn2r+6quvclonJMM999yjebfddtM8bty4rL7P9OnTNW+//fZZfe184o4cAICA0ZADABCwou9at92pqTb4iO8zPnjw4JzWKQQff/yxd2yn7tkVyebMmeOV+/777zU/9NBDmtPtGWyvib9vJpo0aeIdz5gxQ3PNmjXL/HrFyu4zL+JvOFO9evWU11177bWaQ9kLHoVlp16eeuqpOXufpUuX5uy184k7cgAAAkZDDgBAwIq+a93uTR3vOvzFxIkTveNhw4bltE4h+Mc//uEdP/7445rtXsLxPacXLVqU24pt1KNHD80dO3b0zqXrBq5MRo4cqblNmzbeucMPP1zzhAkTNP/888+5r9hGqVabW7ZsmXdcnuGQYmWHpWrUqOGdW7NmTb6rU9Tat2+vmZXdAABAztCQAwAQMBpyAAACVnRj5Mccc4x3XLdu3d+9Jj4ejN+aOXNmqTmfevbsqfmuu+7SXLt27UJUp8LsbmwXXXSRd+7vf/+75nvvvVdzLj/7Bg0aeMepnjVYv369dxyfvonU7K6C//3vf71zDzzwQL6rExQ7LXj27Nkpy33wwQea7aqDe+21l1euRYsW2atcjnFHDgBAwGjIAQAIWNF1rV955ZXecaopZ3PnztU8efLknNYJ6TVv3lzzFVdcobl79+5euXr16mmuUqVKRq+9YMECzXYqYmXzzDPPeMeXX3655n79+mm2XbMi2Z2O9tFHH3nHdsU8251+2223eeXmzZuXtTqExG7wceihhxawJsXh7LPP1pzqe70sDjvsMM32e2fDhg0Vfu1s444cAICA0ZADABAwF0VR/t7Mufy9mXH00UdrfuKJJ7xztgvmp59+0my7Uv71r3/lsHbFxe4NfsYZZ2g+5ZRTNG+99dbeNVWr/joClO0n0HfYYQfN06ZNy+prZ1N8D/WxY8dqbtu2reb4vuDx7vBM2H3Gzz//fM0XX3yxV845p3nSpEma99577zK/ZxJtttlmmuOr3WXinXfe8Y5PPPFEzfPnzy9/xRLq3HPP1XzhhRdmdI2diRGflWFttdVWmhcuXFiO2pVPFEXu90txRw4AQNBoyAEACBgNOQAAASuKMfLLLrtMc7pV2hjnyz37udpdu/JpwIABmm+44QbN+dw5rKKOPPJIzU8//bTm+Cpq9tz7779f6mvZZxVERFq3bq3ZTumLs1POjjvuOM0vvPBCymuKSbVq1TTffvvtmu1Yroj/rEG672O7C+MBBxyged26dRWqZzGrWbOm5s8//9w7Z1d2Y4wcAADkDA05AAABK4qV3VJt7oD869KlS17e54033tD88ssve+fs5gohdadbtvvadtWOGDHCK9e7d+9Sc3nEP6s+ffqUWh+UsEMPdjpr/HPs27dvqdfbLncRkT/96U+a7ZRMutbLb/Xq1Zq/++477xybpgAAgLygIQcAIGBF0bV+6aWXFroK2OjVV1/VbJ8eL49FixZ5xzfffLPmO+64Q3Oo3eeZsvuRP//88965gQMHarbDGvYp3Di7yclrr72muX///l65fD69G7oVK1ZoznQTpvgT7HZltzVr1mSnYlAPPfSQd9yhQwfNnTp10hxfHbQy4I4cAICA0ZADABAwGnIAAAJWFCu7vfXWW5rtWEfc9ddfr/m6667LYY2Kl13ZzU4RmzNnjma7s5eIyIwZMzTbse/4767dvQ4AysLuhCniPytin7Np2LBh3urEym4AABQBGnIAAAJWFF3rAACUxaOPPqr5iCOO0Fy3bt281YGudQAAigANOQAAAaNrHQCASoiudQAAigANOQAAAaMhBwAgYDTkAAAEjIYcAICA0ZADABCwvE4/AwAA2cUdOQAAAaMhBwAgYDTkAAAEjIYcAICA0ZADABAwGnIAAAJGQw4AQMBoyAEACBgNeRrOuQbOuWecc6ucc986504qdJ2SzDk31jm3xjm3cuP/pha6TknG551fzrkWzrmXnHNLnHPznXNDnXNVC12vJCuWz5yGPL1hIrJORBqLSE8RudM5t3Nhq5R4faMoqrPxf20LXZkiwOedP8NFZKGINBGR3USko4icW9AaJV9RfOY05Ck452qLyHEicnUURSujKHpXRJ4TkV6FrRmAQG0nIk9EUbQmiqL5IvKKiHBjkFtF8ZnTkKfWRkQ2RFE0zfzsE0ngL0Elc4tzbpFz7j3nXKdCV6YI8Hnnz2AR6eGcq+Wcayoih0lJw4LcKYrPnIY8tToisiz2s2UislkB6lIs+otISxFpKiJ3i8jzzrlWha1SovF559c7UnIjsFxEZovIJBEZU9AaJV9RfOY05KmtFJG6sZ/VFZEVBahLUYii6IMoilZEUbQ2iqIHROQ9ETm80PVKKj7v/HHObSIir4rIaBGpLSJbikh9ERlYyHolWTF95jTkqU0TkarOudbmZ7uKyOcFqk8xikTEFboSRYTPO3caiEgzERm68R9Oi0VkpPAPp1wqms+chjyFKIpWScm/5K53ztV2zrUXkWNE5KHC1iyZnHP1nHNdnHM1nHNVnXM9ReQAKfkXNbKMzzu/oihaJCIzRaTPxs+7noj0lpLnbpADxfSZ05Cnd66I1JSS6QuPikifKIq4I8+NaiJyo4h8LyKLROQ8EekWRRFzm3ODzzv/jhWRrlLymc8QkZ9E5KKC1ij5iuIzd1EUFboOAACgnLgjBwAgYDTkAAAEjIYcAICA0ZADABCwvO4C45zjyToAADIQRVFG6zpwRw4AQMBoyAEACBgNOQAAAaMhBwAgYDTkAAAEjIYcAICA0ZADABAwGnIAAAJGQw4AQMBoyAEACBgNOQAAAaMhBwAgYHndNCWJTjrpJM0XXHCBd65du3aaX375Zc133nmnV+6FF17IUe2AX51//vmau3fvrrlDhw5euQ8//FBzs2bNNDdu3Djla3/zzTea7e+9iMjixYvLXNekce7XvS+OO+4479yQIUM0b7XVVpqvv/56r9y1116bo9ohdNyRAwAQMBpyAAACRkMOAEDAXBRF+Xsz5/L3ZhmqWvXXxwSuu+46za1atdJ84oknetfYsfBbbrlFc/Xq1TN6zx9//NE7Pv300zU/+eSTGb0Gsm+77bbzjrt27ar5lVde0Txz5sy81amsWrdu7R0/9NBDmu3YtR2zzbZp06Z5xzvssEPO3qsy22yzzTQ/8cQTmvfdd1+v3KhRozTb30H7vSAismDBglLfZ/fdd/eOJ0+eXPbK5slhhx2m+fDDD9d89tlne+WmT5+ueccdd9Qc/71duHChZvsZ9+rVS/PEiRO9a+z395QpUzKueyFEUZTRX1TuyAEACBgNOQAAASv6rvUbb7xR8xVXXKF56dKlmh977DHvmj59+mguz+cX7x5as2aNZjudbcyYMWV+bfyW7SIXEenYsaPmE044QXPLli1Tvsby5cs1b7755lmsXXbZLkkRf4goldWrV3vH3333neYffvhBs51CGbfrrrtqtsMQIiL33HPP79YhKf70pz9pHj16tGb7PWHLiIjMmTNH8yab/Hpv9fPPP2f0nrfddpt3vGrVKs12uLAysHW95JJLClIH+7naYY1TTz21ALVJj651AACKAA05AAABK7qu9SVLlnjHtWrV0myfYE/Hdo3bz892x4ukfsq0QYMG3nHDhg01r127VrN9WlNE5Ntvv82ofklju8btU73Nmzf3ymXaTV4e9mnbeNdxodkV12bMmOGdq127tmb7NPOf//xnzfFZFIsWLcp2FRMr3k3+xhtvaJ40aZLmLl26aF63bl2F39cO73zyySfeub/97W+aX3zxxQq/VzbZYcRNN920gDUpYb+/bTe7iEjv3r3zXZ3foGsdAIAiQEMOAEDAaMgBAAhYUex+1r9/f83xqUMVfUbAji8eddRR3rnx48eXes1OO+3kHb/55puaGzVqpDnTMfuQ2PHua665ptQy8ZWvKurrr7/2ju0KUHfffbfmm266ySs3btw4zZVtXNzaYostNNsxcRF/OtoBBxyg2U5RQtnYcfHXXnvNOzdixAjNl19+ueaffvqpwu9rp6YNGjRIc/zZmco2Lm7NmjVL8/bbb1/m6+M7wg0fPlyzXeFu5cqVms8991zvGrtSp33eqWfPnl45u0OgtX79eu/YfqdNmDAhZd1ziTtyAAACRkMOAEDAktd3u9Edd9yhuW/fvpoz3Szi/fff1xyfnmOnmAwcOLDMdfviiy+842XLlmm2U4mOOOIIr5z9bwrV448/rrlu3bplvt7+ucTZDULuvPPOMr+2XVUvJHbKYrwL1/6u0Z1efnZIzm5sdPvtt3vlcrmS2l133aXZdvvaqZGV3YABAzTbbvL4hkWp/P3vf/eO7ZTTG264QbP9Hm3fvn1Gr22HLkREatasWWq5uXPnesd2Gixd6wAAoMxoyAEACFhiVnaLr5Zmn1S2+wLHu9Y3bNig2T7NfOaZZ2qObypRUbY+IiIff/yx5hYtWmieOnWqVy7+tHuI7P7t9qnedCrzqmqVzTfffOMdp1rNzbIzJUT8p39ff/11zZlu4pFEw4YN02yfeo7/nZw/f36F3qdatWqahw4d6p075ZRTNNtVx+z3Vkjq1aun2Q6Lififa6bd7vZpcvs9n+nsHztEJeJ3kz/88MOa7dCKiD80mm2s7AYAQBGgIQcAIGA05AAABCzo6Wd2XPy9997zzsXHoVNZvny55vjKPrly1llnecd2XNyKj8Ukwbbbbvu7ZeLTaRgXLz+7etY//vEPzXa8fJtttvGuqVGjhubbbrtNc3wa5pAhQzTHdxVMGvscgV09L75DYXnGyO30JTu1Nf4dZv8s58yZU+b3qWzsbpHxVTHr16+v+cEHH9R88MEHe+WqV6+u2T5fUB72eRIRkc6dO1fo9fKJO3IAAAJGQw4AQMCCnn5mV/nae++9M7omvrrV6aefrvmpp57KTsVKsdtuu2mODwPYrkyrTZs23vFXX32V/YrlmJ1uJuJvYGBXdrPdvldccUXuK5Ygtms2PmXRdj2mYqdgioj88MMPmm0XZ3waz3/+8x/NBx54oOYkTlN76623NHfq1Elz/LOz0/XOOOMMzQsWLNAc34zj3nvv1WxXhrM/F/G7oovVwoULveMtt9wya689bdo073iHHXbI2muXF9PPAAAoAjTkAAAELOiu9SlTpmiOd0OncvTRR3vH+dq713a5HXTQQd65VH8GoXat25WY7Kp1In53up0xYDeliK/kZJ/kbdmyZcr3LdYV4OrUqaN55syZ3jm7V7nt8rYracU3+rD7W9unteMrjdku4vvuu0+zXRUxKeznaFckjG+006RJk1Kvt13r8c2C3nzzTc3HHHOM5nx+N1dmdujo008/9c7ZJ/vtrIoLLrhA85FHHuldc8ghh2iuVauW5viGQ3aYyj5VH189MZfoWgcAoAjQkAMAELCgu9a//PJLzem61mfMmJFRuWxr2rSp5rfffltz69atvXL2z2D27Nma99prL69c/InNyqpr166aX375Ze+c7U63XZS9evXSvO+++1a4DsXazR7v9rO/g3YhopEjR1b4va6++mrNV155pea//OUvXrlnn322wu9VWcX3rO7fv79mO0Mj06er3333Xc3xjUTs8EX8afmksd3p77zzjub4glK2O93OQEq3kcyee+6p+dZbb9WcbgEY+73coUMH75wdiso2utYBACgCNOQAAASMhhwAgIAFN0Zup3vcc889muMrWNlNBfbZZx/Nc+fOrWgVUrKbuIj4K1/ZzRXspvci/niXXQntmmuuyXYVcybVlLP4VJvysCv4WTvvvLN3bN/LXrPffvtVuA6hiG+0Y6doZvv3ya6mOGHCBM32mRSR/D6XUmh2444vvvhCs53aNGzYMO8a+3febsgSZzd1evTRRytUz8russsu02xXfYybNGmS5kxX97Ts9LUxY8Z451KNmdvxchGR/fffX/OsWbPKXId0GCMHAKAI0JADABCw4PYjtysnpdsQwk6vyWV3unXDDTd4x/G9ilOxUyVC6k637NSmTLvTU01Fi08Xi69WlsqyZcs0x7vdi8Xxxx9f6CpI48aNC12FghkwYIDmKlWqaLZTKhctWuRdY7uH7d7bdv9xEX9lvaR3rdvpZ+mMGDGiQu+zYsUKzeecc453buzYsZptu7PNNtt45ewmWHYYN597xnNHDgBAwGjIAQAIWHBd65myXVS5dMIJJ2g+7bTTyvUa6VYhCoVdxcqy3eciIsOHD9dc0X3H45ur2C79r7/+ukKvDWQiPoxw/vnnaz7llFM0x7vTLfvEf7t27TTHVwzLZG/5pLCfXTqfffZZ1t5z+vTp3rGdbfHBBx9o3mmnnbxydvVEu1mLffI+17gjBwAgYDTkAAAEjIYcAICABTdGbnfT6tevn+ZGjRrlrQ52o/p0q8ulEl/56sMPP8xOxQrI7nhlx67jU8eyuRPZww8/nPLcggULsvY+lcE///lP7/juu+/WPG3atHxXR0REGjZsWOrP7W5VSRefLmpXdhs9enSZX89OobQ7e4lkZ5XEUNhpYXXq1ElZ7pJLLtFsd57LdMpqOqtWrdJspwjGx8itLl26aGaMHAAAZISGHACAgAXXtW43IrDdUNleTeqOO+7Q3LdvX++c3fQk001nbHd6EjeRyGaXeZztqrfd6Xa1LBF/qpvdYCIJbJediL8RTBeeYtUAABCuSURBVLdu3TQvXLgwZ3WoWbOmd2y79634phJJdthhh6U8V7Xqr1+vdtOU8rLdu0l38MEHa7ZDj7Vq1fLKde/eXfMf//hHzUcffbTm+O9jfMgiE9dee63mdFPjCjVMyh05AAABoyEHACBgwXWtWz/88IPmeBf3a6+9pnndunWa43uBp+oab9my5e+W+b1ztqv31ltvTVkOvq5du3rHjz/+uOZ0T+7aJ+ez8dRqZfK///3POz7xxBM122Eg++SuiP/0/po1azJ6L9uFbjeLuOmmm7xy9px97dtuuy2j90ki+9T6wIEDNdunq9OxQ0JbbLGFdy7+Z5tkU6ZM0fzII49oPv30071ym2zy672oHbL88ssvNX/yySfeNXYlveuvv17zvHnzUtanfv36mVRbevToofnMM8/M6Jps4I4cAICA0ZADABAwGnIAAALmMp0+lZU3cy6rb9agQQPN7777rneubdu2qergHZfnvz/V9LOpU6d65dq3b695yZIlZX6fyuaWW27RHF+9qzzTz/r06aP5H//4h+Z04+Dvv/++ZjsFK+l22GEH7/iZZ57RnOp3XcQfD/z55581r1+/XnP870Tr1q012zHI+BRP+3p2atQ+++yTsj5Jc+edd3rHZ599tma745kdv126dKl3jR1Xt9MH58+f75Wz06vsn18xiU957NWrl+by7A63evVqzW+88UbKcvZ3OtWKhiL+CozZWNktiiL3+6W4IwcAIGg05AAABCzornXLdrOLiHz99deaN9tsM1sHr1x5/vvtJhWLFy/WfNRRR3nlktCdbtlpIHb6k4j/edspTzvvvLNXLpONH+y0PRGRyy+/XLPtwk/aFLOy6Ny5s+Z7771Xc7NmzbxyVapUqdD72L8fthtSROShhx7SbIdJipmdEluvXj3N9rMbN26cd80ee+yh2f5OH3vssV65OXPmZK2eSWGHnOxwn/3s7dBFrtG1DgAAyoyGHACAgCWmaz3ur3/9q+ZLL700Zbn4U6eZuP3228tVp9DZzUtsV3p52SfQ7QpLudyAJenieyVfd911mnffffdSr4kPN02cOFGzHU6JrwyX7infYtWhQwfNzz33nOZ0Gy3ZoRG7OYfdDxtlY1fmjK+qecwxx2i2szLifw8yZWcQHHHEEZqz8feDrnUAAIoADTkAAAGjIQcAIGCJHSNHbsV3KLPj51a6KWKMhQMoJDttMr56ot197rjjjtMcX13u+eef15zt50YYIwcAoAjQkAMAEDC61gEAqIToWgcAoAjQkAMAEDAacgAAAkZDDgBAwGjIAQAIGA05AAABoyEHACBgNOQAAASMhhwAgIDldWU3AACQXdyRAwAQMBpyAAACRkMOAEDAaMgBAAgYDTkAAAGjIQcAIGA05AAABIyG/Hc453o456Y451Y5575yznUodJ2SyjnX1zk3yTm31jl3f6Hrk3R83vnF510YxfAdXrXQFajMnHOHiMhAEfmLiEwUkSaFrVHizRWRG0Wki4jULHBdigGfd37xeedZsXyH05CnN0BEro+iaMLG4zmFrEzSRVE0WkTEObeXiGxT4OokHp93fvF5F0RRfIfTtZ6Cc66KiOwlIg2dczOcc7Odc0Odc/xLGgAquWL6DqchT62xiFQTke4i0kFEdhOR3UXkqkJWCgCQkaL5DqchT231xv8fEkXRvCiKFonIIBE5vIB1AgBkpmi+w2nIU4iiaImIzBYRtocDgMAU03c4DXl6I0XkPOdcI+dcfRG5UEReKHCdEss5V9U5V0NEqohIFedcDeccD2TmCJ93fvF5F0RRfIfTkKd3g4h8KCLTRGSKiEwWkZsKWqNku0pKusMuF5GTN+bEjWdVInze+cXnnX9F8R3uoijxvQ4AACQWd+QAAASMhhwAgIDRkAMAEDAacgAAAkZDDgBAwPI6h9E5xyPyAABkIIoil0k57sgBAAgYqwoBAIJ02mmnab733ns1/+c//9HcqVOnfFapILgjBwAgYDTkAAAEjIYcAICAJXaMvH///ppvueUWzW+++aZX7pBDDslbnQAA2bN8+fJSf96+fXvN++23n3du/PjxOa1TIXBHDgBAwGjIAQAIWGK61nfZZRfv+Morryy1XIsWLfJQGwBAro0bN67Un1epUkXzqaee6p2jax0AAFQqNOQAAAQsMV3r7dq1847r1KlTarnXX389H9VBBoYMGaK5b9++mr/99luvXOfOnTXPnDkz9xUDKoEePXp4x4888ohm535dgtvO0BERufXWW3NbsUpk++23L3QVKgXuyAEACBgNOQAAAaMhBwAgYIkZI09nzZo1mi+55JIC1gSWnRYSRb9uVd+8eXOv3IEHHqjZ7nAEFCv79+X000/3zhXTGPlRRx31u2Xef//9PNSksLgjBwAgYDTkAAAELDFd6927d095btiwYZpXr16dj+oAldoFF1yg+fbbb/fOTZ8+XfNee+2lecWKFbmvGJBlH374YaGrkHPckQMAEDAacgAAAhZ017p9WvOwww7zzq1cuVJzvOsQKHYnnHCCZvsEtIhI69atNdvVxDJ5QhgVs88++2j+97//XcCahO3777/X/M033xSuInnCHTkAAAGjIQcAIGA05AAABCy4MfJq1appvvbaa1OWW7Zsmea5c+fmtE7IrZNOOkkzK7vlV7169QpdhcTbcsstNb/88suaa9eunfIaO4023fdg0vXs2bPUn9tnpGxOKu7IAQAIGA05AAABC65r/dBDD9XcrFmzlOVuueWWfFQHedChQ4dCV6FojRgxotBVSLy+fftq3nzzzTO65qmnntL8+OOPZ71OldXWW2/tHdevX79ANalcuCMHACBgNOQAAAQsuK71v/71r6X+/JNPPvGO77rrrnxUR6pW/fUj3GWXXbxzZ599tuZu3bpprlu3rldu1qxZmu1qToMGDcpaPYHyaNmyZaGrkDinnnqqd3zaaadldN2GDRs0//e//81mlYJx0UUXecepnux/9dVX81GdSoM7cgAAAkZDDgBAwGjIAQAIWHBj5DVq1Cj15+vWrfOOf/7553xUR+68807NZ5xxRrleo23btppvuOEGzdtvv71X7txzzy3X6wPpOOdSHj/wwAP5rk4ibbHFFpqHDRvmnatZs2ZGr7F8+XLNgwcPzk7FAnP88cdnVO7DDz+s8Hu1aNFCc7t27TTb3QFFRJ577jnNn332WYXftzy4IwcAIGA05AAABCy4rvVUxo4dm7f3OvPMMzWn60633Sxdu3bVHO+u7Nixo2bbzRaflvLiiy+WmoGKiKKo0FVIPLvSZKZd6YsXL/aO7aqWxcR+36ZbzXPevHmaR40albKcXQ3Ovvall17qlatTp47mVEO6IiLXXHONZjvk0b9//5TXZBt35AAABIyGHACAgAXXtW67oa0pU6bk7D3jXWFXXXVVqeUee+wx77hXr16a7apMhxxyiFfOrvR00003aW7SpIlXzm6UkGn3HIDCOOWUUzT37t27zNfPnTvXO548eXKF6xQiO8QYn2FhvfDCC5rXr1+vuXnz5l45+zmm23TFznyyeZNN/PvfTTfdVLPdAOeZZ57xyk2YMCHle1UUd+QAAASMhhwAgIAF17VevXr1vL+n7S4R8btq1q5dqzn+1KPtTk/n/vvv1/z6669rnjFjhleuEP/tuTRy5EjN8c8YCJ39PqhWrVqZr1+5cmU2qxMU+5R4fDMqy37H2n3Z7QI8b731lndNqu70hx56yDu+9957NdeqVUtznz59vHJHHXWUZjvkeeutt3rlDjjggFLfNxu4IwcAIGA05AAABIyGHACAgAU3Rl7Z2DGaOXPmVPj17Gu899573rkDDzywwq9fmey2226FrgJKMXXqVM0//PBDAWtSfOxzMXbstdhcd911mjfbbLOU5eyGJXYs/Prrr9fcsmXLlNeff/75mocPH+6dS7Xx1scff+wdd+7cWbNdDa5p06Yp3zfbuCMHACBgNOQAAASMrvVKplWrVpr333//AtYk9+644w7NSf9vrcziq2XZla9WrFiR7+oEzW7cYadApfPjjz9q3meffTQX87BGpvuOP/zww6X+/MILL0x5zfvvv6956NChZauYiCxatMg7XrVqlWbbtT5t2rQyv3Z5cUcOAEDAaMgBAAhYcF3rthvKrrbz97//3Sv39ttva/72228r9J4LFizwjm1XpF0wv3379l65+FPnv4ivVNStWzfNPXr0KPW1RfwuHCBb4vuRsz955i666CLveMstt9RcpUqVjF7jgQce0FzM3enlUbdu3VJ/bleGsxuoiPz2zywTXbt21Xzuued65xo3blzqNfGV4nKJO3IAAAJGQw4AQMBoyAEACFhwY+QvvfSS5mOPPVbz9ttv75X7/PPPNY8ZM0aznVojIjJq1CjNdiezpUuXan7wwQe9a04//XTNdkcbO9Yl8tuN5X8R3+kr1a5m8ak/Xbp0KbVcMbE7z82aNauANUGxsuPgZ599tncuk3Hx+Dj4fffdl52KFaF77rlH89133625atVfmzb7XJWISIMGDTTbse9evXp55ewuafa7Nz5d03r++ec1P/roo2nrnk3ckQMAEDAacgAAAubyOdXEOZfVN7Mbt/fr16/Cr7dy5UrNH374oWbb/S4isvPOO2u++OKLK/y+1syZMzXHN0344osvsvpehda9e3fNTzzxREbXvPLKK5oPP/zwrNepWNipkXY1MRGRDh06aB4/fnze6hSK3r17ax45cmSZr99333294w8++KDCdUqaE088UfP999+vuVq1all9H9tNnq4t/OmnnzT/97//9c7ZtuKKK67QbNuT8oqiKHU/vsEdOQAAAaMhBwAgYEF3rW+yya//DomvqvaPf/xDs92IpFGjRtmsQrlMnz7dO37sscc0Dxo0SPOyZcvyVqdCyLRr3e75brvTX3/99dxUrAik61rPdEWyYmU3w4jPlknF7m29zTbbeOfmz5+fnYol1N577605vpnK7rvvrtl+rm3atMnotW3XenwYyQ552O+nCRMmZPTa2UDXOgAARYCGHACAgNGQAwAQsOBWdrPsuNO4ceO8c/Ex818cdthhKV/viCOO0GzH1bPhySef1MxKTmVjx2ztKnuHHnqoV+7TTz/NW52SJN1KVciOYcOGaWZMvGwmTpxYasavuCMHACBgNOQAAAQs6OlnCJvd2ODmm2/WfMkll3jlUnX9xqeB7LffflmsXbIx/az83nzzTc2dO3dOWe6rr77SbDdXmjdvXm4qhsRh+hkAAEWAhhwAgIAF/dQ6wmY3Irjssss0b7311l65k046SfPUqVM125XhgHyxGyVNnjzZOzdr1izNBx98sGa605FL3JEDABAwGnIAAAJGQw4AQMCYfgYUoRdffFFz165dvXNMPwMqB6afAQBQBGjIAQAIGF3rQBHafPPNNcdXyNtxxx3zXR0ApaBrHQCAIkBDDgBAwOhaBwCgEqJrHQCAIkBDDgBAwGjIAQAIGA05AAABoyEHACBgNOQAAAQsr9PPAABAdnFHDgBAwGjIAQAIGA05AAABoyEHACBgNOQAAASMhhwAgIDRkAMAEDAacgAAAkZDDgBAwGjIAQAIGA05AAABoyEHACBgNOQAAASMhhwAgIDRkAMAEDAacgAAAkZDDgBAwGjIAQAIGA05AAABoyEHACBgNOQAAASMhhwAgIDRkAMAELD/BxF4ABnvkAZHAAAAAElFTkSuQmCC\n", 119 | "text/plain": [ 120 | "" 121 | ] 122 | }, 123 | "metadata": { 124 | "needs_background": "light" 125 | }, 126 | "output_type": "display_data" 127 | } 128 | ], 129 | "source": [ 130 | "n_columns = 5\n", 131 | "n_rows = 5\n", 132 | "fig = plt.figure(figsize=(8, 8))\n", 133 | "index = 0\n", 134 | "for row in range(n_rows):\n", 135 | " for column in range(n_columns):\n", 136 | " ax = fig.add_subplot(n_rows, n_columns, index+1)\n", 137 | " ax.imshow(batch[index].detach().cpu().numpy(), cmap=plt.cm.Greys_r)\n", 138 | " ax.set_axis_off()\n", 139 | " ax.set_title(labels[index].item())\n", 140 | " index += 1\n", 141 | "plt.subplots_adjust(top = 1, bottom = 0, right = 0.8, left = 0, \n", 142 | " hspace = 0, wspace = 0)\n", 143 | "plt.show()\n" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "# Define the net\n", 151 | "\n", 152 | "Our network will be composed of a series of two convolutions, pooling, non-linearity, followed by a flattening and fully-connected layer." 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 7, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "class Net(nn.Module):\n", 162 | " def __init__(self):\n", 163 | " super(Net, self).__init__()\n", 164 | " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", 165 | " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", 166 | " self.conv2_drop = nn.Dropout2d()\n", 167 | " self.fc1 = nn.Linear(320, 50)\n", 168 | " self.fc2 = nn.Linear(50, 10) # We output log-probabilities for 10 classes\n", 169 | "\n", 170 | " def forward(self, x):\n", 171 | " # the input is (bs, 1, 28, 28)\n", 172 | " x = self.conv1(x) # Loose 2 pixels on each side\n", 173 | " \n", 174 | " # x is now (bs, 10, 24, 24)\n", 175 | " x = F.max_pool2d(x, 2) # divide resolution by two\n", 176 | " x = F.relu(x)\n", 177 | " \n", 178 | " x = self.conv2(x)\n", 179 | " # x is (bs, 20, 8, 8)\n", 180 | " \n", 181 | " x = F.max_pool2d(x, 2)\n", 182 | " # x is (bs, 20, 4, 4)\n", 183 | " x = F.relu(x)\n", 184 | " \n", 185 | " x = x.view(-1, 320) \n", 186 | " # we flattened x (320 = 20*4*4)\n", 187 | " \n", 188 | " x = F.relu(self.fc1(x))\n", 189 | " # x is (bs, 50)\n", 190 | " x = F.dropout(x, training=self.training)\n", 191 | "\n", 192 | " x = self.fc2(x)\n", 193 | " # x is (bs, 10)\n", 194 | " return F.log_softmax(x, dim=1)\n", 195 | "\n", 196 | "\n", 197 | "model = Net()" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 8, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n", 207 | "criterion=nn.CrossEntropyLoss()" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 9, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "Train Epoch: 0 [0/60000 (0%)]\tLoss: 2.278881\n", 220 | "Train Epoch: 0 [16000/60000 (27%)]\tLoss: 0.547932\n", 221 | "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.056793\n", 222 | "Train Epoch: 0 [48000/60000 (80%)]\tLoss: 0.103723\n", 223 | "mean: 6.5000408540072385e-06\n", 224 | "\n", 225 | "Test set: Average loss: 0.0000, Accuracy: 9819/10000 (98%)\n", 226 | "\n", 227 | "Train Epoch: 1 [0/60000 (0%)]\tLoss: 0.124601\n", 228 | "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.056968\n", 229 | "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.013600\n", 230 | "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.022261\n", 231 | "mean: 9.150053301709704e-07\n", 232 | "\n", 233 | "Test set: Average loss: 0.0000, Accuracy: 9852/10000 (98%)\n", 234 | "\n", 235 | "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.099478\n", 236 | "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.146930\n", 237 | "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.113602\n", 238 | "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.140082\n", 239 | "mean: 1.9702911302488246e-08\n", 240 | "\n", 241 | "Test set: Average loss: 0.0000, Accuracy: 9880/10000 (98%)\n", 242 | "\n", 243 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.022954\n", 244 | "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.029932\n", 245 | "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.052703\n", 246 | "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.007666\n", 247 | "mean: 8.661115771246841e-07\n", 248 | "\n", 249 | "Test set: Average loss: 0.0000, Accuracy: 9812/10000 (98%)\n", 250 | "\n", 251 | "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.086427\n", 252 | "Train Epoch: 4 [16000/60000 (27%)]\tLoss: 0.047477\n", 253 | "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.164253\n", 254 | "Train Epoch: 4 [48000/60000 (80%)]\tLoss: 0.067096\n", 255 | "mean: 1.0501152836184247e-08\n", 256 | "\n", 257 | "Test set: Average loss: 0.0000, Accuracy: 9871/10000 (98%)\n", 258 | "\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "n_epoch = 5 # Number of epochs\n", 264 | "\n", 265 | "model = model.to(device)\n", 266 | "\n", 267 | "def train(epoch):\n", 268 | " model.train()\n", 269 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 270 | " # Send the data and label to the correct device\n", 271 | " data, target = data.to(device), target.to(device)\n", 272 | " \n", 273 | " # Important: do not forget to reset the gradients\n", 274 | " optimizer.zero_grad()\n", 275 | " \n", 276 | " # Pass the data through the networks\n", 277 | " output = model(data)\n", 278 | " \n", 279 | " # Compute the loss\n", 280 | " loss = criterion(output,target)\n", 281 | " \n", 282 | " # Backprogragate the gradient\n", 283 | " loss.backward()\n", 284 | " \n", 285 | " # Update the weights\n", 286 | " optimizer.step()\n", 287 | " \n", 288 | " # That's just printing some info...\n", 289 | " if batch_idx % 500 == 0:\n", 290 | " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 291 | " epoch, batch_idx * len(data), len(train_loader.dataset),\n", 292 | " 100. * batch_idx / len(train_loader), loss))\n", 293 | "\n", 294 | "def test():\n", 295 | " model.eval()\n", 296 | " test_loss = 0\n", 297 | " correct = 0\n", 298 | " for data, target in test_loader:\n", 299 | " data, target = data.to(device), target.to(device)\n", 300 | " output = model(data)\n", 301 | " test_loss = criterion(output,target)\n", 302 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 303 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 304 | "\n", 305 | " test_loss /= len(test_loader.dataset)\n", 306 | " print('mean: {}'.format(test_loss))\n", 307 | " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 308 | " test_loss, correct, len(test_loader.dataset),\n", 309 | " 100. * correct / len(test_loader.dataset)))\n", 310 | "\n", 311 | "\n", 312 | "for epoch in range(n_epoch):\n", 313 | " train(epoch)\n", 314 | " test()" 315 | ] 316 | } 317 | ], 318 | "metadata": { 319 | "kernelspec": { 320 | "display_name": "Python 3", 321 | "language": "python", 322 | "name": "python3" 323 | }, 324 | "language_info": { 325 | "codemirror_mode": { 326 | "name": "ipython", 327 | "version": 3 328 | }, 329 | "file_extension": ".py", 330 | "mimetype": "text/x-python", 331 | "name": "python", 332 | "nbconvert_exporter": "python", 333 | "pygments_lexer": "ipython3", 334 | "version": "3.6.6" 335 | } 336 | }, 337 | "nbformat": 4, 338 | "nbformat_minor": 2 339 | } 340 | -------------------------------------------------------------------------------- /notebooks/2-tensor_manipulation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using numpy backend.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import numpy as np\n", 18 | "import tensorly as tl" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "# Defining our example tensor" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "X = tl.tensor(np.arange(24).reshape((3, 4, 2)))" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "array([[[ 0., 1.],\n", 46 | " [ 2., 3.],\n", 47 | " [ 4., 5.],\n", 48 | " [ 6., 7.]],\n", 49 | "\n", 50 | " [[ 8., 9.],\n", 51 | " [ 10., 11.],\n", 52 | " [ 12., 13.],\n", 53 | " [ 14., 15.]],\n", 54 | "\n", 55 | " [[ 16., 17.],\n", 56 | " [ 18., 19.],\n", 57 | " [ 20., 21.],\n", 58 | " [ 22., 23.]]])" 59 | ] 60 | }, 61 | "execution_count": 4, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "X" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "You can view the frontal slices by fixing the last axis:" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "array([[ 0., 2., 4., 6.],\n", 86 | " [ 8., 10., 12., 14.],\n", 87 | " [ 16., 18., 20., 22.]])" 88 | ] 89 | }, 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "X[:, :, 0]" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 6, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "array([[ 1., 3., 5., 7.],\n", 108 | " [ 9., 11., 13., 15.],\n", 109 | " [ 17., 19., 21., 23.]])" 110 | ] 111 | }, 112 | "execution_count": 6, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "X[:, :, 1]" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "# 2.Setting the backend\n", 126 | "\n", 127 | "In TensorLy you can dynamically set the backend to use either NumPy or PyTorch to represent tensors and perform the operations:" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 7, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "data": { 137 | "text/plain": [ 138 | "numpy.ndarray" 139 | ] 140 | }, 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | } 145 | ], 146 | "source": [ 147 | "type(X)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "By default, the backend is set to NumPy, here is how to change it, first to PyTorch:" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 12, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "name": "stderr", 164 | "output_type": "stream", 165 | "text": [ 166 | "Using pytorch backend.\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "tl.set_backend('pytorch')" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 15, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "data": { 181 | "text/plain": [ 182 | "torch.Tensor" 183 | ] 184 | }, 185 | "execution_count": 15, 186 | "metadata": {}, 187 | "output_type": "execute_result" 188 | } 189 | ], 190 | "source": [ 191 | "X = tl.tensor(np.arange(24).reshape((3, 4, 2)), device='cpu')\n", 192 | "type(X)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 16, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "text/plain": [ 203 | "{'device': device(type='cpu'), 'dtype': torch.float32, 'requires_grad': False}" 204 | ] 205 | }, 206 | "execution_count": 16, 207 | "metadata": {}, 208 | "output_type": "execute_result" 209 | } 210 | ], 211 | "source": [ 212 | "tl.context(X)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "As expected tensors are now represented as a PyTorch tensor. " 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "Let's change it back to NumPy for the rest of the tutorial." 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 10, 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "name": "stderr", 236 | "output_type": "stream", 237 | "text": [ 238 | "Using numpy backend.\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "tl.set_backend('numpy')" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 11, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "data": { 253 | "text/plain": [ 254 | "numpy.ndarray" 255 | ] 256 | }, 257 | "execution_count": 11, 258 | "metadata": {}, 259 | "output_type": "execute_result" 260 | } 261 | ], 262 | "source": [ 263 | "X = tl.tensor(np.arange(24).reshape((3, 4, 2)))\n", 264 | "type(X)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": {}, 270 | "source": [ 271 | "# 3.Basic tensor operations" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "## 3.1 Unfolding\n", 279 | "\n", 280 | "Also called **matrization**, **unfolding** a tensor is done by reading the element in a given way as to obtain a matrix instead of a tensor.\n", 281 | "\n", 282 | "It is done by stacking the **fibers** of the tensor into a matrix.\n", 283 | "\n", 284 | "![tensor_illustration](images/example-unfolding-fibers.png)\n", 285 | "\n", 286 | "\n", 287 | "### Convention\n", 288 | "\n", 289 | " Remember that, to be consistent with the Python indexing that always starts at zero,\n", 290 | " in tensorly, modes (and therefore unfolding) also start at zero!\n", 291 | "\n", 292 | " Therefore ``unfold(tensor, 0)`` will unfold said tensor along its first dimension!\n", 293 | " \n" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 11, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "data": { 303 | "text/plain": [ 304 | "array([[ 0., 1., 2., 3., 4., 5., 6., 7.],\n", 305 | " [ 8., 9., 10., 11., 12., 13., 14., 15.],\n", 306 | " [ 16., 17., 18., 19., 20., 21., 22., 23.]])" 307 | ] 308 | }, 309 | "execution_count": 11, 310 | "metadata": {}, 311 | "output_type": "execute_result" 312 | } 313 | ], 314 | "source": [ 315 | "tl.unfold(X, mode=0)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 12, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "data": { 325 | "text/plain": [ 326 | "array([[ 0., 1., 8., 9., 16., 17.],\n", 327 | " [ 2., 3., 10., 11., 18., 19.],\n", 328 | " [ 4., 5., 12., 13., 20., 21.],\n", 329 | " [ 6., 7., 14., 15., 22., 23.]])" 330 | ] 331 | }, 332 | "execution_count": 12, 333 | "metadata": {}, 334 | "output_type": "execute_result" 335 | } 336 | ], 337 | "source": [ 338 | "tl.unfold(X, mode=1)" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 13, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "data": { 348 | "text/plain": [ 349 | "array([[ 0., 2., 4., 6., 8., 10., 12., 14., 16., 18., 20.,\n", 350 | " 22.],\n", 351 | " [ 1., 3., 5., 7., 9., 11., 13., 15., 17., 19., 21.,\n", 352 | " 23.]])" 353 | ] 354 | }, 355 | "execution_count": 13, 356 | "metadata": {}, 357 | "output_type": "execute_result" 358 | } 359 | ], 360 | "source": [ 361 | "tl.unfold(X, mode=2)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": {}, 367 | "source": [ 368 | "## 3.2 Folding\n", 369 | "\n", 370 | "Folding is the inverse operation: you can **fold** an unfolded tensor back from matrix to full tensor using the ``tensorly.fold`` function." 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 14, 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "data": { 380 | "text/plain": [ 381 | "array([[[ 0., 1.],\n", 382 | " [ 2., 3.],\n", 383 | " [ 4., 5.],\n", 384 | " [ 6., 7.]],\n", 385 | "\n", 386 | " [[ 8., 9.],\n", 387 | " [ 10., 11.],\n", 388 | " [ 12., 13.],\n", 389 | " [ 14., 15.]],\n", 390 | "\n", 391 | " [[ 16., 17.],\n", 392 | " [ 18., 19.],\n", 393 | " [ 20., 21.],\n", 394 | " [ 22., 23.]]])" 395 | ] 396 | }, 397 | "execution_count": 14, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "unfolding = tl.unfold(X, 1)\n", 404 | "original_shape = X.shape\n", 405 | "tl.fold(unfolding, mode=1, shape=original_shape)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": {}, 411 | "source": [ 412 | "## 3.3 n-mode product\n", 413 | "\n", 414 | "Also known as **tensor contraction**. This is a natural generalization of matrix-vector and matrix-matrix product. When multiplying a tensor by a matrix or a vector, we now have to specify the **mode** $n$ along which to take the product.\n", 415 | "### Tensor times matrix\n", 416 | "\n", 417 | "In that case we are doing an operation analogous to a matrix multiplication on the $n$-th mode. Given a tensor $\\tilde X$ of size $(I_0, I_1, \\cdots, I_N)$, and a matrix $M$ of size $(D, I_n)$, the $n$-mode product of $\\tilde X$ by $M$ is written $\\tilde X \\times_n M$ and is of size $(I_k, I_0 \\times \\cdots \\times I_{n-1} \\times D \\times I_{n+1} \\cdots \\times I_n)$.\n", 418 | "\n", 419 | "### Tensor times vector\n", 420 | "\n", 421 | "In that case we are contracting over the $n$-th mode by multiplying it with a vector. Given a tensor $\\tilde X$ of size $(I_0, I_1, \\cdots, I_N)$, and a vector $v$ of size $(I_n)$, the $n$-mode product of $\\tilde X$ by $v$ is written $\\tilde X \\times_n v$ and is of size $(I_k, I_0 \\times \\cdots \\times I_{n-1} \\times I_{n+1} \\cdots \\times I_n)$.\n", 422 | "\n", 423 | "\n", 424 | "### Example\n", 425 | "\n", 426 | "In TensorLy, all the tensor algebra functions are located in the `tensorly.tenalg` module. For the n-mode product, you will need to use the function `mode_dot` that works transparently for multiplying a tensor by a matrix or a vector along a given mode.\n", 427 | "\n", 428 | "#### Tensor times matrix\n", 429 | "\n", 430 | "With the tensor $\\tilde X$ of size (3, 4, 2) we defined previously, let's define a matrix M of size (5, 4) to multiply along the second mode:" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 15, 436 | "metadata": {}, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | "(5, 4)\n" 443 | ] 444 | } 445 | ], 446 | "source": [ 447 | "M = tl.tensor(np.arange(4*5).reshape((5, 4)))\n", 448 | "print(M.shape)" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "Keep in mind indexing starts at zero, so the second mode is represented by `mode=1`:" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 16, 461 | "metadata": { 462 | "collapsed": true 463 | }, 464 | "outputs": [], 465 | "source": [ 466 | "res = tl.tenalg.mode_dot(X, M, mode=1)" 467 | ] 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "metadata": {}, 472 | "source": [ 473 | "As expected the result is of shape (3, 5, 2)" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 17, 479 | "metadata": {}, 480 | "outputs": [ 481 | { 482 | "data": { 483 | "text/plain": [ 484 | "(3, 5, 2)" 485 | ] 486 | }, 487 | "execution_count": 17, 488 | "metadata": {}, 489 | "output_type": "execute_result" 490 | } 491 | ], 492 | "source": [ 493 | "res.shape" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": {}, 499 | "source": [ 500 | "#### Tensor times vector\n", 501 | "\n", 502 | "Similarly, we can contract along the mode 1 with a vector of size 4 (our tensor is of size (3, 4, 2).\n" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 18, 508 | "metadata": {}, 509 | "outputs": [ 510 | { 511 | "name": "stdout", 512 | "output_type": "stream", 513 | "text": [ 514 | "(4,)\n" 515 | ] 516 | } 517 | ], 518 | "source": [ 519 | "v = tl.tensor(np.arange(4))\n", 520 | "print(v.shape)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 19, 526 | "metadata": { 527 | "collapsed": true 528 | }, 529 | "outputs": [], 530 | "source": [ 531 | "res = tl.tenalg.mode_dot(X, v, mode=1)" 532 | ] 533 | }, 534 | { 535 | "cell_type": "markdown", 536 | "metadata": {}, 537 | "source": [ 538 | "Since we have multiplied by a vector, we have effectively contracted out one mode of the tensor so the result is a matrix:" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 20, 544 | "metadata": {}, 545 | "outputs": [ 546 | { 547 | "data": { 548 | "text/plain": [ 549 | "(3, 2)" 550 | ] 551 | }, 552 | "execution_count": 20, 553 | "metadata": {}, 554 | "output_type": "execute_result" 555 | } 556 | ], 557 | "source": [ 558 | "res.shape" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "## Kronecker and Khatri-Rao product" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 21, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "from tensorly.tenalg import kronecker, khatri_rao" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 22, 580 | "metadata": {}, 581 | "outputs": [], 582 | "source": [ 583 | "A = tl.tensor([[2, 1],\n", 584 | " [3, 4]])" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 23, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "B = tl.tensor([[0.5, 1],\n", 594 | " [2, 0]])" 595 | ] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "metadata": {}, 600 | "source": [ 601 | "The Kronecker and Khatri-Rao product take as input a list of matrices (as they can take the kronecker and khatri-rao product of more than one matrix)" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 24, 607 | "metadata": {}, 608 | "outputs": [ 609 | { 610 | "data": { 611 | "text/plain": [ 612 | "array([[ 1. , 2. , 0.5, 1. ],\n", 613 | " [ 4. , 0. , 2. , 0. ],\n", 614 | " [ 1.5, 3. , 2. , 4. ],\n", 615 | " [ 6. , 0. , 8. , 0. ]])" 616 | ] 617 | }, 618 | "execution_count": 24, 619 | "metadata": {}, 620 | "output_type": "execute_result" 621 | } 622 | ], 623 | "source": [ 624 | "kronecker([A, B])" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 25, 630 | "metadata": {}, 631 | "outputs": [ 632 | { 633 | "data": { 634 | "text/plain": [ 635 | "array([[ 1. , 1. ],\n", 636 | " [ 4. , 0. ],\n", 637 | " [ 1.5, 4. ],\n", 638 | " [ 6. , 0. ]])" 639 | ] 640 | }, 641 | "execution_count": 25, 642 | "metadata": {}, 643 | "output_type": "execute_result" 644 | } 645 | ], 646 | "source": [ 647 | "khatri_rao([A, B])" 648 | ] 649 | }, 650 | { 651 | "cell_type": "markdown", 652 | "metadata": {}, 653 | "source": [ 654 | "Compare that to the result shown in the slides :)" 655 | ] 656 | } 657 | ], 658 | "metadata": { 659 | "kernelspec": { 660 | "display_name": "Python 3", 661 | "language": "python", 662 | "name": "python3" 663 | }, 664 | "language_info": { 665 | "codemirror_mode": { 666 | "name": "ipython", 667 | "version": 3 668 | }, 669 | "file_extension": ".py", 670 | "mimetype": "text/x-python", 671 | "name": "python", 672 | "nbconvert_exporter": "python", 673 | "pygments_lexer": "ipython3", 674 | "version": "3.6.6" 675 | } 676 | }, 677 | "nbformat": 4, 678 | "nbformat_minor": 2 679 | } 680 | -------------------------------------------------------------------------------- /notebooks/4-tensor_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Low-rank Tensor Regression" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Tensor regression is available in the module `tensorly.regression`.\n", 15 | "\n", 16 | "Given a series of $N$ tensor samples/observations, $\\tilde X_i, i={1, \\cdots, N}$, and corresponding labels $y_i, i={1, \\cdots, N}$, we want to find the weight tensor $\\tilde W$ such that, for each $i={1, \\cdots, N}$:\n", 17 | "\n", 18 | "$$\n", 19 | " y_i = \\langle \\tilde X_i, \\tilde W \\rangle\n", 20 | "$$\n", 21 | "\n", 22 | "We additionally impose that $\\tilde W$ be a rank-r CP decomposition (Kruskal regression) or a rank $(r_1, \\cdots, r_N)$-Tucker decomposition (Tucker regression).\n", 23 | "\n", 24 | "TensorLy implements both types of tensor regression as scikit-learn-like estimators.\n", 25 | "\n", 26 | "For instance, Krusal regression is available through the $tensorly.regression.KruskalRegression$ object. This implements a fit method that takes as parameters $X$, the data tensor which first dimension is the number of samples, and $y$, the corresponding vector of labels.\n", 27 | "\n", 28 | "Given a set of testing samples, you can use the predict method to obtain the corresponding predictions from the model.\n", 29 | "\n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "Using numpy backend.\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "from tensorly.base import tensor_to_vec, partial_tensor_to_vec\n", 47 | "from tensorly.datasets.synthetic import gen_image\n", 48 | "from tensorly.random import check_random_state\n", 49 | "from tensorly.regression.kruskal_regression import KruskalRegressor\n", 50 | "import tensorly.backend as T\n", 51 | "\n", 52 | "import matplotlib.pyplot as plt\n", 53 | "#show figures in the notebook\n", 54 | "%matplotlib inline" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": { 61 | "collapsed": true 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "# Parameter of the experiment\n", 66 | "image_height = 25\n", 67 | "image_width = 25\n", 68 | "\n", 69 | "# fix the random seed for reproducibility\n", 70 | "rng = check_random_state(1) \n", 71 | "\n", 72 | "# Generate a random tensor\n", 73 | "X = T.tensor(rng.normal(size=(1000, image_height, image_width), loc=0, scale=1))" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "Generate the original image" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 3, 86 | "metadata": { 87 | "collapsed": true 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "weight_img = gen_image(region='swiss', image_height=image_height, image_width=image_width)\n", 92 | "weight_img = T.tensor(weight_img)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "The true labels is obtained by taking the product between the true regression weights and the input tensors" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 5, 105 | "metadata": { 106 | "collapsed": true 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "y = T.dot(partial_tensor_to_vec(X, skip_begin=1), tensor_to_vec(weight_img))" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "## Let's view the true regression weight " 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "data": { 127 | "text/plain": [ 128 | "Text(0.5,1,'True regression weights')" 129 | ] 130 | }, 131 | "execution_count": 4, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | }, 135 | { 136 | "data": { 137 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAACi5JREFUeJzt3X2MZXV9x/H3R1AGAkvJAoGVIlKU\nGEttJGrjE2ChrYBE/iHa+hTR0ATbVMlqtTWOlqghxocmPjSRxD/wAUJDKGgJsaS7RYmaGpVQbVCk\nArKUXUSW7W5B+PrHOROu453ZGdA7O3zfr2SSe+ee+zu/e3bf9zzcmUyqCkn9PGWtJyBpbRi/1JTx\nS00Zv9SU8UtNGb/UlPFrVZK8N8ln1noe0yQ5PsmDK1z2hCStP+dO58/5F/1HOQj4f+CR8f4FVfX5\n2c9Ks5DkBODWqsoSj18MHFNVb5rpxGZo/7WewFqqqoMXbie5HXhLVX11qeWT7F9Vv/htzSfJflX1\nyN6XXJvx9OTiYf8yklyc5PIkX0yyE3hdksuSzE8sc/r4xrFw/5gkVyW5N8mPk1y4zPiXJflkkuuS\n7AJelmQuyUeT3JHkniSfSjI38Zx3J9mW5K4kb01SSY57POMlOTLJV5Lcn+S+JFsn1vOeJD9N8kCS\nHyQ5dWKbfG5iuVcnuWUc44YkJ048dmeSdyS5OcnPx+14wBLb4s4kzxtvv2l8Xc8e7/9lkivH208Z\n5/ajJNuTfCnJYeNjv3Ion+T3ktyYZGeS65N8enLu4zJvGNd9b5K/Hb93NvBO4C+SPJjkP8fvn5/k\n9nG825K8Zql/2/XA+PfuXOALwKHA5cstmGQ/4FrgW8DTgTOAzUn+eJmn/TnwfuAQ4CbgI8AzgT8A\nngUcB/zdOP7ZwF8BpwHPBl7xRMYDNgO3AUcARwHvHdfzXOAC4PlVtQF4JfCTKa/3OcBl45yOAL4K\nXJPkqROLnTduh+OBk4HXL7EdtgKnjrdfPs7rlIn7W8bb7wDOGr93DLAL+Mclxvwi8DVgI3Ax8Lop\ny7wYOAH4U+D9SZ5VVdcClwCfr6qDq+rkJBuAjwJnVNUhwEuA7y2x3nXB+Pfuxqq6pqoerarde1n2\nj4ANVfXBqnqoqn4IXAost4e4qqpuqqpHgYeBtwB/U1U/q6oHgA9NPP884NKq+n5V7WKI/ImM9zCw\nCTh2nO9CYL8A5oDnjqc6P66q26as6zXAv1TVDVX1MPBhYAPwoollPl5V26pqB8Mb4x8usR228Fjs\nLxvnuXD/FB6L/wLgPVV1V1XtAeaB85L8yv/lJMcDzwPmx9e2FfjylPXOV9Weqvo2cMv4nKUU8PtJ\n5qrq7qr6r2WW3ecZ/97dsYplnwEcOx4C35/kfobDx6NWOP5RwAHAdyeefy1w5Pj4pkXLT5vbasb7\nMPA/wL+Nh9GbAarqv4GLgA8A/zserk97DZvG5zM+71HgToajngXbJm7/H3Aw020BXp7k6QxvPlcy\nnLacwPBGdPO43LEMRxcLr+dmhiiPXDTeJmDHojfsX9teVbWi+Y1vnK8FLgS2Jbl24bRkvTL+vVv8\nccguhk8GFkxGcQfDFeTfmfg6pKpetcLx7wEeAk6ceP6hVXXo+PjdDIe6C373iYxXVQ9U1dur6jjg\n1cC7kpwyPnZZVb2E4ZRhP4Y98WI/ZXjDA4bz8XF+dy3zeqeqqh8wRH8hsKWq7gfuA94M/Ec99rHU\nnQyH3pPbeG5RxDBsq42T10uYvr2WnNKUOf5rVZ0OHA38EPinVYy3zzH+1fsOcFaSw5IcDfz1xGM3\nAQ8luWi80LZfkpOSnLySgccr858FPp7kiAyOSfIn4yJXAOcnOTHJQYzn6I93vCSvGi+KBfg5w8ec\njyR5TpLTxotzu8evaZ8aXAGck+TU8Tx/M7AT+MZKXu8UW4G38dgh/r8vug/wGeCDSY4dX8ORSc6Z\n8tp/xHBU8L4kT0vyUoZrBSt1D3DcuG1IcvS4vQ5ieEPdxfRtsm4Y/+p9Dvg+w+HudcCXFh4YPwY8\nE3ghcDuwnWHvsGEV4180jv1NhiCvZ7hQR1VdA3yaIZJbGS5mwfDzCaseDzgRuAF4cBzrE1V1I8Op\nwiXj/LcBhwF/v3jgqroFeOM4p3uBPwPOGc//H48tDBcqty5xH4aLbtcxnKrsBL4OvGCJ8V7LcGFw\nB/A+hgu2y22rSZcDTwPuS/JNhqOfzQxHFDsYLhS+bYVj7ZNa/5DPepfkJODbwAHj+baWkeSfge9U\n1T+s9Vz2Be7515kk546HsRsZLthdbfjTJXlhkmeOPxtwJnA2cPVaz2tfYfzrz4UMh+O3AnvG+5pu\nE8Mpw07gY8Bbq2pdfzb/m+Rhv9SUe36pqdn+Ys+eHR5mSL9tcxun/qbiYu75paaMX2rK+KWmjF9q\nyvilpoxfasr4paaMX2rK+KWmjF9qyvilpoxfasr4paaMX2rK+KWmjF9qyvilpoxfasr4paaMX2rK\n+KWmjF9qyvilpoxfasr4paZm+xd79LjMH3j4Wk9h5uZ3b1/rKTzpueeXmjJ+qSnjl5oyfqkp45ea\nMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oy\nfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+\nqSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5pKVc1ubXt2zGRl\n8wcePovVSKsyv3v7bFY0tzErWcw9v9SU8UtNGb/UlPFLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTRm/\n1JTxS00Zv9SU8UtNGb/UlPFLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTRm/1JTxS00Zv9SU8UtNGb/U\nlPFLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTRm/1JTxS00Zv9SU\n8UtNGb/UlPFLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTRm/1JTx\nS00Zv9RUqmp2a9uzY4Yre/KYP/DwtZ7CzM3v3r7WU1i/5jZmJYu555eaMn6pKeOXmjJ+qSnjl5oy\nfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+\nqSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6p\nKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmkpVzW5te3bMcGVS\nU3Mbs5LF3PNLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTRm/1JTx\nS00Zv9SU8UtNGb/UlPFLTRm/1JTxS00Zv9SU8UtNGb/UlPFLTc32L/ZI2me455eaMn6pKeOXmjJ+\nqSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6pKeOXmjJ+qSnjl5oyfqkp45eaMn6p\nKeOXmjJ+qSnjl5r6JZVzjGbfF5/4AAAAAElFTkSuQmCC\n", 138 | "text/plain": [ 139 | "" 140 | ] 141 | }, 142 | "metadata": {}, 143 | "output_type": "display_data" 144 | } 145 | ], 146 | "source": [ 147 | "fig = plt.figure()\n", 148 | "ax = fig.add_subplot(1, 1, 1)\n", 149 | "ax.imshow(T.to_numpy(weight_img), cmap=plt.cm.OrRd, interpolation='nearest')\n", 150 | "ax.set_axis_off()\n", 151 | "ax.set_title('True regression weights')" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "## Tensor regression " 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "### Create a tensor Regressor estimator" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 40, 171 | "metadata": { 172 | "collapsed": true 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "estimator = KruskalRegressor(weight_rank=1, tol=10e-7, n_iter_max=100, reg_W=1, verbose=0)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "### Fit the estimator to the data" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 41, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "data": { 193 | "text/plain": [ 194 | "" 195 | ] 196 | }, 197 | "execution_count": 41, 198 | "metadata": {}, 199 | "output_type": "execute_result" 200 | } 201 | ], 202 | "source": [ 203 | "estimator.fit(X, y)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "### Predict the labels given input tensors" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 42, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "y_pred = estimator.predict(X)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "Let's measure the RMSE" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 43, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "from tensorly.metrics import RMSE" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 44, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "text/plain": [ 246 | "6.209524216534728" 247 | ] 248 | }, 249 | "execution_count": 44, 250 | "metadata": {}, 251 | "output_type": "execute_result" 252 | } 253 | ], 254 | "source": [ 255 | "RMSE(y, y_pred)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "### Visualise the learned weights" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 45, 268 | "metadata": {}, 269 | "outputs": [ 270 | { 271 | "data": { 272 | "text/plain": [ 273 | "Text(0.5,1,'Learned regression weights')" 274 | ] 275 | }, 276 | "execution_count": 45, 277 | "metadata": {}, 278 | "output_type": "execute_result" 279 | }, 280 | { 281 | "data": { 282 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAET1JREFUeJzt3XuQZGddxvHn6bnt7M7uzoaYmJDL\nYrSoKEGFwsRCjZdUIgUqopZSQLLEWASVeFcQtQgay+IiRYyoFdACoiIVxSoVlVKSDUnQGCGmykCU\nXDey2ZDNzmz2PpfXP86ZTWcy27932d6Znvl9P1VTNTPn3fe8fU4/c7r7/PZ9XUoRgHw6Kz0AACuD\n8ANJEX4gKcIPJEX4gaQIP5AU4R8QtrfaLraHV3os/WL7H21fsdLjWIrt19n+VGXbbbZvP9ljWm5r\nMvy2H7Z9yUqPI7tSyitKKR9e6XEspZTy56WUS/vRl+1bbV/Vj76W05oM/3JYySt0v/e9ll5toF66\n8Nt+le17bE/ZvtP2i7u2vdX2A7aftn2f7R/u2rbN9h2232f7KUnvWHg5aPs9tvfYfsj2K7r+zWbb\nH7K90/b/2f4d20PttqH23z1p+0FJrwzG/bDtX7N9r6T9todtn2n7r21/pd33NV3tx21/uB3XF2z/\nqu3HTqC/b7N9t+29tnfZ/v329+ts32R7d3tM/8P26e22o1dE2x3bv2H7EdtP2P6I7c3ttoW3PFfY\nfrQ9Jm8/xnF4QbufTvvzB20/0bX9Jts/X3H8n/VS3valtu+3PW37A7a3L76aL3WebV8n6Tsl3WB7\nn+0b3Hhf+zinbd9r+0W9zu+KKKWsuS9JD0u6ZInfv0TSE5IulDQk6Yq27Vi7/ccknanmj+KPS9ov\n6Yx22zZJs5LeImlY0nj7uxlJP9X292ZJX5bk9t/8raQ/kbRB0mmS7pL0pnbb1ZK+KOlsSadIukVS\nkTTc4zHd07Yfb8f4n5J+S9KopK+T9KCky9r2vydpu6Qtks6SdK+kx06gv89KekP7/YSki9rv3yTp\n7yStb4/BSyVtarfdKumq9vsrJX2p7XdC0t9I+mi7bWv72G9sx/LNkg5LOv8Yx+JRSS9tv7+/Hef5\nXdu+teL4b5N0e/v9qZL2SnpNe25/rj2vV3W17XWejz7O9ufL2mM5KcmSzlf7PBqkrxUfwDKH/48k\n/fai390v6eJj9HOPpB/qegI8umj7Nklf6vp5ffsk/lpJp7dP4PGu7a+VdEv7/aclXd217VLF4b+y\n6+cLlxjP2yT9Wfv90eC2P1+l54b/ePq7TdK1kk5d1OZKSXdKevESYz4aCkn/Kumnu7a9sA3UsJ4J\n/1ld2++S9BPHOBYflfSL7XG+X9K71PwxfYGkKTV/yKLjv03PhP9ySZ/tamdJO/Ts8C95nhc/zvbn\n75X0P5IuktRZ6Twc6yvbe71zJV1h+y1dvxtVc7WX7cvVPKm2ttsm1FwVFuxYos/HF74ppRywvfDv\nTpE0Imln+zupeVIu9HHmov4eqRh/d/tzJZ1pe6rrd0OSPnOM/pca+/H095OS3inpi7YfknRtKeXv\n1QTxbEkfsz0p6SZJby+lzCza15l69mN8RE3wT+/63eNd3x9QcxyXsl3SD0p6TM0fpVslvUHSIUmf\nKaXM2z5XvY//4rEd/X0ppXS/RVo8tkXn+TlKKZ+2fYOkP5R0ju1PSPrlUsreYzyeFZEt/DskXVdK\nuW7xhvbJcqOk71NzFZizfY+aq8CC4/kvkDvUXHlOLaXMLrF9p5rQLDinos/u/e+Q9FAp5RuO0Xan\nmpf797U/n71Em+r+Sin/K+m17Xvt10i62fbzSin71bwiuNb2VkmfVHM1/tCiLr6s5g/MgnPUvI3a\n1Y7zeGyX9G414d8u6XZJf6wm/Nu7Hk+v499t4VhJktwk+3jG9JznRSnleknX2z5N0scl/Yqk3zyO\nPk+6tfyB30j7YdTC17CacF9t+8L2Q5kNtl9pe6Oa94VF0lckyfYbJX3VH9KUUnZK+pSk99re1H7g\ndZ7ti9smH5d0je2zbG+R9Nbj3MVdkva2H9qNu/kA8UW2X9bV/9tsb7H9fEk/eyL92X697a8ppcyr\neWktSXO2v8f2Be0HaXvVvJSfW6L/v5T0C+0HdhOSflfSX1UE8znaP0QHJb1e0m3tFXWXpB9RG/6K\n49/tHyRdYPvV7fPkZ9S8pai1S81nGZIk2y9rn2Mjaj43OqSlj8mKWsvh/6SaJ8jC1ztKKXer+dDm\nBkl71HwAtU2SSin3SXqvmg+2dkm6QNIdJziGy9W8rbiv3d/Nks5ot90o6Z8l/Zekz6n5AKxaKWVO\n0g9I+hZJD0l6UtIHJW1um7xTzZXxIUn/0u778An09/2S/tv2PknvV/N+/JCakNysJvhfUBO+m5bY\nxZ+qeYtwW9v/ITUfnn61tkvaXUp5tOtnS/p8V5tex/+oUsqTaj7sfZek3ZK+UdLd6nG8Fnm/pB9t\n7wRcL2mTmvO7R83bm92S3nM8D245LHxaiTXO9pvVBHapKx+6tG9tHpP0ulLKLSs9npNlLV/5U7N9\nhu2Xty93XyjplyR9YqXHNahsX2Z70vaYpF9X8yri31Z4WCdVtg/8MhlVc4974fbXxyR9YEVHNNi+\nXdJf6Jm3Ca8upRxc2SGdXLzsB5LiZT+Q1LK+7C/TDyzPy4wycHdVTlDGV2eOm6w2zX8rOPm72Xxe\n1cHjyg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKm1WdtfVUxRUwcRFdf0o48aa7DgJZSxsGl5\nceUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaRW4X3+invenYqH5VV077wzWtGo5r74cj3mPoxlrmLW\n7H6dw2gqu37tZ27xIkZLCcbSx2n3uPIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0hq8Ip85oNC\niIrVeMqOf4/3c6hiDcbhkd7b5ypWBhquOMSd3n+Dy+c+33O7JGlyc9xmem/cJrJ/f9xmy2TcZs9U\nz83+jpfHfczOxm2GKiZ26UQFR/NxHxXFNz6n4jENr++9fah/keXKDyRF+IGkCD+QFOEHkiL8QFKE\nH0iK8ANJEX4gqcEr8gn1aSafmmKJsECkYlaVoYq/rw7ajK+L+xirmO1n3VjFWILjWzOTzPh43Obw\nkd7bg8Kn/raJznNFMdEqXGCIKz+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaRWYZFPRTXFkYpZ\neg5WtBkNZvKpmUkmmg1IigtRdj0R93GoYnmr6em4TWB+f7yfzuG4zfxTvWcVGqroQzMVy1/VzKQU\nFWLNVszYVPO87MtSW/27XnPlB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkVuF9/go199ZHK+7djgT9\nRJNwSH1ZsadqBZxNmyr2UzERStTFyIG4UcXqQZ354J53dOyleOIRqXLFnuD4h5N9qO4efs14QxWr\nB1Xiyg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKm1WeQzWzHJw5FgxRhJ4QQNNZN51BR/REUm\ne6biPqKiGWn5JvOoKCaan3q65/ahmok6mMxjQHoCsKoQfiApwg8kRfiBpAg/kBThB5Ii/EBShB9I\nam0W+YyMxm36MZNPzQwvVTP5BEUxp5wS97F5Y9ymZlabQGddn2byiWa1iVZLajqJ2wxVHP9oLMzk\nA2AtIfxAUoQfSIrwA0kRfiApwg8kRfiBpAbvPv98MEHGXDyZRLnjzrDNkR27wzYjm8d6bp/dF08m\nMbwxvl/tsd77+adrPhL2cfrWU8M2ux5+MmwTebzifvZ5k+vDNg9M9a4X2PYH8WQrc9NxzcHQRO9j\nK0kKjr8OHQq7KBUTfgxd/pJ4LCMTQSfr4j4qceUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5DU\n4BX59EPFBBqdkfjvnkd699MZiQs7XDMpRTDesfG4j7HxeAKTmn4cTJCxfl9cZFUzlvVBkU/NOXTF\nOayaTCU4z5qtGEtfJupYXlz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8ktfqKfGqKKWpWpqnp\nJ1oRpmbFGFf8fQ36GRqJH8/wcNympp+oWKWiZEmdofgxh/1U9BEVJDX9VDwXon1VjEWlfyvpLBeu\n/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHklp9RT4Vy0XpQLyM09yBYFkwSZ3R3ktGzR2M+/BQ\nvNRTZ7Z3P1PTB8M+1j+1L2xT009kV0WbyYqxhP0ciMdac/yHO/FzQXO9Z2QqVct1xc/LjiqeuwHX\nFI1V4soPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0mtvvv8NfdKj8yETeaPxKvtRG1q+igz8WQS8+p9\nv/rpsAfp4L74XnRNP5HpijYHDvSuj6jpp8zE9/CrzmHFqj6dTu/ny/yReKKOMtuvyTyWb+UfrvxA\nUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5JahUU+FUUQY2Nhk85YXHzTGe3dpjMTF3bU7McjvU/D\nprAHaXxiXdhm01TFxBaByYo2Gybi4z+573DP7dExkaTOaM2ELBUr9oyN9u4jmOxDkkrN6kFVTnzC\nj1pc+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJDWART5RkUOfVuzZH8/2ExWIVK36MxwXf3RG\nexcLPRH2IE1UrJJT00802po+TgsKeGr6mT8YzwZUc/xdUXzTKb1XB5o7WFHkUzGTz3DNalPM5APg\nZCP8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5DU4BX5dEaCBnERhL/7u8I2Gw7sj8cy0nuGl9HZuMhE\nI9HjkRQUoryxplBlSzzfz8V79oZtoqKY2X1xcdTI89aHbV61u3ch1tAl8TkcmonHouGa49/7GtiZ\nr1iKq6bN8HjcpvQuKCqz8bJstWVCXPmBpAg/kBThB5Ii/EBShB9IivADSRF+IKnBu8/fDzWrp3Qq\nVnIJ7v/W7aeijYP7zKMVp2m0d01CdT/uPd7OSDyxRVQf0fQT3K+uOj8V99aDx9P0E5znmkk4oj4G\n0OobMYC+IPxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kNTqK/KpKdoIimYkVRboREU+FYUoVWMJ2tRM\nCDJccSrH4uKbSGesZgKTeCydseDY9eP8SNJQzTkK9lWzn6rVeAYLV34gKcIPJEX4gaQIP5AU4QeS\nIvxAUoQfSIrwA0kNYJFPUCxRKmZv2R+vxlP2PR228di63g1m+7RizFDvv8EHH9wddjGyJX7MM3vi\n1V4is0/Hj3n01Hg/R57s3WbTRfviwRw5ErepWjEpKASaq5i9qOZ5GT23JYXr7dQUjVXiyg8kRfiB\npAg/kBThB5Ii/EBShB9IivADSQ3gff4+3OfctDHeS83932hSitmKiS1qJtkIHtP4Nz0/7mNyczyU\nqem4n8iBg3GbLZNhk7E9U70bbIofj2Zq6iwqjn80WUfNff6ae/hV9+ijOpeK51wlrvxAUoQfSIrw\nA0kRfiApwg8kRfiBpAg/kBThB5IawCKfSM1KLjUr3PSpQKcffUTFHxsn4j7Wj8dtqopVAjUr4Exs\niNtEBVI1+6lZJaemn05wjmpWiWLFHgCrBeEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaRclrE4oUw/\nsDw7c0Vhx2rSx1VaTli/Cl6ifub7UJA0aEofHlPFsfXk11ecJK78QFqEH0iK8ANJEX4gKcIPJEX4\ngaQIP5AU4QeSWoUz+VSoKUSpmRFo2QSFG9FMM8uq4ri5ppYrWpZtvj9jqVlGqy9qip8Gq3CJKz+Q\nFOEHkiL8QFKEH0iK8ANJEX4gKcIPJDVIN5D7Z25mpUfQX2vt8WRVVX+yDH20uPIDSRF+ICnCDyRF\n+IGkCD+QFOEHkiL8QFKEH0hqWVfsATA4uPIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJ\nEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBS/w++J/Op\nN1SDrQAAAABJRU5ErkJggg==\n", 283 | "text/plain": [ 284 | "" 285 | ] 286 | }, 287 | "metadata": {}, 288 | "output_type": "display_data" 289 | } 290 | ], 291 | "source": [ 292 | "fig = plt.figure()\n", 293 | "ax = fig.add_subplot(1, 1, 1)\n", 294 | "ax.imshow(T.to_numpy(estimator.weight_tensor_), cmap=plt.cm.OrRd, interpolation='nearest')\n", 295 | "ax.set_axis_off()\n", 296 | "ax.set_title('Learned regression weights')" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": {}, 302 | "source": [ 303 | "Wait! Weren't the learned weights supposed to be a low rank tensor in the Krusakl form!?\n", 304 | "\n", 305 | "They are! We simply plot the full tensor corresponding to the learned decomposition. You can access the decomposed form as follows:" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 46, 311 | "metadata": { 312 | "collapsed": true 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "factors = estimator.kruskal_weight_" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 47, 322 | "metadata": {}, 323 | "outputs": [ 324 | { 325 | "data": { 326 | "text/plain": [ 327 | "[(25, 1), (25, 1)]" 328 | ] 329 | }, 330 | "execution_count": 47, 331 | "metadata": {}, 332 | "output_type": "execute_result" 333 | } 334 | ], 335 | "source": [ 336 | "[f.shape for f in factors]" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "# Tucker regression" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 48, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "from tensorly.regression import TuckerRegressor" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 52, 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "estimator = TuckerRegressor(weight_ranks=[10, 5])" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 53, 367 | "metadata": {}, 368 | "outputs": [ 369 | { 370 | "name": "stdout", 371 | "output_type": "stream", 372 | "text": [ 373 | "\n", 374 | "Converged in 8 iterations\n" 375 | ] 376 | }, 377 | { 378 | "data": { 379 | "text/plain": [ 380 | "" 381 | ] 382 | }, 383 | "execution_count": 53, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "estimator.fit(X, y)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 54, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "y_pred = estimator.predict(X)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 55, 404 | "metadata": {}, 405 | "outputs": [ 406 | { 407 | "data": { 408 | "text/plain": [ 409 | "0.0037852639020583777" 410 | ] 411 | }, 412 | "execution_count": 55, 413 | "metadata": {}, 414 | "output_type": "execute_result" 415 | } 416 | ], 417 | "source": [ 418 | "RMSE(y, y_pred)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 56, 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "data": { 428 | "text/plain": [ 429 | "Text(0.5,1,'Learned regression weights')" 430 | ] 431 | }, 432 | "execution_count": 56, 433 | "metadata": {}, 434 | "output_type": "execute_result" 435 | }, 436 | { 437 | "data": { 438 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAACq5JREFUeJzt3XmsXGUdh/HnCwVaRRZbRSqrxhiM\n4BZFY4xxCWjccIsaFSpiAPddcEtRMcY1EtwCaBBcg0viFo0KRVyCqEgiiFuBorXYClSgNYCvf5z3\nluFyW+61dKbe3/NJbjIz59xz3jlznjnLTdO01pBUzw6THoCkyTB+qSjjl4oyfqko45eKMn6pKOPf\nTiQ5IElLsmDSY7mrJPlekqMmPY6ZJHlxkh/Mct5lSS7Y1mMat3kZf5Irkjx50uOorrX21NbamZMe\nx0xaa19orR12VywryXlJjrkrljVO8zL+cZjkEfquXvd8OtvQ7JWLP8nTk1yc5LokP0tyyMi0E5L8\nOcm/klya5Nkj05Yl+WmSjyX5J7B86nQwyYeTXJtkZZKnjvzO7knOSLI6yV+TvC/Jjn3ajv331ib5\nC/C0Oxn3FUneluQS4MYkC5IsTfK1JP/o637tyPyLkpzZx3VZkrcmuXorlveoJBclWZ9kTZKP9tcX\nJjk7ybq+TX+ZZK8+bdMRMckOSd6Z5Mok1yT5fJLd+7SpS56jklzVt8k7NrMdDuzr2aE/Pz3JNSPT\nz07y+lls/9udyic5LMnlSa5P8skkK6YfzWf6nJOcDDwOODXJDUlOzeBj/X1en+SSJA/e0uc7Ea21\nefcDXAE8eYbXHw5cAxwK7Agc1efdpU9/PrCU4UvxBcCNwN592jLgFuA1wAJgUX/tZuAVfXnHA38D\n0n/nm8BngLsD9wYuBI7t044Dfg/sC9wTOBdowIItvKeL+/yL+hh/Bbwb2Bm4H/AX4PA+/weAFcCe\nwD7AJcDVW7G8nwMv7Y93BR7dHx8LfAu4W98GjwB269POA47pj48G/tSXuyvwdeCsPu2A/t5P62N5\nCPBv4KDNbIurgEf0x5f3cR40Mu1hs9j+y4AL+uMlwHrgOf2zfV3/XI8ZmXdLn/Om99mfH9635R5A\ngIPo+9H29DPxAYw5/k8B75322uXA4zeznIuBZ43sAFdNm74M+NPI87v1nfg+wF59B140Mv1FwLn9\n8Y+B40amHcadx3/0yPNDZxjPicDn+uNN4fbnx3DH+OeyvPOBk4Al0+Y5GvgZcMgMY94UBfAj4JUj\n0x7Yg1rAbfHvMzL9QuCFm9kWZwFv7Nv5cuCDDF+mBwLXMXyR3dn2X8Zt8R8J/HxkvgCruH38M37O\n099nf/5E4A/Ao4EdJt3D5n6qXevtDxyV5DUjr+3McLQnyZEMO9UBfdquDEeFKatmWObfpx601m5K\nMvV79wR2Alb312DYKaeWsXTa8q6cxfhH598fWJrkupHXdgR+spnlzzT2uSzv5cB7gN8nWQmc1Fr7\nNkOI+wJfTrIHcDbwjtbazdPWtZTbv8crGcLfa+S1v488volhO85kBfBM4GqGL6XzgJcCG4GftNb+\nk2R/trz9p49t0+uttTZ6iTR9bNM+5ztorf04yanAJ4D9knwDeHNrbf1m3s9EVIt/FXBya+3k6RP6\nznIa8CSGo8CtSS5mOApMmcs/gVzFcORZ0lq7ZYbpqxmimbLfLJY5uv5VwMrW2gM2M+9qhtP9S/vz\nfWeYZ9bLa639EXhRv9Z+DnBOksWttRsZzghOSnIA8F2Go/EZ0xbxN4YvmCn7MVxGrenjnIsVwIcY\n4l8BXAB8miH+FSPvZ0vbf9TUtgIgQ9lzGdMd9ovW2inAKUnuDXwVeAvwrjksc5ubzzf8duo3o6Z+\nFjDEfVySQ/tNmbsneVqSezBcFzbgHwBJXgb8zzdpWmurgR8AH0myW7/hdf8kj++zfBV4bZJ9kuwJ\nnDDHVVwIrO837RZluIH44CSPHFn+iUn2THJf4NVbs7wkL0lyr9bafxhOrQFuTfKEJAf3G2nrGU7l\nb51h+V8C3tBv2O0KvB/4yizCvIP+RbQBeAlwfj+irgGeS49/Ftt/1HeAg5Mc0feTVzFcUszWGoZ7\nGQAkeWTfx3ZiuG+0kZm3yUTN5/i/y7CDTP0sb61dxHDT5lTgWoYbUMsAWmuXAh9huLG1BjgY+OlW\njuFIhsuKS/v6zgH27tNOA74P/Bb4NcMNsFlrrd0KPAN4KLASWAucDuzeZ3kPw5FxJfDDvu5/b8Xy\nngL8LskNwMcZrsc3MkRyDkP4lzHEd/YMq/gswyXC+X35Gxlunv6vVgDrWmtXjTwP8JuReba0/Tdp\nra1luNn7QWAd8CDgIrawvab5OPC8/peAU4DdGD7faxkub9YBH57LmxuHqbuVmueSHM8Q7ExHPo3o\nlzZXAy9urZ076fFsK/P5yF9akr2TPLaf7j4QeBPwjUmPa3uV5PAkeyTZBXg7w1nELyY8rG2q2g2/\nSnZm+Bv31J+/vgx8cqIj2r49Bvgit10mHNFa2zDZIW1bnvZLRXnaLxU13tP+jes8zZC2tYWLc+cz\neeSXyjJ+qSjjl4oyfqko45eKMn6pKOOXijJ+qSjjl4oyfqko45eKMn6pKOOXijJ+qSjjl4oyfqko\n45eKMn6pKOOXijJ+qSjjl4oyfqko45eKMn6pKOOXivI/6vw/sHzRkkkPYeyWb1g76SHMex75paKM\nXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxf\nKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8q\nyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK\n+KWi0lob39o2rhvLypYvWjKO1UhzsnzD2vGsaOHizGY2j/xSUcYvFWX8UlHGLxVl/FJRxi8VZfxS\nUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJR\nxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHG\nLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYv\nFWX8UlHGLxVl/FJRxi8VZfxSUWmtjW9tG9eNcWXzx/JFSyY9hLFbvmHtpIfw/2vh4sxmNo/8UlHG\nLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYv\nFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8V\nZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl/FJRxi8VZfxSUcYvFWX8UlHGLxVl\n/FJRaa2Nb20b141xZVJRCxdnNrN55JeKMn6pKOOXijJ+qSjjl4oyfqko45eKMn6pKOOXijJ+qSjj\nl4oyfqko45eKMn6pKOOXijJ+qSjjl4oyfqko45eKMn6pKOOXijJ+qSjjl4oyfqko45eKGu//2CNp\nu+GRXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvilooxfKsr4paKMXyrK+KWijF8qyvil\nooxfKsr4paKMXyrK+KWijF8qyvilooxfKuq/kZh4MOw+RSIAAAAASUVORK5CYII=\n", 439 | "text/plain": [ 440 | "" 441 | ] 442 | }, 443 | "metadata": {}, 444 | "output_type": "display_data" 445 | } 446 | ], 447 | "source": [ 448 | "fig = plt.figure()\n", 449 | "ax = fig.add_subplot(1, 1, 1)\n", 450 | "ax.imshow(T.to_numpy(estimator.weight_tensor_), cmap=plt.cm.OrRd, interpolation='nearest')\n", 451 | "ax.set_axis_off()\n", 452 | "ax.set_title('Learned regression weights')" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 57, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "core, factors = estimator.tucker_weight_" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 58, 467 | "metadata": {}, 468 | "outputs": [ 469 | { 470 | "data": { 471 | "text/plain": [ 472 | "(10, 5)" 473 | ] 474 | }, 475 | "execution_count": 58, 476 | "metadata": {}, 477 | "output_type": "execute_result" 478 | } 479 | ], 480 | "source": [ 481 | "core.shape" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 59, 487 | "metadata": {}, 488 | "outputs": [ 489 | { 490 | "data": { 491 | "text/plain": [ 492 | "[(25, 10), (25, 5)]" 493 | ] 494 | }, 495 | "execution_count": 59, 496 | "metadata": {}, 497 | "output_type": "execute_result" 498 | } 499 | ], 500 | "source": [ 501 | "[f.shape for f in factors]" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [] 510 | } 511 | ], 512 | "metadata": { 513 | "kernelspec": { 514 | "display_name": "Python 3", 515 | "language": "python", 516 | "name": "python3" 517 | }, 518 | "language_info": { 519 | "codemirror_mode": { 520 | "name": "ipython", 521 | "version": 3 522 | }, 523 | "file_extension": ".py", 524 | "mimetype": "text/x-python", 525 | "name": "python", 526 | "nbconvert_exporter": "python", 527 | "pygments_lexer": "ipython3", 528 | "version": "3.6.6" 529 | } 530 | }, 531 | "nbformat": 4, 532 | "nbformat_minor": 2 533 | } 534 | -------------------------------------------------------------------------------- /notebooks/5-decomposition_with_pytorch_and_backprop.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using numpy backend.\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import numpy as np\n", 18 | "\n", 19 | "# Import PyTorch\n", 20 | "import torch\n", 21 | "\n", 22 | "# Import TensorLy\n", 23 | "import tensorly as tl\n", 24 | "from tensorly.tucker_tensor import tucker_to_tensor\n", 25 | "from tensorly.random import check_random_state" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stderr", 35 | "output_type": "stream", 36 | "text": [ 37 | "Using pytorch backend.\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "tl.set_backend('pytorch')" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "Make the results reproducible by fixing the random seed" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "random_state = 1234\n", 59 | "rng = check_random_state(random_state)\n", 60 | "#device = 'cuda:8'\n", 61 | "device = 'cpu'" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "Define a random tensor which we will try to decompose. " 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "shape = [5, 5, 5]\n", 78 | "tensor = tl.tensor(rng.random_sample(shape), device=device, requires_grad=True)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "Initialise a random Tucker decomposition of that tensor\n", 86 | "\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 5, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "ranks = [5, 5, 5]\n", 96 | "core = tl.tensor(rng.random_sample(ranks), device=device, requires_grad=True)\n", 97 | "factors = [tl.tensor(rng.random_sample((tensor.shape[i], ranks[i])),\n", 98 | " device=device, requires_grad=True) for i in range(tl.ndim(tensor))]\n" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "Now we just iterate through the training loop and backpropagate...\n", 106 | "\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 6, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "Epoch 1000,. Rec. error: 12.333837509155273\n", 119 | "Epoch 2000,. Rec. error: 8.580610275268555\n", 120 | "Epoch 3000,. Rec. error: 5.916234493255615\n", 121 | "Epoch 4000,. Rec. error: 4.007179260253906\n", 122 | "Epoch 5000,. Rec. error: 2.6491668224334717\n", 123 | "Epoch 6000,. Rec. error: 1.7086536884307861\n", 124 | "Epoch 7000,. Rec. error: 1.0960273742675781\n", 125 | "Epoch 8000,. Rec. error: 0.7418539524078369\n", 126 | "Epoch 9000,. Rec. error: 0.5647979378700256\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "n_iter = 10000\n", 132 | "lr = 0.00005\n", 133 | "penalty = 0.1\n", 134 | "\n", 135 | "optimizer = torch.optim.Adam([core]+factors, lr=lr)\n", 136 | "# [core, factors[0], factors[1], ...]\n", 137 | "\n", 138 | "for i in range(1, n_iter):\n", 139 | " # Important: do not forget to reset the gradients\n", 140 | " optimizer.zero_grad()\n", 141 | "\n", 142 | " # Reconstruct the tensor from the decomposed form\n", 143 | " rec = tucker_to_tensor(core, factors)\n", 144 | "\n", 145 | " # squared l2 loss\n", 146 | " loss = tl.norm(rec - tensor, 2)\n", 147 | "\n", 148 | " # squared l2 penalty on the factors of the decomposition\n", 149 | " for f in factors:\n", 150 | " loss = loss + penalty * f.pow(2).sum()\n", 151 | "\n", 152 | " loss.backward()\n", 153 | " optimizer.step()\n", 154 | "\n", 155 | " if i % 1000 == 0:\n", 156 | " rec_error = tl.norm(rec.data - tensor.data, 2)/tl.norm(tensor.data, 2)\n", 157 | " print(\"Epoch {},. Rec. error: {}\".format(i, rec_error))\n" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "# Now a CP decomposition :)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 7, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "# Create random factors\n", 174 | "factors = [tl.tensor(rng.random_sample((s, 5)),\n", 175 | " device=device,\n", 176 | " requires_grad=True)\\\n", 177 | " for s in shape]" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "If you are not familiar with list comprehension, note that this is equivalent to writing:" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 8, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "factors = []\n", 194 | "for s in shape:\n", 195 | " factors.append(tl.tensor(rng.random_sample((s, 5)),\n", 196 | " device=device,\n", 197 | " requires_grad=True))" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 9, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "data": { 207 | "text/plain": [ 208 | "[torch.Size([5, 5]), torch.Size([5, 5]), torch.Size([5, 5])]" 209 | ] 210 | }, 211 | "execution_count": 9, 212 | "metadata": {}, 213 | "output_type": "execute_result" 214 | } 215 | ], 216 | "source": [ 217 | "[f.shape for f in factors]" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 10, 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "Epoch 1000,. Rec. error: 0.637516438961029\n", 230 | "Epoch 2000,. Rec. error: 0.5492339134216309\n", 231 | "Epoch 3000,. Rec. error: 0.4935280978679657\n", 232 | "Epoch 4000,. Rec. error: 0.45761317014694214\n", 233 | "Epoch 5000,. Rec. error: 0.4335983991622925\n", 234 | "Epoch 6000,. Rec. error: 0.41660022735595703\n", 235 | "Epoch 7000,. Rec. error: 0.4029829502105713\n", 236 | "Epoch 8000,. Rec. error: 0.39038383960723877\n", 237 | "Epoch 9000,. Rec. error: 0.3780127465724945\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "# Optimise them...\n", 243 | "n_iter = 10000\n", 244 | "lr = 0.00005\n", 245 | "penalty = 0.1\n", 246 | "\n", 247 | "optimizer = torch.optim.Adam(factors, lr=lr)\n", 248 | "\n", 249 | "for i in range(1, n_iter):\n", 250 | " optimizer.zero_grad()\n", 251 | "\n", 252 | " # Reconstruct the tensor from the decomposed form\n", 253 | " rec = tl.kruskal_to_tensor(factors)\n", 254 | "\n", 255 | " # squared l2 loss\n", 256 | " loss = tl.norm(tensor - rec, 2)\n", 257 | "\n", 258 | " # squared l2 penalty on the factors of the decomposition\n", 259 | " for f in factors:\n", 260 | " loss = loss + penalty * f.pow(2).sum()\n", 261 | "\n", 262 | " loss.backward()\n", 263 | " optimizer.step()\n", 264 | "\n", 265 | " if i % 1000 == 0:\n", 266 | " rec_error = tl.norm(rec.data - tensor.data, 2)/tl.norm(tensor.data, 2)\n", 267 | " print(\"Epoch {},. Rec. error: {}\".format(i, rec_error))\n" 268 | ] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "Python 3", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.6.6" 288 | } 289 | }, 290 | "nbformat": 4, 291 | "nbformat_minor": 2 292 | } 293 | -------------------------------------------------------------------------------- /notebooks/images/FC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/notebooks/images/FC.png -------------------------------------------------------------------------------- /notebooks/images/TRL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/notebooks/images/TRL.png -------------------------------------------------------------------------------- /notebooks/images/example-unfolding-fibers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/notebooks/images/example-unfolding-fibers.png -------------------------------------------------------------------------------- /notebooks/images/example_tensor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/notebooks/images/example_tensor.png -------------------------------------------------------------------------------- /slides/1-deep-nets.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/slides/1-deep-nets.pdf -------------------------------------------------------------------------------- /slides/2-tensor-basics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/slides/2-tensor-basics.pdf -------------------------------------------------------------------------------- /slides/3-tensor-decomposition.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/slides/3-tensor-decomposition.pdf -------------------------------------------------------------------------------- /slides/4-tensor-regression.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/slides/4-tensor-regression.pdf -------------------------------------------------------------------------------- /slides/5-tensor+deep.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeanKossaifi/caltech-tutorial/4ae9040898e94d7bb3b220daf6251d2b35a2e13a/slides/5-tensor+deep.pdf --------------------------------------------------------------------------------