├── HW0 ├── Colab_Tutorial.ipynb ├── Google_Colab_Tutorial.pdf ├── Pytorch_Tutorial.ipynb ├── Pytorch_Tutorial_1.pdf └── Pytorch_Tutorial_2.pdf ├── HW1 ├── HW01.pdf ├── data │ ├── covid.test.csv │ ├── covid.train.csv │ └── sampleSubmission.csv └── homework1.ipynb ├── HW10 ├── HW10.pdf ├── homework10.ipynb └── hw10_report.pdf ├── HW11 ├── HW11.pdf ├── homework11.ipynb └── hw11_report.pdf ├── HW12 ├── HW12.pdf ├── homework12.ipynb └── hw12_report.pdf ├── HW13 ├── HW13.pdf └── homework13.ipynb ├── HW14 ├── HW14.pdf └── homework14.ipynb ├── HW15 ├── HW15.pdf └── homework15.ipynb ├── HW2 ├── HW02.pdf ├── README2_1.md ├── README2_2.md ├── homework2_1.ipynb └── homework2_2.ipynb ├── HW3 ├── HW03.pdf └── homework3.ipynb ├── HW4 ├── HW04.pdf ├── homework4.ipynb └── hw4_report.pdf ├── HW5 ├── HW05.pdf └── homework5.ipynb ├── HW6 ├── HW06.pdf ├── homework6.ipynb └── hw6_report.pdf ├── HW7 ├── HW07.pdf ├── data │ ├── hw7_dev.json │ ├── hw7_test.json │ └── hw7_train.json └── homework7.ipynb ├── HW8 ├── HW08.pdf ├── homework8.ipynb └── hw8_report.pdf ├── HW9 ├── HW09.pdf └── homework9.ipynb ├── HomeworkParticipation.JPG ├── README.md ├── StudentSource.jpg ├── cover.png ├── myPerformance.jpg └── slides ├── GuestLecture_QML.pdf ├── W14_PAC-introduction.pdf ├── attack_v3.pdf ├── auto_v8.pdf ├── bert_v8.pdf ├── classification_v2.pdf ├── cnn_v4.pdf ├── da_v6.pdf ├── drl_v5.pdf ├── gan_v10.pdf ├── introduction-2021-v6-Chinese.pdf ├── introduction-2021-v6-English.pdf ├── life_v2.pdf ├── meta_v3.pdf ├── normalization_v4.pdf ├── optimizer_v4.pdf ├── overfit-v6.pdf ├── regression (v16).pdf ├── self_v7.pdf ├── seq2seq_v9.pdf ├── small-gradient-v7.pdf ├── tiny_v7.pdf └── xai_v4.pdf /HW0/Colab_Tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Untitled1.ipynb", 7 | "provenance": [], 8 | "toc_visible": true, 9 | "authorship_tag": "ABX9TyP7fmSj06arI6zfMZAGI0Pn", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "d8U6fiicXLUM" 35 | }, 36 | "source": [ 37 | "# **Google Colab Tutorial**\n", 38 | "\n", 39 | "\n", 40 | "Should you have any question, contact TA via
ntu-ml-2021spring-ta@googlegroups.com\n", 41 | "\n", 42 | "

\"Colaboratory

\n", 43 | "\n", 44 | "

What is Colaboratory?

\n", 45 | "\n", 46 | "Colaboratory, or \"Colab\" for short, allows you to write and execute Python in your browser, with \n", 47 | "- Zero configuration required\n", 48 | "- Free access to GPUs\n", 49 | "- Easy sharing\n", 50 | "\n", 51 | "Whether you're a **student**, a **data scientist** or an **AI researcher**, Colab can make your work easier. Watch [Introduction to Colab](https://www.youtube.com/watch?v=inN8seMm7UI) to learn more, or just get started below!\n", 52 | "\n", 53 | "You can type python code in the code block, or use a leading exclamation mark ! to change the code block to bash environment to execute linux code.\n", 54 | "\n", 55 | "To utilize the free GPU provided by google, click on \"Runtime\"(執行階段) -> \"Change Runtime Type\"(變更執行階段類型). There are three options under \"Hardward Accelerator\"(硬體加速器), select \"GPU\". \n", 56 | "* Doing this will restart the session, so make sure you change to the desired runtime before executing any code.\n" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "metadata": { 62 | "colab": { 63 | "base_uri": "https://localhost:8080/" 64 | }, 65 | "id": "fsy07w40XKOz", 66 | "outputId": "849f894a-a4b3-40d0-ea63-d3fcac55bb8b" 67 | }, 68 | "source": [ 69 | "import torch\n", 70 | "torch.cuda.is_available() # is GPU available\n", 71 | "# Outputs True if running with GPU" 72 | ], 73 | "execution_count": 1, 74 | "outputs": [ 75 | { 76 | "output_type": "execute_result", 77 | "data": { 78 | "text/plain": [ 79 | "False" 80 | ] 81 | }, 82 | "metadata": { 83 | "tags": [] 84 | }, 85 | "execution_count": 1 86 | } 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": { 92 | "id": "h-KYsnyIXUMq" 93 | }, 94 | "source": [ 95 | "**1. Download Files via google drive**\n", 96 | "\n", 97 | " A file stored in Google Drive has the following sharing link:\n", 98 | "\n", 99 | " https://drive.google.com/open?id=1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV\n", 100 | " \n", 101 | " The random string after \"open?id=\" is the **file_id**
\n", 102 | "![](https://i.imgur.com/33SW1WZ.png)\n", 103 | "\n", 104 | " It is possible to download the file via Colab knowing the **file_id**, using the following command.\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "metadata": { 110 | "colab": { 111 | "base_uri": "https://localhost:8080/" 112 | }, 113 | "id": "unWRjSw4XZYz", 114 | "outputId": "f66257e2-cd9e-47b0-8767-4963809d8056" 115 | }, 116 | "source": [ 117 | "# Download the file with file_id \"1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV\", and rename it to Minori.jpg\n", 118 | "!gdown --id '1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV' --output Minori.jpg" 119 | ], 120 | "execution_count": 2, 121 | "outputs": [ 122 | { 123 | "output_type": "stream", 124 | "text": [ 125 | "Downloading...\n", 126 | "From: https://drive.google.com/uc?id=1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV\n", 127 | "To: /content/Minori.jpg\n", 128 | "\r 0% 0.00/219k [00:00\n", 187 | "  ![image.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAANgAAADVCAYAAAA1vj4zAAAgAElEQVR4Ae2dB5gVRdb3a0BA3DXvvru+++qKYU2oDKj7qawoSpBhABVBUIYgUUBQESQpBkQJ4iomBDGQBBRBUVmVXcyKuqKAgBIFyQxhEvF8z6+Gc6dv356ZywTm3rlVz1PTc7urq6pPnX+fU6erzjGNGzeW/HJKSorUqFFDKiRVlu3bt9ucnp4uO3bssHnnzp1C3rVrl+zevVsyMjJszszMFHJWVpbN2dnZQs7JyZE9e/aE8t69e0Xzvn37hLx///5QPnDggATlgwcPSmFZXHIUiAEKmPzAxfkQwCo4gMXAWLkuxCEFogJYkgNYHA6t63IsUMCkpqYWqiI6gMXCULk+xCMFLMDyA5mqiABswYIFNn/zzTfy7bff2vzdd98J+b///a98//33snDhwrD8ww8/CPnHH3+0edGiRbJ48eJQXrJkiWj+6aefxJ+XLl0qQXnZsmVSWF6+fLm47GhQ1jxgVcRoAOaMHPH4/nR9LmsKRCXBKlSo4qyIZT1Srv24pIBp2rSpkIOkmKqIDmBxObau0zFAgZCRA4A1adIkzODhABYDI+S6ENcUMFOnTpXevXs7CRbXw+g6H6sUML/++qu13r3//vvSv39/J8FidaRcv+KSAgYz+sqVK63pfPDgwXb1hq7ucCpiXI6p63QMUcDwnenFF1+UG264IQxc4UulnBUxhsbMdSWOKGB69eol9erVC1MNnQSLoxF0XY1pChjUwCALYrQSbEt7I1s7GNnWwbjV9DE91K5zZUEBg2k+6BuYA1hZDIdrs7xRIKrV9AV9aHYSrLyxhHuekqSAA1hJUtPV5Sjgo4ADmI8g7qejQElSwAGsJKnp6nIU8FHAAcxHEPfTUaAkKeAAVpLUdHU5Cvgo4ADmI4j76ShQkhQwqdG4bStgw6Uz05fkcLi6yhsFTLPUwv0iJjmAlbdxd89zhChgmjZrFrgO0a3kiByBVatWWdcJ3is4XMWRD05TXXIU8FPANLvxJguwlABVMZrtKvGiIuINa8CAATJ8+PAC88MPPyxz587100kAV7Vq1eSSSy6RDRs2hK7/+9//tufHjBkjn3/+uXz88cc248HYJUcB06z1rQkhwd59910LBEBSWH711VfDOAM33Z07d7ZrNuvUqSOXX365bN682ZZ5/PHHI+oDhLgXd8lRwLTq1LnUALb7metkV38j2b8tlaxFcyV79kDJeTsv73l7kGje+84gIe+bM1j2fTHe+qcP8kvPucL80nPdnwAYjI/7ufwS6h4A8gNszpw5FkRr1qyxPvevv/56W27t2rXWYdDo0aNtlbTbr18/6dixo+1/fu2484lDAXNbxy4WYEEr6ourIu7saWTH7caCK3NWf9nV0cjuTnk5o5MRzZmdjJCzOhvZ80RymQCMgBX169cPAxgOTpF41atXl9dee01mzJghAOyBBx6wTlaDpCF+TlxyFIACpkOP3rkSLMCFdnEBlvXLl5L58Qs2skpZR1cpqor4yy+/RKiAqIhEjkGi4cvko48+kvnz58v48eNtWVwwuOQoYAHWtU9/ye9bWB7Ajs7X8Wi8GDkUYMyZnnrqqcA8YsQICxC/ikhIJayESDeVTpybMGGCvPLKKzZPmjRJ+vTpY++njaFDh1rDiGOzxKaA6TXwQQmZ6n1SrNgA27ZBstb/FDMSjDkYkie/xPwO9c8PsOeee07atGljwcP19u3bW2nVtWtXe556VVVE1SbzG5/8LiU2BcyAx0ZJi1atA+dhxQXYrn65c7DsFV9J1idjJXNUsmQ9kZezn0gWzTmjk4W8Z3Sy7J3evdTmYL/99psFmQYJ9B4JLuifg8Eer7/+unVpB2gAD35Mvv76a2vOB3wA87333pNWrVpZbvrggw9k2rRp1hiT2Ozlnt489uw4uSWtbekA7MGTrJEjBLCRyZIFyA7l7FHJojnniWQhY+AoLYCplCns6JVgqIKjRo2yviNxMT5r1izBqohqCLioq2/fvjJz5kypW7eujSaDRHOWRAcuKGDGTZkmIUtivq6zizgHizEVEUvg9OnTrWEC44Q/z54925ryvQBDOjVv3lwmT54cUh/5JkZZEh+X+Tit8y8A16FDB2vOdyzmKGBmvv+htO9+Z6AlsbgqYizFaNbvYEWZgz377LNy9913W/WReRdA3bRpk+UePCOjVnql4rhx4xxnOQpYCphPvv5GOvbuI02bNBGWS3m/hxUbYNZMPzamjByswCDwujcYu/7P6ougORgBB1EHUf34qPzWW2/JF198IVOmTLHAwmyPFOvevbudewG2YcOG2cDwjs8SmwJm0bLlcufAB6R585sjpFhxARZrH5q9Uqag/70qIqs7WG/IOe89rVu3tr9vvPFGuzZx4sSJdt7Fag4+SFMWw4dLiU0B89umzXLfsBHS4pZIS2IewKoW6TtYxvgWgiVRl0plzR4g2W/n5bBlU+8MlD3vDJS9cwaVylIpjBMw/ZNPPimocGPHjg3LuA/HmEEZL8B27dplz3Ee8zxzMcKSAiQ+QgNATPdcb9GiRWiJFGF1MZC4lNgUMFk5OTLsmbHSsk3aIUtiXowwBVjFikUDWCzNwYghzVyqoG0lgAYAoRJ60/r16yUnJ8d7Kux/VtC/+eabFnhhF9yPhKeAeWHcWBn+/Hjp2uueXIB5LInlCWAJP9KOAGVCAdOs+Y1y70NDpV2XbvkCrEI5kGBlQl3XaMJTwNS99lqpXftKueaaq3ONHJ6NlyrBHMASnk8cAYpIgei8SjkJVkTyutsSnQIOYInOAe75S5UCDmClSl5XeaJTwAEs0TnAPX+pUiDuAcZi3KDMNy2XHQ3KmgfiHmD5EbBUX0uuckeBKClgQ8gSRjY3lGy4l19npo+Siq6Yo0A+FHASLB/CuNOOAiVBAdOrXz+5q/8A6dajR9hWlTDX2e47WEnQ2tWRgBQwqTc1l/opqdLprj7StWevsNUcTkVMQI5wj1yiFDAnnXSSVK5YQWr9/f/JnX36HjbAdu7cKWzpYNtGRkaGzepIJpZW05co1VxljgJRUsCcesE5cuIfTpZLL79Sut+du6Ie9TBaFdEBLEpKu2IJSQFz6gXnyXEnnSCX/P3vckfvu5wES0g2cA9dWhQwf/nbWXLcCcdLzUsvk669DrnR9kmwihWPyXdHs5NgpTU0rt7yQAFz5qXJcsIfTpaLkmtKlzuDjRwVK/7OAaw8jLZ7hiNOASvBfn/csXJhck3p2KNnoIroJNgRHxfXYDmhgDn90mQ59uSTpHr1C6VjtzvyAZiTYOVkvN1jHGEKmNMuv0R+f/JJct5550uHQ24D/FZEpyIe4VFxzZUbCphzmjaU4049Rc7+2znSzhftUj80O4CVm/F2D3KEKWCqpTWXY/9yipx51tnS9vaOTkU8wgPgmivfFDAX971DqvzPyXLGGWfJbe3bO4CV7/F2T3eEKWDOvLmx/O7//iynV6smrQ+FMXJzsCM8Cq65cksBc9pVf5eqfzhZTvvrX6XVrbc5CVZuh9o9WFlQwJx19eXyuz+eLKeedpq0bJ3rnz5CglVwZvqyGBzXZvxTwJxy9hlS5ffHyP+d+n9yc8uW5U+C/XKMiOb4Hy/3BHFGAXP6hRdI1eOPl7/85S/S/OZDIYz8axHjWYIpuDi6FFMUYDsTcdc2bNhg+0XcNhLhoohEir8VfyKWG+Up89VXX9mtUtxHwPlt27bJihUrbNQbIt9kZ2f7bz/iv80fTztVKletKqf87ylyU/Ob8pFgv4/ftYjFBBgD9sknnwgRVNauXXvEBghGIlwSQf8Il1TURNB37f+qVauiruZw2ycsFJE/f/rpp6jagJ5Dhw61YZ8ee+wxG9uaKKJjxoyxAeeJce0HGL/vvfdeew9tjR8/Xvr162d/QydARxgpzUTFKetkqlSuIklJSVKt2ulCMDmdf3HM+9CceABjAykMroOlRwaYDaWlnWhDQ9P+8MMPh90czKgMrH3nSBTOaPp/uO0PHjzY0iravj7++OMh2upzAprvvvtOCDZPpFXKDB8+XAYNGhTqs0YaBUz/+te/QmF9iTI6b948mwEegCNiaVknc9U//iFXXnmlNGzQwD6YA5jYN6cyDIM+bdo0+7ZURtUA6KU5eMQj08B+0TKttz/0WfsLw3kjdBJosLB0uO0PGTLEtheNtGWLE2oeKh59JLYaz9qrVy9ZvXq1Pde2bVsLEq7XqVNHuIc0d+5ce/3BBx+UmTNn2tjZqJN33XWXDaDIS6V///6xAzDeFvrG8IKrJCUYRH/iiSckLS1NIBwZlQL1C/2ZTGA8MlEhNQc5FOUcb+fCcoiBiqAirlmzxg4ig0ukSk3EZuYckSzpB+nzzz8XAqPzxkXi+YP3FXQdNeypp56yzNCtWzfLIDDZ+++/bwP++QHGnOWFF16wmgbt8T91+BPgUKnwn//8J3SZvtF/3u5IaGj4xhtvhOqDMdetW2fLBwEMFfOBBx6wzwrPwNgaxVMBxoupc+fOVl0E5EqnUCdEbJBD+uHPvMwADufhGUDI/ytXrgzdrgBDEsNDvDwogwT74IMPbBjfDh06xA7AFFTe4Od6TlXEChWLriIuW7bMMiDgAmQzZsywgcQBGUzF5DTWAAYjMWjNmzcPMZCOMIypKhbzCMr5szJ1Qddhbp0/+O9H2vgZHEZu1apVRFt33nlnBBMz2QdEZP7XRN8B2cKFC+3LDAb1tw2TM2/zt6/M7i//zDPP2OoVYP7rU6dO1eZDR/oEX3z22We2fdokqPz8+fND/SG+9aJFi8KAAs1ojzaol/IAHXohweAvaM6cDqmnL4tQw2XwT5R+EYsGsK1bt0qXLl1s5o3EoPHGJWPlYVILyNCVY0mCvfvuu3YQe/ToYZl3+/btdhCnTJki5LfffttaqFTCEPScN/Xrr79u7wOYgLCg614piZSDNo888oi9PwhgX375pb0G49AfmJT/YTbvGx4egracL2ge8uuvv9oylFuwYIHtLxKM3zCxH2AwO9effvpp+6zaH4wRgF8BhkTh95w5c2xddevWtWPr5+1hw4aF2qdNMs89YcIEqyqixqKK84xqDdTn0vKFHb0xs/3tH6nfpQqwjz76yKqFTD4hkhdggAwdHEnGGz+WAPbpp5/aAVeA/fzzz2HMAFMhDVDDvEy8efNmW45zvPELus68CgbxMgEqjjKan8GVYVWNevjhh23blPfP0ZQRUZvym+irlFaAwHBe0HhfEFo/FkLmODwfGg9tcz8vFwXYkiVLLO+qxPPSR5l64sSJ9l6kD21ypK6vv/7aFlm8eLH9zTkMHZoYB56feNsA+ZVXXgllfflxDinN/Az6qwqrdRzpY5QAO7ZIZnp0cFRD3rZBAEM1BGBIgFgCmDIfDIorOgYJhtFJOUzFfMgPICQLDEUGbAVdDwKYSs4gCabXYDqYm75BOySE3zQOqOiDqnvKVPQJyYEqBRCoyzuf1Dkaz+cHGExNeTKqJeZ0/s8PYLSltPCDfNOmTdawgcqnbSL9kerQEF5Q0Cm4eQZeaqi3GEKwLPKSAYA6b9P+IfngOaYfZZ1iAmBM1mMJYMqgDBjqkk7Uv/nmmxBTKcAoo3Mu5hH8Rq2hDgCW3/WlS5faa4AAQJPGjRtnzxUEMFURKQ+jwaD+RH91vsZ8RN/izFu0fyqVAYF+39P2mRvyQlQVFyafNWuWvVfnXCplVMqrBFMLq36T8qp42k8YH0mDSsynIfqkGQB5pbdX/YXm+OBMT0+3LwAAR1nUSe6nD9QNSCnD9bJOUQAsWSpWLJoEi1ZF5O0cSwBjUNRaxcABAi8j6Ft70qRJIcZo06ZN6H+d2Bd03QsC2vDWrwBTgMLgqpIqI2JF438AEsRIKhm0/5TTezEgID2Y/+o5b/tINxjX276qr5T3fh/EYgeAFWBc177x/+jRoyN4XFVRytEv6Au90RIAEHVqv5jP6hwMq6Wej+ZIvX7pGdGZUj5hUDeCLIglYabHyIHJFjM2Fh0GTY0cfGW/55577DWIEGsAg+5YpHgD62AyYKhHvB1JMCm6v17niKrD+Wiu89xYAbmPurUeBZiqSaoConbpSgbKwpxIivwSap23//7ySASvesV1mJ/EWHnbpywqmfZRQcQ9PIcCTM9Tjr5yX1AClLTBNII6KIfhRaUmz6VSnjpRWZlSbNmyxdKfMdi4caP9jbRGZeVbGec1w386FkF9OBLnTLNmzezb84YbbogAWq6ZvugSDKLxHQlLIiB7/vnnrZmeI7+ZQzz66KP2e1gsAkwHgMHNj1Eow8ArA+g93mN+1zFGIOVgUJhFwaYS0FuH93/6QnvRJsoWVF7ri4YZKctLsqCkz1tQGfqjYFVDj35nxCimif8p51UV9ZpXovFSisVkGjfOlWCoCH5JlgewopnpGQwyJnr0ZO+HZqxRfGgGaKgq/A/IeLNpRo0KyjBCYTlE7CJ8aA7dW8r/MH9RJtMjb3MkVSIkVEIMG6oCMu5B3650juinCXMtPh8g6eCdWEwmNTXFrj/MDcCXGrAWsXgSDIBBQDIqAW9tVRN50/lBpuDiGAQuzhUGrrA3cQwDDEsXa+uY05D/+c9/WkkWi4zi+lQ0CpgmqY0sqJo2a1aABCuakUMlmBdgCi6OAMwPMlQBBVl5B1jRhszdFU8UMKkpDaVRo0YR4MozcpSsBAsCmBdkd9xxh7UklZgEi6fRcH0tdxQwjRs1sJYbXX/oPebNwUpXggEwdGjURSx3JSrByt2QuQeKJwrEFMBKxcgRT6Ph+lruKGAaXd8wZiSYA1i546+EfyBTv34DadiwYZj1UNXEI60iOoAlPD+WOwJYgDVo0MC6B1Bg6TEPYMcVabFvtFZEnYM5gJU7/kr4BzING9S3EgxLIoBScHF0AEt4/nAEKCYFzPUN6tlFnawBcwArJjXd7Y4CPgqYhg2ukwb161tDB1IsWIId71REH+HcT0eBaChgGtS7Vq677jqrJuYvwRzAoiGmK+Mo4KdACGAYOpwE85PH/XYUKB4FzLV1r7Hbztlc5wBWPGK6ux0F/BQw19StK9dee63Uq1cv4oNznhXRqYh+wrnfjgLRUMDUPSTBHMCiIVd8l3n55ZdDGxfZ0sN3R3YR40yUrUT+RBl2BeOLRO9lcyg+RPDJwd4t/idTl0uRFDDXXJOrImLowFTvrIi5rgDYNoNjGG/Gw1QsJMDAJlV1kVZYn9iD9+GHH9p72E2O/3d+45mJwBA47QzbQ3eoQp6dfWp9+vSxPgpxfKN713Aayv3sfuAczmpciqRACGCoif4lU4mqIrKdHV8iykzeI85ZCks4bmHbPzt2SyMBMPoXLVPj7WrgwIH2eXAeiq8MnondwHh3Ymc1Gz/xKjVy5MiQlANgnTp1sh548dYEOPGjMmDAAOsgFMkG0Lg/2r6UBj1iuU5z9dVXC1LMSbC8YYKBcWWAl1lUHwCHs5URI0ZYxgzyD5F3d647NWVg7/mS+l8BBkCiSfSf5wBcgOKll16S3r17253lHAEV5+gz4FE3cqiHt99+u1UPCdAAAKEJZQEp0U34nzKoji5FUsCogYM5mLMi5hJIAYZDGm/CYSZMyNsaX44wnKpWrLsk8MHYsWNtGcrhfuydd96xVbDRFMcsMCPXuJdzJI54bHrxxRctsPFMRcJhED4KKY8EQtUjaf/wkItXKK5zPz4q/EnVPMr4M89Hf/DxTkLl84IWUKKKokJyDY/HWgcqJqGFACR1BLmO8/clEX8b5l1kP7jC1yImlhVRGVgZXRkDgwAMxhxm8uTJYW9uwMA1vCDhC5D/8aCLT0NAiEddzlEn3pP4H9DgFkHb4xyeamFkXK7xGwMEjk3x4w9gATnlYXyuI1Gok/9pQwGvfaZ+mJ/5Iy8F6mDuhlRDcnEfUohyeP+ifk04pOE6m2BxLENfiAaDmoyKiqEDv/aokQDZ37bWk8jHqByPHlUxMQEGQNTYgb9A1CkYFNVJwaZefZ999lnLoLz19RoMSELlglFxxKpJAUmUEQUYahgJRn3ooYekb9++lvE5x7yOOgCflicIhSakH6GPAIo/IW24VzPqL3MpwILKyHPRZ4CikkjnoQAZIwcZacVLAfWQo57H0AFdnJrop7xIlAA7IaHWInolhDIlR5hSVShAgOqEaoaLAxhRJR6gpLxO/DEy+OcpysB4zKU9JILWDUjxkej3jwjIuOYvz7DirReg+AGmfQE8ODDlhUDf9MVAffQd8HglIG2oxAT4ZJ6PI8CmPY5IM8pSZ5CKGslyiXUmKoBVPCrxAOY1cjBHgon8CSaDcZVpUalIytQKMMqp5NM6qI82YFL+91oF9ZoCVu/Ro78853E/HgQwJBNzLYBHBkyAArBiuCGh7gIQjW6i7fAczC3po2ailtx33322PPcgOXnZoAqzr8+lcAqYxqnhvhCDvoMlqoroN3KEk06sqqhGC9Q5JBkJYwTMpyqiAk5/U0bVRkDoB4yqiEhH/ichabBiIuX85bmeH8BwJ828EGDgd5F+aWbuReIjMudUgnKOZ1FX1EhbMu0yf6MsoOSFwtwOyQVgXYqkgElp0kRSmjQVPPx6weU1chxVMbEkGMyEdMFKV1CC+dV4AQNrgtkAHvcjQQAHKh/nMAow/6J+pBYMGgQYfLPDyIAcMCIp+K1zNq/Eo938AEa0SowXWPz4joUkxUijMY8xtVNvz549bTn6QirI+kj5oOx9gSgtEv1oGllwNZZcSRYOMv3QfFQCqoh8PM1PRfMyDSshYFqCWWjyAg+VDNUMAwCmbmVM5nMwPwmmBnCqUnKOOjDxa3naoK38yqOGeo0itqDvD+2oIYP6mU9RP/NA+shLAOMFaiFJ3eepZOYcUg7Aco57gsr4mk3on+bEl36UGnePlOub3pArwTwqowKsUoIB7HA4AsmSH2PDqEgvb+Kcfv/yns/vf8oiUQFEURP3AhzABMCoE6sov72WTV4CAJ3A6P6kllHuIYysS9FRwJj3D4j5SOSkcd/LdS3bhqmJIYAlmIoYDekISKBzF5YLxXri0wISU9dTAjo1cnj7jpk+yFgBKJlbklWN9N7n/g+mgEmakSFJb2aImSdSq/uD+QDsxIQy0weTKvwsKhlvc7WihV91vxwFcilgzBuZkvRmlpg5IrW6DXEAc5zhKFCCFDBmeqYkzSgMYIllRSxB+rqqEpwCJmlaliRNzxbztpNgCc4L7vFLgQLGvJ4hSdMzHcBKgbiuSkcBY6ZkSNLrmWJmi9Tq6uZgjiUcBUqSAsZMypCkKZli3nIAK0nCurocBaCAMRMzJGlyppg3RWp1cRLMsYWjQElSwJjXMiRpUqaYNxzASpKwri5HAShgzKsZkjQxU8yM0gXYtm3bZPr06XaZDqsCyKwYILM0h8WnLCvyrm1jrVtQZhVCYdkNr6NALFDAmJd3S9KrGWKmidTqXDoqIuBiLVxaWprdN+QFGN6K2rZta/1ZADIHsFhgC9eHkqJALsBeyRDzukitTvkB7KQiL5ViDZyCiy0YCi6vBGPDooKMrR4KsiDpxbnCpBfXXXIUiAUKmKQJGZL0coaYqQVIsEpFA5gXXGz6Y5FoEMBQE/FsBMjuv/9+u3kPkDmAxQKLuD4UhwLGjN8tSS9liJlcgAQrAsBYqa2SC3BlZ2cXCDDmX4CsXbt2IZA5gBVnaN29sUABY8ZlSNL4TDETRWp1LDkVEYDhSIV51+ECDP+CqIoOYLHAIq4PxaGASXpxp1QYtzt3DpaPkeOoo4quInpBFo2KqOByKmJxhtXdGysUMEmvHhAzReSEEd9L3eZtArerVK78h2IZORRkhRk52FIfC0YOjCQ4qsH/BVI0KOHwheu6ORGXaKXtkwL6ef14sMX//fffD+qeOxcjFDDHj/hRqncZLg1Smx0CV55fDt3RXByAsUUeM72CDPdeauiAOdVMD7hixUzPFn0YmQ2VXj8ZOmZ6HT8ZbLMHkPjYYLt9foDUe4tzZJOn+ldEG6C9whzzFKc9d2/xKWAaNm6SCyyPLw71LqUAq1Sp6BIMgGHgiKcPzTAvzmoAGN59/WZ/nM9wTZmdYcBiSi7N5ABWmtQtnbqtX8TC3LZVqnRykVVEBVg0VsRYWcnhBRgg8gIHsOGjHYDhQEYlGB6oCAXEdbKCgXL4qvd6vcU3hrp7w+007rAVxBzxBY8bNe4dM2ZMyMuT1kmbQRKM+1AbuQ8LrgaLQGPAxyI+8dEkuO4PFsELUH3q41ELl9qa6BPOcbRu/OerPw/qZt68fPlyLW77/+STT4akOW7qcIpKu7SPap0oyaQ0bhw271Lp5fWLmIgAQ/0CBKiKXp+H6jAUXxwqwWBAry953KDBTDjFwfUa5XB1hkoMWKkTcDKnglkpq45z1E8hwMKbLtdQn1E9CwKYtsl9fLin/7QDEBWM1EW/eRnwv7rKRuVFYtNP2lDw65xyzpw5tjzepAAa9eqLh7r57VWl8dGofvLVGxV14iOfIBa0TXTMREhRuc6udFTiSTAYlDWSeL8FHOp+DQblN8znBRhAUe++SALe9lhCSURK4S2PRycFiUpFwImPe5hXgaAed7kXxlW/i3pvkATD7TWA1KTehTHWaL3e6zyHgkDrLahP9I++kijHs0Of/ACmbrx5idB/jdqCpCROmVfiaZ/L4zE6gCWgigjAcMypIXxgdGUmzuNoNAhgSBrUK97SMBZxwFQSwKBIGL8fRdRo6lbGRW0cOnSoPPLIIyG17Ntvvw2BMwhgMCcg03hitE/29tvrGhtQKQhQbSmLFEZ1pG1+A0CeU+tRAKjEZn6K6h8kwbRufSbqoDySujQNQdrHWDlGCbDiGznibQ4GwHBbDTPBLIT6gTkADfOP1atX5wswBnfjxo32fp23ECgPKaiSLojJ1Ic9ZWn7tddesxlVE1VLJU0QwJAmMDFlsdRSVoGhLwa/GsdzIWXpE88F0LTNKVOmWAri80UAABNxSURBVPV2xYoVth7/nIx7AGQQwLzghRa8QGbPnm1X6NAn1GOdw8UKEEqrHw5gAZSFIQGYmsBRc2AMss5bAEOQBINhmeOom2uqZ16lZWE+3vgwnSYCKTDP0Xa9cz7KaNn8AKYSBemjSVVErwQLApjO7dTjr97Pc1Cv9snrRlzPQR/9ZOGtm+dRKU3oJsCqSeewSORESA5gAaMM03gBxsdvQKESgVv8AAN4OgfjzQ6gYC6CKDAn4zdec72TfurAIEC9SCAYGhWS31j8YEakit5bEMC0DV4GWBO1v0gzAFGQGodVkzYJqIc6i9WP8oS8pU8aFpeXAHMnno/y9J+6MZBwL9ZBPnxzDSsm4NXnQ63meVTSqtocQP5ydcoBLGA4YRp/8AdUL6+xww8wVZlgKoCkTAizKdi0KRiYc1wj88ZXlRE1kigoeg3VDaCSkIRIGupXxoZxSYAElU/vQ5rxP2otZXlheKUMYFUpw/2sRFFQch8vDJWc9AnDi9ZN370qI/3zXkMtVoABUNROvc6RjbecT4RkmjTVFRyR5vqS/NAcT3Owkhp4vhEhDYOYCUBxTa2T/jb13vyu+8vrb0DBvUVN9IkclKiXHPQ8fIIo6HkKux7UXnk4Z7r2HhSxREq/hSnAKpfASo5EBFh5YBD3DMWjgOk96ElpldbFgizVt1zKAax4xHV3OwqYTnc9JB16DpYmzQ7FB/NEunQAcwziKFA8Cpj23QdbgN14S1qEqugAVjziursdBUxat/ukXY8B0qJtF0nxSC/vWkQ3B3OM4ihQNAqYNl17y62d75aW7e+Qxhqv+dACYCfBikZUd5ejgFLA3Naxi7Roe4eVYKUBsOyMDNmzI132bt0i+zZvlAObNsiBjettPrh5gxzYskkOpG+TA7t3yf6cnJDLNlYSYMoOypiJC8v6gO7oKFCWFDAtWt8mqS3S5MbWaZKS2jRs60qeBPvj4e0Hy8iQ7G1bZe+G9SILvxZ5+Z8ifduLND5P5AKDP+HcfLYRqXeqSO9bRF54XOS7z0U2b5CDu3fK/r17A8EF4AoDV9B3mrIksms7cSlgUlJTpG6jG6TRDc2lkW9vmAKsUqUoAbZ7t2Rt2Sz7ly0SefZRkcs8YFJQFXb8qxEZfp/Iku9FdqbLgX37IoDmAJa4DBtvT27q1LlKrri6ntRt2EiuT0kJlGDRACxz+zbZ9/MSkWH35kmowsBU2PX7Oogs/UEkc7cc8KiMDmDxxmaJ219zyaWXSa2/15Zr6tWTlMbBAKtc6X8KVBEzN/4mMuNlkeOilFjVjMhZUZYFhKiYWzaGpJkDWOIybLw9ubngwmSpccmlUve668KkV7iZPh+A7dghWatXiPRJy19qXWFEhvUR+ehtObhmhezJzg7zKnVw43qRL/4t8swjIinn5F9P+/oiK5fJwb25a+EKA1m8DYTrb/mkgDnznHPlouRkqd+g3uEBLD1dcpYuEmlSPRgUbeqKzJsj2Tt32E150a1F3Cvy/Vci99wWXGctI/LjNyJRgKx8Dpd7qnijgDnz7HPksssulVSfgaNgCZYu2T/9KHLt/0YCoaYRmTtTspBuWVk2A67oAJYXH0z++6VIs4si6z/DiPzwjci+vQVaE+NtIFx/yycFzDnnVZeatWpKamr4/MsLsEq+OVjmmhUijc6NZP4eN8ueNSslMyPDbl0oDsD4DiaZGcFGkwuNyIqlcvDQrtsgdbF8Dpd7qnijgKl+0cVyznkXyKWX1pLGKY3C1EQ103uNHDs3bRTpcXMkuIbfZ030eE7SPUXFBZh+ZJbJz0e21/JykU2/ycF8vovF20C4/pZPCpiLLr5YLrjwIjnv/OpyZe0rgwFW+U/WipiO59qXn4pk9hEDJHPLJuuWrDQABtBk6tjIdp9+SGT3rkBVsXwOl3uqeKOAuaB6dbm4Rg258OJkueLKggGW/cO3kUzeu7VkbvjNbmMHXKUFMAuyJwZHtv/Np4HzsXgbCNff8kkBc97558uFF18sF15UQ2rX/kewBKv0J0nfslmkZ4twBq9dSbJW/Sy7du08IgA7uCdHpO214X1oXVtk2+YIKVY+h8s9VbxRwFx1VR2p/Y+r5Ko6daRBgwbBAKv8J9n76UfhjG2MHJw/V3Zt32ZDDuGIpbQlmJViK5eJnOj7SP3hbJFDviLU4BFvA+H6Wz4pEJVXqaqV/iTSvXk4wO5tKxkb1tuQQ7g1O1IAsyAb83B4XzB4bN8aJsXK53C5p4o3CkQFsJpVTgxnaGNkz/cLZOeO9LIBWPo2kVN9UmzhVyL79oVAFm8D4fpbPikQFcBG+Bfl9rhZdq//1YKLoHlHWoKhBsrI/uGgf7yvSMZuB7Dyyadx+1SFAqxmjWRZ4gfYv2bJzq1byxZgy38MB9jFJszYEbcj4jperihgmjZrJuQmTZtGLJfiQ3ODiy4IZ+QkI5mrV9jAc0ivspJgBw8eELnOt1Rr+aLQ6o5yNUruYeKWAoVKsIF/OyUcYN2bW/WQCB+FAQy/7PPnz7e+0vGXTkAEMpE18BCLt1fNeLAls0RKMwaNoByyFI7wqYlvvCKSk23VxLgdkUI6zrMT19oaewop6y6XPQXMgKHDZfDjo6Vn3/7SpMmheM0epzeT/ugzJowbJbs2bSxUghEk4PTTTw/M119/vQWZgotjkQD2wVvh4H+ol0hG7sqOsidtZA9mzZolLVq0iCrj215Tenq6DZrw2GOPSVpamo0BRhAJfMI3b948lJ977jm9xR1jhAJmUXqO/JwjsmDdeml7e6eI72DzffOvAx+9Izu3bS0UYMOHDw8El4Kudu3a0rJly1C+5ZZbRDPxsaJxesOC35B/D/rZ7jrrZsAaQYpJ4Pvvvz8sYIE3eIH3f+I1R5umTp1qo6dgFEIDIG6zPwMmrnv9wxNYjyAQRMr0+qpHOxgyZIjVBoh8MnLkyGi74sodIQqYz1etlwXrt8iXK3+VTt3uiADYMh/Ashf9V3akp0cNMD5e33zzzTYX9PZWsCHdqlWrZgPIBamHnAupiFkZ4QBjc2d67vew4tAP5vaCqLD/iYIZTQJghAIiEcaHHQv6UuEIDZBI/gTACB7uT0Rl0XCzc+fOdQDzEygGfpuPl62Sr9ZskI+XrZTbu3SNANh6H8AyVyy34CpsDqYSjIDch7MfjHA7AIxwo4UBzEqqYzwqLN/Gtm0pkTlYEKiI2UUcLP81Qv9Ek/wAGz9+fNht0Ck/gNWvX1+aNWtmM6Fdmce2adPGzseo5Msvv7QAJRj65s2bw+p1P8qOAmbe4p/l0+Wr5ePFywMBts0HsIyVP8cWwM73AIy+HlqXWFyS+kHEbxLzIP+1ogKsY8eONnwrscXIjz76aL4A80swoksyp9OE6khkzGeeecYGxtPz7li2FDBfrlwn835cJh99v0Q6dM6NsoLqohsut8Y6wM6NT4Bt2rRJFixYEJGZZ/kTKmL37t1tkHP+J5g5MaKRqP5MTGWXYocCZsJbc2T+Tyvkva++kw6dOkeoiL/6AJb5y7LYkWDsE/P278ySVxFh4KCE1FJJdrgSjDCvSCTNAIbolfqb45IlS0LNAirmZ0hP8rhx42wo17vuukvuuecemTBhgs233nqrfPjhh6H73D9lTwHTukMn+XjJLzL7ky+kQ+dIgPlXceQsXHBYRo7SnIPJrh3hALvmZJH0bSU6B7v33ntDahzDhWEBdQ7rYVEBRkhVQIFxZOjQofY7IfMx1D7OYRlkLqopPyPH0qVLpXXr1vZbIvGRCe1qXS3oje5Y5hQwNS+5RGb+51OZ/fEX0q5jpJn+Q6+EYIvKe2/Izq1borYilirAfloYDrCuTQTQlYSZXsHjPTJaJTEHA2CzZ8+2g//000+HAEbcZNKoUaNCAAMw0JC4x59++qkFOH1A6pH69esnkyZNsjGYNZazveD+xAQFzLl/O1smTJshs+d/Jm3bd4hQEV86xTfHGfOI7Nq0ITYA9vaUcIDhcjszd8FvcanrBZb+j8TwSi49f7gqIgC77bbbLGj4hMEqFyQYVkGAdNNNN4UABmiYDzdt2tRKthdffFHmzZtnv5XxjBs2bLD7+Hr27FncR3b3lwIFTHKNi+TpceNl5ofzJa1d+wiA3X3uqeFM3LaeZKxbGzXAnn/++YilUrpkijeyZlQiMvOLaM30MqR7eN/enSayJ6fUJJgCyn88HIC98MILdu8cRg7NOAfi47L+5ojJXlOQioiUZr6Fijho0CCrHvJphA/VLsUOBUzNmskyeux4mT5nbiDArr74QtnvUxOzfv5JdhzaC5bfYt+Clkrpao6gI+AiYyUr6DsYzkflUp90Xf1zyMtUcUkc7UoOykWbZsyYISygjjazQ5zkBxibW7t27SrdunWThQsX2jKY6V955RW58cYbQ+pjtP1y5UqPAuaKK6+Qpya8KhPfmBkIsBo1kuVrH8Bk5muyc8vmqBb7Mn9gSY930S8qkXfxLxLNK8kAV2FLpWyoI2+/rj0lZOAoiTkYJEcyFZZLb2jyakYqLV68OO+EiF02FfScGElycnLCyrofZUcBk9L0Bnl+0lQZ99rEwDkYABvoZWT+b32VZKxdXSjASsovol+S4XBUBnYKVw+fHyaSlWHVwyDGKzsSu5YTmQLmnkFD5LkJL8uz4ydI2/aRczAAdlbVP4scHa6O7f9snuzavr1M9oPJ+tXh4AL0y38M7QVzAEtklo6tZzcjxoyVUf98SsaMHStp7dpFGDkAWOUqp4oM6hbO1O3qS+baVUccYERXkSE9wvvSo3nIPA+4HMBii8kSuTfm8SdGy7ARI+TJMc9Im7Zt8wVY9vdfhzM1UmPWZNm9dcsR9MmxPzfMrF9l/Wp+hPPRRB5U9+yxQwHz8NBH5JFHH5WRo5+UNmlpgQCrUuVUu29JHrwzHGR/MrL3x29l144dR8Rtm2zfItLgr+F96N0qQno5CRY7DJboPTH3PzBEhjw4RB4fObJQgO1esUzkovC5mLSqLTkrlsvuXbtK1fEoHqOEkLJ+6bUsfO7lVMREZ+nYen4zcNBAGTx4kCDJWrVuHSDBaopKMHbfyizf6gkY/t52krN6hWQc8u5b4tFVANdTD0aCa8pYkeyskOVQweUkWGwxWSL3xjxw//3y8MMPycABA+xmPt2qottVatQAYKeFtrbblQL4vvBLkj5psnflz5J5aLs7ICu2mR5z/M50kSfvj2zvrmDVUEGWyIPqnj12KGBw18ZO2capqdKoUVB8sEiA7dy6WaRjSiTT31hLDnz/tWRv3yaZmRnFAtiB7CyR9WtEerWMbIewtVs2hFZtKKi8x9ghsetJIlOgULdtmOn9Egx3AbvXrRVJuy6S+ZFsE5+TvWtXSc6OdMnKyjwslwEHWIOXvlUEj1EX+OZ71N3gDJFfV8nBA/sDVUMFWSIPqnv22KGASU1NFc1e9VBVxOTkmnLsceeGqYgAjJzx2zqRbs2CQXbV8SKvj5P9q3+RvVs3y56dO2RPVqb1gIRPxJBfxJwc2ZeZIQdQBbdsEPlglgjBHPwqKL+bXyqyDnDlOb7xr/LQ3wo0d8z9LujoUDZ0KFSCAbATTzg/EGB2oe+2bSJ+B6B+cNzZIjdC5befycF1q+XAhnU2H9ywTmTRtyKzp4gM7ipyeoDE0rpYGrV9q10AHA2zxM47zPUkkSkQFcB+f+xZ+QJMV9PLf94TueYPwZJHQVKU4zlG5N3pcjA7K7S63gEskVk2vp69UIDlWhH/WijA2EKRtWWTyPjRwXOnwwXX74zIkw9YtfHA/v0hcKECOoDFF5Mlcm+jA1jV06MCmEa4zAFoM14WueWKw5dojc4WefVpkc2/iR9YhzO/SuRBdc8eOxQwjRunhn1c9ho62BhoJdhhAsz7oXnfujUic6aJEMAcg8gNNUTqHCfyjyoiqeeL3N5I5LG+wh4zWbtC9u/dG3Xwh4IkWeyQ2PUkkSlgJVhKPiALAezow5NgXoDxsflwPPuy0VKzSiz/sSBg6bVEHlT37LFDAdOo2U25Eiw1UpKFAFa12mGpiA5gsTPAridlSwFz+qh5YtXEggDmJFjZjpJrPW4pYCq+J3Jx37FWiqWkNg2bjzkJFrfj6joeIxQwZobIsVM3y7Wtcr36pqTmBeFzAIuRUXLdiFsKmKTJWZL0jsi590/OlWJN86SYA1jcjqvreIxQwCRNyhLzlsjJLy6WBs1aBqqIRx/jjBwxMl6uG3FGASvBzPS9UnnqDvlHWi8HsDgbQNfd2KaAMVMyJWlKtiTNFLmk46BAgFVxZvrYHkXXu5ilgDGvZUqFKVlSYbZIctehDmAxO1SuY/FIAfPnaRmSNDFTKgKwLg8FAuzoqme4D83xOLquz2VOAXPLB+mS9Fq2VJyxTy5r3y8QYFUdwMp8oFwH4pMCJvmNrWKm7pOqr66Xq26JDMDHYl8HsPgcXNfrsqeAqTAhXcwskT+O+kwaNsoNfq4r6vU7mDPTl/1AuR7EJwVM0qSdkvTGATm/+2NWPUxp0iykJirAqhzj5mDxObyu12VNAWPeOCDHPf2DNGicu4LDu3UlBLDfnemMHGU9Uq79uKSAqTh1t1zS5u5c6eVbUa8AO9oBLC4H13W67Clgzu027BC48hb5RszBHMDKfqRcD+KSAv8fHbo7Z6RoypMAAAAASUVORK5CYII=)\n", 188 | "\n", 189 | "There should be a file named `Minori.jpg`, if you do not see it, click the icon in the middle (refresh button)
\n", 190 | "  ![](https://i.imgur.com/CNBTH23.png)\n", 191 | "
\n", 192 | "You can double click on the file to view the image.\n", 193 | "\n", 194 | "\n", 195 | "   \n", 196 | "![](https://i.imgur.com/h2PLMrq.png)\n", 197 | "\n" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": { 203 | "id": "zqAFLkyfXuTw" 204 | }, 205 | "source": [ 206 | "**2. Mounting Google Drive**\n", 207 | "\n", 208 | " One advantage of using google colab is that connection with other google services such as Google Drive is simple. By mounting google drive, the working files can be stored permanantly. After executing the following code block, log in to the google account and copy the authentication code to the input box to finish the process." 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": { 214 | "id": "Uzuw3wBRX3qw" 215 | }, 216 | "source": [ 217 | "After mounting the drive, the content of the google drive will be under a directory named `MyDrive`, check the file structure for such a folder to confirm the execution of the code.\n", 218 | "\n", 219 | "There is also an icon for mounting google drive. The icon will automatically generate the code above.\n", 220 | "\n", 221 | "![](https://i.imgur.com/hM9Jgi7.png) \n", 222 | "\n", 223 | "After mounting the drive, all the chnages will be synced with the google drive.\n", 224 | "Since models could be quite large, make sure that your google drive has enough space. You can apply for a gsuite drive which has unlimited space using your studentID (until 2022/07). \n", 225 | "https://www.cc.ntu.edu.tw/chinese/services/serv_i06.asp\n", 226 | "http://www.cc.ntu.edu.tw/english/spotlight/2016/a105038.asp" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "metadata": { 232 | "id": "476v8iJJX8sl" 233 | }, 234 | "source": [ 235 | "%cd /content/drive/MyDrive \n", 236 | "#change directory to google drive\n", 237 | "!mkdir ML2021 #make a directory named ML2021\n", 238 | "%cd ./ML2021 \n", 239 | "#change directory to ML2021" 240 | ], 241 | "execution_count": null, 242 | "outputs": [] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": { 247 | "id": "ZyXibDQIX_3N" 248 | }, 249 | "source": [ 250 | "Use bash command pwd to output the current directory" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "metadata": { 256 | "id": "ONO1WIIHYBhg" 257 | }, 258 | "source": [ 259 | "!pwd #output the current directory" 260 | ], 261 | "execution_count": null, 262 | "outputs": [] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": { 267 | "id": "pY3yrcWmYB_C" 268 | }, 269 | "source": [ 270 | "Repeat the downloading process, this time, the file will be stored permanently in your google drive." 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "metadata": { 276 | "id": "1GWvRU3sYDgk" 277 | }, 278 | "source": [ 279 | "# Download the file with file_id \"1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV\", and rename it to Minori.jpg\n", 280 | "!gdown --id '1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV' --output Minori.jpg" 281 | ], 282 | "execution_count": null, 283 | "outputs": [] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": { 288 | "id": "NvTT17f-YX8_" 289 | }, 290 | "source": [ 291 | "TA will provide the homework data using code similar to the code above. The data could also be stored in the google drive and loaded from there." 292 | ] 293 | } 294 | ] 295 | } -------------------------------------------------------------------------------- /HW0/Google_Colab_Tutorial.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW0/Google_Colab_Tutorial.pdf -------------------------------------------------------------------------------- /HW0/Pytorch_Tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Untitled2.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyMNIStRQ4RvQtfoo9vQ5pay", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "eItOSGyKakne" 35 | }, 36 | "source": [ 37 | "# **Pytorch Tutorial**\n" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "0AatWga6aTh_" 44 | }, 45 | "source": [ 46 | "import torch" 47 | ], 48 | "execution_count": 1, 49 | "outputs": [] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": { 54 | "id": "hsZeHjlzan6W" 55 | }, 56 | "source": [ 57 | "**1. Pytorch Documentation Explanation with torch.max**\n", 58 | "\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "colab": { 65 | "base_uri": "https://localhost:8080/" 66 | }, 67 | "id": "kWSXhKOLapaG", 68 | "outputId": "c06ac6e4-7f5a-4841-868c-35cddfdf671c" 69 | }, 70 | "source": [ 71 | "x = torch.randn(4,5)\n", 72 | "y = torch.randn(4,5)\n", 73 | "z = torch.randn(4,5)\n", 74 | "print(x)\n", 75 | "print(y)\n", 76 | "print(z)" 77 | ], 78 | "execution_count": 2, 79 | "outputs": [ 80 | { 81 | "output_type": "stream", 82 | "text": [ 83 | "tensor([[-0.2319, 0.7826, 0.5202, -0.1384, 0.6040],\n", 84 | " [-0.5828, -0.2883, 0.3508, 1.5920, 0.1148],\n", 85 | " [ 0.5468, -1.3432, -0.3204, 0.9779, -0.0244],\n", 86 | " [ 0.0494, -0.4104, 2.3866, 0.2881, 1.0155]])\n", 87 | "tensor([[-1.2664, 1.0201, 0.4730, -1.8292, 0.8858],\n", 88 | " [-2.5277, 0.1873, -0.1353, -0.6365, -0.7611],\n", 89 | " [ 0.9117, 0.3282, -0.5094, -1.0268, -1.8023],\n", 90 | " [-0.1177, -0.1617, -1.0410, 0.9007, 0.8385]])\n", 91 | "tensor([[-1.5717, -1.5700, -3.2840, -0.1819, 1.5947],\n", 92 | " [ 1.4966, -0.1422, 1.4710, -2.4297, 1.3951],\n", 93 | " [-0.5066, 0.5822, 0.9158, -0.2434, 0.3539],\n", 94 | " [-0.0303, 0.1333, -0.6207, 1.2630, 0.3462]])\n" 95 | ], 96 | "name": "stdout" 97 | } 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "metadata": { 103 | "colab": { 104 | "base_uri": "https://localhost:8080/" 105 | }, 106 | "id": "sQW_dRseasEt", 107 | "outputId": "2ab028f7-e7ed-4d58-ca3f-7bb3a16d16f7" 108 | }, 109 | "source": [ 110 | "# 1. max of entire tensor (torch.max(input) → Tensor)\n", 111 | "m = torch.max(x)\n", 112 | "print(m)" 113 | ], 114 | "execution_count": 3, 115 | "outputs": [ 116 | { 117 | "output_type": "stream", 118 | "text": [ 119 | "tensor(2.3866)\n" 120 | ], 121 | "name": "stdout" 122 | } 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "metadata": { 128 | "colab": { 129 | "base_uri": "https://localhost:8080/" 130 | }, 131 | "id": "1PqGQti8asmb", 132 | "outputId": "a9072e53-274f-4160-aa2c-6279bc38b130" 133 | }, 134 | "source": [ 135 | "# 2. max along a dimension (torch.max(input, dim, keepdim=False, *, out=None) → (Tensor, LongTensor))\n", 136 | "m, idx = torch.max(x,0)\n", 137 | "print(m)\n", 138 | "print(idx)" 139 | ], 140 | "execution_count": 4, 141 | "outputs": [ 142 | { 143 | "output_type": "stream", 144 | "text": [ 145 | "tensor([0.5468, 0.7826, 2.3866, 1.5920, 1.0155])\n", 146 | "tensor([2, 0, 3, 1, 3])\n" 147 | ], 148 | "name": "stdout" 149 | } 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "metadata": { 155 | "colab": { 156 | "base_uri": "https://localhost:8080/" 157 | }, 158 | "id": "jOlacFQNat5h", 159 | "outputId": "1f8b3a88-5589-486d-aaaf-a0e7bc825dab" 160 | }, 161 | "source": [ 162 | "# 2-2\n", 163 | "m, idx = torch.max(input=x,dim=0)\n", 164 | "print(m)\n", 165 | "print(idx)" 166 | ], 167 | "execution_count": 5, 168 | "outputs": [ 169 | { 170 | "output_type": "stream", 171 | "text": [ 172 | "tensor([0.5468, 0.7826, 2.3866, 1.5920, 1.0155])\n", 173 | "tensor([2, 0, 3, 1, 3])\n" 174 | ], 175 | "name": "stdout" 176 | } 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "metadata": { 182 | "colab": { 183 | "base_uri": "https://localhost:8080/" 184 | }, 185 | "id": "sLysSuo1avgS", 186 | "outputId": "d9e1163b-abcf-4701-91b2-ee879939638e" 187 | }, 188 | "source": [ 189 | "# 2-3\n", 190 | "m, idx = torch.max(x,0,False)\n", 191 | "print(m)\n", 192 | "print(idx)" 193 | ], 194 | "execution_count": 6, 195 | "outputs": [ 196 | { 197 | "output_type": "stream", 198 | "text": [ 199 | "tensor([0.5468, 0.7826, 2.3866, 1.5920, 1.0155])\n", 200 | "tensor([2, 0, 3, 1, 3])\n" 201 | ], 202 | "name": "stdout" 203 | } 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "metadata": { 209 | "colab": { 210 | "base_uri": "https://localhost:8080/" 211 | }, 212 | "id": "CwA6osCvawxl", 213 | "outputId": "cf4e1219-e22a-4b02-f037-41a733ad21db" 214 | }, 215 | "source": [ 216 | "# 2-4\n", 217 | "m, idx = torch.max(x,dim=0,keepdim=True)\n", 218 | "print(m)\n", 219 | "print(idx)" 220 | ], 221 | "execution_count": 7, 222 | "outputs": [ 223 | { 224 | "output_type": "stream", 225 | "text": [ 226 | "tensor([[0.5468, 0.7826, 2.3866, 1.5920, 1.0155]])\n", 227 | "tensor([[2, 0, 3, 1, 3]])\n" 228 | ], 229 | "name": "stdout" 230 | } 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "metadata": { 236 | "colab": { 237 | "base_uri": "https://localhost:8080/" 238 | }, 239 | "id": "IgYZZTO0ayKb", 240 | "outputId": "6d66eef2-a95d-420f-afad-60f24d2f0a89" 241 | }, 242 | "source": [ 243 | "# 2-5\n", 244 | "p = (m,idx)\n", 245 | "torch.max(x,0,False,out=p)\n", 246 | "print(p[0])\n", 247 | "print(p[1])\n" 248 | ], 249 | "execution_count": 8, 250 | "outputs": [ 251 | { 252 | "output_type": "stream", 253 | "text": [ 254 | "tensor([0.5468, 0.7826, 2.3866, 1.5920, 1.0155])\n", 255 | "tensor([2, 0, 3, 1, 3])\n" 256 | ], 257 | "name": "stdout" 258 | } 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "metadata": { 264 | "colab": { 265 | "base_uri": "https://localhost:8080/", 266 | "height": 318 267 | }, 268 | "id": "RuhKYn_8azTX", 269 | "outputId": "fd302638-1983-43cd-972e-4c38641c59a7" 270 | }, 271 | "source": [ 272 | "# 2-6\n", 273 | "p = (m,idx)\n", 274 | "torch.max(x,0,False,p)\n", 275 | "print(p[0])\n", 276 | "print(p[1])" 277 | ], 278 | "execution_count": 9, 279 | "outputs": [ 280 | { 281 | "output_type": "error", 282 | "ename": "TypeError", 283 | "evalue": "ignored", 284 | "traceback": [ 285 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 286 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 287 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# 2-6\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 288 | "\u001b[0;31mTypeError\u001b[0m: max() received an invalid combination of arguments - got (Tensor, int, bool, tuple), but expected one of:\n * (Tensor input)\n * (Tensor input, name dim, bool keepdim, *, tuple of Tensors out)\n * (Tensor input, Tensor other, *, Tensor out)\n didn't match because some of the arguments have invalid types: (Tensor, !int!, !bool!, !tuple!)\n * (Tensor input, int dim, bool keepdim, *, tuple of Tensors out)\n" 289 | ] 290 | } 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "metadata": { 296 | "colab": { 297 | "base_uri": "https://localhost:8080/", 298 | "height": 250 299 | }, 300 | "id": "sT5TyG8Ca03o", 301 | "outputId": "72eaccca-de2c-458f-9250-07acf324b88b" 302 | }, 303 | "source": [ 304 | "# 2-7\n", 305 | "m, idx = torch.max(x,True)" 306 | ], 307 | "execution_count": 10, 308 | "outputs": [ 309 | { 310 | "output_type": "error", 311 | "ename": "TypeError", 312 | "evalue": "ignored", 313 | "traceback": [ 314 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 315 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 316 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# 2-7\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 317 | "\u001b[0;31mTypeError\u001b[0m: max() received an invalid combination of arguments - got (Tensor, bool), but expected one of:\n * (Tensor input)\n * (Tensor input, name dim, bool keepdim, *, tuple of Tensors out)\n * (Tensor input, Tensor other, *, Tensor out)\n * (Tensor input, int dim, bool keepdim, *, tuple of Tensors out)\n" 318 | ] 319 | } 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "metadata": { 325 | "colab": { 326 | "base_uri": "https://localhost:8080/" 327 | }, 328 | "id": "gIt5t7Vma2HD", 329 | "outputId": "b0cbfd7c-897b-43a7-bd92-716ff57d495a" 330 | }, 331 | "source": [ 332 | "# 3. max(choose max) operators on two tensors (torch.max(input, other, *, out=None) → Tensor)\n", 333 | "t = torch.max(x,y)\n", 334 | "print(t)" 335 | ], 336 | "execution_count": 11, 337 | "outputs": [ 338 | { 339 | "output_type": "stream", 340 | "text": [ 341 | "tensor([[-0.2319, 1.0201, 0.5202, -0.1384, 0.8858],\n", 342 | " [-0.5828, 0.1873, 0.3508, 1.5920, 0.1148],\n", 343 | " [ 0.9117, 0.3282, -0.3204, 0.9779, -0.0244],\n", 344 | " [ 0.0494, -0.1617, 2.3866, 0.9007, 1.0155]])\n" 345 | ], 346 | "name": "stdout" 347 | } 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": { 353 | "id": "Sz_mMLKYa3Wy" 354 | }, 355 | "source": [ 356 | "**2. Common errors**\n", 357 | "\n", 358 | "The following code blocks show some common errors while using the torch library. First, execute the code with error, and then execute the next code block to fix the error. You need to change the runtime to GPU.\n" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "metadata": { 364 | "id": "B5Ujt5kPa7JG" 365 | }, 366 | "source": [ 367 | "import torch" 368 | ], 369 | "execution_count": 4, 370 | "outputs": [] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "metadata": { 375 | "colab": { 376 | "base_uri": "https://localhost:8080/", 377 | "height": 342 378 | }, 379 | "id": "r_MqFNWPa8qm", 380 | "outputId": "7edcb3fc-fc95-4f1d-e0b8-a83ab42df8af" 381 | }, 382 | "source": [ 383 | "# 1. different device error\n", 384 | "model = torch.nn.Linear(5,1).to(\"cuda:0\")\n", 385 | "x = torch.Tensor([1,2,3,4,5]).to(\"cpu\")\n", 386 | "y = model(x)" 387 | ], 388 | "execution_count": 5, 389 | "outputs": [ 390 | { 391 | "output_type": "error", 392 | "ename": "RuntimeError", 393 | "evalue": "ignored", 394 | "traceback": [ 395 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 396 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 397 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda:0\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 398 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 399 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 400 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1751\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhas_torch_function_variadic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1752\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1753\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1754\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1755\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 401 | "\u001b[0;31mRuntimeError\u001b[0m: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" 402 | ] 403 | } 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "metadata": { 409 | "id": "-7TyFUwRa9vm" 410 | }, 411 | "source": [ 412 | "# 1. different device error (fixed)\n", 413 | "x = torch.Tensor([1,2,3,4,5]).to(\"cuda:0\")\n", 414 | "y = model(x)\n", 415 | "print(y.shape)" 416 | ], 417 | "execution_count": null, 418 | "outputs": [] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "metadata": { 423 | "id": "D_1ZgoqKa_Cx" 424 | }, 425 | "source": [ 426 | "# 2. mismatched dimensions error\n", 427 | "x = torch.randn(4,5)\n", 428 | "y= torch.randn(5,4)\n", 429 | "z = x + y" 430 | ], 431 | "execution_count": null, 432 | "outputs": [] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "metadata": { 437 | "id": "g2Qt06MrbAHH" 438 | }, 439 | "source": [ 440 | "# 2. mismatched dimensions error (fixed)\n", 441 | "y= y.transpose(0,1)\n", 442 | "z = x + y\n", 443 | "print(z.shape)" 444 | ], 445 | "execution_count": null, 446 | "outputs": [] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "metadata": { 451 | "id": "upiWzJtUbBYL" 452 | }, 453 | "source": [ 454 | "# 3. cuda out of memory error\n", 455 | "import torch\n", 456 | "import torchvision.models as models\n", 457 | "resnet18 = models.resnet18().to(\"cuda:0\") # Neural Networks for Image Recognition\n", 458 | "data = torch.randn(2048,3,244,244) # Create fake data (512 images)\n", 459 | "out = resnet18(data.to(\"cuda:0\")) # Use Data as Input and Feed to Model\n", 460 | "print(out.shape)\n" 461 | ], 462 | "execution_count": null, 463 | "outputs": [] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "metadata": { 468 | "id": "XcpqggzlbCcv" 469 | }, 470 | "source": [ 471 | "# 3. cuda out of memory error (fixed)\n", 472 | "for d in data:\n", 473 | " out = resnet18(d.to(\"cuda:0\").unsqueeze(0))\n", 474 | "print(out.shape)" 475 | ], 476 | "execution_count": null, 477 | "outputs": [] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "metadata": { 482 | "id": "GRHKl7m9bDvb" 483 | }, 484 | "source": [ 485 | "# 4. mismatched tensor type\n", 486 | "import torch.nn as nn\n", 487 | "L = nn.CrossEntropyLoss()\n", 488 | "outs = torch.randn(5,5)\n", 489 | "labels = torch.Tensor([1,2,3,4,0])\n", 490 | "lossval = L(outs,labels) # Calculate CrossEntropyLoss between outs and labels" 491 | ], 492 | "execution_count": null, 493 | "outputs": [] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "metadata": { 498 | "id": "xpcNI81xbFId" 499 | }, 500 | "source": [ 501 | "# 4. mismatched tensor type (fixed)\n", 502 | "labels = labels.long()\n", 503 | "lossval = L(outs,labels)\n", 504 | "print(lossval)" 505 | ], 506 | "execution_count": null, 507 | "outputs": [] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": { 512 | "id": "Ppqk6FOwbHpH" 513 | }, 514 | "source": [ 515 | "**3. More on dataset and dataloader**\n", 516 | "\n", 517 | "Let a dataset be the English alphabets \"abcdefghijklmnopqrstuvwxyz\"" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "metadata": { 523 | "id": "k7B48CmrbGXQ" 524 | }, 525 | "source": [ 526 | "dataset = \"abcdefghijklmnopqrstuvwxyz\"" 527 | ], 528 | "execution_count": null, 529 | "outputs": [] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": { 534 | "id": "p-uSDXmvbPJ9" 535 | }, 536 | "source": [ 537 | "A simple dataloader could be implemented with the python code \"for\"" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "metadata": { 543 | "id": "rFBzgIgKbQVB" 544 | }, 545 | "source": [ 546 | "for datapoint in dataset:\n", 547 | " print(datapoint)" 548 | ], 549 | "execution_count": null, 550 | "outputs": [] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "metadata": { 555 | "id": "8egynVzdbRs1" 556 | }, 557 | "source": [ 558 | "When using the dataloader, we often like to shuffle the data. This is where torch.utils.data.DataLoader comes in handy. If each data is an index (0,1,2...) from the view of torch.utils.data.DataLoader, shuffling can simply be done by shuffling an index array. \n", 559 | "\n", 560 | "torch.utils.data.DataLoader will need two imformation to fulfill its role. First, it needs to know the length of the data. Second, once torch.utils.data.DataLoader outputs the index of the shuffling results, the dataset needs to return the corresponding data.\n", 561 | "\n", 562 | "Therefore, torch.utils.data.Dataset provides the imformation by two functions, `__len__()` and `__getitem__()` to support torch.utils.data.Dataloader" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "metadata": { 568 | "id": "tnyuFR_IbUVU" 569 | }, 570 | "source": [ 571 | "import torch\n", 572 | "import torch.utils.data \n", 573 | "class ExampleDataset(torch.utils.data.Dataset):\n", 574 | " def __init__(self):\n", 575 | " self.data = \"abcdefghijklmnopqrstuvwxyz\"\n", 576 | " \n", 577 | " def __getitem__(self,idx): # if the index is idx, what will be the data?\n", 578 | " return self.data[idx]\n", 579 | " \n", 580 | " def __len__(self): # What is the length of the dataset\n", 581 | " return len(self.data)\n", 582 | "\n", 583 | "dataset1 = ExampleDataset() # create the dataset\n", 584 | "dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)\n", 585 | "for datapoint in dataloader:\n", 586 | " print(datapoint)" 587 | ], 588 | "execution_count": null, 589 | "outputs": [] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "metadata": { 594 | "id": "2fWMkpltbZOJ" 595 | }, 596 | "source": [ 597 | "A simple data augmentation technique can be done by changing the code in `__len__()` and `__getitem__()`. Suppose we want to double the length of the dataset by adding in the uppercase letters, using only the lowercase dataset, you can change the dataset to the following." 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "metadata": { 603 | "id": "82KhTMF5bayN" 604 | }, 605 | "source": [ 606 | "import torch.utils.data \n", 607 | "class ExampleDataset(torch.utils.data.Dataset):\n", 608 | " def __init__(self):\n", 609 | " self.data = \"abcdefghijklmnopqrstuvwxyz\"\n", 610 | " \n", 611 | " def __getitem__(self,idx): # if the index is idx, what will be the data?\n", 612 | " if idx >= len(self.data): # if the index >= 26, return upper case letter\n", 613 | " return self.data[idx%26].upper()\n", 614 | " else: # if the index < 26, return lower case, return lower case letter\n", 615 | " return self.data[idx]\n", 616 | " \n", 617 | " def __len__(self): # What is the length of the dataset\n", 618 | " return 2 * len(self.data) # The length is now twice as large\n", 619 | "\n", 620 | "dataset1 = ExampleDataset() # create the dataset\n", 621 | "dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)\n", 622 | "for datapoint in dataloader:\n", 623 | " print(datapoint)" 624 | ], 625 | "execution_count": null, 626 | "outputs": [] 627 | } 628 | ] 629 | } -------------------------------------------------------------------------------- /HW0/Pytorch_Tutorial_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW0/Pytorch_Tutorial_1.pdf -------------------------------------------------------------------------------- /HW0/Pytorch_Tutorial_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW0/Pytorch_Tutorial_2.pdf -------------------------------------------------------------------------------- /HW1/HW01.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW1/HW01.pdf -------------------------------------------------------------------------------- /HW1/data/sampleSubmission.csv: -------------------------------------------------------------------------------- 1 | id,tested_positive 2 | 0,0.0 3 | 1,0.0 4 | 2,0.0 5 | 3,0.0 6 | 4,0.0 7 | 5,0.0 8 | 6,0.0 9 | 7,0.0 10 | 8,0.0 11 | 9,0.0 12 | 10,0.0 13 | 11,0.0 14 | 12,0.0 15 | 13,0.0 16 | 14,0.0 17 | 15,0.0 18 | 16,0.0 19 | 17,0.0 20 | 18,0.0 21 | 19,0.0 22 | 20,0.0 23 | 21,0.0 24 | 22,0.0 25 | 23,0.0 26 | 24,0.0 27 | 25,0.0 28 | 26,0.0 29 | 27,0.0 30 | 28,0.0 31 | 29,0.0 32 | 30,0.0 33 | 31,0.0 34 | 32,0.0 35 | 33,0.0 36 | 34,0.0 37 | 35,0.0 38 | 36,0.0 39 | 37,0.0 40 | 38,0.0 41 | 39,0.0 42 | 40,0.0 43 | 41,0.0 44 | 42,0.0 45 | 43,0.0 46 | 44,0.0 47 | 45,0.0 48 | 46,0.0 49 | 47,0.0 50 | 48,0.0 51 | 49,0.0 52 | 50,0.0 53 | 51,0.0 54 | 52,0.0 55 | 53,0.0 56 | 54,0.0 57 | 55,0.0 58 | 56,0.0 59 | 57,0.0 60 | 58,0.0 61 | 59,0.0 62 | 60,0.0 63 | 61,0.0 64 | 62,0.0 65 | 63,0.0 66 | 64,0.0 67 | 65,0.0 68 | 66,0.0 69 | 67,0.0 70 | 68,0.0 71 | 69,0.0 72 | 70,0.0 73 | 71,0.0 74 | 72,0.0 75 | 73,0.0 76 | 74,0.0 77 | 75,0.0 78 | 76,0.0 79 | 77,0.0 80 | 78,0.0 81 | 79,0.0 82 | 80,0.0 83 | 81,0.0 84 | 82,0.0 85 | 83,0.0 86 | 84,0.0 87 | 85,0.0 88 | 86,0.0 89 | 87,0.0 90 | 88,0.0 91 | 89,0.0 92 | 90,0.0 93 | 91,0.0 94 | 92,0.0 95 | 93,0.0 96 | 94,0.0 97 | 95,0.0 98 | 96,0.0 99 | 97,0.0 100 | 98,0.0 101 | 99,0.0 102 | 100,0.0 103 | 101,0.0 104 | 102,0.0 105 | 103,0.0 106 | 104,0.0 107 | 105,0.0 108 | 106,0.0 109 | 107,0.0 110 | 108,0.0 111 | 109,0.0 112 | 110,0.0 113 | 111,0.0 114 | 112,0.0 115 | 113,0.0 116 | 114,0.0 117 | 115,0.0 118 | 116,0.0 119 | 117,0.0 120 | 118,0.0 121 | 119,0.0 122 | 120,0.0 123 | 121,0.0 124 | 122,0.0 125 | 123,0.0 126 | 124,0.0 127 | 125,0.0 128 | 126,0.0 129 | 127,0.0 130 | 128,0.0 131 | 129,0.0 132 | 130,0.0 133 | 131,0.0 134 | 132,0.0 135 | 133,0.0 136 | 134,0.0 137 | 135,0.0 138 | 136,0.0 139 | 137,0.0 140 | 138,0.0 141 | 139,0.0 142 | 140,0.0 143 | 141,0.0 144 | 142,0.0 145 | 143,0.0 146 | 144,0.0 147 | 145,0.0 148 | 146,0.0 149 | 147,0.0 150 | 148,0.0 151 | 149,0.0 152 | 150,0.0 153 | 151,0.0 154 | 152,0.0 155 | 153,0.0 156 | 154,0.0 157 | 155,0.0 158 | 156,0.0 159 | 157,0.0 160 | 158,0.0 161 | 159,0.0 162 | 160,0.0 163 | 161,0.0 164 | 162,0.0 165 | 163,0.0 166 | 164,0.0 167 | 165,0.0 168 | 166,0.0 169 | 167,0.0 170 | 168,0.0 171 | 169,0.0 172 | 170,0.0 173 | 171,0.0 174 | 172,0.0 175 | 173,0.0 176 | 174,0.0 177 | 175,0.0 178 | 176,0.0 179 | 177,0.0 180 | 178,0.0 181 | 179,0.0 182 | 180,0.0 183 | 181,0.0 184 | 182,0.0 185 | 183,0.0 186 | 184,0.0 187 | 185,0.0 188 | 186,0.0 189 | 187,0.0 190 | 188,0.0 191 | 189,0.0 192 | 190,0.0 193 | 191,0.0 194 | 192,0.0 195 | 193,0.0 196 | 194,0.0 197 | 195,0.0 198 | 196,0.0 199 | 197,0.0 200 | 198,0.0 201 | 199,0.0 202 | 200,0.0 203 | 201,0.0 204 | 202,0.0 205 | 203,0.0 206 | 204,0.0 207 | 205,0.0 208 | 206,0.0 209 | 207,0.0 210 | 208,0.0 211 | 209,0.0 212 | 210,0.0 213 | 211,0.0 214 | 212,0.0 215 | 213,0.0 216 | 214,0.0 217 | 215,0.0 218 | 216,0.0 219 | 217,0.0 220 | 218,0.0 221 | 219,0.0 222 | 220,0.0 223 | 221,0.0 224 | 222,0.0 225 | 223,0.0 226 | 224,0.0 227 | 225,0.0 228 | 226,0.0 229 | 227,0.0 230 | 228,0.0 231 | 229,0.0 232 | 230,0.0 233 | 231,0.0 234 | 232,0.0 235 | 233,0.0 236 | 234,0.0 237 | 235,0.0 238 | 236,0.0 239 | 237,0.0 240 | 238,0.0 241 | 239,0.0 242 | 240,0.0 243 | 241,0.0 244 | 242,0.0 245 | 243,0.0 246 | 244,0.0 247 | 245,0.0 248 | 246,0.0 249 | 247,0.0 250 | 248,0.0 251 | 249,0.0 252 | 250,0.0 253 | 251,0.0 254 | 252,0.0 255 | 253,0.0 256 | 254,0.0 257 | 255,0.0 258 | 256,0.0 259 | 257,0.0 260 | 258,0.0 261 | 259,0.0 262 | 260,0.0 263 | 261,0.0 264 | 262,0.0 265 | 263,0.0 266 | 264,0.0 267 | 265,0.0 268 | 266,0.0 269 | 267,0.0 270 | 268,0.0 271 | 269,0.0 272 | 270,0.0 273 | 271,0.0 274 | 272,0.0 275 | 273,0.0 276 | 274,0.0 277 | 275,0.0 278 | 276,0.0 279 | 277,0.0 280 | 278,0.0 281 | 279,0.0 282 | 280,0.0 283 | 281,0.0 284 | 282,0.0 285 | 283,0.0 286 | 284,0.0 287 | 285,0.0 288 | 286,0.0 289 | 287,0.0 290 | 288,0.0 291 | 289,0.0 292 | 290,0.0 293 | 291,0.0 294 | 292,0.0 295 | 293,0.0 296 | 294,0.0 297 | 295,0.0 298 | 296,0.0 299 | 297,0.0 300 | 298,0.0 301 | 299,0.0 302 | 300,0.0 303 | 301,0.0 304 | 302,0.0 305 | 303,0.0 306 | 304,0.0 307 | 305,0.0 308 | 306,0.0 309 | 307,0.0 310 | 308,0.0 311 | 309,0.0 312 | 310,0.0 313 | 311,0.0 314 | 312,0.0 315 | 313,0.0 316 | 314,0.0 317 | 315,0.0 318 | 316,0.0 319 | 317,0.0 320 | 318,0.0 321 | 319,0.0 322 | 320,0.0 323 | 321,0.0 324 | 322,0.0 325 | 323,0.0 326 | 324,0.0 327 | 325,0.0 328 | 326,0.0 329 | 327,0.0 330 | 328,0.0 331 | 329,0.0 332 | 330,0.0 333 | 331,0.0 334 | 332,0.0 335 | 333,0.0 336 | 334,0.0 337 | 335,0.0 338 | 336,0.0 339 | 337,0.0 340 | 338,0.0 341 | 339,0.0 342 | 340,0.0 343 | 341,0.0 344 | 342,0.0 345 | 343,0.0 346 | 344,0.0 347 | 345,0.0 348 | 346,0.0 349 | 347,0.0 350 | 348,0.0 351 | 349,0.0 352 | 350,0.0 353 | 351,0.0 354 | 352,0.0 355 | 353,0.0 356 | 354,0.0 357 | 355,0.0 358 | 356,0.0 359 | 357,0.0 360 | 358,0.0 361 | 359,0.0 362 | 360,0.0 363 | 361,0.0 364 | 362,0.0 365 | 363,0.0 366 | 364,0.0 367 | 365,0.0 368 | 366,0.0 369 | 367,0.0 370 | 368,0.0 371 | 369,0.0 372 | 370,0.0 373 | 371,0.0 374 | 372,0.0 375 | 373,0.0 376 | 374,0.0 377 | 375,0.0 378 | 376,0.0 379 | 377,0.0 380 | 378,0.0 381 | 379,0.0 382 | 380,0.0 383 | 381,0.0 384 | 382,0.0 385 | 383,0.0 386 | 384,0.0 387 | 385,0.0 388 | 386,0.0 389 | 387,0.0 390 | 388,0.0 391 | 389,0.0 392 | 390,0.0 393 | 391,0.0 394 | 392,0.0 395 | 393,0.0 396 | 394,0.0 397 | 395,0.0 398 | 396,0.0 399 | 397,0.0 400 | 398,0.0 401 | 399,0.0 402 | 400,0.0 403 | 401,0.0 404 | 402,0.0 405 | 403,0.0 406 | 404,0.0 407 | 405,0.0 408 | 406,0.0 409 | 407,0.0 410 | 408,0.0 411 | 409,0.0 412 | 410,0.0 413 | 411,0.0 414 | 412,0.0 415 | 413,0.0 416 | 414,0.0 417 | 415,0.0 418 | 416,0.0 419 | 417,0.0 420 | 418,0.0 421 | 419,0.0 422 | 420,0.0 423 | 421,0.0 424 | 422,0.0 425 | 423,0.0 426 | 424,0.0 427 | 425,0.0 428 | 426,0.0 429 | 427,0.0 430 | 428,0.0 431 | 429,0.0 432 | 430,0.0 433 | 431,0.0 434 | 432,0.0 435 | 433,0.0 436 | 434,0.0 437 | 435,0.0 438 | 436,0.0 439 | 437,0.0 440 | 438,0.0 441 | 439,0.0 442 | 440,0.0 443 | 441,0.0 444 | 442,0.0 445 | 443,0.0 446 | 444,0.0 447 | 445,0.0 448 | 446,0.0 449 | 447,0.0 450 | 448,0.0 451 | 449,0.0 452 | 450,0.0 453 | 451,0.0 454 | 452,0.0 455 | 453,0.0 456 | 454,0.0 457 | 455,0.0 458 | 456,0.0 459 | 457,0.0 460 | 458,0.0 461 | 459,0.0 462 | 460,0.0 463 | 461,0.0 464 | 462,0.0 465 | 463,0.0 466 | 464,0.0 467 | 465,0.0 468 | 466,0.0 469 | 467,0.0 470 | 468,0.0 471 | 469,0.0 472 | 470,0.0 473 | 471,0.0 474 | 472,0.0 475 | 473,0.0 476 | 474,0.0 477 | 475,0.0 478 | 476,0.0 479 | 477,0.0 480 | 478,0.0 481 | 479,0.0 482 | 480,0.0 483 | 481,0.0 484 | 482,0.0 485 | 483,0.0 486 | 484,0.0 487 | 485,0.0 488 | 486,0.0 489 | 487,0.0 490 | 488,0.0 491 | 489,0.0 492 | 490,0.0 493 | 491,0.0 494 | 492,0.0 495 | 493,0.0 496 | 494,0.0 497 | 495,0.0 498 | 496,0.0 499 | 497,0.0 500 | 498,0.0 501 | 499,0.0 502 | 500,0.0 503 | 501,0.0 504 | 502,0.0 505 | 503,0.0 506 | 504,0.0 507 | 505,0.0 508 | 506,0.0 509 | 507,0.0 510 | 508,0.0 511 | 509,0.0 512 | 510,0.0 513 | 511,0.0 514 | 512,0.0 515 | 513,0.0 516 | 514,0.0 517 | 515,0.0 518 | 516,0.0 519 | 517,0.0 520 | 518,0.0 521 | 519,0.0 522 | 520,0.0 523 | 521,0.0 524 | 522,0.0 525 | 523,0.0 526 | 524,0.0 527 | 525,0.0 528 | 526,0.0 529 | 527,0.0 530 | 528,0.0 531 | 529,0.0 532 | 530,0.0 533 | 531,0.0 534 | 532,0.0 535 | 533,0.0 536 | 534,0.0 537 | 535,0.0 538 | 536,0.0 539 | 537,0.0 540 | 538,0.0 541 | 539,0.0 542 | 540,0.0 543 | 541,0.0 544 | 542,0.0 545 | 543,0.0 546 | 544,0.0 547 | 545,0.0 548 | 546,0.0 549 | 547,0.0 550 | 548,0.0 551 | 549,0.0 552 | 550,0.0 553 | 551,0.0 554 | 552,0.0 555 | 553,0.0 556 | 554,0.0 557 | 555,0.0 558 | 556,0.0 559 | 557,0.0 560 | 558,0.0 561 | 559,0.0 562 | 560,0.0 563 | 561,0.0 564 | 562,0.0 565 | 563,0.0 566 | 564,0.0 567 | 565,0.0 568 | 566,0.0 569 | 567,0.0 570 | 568,0.0 571 | 569,0.0 572 | 570,0.0 573 | 571,0.0 574 | 572,0.0 575 | 573,0.0 576 | 574,0.0 577 | 575,0.0 578 | 576,0.0 579 | 577,0.0 580 | 578,0.0 581 | 579,0.0 582 | 580,0.0 583 | 581,0.0 584 | 582,0.0 585 | 583,0.0 586 | 584,0.0 587 | 585,0.0 588 | 586,0.0 589 | 587,0.0 590 | 588,0.0 591 | 589,0.0 592 | 590,0.0 593 | 591,0.0 594 | 592,0.0 595 | 593,0.0 596 | 594,0.0 597 | 595,0.0 598 | 596,0.0 599 | 597,0.0 600 | 598,0.0 601 | 599,0.0 602 | 600,0.0 603 | 601,0.0 604 | 602,0.0 605 | 603,0.0 606 | 604,0.0 607 | 605,0.0 608 | 606,0.0 609 | 607,0.0 610 | 608,0.0 611 | 609,0.0 612 | 610,0.0 613 | 611,0.0 614 | 612,0.0 615 | 613,0.0 616 | 614,0.0 617 | 615,0.0 618 | 616,0.0 619 | 617,0.0 620 | 618,0.0 621 | 619,0.0 622 | 620,0.0 623 | 621,0.0 624 | 622,0.0 625 | 623,0.0 626 | 624,0.0 627 | 625,0.0 628 | 626,0.0 629 | 627,0.0 630 | 628,0.0 631 | 629,0.0 632 | 630,0.0 633 | 631,0.0 634 | 632,0.0 635 | 633,0.0 636 | 634,0.0 637 | 635,0.0 638 | 636,0.0 639 | 637,0.0 640 | 638,0.0 641 | 639,0.0 642 | 640,0.0 643 | 641,0.0 644 | 642,0.0 645 | 643,0.0 646 | 644,0.0 647 | 645,0.0 648 | 646,0.0 649 | 647,0.0 650 | 648,0.0 651 | 649,0.0 652 | 650,0.0 653 | 651,0.0 654 | 652,0.0 655 | 653,0.0 656 | 654,0.0 657 | 655,0.0 658 | 656,0.0 659 | 657,0.0 660 | 658,0.0 661 | 659,0.0 662 | 660,0.0 663 | 661,0.0 664 | 662,0.0 665 | 663,0.0 666 | 664,0.0 667 | 665,0.0 668 | 666,0.0 669 | 667,0.0 670 | 668,0.0 671 | 669,0.0 672 | 670,0.0 673 | 671,0.0 674 | 672,0.0 675 | 673,0.0 676 | 674,0.0 677 | 675,0.0 678 | 676,0.0 679 | 677,0.0 680 | 678,0.0 681 | 679,0.0 682 | 680,0.0 683 | 681,0.0 684 | 682,0.0 685 | 683,0.0 686 | 684,0.0 687 | 685,0.0 688 | 686,0.0 689 | 687,0.0 690 | 688,0.0 691 | 689,0.0 692 | 690,0.0 693 | 691,0.0 694 | 692,0.0 695 | 693,0.0 696 | 694,0.0 697 | 695,0.0 698 | 696,0.0 699 | 697,0.0 700 | 698,0.0 701 | 699,0.0 702 | 700,0.0 703 | 701,0.0 704 | 702,0.0 705 | 703,0.0 706 | 704,0.0 707 | 705,0.0 708 | 706,0.0 709 | 707,0.0 710 | 708,0.0 711 | 709,0.0 712 | 710,0.0 713 | 711,0.0 714 | 712,0.0 715 | 713,0.0 716 | 714,0.0 717 | 715,0.0 718 | 716,0.0 719 | 717,0.0 720 | 718,0.0 721 | 719,0.0 722 | 720,0.0 723 | 721,0.0 724 | 722,0.0 725 | 723,0.0 726 | 724,0.0 727 | 725,0.0 728 | 726,0.0 729 | 727,0.0 730 | 728,0.0 731 | 729,0.0 732 | 730,0.0 733 | 731,0.0 734 | 732,0.0 735 | 733,0.0 736 | 734,0.0 737 | 735,0.0 738 | 736,0.0 739 | 737,0.0 740 | 738,0.0 741 | 739,0.0 742 | 740,0.0 743 | 741,0.0 744 | 742,0.0 745 | 743,0.0 746 | 744,0.0 747 | 745,0.0 748 | 746,0.0 749 | 747,0.0 750 | 748,0.0 751 | 749,0.0 752 | 750,0.0 753 | 751,0.0 754 | 752,0.0 755 | 753,0.0 756 | 754,0.0 757 | 755,0.0 758 | 756,0.0 759 | 757,0.0 760 | 758,0.0 761 | 759,0.0 762 | 760,0.0 763 | 761,0.0 764 | 762,0.0 765 | 763,0.0 766 | 764,0.0 767 | 765,0.0 768 | 766,0.0 769 | 767,0.0 770 | 768,0.0 771 | 769,0.0 772 | 770,0.0 773 | 771,0.0 774 | 772,0.0 775 | 773,0.0 776 | 774,0.0 777 | 775,0.0 778 | 776,0.0 779 | 777,0.0 780 | 778,0.0 781 | 779,0.0 782 | 780,0.0 783 | 781,0.0 784 | 782,0.0 785 | 783,0.0 786 | 784,0.0 787 | 785,0.0 788 | 786,0.0 789 | 787,0.0 790 | 788,0.0 791 | 789,0.0 792 | 790,0.0 793 | 791,0.0 794 | 792,0.0 795 | 793,0.0 796 | 794,0.0 797 | 795,0.0 798 | 796,0.0 799 | 797,0.0 800 | 798,0.0 801 | 799,0.0 802 | 800,0.0 803 | 801,0.0 804 | 802,0.0 805 | 803,0.0 806 | 804,0.0 807 | 805,0.0 808 | 806,0.0 809 | 807,0.0 810 | 808,0.0 811 | 809,0.0 812 | 810,0.0 813 | 811,0.0 814 | 812,0.0 815 | 813,0.0 816 | 814,0.0 817 | 815,0.0 818 | 816,0.0 819 | 817,0.0 820 | 818,0.0 821 | 819,0.0 822 | 820,0.0 823 | 821,0.0 824 | 822,0.0 825 | 823,0.0 826 | 824,0.0 827 | 825,0.0 828 | 826,0.0 829 | 827,0.0 830 | 828,0.0 831 | 829,0.0 832 | 830,0.0 833 | 831,0.0 834 | 832,0.0 835 | 833,0.0 836 | 834,0.0 837 | 835,0.0 838 | 836,0.0 839 | 837,0.0 840 | 838,0.0 841 | 839,0.0 842 | 840,0.0 843 | 841,0.0 844 | 842,0.0 845 | 843,0.0 846 | 844,0.0 847 | 845,0.0 848 | 846,0.0 849 | 847,0.0 850 | 848,0.0 851 | 849,0.0 852 | 850,0.0 853 | 851,0.0 854 | 852,0.0 855 | 853,0.0 856 | 854,0.0 857 | 855,0.0 858 | 856,0.0 859 | 857,0.0 860 | 858,0.0 861 | 859,0.0 862 | 860,0.0 863 | 861,0.0 864 | 862,0.0 865 | 863,0.0 866 | 864,0.0 867 | 865,0.0 868 | 866,0.0 869 | 867,0.0 870 | 868,0.0 871 | 869,0.0 872 | 870,0.0 873 | 871,0.0 874 | 872,0.0 875 | 873,0.0 876 | 874,0.0 877 | 875,0.0 878 | 876,0.0 879 | 877,0.0 880 | 878,0.0 881 | 879,0.0 882 | 880,0.0 883 | 881,0.0 884 | 882,0.0 885 | 883,0.0 886 | 884,0.0 887 | 885,0.0 888 | 886,0.0 889 | 887,0.0 890 | 888,0.0 891 | 889,0.0 892 | 890,0.0 893 | 891,0.0 894 | 892,0.0 895 | -------------------------------------------------------------------------------- /HW10/HW10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW10/HW10.pdf -------------------------------------------------------------------------------- /HW10/hw10_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW10/hw10_report.pdf -------------------------------------------------------------------------------- /HW11/HW11.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW11/HW11.pdf -------------------------------------------------------------------------------- /HW11/hw11_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW11/hw11_report.pdf -------------------------------------------------------------------------------- /HW12/HW12.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW12/HW12.pdf -------------------------------------------------------------------------------- /HW12/hw12_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW12/hw12_report.pdf -------------------------------------------------------------------------------- /HW13/HW13.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW13/HW13.pdf -------------------------------------------------------------------------------- /HW14/HW14.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW14/HW14.pdf -------------------------------------------------------------------------------- /HW15/HW15.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW15/HW15.pdf -------------------------------------------------------------------------------- /HW15/homework15.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "ML2021 HW15 Meta Learning.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "wzVBe3h7Xh-2" 35 | }, 36 | "source": [ 37 | "\n", 38 | "# **HW15 Meta Learning: Few-shot Classification**\n", 39 | "\n", 40 | "Please mail to ntu-ml-2021spring-ta@googlegroups.com if you have any questions.\n", 41 | "\n", 42 | "Useful Links:\n", 43 | "1. [Go to hyperparameter setting.](#hyp)\n", 44 | "1. [Go to meta algorithm setting.](#modelsetting)\n", 45 | "1. [Go to main loop.](#mainloop)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": { 51 | "id": "RdpzIMG6XsGK" 52 | }, 53 | "source": [ 54 | "## **Step 0: Check GPU**" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "metadata": { 60 | "id": "zjjHsZbaL7SV" 61 | }, 62 | "source": [ 63 | "!nvidia-smi" 64 | ], 65 | "execution_count": null, 66 | "outputs": [] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "metadata": { 71 | "cellView": "form", 72 | "id": "gWpc6vW3MQhv" 73 | }, 74 | "source": [ 75 | "#@markdown ### Install `qqdm`\n", 76 | "# Check if installed\n", 77 | "try:\n", 78 | " import qqdm\n", 79 | "except:\n", 80 | " ! pip install qqdm > /dev/null 2>&1\n", 81 | "print(\"Done!\")" 82 | ], 83 | "execution_count": null, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "id": "bQ3wvyjnXwGX" 90 | }, 91 | "source": [ 92 | "## **Step 1: Download Data**\n", 93 | "\n", 94 | "Run the cell to download data, which has been pre-processed by TAs. \n", 95 | "The dataset has been augmented, so extra data augmentation is not required.\n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "metadata": { 101 | "id": "g7Gt4Jucug41" 102 | }, 103 | "source": [ 104 | "workspace_dir = '.'\n", 105 | "\n", 106 | "# gdown is a package that downloads files from google drive\n", 107 | "!gdown --id 1FLDrQ0k-iJ-mk8ors0WItqvwgu0w9J0U \\\n", 108 | " --output \"{workspace_dir}/Omniglot.tar.gz\"" 109 | ], 110 | "execution_count": null, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "AMGFHI9XX9ms" 117 | }, 118 | "source": [ 119 | "### Decompress the dataset\n", 120 | "\n", 121 | "Since the dataset is quite large, please wait and observe the main program [here](#mainprogram). \n", 122 | "You can come back here later by [*back to pre-process*](#preprocess)." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "metadata": { 128 | "id": "AvvlAQBUug42" 129 | }, 130 | "source": [ 131 | "# Use `tar' command to decompress\n", 132 | "!tar -zxf \"{workspace_dir}/Omniglot.tar.gz\" \\\n", 133 | " -C \"{workspace_dir}/\"" 134 | ], 135 | "execution_count": null, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "T5P9eT0fYDqV" 142 | }, 143 | "source": [ 144 | "### Data Preview\n", 145 | "\n", 146 | "Just look at some data in the dataset." 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "metadata": { 152 | "colab": { 153 | "base_uri": "https://localhost:8080/", 154 | "height": 297 155 | }, 156 | "id": "7VtgHLurYE5x", 157 | "outputId": "961971b2-8b61-4d03-c06e-571a778ab52d" 158 | }, 159 | "source": [ 160 | "from PIL import Image\n", 161 | "from IPython.display import display\n", 162 | "for i in range(10, 20):\n", 163 | " im = Image.open(\"Omniglot/images_background/Japanese_(hiragana).0/character13/0500_\" + str (i) + \".png\")\n", 164 | " display(im)" 165 | ], 166 | "execution_count": null, 167 | "outputs": [ 168 | { 169 | "output_type": "display_data", 170 | "data": { 171 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABEUlEQVR4nGP8z4AbMOGRQ5F8uekLmux/BNjHs+0/CkDWqcKJppMFQv1jYGBgYGb69w/FMsb/DAwMDD+bnjMwMHxbb6nIyMDAwMBWqoyk8/+zRwwMDD//vGFmYGBgYBDhQNbJ8Pvvp7t/OL0nBENMZUG1s2vFsz+C71nYsPnz5QSz/Vt1f//DGgj/GV0MbEMYTqDJQrz7Sc/62VVZZtefKIEAlfy3XthM3TBD7Qu2EGL07ZPWncHGzog9bP/+/v1EvvUv9rBlYvk37VsgE1ad/34+zeeL+4UaKywMDAwM7w7/unH46puSKlZUjYz/GRj+z278y2xkbW7Cy4ApyfD1838mQVY0lzLAAx47IDqBDQpJAN4Euv7fFejQAAAAAElFTkSuQmCC\n", 172 | "text/plain": [ 173 | "" 174 | ] 175 | }, 176 | "metadata": { 177 | "tags": [] 178 | } 179 | }, 180 | { 181 | "output_type": "display_data", 182 | "data": { 183 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABNklEQVR4nGP8zwAH//8xM6AAJiT2pdTXuCVfbvmGW5LhPwNuSUaGP7gl5ZkuoUqyMDAwMLw78I9PjVVYSu2AP9P//39ZUSQfFP/8/4dVWvER9y6GC08+T+dCltQ9zvD5wZPr9/7uPsMkzuXKgnAhHPz71aJ07eHHH3/gIghVDIwsEv9ERHF6RevDG9yBwMn8GZvk/2+3nvxnEOe4g+rR////////scmBX+nov+OCV/4jA4jkNR73ed4aD3M032GRvME/5ddV4QKRqX+xSH4vFM4/4cmg+eQ/Fsn/X6doiPDy7vmHVfL/3xdzjB2+o8r9h/mTSTxBn4UJ1SNIgfD1iTS6JDzgfxSJ7UEzFWbnr5fZQrP+YJf8FKcpsegHuhw0yti02QI9MGxkYPwPtZmREUMOJokdAAB60yoWf/hgewAAAABJRU5ErkJggg==\n", 184 | "text/plain": [ 185 | "" 186 | ] 187 | }, 188 | "metadata": { 189 | "tags": [] 190 | } 191 | }, 192 | { 193 | "output_type": "display_data", 194 | "data": { 195 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABN0lEQVR4nGP8z4AbsCBz/j9lkGZE4jMi6/wTz7AQWTWKToafqMYy4bESXfI/bklG8Yc/cEoyGbz59R8OIA76eBNq2v/7P/c8gJnM6KHHwsDAcC7iN8y133KEYR7lMmdg/M/A8PUxTPWFtCXOMElGTkYWBgYGbg24rawSPEhuQATC/6dPGe7/Q/EKQvJa0GuGH39RPQpz+FNL0zNXW9gyv/1H8gyU/tuicOXf/6tcEs+RJGGB8HOnrwYjA1r4wSQ/3tP5/vXLpV+i7EiSsPh85/JEmJHh/eelfoyYkv8f7nvNwHDy7HkhbK79/+/fv19JFl/+Y3EQAwMjI+PrA36cWP35////m55y15E1/keSfOurdvAvdskfm/3kdv37j1Xy33Jh7ZWo+v7/h6fbV49UedGTIiO+7AAAZ4kCU7KEzEEAAAAASUVORK5CYII=\n", 196 | "text/plain": [ 197 | "" 198 | ] 199 | }, 200 | "metadata": { 201 | "tags": [] 202 | } 203 | }, 204 | { 205 | "output_type": "display_data", 206 | "data": { 207 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABLUlEQVR4nGP8z4AbMKHx/3/7g1vyredeBIcFTfLP7c8w5r+/EMlX90UEISLv/315B5VbepDxPwMDA8OiAmYuiNDfZ0LcUNs/BkEkPzw4D3XHpyZHL0YIU9ydEeYVKP3dW3g51B2MCNcyQgCn47VfUCamVxjY/uH2JwpAlUQLS+RAeH3rnLbIN+ySn6JPsP35/18Kq7Hvz/edPhr75f9/bDoZGLgUGJveb32hgkUnn0H3ob/8vJ8PILT+R4BLMqIXDssyhPyGCTAiuf7fwQiuj9o/5FbC7EK2k8l+2RnFe20WjNiM/f///7+frp6v4Tz04HvyUlaIAbvOf1eNJQ8juKiSXy2lt/7FIfl7Dl/rn//YJf+tFrR7jKwY2Z8MZw/5qDAi8VEkGf4jSzEwAABSseqGZyInRAAAAABJRU5ErkJggg==\n", 208 | "text/plain": [ 209 | "" 210 | ] 211 | }, 212 | "metadata": { 213 | "tags": [] 214 | } 215 | }, 216 | { 217 | "output_type": "display_data", 218 | "data": { 219 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABJklEQVR4nGP8z4AbMKFy/+GRfF54D7fkh8Wv8RgLAxCXsGCR+Pfx0BtTfZjkz5dQNz/79+Lhv32HTj4TbtFjZGBg/M/AwHDO/xdE8s87Qdb/v3WsnYz5WBmgkh9PQ73wqLBVg0FAhwPmkv/I4JrgCWQukoN+/2NECy645P9bTXf4PVFDCG7sU33JNFMOwZvIxsIk/03k2/PnvpTsW2RJeAh9lzRjFtdBNRUuyfjzCwOHP2powhzEaPBypvvlvWjOhZn/xlZYUIgH1U641/5/e/6bufT8BSFs/mTkVmH4+gKHgxgYGBh+PmdhxCn5/a85D1YH/f///0Mk336UeEBI/ntXzdv9C7vkvyOWvJnf/2OX/GxlMBNNDuHPvxdlRFGcygBNJrgAAEPeDmCQZ6aqAAAAAElFTkSuQmCC\n", 220 | "text/plain": [ 221 | "" 222 | ] 223 | }, 224 | "metadata": { 225 | "tags": [] 226 | } 227 | }, 228 | { 229 | "output_type": "display_data", 230 | "data": { 231 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABOElEQVR4nGP8z4AbMKHxXz1AVv0fBfxJt/+B4KHp/P/j7U+cxrKEvHyD2041ppP/cUrKqh34D3EKAwMDC5okq9Sdrz9/3v50ysqTEUny//9/v169/K+01e/Rjz8MUppQnf8YGBi+3bxw+fGLu78Y/nxlCjHXZRfiYmRgYPzP8KfzPsO/iw+4FUUEHAyZX0R0xcIcwsLA8P/lQwYmTzsDXlbGX6f/fGf8zYgcQr9//fr19/////9fRgsKczLO/occQiysrKxMDAwM/5fumr0vmEUcrhPZK38OmvowCkiaYQ0EJoWbuyvmWggjRJDj5IkrPw9D8V84Hyb5bsO3////v9sQJHbxP4bkWYmajXcfnM4QWPQHU/JXjbSYqKiwRj9SXP9nhEXQz/cfH/3n0BJCdiEjKQlscEsCAN5i3onYmdekAAAAAElFTkSuQmCC\n", 232 | "text/plain": [ 233 | "" 234 | ] 235 | }, 236 | "metadata": { 237 | "tags": [] 238 | } 239 | }, 240 | { 241 | "output_type": "display_data", 242 | "data": { 243 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABO0lEQVR4nGP8z4AbMOGRo5bks4v/UCRZkNj/ly06zs3w4+eLf4yK7OiSDH9+/Xl+ZObDNwxcG00Qkk+v/WdgYPh/+1XWic9ySSa87DpIxu4v/sfAwMDw69OtpFAxfkaYSYz/GRgYGL68ZWBgYGA4H39IDy4D18nDw8DAwMDwnIkNWY6CQIB55T+2CIBI/j1+8B4DA5P0P2ySF4NYVVkYPi1m/IEq+///////Fwge+/Hz5/sipvif/5EARPKW2Mx/////P8wi+w5ZEmKsrF4vhx3jh/6/aO6CqLkfxS0iIiQbJvsWWScjVOnnM78ZWNTW95wXwuJPXkcGBoY/x5S5cIbQlYOubJh2/v318+fPn7ctLd78x3Dt//Uz/zEwMDxmXymMGUKMCop/GBgY1BI0UAMI6tp/mA5ASGIHADm3qpNJq4xdAAAAAElFTkSuQmCC\n", 244 | "text/plain": [ 245 | "" 246 | ] 247 | }, 248 | "metadata": { 249 | "tags": [] 250 | } 251 | }, 252 | { 253 | "output_type": "display_data", 254 | "data": { 255 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABI0lEQVR4nGP8z4AbMKFy7yz/jsz9jwz+5su//P///79///7+/v//Pwuqzi9KfP9/Xdv1juEXZxMLA5Lk/39/3r153Mdw9BS/CANjNBMDAyPEQf8ffv92/vizq8+4pBlVg12FGBhYmeB2/nARFJSxTpokWfXpy89/MCdAjWWd8IVRjo/9+3Q+HkaERVBJJm0Ghp9ff6B5GuGg36WbGNXe4AiEd5tMC869xqHzwPcq7XuTsev8/5lTnlWPEUUSKRB+3+L/z4VdklH8qyfj/x84dHrs/8XwOxGHJKshA8NXjof/mLF5hYGBgYFD/84/bK5lYPj///+PxziMfbTg6/+nZ9KYsEo+2PWHgasoE8lKWHwyMDD8/8XAwMiKEgqMJKS+gZcEAF56gf6wykc6AAAAAElFTkSuQmCC\n", 256 | "text/plain": [ 257 | "" 258 | ] 259 | }, 260 | "metadata": { 261 | "tags": [] 262 | } 263 | }, 264 | { 265 | "output_type": "display_data", 266 | "data": { 267 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABXklEQVR4nGP8z4AbMOGRI1Xy9z/ckp+TVv3HKfnn6Lq/MDYLuiSHMkTjv98MzFDJv/9ZGBj+//339ef1symvmN49/bnn1H9rxv8MDAwM/yecjhFmuLPl7bMPv98IszH8ZmCUtuW0h+rkvhT3n4FVV95M/1tpgBsDpxazIDcjIyPUhq/P/zJwSLIxMr4z9u1nRnUQEy8vAwMDw3+G/yw8H+GOQ3bt/w9bnrOZfTZjwiL5/0LaPd63PJ/YsYXQ6wSezacOW/6+jYio/zDwb7Hkuf///5/ij/sDE0Lo/H9IXYOB4deET9+whS3vrbmvfx3axmTJjGns/6dJMpp+4gxid+EiSJL/f98sc7Ph6PmDVfL//++zZWM+/8cu+TZfOOrdf2yS/15u8BCpe/8fm+TXfm1Bj4O//2OT/Dtf0O/g1///sUreMF707T86gEp+S1/2B0PuPzSZbPscgpHUGBgAt9BS1wiwXusAAAAASUVORK5CYII=\n", 268 | "text/plain": [ 269 | "" 270 | ] 271 | }, 272 | "metadata": { 273 | "tags": [] 274 | } 275 | }, 276 | { 277 | "output_type": "display_data", 278 | "data": { 279 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABQElEQVR4nGP8z4AbMKHwfvxA4TIi6/xbyNDPjMRnQVb5/wkDii2oxuK1Ew2gGMvwn5EBZjAjTPL3W4jAp6uGT86f+MfAwMDAnKwKde25oF8MDAwMDP9e8nD/kWNhYGBg4JymDZV8d+AvAwMDA8OX0mBXDSWIZ9gYGRgY/kPBv68Hdl6SmfXvPxKAOuj/g92rzzEwfcDmlf97nKs4Vu3y/IPml////////9VI5fL3//+v8KMaC9HJbs8kysHAIMT9D4vO/9eEJ/z59y6HMfEPFgcp+7XfF9tyjfMDSsBDJdl6eXf+F1s1GdVUeHz++cnAxOAlvAI5sOGxwsLNzfn9njlyXKNF2X8BLIGAAyBZ8f/zge9oamF++nG/S59L9RayN//DJP8tlpTL3XwbJfT+w73y7KShLDOqoajpFh3gdS0Aq5C/ToYG3GgAAAAASUVORK5CYII=\n", 280 | "text/plain": [ 281 | "" 282 | ] 283 | }, 284 | "metadata": { 285 | "tags": [] 286 | } 287 | } 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": { 293 | "id": "baVsWfcSYHVN" 294 | }, 295 | "source": [ 296 | "## **Step 2: Build the model**" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": { 302 | "id": "gqiOdDLgYOlQ" 303 | }, 304 | "source": [ 305 | "### Library importation" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "metadata": { 311 | "id": "-9pfkqh8gxHD" 312 | }, 313 | "source": [ 314 | "# Import modules we need\n", 315 | "import glob, random\n", 316 | "from collections import OrderedDict\n", 317 | "\n", 318 | "import numpy as np\n", 319 | "\n", 320 | "try:\n", 321 | " from qqdm.notebook import qqdm as tqdm\n", 322 | "except ModuleNotFoundError:\n", 323 | " from tqdm.auto import tqdm\n", 324 | "\n", 325 | "import torch, torch.nn as nn\n", 326 | "import torch.nn.functional as F\n", 327 | "from torch.utils.data import DataLoader, Dataset\n", 328 | "import torchvision.transforms as transforms\n", 329 | "\n", 330 | "from PIL import Image\n", 331 | "from IPython.display import display\n", 332 | "\n", 333 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 334 | "\n", 335 | "# fix random seeds\n", 336 | "random_seed = 0\n", 337 | "random.seed(random_seed)\n", 338 | "np.random.seed(random_seed)\n", 339 | "torch.manual_seed(random_seed)\n", 340 | "if torch.cuda.is_available():\n", 341 | " torch.cuda.manual_seed_all(random_seed)" 342 | ], 343 | "execution_count": null, 344 | "outputs": [] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "id": "3TlwLtC1YRT7" 350 | }, 351 | "source": [ 352 | "### Model Construction Preliminaries\n", 353 | "\n", 354 | "Since our task is image classification, we need to build a CNN-based model. \n", 355 | "However, to implement MAML algorithm, we should adjust some code in `nn.Module`.\n" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": { 361 | "id": "dFwB3tuEDYfy" 362 | }, 363 | "source": [ 364 | "Take a look at MAML pseudocode...\n", 365 | "\n", 366 | "\n", 367 | "\n", 368 | "On the 10-th line, what we take gradients on are those $\\theta$ representing \n", 369 | "**the original model parameters** (outer loop) instead of those in the \n", 370 | "**inner loop**, so we need to use `functional_forward` to compute the output \n", 371 | "logits of input image instead of `forward` in `nn.Module`.\n", 372 | "\n", 373 | "The following defines these functions.\n", 374 | "\n", 375 | "" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": { 381 | "id": "iuYQiPeQYc__" 382 | }, 383 | "source": [ 384 | "### Model block definition" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "metadata": { 390 | "id": "GgFbbKHYg3Hk" 391 | }, 392 | "source": [ 393 | "def ConvBlock(in_ch: int, out_ch: int):\n", 394 | " return nn.Sequential(\n", 395 | " nn.Conv2d(in_ch, out_ch, 3, padding=1),\n", 396 | " nn.BatchNorm2d(out_ch),\n", 397 | " nn.ReLU(),\n", 398 | " nn.MaxPool2d(kernel_size=2, stride=2)\n", 399 | " )\n", 400 | "\n", 401 | "def ConvBlockFunction(x, w, b, w_bn, b_bn):\n", 402 | " x = F.conv2d(x, w, b, padding=1)\n", 403 | " x = F.batch_norm(x,\n", 404 | " running_mean=None,\n", 405 | " running_var=None,\n", 406 | " weight=w_bn, bias=b_bn,\n", 407 | " training=True)\n", 408 | " x = F.relu(x)\n", 409 | " x = F.max_pool2d(x, kernel_size=2, stride=2)\n", 410 | " return x" 411 | ], 412 | "execution_count": null, 413 | "outputs": [] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "metadata": { 418 | "id": "iQEzgWN7fi7B" 419 | }, 420 | "source": [ 421 | "### Model definition" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "metadata": { 427 | "id": "0bFBGEQoHQUW" 428 | }, 429 | "source": [ 430 | "class Classifier(nn.Module):\n", 431 | " def __init__(self, in_ch, k_way):\n", 432 | " super(Classifier, self).__init__()\n", 433 | " self.conv1 = ConvBlock(in_ch, 64)\n", 434 | " self.conv2 = ConvBlock(64, 64)\n", 435 | " self.conv3 = ConvBlock(64, 64)\n", 436 | " self.conv4 = ConvBlock(64, 64)\n", 437 | " self.logits = nn.Linear(64, k_way)\n", 438 | "\n", 439 | " def forward(self, x):\n", 440 | " x = self.conv1(x)\n", 441 | " x = self.conv2(x)\n", 442 | " x = self.conv3(x)\n", 443 | " x = self.conv4(x)\n", 444 | " x = x.view(x.shape[0], -1)\n", 445 | " x = self.logits(x)\n", 446 | " return x\n", 447 | "\n", 448 | " def functional_forward(self, x, params):\n", 449 | " '''\n", 450 | " Arguments:\n", 451 | " x: input images [batch, 1, 28, 28]\n", 452 | " params: model parameters, \n", 453 | " i.e. weights and biases of convolution\n", 454 | " and weights and biases of \n", 455 | " batch normalization\n", 456 | " type is an OrderedDict\n", 457 | "\n", 458 | " Arguments:\n", 459 | " x: input images [batch, 1, 28, 28]\n", 460 | " params: The model parameters, \n", 461 | " i.e. weights and biases of convolution \n", 462 | " and batch normalization layers\n", 463 | " It's an `OrderedDict`\n", 464 | " '''\n", 465 | " for block in [1, 2, 3, 4]:\n", 466 | " x = ConvBlockFunction(\n", 467 | " x,\n", 468 | " params[f'conv{block}.0.weight'],\n", 469 | " params[f'conv{block}.0.bias'],\n", 470 | " params.get(f'conv{block}.1.weight'),\n", 471 | " params.get(f'conv{block}.1.bias'))\n", 472 | " x = x.view(x.shape[0], -1)\n", 473 | " x = F.linear(x,\n", 474 | " params['logits.weight'],\n", 475 | " params['logits.bias'])\n", 476 | " return x" 477 | ], 478 | "execution_count": null, 479 | "outputs": [] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": { 484 | "id": "gmJq_0B9Yj0G" 485 | }, 486 | "source": [ 487 | "### Create Label\n", 488 | "\n", 489 | "This function is used to create labels. \n", 490 | "In a N-way K-shot few-shot classification problem,\n", 491 | "each task has `n_way` classes, while there are `k_shot` images for each class. \n", 492 | "This is a function that creates such labels.\n" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "metadata": { 498 | "colab": { 499 | "base_uri": "https://localhost:8080/" 500 | }, 501 | "id": "GQF5vgLvg5aX", 502 | "outputId": "5df41e04-290c-428b-b06f-cc749f09f027" 503 | }, 504 | "source": [ 505 | "def create_label(n_way, k_shot):\n", 506 | " return (torch.arange(n_way)\n", 507 | " .repeat_interleave(k_shot)\n", 508 | " .long())\n", 509 | "\n", 510 | "# Try to create labels for 5-way 2-shot setting\n", 511 | "create_label(5, 2)" 512 | ], 513 | "execution_count": null, 514 | "outputs": [ 515 | { 516 | "output_type": "execute_result", 517 | "data": { 518 | "text/plain": [ 519 | "tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])" 520 | ] 521 | }, 522 | "metadata": { 523 | "tags": [] 524 | }, 525 | "execution_count": 9 526 | } 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": { 532 | "id": "2nCFv9PGw50J" 533 | }, 534 | "source": [ 535 | "### Accuracy calculation" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "metadata": { 541 | "id": "FahDr0xQw50S" 542 | }, 543 | "source": [ 544 | "def calculate_accuracy(logits, val_label):\n", 545 | " \"\"\" utility function for accuracy calculation \"\"\"\n", 546 | " acc = np.asarray([(\n", 547 | " torch.argmax(logits, -1).cpu().numpy() == val_label.cpu().numpy())]\n", 548 | " ).mean() \n", 549 | " return acc" 550 | ], 551 | "execution_count": null, 552 | "outputs": [] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "metadata": { 557 | "id": "9Hl7ro2mYzsI" 558 | }, 559 | "source": [ 560 | "### Define Dataset\n", 561 | "\n", 562 | "Define the dataset. \n", 563 | "The dataset returns images of a random character, with (`k_shot + q_query`) images, \n", 564 | "so the size of returned tensor is `[k_shot+q_query, 1, 28, 28]`. \n" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "metadata": { 570 | "id": "-tJ2mot9hHPb" 571 | }, 572 | "source": [ 573 | "class Omniglot(Dataset):\n", 574 | " def __init__(self, data_dir, k_way, q_query):\n", 575 | " self.file_list = [f for f in glob.glob(\n", 576 | " data_dir + \"**/character*\", \n", 577 | " recursive=True)]\n", 578 | " self.transform = transforms.Compose(\n", 579 | " [transforms.ToTensor()])\n", 580 | " self.n = k_way + q_query\n", 581 | "\n", 582 | " def __getitem__(self, idx):\n", 583 | " sample = np.arange(20)\n", 584 | "\n", 585 | " # For random sampling the characters we want.\n", 586 | " np.random.shuffle(sample) \n", 587 | " img_path = self.file_list[idx]\n", 588 | " img_list = [f for f in glob.glob(\n", 589 | " img_path + \"**/*.png\", recursive=True)]\n", 590 | " img_list.sort()\n", 591 | " imgs = [self.transform(\n", 592 | " Image.open(img_file)) \n", 593 | " for img_file in img_list]\n", 594 | " # `k_way + q_query` examples for each character\n", 595 | " imgs = torch.stack(imgs)[sample[:self.n]] \n", 596 | " return imgs\n", 597 | "\n", 598 | " def __len__(self):\n", 599 | " return len(self.file_list)" 600 | ], 601 | "execution_count": null, 602 | "outputs": [] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "metadata": { 607 | "id": "Gm5iVp90Ylii" 608 | }, 609 | "source": [ 610 | "## **Step 3: Core MAML**\n", 611 | "\n", 612 | "Here is the main Meta Learning algorithm. \n", 613 | "The algorithm is exactly the same as the paper. \n", 614 | "What the function does is to update the parameters using \"the data of a meta-batch.\"\n", 615 | "Here we implement the second-order MAML (inner_train_step = 1), according to [the slides of meta learning in 2019 (p. 13 ~ p.18)](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=13&view=FitW)\n", 616 | "\n", 617 | "As for the mathematical derivation of the first-order version, please refer to [p.25 of the slides in 2019](http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf#page=25&view=FitW).\n", 618 | "\n", 619 | "The following is the algorithm with some explanation." 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "metadata": { 625 | "id": "KjNxrWW_yNck" 626 | }, 627 | "source": [ 628 | "def OriginalMAML(\n", 629 | " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n", 630 | " inner_train_step=1, inner_lr=0.4, train=True):\n", 631 | " criterion, task_loss, task_acc = loss_fn, [], []\n", 632 | "\n", 633 | " for meta_batch in x:\n", 634 | " # Get data\n", 635 | " support_set = meta_batch[: n_way * k_shot] \n", 636 | " query_set = meta_batch[n_way * k_shot :] \n", 637 | " \n", 638 | " # Copy the params for inner loop\n", 639 | " fast_weights = OrderedDict(model.named_parameters())\n", 640 | " \n", 641 | " ### ---------- INNER TRAIN LOOP ---------- ###\n", 642 | " for inner_step in range(inner_train_step): \n", 643 | " # Simply training\n", 644 | " train_label = create_label(n_way, k_shot) \\\n", 645 | " .to(device)\n", 646 | " logits = model.functional_forward(\n", 647 | " support_set, fast_weights)\n", 648 | " loss = criterion(logits, train_label)\n", 649 | " # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #\n", 650 | " \"\"\" Inner Loop Update \"\"\" #\n", 651 | " grads = torch.autograd.grad( #\n", 652 | " loss, fast_weights.values(), #\n", 653 | " create_graph=True) #\n", 654 | " # Perform SGD #\n", 655 | " fast_weights = OrderedDict( #\n", 656 | " (name, param - inner_lr * grad) #\n", 657 | " for ((name, param), grad) #\n", 658 | " in zip(fast_weights.items(), grads)) #\n", 659 | " # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #\n", 660 | "\n", 661 | " ### ---------- INNER VALID LOOP ---------- ###\n", 662 | " val_label = create_label(n_way, q_query).to(device)\n", 663 | " \n", 664 | " # Collect gradients for outer loop\n", 665 | " logits = model.functional_forward(\n", 666 | " query_set, fast_weights) \n", 667 | " loss = criterion(logits, val_label)\n", 668 | " task_loss.append(loss)\n", 669 | " task_acc.append(\n", 670 | " calculate_accuracy(logits, val_label))\n", 671 | "\n", 672 | " # Update outer loop\n", 673 | " model.train()\n", 674 | " optimizer.zero_grad()\n", 675 | "\n", 676 | " meta_batch_loss = torch.stack(task_loss).mean()\n", 677 | " if train:\n", 678 | " meta_batch_loss.backward() # <--- May change later!\n", 679 | " optimizer.step()\n", 680 | " task_acc = np.mean(task_acc)\n", 681 | " return meta_batch_loss, task_acc" 682 | ], 683 | "execution_count": null, 684 | "outputs": [] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": { 689 | "id": "MF5ZahPdxKbp" 690 | }, 691 | "source": [ 692 | "## Variations of MAML\n", 693 | "\n", 694 | "### First-order approximation of MAML (FOMAML)\n", 695 | "\n", 696 | "Slightly modify the MAML mentioned earlier, applying first-order approximation to decrease amount of computation.\n", 697 | "\n", 698 | "### Almost No Inner Loop (ANIL)\n", 699 | "\n", 700 | "The algorithm from [this paper](https://arxiv.org/abs/1909.09157), using the technique of feature reuse to decrease amount of computation." 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": { 706 | "id": "qyQ7ZUN4foh-" 707 | }, 708 | "source": [ 709 | "To finish the modification required, we need to change some blocks of the MAML algorithm. \n", 710 | "Below, we have replace three parts that may be modified as functions. \n", 711 | "Please choose to replace the functions with their alternative versions to complete the algorithm." 712 | ] 713 | }, 714 | { 715 | "cell_type": "markdown", 716 | "metadata": { 717 | "id": "Ne5cOja0H8H7" 718 | }, 719 | "source": [ 720 | "### Part 1: Inner loop update" 721 | ] 722 | }, 723 | { 724 | "cell_type": "markdown", 725 | "metadata": { 726 | "id": "LChAX51sIFwi" 727 | }, 728 | "source": [ 729 | "MAML" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "metadata": { 735 | "id": "Aqgb0kEVzQol" 736 | }, 737 | "source": [ 738 | "def inner_update_MAML(fast_weights, loss, inner_lr):\n", 739 | " \"\"\" Inner Loop Update \"\"\"\n", 740 | " grads = torch.autograd.grad(\n", 741 | " loss, fast_weights.values(), create_graph=True)\n", 742 | " # Perform SGD\n", 743 | " fast_weights = OrderedDict(\n", 744 | " (name, param - inner_lr * grad)\n", 745 | " for ((name, param), grad) in zip(fast_weights.items(), grads))\n", 746 | " return fast_weights" 747 | ], 748 | "execution_count": null, 749 | "outputs": [] 750 | }, 751 | { 752 | "cell_type": "markdown", 753 | "metadata": { 754 | "id": "QnQ_BN-L2Gd7" 755 | }, 756 | "source": [ 757 | "Alternatives" 758 | ] 759 | }, 760 | { 761 | "cell_type": "code", 762 | "metadata": { 763 | "id": "Ug5LIO6V15cd" 764 | }, 765 | "source": [ 766 | "def inner_update_alt1(fast_weights, loss, inner_lr):\n", 767 | " grads = torch.autograd.grad(\n", 768 | " loss, fast_weights.values(), create_graph=False)\n", 769 | " # Perform SGD\n", 770 | " fast_weights = OrderedDict(\n", 771 | " (name, param - inner_lr * grad)\n", 772 | " for ((name, param), grad) in zip(fast_weights.items(), grads))\n", 773 | " return fast_weights\n", 774 | "\n", 775 | "def inner_update_alt2(fast_weights, loss, inner_lr):\n", 776 | " grads = torch.autograd.grad(\n", 777 | " loss, list(fast_weights.values())[-2:], create_graph=True)\n", 778 | " # Split out the logits\n", 779 | " for ((name, param), grad) in zip(\n", 780 | " list(fast_weights.items())[-2:], grads):\n", 781 | " fast_weights[name] = param - inner_lr * grad\n", 782 | " return fast_weights" 783 | ], 784 | "execution_count": null, 785 | "outputs": [] 786 | }, 787 | { 788 | "cell_type": "markdown", 789 | "metadata": { 790 | "id": "1ZfaWPMt164t" 791 | }, 792 | "source": [ 793 | "### Part 2: Collect gradients" 794 | ] 795 | }, 796 | { 797 | "cell_type": "markdown", 798 | "metadata": { 799 | "id": "-W7zL2nN164u" 800 | }, 801 | "source": [ 802 | "MAML \n", 803 | "(Actually do nothing as gradients are computed by PyTorch automatically.)" 804 | ] 805 | }, 806 | { 807 | "cell_type": "code", 808 | "metadata": { 809 | "id": "sgcuPPm2zSFL" 810 | }, 811 | "source": [ 812 | "def collect_gradients_MAML(\n", 813 | " special_grad: OrderedDict, fast_weights, model, len_data):\n", 814 | " \"\"\" Actually do nothing (just backwards later) \"\"\"\n", 815 | " return special_grad" 816 | ], 817 | "execution_count": null, 818 | "outputs": [] 819 | }, 820 | { 821 | "cell_type": "markdown", 822 | "metadata": { 823 | "id": "2OxEME6l2QOO" 824 | }, 825 | "source": [ 826 | "Alternatives" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "metadata": { 832 | "id": "fWLYwZlM2RZO" 833 | }, 834 | "source": [ 835 | "def collect_gradients_alt(\n", 836 | " special_grad: OrderedDict, fast_weights, model, len_data):\n", 837 | " \"\"\" Special gradient calculation \"\"\"\n", 838 | " diff = OrderedDict(\n", 839 | " (name, params - fast_weights[name]) \n", 840 | " for (name, params) in model.named_parameters())\n", 841 | " for name in diff:\n", 842 | " special_grad[name] = special_grad.get(name, 0) + diff[name] / len_data\n", 843 | " return special_grad" 844 | ], 845 | "execution_count": null, 846 | "outputs": [] 847 | }, 848 | { 849 | "cell_type": "markdown", 850 | "metadata": { 851 | "id": "ahqE-Sf92TID" 852 | }, 853 | "source": [ 854 | "### Part 3: Outer loop gradients calculation" 855 | ] 856 | }, 857 | { 858 | "cell_type": "markdown", 859 | "metadata": { 860 | "id": "-wr0hSd02TIE" 861 | }, 862 | "source": [ 863 | "MAML \n", 864 | "(Simply call PyTorch `backward`.)" 865 | ] 866 | }, 867 | { 868 | "cell_type": "code", 869 | "metadata": { 870 | "id": "_hBSQ02xzTXb" 871 | }, 872 | "source": [ 873 | "def outer_update_MAML(model, meta_batch_loss, grad_tensors):\n", 874 | " \"\"\" Simply backwards \"\"\"\n", 875 | " meta_batch_loss.backward()" 876 | ], 877 | "execution_count": null, 878 | "outputs": [] 879 | }, 880 | { 881 | "cell_type": "markdown", 882 | "metadata": { 883 | "id": "Q4zxf6yr2TIE" 884 | }, 885 | "source": [ 886 | "Alternatives" 887 | ] 888 | }, 889 | { 890 | "cell_type": "code", 891 | "metadata": { 892 | "id": "DEyCwYmI2bdC" 893 | }, 894 | "source": [ 895 | "def outer_update_alt(model, meta_batch_loss, grad_tensors):\n", 896 | " \"\"\" Replace the gradients\n", 897 | " with precalculated tensors \"\"\"\n", 898 | " for (name, params) in model.named_parameters():\n", 899 | " params.grad = grad_tensors[name]" 900 | ], 901 | "execution_count": null, 902 | "outputs": [] 903 | }, 904 | { 905 | "cell_type": "markdown", 906 | "metadata": { 907 | "id": "z1jck3KE2g1D" 908 | }, 909 | "source": [ 910 | "### Complete the algorithm\n", 911 | "Here we have wrapped the algorithm in `MetaAlgorithmGenerator`. \n", 912 | "You can get your modified algorithm by filling in like this:\n", 913 | "```python\n", 914 | "MyAlgorithm = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n", 915 | "```\n", 916 | "Default the three blocks will be filled with that of `MAML`." 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "metadata": { 922 | "id": "XosNxVMDxL6V" 923 | }, 924 | "source": [ 925 | "def MetaAlgorithmGenerator(\n", 926 | " inner_update = inner_update_MAML, \n", 927 | " collect_gradients = collect_gradients_MAML, \n", 928 | " outer_update = outer_update_MAML):\n", 929 | "\n", 930 | " global calculate_accuracy\n", 931 | "\n", 932 | " def MetaAlgorithm(\n", 933 | " model, optimizer, x, n_way, k_shot, q_query, loss_fn,\n", 934 | " inner_train_step=1, inner_lr=0.4, train=True): \n", 935 | " criterion = loss_fn\n", 936 | " task_loss, task_acc = [], []\n", 937 | " special_grad = OrderedDict() # Added for variants!\n", 938 | "\n", 939 | " for meta_batch in x:\n", 940 | " support_set = meta_batch[: n_way * k_shot] \n", 941 | " query_set = meta_batch[n_way * k_shot :] \n", 942 | " \n", 943 | " fast_weights = OrderedDict(model.named_parameters())\n", 944 | " \n", 945 | " ### ---------- INNER TRAIN LOOP ---------- ###\n", 946 | " for inner_step in range(inner_train_step): \n", 947 | " train_label = create_label(n_way, k_shot).to(device)\n", 948 | " logits = model.functional_forward(support_set, fast_weights)\n", 949 | " loss = criterion(logits, train_label)\n", 950 | "\n", 951 | " fast_weights = inner_update(fast_weights, loss, inner_lr)\n", 952 | "\n", 953 | " ### ---------- INNER VALID LOOP ---------- ###\n", 954 | " val_label = create_label(n_way, q_query).to(device)\n", 955 | " # FIXME: W for val?\n", 956 | " special_grad = collect_gradients(\n", 957 | " special_grad, fast_weights, model, len(x))\n", 958 | " \n", 959 | " # Collect gradients for outer loop\n", 960 | " logits = model.functional_forward(query_set, fast_weights) \n", 961 | " loss = criterion(logits, val_label)\n", 962 | " task_loss.append(loss)\n", 963 | " task_acc.append(calculate_accuracy(logits, val_label))\n", 964 | "\n", 965 | " # Update outer loop\n", 966 | " model.train()\n", 967 | " optimizer.zero_grad()\n", 968 | "\n", 969 | " meta_batch_loss = torch.stack(task_loss).mean()\n", 970 | " if train:\n", 971 | " # Notice the update part!\n", 972 | " outer_update(model, meta_batch_loss, special_grad)\n", 973 | " optimizer.step()\n", 974 | " task_acc = np.mean(task_acc)\n", 975 | " return meta_batch_loss, task_acc\n", 976 | " return MetaAlgorithm" 977 | ], 978 | "execution_count": null, 979 | "outputs": [] 980 | }, 981 | { 982 | "cell_type": "code", 983 | "metadata": { 984 | "id": "jEsPtV-GzbDv", 985 | "cellView": "form" 986 | }, 987 | "source": [ 988 | "#@title Here is the answer hidden, please fill in yourself!\n", 989 | "Give_me_the_answer = True #@param {\"type\": \"boolean\"}\n", 990 | "\n", 991 | "def HiddenAnswer():\n", 992 | " MAML = MetaAlgorithmGenerator()\n", 993 | " FOMAML = MetaAlgorithmGenerator(inner_update=inner_update_alt1)\n", 994 | " ANIL = MetaAlgorithmGenerator(inner_update=inner_update_alt2)\n", 995 | " return MAML, FOMAML, ANIL" 996 | ], 997 | "execution_count": null, 998 | "outputs": [] 999 | }, 1000 | { 1001 | "cell_type": "code", 1002 | "metadata": { 1003 | "id": "2P__5N2Yz9O4" 1004 | }, 1005 | "source": [ 1006 | "# `HiddenAnswer` is hidden in the last cell.\n", 1007 | "if Give_me_the_answer:\n", 1008 | " MAML, FOMAML, ANIL = HiddenAnswer()\n", 1009 | "else: \n", 1010 | " # TODO: Please fill in the function names \\\n", 1011 | " # as the function arguments to finish the algorithm.\n", 1012 | " MAML = MetaAlgorithmGenerator()\n", 1013 | " FOMAML = MetaAlgorithmGenerator()\n", 1014 | " ANIL = MetaAlgorithmGenerator()" 1015 | ], 1016 | "execution_count": null, 1017 | "outputs": [] 1018 | }, 1019 | { 1020 | "cell_type": "markdown", 1021 | "metadata": { 1022 | "id": "nBoRBhVlZAST" 1023 | }, 1024 | "source": [ 1025 | "## **Step 4: Initialization**\n", 1026 | "\n", 1027 | "After defining all components we need, the following initialize a model before training." 1028 | ] 1029 | }, 1030 | { 1031 | "cell_type": "markdown", 1032 | "metadata": { 1033 | "id": "Ip-i7aseftUF" 1034 | }, 1035 | "source": [ 1036 | "\n", 1037 | "### Hyperparameters \n", 1038 | "[Go back to top!](#top)" 1039 | ] 1040 | }, 1041 | { 1042 | "cell_type": "code", 1043 | "metadata": { 1044 | "id": "0wFHmVcBhE4M" 1045 | }, 1046 | "source": [ 1047 | "n_way = 5\n", 1048 | "k_shot = 1\n", 1049 | "q_query = 1\n", 1050 | "inner_train_step = 1\n", 1051 | "inner_lr = 0.4\n", 1052 | "meta_lr = 0.001\n", 1053 | "meta_batch_size = 32\n", 1054 | "max_epoch = 30\n", 1055 | "eval_batches = test_batches = 20\n", 1056 | "train_data_path = './Omniglot/images_background/'\n", 1057 | "test_data_path = './Omniglot/images_evaluation/' " 1058 | ], 1059 | "execution_count": null, 1060 | "outputs": [] 1061 | }, 1062 | { 1063 | "cell_type": "markdown", 1064 | "metadata": { 1065 | "id": "Uvzo7NVpfu5V" 1066 | }, 1067 | "source": [ 1068 | "### Dataloader initialization" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "code", 1073 | "metadata": { 1074 | "id": "3I13GJavhP0_" 1075 | }, 1076 | "source": [ 1077 | "def dataloader_init(datasets, num_workers=2):\n", 1078 | " train_set, val_set, test_set = datasets\n", 1079 | " train_loader = DataLoader(train_set,\n", 1080 | " # The \"batch_size\" here is not \\\n", 1081 | " # the meta batch size, but \\\n", 1082 | " # how many different \\\n", 1083 | " # characters in a task, \\\n", 1084 | " # i.e. the \"n_way\" in \\\n", 1085 | " # few-shot classification.\n", 1086 | " batch_size=n_way,\n", 1087 | " num_workers=num_workers,\n", 1088 | " shuffle=True,\n", 1089 | " drop_last=True)\n", 1090 | " val_loader = DataLoader(val_set,\n", 1091 | " batch_size=n_way,\n", 1092 | " num_workers=num_workers,\n", 1093 | " shuffle=True,\n", 1094 | " drop_last=True)\n", 1095 | " test_loader = DataLoader(test_set,\n", 1096 | " batch_size=n_way,\n", 1097 | " num_workers=num_workers,\n", 1098 | " shuffle=True,\n", 1099 | " drop_last=True)\n", 1100 | " train_iter = iter(train_loader)\n", 1101 | " val_iter = iter(val_loader)\n", 1102 | " test_iter = iter(test_loader)\n", 1103 | " return (train_loader, val_loader, test_loader), \\\n", 1104 | " (train_iter, val_iter, test_iter)\n", 1105 | "\n", 1106 | "train_set, val_set = torch.utils.data.random_split(\n", 1107 | " Omniglot(train_data_path, k_shot, q_query), [3200, 656])\n", 1108 | "test_set = Omniglot(test_data_path, k_shot, q_query)\n", 1109 | "\n", 1110 | "(train_loader, val_loader, test_loader), \\\n", 1111 | "(train_iter, val_iter, test_iter) = dataloader_init(\n", 1112 | " (train_set, val_set, test_set))" 1113 | ], 1114 | "execution_count": null, 1115 | "outputs": [] 1116 | }, 1117 | { 1118 | "cell_type": "markdown", 1119 | "metadata": { 1120 | "id": "KVund--bfw0e" 1121 | }, 1122 | "source": [ 1123 | "### Model & optimizer initialization" 1124 | ] 1125 | }, 1126 | { 1127 | "cell_type": "code", 1128 | "metadata": { 1129 | "id": "Kxug882ihF2B" 1130 | }, 1131 | "source": [ 1132 | "def model_init():\n", 1133 | " meta_model = Classifier(1, n_way).to(device)\n", 1134 | " optimizer = torch.optim.Adam(meta_model.parameters(), \n", 1135 | " lr=meta_lr)\n", 1136 | " loss_fn = nn.CrossEntropyLoss().to(device)\n", 1137 | " return meta_model, optimizer, loss_fn\n", 1138 | "\n", 1139 | "meta_model, optimizer, loss_fn = model_init()" 1140 | ], 1141 | "execution_count": null, 1142 | "outputs": [] 1143 | }, 1144 | { 1145 | "cell_type": "markdown", 1146 | "metadata": { 1147 | "id": "gj8cLRNLf2zg" 1148 | }, 1149 | "source": [ 1150 | "### Utility function to get a meta-batch" 1151 | ] 1152 | }, 1153 | { 1154 | "cell_type": "code", 1155 | "metadata": { 1156 | "id": "zrkCSsxOhC-N" 1157 | }, 1158 | "source": [ 1159 | "def get_meta_batch(meta_batch_size,\n", 1160 | " k_shot, q_query, \n", 1161 | " data_loader, iterator):\n", 1162 | " data = []\n", 1163 | " for _ in range(meta_batch_size):\n", 1164 | " try:\n", 1165 | " # a \"task_data\" tensor is representing \\\n", 1166 | " # the data of a task, with size of \\\n", 1167 | " # [n_way, k_shot+q_query, 1, 28, 28]\n", 1168 | " task_data = iterator.next() \n", 1169 | " except StopIteration:\n", 1170 | " iterator = iter(data_loader)\n", 1171 | " task_data = iterator.next()\n", 1172 | " train_data = (task_data[:, :k_shot]\n", 1173 | " .reshape(-1, 1, 28, 28))\n", 1174 | " val_data = (task_data[:, k_shot:]\n", 1175 | " .reshape(-1, 1, 28, 28))\n", 1176 | " task_data = torch.cat(\n", 1177 | " (train_data, val_data), 0)\n", 1178 | " data.append(task_data)\n", 1179 | " return torch.stack(data).to(device), iterator" 1180 | ], 1181 | "execution_count": null, 1182 | "outputs": [] 1183 | }, 1184 | { 1185 | "cell_type": "markdown", 1186 | "metadata": { 1187 | "id": "O5JCtob4fyh_" 1188 | }, 1189 | "source": [ 1190 | "\n", 1191 | "### Choose the meta learning algorithm\n", 1192 | "[Go back to top!](#top)" 1193 | ] 1194 | }, 1195 | { 1196 | "cell_type": "code", 1197 | "metadata": { 1198 | "id": "3av6pAI7OxOP" 1199 | }, 1200 | "source": [ 1201 | "# You can change this to `FOMAML` or `ANIL`\n", 1202 | "MetaAlgorithm = MAML" 1203 | ], 1204 | "execution_count": null, 1205 | "outputs": [] 1206 | }, 1207 | { 1208 | "cell_type": "markdown", 1209 | "metadata": { 1210 | "id": "pWQczA3FwjEG" 1211 | }, 1212 | "source": [ 1213 | "\n", 1214 | "## **Step 5: Main program for training & testing**" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "markdown", 1219 | "metadata": { 1220 | "id": "8EirEnaof7ep" 1221 | }, 1222 | "source": [ 1223 | "### Start training!\n", 1224 | "\n", 1225 | "[Go back to top!](#top)" 1226 | ] 1227 | }, 1228 | { 1229 | "cell_type": "code", 1230 | "metadata": { 1231 | "id": "JQZjJrLAhBWw" 1232 | }, 1233 | "source": [ 1234 | "for epoch in range(max_epoch):\n", 1235 | " print(\"Epoch %d\" % (epoch + 1))\n", 1236 | " train_meta_loss = []\n", 1237 | " train_acc = []\n", 1238 | " # The \"step\" here is a meta-gradinet update step\n", 1239 | " for step in tqdm(range(\n", 1240 | " len(train_loader) // meta_batch_size)): \n", 1241 | " x, train_iter = get_meta_batch(\n", 1242 | " meta_batch_size, k_shot, q_query, \n", 1243 | " train_loader, train_iter)\n", 1244 | " meta_loss, acc = MetaAlgorithm(\n", 1245 | " meta_model, optimizer, x, \n", 1246 | " n_way, k_shot, q_query, loss_fn)\n", 1247 | " train_meta_loss.append(meta_loss.item())\n", 1248 | " train_acc.append(acc)\n", 1249 | " print(\" Loss : \", \"%.3f\" % (np.mean(train_meta_loss)), end='\\t')\n", 1250 | " print(\" Accuracy: \", \"%.3f %%\" % (np.mean(train_acc) * 100))\n", 1251 | "\n", 1252 | " # See the validation accuracy after each epoch.\n", 1253 | " # Early stopping is welcomed to implement.\n", 1254 | " val_acc = []\n", 1255 | " for eval_step in tqdm(range(\n", 1256 | " len(val_loader) // (eval_batches))):\n", 1257 | " x, val_iter = get_meta_batch(\n", 1258 | " eval_batches, k_shot, q_query, \n", 1259 | " val_loader, val_iter)\n", 1260 | " # We update three inner steps when testing.\n", 1261 | " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n", 1262 | " n_way, k_shot, q_query, \n", 1263 | " loss_fn, \n", 1264 | " inner_train_step=3, \n", 1265 | " train=False) \n", 1266 | " val_acc.append(acc)\n", 1267 | " print(\" Validation accuracy: \", \"%.3f %%\" % (np.mean(val_acc) * 100))" 1268 | ], 1269 | "execution_count": null, 1270 | "outputs": [] 1271 | }, 1272 | { 1273 | "cell_type": "markdown", 1274 | "metadata": { 1275 | "id": "u5Ew8-POf9sw" 1276 | }, 1277 | "source": [ 1278 | "### Testing the result" 1279 | ] 1280 | }, 1281 | { 1282 | "cell_type": "code", 1283 | "metadata": { 1284 | "id": "CYN_zGB3g_5_" 1285 | }, 1286 | "source": [ 1287 | "test_acc = []\n", 1288 | "for test_step in tqdm(range(\n", 1289 | " len(test_loader) // (test_batches))):\n", 1290 | " x, test_iter = get_meta_batch(\n", 1291 | " test_batches, k_shot, q_query, \n", 1292 | " test_loader, test_iter)\n", 1293 | " # When testing, we update 3 inner-steps\n", 1294 | " _, acc = MetaAlgorithm(meta_model, optimizer, x, \n", 1295 | " n_way, k_shot, q_query, loss_fn, \n", 1296 | " inner_train_step=3, train=False)\n", 1297 | " test_acc.append(acc)\n", 1298 | "print(\" Testing accuracy: \", \"%.3f %%\" % (np.mean(test_acc) * 100))" 1299 | ], 1300 | "execution_count": null, 1301 | "outputs": [] 1302 | }, 1303 | { 1304 | "cell_type": "markdown", 1305 | "metadata": { 1306 | "id": "rtD8X3RLf-6w" 1307 | }, 1308 | "source": [ 1309 | "## **Reference**\n", 1310 | "1. Chelsea Finn, Pieter Abbeel, & Sergey Levine. (2017). [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.](https://arxiv.org/abs/1909.09157)\n", 1311 | "1. Aniruddh Raghu, Maithra Raghu, Samy Bengio, & Oriol Vinyals. (2020). [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML.](https://arxiv.org/abs/1909.09157)" 1312 | ] 1313 | } 1314 | ] 1315 | } 1316 | -------------------------------------------------------------------------------- /HW2/HW02.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW2/HW02.pdf -------------------------------------------------------------------------------- /HW2/README2_1.md: -------------------------------------------------------------------------------- 1 | # Homework 2-1: Phoneme Classification 2 | 3 | Detail : [Link](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/hw/HW02/HW02.pdf) 4 | 5 | Code : [Link](homework2_1.ipynb) 6 | -------------------------------------------------------------------------------- /HW2/README2_2.md: -------------------------------------------------------------------------------- 1 | # Homework 2-2: Hessian Matrix 2 | 3 | Detail : [Link](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/hw/HW02/HW02.pdf) 4 | 5 | [code](homework2_2.ipynb) 6 | -------------------------------------------------------------------------------- /HW2/homework2_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Untitled0.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyNh31QPg5onyqlA0B16rpPP", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "mEVFnEg4a6yE" 34 | }, 35 | "source": [ 36 | "# **Homework 2-2 Hessian Matrix**\n", 37 | "\n", 38 | "## Hessian Matrix\n", 39 | "Imagine we are training a neural network and we are trying to find out whether the model is at **local minima like, saddle point, or none of the above**. We can make our decision by calculating the Hessian matrix.\n", 40 | "\n", 41 | "In practice, it is really hard to find a point where the gradient equals zero or all of the eigenvalues in Hessian matrix are greater than zero. In this homework, we make the following two assumptions:\n", 42 | "1. View gradient norm less than 1e-3 as **gradient equals to zero**.\n", 43 | "2. If minimum ratio is greater than 0.5 and gradient norm is less than 1e-3, then we assume that the model is at “local minima like”.\n", 44 | "\n", 45 | "> Minimum ratio is defined as the proportion of positive eigenvalues." 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": { 51 | "id": "AfEv2ee-bATF" 52 | }, 53 | "source": [ 54 | "## IMPORTANT NOTICE\n", 55 | "In this homework, students with different student IDs will get different answers. Make sure to fill in your `student_id` in the following block correctly. Otherwise, your code may not run correctly and you will get a wrong answer." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "id": "LR4SxRxQaqVo" 62 | }, 63 | "source": [ 64 | "student_id = '40673034h' # fill with your student ID\n", 65 | "\n", 66 | "assert student_id != 'your_student_id', 'Please fill in your student_id before you start.'" 67 | ], 68 | "execution_count": 1, 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "id": "CS0o8eJsbEIX" 75 | }, 76 | "source": [ 77 | "## Calculate Hessian Matrix\n", 78 | "The computation of Hessian is done by TA, you don't need to and shouldn't change the following code. The only thing you need to do is to run the following blocks and determine whether the model is at `local minima like`, `saddle point`, or `none of the above` according to the value of `gradient norm` and `minimum ratio`." 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "id": "cr3BIXy4bG0Y" 85 | }, 86 | "source": [ 87 | "### Install Package to Compute Hessian.\n", 88 | "\n", 89 | "The autograd-lib library is used to compute Hessian matrix. You can check the full document here https://github.com/cybertronai/autograd-lib." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "metadata": { 95 | "colab": { 96 | "base_uri": "https://localhost:8080/" 97 | }, 98 | "id": "s5glDmIebEqw", 99 | "outputId": "1b536f61-3e6b-46bb-8802-c293c117025a" 100 | }, 101 | "source": [ 102 | "!pip install autograd-lib" 103 | ], 104 | "execution_count": 2, 105 | "outputs": [ 106 | { 107 | "output_type": "stream", 108 | "text": [ 109 | "Collecting autograd-lib\n", 110 | " Downloading https://files.pythonhosted.org/packages/b9/ac/a3927e1e2a886a12b914bce86965bec3b925ad14ffb696b2f84d9f8ee949/autograd_lib-0.0.7-py3-none-any.whl\n", 111 | "Requirement already satisfied: gin-config in /usr/local/lib/python3.7/dist-packages (from autograd-lib) (0.4.0)\n", 112 | "Collecting pytorch-lightning\n", 113 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/58/01/5df6324efdc3f79025ea7eaf19478936c401a16dae4fd3fbd29f7d426974/pytorch_lightning-1.2.6-py3-none-any.whl (829kB)\n", 114 | "\u001b[K |████████████████████████████████| 839kB 5.2MB/s \n", 115 | "\u001b[?25hRequirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from autograd-lib) (0.11.1)\n", 116 | "Requirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning->autograd-lib) (1.19.5)\n", 117 | "Collecting torchmetrics>=0.2.0\n", 118 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/3a/42/d984612cabf005a265aa99c8d4ab2958e37b753aafb12f31c81df38751c8/torchmetrics-0.2.0-py3-none-any.whl (176kB)\n", 119 | "\u001b[K |████████████████████████████████| 184kB 8.0MB/s \n", 120 | "\u001b[?25hRequirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning->autograd-lib) (4.41.1)\n", 121 | "Collecting fsspec[http]>=0.8.1\n", 122 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/91/0d/a6bfee0ddf47b254286b9bd574e6f50978c69897647ae15b14230711806e/fsspec-0.8.7-py3-none-any.whl (103kB)\n", 123 | "\u001b[K |████████████████████████████████| 112kB 11.4MB/s \n", 124 | "\u001b[?25hRequirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning->autograd-lib) (2.4.1)\n", 125 | "Collecting PyYAML!=5.4.*,>=5.1\n", 126 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", 127 | "\u001b[K |████████████████████████████████| 276kB 10.7MB/s \n", 128 | "\u001b[?25hRequirement already satisfied: torch>=1.4 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning->autograd-lib) (1.8.1+cu101)\n", 129 | "Collecting future>=0.17.1\n", 130 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)\n", 131 | "\u001b[K |████████████████████████████████| 829kB 12.7MB/s \n", 132 | "\u001b[?25hRequirement already satisfied: scipy>=1.0 in /usr/local/lib/python3.7/dist-packages (from seaborn->autograd-lib) (1.4.1)\n", 133 | "Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.7/dist-packages (from seaborn->autograd-lib) (1.1.5)\n", 134 | "Requirement already satisfied: matplotlib>=2.2 in /usr/local/lib/python3.7/dist-packages (from seaborn->autograd-lib) (3.2.2)\n", 135 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=0.8.1->pytorch-lightning->autograd-lib) (3.8.1)\n", 136 | "Collecting aiohttp; extra == \"http\"\n", 137 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/88/c0/5890b4c8b04a79b7360e8fe4490feb0bb3ab179743f199f0e6220cebd568/aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3MB)\n", 138 | "\u001b[K |████████████████████████████████| 1.3MB 21.0MB/s \n", 139 | "\u001b[?25hRequirement already satisfied: requests; extra == \"http\" in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=0.8.1->pytorch-lightning->autograd-lib) (2.23.0)\n", 140 | "Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (3.12.4)\n", 141 | "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (1.32.0)\n", 142 | "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (1.15.0)\n", 143 | "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (1.8.0)\n", 144 | "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (0.36.2)\n", 145 | "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (0.12.0)\n", 146 | "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (1.28.0)\n", 147 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (3.3.4)\n", 148 | "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (0.4.3)\n", 149 | "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (1.0.1)\n", 150 | "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (54.2.0)\n", 151 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.4->pytorch-lightning->autograd-lib) (3.7.4.3)\n", 152 | "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->seaborn->autograd-lib) (2.8.1)\n", 153 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23->seaborn->autograd-lib) (2018.9)\n", 154 | "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.2->seaborn->autograd-lib) (1.3.1)\n", 155 | "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.2->seaborn->autograd-lib) (2.4.7)\n", 156 | "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.2->seaborn->autograd-lib) (0.10.0)\n", 157 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->fsspec[http]>=0.8.1->pytorch-lightning->autograd-lib) (3.4.1)\n", 158 | "Collecting async-timeout<4.0,>=3.0\n", 159 | " Downloading https://files.pythonhosted.org/packages/e1/1e/5a4441be21b0726c4464f3f23c8b19628372f606755a9d2e46c187e65ec4/async_timeout-3.0.1-py3-none-any.whl\n", 160 | "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp; extra == \"http\"->fsspec[http]>=0.8.1->pytorch-lightning->autograd-lib) (20.3.0)\n", 161 | "Collecting multidict<7.0,>=4.5\n", 162 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7c/a6/4123b8165acbe773d1a8dc8e3f0d1edea16d29f7de018eda769abb56bd30/multidict-5.1.0-cp37-cp37m-manylinux2014_x86_64.whl (142kB)\n", 163 | "\u001b[K |████████████████████████████████| 143kB 29.1MB/s \n", 164 | "\u001b[?25hRequirement already satisfied: chardet<5.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp; extra == \"http\"->fsspec[http]>=0.8.1->pytorch-lightning->autograd-lib) (3.0.4)\n", 165 | "Collecting yarl<2.0,>=1.0\n", 166 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f1/62/046834c5fc998c88ab2ef722f5d42122230a632212c8afa76418324f53ff/yarl-1.6.3-cp37-cp37m-manylinux2014_x86_64.whl (294kB)\n", 167 | "\u001b[K |████████████████████████████████| 296kB 21.3MB/s \n", 168 | "\u001b[?25hRequirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests; extra == \"http\"->fsspec[http]>=0.8.1->pytorch-lightning->autograd-lib) (1.24.3)\n", 169 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests; extra == \"http\"->fsspec[http]>=0.8.1->pytorch-lightning->autograd-lib) (2020.12.5)\n", 170 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests; extra == \"http\"->fsspec[http]>=0.8.1->pytorch-lightning->autograd-lib) (2.10)\n", 171 | "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (4.2.1)\n", 172 | "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3.6\" in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (4.7.2)\n", 173 | "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (0.2.8)\n", 174 | "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (1.3.0)\n", 175 | "Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.7/dist-packages (from rsa<5,>=3.1.4; python_version >= \"3.6\"->google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (0.4.8)\n", 176 | "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning->autograd-lib) (3.1.0)\n", 177 | "Building wheels for collected packages: PyYAML, future\n", 178 | " Building wheel for PyYAML (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 179 | " Created wheel for PyYAML: filename=PyYAML-5.3.1-cp37-cp37m-linux_x86_64.whl size=44620 sha256=b80020fe2f1cc81e7a85d1a230840ae9e3051ccf18808e5d38b12578f4c8080e\n", 180 | " Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n", 181 | " Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 182 | " Created wheel for future: filename=future-0.18.2-cp37-none-any.whl size=491058 sha256=eb42531dfa9bff327ad80fe7f584e2231e31f800574c57c0cfaa780ca568c4a0\n", 183 | " Stored in directory: /root/.cache/pip/wheels/8b/99/a0/81daf51dcd359a9377b110a8a886b3895921802d2fc1b2397e\n", 184 | "Successfully built PyYAML future\n", 185 | "Installing collected packages: torchmetrics, async-timeout, multidict, yarl, aiohttp, fsspec, PyYAML, future, pytorch-lightning, autograd-lib\n", 186 | " Found existing installation: PyYAML 3.13\n", 187 | " Uninstalling PyYAML-3.13:\n", 188 | " Successfully uninstalled PyYAML-3.13\n", 189 | " Found existing installation: future 0.16.0\n", 190 | " Uninstalling future-0.16.0:\n", 191 | " Successfully uninstalled future-0.16.0\n", 192 | "Successfully installed PyYAML-5.3.1 aiohttp-3.7.4.post0 async-timeout-3.0.1 autograd-lib-0.0.7 fsspec-0.8.7 future-0.18.2 multidict-5.1.0 pytorch-lightning-1.2.6 torchmetrics-0.2.0 yarl-1.6.3\n" 193 | ], 194 | "name": "stdout" 195 | } 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "Ts3YO-DrbNhW" 202 | }, 203 | "source": [ 204 | "### Import Libraries" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "metadata": { 210 | "id": "kYWpBK-sbPX8" 211 | }, 212 | "source": [ 213 | "import numpy as np\n", 214 | "from math import pi\n", 215 | "from collections import defaultdict\n", 216 | "from autograd_lib import autograd_lib\n", 217 | "\n", 218 | "import torch\n", 219 | "import torch.nn as nn\n", 220 | "from torch.utils.data import DataLoader, Dataset\n", 221 | "\n", 222 | "import warnings\n", 223 | "warnings.filterwarnings(\"ignore\")" 224 | ], 225 | "execution_count": 3, 226 | "outputs": [] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": { 231 | "id": "qy_Hs63GbQ_a" 232 | }, 233 | "source": [ 234 | "### Define NN Model\n", 235 | "The NN model here is used to fit a single variable math function.\n", 236 | "$$f(x) = \\frac{\\sin(5\\pi x)}{5\\pi x}.$$" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "metadata": { 242 | "id": "cIX42-5mbTVl" 243 | }, 244 | "source": [ 245 | "class MathRegressor(nn.Module):\n", 246 | " def __init__(self, num_hidden=128):\n", 247 | " super().__init__()\n", 248 | " self.regressor = nn.Sequential(\n", 249 | " nn.Linear(1, num_hidden),\n", 250 | " nn.ReLU(),\n", 251 | " nn.Linear(num_hidden, 1)\n", 252 | " )\n", 253 | "\n", 254 | " def forward(self, x):\n", 255 | " x = self.regressor(x)\n", 256 | " return x" 257 | ], 258 | "execution_count": 4, 259 | "outputs": [] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": { 264 | "id": "StneuhWqbVGL" 265 | }, 266 | "source": [ 267 | "### Get Pretrained Checkpoints\n", 268 | "The pretrained checkpoints is done by TA. Each student will get a different checkpoint." 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "colab": { 275 | "base_uri": "https://localhost:8080/" 276 | }, 277 | "id": "ibMNFpMgbW2-", 278 | "outputId": "8a469925-0a79-4f51-dde1-437072c4de3f" 279 | }, 280 | "source": [ 281 | "!gdown --id 1ym6G7KKNkbsqSnMmnxdQKHO1JBoF0LPR" 282 | ], 283 | "execution_count": 5, 284 | "outputs": [ 285 | { 286 | "output_type": "stream", 287 | "text": [ 288 | "Downloading...\n", 289 | "From: https://drive.google.com/uc?id=1ym6G7KKNkbsqSnMmnxdQKHO1JBoF0LPR\n", 290 | "To: /content/data.pth\n", 291 | "\r 0% 0.00/34.5k [00:00nli', B, A) # do batch-wise outer product\n", 403 | "\n", 404 | " # full Hessian\n", 405 | " hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) # do batch-wise outer product, then sum over the batch\n", 406 | "\n", 407 | "# function to compute the minimum ratio\n", 408 | "def compute_minimum_ratio(model, criterion, train, target):\n", 409 | " model.zero_grad()\n", 410 | " # compute Hessian matrix\n", 411 | " # save the gradient of each layer\n", 412 | " with autograd_lib.module_hook(save_activations):\n", 413 | " output = model(train)\n", 414 | " loss = criterion(output, target)\n", 415 | "\n", 416 | " # compute Hessian according to the gradient value stored in the previous step\n", 417 | " with autograd_lib.module_hook(compute_hess):\n", 418 | " autograd_lib.backward_hessian(output, loss='LeastSquares')\n", 419 | "\n", 420 | " layer_hess = list(hess.values())\n", 421 | " minimum_ratio = []\n", 422 | "\n", 423 | " # compute eigenvalues of the Hessian matrix\n", 424 | " for h in layer_hess:\n", 425 | " size = h.shape[0] * h.shape[1]\n", 426 | " h = h.reshape(size, size)\n", 427 | " h_eig = torch.symeig(h).eigenvalues # torch.symeig() returns eigenvalues and eigenvectors of a real symmetric matrix\n", 428 | " num_greater = torch.sum(h_eig > 0).item()\n", 429 | " minimum_ratio.append(num_greater / len(h_eig))\n", 430 | "\n", 431 | " ratio_mean = np.mean(minimum_ratio) # compute mean of minimum ratio\n", 432 | "\n", 433 | " return ratio_mean" 434 | ], 435 | "execution_count": 8, 436 | "outputs": [] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": { 441 | "id": "78xJoeY_bmSr" 442 | }, 443 | "source": [ 444 | "### Mathematical Derivation\n", 445 | "\n", 446 | "Method used here: https://en.wikipedia.org/wiki/Gauss–Newton_algorithm\n", 447 | "\n", 448 | "> **Notations** \\\\\n", 449 | "> $\\mathbf{A}$: the input of the layer. \\\\\n", 450 | "> $\\mathbf{B}$: the backprop value. \\\\\n", 451 | "> $\\mathbf{Z}$: the output of the layer. \\\\\n", 452 | "> $L$: the total loss, mean squared error was used here, $L=e^2$. \\\\\n", 453 | "> $w$: the weight value.\n", 454 | "\n", 455 | "Assume that the input dimension of the layer is $n$, and the output dimension of the layer is $m$.\n", 456 | "\n", 457 | "The derivative of the loss is\n", 458 | "\n", 459 | "\\begin{align*}\n", 460 | " \\left(\\frac{\\partial L}{\\partial w}\\right)_{nm} &= \\mathbf{A}_m \\mathbf{B}_n,\n", 461 | "\\end{align*}\n", 462 | "\n", 463 | "which can be written as\n", 464 | "\n", 465 | "\\begin{align*}\n", 466 | " \\frac{\\partial L}{\\partial w} &= \\mathbf{B} \\times \\mathbf{A}.\n", 467 | "\\end{align*}\n", 468 | "\n", 469 | "The Hessian can be derived as\n", 470 | "\n", 471 | "\\begin{align*}\n", 472 | " \\mathbf{H}_{ij}&=\\frac{\\partial^2 L}{\\partial w_i \\partial w_j} \\\\\n", 473 | " &= \\frac{\\partial}{\\partial w_i}\\left(\\frac{\\partial L}{\\partial w_j}\\right) \\\\\n", 474 | " &= \\frac{\\partial}{\\partial w_i}\\left(\\frac{2e\\partial e}{\\partial w_j}\\right) \\\\\n", 475 | " &= 2\\frac{\\partial e}{\\partial w_i}\\frac{\\partial e}{\\partial w_j}+2e\\frac{\\partial^2 e}{\\partial w_j \\partial w_i}.\n", 476 | "\\end{align*}\n", 477 | "\n", 478 | "We neglect the second-order derivative term because the term is relatively small ($e$ is small)\n", 479 | "\n", 480 | "\\begin{align*}\n", 481 | " \\mathbf{H}_{ij}\n", 482 | " &\\propto \\frac{\\partial e}{\\partial w_i}\\frac{\\partial e}{\\partial w_j},\n", 483 | "\\end{align*}\n", 484 | "\n", 485 | "and as the error $e$ is a constant\n", 486 | "\n", 487 | "\\begin{align*}\n", 488 | " \\mathbf{H}_{ij}\n", 489 | " &\\propto \\frac{\\partial L}{\\partial w_i}\\frac{\\partial L}{\\partial w_j},\n", 490 | "\\end{align*}\n", 491 | "\n", 492 | "then the full Hessian becomes\n", 493 | "\n", 494 | "\\begin{align*}\n", 495 | " \\mathbf{H} &\\propto (\\mathbf{B}\\times\\mathbf{A})\\times(\\mathbf{B}\\times\\mathbf{A}).\n", 496 | "\\end{align*}\n" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "metadata": { 502 | "id": "EDCQuAdlboHT" 503 | }, 504 | "source": [ 505 | "# the main function to compute gradient norm and minimum ratio\n", 506 | "def main(model, train, target):\n", 507 | " criterion = nn.MSELoss()\n", 508 | "\n", 509 | " gradient_norm = compute_gradient_norm(model, criterion, train, target)\n", 510 | " minimum_ratio = compute_minimum_ratio(model, criterion, train, target)\n", 511 | "\n", 512 | " print('gradient norm: {}, minimum ratio: {}'.format(gradient_norm, minimum_ratio))" 513 | ], 514 | "execution_count": 9, 515 | "outputs": [] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": { 520 | "id": "75_j79LlbsTY" 521 | }, 522 | "source": [ 523 | "After running this block, you will get the value of `gradient norm` and `minimum ratio`. Determine whether the model is at `local minima like`, `saddle point`, or `none of the above`, and then submit your choice to NTU COOL." 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "metadata": { 529 | "colab": { 530 | "base_uri": "https://localhost:8080/" 531 | }, 532 | "id": "fsGdnelqbqBu", 533 | "outputId": "5dfd75e3-69f0-4637-8f22-2683ccbaa46a" 534 | }, 535 | "source": [ 536 | "if __name__ == '__main__':\n", 537 | " # fix random seed\n", 538 | " torch.manual_seed(0)\n", 539 | "\n", 540 | " # reset compute dictionaries\n", 541 | " activations = defaultdict(int)\n", 542 | " hess = defaultdict(float)\n", 543 | "\n", 544 | " # compute Hessian\n", 545 | " main(model, train, target)" 546 | ], 547 | "execution_count": 10, 548 | "outputs": [ 549 | { 550 | "output_type": "stream", 551 | "text": [ 552 | "gradient norm: 0.000772382496506907, minimum ratio: 0.48046875\n" 553 | ], 554 | "name": "stdout" 555 | } 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "metadata": { 561 | "id": "wU-vs9zXbvpc" 562 | }, 563 | "source": [ 564 | "● gradient norm < 1e-3 and minimum ratio > 0.5 => local minima like\n", 565 | "\n", 566 | "● gradient norm < 1e-3 and minimum ratio <= 0.5 => saddle point\n", 567 | "\n", 568 | "● gradient norm >= 1e-3 => none of the above." 569 | ] 570 | } 571 | ] 572 | } -------------------------------------------------------------------------------- /HW3/HW03.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW3/HW03.pdf -------------------------------------------------------------------------------- /HW4/HW04.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW4/HW04.pdf -------------------------------------------------------------------------------- /HW4/hw4_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW4/hw4_report.pdf -------------------------------------------------------------------------------- /HW5/HW05.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW5/HW05.pdf -------------------------------------------------------------------------------- /HW6/HW06.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW6/HW06.pdf -------------------------------------------------------------------------------- /HW6/hw6_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW6/hw6_report.pdf -------------------------------------------------------------------------------- /HW7/HW07.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW7/HW07.pdf -------------------------------------------------------------------------------- /HW8/HW08.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW8/HW08.pdf -------------------------------------------------------------------------------- /HW8/hw8_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW8/hw8_report.pdf -------------------------------------------------------------------------------- /HW9/HW09.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HW9/HW09.pdf -------------------------------------------------------------------------------- /HomeworkParticipation.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/HomeworkParticipation.JPG -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![cover](https://github.com/Offliners/OFF/blob/main/cover.png) 2 | 3 | Course Syllabus : [Link](https://speech.ee.ntu.edu.tw/~hylee/ml/2021-spring.html) 4 | 5 | ## Outline 6 | |#|Homework|Slide|Code|Public score|Private Score|Score| 7 | |-|-|-|-|-|-|-| 8 | |0|Colab Tutorial|[Link](HW0/Google_Colab_Tutorial.pdf)|[Colab Tutorial](HW0/Colab_Tutorial.ipynb)|x|x|x| 9 | |0|Pytorch Tutorial|[Link](HW0/Pytorch_Tutorial_1.pdf) [Link](HW0/Pytorch_Tutorial_2.pdf)|[Pytorch Tutorial](HW0/Pytorch_Tutorial.ipynb)|x|x|x| 10 | |1|Regression|[Link](HW1/HW01.pdf)|[COVID-19 Cases Prediction](HW1/homework1.ipynb)|`0.86928`|`0.92417`|9| 11 | |2|Classification|[Link](HW2/HW02.pdf)|[Phoneme Classification](HW2/homework2_1.ipynb)|`0.75171`|`0.74957`|6| 12 | |2|Linear Algebra|[Link](HW2/HW02.pdf)|[Hessian Matrix](HW2/homework2_2.ipynb)|x|x|2| 13 | |3|CNN|[Link](HW3/HW03.pdf)|[Image Classification](HW3/homework3.ipynb)|`0.77956`|`0.77166`|8| 14 | |4|Self-Attention|[Link](HW4/HW04.pdf)|[Phoneme classification](HW4/homework4.ipynb)|`0.95928`|`0.95833`|10+0.5([report](HW4/hw4_report.pdf))| 15 | |5|Transformer|[Link](HW5/HW05.pdf)|[Machine Translation](HW5/homework5.ipynb)|`26.44`|`25.54`|8| 16 | |6|GAN|[Link](HW6/HW06.pdf)|[Anime Face Generation](HW6/homework6.ipynb)|`0.710`, `7716.47`|`0.710`, `7716.47`|10+0.5([report](HW6/hw6_report.pdf))| 17 | |7|BERT|[Link](HW7/HW07.pdf)|[Question Answering](HW7/homework7.ipynb)|`0.81178`|`0.81833`|9| 18 | |8|Anomaly Detection|[Link](HW8/HW08.pdf)|[Anomaly Detection](HW8/homework8.ipynb)|`0.89049`|`0.88126`|10+0.5([report](HW8/hw8_report.pdf))| 19 | |9|Explainable AI|[Link](HW9/HW09.pdf)|[Explainable AI](HW9/homework9.ipynb)|x|x|8.4| 20 | |10|Attack|[Link](HW10/HW10.pdf)|[Adversarial Attack](HW10/homework10.ipynb)|`0.000`|`0.010`|10+0.5([report](HW10/hw10_report.pdf))| 21 | |11|Adaptation|[Link](HW11/HW11.pdf)|[Domain Adaptation](HW11/homework11.ipynb)|`0.81598`|`0.81690`|10+0.5([report](HW11/hw11_report.pdf))| 22 | |12|RL|[Link](HW12/HW12.pdf)|[Reinforcement Learning](HW12/homework12.ipynb)|`280`|`280`|10+0.5([report](HW12/hw12_report.pdf))| 23 | |13|Compression|[Link](HW13/HW13.pdf)|[Food Classification](HW13/homework13.ipynb)|`0.80167`|`0.81350`|9.5| 24 | |14|Life-Long Learning|[Link](HW14/HW14.pdf)|[Permuted MNIST](HW14/homework14.ipynb)|x|x|8.8| 25 | |15|Meta Learning|[Link](HW15/HW15.pdf)|[Few-shot image classification](HW15/homework15.ipynb)|x|x|9.2| 26 | 27 | 45 | 46 | ## Useful Tips 47 | ### Prevent Google Colab from disconnecting (valid utill 2021/06/25) 48 | Press `F12`,and enter this code in console,then press `enter` 49 | ```javascript 50 | function ClickConnect(){ 51 | console.log("Connnect Clicked - Start"); 52 | document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click(); 53 | console.log("Connnect Clicked - End"); 54 | }; 55 | setInterval(ClickConnect, 60000) 56 | ``` 57 | 58 | ### Auto save output file on Google Colab 59 | insert code cell at the bottom 60 | ```python 61 | from google.colab import files 62 | files.download("output_file.csv") # "output_file.csv" must be your output file name 63 | ``` 64 | 65 | ### Display information of GPU of Google Colab 66 | insert code cell to check which GPU is assigned 67 | ```shell 68 | !nvidia-smi 69 | ``` 70 | ### "Sorry, something went wrong. Reload?" when viewing *.ipynb on Github 71 | Copy the URL to https://nbviewer.jupyter.org/ 72 | 73 | ## Student Source 74 |

75 | Student Source 76 |

77 | 78 | ## Homework Participation 79 |

80 | Class Participation 81 |

82 | 83 | ## My Performance 84 |

85 | My Performance 86 |

87 | 88 | ## Reference 89 | * TA's github : https://github.com/ga642381/ML2021-Spring 90 | * Pytorch documentation : https://pytorch.org/docs/stable/index.html 91 | -------------------------------------------------------------------------------- /StudentSource.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/StudentSource.jpg -------------------------------------------------------------------------------- /cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/cover.png -------------------------------------------------------------------------------- /myPerformance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/myPerformance.jpg -------------------------------------------------------------------------------- /slides/GuestLecture_QML.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/GuestLecture_QML.pdf -------------------------------------------------------------------------------- /slides/W14_PAC-introduction.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/W14_PAC-introduction.pdf -------------------------------------------------------------------------------- /slides/attack_v3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/attack_v3.pdf -------------------------------------------------------------------------------- /slides/auto_v8.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/auto_v8.pdf -------------------------------------------------------------------------------- /slides/bert_v8.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/bert_v8.pdf -------------------------------------------------------------------------------- /slides/classification_v2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/classification_v2.pdf -------------------------------------------------------------------------------- /slides/cnn_v4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/cnn_v4.pdf -------------------------------------------------------------------------------- /slides/da_v6.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/da_v6.pdf -------------------------------------------------------------------------------- /slides/drl_v5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/drl_v5.pdf -------------------------------------------------------------------------------- /slides/gan_v10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/gan_v10.pdf -------------------------------------------------------------------------------- /slides/introduction-2021-v6-Chinese.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/introduction-2021-v6-Chinese.pdf -------------------------------------------------------------------------------- /slides/introduction-2021-v6-English.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/introduction-2021-v6-English.pdf -------------------------------------------------------------------------------- /slides/life_v2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/life_v2.pdf -------------------------------------------------------------------------------- /slides/meta_v3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/meta_v3.pdf -------------------------------------------------------------------------------- /slides/normalization_v4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/normalization_v4.pdf -------------------------------------------------------------------------------- /slides/optimizer_v4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/optimizer_v4.pdf -------------------------------------------------------------------------------- /slides/overfit-v6.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/overfit-v6.pdf -------------------------------------------------------------------------------- /slides/regression (v16).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/regression (v16).pdf -------------------------------------------------------------------------------- /slides/self_v7.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/self_v7.pdf -------------------------------------------------------------------------------- /slides/seq2seq_v9.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/seq2seq_v9.pdf -------------------------------------------------------------------------------- /slides/small-gradient-v7.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/small-gradient-v7.pdf -------------------------------------------------------------------------------- /slides/tiny_v7.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/tiny_v7.pdf -------------------------------------------------------------------------------- /slides/xai_v4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Offliners/NTUML-2021Spring/21e38b137d76a6f6e57673c1871cd93e413bee02/slides/xai_v4.pdf --------------------------------------------------------------------------------