├── hw-02 ├── README.md └── hw2_strong.ipynb ├── hw-03 ├── README.md └── hw3_strong.ipynb ├── hw-04 ├── hw4_report.pdf ├── README.md └── hw4_strong.ipynb ├── hw-06 ├── hw6_report.pdf ├── README.md ├── hw6_StyleGAN.ipynb └── hw6_WGAN.ipynb ├── hw-07 ├── hw7_report.pdf └── README.md ├── hw-08 ├── hw8_report.pdf ├── README.md └── hw8_strong.ipynb ├── hw-11 ├── hw11_report.pdf └── README.md ├── hw-12 ├── hw12_report.pdf └── README.md ├── hw-13 ├── student_best.ckpt ├── network_pruning_Q3.png ├── README.md └── network_pruning_Q3.ipynb ├── hw-10 └── README.md ├── README.md ├── .gitignore └── hw-01 ├── preprocessing.ipynb └── hw1_boss.ipynb /hw-02/README.md: -------------------------------------------------------------------------------- 1 | # HW2 2 | ## 執行方式 3 | 直接在 Jupyter Notebook 中執行整個 Notebook 即可 -------------------------------------------------------------------------------- /hw-03/README.md: -------------------------------------------------------------------------------- 1 | # HW3 2 | ## 執行方式 3 | 直接在 Jupyter Notebook 中執行整個 Notebook 即可 -------------------------------------------------------------------------------- /hw-04/hw4_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujunkuo/ML2022-Homework/HEAD/hw-04/hw4_report.pdf -------------------------------------------------------------------------------- /hw-06/hw6_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujunkuo/ML2022-Homework/HEAD/hw-06/hw6_report.pdf -------------------------------------------------------------------------------- /hw-07/hw7_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujunkuo/ML2022-Homework/HEAD/hw-07/hw7_report.pdf -------------------------------------------------------------------------------- /hw-08/hw8_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujunkuo/ML2022-Homework/HEAD/hw-08/hw8_report.pdf -------------------------------------------------------------------------------- /hw-11/hw11_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujunkuo/ML2022-Homework/HEAD/hw-11/hw11_report.pdf -------------------------------------------------------------------------------- /hw-12/hw12_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujunkuo/ML2022-Homework/HEAD/hw-12/hw12_report.pdf -------------------------------------------------------------------------------- /hw-13/student_best.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujunkuo/ML2022-Homework/HEAD/hw-13/student_best.ckpt -------------------------------------------------------------------------------- /hw-13/network_pruning_Q3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujunkuo/ML2022-Homework/HEAD/hw-13/network_pruning_Q3.png -------------------------------------------------------------------------------- /hw-08/README.md: -------------------------------------------------------------------------------- 1 | # HW8 2 | 3 | ## 執行方式 4 | 5 | 直接執行整個 ipynb 檔案即可 6 | 7 | 8 | ## 結果紀錄 9 | 10 | ```(AUC) Public Leaderboard: 0.78175``` -------------------------------------------------------------------------------- /hw-12/README.md: -------------------------------------------------------------------------------- 1 | # HW12 2 | 3 | ## RL 演算法 4 | - DQN 5 | 6 | ## 執行方式 7 | - 直接執行整個 ipynb 檔案即可 8 | 9 | ## 結果紀錄 10 | ```(Reward) Public Leaderboard: 264``` -------------------------------------------------------------------------------- /hw-13/README.md: -------------------------------------------------------------------------------- 1 | # HW13 2 | 3 | ## 執行方式 4 | - 直接執行整個 ipynb 檔案即可 5 | 6 | ## 結果紀錄 7 | ``` Parameter Number: Less than 100,000 ``` 8 | 9 | ```Public Leaderboard Accuracy: 0.79 ~ 0.81``` -------------------------------------------------------------------------------- /hw-10/README.md: -------------------------------------------------------------------------------- 1 | # HW10 2 | 3 | ## 執行方式 4 | 5 | - 直接執行整個 ipynb 檔案即可 6 | - 最後 Leaderboard 的結果使用的是執行完程式後,所生成的 dmifgsm 檔案 7 | 8 | 9 | ## 結果紀錄 10 | 11 | ```(Accuracy) Public Leaderboard: 0.04``` -------------------------------------------------------------------------------- /hw-04/README.md: -------------------------------------------------------------------------------- 1 | # HW4 2 | ## 執行方式 3 | 1. 在目錄底下建立 models/ 與 outputs/ 資料夾,以存放模型與輸出結果 4 | 2. 再以 Jupyter Notebook 執行整個 Notebook 即可 5 | > Conformer Reference: [Conformer](https://github.com/lucidrains/conformer) -------------------------------------------------------------------------------- /hw-11/README.md: -------------------------------------------------------------------------------- 1 | # HW11 2 | 3 | ## 優化方法 4 | - Adaptive Lambda 5 | - Total Epochs = 1400 6 | 7 | ## 執行方式 8 | - 直接執行整個 ipynb 檔案即可 9 | 10 | ## 結果紀錄 11 | ```(Accuracy) Public Leaderboard: 0.78328``` -------------------------------------------------------------------------------- /hw-07/README.md: -------------------------------------------------------------------------------- 1 | # HW7 2 | 3 | ## 執行方式 4 | 5 | 直接執行整個 ipynb 檔案即可 6 | 7 | 8 | ## 結果紀錄 9 | 10 | (Accuracy) Public Leaderboard: 0.8459 11 | 12 | 13 | ## 參考資料 14 | 15 | - [MacBERT](https://huggingface.co/luhua/chinese_pretrain_mrc_macbert_large) -------------------------------------------------------------------------------- /hw-06/README.md: -------------------------------------------------------------------------------- 1 | # HW6 2 | 3 | ## 執行方式 4 | 5 | 1. WGAN & WGAN-GP: Run with Jupyter Notebook 6 | 7 | 2. StyleGAN: 8 | - Optional (For Multiple GPUs): 9 | - $ export CUDA_VISIBLE_DEVICES=1 10 | - Install: 11 | - $ pip install stylegan2_pytorch 12 | - Train: 13 | - $ stylegan2_pytorch --data ../faces --multi-gpus --image-size 64 --batch-size 1 --num-train-steps 40000 --gradient-accumulate-every 8 14 | - Generate: 15 | - $ for run in {1..1000}; do stylegan2_pytorch --generate --num_image_tiles 1; done 16 | - Modify File Name 17 | - Run styleGAN.ipynb to modify file name 18 | 19 | 20 | ## 結果紀錄 21 | 22 | - styleGAN Regular: 0.575 8567.92 23 | - *(Best) styleGAN MR: 0.749 8422.84 24 | - styleGAN EMA: 0.735 8465.07 25 | 26 | ## 參考資料 27 | 28 | - WGAN: [WGAN-Link](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py) 29 | - WGAN-GP: [WGAN-GP-Link](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py) 30 | - StyleGAN: [StyleGAN-Link](https://github.com/lucidrains/stylegan2-pytorch) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML2022-Homework 2 | 3 | Welcome to the repository for my **Machine Learning** course assignments! 🎉 4 | 5 | This repository contains my code implementations and reports for the assignments assigned during the Machine Learning course at National Taiwan University (NTU) in the spring semester of 2022. The assignments cover a wide range of topics in machine learning, facilitating an enhanced understanding and practical application of various algorithms and techniques. 6 | 7 | ## Course Information 8 | - Institution: **National Taiwan University, College of Electrical Engineering** 9 | - Course Title: **Machine Learning (2022 Spring)** 10 | - Instructor: **Professor [Hung-Yi Lee](https://speech.ee.ntu.edu.tw/~hylee/index.php)** 11 | - Course Website: https://speech.ee.ntu.edu.tw/~hylee/ml/2022-spring.php 12 | 13 | ## Homework Titles 14 | - HW-01: Regression 15 | - HW-02: Classification 16 | - HW-03: Convolutional Neural Networks (CNN) 17 | - HW-04: Self-Attention Mechanisms 18 | - HW-05: Transformer Architecture 19 | - HW-06: Generative Adversarial Networks (GAN) 20 | - HW-07: Bidirectional Encoder Representations from Transformers (BERT) 21 | - HW-08: Autoencoder 22 | - HW-09: Explainable AI 23 | - HW-10: Adversarial Attacks 24 | - HW-11: Domain Adaptation 25 | - HW-12: Reinforcement Learning (RL) 26 | - HW-13: Compression Techniques 27 | - HW-14: Life-long Learning **(⚠️ Not Yet Implemented)** 28 | - HW-15: Meta Learning **(⚠️ Not Yet Implemented)** 29 | 30 | ## Code Template 31 | The code for each homework assignment is based on the template provided by the course instructor and teaching assistants. If there are any concerns regarding copyright or related issues, please contact me immediately at johnboy880313@gmail.com. 32 | 33 | ## Contribution 34 | Contributions to improve the codebase, documentation, or address any issues are welcomed. If you have ideas for enhancements, bug fixes, or would like to collaborate on pending assignments (e.g., HW-14 and HW-15), please feel free to create a pull request or get in touch. Your feedback is valuable, and I appreciate any contributions or suggestions that enhance the code quality and understanding of the covered topics. 35 | 36 | Let's learn and grow together in the fascinating field of Machine Learning! 🚀 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # MAC File 2 | .DS_Store 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /hw-06/hw6_StyleGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "771c4130", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# import torch\n", 11 | "# from torchvision.utils import save_image\n", 12 | "# from stylegan2_pytorch import ModelLoader\n", 13 | "\n", 14 | "# loader = ModelLoader(\n", 15 | "# base_dir = './models/defult/model_40.pt', # path to where you invoked the command line tool\n", 16 | "# name = 'default' # the project name, defaults to 'default'\n", 17 | "# )\n", 18 | "# for i in range(1000):\n", 19 | "# noise = torch.randn(1, 512).cuda() # noise\n", 20 | "# styles = loader.noise_to_styles(noise, trunc_psi = 0.7) # pass through mapping network\n", 21 | "# images = loader.styles_to_images(styles) # call the generator on intermediate style vectors\n", 22 | "\n", 23 | "# save_image(images, f'./output/{i+1}.jpg') # save your images, or do whatever you desire" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "8bf12cf9", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import os\n", 34 | "files = os.listdir('results/default')\n", 35 | "print(len(files))\n", 36 | "files = [each for each in files if \"generated\" in each]\n", 37 | "print(len(files))" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "16e00965", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "files.sort()\n", 48 | "files = files[3:]\n", 49 | "files" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "a9a60405", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "ema_files = []\n", 60 | "for i, v in enumerate(files):\n", 61 | " if i % 3 == 0:\n", 62 | " ema_files.append(v)\n", 63 | "print(len(ema_files))\n", 64 | "ema_files" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "149122e1", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "mr_files = []\n", 75 | "for i, v in enumerate(files):\n", 76 | " if (i-1) % 3 == 0:\n", 77 | " mr_files.append(v)\n", 78 | "print(len(mr_files))\n", 79 | "mr_files" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "040e8618", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "regular_files = []\n", 90 | "for i, v in enumerate(files):\n", 91 | " if (i-2) % 3 == 0:\n", 92 | " regular_files.append(v)\n", 93 | "print(len(regular_files))\n", 94 | "regular_files" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "59580494", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "for i, name in enumerate(ema_files):\n", 105 | " os.system(f'mv ./results/default/{name} ./results/ema_output/{i+1}.jpg')" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "40816998", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "for i, name in enumerate(mr_files):\n", 116 | " os.system(f'mv ./results/default/{name} ./results/mr_output/{i+1}.jpg')" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "ac513a5a", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "for i, name in enumerate(regular_files):\n", 127 | " os.system(f'mv ./results/default/{name} ./results/regular_output/{i+1}.jpg')" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "0c89e976", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "!ls" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "a1489ada", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "%cd results" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "018ef0d3", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "!ls" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "0dedb541", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "%cd ema_output\n", 168 | "!tar -zcf ../ema_output.tgz *.jpg\n", 169 | "%cd .." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "264be038", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "%cd mr_output\n", 180 | "!tar -zcf ../mr_output.tgz *.jpg\n", 181 | "%cd .." 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "53cf46bb", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "%cd regular_output\n", 192 | "!tar -zcf ../regular_output.tgz *.jpg\n", 193 | "%cd .." 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "kuokuo_env", 200 | "language": "python", 201 | "name": "kuokuo_env" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.7.5" 214 | }, 215 | "toc": { 216 | "base_numbering": 1, 217 | "nav_menu": {}, 218 | "number_sections": true, 219 | "sideBar": true, 220 | "skip_h1_title": false, 221 | "title_cell": "Table of Contents", 222 | "title_sidebar": "Contents", 223 | "toc_cell": false, 224 | "toc_position": {}, 225 | "toc_section_display": true, 226 | "toc_window_display": false 227 | } 228 | }, 229 | "nbformat": 4, 230 | "nbformat_minor": 5 231 | } 232 | -------------------------------------------------------------------------------- /hw-01/preprocessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Data Preprocessing" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## View Training Data" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import pandas as pd\n", 24 | "import numpy as np" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "df = pd.read_csv('./covid.train.csv')\n", 34 | "df.head()" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# Calculate Pearson Correlation\n", 44 | "cor = df.corr()\n", 45 | "# Correlation with output variable\n", 46 | "cor_target = abs(cor[\"tested_positive.4\"])\n", 47 | "# Selecting highly correlated features\n", 48 | "relevant_features = cor_target[cor_target > 0.8] # 0.5" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "df = df[list(relevant_features.index)]" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "pd.set_option('display.max_columns', None)\n", 67 | "df.head()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "len(df.columns)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "df.shape" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def concat_df(df_old, df_new):\n", 95 | " new_cols = {x: y for x, y in zip(df_new.columns, df_old.columns)}\n", 96 | " df_out = df_old.append(df_new.rename(columns=new_cols))\n", 97 | " return df_out" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "## Split Training Data (兩天一組)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "df1 = df.iloc[:, :10]\n", 114 | "df2 = df.iloc[:, 5:15]\n", 115 | "df3 = df.iloc[:, 10:20]\n", 116 | "df4 = df.iloc[:, 15:25]" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "result1 = concat_df(df1, df2)\n", 126 | "result1 = concat_df(result1, df3)\n", 127 | "result1 = concat_df(result1, df4)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "print(result1.shape)\n", 137 | "result1.head()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "## Split Test Data (兩天一組)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "df_test = pd.read_csv('./covid.test.csv')\n", 154 | "df_test = df_test[list(relevant_features.index)[:-1]]\n", 155 | "df_test.head()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "df_test.shape" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "df_test_1 = df_test.iloc[:, :10]\n", 174 | "df_test_2 = df_test.iloc[:, 5:15]\n", 175 | "df_test_3 = df_test.iloc[:, 10:20]" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "result_test_1 = concat_df(df_test_1, df_test_2)\n", 185 | "result_test_1 = concat_df(result_test_1, df_test_3)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "print(result_test_1.shape)\n", 195 | "result_test_1.head()" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "## 合併產生新的訓練資料" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "result_train_test_1 = concat_df(result1, result_test_1)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "result_train_test_1.shape" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "result_train_test_1.tail()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "result_train_test_1.to_csv(\"./covid.train_test_1.csv\", index=False)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "## 產生新的測試資料" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "df_test = pd.read_csv('./covid.test.csv')\n", 255 | "df_test.head()" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "df_test = df_test[list(relevant_features.index)[:-1]]\n", 265 | "df_test.head()" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "df_test1 = df_test.iloc[:, -9:]\n", 275 | "df_test1.head()" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "df_test1.to_csv(\"./covid.test_1.csv\", index=False)" 285 | ] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "Python 3 (ipykernel)", 291 | "language": "python", 292 | "name": "python3" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.7.5" 305 | }, 306 | "toc": { 307 | "base_numbering": 1, 308 | "nav_menu": {}, 309 | "number_sections": true, 310 | "sideBar": true, 311 | "skip_h1_title": false, 312 | "title_cell": "Table of Contents", 313 | "title_sidebar": "Contents", 314 | "toc_cell": false, 315 | "toc_position": {}, 316 | "toc_section_display": true, 317 | "toc_window_display": false 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 5 322 | } 323 | -------------------------------------------------------------------------------- /hw-13/network_pruning_Q3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Import some useful packages for this homework\n", 10 | "import os\n", 11 | "import random\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import torch.nn.functional as F\n", 19 | "import torch.nn.utils.prune as prune\n", 20 | "import torchvision.transforms as transforms\n", 21 | "\n", 22 | "from PIL import Image\n", 23 | "from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset # \"ConcatDataset\" and \"Subset\" are possibly useful\n", 24 | "from torchvision.datasets import DatasetFolder, VisionDataset\n", 25 | "from torchsummary import summary\n", 26 | "from tqdm.auto import tqdm" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "import matplotlib.pyplot as plt" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 45 | "# define testing transforms\n", 46 | "test_tfm = transforms.Compose([\n", 47 | " # It is not encouraged to modify this part if you are using the provided teacher model. This transform is stardard and good enough for testing.\n", 48 | " transforms.Resize(256),\n", 49 | " transforms.CenterCrop(224),\n", 50 | " transforms.ToTensor(),\n", 51 | " normalize,\n", 52 | "])" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "class FoodDataset(Dataset):\n", 62 | " def __init__(self, path, tfm=test_tfm, files = None):\n", 63 | " super().__init__()\n", 64 | " self.path = path\n", 65 | " self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(\".jpg\")])\n", 66 | " if files != None:\n", 67 | " self.files = files\n", 68 | " print(f\"One {path} sample\",self.files[0])\n", 69 | " self.transform = tfm\n", 70 | " \n", 71 | " def __len__(self):\n", 72 | " return len(self.files)\n", 73 | " \n", 74 | " def __getitem__(self,idx):\n", 75 | " fname = self.files[idx]\n", 76 | " im = Image.open(fname)\n", 77 | " im = self.transform(im)\n", 78 | " try:\n", 79 | " label = int(fname.split(\"/\")[-1].split(\"_\")[0])\n", 80 | " except:\n", 81 | " label = -1 # test has no label\n", 82 | " return im,label" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# Form valid dataloaders\n", 92 | "valid_set = FoodDataset(os.path.join('./food11-hw13', \"validation\"), tfm=test_tfm)\n", 93 | "valid_loader = DataLoader(valid_set, batch_size=64, shuffle=False, num_workers=0, pin_memory=True)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 103 | "device = torch.device(\"cpu\")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "def evaluate(model):\n", 113 | " model.to(device)\n", 114 | " model.eval()\n", 115 | "\n", 116 | " valid_accs = []\n", 117 | " valid_lens = []\n", 118 | "\n", 119 | " for batch in tqdm(valid_loader):\n", 120 | " # A batch consists of image data and corresponding labels.\n", 121 | " imgs, labels = batch\n", 122 | " imgs = imgs.to(device)\n", 123 | " labels = labels.to(device)\n", 124 | "\n", 125 | " # We don't need gradient in validation.\n", 126 | " # Using torch.no_grad() accelerates the forward process.\n", 127 | " with torch.no_grad():\n", 128 | " logits = model(imgs) # MEDIUM BASELINE\n", 129 | "\n", 130 | " # Compute the accuracy for current batch.\n", 131 | " acc = (logits.argmax(dim=-1) == labels).float().sum()\n", 132 | "\n", 133 | " # Record the loss and accuracy.\n", 134 | " batch_len = len(imgs)\n", 135 | " valid_accs.append(acc)\n", 136 | " valid_lens.append(batch_len)\n", 137 | "\n", 138 | " # The average accuracy for entire validation set is the average of the recorded values.\n", 139 | " valid_acc = sum(valid_accs) / sum(valid_lens)\n", 140 | " return valid_acc.item()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": { 146 | "id": "kWEnrhWatOQb" 147 | }, 148 | "source": [ 149 | "Let's say now you want to prune all the parameters named with `weight` in all the `nn.Conv2d` layers in the `model`, with pruning ratio **0.2**. Then please refer to the code below to achieve this." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "valid_acc_list = []\n", 159 | "\n", 160 | "for ratio in np.arange(0, 1, 0.05):\n", 161 | " # Specify the pruning ratio\n", 162 | " ratio = round(ratio, 2)\n", 163 | " # Load model\n", 164 | " teacher_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False, num_classes=11)\n", 165 | " teacher_ckpt_path = os.path.join('./food11-hw13', \"resnet18_teacher.ckpt\")\n", 166 | " teacher_model.load_state_dict(torch.load(teacher_ckpt_path, map_location='cpu'))\n", 167 | " for name, module in teacher_model.named_modules():\n", 168 | " if isinstance(module, torch.nn.Conv2d): # if the nn.module is torch.nn.Conv2d\n", 169 | " prune.l1_unstructured(module, name='weight', amount=ratio) # use 'prune' method provided by 'torch.nn.utils.prune' to prune the weight parameters in the nn.Conv2d layers\n", 170 | " # Next, you just have to generize the above code to different ratio and test the accuracy on the validation set of food11-hw13.\n", 171 | " valid_acc = evaluate(teacher_model)\n", 172 | " valid_acc_list.append(valid_acc)\n", 173 | " print(valid_acc)\n" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "plt.figure(figsize=(12,6))\n", 183 | "plt.plot(np.arange(0, 1, 0.05), valid_acc_list, \"-o\")\n", 184 | "plt.grid(ls=\"--\")\n", 185 | "plt.xticks(np.arange(0, 1, 0.05))\n", 186 | "plt.title(\"Pruning Ratio vs. Model Accuracy\")\n", 187 | "plt.xlabel(\"Pruning Ratio\")\n", 188 | "plt.ylabel(\"Model Accuracy\")\n", 189 | "# plt.savefig(\"pruning.png\")\n", 190 | "plt.show()" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "valid_acc_list" 200 | ] 201 | } 202 | ], 203 | "metadata": { 204 | "colab": { 205 | "collapsed_sections": [], 206 | "name": "network pruning example for report Q3-1", 207 | "provenance": [] 208 | }, 209 | "kernelspec": { 210 | "display_name": "kuokuo_env", 211 | "language": "python", 212 | "name": "kuokuo_env" 213 | }, 214 | "language_info": { 215 | "codemirror_mode": { 216 | "name": "ipython", 217 | "version": 3 218 | }, 219 | "file_extension": ".py", 220 | "mimetype": "text/x-python", 221 | "name": "python", 222 | "nbconvert_exporter": "python", 223 | "pygments_lexer": "ipython3", 224 | "version": "3.7.5" 225 | }, 226 | "toc": { 227 | "base_numbering": 1, 228 | "nav_menu": {}, 229 | "number_sections": true, 230 | "sideBar": true, 231 | "skip_h1_title": false, 232 | "title_cell": "Table of Contents", 233 | "title_sidebar": "Contents", 234 | "toc_cell": false, 235 | "toc_position": {}, 236 | "toc_section_display": true, 237 | "toc_window_display": false 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 1 242 | } 243 | -------------------------------------------------------------------------------- /hw-01/hw1_boss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "guE34D3Fj2R9" 7 | }, 8 | "source": [ 9 | "# **Homework 1: COVID-19 Cases Prediction (Regression) -1**" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "V57zhcTp1Xxb" 16 | }, 17 | "source": [ 18 | "Objectives:\n", 19 | "* Solve a regression problem with deep neural networks (DNN).\n", 20 | "* Understand basic DNN training tips.\n", 21 | "* Familiarize yourself with PyTorch.\n", 22 | "\n", 23 | "If you have any questions, please contact the TAs via TA hours, NTU COOL, or email to mlta-2022-spring@googlegroups.com" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "Tm2aXcb-j9Fc" 30 | }, 31 | "source": [ 32 | "# Download data\n", 33 | "If the Google Drive links below do not work, you can download data from [Kaggle](https://www.kaggle.com/c/ml2022spring-hw1/data), and upload data manually to the workspace." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": { 40 | "colab": { 41 | "base_uri": "https://localhost:8080/" 42 | }, 43 | "id": "YPmfl-awlKZA", 44 | "outputId": "b97be226-663d-4c3a-8833-950e8f79e6ab" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "!gdown --id '1kLSW_-cW2Huj7bh84YTdimGBOJaODiOS' --output covid.train.csv\n", 49 | "!gdown --id '1iiI5qROrAhZn-o4FPqsE97bMzDEFvIdg' --output covid.test.csv" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "igqIMEgu64-F" 56 | }, 57 | "source": [ 58 | "# Import packages" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "id": "xybQNYCXYu13" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "# Numerical Operations\n", 70 | "import math\n", 71 | "import numpy as np\n", 72 | "\n", 73 | "# Reading/Writing Data\n", 74 | "import pandas as pd\n", 75 | "import os\n", 76 | "import csv\n", 77 | "\n", 78 | "# For Progress Bar\n", 79 | "from tqdm import tqdm\n", 80 | "\n", 81 | "# Pytorch\n", 82 | "import torch \n", 83 | "import torch.nn as nn\n", 84 | "from torch.utils.data import Dataset, DataLoader, random_split\n", 85 | "\n", 86 | "# For plotting learning curve\n", 87 | "from torch.utils.tensorboard import SummaryWriter" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "id": "fTAVqRfc2KK3" 94 | }, 95 | "source": [ 96 | "# Some Utility Functions\n", 97 | "\n", 98 | "You do not need to modify this part." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "id": "RbrcpfYN2I-H" 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "def same_seed(seed): \n", 110 | " '''Fixes random number generator seeds for reproducibility.'''\n", 111 | " torch.backends.cudnn.deterministic = True\n", 112 | " torch.backends.cudnn.benchmark = False\n", 113 | " np.random.seed(seed)\n", 114 | " torch.manual_seed(seed)\n", 115 | " if torch.cuda.is_available():\n", 116 | " torch.cuda.manual_seed_all(seed)\n", 117 | "\n", 118 | "def train_valid_split(data_set, valid_ratio, seed):\n", 119 | " '''Split provided training data into training set and validation set'''\n", 120 | " valid_set_size = int(valid_ratio * len(data_set)) \n", 121 | " train_set_size = len(data_set) - valid_set_size\n", 122 | " train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))\n", 123 | " return np.array(train_set), np.array(valid_set)\n", 124 | "\n", 125 | "def predict(test_loader, model, device):\n", 126 | " model.eval() # Set your model to evaluation mode.\n", 127 | " preds = []\n", 128 | " for x in tqdm(test_loader):\n", 129 | " x = x.to(device) \n", 130 | " with torch.no_grad(): \n", 131 | " pred = model(x) \n", 132 | " preds.append(pred.detach().cpu()) \n", 133 | " preds = torch.cat(preds, dim=0).numpy() \n", 134 | " return preds" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "id": "IqO3lTm78nNO" 141 | }, 142 | "source": [ 143 | "# Dataset" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": { 150 | "id": "-mjaJM0wprMs" 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "class COVID19Dataset(Dataset):\n", 155 | " '''\n", 156 | " x: Features.\n", 157 | " y: Targets, if none, do prediction.\n", 158 | " '''\n", 159 | " def __init__(self, x, y=None):\n", 160 | " if y is None:\n", 161 | " self.y = y\n", 162 | " else:\n", 163 | " self.y = torch.FloatTensor(y)\n", 164 | " self.x = torch.FloatTensor(x)\n", 165 | " \n", 166 | " def __getitem__(self, idx):\n", 167 | " if self.y is None:\n", 168 | " return self.x[idx]\n", 169 | " else:\n", 170 | " return self.x[idx], self.y[idx]\n", 171 | "\n", 172 | " def __len__(self):\n", 173 | " return len(self.x)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": { 179 | "id": "m73ooU75CL_j" 180 | }, 181 | "source": [ 182 | "# Neural Network Model\n", 183 | "Try out different model architectures by modifying the class below." 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "class My_Model(nn.Module):\n", 193 | " def __init__(self, input_dim):\n", 194 | " super(My_Model, self).__init__()\n", 195 | " # TODO: modify model's structure, be aware of dimensions. \n", 196 | " self.layers = nn.Sequential(\n", 197 | " nn.Linear(input_dim, 512),\n", 198 | " nn.Dropout(0.4),\n", 199 | " nn.LeakyReLU(),\n", 200 | " nn.Linear(512, 1),\n", 201 | " )\n", 202 | "\n", 203 | " def forward(self, x):\n", 204 | " x = self.layers(x)\n", 205 | " x = x.squeeze(1) # (B, 1) -> (B)\n", 206 | " return x" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": { 212 | "id": "x5-LKF6R8xeq" 213 | }, 214 | "source": [ 215 | "# Feature Selection\n", 216 | "Choose features you deem useful by modifying the function below." 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": { 223 | "id": "0FEnKRaIIeKp" 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "def select_feat(train_data, valid_data, test_data, select_all=True, feature_selection_idx_list=None):\n", 228 | " '''Selects useful features to perform regression'''\n", 229 | " y_train, y_valid = train_data[:,-1], valid_data[:,-1]\n", 230 | " raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_data\n", 231 | "\n", 232 | " if not select_all and feature_selection_idx_list is not None:\n", 233 | " feat_idx = feature_selection_idx_list\n", 234 | " else:\n", 235 | " feat_idx = list(range(raw_x_train.shape[1]))\n", 236 | " \n", 237 | " return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": { 243 | "id": "kADIPNQ2Ih5X" 244 | }, 245 | "source": [ 246 | "# Training Loop" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": { 253 | "id": "k4Rq8_TztAhq" 254 | }, 255 | "outputs": [], 256 | "source": [ 257 | "def trainer(train_loader, valid_loader, model, config, device):\n", 258 | "\n", 259 | " criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.\n", 260 | "\n", 261 | " # Define your optimization algorithm. \n", 262 | " # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.\n", 263 | " # TODO: L2 regularization (optimizer(weight decay...) or implement by your self).\n", 264 | " optimizer = torch.optim.AdamW(model.parameters(), weight_decay=5e-5) # 5e-5\n", 265 | "# optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.5) \n", 266 | "\n", 267 | " writer = SummaryWriter() # Writer of tensoboard.\n", 268 | "\n", 269 | " if not os.path.isdir('./models'):\n", 270 | " os.mkdir('./models') # Create directory of saving models.\n", 271 | "\n", 272 | " n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0\n", 273 | "\n", 274 | " for epoch in range(n_epochs):\n", 275 | " model.train() # Set your model to train mode.\n", 276 | " loss_record = []\n", 277 | "\n", 278 | " # tqdm is a package to visualize your training progress.\n", 279 | " train_pbar = tqdm(train_loader, position=0, leave=True)\n", 280 | "\n", 281 | " for x, y in train_pbar:\n", 282 | " optimizer.zero_grad() # Set gradient to zero.\n", 283 | " x, y = x.to(device), y.to(device) # Move your data to device. \n", 284 | " pred = model(x) \n", 285 | " loss = criterion(pred, y) # RMSE\n", 286 | " loss.backward() # Compute gradient(backpropagation).\n", 287 | " optimizer.step() # Update parameters.\n", 288 | " step += 1\n", 289 | " loss_record.append(loss.detach().item())\n", 290 | " \n", 291 | " # Display current epoch number and loss on tqdm progress bar.\n", 292 | " train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')\n", 293 | " train_pbar.set_postfix({'loss': loss.detach().item()})\n", 294 | "\n", 295 | " mean_train_loss = sum(loss_record)/len(loss_record)\n", 296 | " writer.add_scalar('Loss/train', mean_train_loss, step)\n", 297 | "\n", 298 | " model.eval() # Set your model to evaluation mode.\n", 299 | " loss_record = []\n", 300 | " for x, y in valid_loader:\n", 301 | " x, y = x.to(device), y.to(device)\n", 302 | " with torch.no_grad():\n", 303 | " pred = model(x)\n", 304 | " loss = criterion(pred, y)\n", 305 | "\n", 306 | " loss_record.append(loss.item())\n", 307 | " \n", 308 | " mean_valid_loss = sum(loss_record)/len(loss_record)\n", 309 | " print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')\n", 310 | " writer.add_scalar('Loss/valid', mean_valid_loss, step)\n", 311 | "\n", 312 | " if mean_valid_loss < best_loss:\n", 313 | " best_loss = mean_valid_loss\n", 314 | " torch.save(model.state_dict(), config['save_path']) # Save your best model\n", 315 | " print('Saving model with loss {:.3f}...'.format(best_loss))\n", 316 | " early_stop_count = 0\n", 317 | " else: \n", 318 | " early_stop_count += 1\n", 319 | "\n", 320 | " if early_stop_count >= config['early_stop']:\n", 321 | " print('\\nModel is not improving, so we halt the training session.')\n", 322 | " print(f'Best Loss: {best_loss}')\n", 323 | " return" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": { 329 | "id": "0pgkOh2e9UjE" 330 | }, 331 | "source": [ 332 | "# Configurations\n", 333 | "`config` contains hyper-parameters for training and the path to save your model." 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "colab": { 341 | "base_uri": "https://localhost:8080/" 342 | }, 343 | "id": "QoWPUahCtoT6", 344 | "outputId": "1f6cecc7-9386-470e-8159-bda5242ebefb" 345 | }, 346 | "outputs": [], 347 | "source": [ 348 | "# device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 349 | "torch.cuda.set_device(1) \n", 350 | "device = torch.device(\"cuda\", 1)\n", 351 | "print(device)\n", 352 | "\n", 353 | "config = {\n", 354 | " 'seed': 5201314, # 5201314 Your seed number, you can pick your lucky number. :)\n", 355 | " 'select_all': True, # True Whether to use all features.\n", 356 | " 'valid_ratio': 0.3, # 0.3 # validation_size = train_size * valid_ratio\n", 357 | " 'n_epochs': 5000, # 5000 Number of epochs. \n", 358 | " 'batch_size': 1024, # 1024\n", 359 | " 'learning_rate': 2e-4, # 2e-4 \n", 360 | " 'early_stop': 300, # 300 If model has not improved for this many consecutive epochs, stop training. \n", 361 | " 'save_path': './models/model_train_test_73_1.ckpt' # Your model will be saved here.\n", 362 | "}" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "metadata": { 368 | "id": "lrS-aJJh9XkW" 369 | }, 370 | "source": [ 371 | "# Dataloader\n", 372 | "Read data from files and set up training, validation, and testing sets. You do not need to modify this part." 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": { 379 | "colab": { 380 | "base_uri": "https://localhost:8080/" 381 | }, 382 | "id": "2jc7ZfDot2t9", 383 | "outputId": "ea3c28ed-a8d6-475e-e7d6-2694e825dad8" 384 | }, 385 | "outputs": [], 386 | "source": [ 387 | "# Set seed for reproducibility\n", 388 | "same_seed(config['seed'])\n", 389 | "\n", 390 | "\n", 391 | "# train_data size: 2699 x 118 (id + 37 states + 16 features x 5 days) \n", 392 | "# test_data size: 1078 x 117 (without last day's positive rate)\n", 393 | "train_data, test_data = pd.read_csv('./covid.train_test_1.csv').values, pd.read_csv('./covid.test_1.csv').values\n", 394 | "train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])\n", 395 | "\n", 396 | "# Print out the data size before feature selection.\n", 397 | "print(f\"\"\" -- Before feature selection --\n", 398 | "train_data size: {train_data.shape} \n", 399 | "valid_data size: {valid_data.shape} \n", 400 | "test_data size: {test_data.shape}\"\"\")\n", 401 | "\n", 402 | "# Select features\n", 403 | "x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])\n", 404 | "\n", 405 | "# Print out the data size after feature selection.\n", 406 | "print(f\"\"\" -- After feature selection --\n", 407 | "train_data size: {x_train.shape} \n", 408 | "valid_data size: {x_valid.shape} \n", 409 | "test_data size: {x_test.shape}\"\"\")\n", 410 | "\n", 411 | "# Print out the number of features.\n", 412 | "print(f'number of features: {x_train.shape[1]}')\n", 413 | "\n", 414 | "train_dataset = COVID19Dataset(x_train, y_train)\n", 415 | "valid_dataset = COVID19Dataset(x_valid, y_valid)\n", 416 | "test_dataset = COVID19Dataset(x_test)\n", 417 | " \n", 418 | "# Pytorch data loader loads pytorch dataset into batches.\n", 419 | "train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)\n", 420 | "valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)\n", 421 | "test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "# print(relevant_features)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "metadata": { 437 | "id": "1my3axtUH3tU" 438 | }, 439 | "outputs": [], 440 | "source": [ 441 | "# train_data size: (2160, 118) \n", 442 | "# valid_data size: (539, 118) \n", 443 | "# test_data size: (1078, 117)\n", 444 | "# number of features: 117" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": { 450 | "id": "0OBYgjCA-YwD" 451 | }, 452 | "source": [ 453 | "# Start training!" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "metadata": { 460 | "colab": { 461 | "base_uri": "https://localhost:8080/" 462 | }, 463 | "id": "YdttVRkAfu2t", 464 | "outputId": "cf54e2b2-c13b-43ca-e46b-e75d7d164812" 465 | }, 466 | "outputs": [], 467 | "source": [ 468 | "model = My_Model(input_dim=x_train.shape[1]).to(device) # put your model and data on the same computation device.\n", 469 | "trainer(train_loader, valid_loader, model, config, device)" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "id": "KBcodjJjODyY" 477 | }, 478 | "outputs": [], 479 | "source": [ 480 | "# Best Valid loss without FS: 2.4238\n", 481 | "# Best Valid loss with FS: 1.1371\n", 482 | "# Best Valid loss with Adam: 1.0121\n", 483 | "# Best Valid loss with AdamW: 0.9102\n", 484 | "# Current Best: 0.9029\n", 485 | "\n", 486 | "# 0.8854" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": { 492 | "id": "Ik09KPqU-di-" 493 | }, 494 | "source": [ 495 | "# Plot learning curves with `tensorboard` (optional)\n", 496 | "\n", 497 | "`tensorboard` is a tool that allows you to visualize your training progress.\n", 498 | "\n", 499 | "If this block does not display your learning curve, please wait for few minutes, and re-run this block. It might take some time to load your logging information. " 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "metadata": { 506 | "id": "loA4nKmLGQ-n" 507 | }, 508 | "outputs": [], 509 | "source": [ 510 | "%reload_ext tensorboard\n", 511 | "%tensorboard --logdir=./runs/" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": { 517 | "id": "yhAHGqC9-woK" 518 | }, 519 | "source": [ 520 | "# Testing\n", 521 | "The predictions of your model on testing set will be stored at `pred.csv`." 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": { 528 | "id": "Q5eVdpbvAlAe" 529 | }, 530 | "outputs": [], 531 | "source": [ 532 | "def save_pred(preds, file):\n", 533 | " ''' Save predictions to specified file '''\n", 534 | " with open(file, 'w') as fp:\n", 535 | " writer = csv.writer(fp)\n", 536 | " writer.writerow(['id', 'tested_positive'])\n", 537 | " for i, p in enumerate(preds):\n", 538 | " writer.writerow([i, p])\n", 539 | "\n", 540 | "model = My_Model(input_dim=x_train.shape[1]).to(device)\n", 541 | "model.load_state_dict(torch.load(config['save_path']))\n", 542 | "preds = predict(test_loader, model, device) \n", 543 | "save_pred(preds, 'pred.csv') " 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": {}, 550 | "outputs": [], 551 | "source": [ 552 | "preds" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [ 561 | "# np.save('pred_arr_train_test_1', preds)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "markdown", 566 | "metadata": { 567 | "id": "IJ_k5rY0GvSV" 568 | }, 569 | "source": [ 570 | "# Reference\n", 571 | "This notebook uses code written by Heng-Jui Chang @ NTUEE (https://github.com/ga642381/ML2021-Spring/blob/main/HW01/HW01.ipynb)" 572 | ] 573 | } 574 | ], 575 | "metadata": { 576 | "accelerator": "GPU", 577 | "colab": { 578 | "collapsed_sections": [], 579 | "name": "ML2022Spring_HW1.ipynb", 580 | "provenance": [] 581 | }, 582 | "kernelspec": { 583 | "display_name": "Python 3 (ipykernel)", 584 | "language": "python", 585 | "name": "python3" 586 | }, 587 | "language_info": { 588 | "codemirror_mode": { 589 | "name": "ipython", 590 | "version": 3 591 | }, 592 | "file_extension": ".py", 593 | "mimetype": "text/x-python", 594 | "name": "python", 595 | "nbconvert_exporter": "python", 596 | "pygments_lexer": "ipython3", 597 | "version": "3.7.5" 598 | }, 599 | "toc": { 600 | "base_numbering": 1, 601 | "nav_menu": {}, 602 | "number_sections": true, 603 | "sideBar": true, 604 | "skip_h1_title": false, 605 | "title_cell": "Table of Contents", 606 | "title_sidebar": "Contents", 607 | "toc_cell": false, 608 | "toc_position": {}, 609 | "toc_section_display": true, 610 | "toc_window_display": true 611 | } 612 | }, 613 | "nbformat": 4, 614 | "nbformat_minor": 1 615 | } 616 | -------------------------------------------------------------------------------- /hw-08/hw8_strong.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "YiVfKn-6tXz8" 7 | }, 8 | "source": [ 9 | "# **Homework 8 - Anomaly Detection**\n", 10 | "\n", 11 | "If there are any questions, please contact mlta-2022spring-ta@googlegroups.com\n", 12 | "\n", 13 | "Slide: [Link]() Kaggle: [Link](https://www.kaggle.com/c/ml2022spring-hw8)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "bDk9r2YOcDc9" 20 | }, 21 | "source": [ 22 | "# Set up the environment\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "!nvidia-smi" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": { 37 | "id": "Oi12tJMYWi0Q" 38 | }, 39 | "source": [ 40 | "## Package installation" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "colab": { 48 | "base_uri": "https://localhost:8080/" 49 | }, 50 | "id": "7LexxyPWWjJB", 51 | "outputId": "3a733a84-fca3-4e7c-fb9b-bfd5f890114d" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# Training progress bar\n", 56 | "%pip install -q qqdm" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": { 62 | "id": "DCgNXSsEWuY7" 63 | }, 64 | "source": [ 65 | "## Downloading data" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "colab": { 73 | "base_uri": "https://localhost:8080/" 74 | }, 75 | "id": "SCLJtgF2BLSK", 76 | "outputId": "54d462c4-2121-46a9-a966-1ca9b01e7b61" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "!wget https://github.com/MachineLearningHW/HW8_Dataset/releases/download/v1.0.0/data.zip" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "colab": { 88 | "base_uri": "https://localhost:8080/" 89 | }, 90 | "id": "0K5kmlkuWzhJ", 91 | "outputId": "c12176a4-f513-4ed3-c351-c82ef26e8072" 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "!unzip data.zip" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "id": "HNe7QU7n7cqh" 102 | }, 103 | "source": [ 104 | "# Import packages" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "id": "Jk3qFK_a7k8P" 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "import random\n", 116 | "import numpy as np\n", 117 | "import torch\n", 118 | "from torch import nn\n", 119 | "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n", 120 | "import torchvision.transforms as transforms\n", 121 | "import torch.nn.functional as F\n", 122 | "from torch.autograd import Variable\n", 123 | "import torchvision.models as models\n", 124 | "from torch.optim import Adam, AdamW, lr_scheduler\n", 125 | "from qqdm import qqdm, format_str\n", 126 | "import pandas as pd" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": { 132 | "id": "6X6fkGPnYyaF" 133 | }, 134 | "source": [ 135 | "# Loading data" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "colab": { 143 | "base_uri": "https://localhost:8080/" 144 | }, 145 | "id": "k7Wd4yiUYzAm", 146 | "outputId": "2c2fa7bf-1c0a-4090-ac16-627c1fe5ca5c" 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "train = np.load('data/trainingset.npy', allow_pickle=True)\n", 151 | "test = np.load('data/testingset.npy', allow_pickle=True)\n", 152 | "\n", 153 | "print(train.shape)\n", 154 | "print(test.shape)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": { 160 | "id": "_flpmj6OYIa6" 161 | }, 162 | "source": [ 163 | "## Random seed\n", 164 | "Set the random seed to a certain value for reproducibility." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "id": "Gb-dgXQYYI2Q" 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "def same_seeds(seed):\n", 176 | " random.seed(seed)\n", 177 | " np.random.seed(seed)\n", 178 | " torch.manual_seed(seed)\n", 179 | " if torch.cuda.is_available():\n", 180 | " torch.cuda.manual_seed(seed)\n", 181 | " torch.cuda.manual_seed_all(seed)\n", 182 | " torch.backends.cudnn.benchmark = False\n", 183 | " torch.backends.cudnn.deterministic = True\n", 184 | "\n", 185 | "same_seeds(42)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": { 191 | "id": "zR9zC0_Df-CR" 192 | }, 193 | "source": [ 194 | "# Autoencoder" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": { 200 | "id": "1EbfwRREhA7c" 201 | }, 202 | "source": [ 203 | "# Models & loss" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": { 210 | "id": "Wi8ds1fugCkR" 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "class fcn_autoencoder(nn.Module):\n", 215 | " def __init__(self):\n", 216 | " super(fcn_autoencoder, self).__init__()\n", 217 | " self.encoder = nn.Sequential(\n", 218 | " nn.Linear(64 * 64 * 3, 1024),\n", 219 | " nn.LeakyReLU(0.1), \n", 220 | " nn.Linear(1024, 512),\n", 221 | " nn.LeakyReLU(0.1), \n", 222 | " nn.Linear(512, 256),\n", 223 | " nn.LeakyReLU(0.1),\n", 224 | " nn.Linear(256, 128), \n", 225 | " nn.LeakyReLU(0.1),\n", 226 | " nn.Linear(128, 64),\n", 227 | " )\n", 228 | "\n", 229 | " self.decoder = nn.Sequential(\n", 230 | " nn.Linear(64, 128),\n", 231 | " nn.LeakyReLU(0.1),\n", 232 | " nn.Linear(128, 256),\n", 233 | " nn.LeakyReLU(0.1),\n", 234 | " nn.Linear(256, 512),\n", 235 | " nn.LeakyReLU(0.1),\n", 236 | " nn.Linear(512, 1024),\n", 237 | " nn.LeakyReLU(0.1),\n", 238 | " nn.Linear(1024, 64 * 64 * 3), \n", 239 | " nn.Tanh()\n", 240 | " )\n", 241 | "\n", 242 | " def forward(self, x):\n", 243 | " code = self.encoder(x)\n", 244 | " # Adjust Latent Repr for Report Image\n", 245 | " # code = target_code\n", 246 | " y = self.decoder(code)\n", 247 | " # return code, y\n", 248 | " return y\n", 249 | "\n", 250 | "\n", 251 | "class conv_autoencoder(nn.Module):\n", 252 | " def __init__(self):\n", 253 | " super(conv_autoencoder, self).__init__()\n", 254 | " self.encoder = nn.Sequential(\n", 255 | " nn.Conv2d(3, 128, 4, stride=2, padding=1),\n", 256 | " nn.BatchNorm2d(128),\n", 257 | " nn.ReLU(),\n", 258 | " nn.Conv2d(128, 256, 4, stride=2, padding=1), \n", 259 | " nn.BatchNorm2d(256),\n", 260 | " nn.ReLU(),\n", 261 | " nn.Conv2d(256, 512, 4, stride=2, padding=1),\n", 262 | " nn.BatchNorm2d(512),\n", 263 | " nn.ReLU(),\n", 264 | " )\n", 265 | " self.decoder = nn.Sequential(\n", 266 | " nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),\n", 267 | " nn.BatchNorm2d(256),\n", 268 | " nn.ReLU(),\n", 269 | " nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),\n", 270 | " nn.BatchNorm2d(128),\n", 271 | " nn.ReLU(),\n", 272 | " nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),\n", 273 | " nn.Tanh(),\n", 274 | " )\n", 275 | "\n", 276 | " def forward(self, x):\n", 277 | " x = self.encoder(x)\n", 278 | " x = self.decoder(x)\n", 279 | " return x\n", 280 | "\n", 281 | "\n", 282 | "class VAE(nn.Module):\n", 283 | " def __init__(self):\n", 284 | " super(VAE, self).__init__()\n", 285 | " self.encoder = nn.Sequential(\n", 286 | " nn.Conv2d(3, 12, 4, stride=2, padding=1), \n", 287 | " nn.ReLU(),\n", 288 | " nn.Conv2d(12, 24, 4, stride=2, padding=1), \n", 289 | " nn.ReLU(),\n", 290 | " )\n", 291 | " self.enc_out_1 = nn.Sequential(\n", 292 | " nn.Conv2d(24, 48, 4, stride=2, padding=1), \n", 293 | " nn.ReLU(),\n", 294 | " )\n", 295 | " self.enc_out_2 = nn.Sequential(\n", 296 | " nn.Conv2d(24, 48, 4, stride=2, padding=1),\n", 297 | " nn.ReLU(),\n", 298 | " )\n", 299 | " self.decoder = nn.Sequential(\n", 300 | " nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1), \n", 301 | " nn.ReLU(),\n", 302 | " nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1), \n", 303 | " nn.ReLU(),\n", 304 | " nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1), \n", 305 | " nn.Tanh(),\n", 306 | " )\n", 307 | "\n", 308 | " def encode(self, x):\n", 309 | " h1 = self.encoder(x)\n", 310 | " return self.enc_out_1(h1), self.enc_out_2(h1)\n", 311 | "\n", 312 | " def reparametrize(self, mu, logvar):\n", 313 | " std = logvar.mul(0.5).exp_()\n", 314 | " if torch.cuda.is_available():\n", 315 | " eps = torch.cuda.FloatTensor(std.size()).normal_()\n", 316 | " else:\n", 317 | " eps = torch.FloatTensor(std.size()).normal_()\n", 318 | " eps = Variable(eps)\n", 319 | " return eps.mul(std).add_(mu)\n", 320 | "\n", 321 | " def decode(self, z):\n", 322 | " return self.decoder(z)\n", 323 | "\n", 324 | " def forward(self, x):\n", 325 | " mu, logvar = self.encode(x)\n", 326 | " z = self.reparametrize(mu, logvar)\n", 327 | " return self.decode(z), mu, logvar\n", 328 | "\n", 329 | "\n", 330 | "def loss_vae(recon_x, x, mu, logvar, criterion):\n", 331 | " \"\"\"\n", 332 | " recon_x: generating images\n", 333 | " x: origin images\n", 334 | " mu: latent mean\n", 335 | " logvar: latent log variance\n", 336 | " \"\"\"\n", 337 | " mse = criterion(recon_x, x)\n", 338 | " KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)\n", 339 | " KLD = torch.sum(KLD_element).mul_(-0.5)\n", 340 | " return mse + KLD" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": { 346 | "id": "vrJ9bScg9AgO" 347 | }, 348 | "source": [ 349 | "# Dataset module\n", 350 | "\n", 351 | "Module for obtaining and processing data. The transform function here normalizes image's pixels from [0, 255] to [-1.0, 1.0].\n" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": { 358 | "id": "33fWhE-h9LPq" 359 | }, 360 | "outputs": [], 361 | "source": [ 362 | "class CustomTensorDataset(TensorDataset):\n", 363 | " \"\"\"TensorDataset with support of transforms.\n", 364 | " \"\"\"\n", 365 | " def __init__(self, tensors):\n", 366 | " self.tensors = tensors\n", 367 | " if tensors.shape[-1] == 3:\n", 368 | " self.tensors = tensors.permute(0, 3, 1, 2)\n", 369 | " \n", 370 | " self.transform = transforms.Compose([\n", 371 | " transforms.Lambda(lambda x: x.to(torch.float32)),\n", 372 | " transforms.Lambda(lambda x: 2. * x/255. - 1.),\n", 373 | " ])\n", 374 | " \n", 375 | " def __getitem__(self, index):\n", 376 | " x = self.tensors[index]\n", 377 | " \n", 378 | " if self.transform:\n", 379 | " # mapping images to [-1.0, 1.0]\n", 380 | " x = self.transform(x)\n", 381 | "\n", 382 | " return x\n", 383 | "\n", 384 | " def __len__(self):\n", 385 | " return len(self.tensors)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": { 391 | "id": "XKNUImqUhIeq" 392 | }, 393 | "source": [ 394 | "# Training" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": { 400 | "id": "7ebAJdjFmS08" 401 | }, 402 | "source": [ 403 | "## Configuration\n" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "metadata": { 410 | "id": "in7yLfmqtZTk" 411 | }, 412 | "outputs": [], 413 | "source": [ 414 | "# Training hyperparameters\n", 415 | "num_epochs = 200 # 500 \n", 416 | "batch_size = 256\n", 417 | "learning_rate = 5e-4 # 1e-3\n", 418 | "\n", 419 | "# Build training dataloader\n", 420 | "x = torch.from_numpy(train)\n", 421 | "train_dataset = CustomTensorDataset(x)\n", 422 | "\n", 423 | "train_sampler = RandomSampler(train_dataset)\n", 424 | "train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)\n", 425 | "\n", 426 | "# Model\n", 427 | "model_type = 'fcn' # selecting a model type from {'cnn', 'fcn', 'vae', 'resnet'}\n", 428 | "model_classes = {'fcn': fcn_autoencoder(), 'cnn': conv_autoencoder(), 'vae': VAE()}\n", 429 | "model = model_classes[model_type].cuda()\n", 430 | "\n", 431 | "# Loss and optimizer\n", 432 | "criterion = nn.MSELoss()\n", 433 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", 434 | "total_steps = len(train_dataloader) * num_epochs\n", 435 | "break_steps = int(0.05 * total_steps)\n", 436 | "scheduler = lr_scheduler.StepLR(optimizer, step_size=break_steps, gamma=0.95)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": { 442 | "id": "wyooN-JPm8sS" 443 | }, 444 | "source": [ 445 | "## Training loop" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": { 452 | "id": "JoW1UrrxgI_U" 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "best_loss = np.inf\n", 457 | "model.train()\n", 458 | "\n", 459 | "qqdm_train = qqdm(range(num_epochs), desc=format_str('bold', 'Description'))\n", 460 | "for epoch in qqdm_train:\n", 461 | " tot_loss = list()\n", 462 | " for data in train_dataloader:\n", 463 | "\n", 464 | " # ===================loading=====================\n", 465 | " img = data.float().cuda()\n", 466 | " if model_type in ['fcn']:\n", 467 | " img = img.view(img.shape[0], -1)\n", 468 | " # ===================forward=====================\n", 469 | " output = model(img)\n", 470 | " if model_type in ['vae']:\n", 471 | " loss = loss_vae(output[0], img, output[1], output[2], criterion)\n", 472 | " else:\n", 473 | " loss = criterion(output, img)\n", 474 | "\n", 475 | " tot_loss.append(loss.item())\n", 476 | " # ===================backward====================\n", 477 | " optimizer.zero_grad()\n", 478 | " loss.backward()\n", 479 | " optimizer.step()\n", 480 | " scheduler.step()\n", 481 | " # ===================save_best====================\n", 482 | " mean_loss = np.mean(tot_loss)\n", 483 | " if mean_loss < best_loss:\n", 484 | " best_loss = mean_loss\n", 485 | " torch.save(model, './models/best_model_{}.pt'.format(model_type))\n", 486 | " # ===================log========================\n", 487 | " qqdm_train.set_infos({\n", 488 | " 'epoch': f'{epoch + 1:.0f}/{num_epochs:.0f}',\n", 489 | " 'loss': f'{mean_loss:.4f}',\n", 490 | " })\n", 491 | " # ===================save_last========================\n", 492 | " torch.save(model, './models/last_model_{}.pt'.format(model_type))\n", 493 | " print(f\"LR: {scheduler.get_last_lr()[0]}\")" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": {}, 499 | "source": [ 500 | "# Adjust latent representation and plot its result" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "# torch.save(model.state_dict(), \"./models/report_model\")" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "# model = fcn_autoencoder()\n", 519 | "# model = model.cuda()\n", 520 | "# model.load_state_dict(torch.load(\"./models/report_model\"))\n", 521 | "# model.eval()" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [ 530 | "# eval_batch_size = 1\n", 531 | "\n", 532 | "# # build testing dataloader\n", 533 | "# data = torch.tensor(test, dtype=torch.float32)\n", 534 | "# test_dataset = CustomTensorDataset(data)\n", 535 | "# test_sampler = SequentialSampler(test_dataset)\n", 536 | "# test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=eval_batch_size, num_workers=1)" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": null, 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "# # %pip install opencv-python\n", 546 | "# import torchvision\n", 547 | "# import matplotlib.pyplot as plt\n", 548 | "\n", 549 | "# anomality = list()\n", 550 | "# with torch.no_grad():\n", 551 | "# for i, data in enumerate(test_dataloader):\n", 552 | "# if i == 9:\n", 553 | "# ######### 原始圖片 #########\n", 554 | "# print(\"原始圖片\")\n", 555 | "# img = data.float()\n", 556 | "# img = img.squeeze(0)\n", 557 | "# print(img.shape)\n", 558 | "# img = img.permute(1, 2, 0).numpy()\n", 559 | "# plt.imshow(img)\n", 560 | "# plt.show()\n", 561 | "# ###### Latent Repr. #######\n", 562 | "# img_train = data.float().cuda()\n", 563 | "# img_train = img_train.view(img_train.shape[0], -1)\n", 564 | "# target_code, output = model(img_train)\n", 565 | "# ######### 原始輸出 #########\n", 566 | "# print(\"原始輸出\")\n", 567 | "# print(output.shape)\n", 568 | "# output = output.squeeze(0).cpu()\n", 569 | "# output = output.reshape((3, 64, 64))\n", 570 | "# print(output.shape)\n", 571 | "# output = output.permute(1, 2, 0).numpy()\n", 572 | "# plt.imshow(output)\n", 573 | "# plt.show()\n", 574 | "# ## Adjust Latent Repr. ###\n", 575 | "# print(\"Code 調整\")\n", 576 | "# print(target_code)\n", 577 | "# target_code[0][-1] += 20\n", 578 | "# print(target_code)\n", 579 | "# ###########################\n", 580 | "# break" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": null, 586 | "metadata": {}, 587 | "outputs": [], 588 | "source": [ 589 | "# target_code" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "# # %pip install opencv-python\n", 599 | "# import torchvision\n", 600 | "# import matplotlib.pyplot as plt\n", 601 | "\n", 602 | "# anomality = list()\n", 603 | "# with torch.no_grad():\n", 604 | "# for i, data in enumerate(test_dataloader):\n", 605 | "# if i == 9:\n", 606 | "# ######### 原始圖片 #########\n", 607 | "# print(\"原始圖片\")\n", 608 | "# img = data.float()\n", 609 | "# img = img.squeeze(0)\n", 610 | "# print(img.shape)\n", 611 | "# img = img.permute(1, 2, 0).numpy()\n", 612 | "# plt.imshow(img)\n", 613 | "# plt.show()\n", 614 | "# ###### Latent Repr. #######\n", 615 | "# img_train = data.float().cuda()\n", 616 | "# img_train = img_train.view(img_train.shape[0], -1)\n", 617 | "# target_code, output = model(img_train)\n", 618 | "# ######### 調整後輸出 #########\n", 619 | "# print(\"調整後輸出\")\n", 620 | "# print(output.shape)\n", 621 | "# output = output.squeeze(0).cpu()\n", 622 | "# output = output.reshape((3, 64, 64))\n", 623 | "# print(output.shape)\n", 624 | "# output = output.permute(1, 2, 0).numpy()\n", 625 | "# plt.imshow(output)\n", 626 | "# plt.show()\n", 627 | "# ###########################\n", 628 | "# break" 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "metadata": { 634 | "id": "Wk0UxFuchLzR" 635 | }, 636 | "source": [ 637 | "# Inference\n", 638 | "Model is loaded and generates its anomaly score predictions." 639 | ] 640 | }, 641 | { 642 | "cell_type": "markdown", 643 | "metadata": { 644 | "id": "evgMW3OwoGqD" 645 | }, 646 | "source": [ 647 | "## Initialize\n", 648 | "- dataloader\n", 649 | "- model\n", 650 | "- prediction file" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "metadata": { 657 | "id": "_MBnXAswoKmq" 658 | }, 659 | "outputs": [], 660 | "source": [ 661 | "eval_batch_size = 200\n", 662 | "\n", 663 | "# build testing dataloader\n", 664 | "data = torch.tensor(test, dtype=torch.float32)\n", 665 | "test_dataset = CustomTensorDataset(data)\n", 666 | "test_sampler = SequentialSampler(test_dataset)\n", 667 | "test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=eval_batch_size, num_workers=1)\n", 668 | "eval_loss = nn.MSELoss(reduction='none')\n", 669 | "\n", 670 | "# load trained model\n", 671 | "checkpoint_path = f'./models/best_model_{model_type}.pt'\n", 672 | "model = torch.load(checkpoint_path)\n", 673 | "model.eval()\n", 674 | "\n", 675 | "# prediction file \n", 676 | "out_file = './outputs/prediction.csv'" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": null, 682 | "metadata": { 683 | "id": "_1IxCX2iCW6V" 684 | }, 685 | "outputs": [], 686 | "source": [ 687 | "anomality = list()\n", 688 | "with torch.no_grad():\n", 689 | " for i, data in enumerate(test_dataloader):\n", 690 | " img = data.float().cuda()\n", 691 | " if model_type in ['fcn']:\n", 692 | " img = img.view(img.shape[0], -1)\n", 693 | " output = model(img)\n", 694 | " if model_type in ['vae']:\n", 695 | " output = output[0]\n", 696 | " if model_type in ['fcn']:\n", 697 | " loss = eval_loss(output, img).sum(-1)\n", 698 | " else:\n", 699 | " loss = eval_loss(output, img).sum([1, 2, 3])\n", 700 | " anomality.append(loss)\n", 701 | "anomality = torch.cat(anomality, axis=0)\n", 702 | "anomality = torch.sqrt(anomality).reshape(len(test), 1).cpu().numpy()\n", 703 | "\n", 704 | "df = pd.DataFrame(anomality, columns=['score'])\n", 705 | "df.to_csv(out_file, index_label = 'ID')" 706 | ] 707 | } 708 | ], 709 | "metadata": { 710 | "accelerator": "GPU", 711 | "colab": { 712 | "collapsed_sections": [ 713 | "bDk9r2YOcDc9", 714 | "Oi12tJMYWi0Q", 715 | "DCgNXSsEWuY7", 716 | "HNe7QU7n7cqh", 717 | "6X6fkGPnYyaF", 718 | "1EbfwRREhA7c", 719 | "vrJ9bScg9AgO", 720 | "XKNUImqUhIeq" 721 | ], 722 | "name": "ML2022Spring - HW8.ipynb", 723 | "provenance": [] 724 | }, 725 | "kernelspec": { 726 | "display_name": "kuokuo_env", 727 | "language": "python", 728 | "name": "kuokuo_env" 729 | }, 730 | "language_info": { 731 | "codemirror_mode": { 732 | "name": "ipython", 733 | "version": 3 734 | }, 735 | "file_extension": ".py", 736 | "mimetype": "text/x-python", 737 | "name": "python", 738 | "nbconvert_exporter": "python", 739 | "pygments_lexer": "ipython3", 740 | "version": "3.7.5" 741 | }, 742 | "toc": { 743 | "base_numbering": 1, 744 | "nav_menu": {}, 745 | "number_sections": true, 746 | "sideBar": true, 747 | "skip_h1_title": false, 748 | "title_cell": "Table of Contents", 749 | "title_sidebar": "Contents", 750 | "toc_cell": false, 751 | "toc_position": {}, 752 | "toc_section_display": true, 753 | "toc_window_display": false 754 | } 755 | }, 756 | "nbformat": 4, 757 | "nbformat_minor": 1 758 | } 759 | -------------------------------------------------------------------------------- /hw-02/hw2_strong.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "OYlaRwNu7ojq" 7 | }, 8 | "source": [ 9 | "# **Homework 2 Phoneme Classification**\n", 10 | "\n", 11 | "* Slides: https://docs.google.com/presentation/d/1v6HkBWiJb8WNDcJ9_-2kwVstxUWml87b9CnA16Gdoio/edit?usp=sharing\n", 12 | "* Kaggle: https://www.kaggle.com/c/ml2022spring-hw2\n", 13 | "* Video: TBA\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "colab": { 21 | "base_uri": "https://localhost:8080/" 22 | }, 23 | "id": "mLQI0mNcmM-O", 24 | "outputId": "7d5b4d81-9438-4d50-8153-cd235c47ee21" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "!nvidia-smi" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "!pwd" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "KVUGfWTo7_Oj" 44 | }, 45 | "source": [ 46 | "## Download Data\n", 47 | "Download data from google drive, then unzip it.\n", 48 | "\n", 49 | "You should have\n", 50 | "- `libriphone/train_split.txt`\n", 51 | "- `libriphone/train_labels`\n", 52 | "- `libriphone/test_split.txt`\n", 53 | "- `libriphone/feat/train/*.pt`: training feature
\n", 54 | "- `libriphone/feat/test/*.pt`: testing feature
\n", 55 | "\n", 56 | "after running the following block.\n", 57 | "\n", 58 | "> **Notes: if the links are dead, you can download the data directly from [Kaggle](https://www.kaggle.com/c/ml2022spring-hw2/data) and upload it to the workspace, or you can use [the Kaggle API](https://www.kaggle.com/general/74235) to directly download the data into colab.**\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": { 64 | "id": "Bj5jYXsD9Ef3" 65 | }, 66 | "source": [ 67 | "### Download train/test metadata" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": { 74 | "colab": { 75 | "base_uri": "https://localhost:8080/" 76 | }, 77 | "id": "OzkiMEcC3Foq", 78 | "outputId": "cc90c16c-ee21-400e-ec08-dfcd422212a6" 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "# Main link\n", 83 | "### !wget -O libriphone.zip \"https://github.com/xraychen/shiny-robot/releases/download/v1.0/libriphone.zip\"\n", 84 | "\n", 85 | "# Backup Link 0\n", 86 | "# !pip install --upgrade gdown\n", 87 | "# !gdown --id '1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc' --output libriphone.zip\n", 88 | "\n", 89 | "# Backup link 1\n", 90 | "# !pip install --upgrade gdown\n", 91 | "# !gdown --id '1R1uQYi4QpX0tBfUWt2mbZcncdBsJkxeW' --output libriphone.zip\n", 92 | "\n", 93 | "# Backup link 2\n", 94 | "# !wget -O libriphone.zip \"https://www.dropbox.com/s/wqww8c5dbrl2ka9/libriphone.zip?dl=1\"\n", 95 | "\n", 96 | "# Backup link 3\n", 97 | "# !wget -O libriphone.zip \"https://www.dropbox.com/s/p2ljbtb2bam13in/libriphone.zip?dl=1\"\n", 98 | "\n", 99 | "### !unzip -q libriphone.zip\n", 100 | "### !ls libriphone" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": { 106 | "id": "_L_4anls8Drv" 107 | }, 108 | "source": [ 109 | "### Preparing Data" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": { 115 | "id": "po4N3C-AWuWl" 116 | }, 117 | "source": [ 118 | "**Helper functions to pre-process the training data from raw MFCC features of each utterance.**\n", 119 | "\n", 120 | "A phoneme may span several frames and is dependent to past and future frames. \\\n", 121 | "Hence we concatenate neighboring phonemes for training to achieve higher accuracy. The **concat_feat** function concatenates past and future k frames (total 2k+1 = n frames), and we predict the center frame.\n", 122 | "\n", 123 | "Feel free to modify the data preprocess functions, but **do not drop any frame** (if you modify the functions, remember to check that the number of frames are the same as mentioned in the slides)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "id": "IJjLT8em-y9G" 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "import os\n", 135 | "import random\n", 136 | "import pandas as pd\n", 137 | "import torch\n", 138 | "from tqdm import tqdm\n", 139 | "\n", 140 | "def load_feat(path):\n", 141 | " feat = torch.load(path)\n", 142 | " return feat\n", 143 | "\n", 144 | "def shift(x, n):\n", 145 | " if n < 0:\n", 146 | " left = x[0].repeat(-n, 1)\n", 147 | " right = x[:n]\n", 148 | "\n", 149 | " elif n > 0:\n", 150 | " right = x[-1].repeat(n, 1)\n", 151 | " left = x[n:]\n", 152 | " else:\n", 153 | " return x\n", 154 | "\n", 155 | " return torch.cat((left, right), dim=0)\n", 156 | "\n", 157 | "def concat_feat(x, concat_n):\n", 158 | " assert concat_n % 2 == 1 # n must be odd\n", 159 | " if concat_n < 2:\n", 160 | " return x\n", 161 | " seq_len, feature_dim = x.size(0), x.size(1)\n", 162 | " x = x.repeat(1, concat_n) \n", 163 | " x = x.view(seq_len, concat_n, feature_dim).permute(1, 0, 2) # concat_n, seq_len, feature_dim\n", 164 | " mid = (concat_n // 2)\n", 165 | " for r_idx in range(1, mid+1):\n", 166 | " x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)\n", 167 | " x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)\n", 168 | "\n", 169 | "# return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)\n", 170 | " return x.permute(1, 0, 2).view(seq_len, concat_n, feature_dim)\n", 171 | "\n", 172 | "def preprocess_data(split, feat_dir, phone_path, concat_nframes, train_ratio=0.8, train_val_seed=1337):\n", 173 | " class_num = 41 # NOTE: pre-computed, should not need change\n", 174 | " mode = 'train' if (split == 'train' or split == 'val') else 'test'\n", 175 | "\n", 176 | " label_dict = {}\n", 177 | " if mode != 'test':\n", 178 | " phone_file = open(os.path.join(phone_path, f'{mode}_labels.txt')).readlines()\n", 179 | "\n", 180 | " for line in phone_file:\n", 181 | " line = line.strip('\\n').split(' ')\n", 182 | " label_dict[line[0]] = [int(p) for p in line[1:]]\n", 183 | "\n", 184 | " if split == 'train' or split == 'val':\n", 185 | " # split training and validation data\n", 186 | " usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()\n", 187 | " random.seed(train_val_seed)\n", 188 | " random.shuffle(usage_list)\n", 189 | " percent = int(len(usage_list) * train_ratio)\n", 190 | " usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]\n", 191 | " elif split == 'test':\n", 192 | " usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()\n", 193 | " else:\n", 194 | " raise ValueError('Invalid \\'split\\' argument for dataset: PhoneDataset!')\n", 195 | "\n", 196 | " usage_list = [line.strip('\\n') for line in usage_list]\n", 197 | " print('[Dataset] - # phone classes: ' + str(class_num) + ', number of utterances for ' + split + ': ' + str(len(usage_list)))\n", 198 | "\n", 199 | " max_len = 3000000\n", 200 | "# X = torch.empty(max_len, 39 * concat_nframes)\n", 201 | " X = torch.empty((max_len, concat_nframes, 39))\n", 202 | " if mode != 'test':\n", 203 | " y = torch.empty(max_len, dtype=torch.long)\n", 204 | "\n", 205 | " idx = 0\n", 206 | " for i, fname in tqdm(enumerate(usage_list)):\n", 207 | " feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))\n", 208 | " cur_len = len(feat)\n", 209 | " feat = concat_feat(feat, concat_nframes)\n", 210 | " if mode != 'test':\n", 211 | " label = torch.LongTensor(label_dict[fname])\n", 212 | "\n", 213 | " X[idx: idx + cur_len, :, :] = feat\n", 214 | " if mode != 'test':\n", 215 | " y[idx: idx + cur_len] = label\n", 216 | "\n", 217 | " idx += cur_len\n", 218 | "\n", 219 | " X = X[:idx, :, :]\n", 220 | " if mode != 'test':\n", 221 | " y = y[:idx]\n", 222 | "\n", 223 | " print(f'[INFO] {split} set')\n", 224 | " print(X.shape)\n", 225 | " if mode != 'test':\n", 226 | " print(y.shape)\n", 227 | " return X, y\n", 228 | " else:\n", 229 | " return X\n" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "id": "us5XW_x6udZQ" 236 | }, 237 | "source": [ 238 | "## Define Dataset" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": { 245 | "id": "Fjf5EcmJtf4e" 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "import torch\n", 250 | "from torch.utils.data import Dataset\n", 251 | "from torch.utils.data import DataLoader\n", 252 | "\n", 253 | "class LibriDataset(Dataset):\n", 254 | " def __init__(self, X, y=None):\n", 255 | " self.data = X\n", 256 | " if y is not None:\n", 257 | " self.label = torch.LongTensor(y)\n", 258 | " else:\n", 259 | " self.label = None\n", 260 | "\n", 261 | " def __getitem__(self, idx):\n", 262 | " if self.label is not None:\n", 263 | " return self.data[idx], self.label[idx]\n", 264 | " else:\n", 265 | " return self.data[idx]\n", 266 | "\n", 267 | " def __len__(self):\n", 268 | " return len(self.data)\n" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": { 274 | "id": "IRqKNvNZwe3V" 275 | }, 276 | "source": [ 277 | "## Define Model" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": { 284 | "id": "Bg-GRd7ywdrL" 285 | }, 286 | "outputs": [], 287 | "source": [ 288 | "# import torch\n", 289 | "# import torch.nn as nn\n", 290 | "# import torch.nn.functional as F\n", 291 | "\n", 292 | "# class BasicBlock(nn.Module):\n", 293 | "# def __init__(self, input_dim, output_dim):\n", 294 | "# super(BasicBlock, self).__init__()\n", 295 | "\n", 296 | "# self.block = nn.Sequential(\n", 297 | "# nn.Linear(input_dim, output_dim),\n", 298 | "# nn.ReLU(),\n", 299 | "# )\n", 300 | "\n", 301 | "# def forward(self, x):\n", 302 | "# x = self.block(x)\n", 303 | "# return x\n", 304 | "\n", 305 | "\n", 306 | "# class Classifier(nn.Module):\n", 307 | "# def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):\n", 308 | "# super(Classifier, self).__init__()\n", 309 | "\n", 310 | "# self.fc = nn.Sequential(\n", 311 | "# BasicBlock(input_dim, hidden_dim),\n", 312 | "# *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],\n", 313 | "# nn.Linear(hidden_dim, output_dim)\n", 314 | "# )\n", 315 | "\n", 316 | "# def forward(self, x):\n", 317 | "# x = self.fc(x)\n", 318 | "# return x" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "import torch\n", 328 | "import torch.nn as nn\n", 329 | "import torch.nn.functional as F\n", 330 | "\n", 331 | "\n", 332 | "class Classifier(nn.Module):\n", 333 | " \n", 334 | " def __init__(self, dropout=0.1):\n", 335 | " super(Classifier, self).__init__()\n", 336 | " self.lstm_1 = nn.LSTM(39, 256, batch_first=True, bidirectional=True, dropout=dropout)\n", 337 | " self.ln_1 = nn.LayerNorm(512)\n", 338 | " self.lstm_2 = nn.LSTM(512, 128, batch_first=True, bidirectional=True, dropout=dropout)\n", 339 | " self.ln_2 = nn.LayerNorm(256)\n", 340 | " self.lstm_3 = nn.LSTM(256, 64, batch_first=True, bidirectional=True, dropout=dropout)\n", 341 | " self.linear = nn.Linear(128, 41)\n", 342 | "\n", 343 | " def forward(self, x):\n", 344 | " lstm_1_output, _ = self.lstm_1(x)\n", 345 | " lstm_2_output, _ = self.lstm_2(self.ln_1(lstm_1_output))\n", 346 | " lstm_3_output, (ht, ct) = self.lstm_3(self.ln_2(lstm_2_output))\n", 347 | " ht = torch.cat((ht[0], ht[1]), 1) # 把 BiLSTM 的兩層結果接起來\n", 348 | " ht = ht.unsqueeze(0) # 添加第0維\n", 349 | " linear_output = self.linear(ht[-1])\n", 350 | " return linear_output" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "id": "TlIq8JeqvvHC" 357 | }, 358 | "source": [ 359 | "## Hyper-parameters" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": { 366 | "id": "iIHn79Iav1ri" 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "# data prarameters\n", 371 | "concat_nframes = 25 # the number of frames to concat with, n must be odd (total 2k+1 = n frames)\n", 372 | "train_ratio = 0.8 # the ratio of data used for training, the rest will be used for validation\n", 373 | "\n", 374 | "# training parameters\n", 375 | "seed = 0 # random seed\n", 376 | "batch_size = 256 # batch size\n", 377 | "num_epoch = 10 # the number of training epoch\n", 378 | "learning_rate = 0.005 # learning rate\n", 379 | "model_path = './model/model_rnn.ckpt' # the path where the checkpoint will be saved\n", 380 | "\n", 381 | "# model parameters\n", 382 | "input_dim = 39 * concat_nframes # the input dim of the model, you should not change the value\n", 383 | "# hidden_layers = 4 # the number of hidden layers\n", 384 | "# hidden_dim = 128 # the hidden dim" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": { 390 | "id": "IIUFRgG5yoDn" 391 | }, 392 | "source": [ 393 | "## Prepare dataset and model" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": { 400 | "colab": { 401 | "base_uri": "https://localhost:8080/" 402 | }, 403 | "id": "c1zI3v5jyrDn", 404 | "outputId": "3ea2823a-83f3-42d9-ef05-2f2c002f9538" 405 | }, 406 | "outputs": [], 407 | "source": [ 408 | "import gc\n", 409 | "\n", 410 | "# preprocess data\n", 411 | "train_X, train_y = preprocess_data(split='train', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n", 412 | "val_X, val_y = preprocess_data(split='val', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n", 413 | "\n", 414 | "# get dataset\n", 415 | "train_set = LibriDataset(train_X, train_y)\n", 416 | "val_set = LibriDataset(val_X, val_y)\n", 417 | "\n", 418 | "# remove raw feature to save memory\n", 419 | "del train_X, train_y, val_X, val_y\n", 420 | "gc.collect()\n", 421 | "\n", 422 | "# get dataloader\n", 423 | "train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", 424 | "val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "# [Dataset] - # phone classes: 41, number of utterances for train: 3428\n", 434 | "# 3428it [00:01, 2097.23it/s]\n", 435 | "# [INFO] train set\n", 436 | "# torch.Size([2116368, 39])\n", 437 | "# torch.Size([2116368])\n", 438 | "# [Dataset] - # phone classes: 41, number of utterances for val: 858\n", 439 | "# 858it [00:00, 2087.91it/s]\n", 440 | "# [INFO] val set\n", 441 | "# torch.Size([527790, 39])\n", 442 | "# torch.Size([527790])" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "metadata": { 449 | "colab": { 450 | "base_uri": "https://localhost:8080/" 451 | }, 452 | "id": "CfRUEgC0GxUV", 453 | "outputId": "f9804711-72b1-4717-896b-821a300cfe87" 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", 458 | "print(f'DEVICE: {device}')" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": { 465 | "id": "88xPiUnm0tAd" 466 | }, 467 | "outputs": [], 468 | "source": [ 469 | "import numpy as np\n", 470 | "\n", 471 | "#fix seed\n", 472 | "def same_seeds(seed):\n", 473 | " torch.manual_seed(seed)\n", 474 | " if torch.cuda.is_available():\n", 475 | " torch.cuda.manual_seed(seed)\n", 476 | " torch.cuda.manual_seed_all(seed) \n", 477 | " np.random.seed(seed) \n", 478 | " torch.backends.cudnn.benchmark = False\n", 479 | " torch.backends.cudnn.deterministic = True" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "metadata": { 486 | "id": "QTp3ZXg1yO9Y" 487 | }, 488 | "outputs": [], 489 | "source": [ 490 | "# fix random seed\n", 491 | "same_seeds(seed)\n", 492 | "\n", 493 | "# create model, define a loss function, and optimizer\n", 494 | "model = Classifier().to(device)\n", 495 | "criterion = nn.CrossEntropyLoss() \n", 496 | "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": { 502 | "id": "pwWH1KIqzxEr" 503 | }, 504 | "source": [ 505 | "## Training" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "metadata": { 512 | "colab": { 513 | "base_uri": "https://localhost:8080/" 514 | }, 515 | "id": "CdMWsBs7zzNs", 516 | "outputId": "17922ad2-a319-4253-8783-3e4939d0a7cf" 517 | }, 518 | "outputs": [], 519 | "source": [ 520 | "best_acc = 0.0\n", 521 | "for epoch in range(num_epoch):\n", 522 | " train_acc = 0.0\n", 523 | " train_loss = 0.0\n", 524 | " val_acc = 0.0\n", 525 | " val_loss = 0.0\n", 526 | " \n", 527 | " # training\n", 528 | " model.train() # set the model to training mode\n", 529 | " for i, batch in enumerate(tqdm(train_loader)):\n", 530 | " features, labels = batch\n", 531 | " features = features.to(device)\n", 532 | " labels = labels.to(device)\n", 533 | " \n", 534 | " optimizer.zero_grad() \n", 535 | " outputs = model(features) \n", 536 | " \n", 537 | " loss = criterion(outputs, labels)\n", 538 | " loss.backward() \n", 539 | " optimizer.step() \n", 540 | " \n", 541 | " _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n", 542 | " train_acc += (train_pred.detach() == labels.detach()).sum().item()\n", 543 | " train_loss += loss.item()\n", 544 | " \n", 545 | " # validation\n", 546 | " if len(val_set) > 0:\n", 547 | " model.eval() # set the model to evaluation mode\n", 548 | " with torch.no_grad():\n", 549 | " for i, batch in enumerate(tqdm(val_loader)):\n", 550 | " features, labels = batch\n", 551 | " features = features.to(device)\n", 552 | " labels = labels.to(device)\n", 553 | " outputs = model(features)\n", 554 | " \n", 555 | " loss = criterion(outputs, labels) \n", 556 | " \n", 557 | " _, val_pred = torch.max(outputs, 1) \n", 558 | " val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability\n", 559 | " val_loss += loss.item()\n", 560 | "\n", 561 | " print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(\n", 562 | " epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)\n", 563 | " ))\n", 564 | "\n", 565 | " # if the model improves, save a checkpoint at this epoch\n", 566 | " if val_acc > best_acc:\n", 567 | " best_acc = val_acc\n", 568 | " torch.save(model.state_dict(), model_path)\n", 569 | " print('saving model with acc {:.3f}'.format(best_acc/len(val_set)))\n", 570 | " else:\n", 571 | " print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(\n", 572 | " epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)\n", 573 | " ))\n", 574 | "\n", 575 | "# if not validating, save the last epoch\n", 576 | "if len(val_set) == 0:\n", 577 | " torch.save(model.state_dict(), model_path)\n", 578 | " print('saving model at last epoch')\n" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "# 0.458 Baseline\n", 588 | "# 0.633 11 frames + 2 layers\n", 589 | "# 0.689\n", 590 | "# 0.714" 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": null, 596 | "metadata": { 597 | "colab": { 598 | "base_uri": "https://localhost:8080/" 599 | }, 600 | "id": "ab33MxosWLmG", 601 | "outputId": "911e8c9b-fc0f-4591-b0f6-311a1231c5e2" 602 | }, 603 | "outputs": [], 604 | "source": [ 605 | "del train_loader, val_loader\n", 606 | "gc.collect()" 607 | ] 608 | }, 609 | { 610 | "cell_type": "markdown", 611 | "metadata": { 612 | "id": "1Hi7jTn3PX-m" 613 | }, 614 | "source": [ 615 | "## Testing\n", 616 | "Create a testing dataset, and load model from the saved checkpoint." 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": null, 622 | "metadata": { 623 | "colab": { 624 | "base_uri": "https://localhost:8080/" 625 | }, 626 | "id": "VOG1Ou0PGrhc", 627 | "outputId": "abaaa25b-a93c-49b0-d228-9eca1e2ab2e0" 628 | }, 629 | "outputs": [], 630 | "source": [ 631 | "# load data\n", 632 | "test_X = preprocess_data(split='test', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes)\n", 633 | "test_set = LibriDataset(test_X, None)\n", 634 | "test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": null, 640 | "metadata": {}, 641 | "outputs": [], 642 | "source": [ 643 | "# [Dataset] - # phone classes: 41, number of utterances for test: 1078\n", 644 | "# 1078it [00:00, 2496.92it/s]\n", 645 | "# [INFO] test set\n", 646 | "# torch.Size([646268, 39])" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": { 653 | "colab": { 654 | "base_uri": "https://localhost:8080/" 655 | }, 656 | "id": "ay0Fu8Ovkdad", 657 | "outputId": "e5b20aa7-4d8b-43a9-e068-f5c89706a360" 658 | }, 659 | "outputs": [], 660 | "source": [ 661 | "# load model\n", 662 | "model = Classifier().to(device)\n", 663 | "model.load_state_dict(torch.load(model_path))" 664 | ] 665 | }, 666 | { 667 | "cell_type": "markdown", 668 | "metadata": { 669 | "id": "zp-DV1p4r7Nz" 670 | }, 671 | "source": [ 672 | "Make prediction." 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "execution_count": null, 678 | "metadata": { 679 | "colab": { 680 | "base_uri": "https://localhost:8080/" 681 | }, 682 | "id": "84HU5GGjPqR0", 683 | "outputId": "cebd6694-8f74-44ff-f922-96ca4385acb8" 684 | }, 685 | "outputs": [], 686 | "source": [ 687 | "test_acc = 0.0\n", 688 | "test_lengths = 0\n", 689 | "pred = np.array([], dtype=np.int32)\n", 690 | "\n", 691 | "model.eval()\n", 692 | "with torch.no_grad():\n", 693 | " for i, batch in enumerate(tqdm(test_loader)):\n", 694 | " features = batch\n", 695 | " features = features.to(device)\n", 696 | "\n", 697 | " outputs = model(features)\n", 698 | "\n", 699 | " _, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n", 700 | " pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)\n" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": { 706 | "id": "wyZqy40Prz0v" 707 | }, 708 | "source": [ 709 | "Write prediction to a CSV file.\n", 710 | "\n", 711 | "After finish running this block, download the file `prediction.csv` from the files section on the left-hand side and submit it to Kaggle." 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": null, 717 | "metadata": { 718 | "id": "GuljYSPHcZir" 719 | }, 720 | "outputs": [], 721 | "source": [ 722 | "# with open('prediction/prediction_2.csv', 'w') as f:\n", 723 | "# f.write('Id,Class\\n')\n", 724 | "# for i, y in enumerate(pred):\n", 725 | "# f.write('{},{}\\n'.format(i, y))" 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "execution_count": null, 731 | "metadata": {}, 732 | "outputs": [], 733 | "source": [ 734 | "with open('prediction/prediction_rnn.csv', 'w') as f:\n", 735 | " f.write('Id,Class\\n')\n", 736 | " for i, y in enumerate(pred):\n", 737 | " f.write('{},{}\\n'.format(i, y))" 738 | ] 739 | } 740 | ], 741 | "metadata": { 742 | "accelerator": "GPU", 743 | "colab": { 744 | "collapsed_sections": [], 745 | "name": "ML2022Spring - HW2.ipynb", 746 | "provenance": [] 747 | }, 748 | "kernelspec": { 749 | "display_name": "kuokuo_env", 750 | "language": "python", 751 | "name": "kuokuo_env" 752 | }, 753 | "language_info": { 754 | "codemirror_mode": { 755 | "name": "ipython", 756 | "version": 3 757 | }, 758 | "file_extension": ".py", 759 | "mimetype": "text/x-python", 760 | "name": "python", 761 | "nbconvert_exporter": "python", 762 | "pygments_lexer": "ipython3", 763 | "version": "3.7.5" 764 | }, 765 | "toc": { 766 | "base_numbering": 1, 767 | "nav_menu": {}, 768 | "number_sections": true, 769 | "sideBar": true, 770 | "skip_h1_title": false, 771 | "title_cell": "Table of Contents", 772 | "title_sidebar": "Contents", 773 | "toc_cell": false, 774 | "toc_position": {}, 775 | "toc_section_display": true, 776 | "toc_window_display": false 777 | } 778 | }, 779 | "nbformat": 4, 780 | "nbformat_minor": 1 781 | } 782 | -------------------------------------------------------------------------------- /hw-04/hw4_strong.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "C_jdZ5vHJ4A9" 7 | }, 8 | "source": [ 9 | "# Task description\n", 10 | "- Classify the speakers of given features.\n", 11 | "- Main goal: Learn how to use transformer.\n", 12 | "- Baselines:\n", 13 | " - Easy: Run sample code and know how to use transformer.\n", 14 | " - Medium: Know how to adjust parameters of transformer.\n", 15 | " - Strong: Construct [conformer](https://arxiv.org/abs/2005.08100) which is a variety of transformer. \n", 16 | " - Boss: Implement [Self-Attention Pooling](https://arxiv.org/pdf/2008.01077v1.pdf) & [Additive Margin Softmax](https://arxiv.org/pdf/1801.05599.pdf) to further boost the performance.\n", 17 | "\n", 18 | "- Other links\n", 19 | " - Kaggle: [link](https://www.kaggle.com/t/ac77388c90204a4c8daebeddd40ff916)\n", 20 | " - Slide: [link](https://docs.google.com/presentation/d/1HLAj7UUIjZOycDe7DaVLSwJfXVd3bXPOyzSb6Zk3hYU/edit?usp=sharing)\n", 21 | " - Data: [link](https://drive.google.com/drive/folders/1vI1kuLB-q1VilIftiwnPOCAeOOFfBZge?usp=sharing)\n", 22 | "\n", 23 | "# Download dataset\n", 24 | "- Data is [here](https://drive.google.com/drive/folders/1vI1kuLB-q1VilIftiwnPOCAeOOFfBZge?usp=sharing)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "!pwd\n", 34 | "!nvidia-smi" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "id": "LhLNWB-AK2Z5" 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "!wget https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/Dataset.tar.gz.partaa\n", 46 | "!wget https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/Dataset.tar.gz.partab\n", 47 | "!wget https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/Dataset.tar.gz.partac\n", 48 | "!wget https://github.com/MachineLearningHW/ML_HW4_Dataset/releases/latest/download/Dataset.tar.gz.partad\n", 49 | "\n", 50 | "!cat Dataset.tar.gz.part* > Dataset.tar.gz\n", 51 | "\n", 52 | "# unzip the file\n", 53 | "!tar zxvf Dataset.tar.gz" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "id": "ENWVAUDVJtVY" 60 | }, 61 | "source": [ 62 | "## Fix Random Seed" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "id": "E6burzCXIyuA" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "import numpy as np\n", 74 | "import torch\n", 75 | "import random\n", 76 | "\n", 77 | "def set_seed(seed):\n", 78 | " np.random.seed(seed)\n", 79 | " random.seed(seed)\n", 80 | " torch.manual_seed(seed)\n", 81 | " if torch.cuda.is_available():\n", 82 | " torch.cuda.manual_seed(seed)\n", 83 | " torch.cuda.manual_seed_all(seed)\n", 84 | " torch.backends.cudnn.benchmark = False\n", 85 | " torch.backends.cudnn.deterministic = True\n", 86 | "\n", 87 | "set_seed(42)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# Model and output's file name\n", 97 | "NAME = \"strong\"" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": { 103 | "id": "k7dVbxW2LASN" 104 | }, 105 | "source": [ 106 | "# Data\n", 107 | "\n", 108 | "## Dataset\n", 109 | "- Original dataset is [Voxceleb2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html).\n", 110 | "- The [license](https://creativecommons.org/licenses/by/4.0/) and [complete version](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/files/license.txt) of Voxceleb2.\n", 111 | "- We randomly select 600 speakers from Voxceleb2.\n", 112 | "- Then preprocess the raw waveforms into mel-spectrograms.\n", 113 | "\n", 114 | "- Args:\n", 115 | " - data_dir: The path to the data directory.\n", 116 | " - metadata_path: The path to the metadata.\n", 117 | " - segment_len: The length of audio segment for training. \n", 118 | "- The architecture of data directory \\\\\n", 119 | " - data directory \\\\\n", 120 | " |---- metadata.json \\\\\n", 121 | " |---- testdata.json \\\\\n", 122 | " |---- mapping.json \\\\\n", 123 | " |---- uttr-{random string}.pt \\\\\n", 124 | "\n", 125 | "- The information in metadata\n", 126 | " - \"n_mels\": The dimention of mel-spectrogram.\n", 127 | " - \"speakers\": A dictionary. \n", 128 | " - Key: speaker ids.\n", 129 | " - value: \"feature_path\" and \"mel_len\"\n", 130 | "\n", 131 | "\n", 132 | "For efficiency, we segment the mel-spectrograms into segments in the traing step." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "id": "KpuGxl4CI2pr" 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "import os\n", 144 | "import json\n", 145 | "import torch\n", 146 | "import random\n", 147 | "from pathlib import Path\n", 148 | "from torch.utils.data import Dataset\n", 149 | "from torch.nn.utils.rnn import pad_sequence\n", 150 | "\n", 151 | "from conformer import ConformerBlock\n", 152 | "\n", 153 | "\n", 154 | "class myDataset(Dataset):\n", 155 | " def __init__(self, data_dir, segment_len=128):\n", 156 | " self.data_dir = data_dir\n", 157 | " self.segment_len = segment_len\n", 158 | "\n", 159 | " # Load the mapping from speaker neme to their corresponding id. \n", 160 | " mapping_path = Path(data_dir) / \"mapping.json\"\n", 161 | " mapping = json.load(mapping_path.open())\n", 162 | " self.speaker2id = mapping[\"speaker2id\"]\n", 163 | "\n", 164 | " # Load metadata of training data.\n", 165 | " metadata_path = Path(data_dir) / \"metadata.json\"\n", 166 | " metadata = json.load(open(metadata_path))[\"speakers\"]\n", 167 | "\n", 168 | " # Get the total number of speaker.\n", 169 | " self.speaker_num = len(metadata.keys())\n", 170 | " self.data = []\n", 171 | " for speaker in metadata.keys():\n", 172 | " for utterances in metadata[speaker]:\n", 173 | " self.data.append([utterances[\"feature_path\"], self.speaker2id[speaker]])\n", 174 | "\n", 175 | " def __len__(self):\n", 176 | " return len(self.data)\n", 177 | "\n", 178 | " def __getitem__(self, index):\n", 179 | " feat_path, speaker = self.data[index]\n", 180 | " # Load preprocessed mel-spectrogram.\n", 181 | " mel = torch.load(os.path.join(self.data_dir, feat_path))\n", 182 | "\n", 183 | " # Segmemt mel-spectrogram into \"segment_len\" frames.\n", 184 | " if len(mel) > self.segment_len:\n", 185 | " # Randomly get the starting point of the segment.\n", 186 | " start = random.randint(0, len(mel) - self.segment_len)\n", 187 | " # Get a segment with \"segment_len\" frames.\n", 188 | " mel = torch.FloatTensor(mel[start:start+self.segment_len])\n", 189 | " else:\n", 190 | " mel = torch.FloatTensor(mel)\n", 191 | " # Turn the speaker id into long for computing loss later.\n", 192 | " speaker = torch.FloatTensor([speaker]).long()\n", 193 | " return mel, speaker\n", 194 | "\n", 195 | " def get_speaker_number(self):\n", 196 | " return self.speaker_num" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": { 202 | "id": "668hverTMlGN" 203 | }, 204 | "source": [ 205 | "## Dataloader\n", 206 | "- Split dataset into training dataset(90%) and validation dataset(10%).\n", 207 | "- Create dataloader to iterate the data." 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "id": "B7c2gZYoJDRS" 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "import torch\n", 219 | "from torch.utils.data import DataLoader, random_split\n", 220 | "from torch.nn.utils.rnn import pad_sequence\n", 221 | "\n", 222 | "\n", 223 | "def collate_batch(batch):\n", 224 | " # Process features within a batch.\n", 225 | " \"\"\"Collate a batch of data.\"\"\"\n", 226 | " mel, speaker = zip(*batch)\n", 227 | " # Because we train the model batch by batch, we need to pad the features in the same batch to make their lengths the same.\n", 228 | " mel = pad_sequence(mel, batch_first=True, padding_value=-20) # pad log 10^(-20) which is very small value.\n", 229 | " # mel: (batch size, length, 40)\n", 230 | " return mel, torch.FloatTensor(speaker).long()\n", 231 | "\n", 232 | "\n", 233 | "def get_dataloader(data_dir, batch_size, n_workers):\n", 234 | " \"\"\"Generate dataloader\"\"\"\n", 235 | " dataset = myDataset(data_dir)\n", 236 | " speaker_num = dataset.get_speaker_number()\n", 237 | " # Split dataset into training dataset and validation dataset\n", 238 | " trainlen = int(0.9 * len(dataset))\n", 239 | " lengths = [trainlen, len(dataset) - trainlen]\n", 240 | " trainset, validset = random_split(dataset, lengths)\n", 241 | "\n", 242 | " train_loader = DataLoader(\n", 243 | " trainset,\n", 244 | " batch_size=batch_size,\n", 245 | " shuffle=True,\n", 246 | " drop_last=True,\n", 247 | " num_workers=n_workers,\n", 248 | " pin_memory=True,\n", 249 | " collate_fn=collate_batch,\n", 250 | " )\n", 251 | " valid_loader = DataLoader(\n", 252 | " validset,\n", 253 | " batch_size=batch_size,\n", 254 | " num_workers=n_workers,\n", 255 | " drop_last=True,\n", 256 | " pin_memory=True,\n", 257 | " collate_fn=collate_batch,\n", 258 | " )\n", 259 | "\n", 260 | " return train_loader, valid_loader, speaker_num" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": { 266 | "id": "5FOSZYxrMqhc" 267 | }, 268 | "source": [ 269 | "# Model\n", 270 | "- TransformerEncoderLayer:\n", 271 | " - Base transformer encoder layer in [Attention Is All You Need](https://arxiv.org/abs/1706.03762)\n", 272 | " - Parameters:\n", 273 | " - d_model: the number of expected features of the input (required).\n", 274 | "\n", 275 | " - nhead: the number of heads of the multiheadattention models (required).\n", 276 | "\n", 277 | " - dim_feedforward: the dimension of the feedforward network model (default=2048).\n", 278 | "\n", 279 | " - dropout: the dropout value (default=0.1).\n", 280 | "\n", 281 | " - activation: the activation function of intermediate layer, relu or gelu (default=relu).\n", 282 | "\n", 283 | "- TransformerEncoder:\n", 284 | " - TransformerEncoder is a stack of N transformer encoder layers\n", 285 | " - Parameters:\n", 286 | " - encoder_layer: an instance of the TransformerEncoderLayer() class (required).\n", 287 | "\n", 288 | " - num_layers: the number of sub-encoder-layers in the encoder (required).\n", 289 | "\n", 290 | " - norm: the layer normalization component (optional)." 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "import torch\n", 300 | "import torch.nn as nn\n", 301 | "import torch.nn.functional as F" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "class Classifier(nn.Module):\n", 311 | " def __init__(self, d_model=80, n_spks=600, dropout=0.1):\n", 312 | " super().__init__()\n", 313 | " # Project the dimension of features from that of input into d_model.\n", 314 | " self.prenet = nn.Linear(40, d_model)\n", 315 | " \n", 316 | " self.block1 = ConformerBlock(\n", 317 | " dim = d_model,\n", 318 | " dim_head = 32,\n", 319 | " heads = 8,\n", 320 | " ff_mult = 4,\n", 321 | " conv_expansion_factor = 2,\n", 322 | " conv_kernel_size = 31,\n", 323 | " attn_dropout = 0.,\n", 324 | " ff_dropout = 0.,\n", 325 | " conv_dropout = 0.\n", 326 | " )\n", 327 | " \n", 328 | " self.block2 = ConformerBlock(\n", 329 | " dim = d_model,\n", 330 | " dim_head = 32,\n", 331 | " heads = 8,\n", 332 | " ff_mult = 4,\n", 333 | " conv_expansion_factor = 2,\n", 334 | " conv_kernel_size = 31,\n", 335 | " attn_dropout = 0.,\n", 336 | " ff_dropout = 0.,\n", 337 | " conv_dropout = 0.\n", 338 | " )\n", 339 | " \n", 340 | " self.block3 = ConformerBlock(\n", 341 | " dim = d_model,\n", 342 | " dim_head = 32,\n", 343 | " heads = 8,\n", 344 | " ff_mult = 4,\n", 345 | " conv_expansion_factor = 2,\n", 346 | " conv_kernel_size = 31,\n", 347 | " attn_dropout = 0.1,\n", 348 | " ff_dropout = 0.1,\n", 349 | " conv_dropout = 0.1\n", 350 | " )\n", 351 | " \n", 352 | " # Project the the dimension of features from d_model into speaker nums.\n", 353 | " self.pred_layer = nn.Sequential(\n", 354 | " nn.Linear(d_model, n_spks),\n", 355 | " )\n", 356 | "\n", 357 | " def forward(self, mels):\n", 358 | " \"\"\"\n", 359 | " args:\n", 360 | " mels: (batch size, length, 40)\n", 361 | " return:\n", 362 | " out: (batch size, n_spks)\n", 363 | " \"\"\"\n", 364 | " # out: (batch size, length, d_model)\n", 365 | " out = self.prenet(mels)\n", 366 | " # The encoder layer\n", 367 | " out = self.block1(out)\n", 368 | " out = self.block2(out)\n", 369 | " out = self.block3(out)\n", 370 | " # mean pooling\n", 371 | " stats = out.mean(dim=1)\n", 372 | " # out: (batch, n_spks)\n", 373 | " out = self.pred_layer(stats)\n", 374 | " return out" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": { 380 | "id": "W7yX8JinM5Ly" 381 | }, 382 | "source": [ 383 | "# Learning rate schedule\n", 384 | "- For transformer architecture, the design of learning rate schedule is different from that of CNN.\n", 385 | "- Previous works show that the warmup of learning rate is useful for training models with transformer architectures.\n", 386 | "- The warmup schedule\n", 387 | " - Set learning rate to 0 in the beginning.\n", 388 | " - The learning rate increases linearly from 0 to initial learning rate during warmup period." 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": { 395 | "id": "ykt0N1nVJJi2" 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "import math\n", 400 | "\n", 401 | "import torch\n", 402 | "from torch.optim import Optimizer\n", 403 | "from torch.optim.lr_scheduler import LambdaLR\n", 404 | "\n", 405 | "\n", 406 | "def get_cosine_schedule_with_warmup(\n", 407 | " optimizer: Optimizer,\n", 408 | " num_warmup_steps: int,\n", 409 | " num_training_steps: int,\n", 410 | " num_cycles: float = 0.5,\n", 411 | " last_epoch: int = -1,\n", 412 | "):\n", 413 | " \"\"\"\n", 414 | " Create a schedule with a learning rate that decreases following the values of the cosine function between the\n", 415 | " initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n", 416 | " initial lr set in the optimizer.\n", 417 | "\n", 418 | " Args:\n", 419 | " optimizer (:class:`~torch.optim.Optimizer`):\n", 420 | " The optimizer for which to schedule the learning rate.\n", 421 | " num_warmup_steps (:obj:`int`):\n", 422 | " The number of steps for the warmup phase.\n", 423 | " num_training_steps (:obj:`int`):\n", 424 | " The total number of training steps.\n", 425 | " num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n", 426 | " The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n", 427 | " following a half-cosine).\n", 428 | " last_epoch (:obj:`int`, `optional`, defaults to -1):\n", 429 | " The index of the last epoch when resuming training.\n", 430 | "\n", 431 | " Return:\n", 432 | " :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n", 433 | " \"\"\"\n", 434 | " def lr_lambda(current_step):\n", 435 | " # Warmup\n", 436 | " if current_step < num_warmup_steps:\n", 437 | " return float(current_step) / float(max(1, num_warmup_steps))\n", 438 | " # decadence\n", 439 | " progress = float(current_step - num_warmup_steps) / float(\n", 440 | " max(1, num_training_steps - num_warmup_steps)\n", 441 | " )\n", 442 | " return max(\n", 443 | " 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n", 444 | " )\n", 445 | "\n", 446 | " return LambdaLR(optimizer, lr_lambda, last_epoch)" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": { 452 | "id": "-LN2XkteM_uH" 453 | }, 454 | "source": [ 455 | "# Model Function\n", 456 | "- Model forward function." 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": { 463 | "id": "N-rr8529JMz0" 464 | }, 465 | "outputs": [], 466 | "source": [ 467 | "import torch\n", 468 | "\n", 469 | "\n", 470 | "def model_fn(batch, model, criterion, device):\n", 471 | " \"\"\"Forward a batch through the model.\"\"\"\n", 472 | "\n", 473 | " mels, labels = batch\n", 474 | " mels = mels.to(device)\n", 475 | " labels = labels.to(device)\n", 476 | "\n", 477 | " outs = model(mels)\n", 478 | "\n", 479 | " loss = criterion(outs, labels)\n", 480 | "\n", 481 | " # Get the speaker id with highest probability.\n", 482 | " preds = outs.argmax(1)\n", 483 | " # Compute accuracy.\n", 484 | " accuracy = torch.mean((preds == labels).float())\n", 485 | "\n", 486 | " return loss, accuracy" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": { 492 | "id": "cwM_xyOtNCI2" 493 | }, 494 | "source": [ 495 | "# Validate\n", 496 | "- Calculate accuracy of the validation set." 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": null, 502 | "metadata": { 503 | "id": "YAiv6kpdJRTJ" 504 | }, 505 | "outputs": [], 506 | "source": [ 507 | "from tqdm import tqdm\n", 508 | "import torch\n", 509 | "\n", 510 | "\n", 511 | "def valid(dataloader, model, criterion, device): \n", 512 | " \"\"\"Validate on validation set.\"\"\"\n", 513 | "\n", 514 | " model.eval()\n", 515 | " running_loss = 0.0\n", 516 | " running_accuracy = 0.0\n", 517 | " pbar = tqdm(total=len(dataloader.dataset), ncols=0, desc=\"Valid\", unit=\" uttr\")\n", 518 | "\n", 519 | " for i, batch in enumerate(dataloader):\n", 520 | " with torch.no_grad():\n", 521 | " loss, accuracy = model_fn(batch, model, criterion, device)\n", 522 | " running_loss += loss.item()\n", 523 | " running_accuracy += accuracy.item()\n", 524 | "\n", 525 | " pbar.update(dataloader.batch_size)\n", 526 | " pbar.set_postfix(\n", 527 | " loss=f\"{running_loss / (i+1):.2f}\",\n", 528 | " accuracy=f\"{running_accuracy / (i+1):.2f}\",\n", 529 | " )\n", 530 | "\n", 531 | " pbar.close()\n", 532 | " model.train()\n", 533 | "\n", 534 | " return running_accuracy / len(dataloader)" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": { 540 | "id": "g6ne9G-eNEdG" 541 | }, 542 | "source": [ 543 | "# Main function" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": {}, 550 | "outputs": [], 551 | "source": [ 552 | "from tqdm import tqdm\n", 553 | "\n", 554 | "import torch\n", 555 | "import torch.nn as nn\n", 556 | "from torch.optim import AdamW\n", 557 | "from torch.utils.data import DataLoader, random_split\n", 558 | "\n", 559 | "\n", 560 | "def parse_args():\n", 561 | " \"\"\"arguments\"\"\"\n", 562 | " config = {\n", 563 | " \"data_dir\": \"./Dataset\",\n", 564 | " \"save_path\": f\"./models/{NAME}.ckpt\",\n", 565 | " \"batch_size\": 256,\n", 566 | " \"n_workers\": 0,\n", 567 | " \"valid_steps\": 1000,\n", 568 | " \"warmup_steps\": 1000,\n", 569 | " \"save_steps\": 1000,\n", 570 | " \"total_steps\": 70000,\n", 571 | " }\n", 572 | "\n", 573 | " return config\n", 574 | "\n", 575 | "\n", 576 | "def main(\n", 577 | " data_dir,\n", 578 | " save_path,\n", 579 | " batch_size,\n", 580 | " n_workers,\n", 581 | " valid_steps,\n", 582 | " warmup_steps,\n", 583 | " total_steps,\n", 584 | " save_steps,\n", 585 | "):\n", 586 | " \"\"\"Main function.\"\"\"\n", 587 | " device = torch.device(\"cuda\", 1)\n", 588 | " print(f\"[Info]: Use {device} now!\")\n", 589 | "\n", 590 | " train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)\n", 591 | " train_iterator = iter(train_loader)\n", 592 | " print(f\"[Info]: Finish loading data!\",flush = True)\n", 593 | "\n", 594 | " model = Classifier(n_spks=speaker_num)\n", 595 | "# model.load_state_dict(torch.load(save_path))\n", 596 | " model = model.to(device)\n", 597 | " criterion = nn.CrossEntropyLoss()\n", 598 | " optimizer = AdamW(model.parameters(), lr=1e-3)\n", 599 | " scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)\n", 600 | " print(f\"[Info]: Finish creating model!\",flush = True)\n", 601 | "\n", 602 | " best_accuracy = -1.0\n", 603 | " best_state_dict = None\n", 604 | "\n", 605 | " pbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n", 606 | "\n", 607 | " for step in range(total_steps):\n", 608 | " # Get data\n", 609 | " try:\n", 610 | " batch = next(train_iterator)\n", 611 | " except StopIteration:\n", 612 | " train_iterator = iter(train_loader)\n", 613 | " batch = next(train_iterator)\n", 614 | "\n", 615 | " loss, accuracy = model_fn(batch, model, criterion, device)\n", 616 | " batch_loss = loss.item()\n", 617 | " batch_accuracy = accuracy.item()\n", 618 | "\n", 619 | " # Updata model\n", 620 | " loss.backward()\n", 621 | " optimizer.step()\n", 622 | " scheduler.step()\n", 623 | " optimizer.zero_grad()\n", 624 | "\n", 625 | " # Log\n", 626 | " pbar.update()\n", 627 | " pbar.set_postfix(\n", 628 | " loss=f\"{batch_loss:.2f}\",\n", 629 | " accuracy=f\"{batch_accuracy:.2f}\",\n", 630 | " step=step + 1,\n", 631 | " )\n", 632 | "\n", 633 | " # Do validation\n", 634 | " if (step + 1) % valid_steps == 0:\n", 635 | " pbar.close()\n", 636 | "\n", 637 | " valid_accuracy = valid(valid_loader, model, criterion, device)\n", 638 | "\n", 639 | " # keep the best model\n", 640 | " if valid_accuracy > best_accuracy:\n", 641 | " best_accuracy = valid_accuracy\n", 642 | " best_state_dict = model.state_dict()\n", 643 | "\n", 644 | " pbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n", 645 | "\n", 646 | " # Save the best model so far.\n", 647 | " if (step + 1) % save_steps == 0 and best_state_dict is not None:\n", 648 | " torch.save(best_state_dict, save_path)\n", 649 | " pbar.write(f\"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})\")\n", 650 | "\n", 651 | " pbar.close()\n", 652 | "\n", 653 | "\n", 654 | "if __name__ == \"__main__\":\n", 655 | " main(**parse_args())" 656 | ] 657 | }, 658 | { 659 | "cell_type": "markdown", 660 | "metadata": { 661 | "id": "NLatBYAhNNMx" 662 | }, 663 | "source": [ 664 | "# Inference\n", 665 | "\n", 666 | "## Dataset of inference" 667 | ] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "execution_count": null, 672 | "metadata": { 673 | "colab": { 674 | "background_save": true 675 | }, 676 | "id": "efS4pCmAJXJH" 677 | }, 678 | "outputs": [], 679 | "source": [ 680 | "import os\n", 681 | "import json\n", 682 | "import torch\n", 683 | "from pathlib import Path\n", 684 | "from torch.utils.data import Dataset\n", 685 | "\n", 686 | "\n", 687 | "class InferenceDataset(Dataset):\n", 688 | " def __init__(self, data_dir):\n", 689 | " testdata_path = Path(data_dir) / \"testdata.json\"\n", 690 | " metadata = json.load(testdata_path.open())\n", 691 | " self.data_dir = data_dir\n", 692 | " self.data = metadata[\"utterances\"]\n", 693 | "\n", 694 | " def __len__(self):\n", 695 | " return len(self.data)\n", 696 | "\n", 697 | " def __getitem__(self, index):\n", 698 | " utterance = self.data[index]\n", 699 | " feat_path = utterance[\"feature_path\"]\n", 700 | " mel = torch.load(os.path.join(self.data_dir, feat_path))\n", 701 | "\n", 702 | " return feat_path, mel\n", 703 | "\n", 704 | "\n", 705 | "def inference_collate_batch(batch):\n", 706 | " \"\"\"Collate a batch of data.\"\"\"\n", 707 | " feat_paths, mels = zip(*batch)\n", 708 | "\n", 709 | " return feat_paths, torch.stack(mels)" 710 | ] 711 | }, 712 | { 713 | "cell_type": "markdown", 714 | "metadata": { 715 | "id": "tl0WnYwxNK_S" 716 | }, 717 | "source": [ 718 | "## Main funcrion of Inference" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": null, 724 | "metadata": { 725 | "colab": { 726 | "background_save": true 727 | }, 728 | "id": "i8SAbuXEJb2A" 729 | }, 730 | "outputs": [], 731 | "source": [ 732 | "import json\n", 733 | "import csv\n", 734 | "from pathlib import Path\n", 735 | "from tqdm.notebook import tqdm\n", 736 | "\n", 737 | "import torch\n", 738 | "from torch.utils.data import DataLoader\n", 739 | "\n", 740 | "def parse_args():\n", 741 | " \"\"\"arguments\"\"\"\n", 742 | " config = {\n", 743 | " \"data_dir\": \"./Dataset\",\n", 744 | " \"model_path\": f\"./models/{NAME}.ckpt\",\n", 745 | " \"output_path\": f\"./outputs/{NAME}.csv\",\n", 746 | " }\n", 747 | "\n", 748 | " return config\n", 749 | "\n", 750 | "\n", 751 | "def main(\n", 752 | " data_dir,\n", 753 | " model_path,\n", 754 | " output_path,\n", 755 | "):\n", 756 | " \"\"\"Main function.\"\"\"\n", 757 | " device = torch.device(\"cuda\", 1)\n", 758 | " print(f\"[Info]: Use {device} now!\")\n", 759 | "\n", 760 | " mapping_path = Path(data_dir) / \"mapping.json\"\n", 761 | " mapping = json.load(mapping_path.open())\n", 762 | "\n", 763 | " dataset = InferenceDataset(data_dir)\n", 764 | " dataloader = DataLoader(\n", 765 | " dataset,\n", 766 | " batch_size=1,\n", 767 | " shuffle=False,\n", 768 | " drop_last=False,\n", 769 | " num_workers=8,\n", 770 | " collate_fn=inference_collate_batch,\n", 771 | " )\n", 772 | " print(f\"[Info]: Finish loading data!\",flush = True)\n", 773 | "\n", 774 | " speaker_num = len(mapping[\"id2speaker\"])\n", 775 | " model = Classifier(n_spks=speaker_num).to(device)\n", 776 | " model.load_state_dict(torch.load(model_path))\n", 777 | " model.eval()\n", 778 | " print(f\"[Info]: Finish creating model!\",flush = True)\n", 779 | "\n", 780 | " results = [[\"Id\", \"Category\"]]\n", 781 | " for feat_paths, mels in tqdm(dataloader):\n", 782 | " with torch.no_grad():\n", 783 | " mels = mels.to(device)\n", 784 | " outs = model(mels)\n", 785 | " preds = outs.argmax(1).cpu().numpy()\n", 786 | " for feat_path, pred in zip(feat_paths, preds):\n", 787 | " results.append([feat_path, mapping[\"id2speaker\"][str(pred)]])\n", 788 | "\n", 789 | " with open(output_path, 'w', newline='') as csvfile:\n", 790 | " writer = csv.writer(csvfile)\n", 791 | " writer.writerows(results)\n", 792 | "\n", 793 | "\n", 794 | "if __name__ == \"__main__\":\n", 795 | " main(**parse_args())" 796 | ] 797 | } 798 | ], 799 | "metadata": { 800 | "accelerator": "GPU", 801 | "colab": { 802 | "collapsed_sections": [], 803 | "name": "hw04.ipynb", 804 | "provenance": [] 805 | }, 806 | "kernelspec": { 807 | "display_name": "kuokuo_env", 808 | "language": "python", 809 | "name": "kuokuo_env" 810 | }, 811 | "language_info": { 812 | "codemirror_mode": { 813 | "name": "ipython", 814 | "version": 3 815 | }, 816 | "file_extension": ".py", 817 | "mimetype": "text/x-python", 818 | "name": "python", 819 | "nbconvert_exporter": "python", 820 | "pygments_lexer": "ipython3", 821 | "version": "3.7.5" 822 | }, 823 | "toc": { 824 | "base_numbering": 1, 825 | "nav_menu": {}, 826 | "number_sections": true, 827 | "sideBar": true, 828 | "skip_h1_title": false, 829 | "title_cell": "Table of Contents", 830 | "title_sidebar": "Contents", 831 | "toc_cell": false, 832 | "toc_position": {}, 833 | "toc_section_display": true, 834 | "toc_window_display": false 835 | } 836 | }, 837 | "nbformat": 4, 838 | "nbformat_minor": 1 839 | } 840 | -------------------------------------------------------------------------------- /hw-03/hw3_strong.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "jRDuJsGCgxCO" 7 | }, 8 | "source": [ 9 | "# HW3 Image Classification\n", 10 | "## We strongly recommend that you run with Kaggle for this homework\n", 11 | "https://www.kaggle.com/c/ml2022spring-hw3b/code?competitionId=34954&sortBy=dateCreated" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "id": "EVgrPb3HhJUT" 18 | }, 19 | "source": [ 20 | "# Get Data\n", 21 | "Notes: if the links are dead, you can download the data directly from Kaggle and upload it to the workspace, or you can use the Kaggle API to directly download the data into colab.\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "! nvidia-smi" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 38 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 39 | "id": "EAO6dg9eVaU_", 40 | "papermill": { 41 | "duration": 19.351342, 42 | "end_time": "2022-02-23T10:03:06.247288", 43 | "exception": false, 44 | "start_time": "2022-02-23T10:02:46.895946", 45 | "status": "completed" 46 | }, 47 | "tags": [] 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "# ! wget https://www.dropbox.com/s/6l2vcvxl54b0b6w/food11.zip\n", 52 | "# Size: 1.08 GB\n", 53 | "! wget -O food11.zip \"https://github.com/virginiakm1988/ML2022-Spring/blob/main/HW03/food11.zip?raw=true\"" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "id": "HEsBm1lkhGmk" 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "! unzip food11.zip" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": { 70 | "id": "n5ceUnRihL-f" 71 | }, 72 | "source": [ 73 | "# Training" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "id": "ay3WkYnHVaVE", 81 | "papermill": { 82 | "duration": 0.0189, 83 | "end_time": "2022-02-23T10:03:06.279758", 84 | "exception": false, 85 | "start_time": "2022-02-23T10:03:06.260858", 86 | "status": "completed" 87 | }, 88 | "tags": [] 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "_exp_name = \"sample\"" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": { 99 | "id": "CwOGtRWHVaVF", 100 | "papermill": { 101 | "duration": 1.654263, 102 | "end_time": "2022-02-23T10:03:07.947242", 103 | "exception": false, 104 | "start_time": "2022-02-23T10:03:06.292979", 105 | "status": "completed" 106 | }, 107 | "tags": [] 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "# Import necessary packages.\n", 112 | "import numpy as np\n", 113 | "import pandas as pd\n", 114 | "import torch\n", 115 | "import os\n", 116 | "import torch.nn as nn\n", 117 | "import torchvision.transforms as transforms\n", 118 | "from PIL import Image\n", 119 | "# \"ConcatDataset\" and \"Subset\" are possibly useful when doing semi-supervised learning.\n", 120 | "from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset\n", 121 | "from torchvision.datasets import DatasetFolder, VisionDataset\n", 122 | "\n", 123 | "# This is for the progress bar.\n", 124 | "from tqdm.auto import tqdm\n", 125 | "import random" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "id": "8kJm9GekVaVH", 133 | "papermill": { 134 | "duration": 0.078771, 135 | "end_time": "2022-02-23T10:03:08.039428", 136 | "exception": false, 137 | "start_time": "2022-02-23T10:03:07.960657", 138 | "status": "completed" 139 | }, 140 | "tags": [] 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "myseed = 42 # set a random seed for reproducibility\n", 145 | "torch.backends.cudnn.deterministic = True\n", 146 | "torch.backends.cudnn.benchmark = False\n", 147 | "np.random.seed(myseed)\n", 148 | "torch.manual_seed(myseed)\n", 149 | "if torch.cuda.is_available():\n", 150 | " torch.cuda.manual_seed_all(myseed)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": { 156 | "id": "d9MVtgbSVaVH", 157 | "papermill": { 158 | "duration": 0.01289, 159 | "end_time": "2022-02-23T10:03:08.065357", 160 | "exception": false, 161 | "start_time": "2022-02-23T10:03:08.052467", 162 | "status": "completed" 163 | }, 164 | "tags": [] 165 | }, 166 | "source": [ 167 | "## **Transforms**\n", 168 | "Torchvision provides lots of useful utilities for image preprocessing, data wrapping as well as data augmentation.\n", 169 | "\n", 170 | "Please refer to PyTorch official website for details about different transforms." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "id": "jvI3Xmq4VaVJ", 178 | "papermill": { 179 | "duration": 0.021406, 180 | "end_time": "2022-02-23T10:03:08.099437", 181 | "exception": false, 182 | "start_time": "2022-02-23T10:03:08.078031", 183 | "status": "completed" 184 | }, 185 | "tags": [] 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "# Normally, We don't need augmentations in testing and validation.\n", 190 | "# All we need here is to resize the PIL image and transform it into Tensor.\n", 191 | "test_tfm = transforms.Compose([\n", 192 | " transforms.Resize((128, 128)),\n", 193 | " transforms.ToTensor(),\n", 194 | "])\n", 195 | "\n", 196 | "# However, it is also possible to use augmentation in the testing phase.\n", 197 | "# You may use train_tfm to produce a variety of images and then test using ensemble methods\n", 198 | "train_tfm = transforms.Compose([\n", 199 | " # Resize the image into a fixed shape (height = width = 128)\n", 200 | " transforms.Resize((128, 128)),\n", 201 | " # You may add some transforms here.\n", 202 | " # ToTensor() should be the last one of the transforms.\n", 203 | " transforms.RandomHorizontalFlip(p=0.3),\n", 204 | " transforms.RandomVerticalFlip(p=0.3),\n", 205 | " transforms.RandomRotation(45),\n", 206 | " transforms.RandomGrayscale(p=0.1),\n", 207 | " transforms.ToTensor(),\n", 208 | "])\n" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": { 214 | "id": "D0ivMf-jVaVK", 215 | "papermill": { 216 | "duration": 0.012739, 217 | "end_time": "2022-02-23T10:03:08.125181", 218 | "exception": false, 219 | "start_time": "2022-02-23T10:03:08.112442", 220 | "status": "completed" 221 | }, 222 | "tags": [] 223 | }, 224 | "source": [ 225 | "## **Datasets**\n", 226 | "The data is labelled by the name, so we load images and label while calling '__getitem__'" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "id": "xBdtPhKwVaVL", 234 | "papermill": { 235 | "duration": 0.023022, 236 | "end_time": "2022-02-23T10:03:08.160912", 237 | "exception": false, 238 | "start_time": "2022-02-23T10:03:08.13789", 239 | "status": "completed" 240 | }, 241 | "tags": [] 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "class FoodDataset(Dataset):\n", 246 | "\n", 247 | " def __init__(self,path,tfm=test_tfm,files = None):\n", 248 | " super(FoodDataset).__init__()\n", 249 | " self.path = path\n", 250 | " self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(\".jpg\")])\n", 251 | " if files != None:\n", 252 | " self.files = files\n", 253 | " print(f\"One {path} sample\",self.files[0])\n", 254 | " self.transform = tfm\n", 255 | " \n", 256 | " def __len__(self):\n", 257 | " return len(self.files)\n", 258 | " \n", 259 | " def __getitem__(self,idx):\n", 260 | " fname = self.files[idx]\n", 261 | " im = Image.open(fname)\n", 262 | " im = self.transform(im)\n", 263 | " #im = self.data[idx]\n", 264 | " try:\n", 265 | " label = int(fname.split(\"/\")[-1].split(\"_\")[0])\n", 266 | " except:\n", 267 | " label = -1 # test has no label\n", 268 | " return im,label\n", 269 | "\n" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "class Residual_Block(nn.Module):\n", 279 | " def __init__(self, i_channel, o_channel, stride=1, down_sample=None):\n", 280 | " super(Residual_Block, self).__init__()\n", 281 | " self.conv1 = nn.Conv2d(in_channels=i_channel, \n", 282 | " out_channels=o_channel, \n", 283 | " kernel_size=3, \n", 284 | " stride=stride, \n", 285 | " padding=1,\n", 286 | " bias=False)\n", 287 | " self.bn1 = nn.BatchNorm2d(o_channel)\n", 288 | " self.relu = nn.ReLU(inplace=True)\n", 289 | " self.conv2 = nn.Conv2d(in_channels=o_channel, \n", 290 | " out_channels=o_channel, \n", 291 | " kernel_size=3, \n", 292 | " stride=1, \n", 293 | " padding=1,\n", 294 | " bias=False)\n", 295 | " self.bn2 = nn.BatchNorm2d(o_channel)\n", 296 | " self.down_sample = down_sample\n", 297 | "\n", 298 | " def forward(self, x):\n", 299 | " residual = x\n", 300 | " out = self.conv1(x)\n", 301 | " out = self.bn1(out)\n", 302 | " out = self.relu(out)\n", 303 | " out = self.conv2(out)\n", 304 | " out = self.bn2(out)\n", 305 | " \n", 306 | " if self.down_sample:\n", 307 | " residual = self.down_sample(x)\n", 308 | " out += residual\n", 309 | " out = self.relu(out)\n", 310 | "\n", 311 | " return out" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "class ResNet(nn.Module):\n", 321 | " def __init__(self, block, layers, num_classes=11):\n", 322 | " super(ResNet, self).__init__()\n", 323 | " self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False)\n", 324 | " self.in_channels = 16\n", 325 | " self.bn = nn.BatchNorm2d(16)\n", 326 | " self.relu1 = nn.ReLU(inplace=True)\n", 327 | " self.layer1 = self.make_layer(block, 16, layers[0])\n", 328 | " self.layer2 = self.make_layer(block, 32, layers[0], 2)\n", 329 | " self.layer3 = self.make_layer(block, 64, layers[1], 2)\n", 330 | " self.avg_pool = nn.AvgPool2d(8)\n", 331 | " self.fc1 = nn.Linear(1024, 256)\n", 332 | " self.relu2 = nn.ReLU(inplace=True)\n", 333 | " self.dropout = nn.Dropout(p=0.2)\n", 334 | " self.fc2 = nn.Linear(256, num_classes)\n", 335 | " \n", 336 | " def make_layer(self, block, out_channels, blocks, stride=1): \n", 337 | " down_sample = None\n", 338 | " if (stride != 1) or (self.in_channels != out_channels):\n", 339 | " down_sample = nn.Sequential(\n", 340 | " nn.Conv2d(self.in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),\n", 341 | " nn.BatchNorm2d(out_channels)\n", 342 | " )\n", 343 | "\n", 344 | " layers = []\n", 345 | " layers.append(block(self.in_channels, out_channels, stride, down_sample))\n", 346 | " self.in_channels = out_channels\n", 347 | " for i in range(1, blocks):\n", 348 | " layers.append(block(out_channels, out_channels))\n", 349 | " return nn.Sequential(*layers)\n", 350 | "\n", 351 | " def forward(self, x):\n", 352 | " out = self.conv(x)\n", 353 | " out = self.bn(out)\n", 354 | " out = self.relu1(out)\n", 355 | " out = self.layer1(out)\n", 356 | " out = self.layer2(out)\n", 357 | " out = self.layer3(out)\n", 358 | " out = self.avg_pool(out)\n", 359 | " out = out.view(out.size()[0], -1)\n", 360 | " out = self.fc1(out)\n", 361 | " out = self.relu2(out)\n", 362 | " out = self.dropout(out)\n", 363 | " out = self.fc2(out)\n", 364 | " return out" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": { 371 | "id": "2_OeWtstVaVO", 372 | "papermill": { 373 | "duration": 0.054295, 374 | "end_time": "2022-02-23T10:03:08.266338", 375 | "exception": false, 376 | "start_time": "2022-02-23T10:03:08.212043", 377 | "status": "completed" 378 | }, 379 | "tags": [] 380 | }, 381 | "outputs": [], 382 | "source": [ 383 | "batch_size = 128\n", 384 | "_dataset_dir = \"./food11\"\n", 385 | "# Construct datasets.\n", 386 | "# The argument \"loader\" tells how torchvision reads the data.\n", 387 | "train_set = FoodDataset(os.path.join(_dataset_dir,\"training\"), tfm=train_tfm)\n", 388 | "train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)\n", 389 | "valid_set = FoodDataset(os.path.join(_dataset_dir,\"validation\"), tfm=test_tfm)\n", 390 | "valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": { 397 | "id": "zbVkfIFhVaVO", 398 | "papermill": { 399 | "duration": 32830.720158, 400 | "end_time": "2022-02-23T19:10:19.001001", 401 | "exception": false, 402 | "start_time": "2022-02-23T10:03:08.280843", 403 | "status": "completed" 404 | }, 405 | "tags": [] 406 | }, 407 | "outputs": [], 408 | "source": [ 409 | "# \"cuda\" only when GPUs are available.\n", 410 | "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", 411 | "\n", 412 | "# The number of training epochs and patience.\n", 413 | "n_epochs = 600\n", 414 | "patience = 600 # If no improvement in 'patience' epochs, early stop\n", 415 | "\n", 416 | "# Initialize a model, and put it on the device specified.\n", 417 | "model = ResNet(Residual_Block, [2, 2, 2, 2]).to(device)\n", 418 | "\n", 419 | "# For the classification task, we use cross-entropy as the measurement of performance.\n", 420 | "criterion = nn.CrossEntropyLoss()\n", 421 | "\n", 422 | "# Initialize optimizer, you may fine-tune some hyperparameters such as learning rate on your own.\n", 423 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.999]) \n", 424 | "\n", 425 | "# Initialize trackers, these are not parameters and should not be changed\n", 426 | "stale = 0\n", 427 | "best_acc = 0\n", 428 | "\n", 429 | "for epoch in range(n_epochs):\n", 430 | "\n", 431 | " # ---------- Training ----------\n", 432 | " # Make sure the model is in train mode before training.\n", 433 | " model.train()\n", 434 | "\n", 435 | " # These are used to record information in training.\n", 436 | " train_loss = []\n", 437 | " train_accs = []\n", 438 | "\n", 439 | " for batch in tqdm(train_loader):\n", 440 | "\n", 441 | " # A batch consists of image data and corresponding labels.\n", 442 | " imgs, labels = batch\n", 443 | " #imgs = imgs.half()\n", 444 | " #print(imgs.shape,labels.shape)\n", 445 | "\n", 446 | " # Forward the data. (Make sure data and model are on the same device.)\n", 447 | " logits = model(imgs.to(device))\n", 448 | "\n", 449 | " # Calculate the cross-entropy loss.\n", 450 | " # We don't need to apply softmax before computing cross-entropy as it is done automatically.\n", 451 | " loss = criterion(logits, labels.to(device))\n", 452 | "\n", 453 | " # Gradients stored in the parameters in the previous step should be cleared out first.\n", 454 | " optimizer.zero_grad()\n", 455 | "\n", 456 | " # Compute the gradients for parameters.\n", 457 | " loss.backward()\n", 458 | "\n", 459 | " # Clip the gradient norms for stable training.\n", 460 | " grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)\n", 461 | "\n", 462 | " # Update the parameters with computed gradients.\n", 463 | " optimizer.step()\n", 464 | "\n", 465 | " # Compute the accuracy for current batch.\n", 466 | " acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()\n", 467 | "\n", 468 | " # Record the loss and accuracy.\n", 469 | " train_loss.append(loss.item())\n", 470 | " train_accs.append(acc)\n", 471 | " \n", 472 | " train_loss = sum(train_loss) / len(train_loss)\n", 473 | " train_acc = sum(train_accs) / len(train_accs)\n", 474 | "\n", 475 | " # Print the information.\n", 476 | " print(f\"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}\")\n", 477 | "\n", 478 | " # ---------- Validation ----------\n", 479 | " # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.\n", 480 | " model.eval()\n", 481 | "\n", 482 | " # These are used to record information in validation.\n", 483 | " valid_loss = []\n", 484 | " valid_accs = []\n", 485 | "\n", 486 | " # Iterate the validation set by batches.\n", 487 | " for batch in tqdm(valid_loader):\n", 488 | "\n", 489 | " # A batch consists of image data and corresponding labels.\n", 490 | " imgs, labels = batch\n", 491 | " #imgs = imgs.half()\n", 492 | "\n", 493 | " # We don't need gradient in validation.\n", 494 | " # Using torch.no_grad() accelerates the forward process.\n", 495 | " with torch.no_grad():\n", 496 | " logits = model(imgs.to(device))\n", 497 | "\n", 498 | " # We can still compute the loss (but not the gradient).\n", 499 | " loss = criterion(logits, labels.to(device))\n", 500 | "\n", 501 | " # Compute the accuracy for current batch.\n", 502 | " acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()\n", 503 | "\n", 504 | " # Record the loss and accuracy.\n", 505 | " valid_loss.append(loss.item())\n", 506 | " valid_accs.append(acc)\n", 507 | " #break\n", 508 | "\n", 509 | " # The average loss and accuracy for entire validation set is the average of the recorded values.\n", 510 | " valid_loss = sum(valid_loss) / len(valid_loss)\n", 511 | " valid_acc = sum(valid_accs) / len(valid_accs)\n", 512 | "\n", 513 | " # Print the information.\n", 514 | " print(f\"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}\")\n", 515 | "\n", 516 | "\n", 517 | " # update logs\n", 518 | " if valid_acc > best_acc:\n", 519 | " with open(f\"./{_exp_name}_log.txt\",\"a\"):\n", 520 | " print(f\"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f} -> best\")\n", 521 | " else:\n", 522 | " with open(f\"./{_exp_name}_log.txt\",\"a\"):\n", 523 | " print(f\"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}\")\n", 524 | "\n", 525 | "\n", 526 | " # save models\n", 527 | " if valid_acc > best_acc:\n", 528 | " print(f\"Best model found at epoch {epoch}, saving model\")\n", 529 | " torch.save(model.state_dict(), f\"{_exp_name}_best.ckpt\") # only save best to prevent output memory exceed error\n", 530 | " best_acc = valid_acc\n", 531 | " stale = 0\n", 532 | " else:\n", 533 | " stale += 1\n", 534 | " if stale > patience:\n", 535 | " print(f\"No improvment {patience} consecutive epochs, early stopping\")\n", 536 | " break" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": { 542 | "id": "G31uyjpvVaVP", 543 | "papermill": { 544 | "duration": 0.498773, 545 | "end_time": "2022-02-23T19:10:20.961802", 546 | "exception": false, 547 | "start_time": "2022-02-23T19:10:20.463029", 548 | "status": "completed" 549 | }, 550 | "tags": [] 551 | }, 552 | "source": [ 553 | "# Testing and generate prediction CSV" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "class FoodDataset(Dataset):\n", 563 | "\n", 564 | " def __init__(self, path, mode=\"train\", files = None):\n", 565 | " super(FoodDataset).__init__()\n", 566 | " self.path = path\n", 567 | " self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(\".jpg\")])\n", 568 | " if files != None:\n", 569 | " self.files = files\n", 570 | " print(f\"One {path} sample\",self.files[0])\n", 571 | " self.mode = mode\n", 572 | " \n", 573 | " def __len__(self):\n", 574 | " return len(self.files)\n", 575 | " \n", 576 | " def __getitem__(self,idx):\n", 577 | " fname = self.files[idx]\n", 578 | " im = Image.open(fname)\n", 579 | " if self.mode == \"training\":\n", 580 | " im = train_tfm(im)\n", 581 | " label = int(fname.split(\"/\")[-1].split(\"_\")[0])\n", 582 | " return im, label\n", 583 | " elif self.mode == \"validation\":\n", 584 | " im = test_tfm(im)\n", 585 | " label = -1 # test has no label\n", 586 | " return im, label\n", 587 | " else:\n", 588 | " ims = [test_tfm(im)]\n", 589 | " ims += [train_tfm(im) for _ in range(10)]\n", 590 | " label = -1 # test has no label\n", 591 | " return torch.stack(ims), label" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": null, 597 | "metadata": { 598 | "id": "B9QNdHIXVaVP", 599 | "papermill": { 600 | "duration": 0.493644, 601 | "end_time": "2022-02-23T19:10:19.985992", 602 | "exception": false, 603 | "start_time": "2022-02-23T19:10:19.492348", 604 | "status": "completed" 605 | }, 606 | "tags": [] 607 | }, 608 | "outputs": [], 609 | "source": [ 610 | "test_set = FoodDataset(os.path.join(_dataset_dir,\"test\"), mode=\"test\")\n", 611 | "test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": { 618 | "id": "bpLtxx5FVaVP", 619 | "papermill": { 620 | "duration": 49.157727, 621 | "end_time": "2022-02-23T19:11:10.61523", 622 | "exception": false, 623 | "start_time": "2022-02-23T19:10:21.457503", 624 | "status": "completed" 625 | }, 626 | "tags": [] 627 | }, 628 | "outputs": [], 629 | "source": [ 630 | "model_best = ResNet(Residual_Block, [2, 2, 2, 2]).to(device)\n", 631 | "model_best.load_state_dict(torch.load(f\"{_exp_name}_best.ckpt\"))\n", 632 | "model_best.eval()\n", 633 | "prediction = []\n", 634 | "with torch.no_grad():\n", 635 | " for img_list, _ in test_loader:\n", 636 | " # TTA\n", 637 | " test_pred = []\n", 638 | " for imgs in img_list:\n", 639 | " imgs_first = imgs[0].unsqueeze(0)\n", 640 | " origin_logit = model_best(imgs_first.to(device)).squeeze(0)\n", 641 | " tta_logit = model_best(imgs[1:].to(device))\n", 642 | " tta_logit = torch.mean(tta_logit, 0)\n", 643 | " logit = (0.6*origin_logit) + (0.4*tta_logit)\n", 644 | " test_pred.append(logit)\n", 645 | " test_pred = torch.stack(test_pred)\n", 646 | " \n", 647 | " test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)\n", 648 | " prediction += test_label.squeeze().tolist()" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": { 655 | "id": "fKupB3VUVaVQ", 656 | "papermill": { 657 | "duration": 0.554276, 658 | "end_time": "2022-02-23T19:11:11.870035", 659 | "exception": false, 660 | "start_time": "2022-02-23T19:11:11.315759", 661 | "status": "completed" 662 | }, 663 | "tags": [] 664 | }, 665 | "outputs": [], 666 | "source": [ 667 | "#create test csv\n", 668 | "def pad4(i):\n", 669 | " return \"0\"*(4-len(str(i)))+str(i)\n", 670 | "df = pd.DataFrame()\n", 671 | "df[\"Id\"] = [pad4(i) for i in range(1,len(test_set)+1)]\n", 672 | "df[\"Category\"] = prediction\n", 673 | "df.to_csv(\"submission.csv\",index = False)" 674 | ] 675 | }, 676 | { 677 | "cell_type": "markdown", 678 | "metadata": { 679 | "id": "Ivk0hrE-V8Cu" 680 | }, 681 | "source": [ 682 | "# Q1. Augmentation Implementation\n", 683 | "## Implement augmentation by finishing train_tfm in the code with image size of your choice. \n", 684 | "## Directly copy the following block and paste it on GradeScope after you finish the code\n", 685 | "### Your train_tfm must be capable of producing 5+ different results when given an identical image multiple times.\n", 686 | "### Your train_tfm in the report can be different from train_tfm in your training code.\n" 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": null, 692 | "metadata": { 693 | "id": "GSfKNo42WjKm" 694 | }, 695 | "outputs": [], 696 | "source": [ 697 | "# train_tfm = transforms.Compose([\n", 698 | "# transforms.Resize((224, 224)),\n", 699 | "# transforms.RandomHorizontalFlip(p=0.5),\n", 700 | "# transforms.RandomVerticalFlip(p=0.5),\n", 701 | "# transforms.RandomGrayscale(p=0.1),\n", 702 | "# transforms.RandomApply(transforms=[transforms.RandomRotation(degrees=(0, 180))], p=0.5),\n", 703 | "# transforms.RandomApply(transforms=[transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=0.3),\n", 704 | "# transforms.RandomApply(transforms=[transforms.RandomPosterize(bits=2)], p=0.2),\n", 705 | "# transforms.RandomApply(transforms=[transforms.RandomAdjustSharpness(sharpness_factor=2)], p=0.3),\n", 706 | "# transforms.RandomApply(transforms=[transforms.RandomAutocontrast()], p=0.3),\n", 707 | "# transforms.RandomApply(transforms=[transforms.TrivialAugmentWide()], p=0.1),\n", 708 | "# transforms.ToTensor(),\n", 709 | "# ])" 710 | ] 711 | }, 712 | { 713 | "cell_type": "markdown", 714 | "metadata": { 715 | "id": "3HemRgZ6WwRM" 716 | }, 717 | "source": [ 718 | "# Q2. Residual Implementation\n", 719 | "![](https://i.imgur.com/GYsq1Ap.png)\n", 720 | "## Directly copy the following block and paste it on GradeScope after you finish the code\n" 721 | ] 722 | }, 723 | { 724 | "cell_type": "code", 725 | "execution_count": null, 726 | "metadata": { 727 | "id": "Q4OK9kRaWuiV" 728 | }, 729 | "outputs": [], 730 | "source": [ 731 | "# from torch import nn\n", 732 | "# class Residual_Network(nn.Module):\n", 733 | "# def __init__(self):\n", 734 | "# super(Residual_Network, self).__init__()\n", 735 | " \n", 736 | "# self.cnn_layer1 = nn.Sequential(\n", 737 | "# nn.Conv2d(3, 64, 3, 1, 1),\n", 738 | "# nn.BatchNorm2d(64),\n", 739 | "# )\n", 740 | "\n", 741 | "# self.cnn_layer2 = nn.Sequential(\n", 742 | "# nn.Conv2d(64, 64, 3, 1, 1),\n", 743 | "# nn.BatchNorm2d(64),\n", 744 | "# )\n", 745 | "\n", 746 | "# self.cnn_layer3 = nn.Sequential(\n", 747 | "# nn.Conv2d(64, 128, 3, 2, 1),\n", 748 | "# nn.BatchNorm2d(128),\n", 749 | "# )\n", 750 | "\n", 751 | "# self.cnn_layer4 = nn.Sequential(\n", 752 | "# nn.Conv2d(128, 128, 3, 1, 1),\n", 753 | "# nn.BatchNorm2d(128),\n", 754 | "# )\n", 755 | "# self.cnn_layer5 = nn.Sequential(\n", 756 | "# nn.Conv2d(128, 256, 3, 2, 1),\n", 757 | "# nn.BatchNorm2d(256),\n", 758 | "# )\n", 759 | "# self.cnn_layer6 = nn.Sequential(\n", 760 | "# nn.Conv2d(256, 256, 3, 1, 1),\n", 761 | "# nn.BatchNorm2d(256),\n", 762 | "# )\n", 763 | "# self.fc_layer = nn.Sequential(\n", 764 | "# nn.Linear(256* 32* 32, 256),\n", 765 | "# nn.ReLU(),\n", 766 | "# nn.Linear(256, 11)\n", 767 | "# )\n", 768 | "# self.relu = nn.ReLU()\n", 769 | "\n", 770 | "# def forward(self, x):\n", 771 | "# # input (x): [batch_size, 3, 128, 128]\n", 772 | "# # output: [batch_size, 11]\n", 773 | "\n", 774 | "# # Extract features by convolutional layers.\n", 775 | "# x1 = self.cnn_layer1(x)\n", 776 | " \n", 777 | "# x1 = self.relu(x1)\n", 778 | " \n", 779 | "# x2 = self.cnn_layer2(x1)\n", 780 | " \n", 781 | "# x2 = self.relu(x2+x1)\n", 782 | " \n", 783 | "# x3 = self.cnn_layer3(x2)\n", 784 | " \n", 785 | "# x3 = self.relu(x3)\n", 786 | " \n", 787 | "# x4 = self.cnn_layer4(x3)\n", 788 | " \n", 789 | "# x4 = self.relu(x4+x3)\n", 790 | " \n", 791 | "# x5 = self.cnn_layer5(x4)\n", 792 | " \n", 793 | "# x5 = self.relu(x5)\n", 794 | " \n", 795 | "# x6 = self.cnn_layer6(x5)\n", 796 | " \n", 797 | "# x6 = self.relu(x6+x5)\n", 798 | " \n", 799 | "# # The extracted feature map must be flatten before going to fully-connected layers.\n", 800 | "# xout = x6.flatten(1)\n", 801 | "\n", 802 | "# # The features are transformed by fully-connected layers to obtain the final logits.\n", 803 | "# xout = self.fc_layer(xout)\n", 804 | "# return xout" 805 | ] 806 | } 807 | ], 808 | "metadata": { 809 | "colab": { 810 | "collapsed_sections": [], 811 | "name": "2022ML HW3 Image Classification", 812 | "provenance": [] 813 | }, 814 | "kernelspec": { 815 | "display_name": "kuokuo_env", 816 | "language": "python", 817 | "name": "kuokuo_env" 818 | }, 819 | "language_info": { 820 | "codemirror_mode": { 821 | "name": "ipython", 822 | "version": 3 823 | }, 824 | "file_extension": ".py", 825 | "mimetype": "text/x-python", 826 | "name": "python", 827 | "nbconvert_exporter": "python", 828 | "pygments_lexer": "ipython3", 829 | "version": "3.7.5" 830 | }, 831 | "toc": { 832 | "base_numbering": 1, 833 | "nav_menu": {}, 834 | "number_sections": true, 835 | "sideBar": true, 836 | "skip_h1_title": false, 837 | "title_cell": "Table of Contents", 838 | "title_sidebar": "Contents", 839 | "toc_cell": false, 840 | "toc_position": {}, 841 | "toc_section_display": true, 842 | "toc_window_display": false 843 | } 844 | }, 845 | "nbformat": 4, 846 | "nbformat_minor": 1 847 | } 848 | -------------------------------------------------------------------------------- /hw-06/hw6_WGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "Iv6bjjqyGmqV" 7 | }, 8 | "source": [ 9 | "# Homework 6 - Generative Adversarial Network\n", 10 | "This is the sample code for hw6 of 2022 Machine Learning course in National Taiwan University. \n", 11 | "\n", 12 | "In this sample code, there are 5 sections:\n", 13 | "1. Environment setting\n", 14 | "2. Dataset preparation\n", 15 | "3. Model setting\n", 16 | "4. Train\n", 17 | "5. Inference\n", 18 | "\n", 19 | "Your goal is to do anime face generation, if you have any question, please discuss at NTU COOL " 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "xnp-5lUFLak7" 26 | }, 27 | "source": [ 28 | "# Environment setting\n", 29 | "In this section, we will prepare for the dataset and set some environment variable" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "_qhoMUt9LniJ" 36 | }, 37 | "source": [ 38 | "## Download Dataset" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "!nvidia-smi" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "colab": { 55 | "base_uri": "https://localhost:8080/" 56 | }, 57 | "id": "AaJRTJEFLrND", 58 | "outputId": "0290ac41-c7e2-45cd-8a32-b88e847be85c" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "# get dataset from huggingface hub\n", 63 | "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash\n", 64 | "!apt-get install git-lfs\n", 65 | "!git lfs install\n", 66 | "!git clone https://huggingface.co/datasets/LeoFeng/MLHW_6\n", 67 | "!unzip ./faces.zip -d ." 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "lBkkAB9sO3R4" 74 | }, 75 | "source": [ 76 | "## Other setting" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "id": "Qxf1TXTLO6Ek" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "# import module\n", 88 | "import os\n", 89 | "import glob\n", 90 | "import random\n", 91 | "from datetime import datetime\n", 92 | "\n", 93 | "import torch\n", 94 | "import torch.nn as nn\n", 95 | "import torch.nn.functional as F\n", 96 | "import torchvision\n", 97 | "import torchvision.transforms as transforms\n", 98 | "from torch import optim\n", 99 | "from torch.autograd import Variable\n", 100 | "from torch.utils.data import Dataset, DataLoader\n", 101 | "\n", 102 | "import matplotlib.pyplot as plt\n", 103 | "import numpy as np\n", 104 | "import logging\n", 105 | "from tqdm import tqdm\n", 106 | "\n", 107 | "\n", 108 | "# seed setting\n", 109 | "def same_seeds(seed):\n", 110 | " # Python built-in random module\n", 111 | " random.seed(seed)\n", 112 | " # Numpy\n", 113 | " np.random.seed(seed)\n", 114 | " # Torch\n", 115 | " torch.manual_seed(seed)\n", 116 | " if torch.cuda.is_available():\n", 117 | " torch.cuda.manual_seed(seed)\n", 118 | " torch.cuda.manual_seed_all(seed)\n", 119 | " torch.backends.cudnn.benchmark = False\n", 120 | " torch.backends.cudnn.deterministic = True\n", 121 | "\n", 122 | "same_seeds(2022)\n", 123 | "workspace_dir = '.'" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "device = torch.device(\"cuda\", 0)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": { 138 | "id": "eg2qsevzOeQT" 139 | }, 140 | "source": [ 141 | "# Dataset preparation\n", 142 | "In this section, we prepare for the dataset for Pytorch" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": { 148 | "id": "UT6s1x92OudB" 149 | }, 150 | "source": [ 151 | "## Create dataset for Pytorch\n", 152 | "\n", 153 | "In order to unified image information, we use the transform function to:\n", 154 | "1. Resize image to 64x64\n", 155 | "2. Normalize the image\n", 156 | "\n", 157 | "This CrypkoDataset class will be use in Section 4" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": { 164 | "id": "9MsHqaglOywi" 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "# prepare for CrypkoDataset\n", 169 | "\n", 170 | "class CrypkoDataset(Dataset):\n", 171 | " def __init__(self, fnames, transform):\n", 172 | " self.transform = transform\n", 173 | " self.fnames = fnames\n", 174 | " self.num_samples = len(self.fnames)\n", 175 | "\n", 176 | " def __getitem__(self,idx):\n", 177 | " fname = self.fnames[idx]\n", 178 | " img = torchvision.io.read_image(fname)\n", 179 | " img = self.transform(img)\n", 180 | " return img\n", 181 | "\n", 182 | " def __len__(self):\n", 183 | " return self.num_samples\n", 184 | "\n", 185 | "def get_dataset(root):\n", 186 | " fnames = glob.glob(os.path.join(root, '*'))\n", 187 | " compose = [\n", 188 | " transforms.ToPILImage(),\n", 189 | " transforms.Resize((64, 64)),\n", 190 | " transforms.ToTensor(),\n", 191 | " transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),\n", 192 | " ]\n", 193 | " transform = transforms.Compose(compose)\n", 194 | " dataset = CrypkoDataset(fnames, transform)\n", 195 | " return dataset" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "BPMZTwAiQSnx" 202 | }, 203 | "source": [ 204 | "## Show the image\n", 205 | "Show some sample in the dataset" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": { 212 | "colab": { 213 | "base_uri": "https://localhost:8080/", 214 | "height": 211 215 | }, 216 | "id": "rX5-Q71TOyy4", 217 | "outputId": "664c5793-8d53-4533-82a7-4061ed144a0c" 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "temp_dataset = get_dataset(os.path.join(workspace_dir, 'faces'))\n", 222 | "\n", 223 | "images = [temp_dataset[i] for i in range(4)]\n", 224 | "grid_img = torchvision.utils.make_grid(images, nrow=4)\n", 225 | "plt.figure(figsize=(10,10))\n", 226 | "plt.imshow(grid_img.permute(1, 2, 0))\n", 227 | "plt.show()" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": { 233 | "id": "IgV-jpcfQwEM" 234 | }, 235 | "source": [ 236 | "# Model setting\n", 237 | "In this section, we will create models and trainer." 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": { 243 | "id": "EY4rAlw8RNhG" 244 | }, 245 | "source": [ 246 | "## Create model\n", 247 | "In this section, we will create models for Generator and Discriminator" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": { 254 | "id": "8dfregFtRVGo" 255 | }, 256 | "outputs": [], 257 | "source": [ 258 | "# Generator\n", 259 | "\n", 260 | "class Generator(nn.Module):\n", 261 | " \"\"\"\n", 262 | " Input shape: (batch, in_dim)\n", 263 | " Output shape: (batch, 3, 64, 64)\n", 264 | " \"\"\"\n", 265 | " def __init__(self, in_dim, feature_dim=64):\n", 266 | " super().__init__()\n", 267 | " \n", 268 | " #input: (batch, 100)\n", 269 | " self.l1 = nn.Sequential(\n", 270 | " nn.Linear(in_dim, feature_dim * 8 * 4 * 4, bias=False),\n", 271 | " nn.BatchNorm1d(feature_dim * 8 * 4 * 4),\n", 272 | " nn.ReLU()\n", 273 | " )\n", 274 | " self.l2 = nn.Sequential(\n", 275 | " self.dconv_bn_relu(feature_dim * 8, feature_dim * 4), #(batch, feature_dim * 16, 8, 8) \n", 276 | " self.dconv_bn_relu(feature_dim * 4, feature_dim * 2), #(batch, feature_dim * 16, 16, 16) \n", 277 | " self.dconv_bn_relu(feature_dim * 2, feature_dim), #(batch, feature_dim * 16, 32, 32) \n", 278 | " )\n", 279 | " self.l3 = nn.Sequential(\n", 280 | " nn.ConvTranspose2d(feature_dim, 3, kernel_size=5, stride=2,\n", 281 | " padding=2, output_padding=1, bias=False),\n", 282 | " nn.Tanh() \n", 283 | " )\n", 284 | " self.apply(weights_init)\n", 285 | " def dconv_bn_relu(self, in_dim, out_dim):\n", 286 | " return nn.Sequential(\n", 287 | " nn.ConvTranspose2d(in_dim, out_dim, kernel_size=5, stride=2,\n", 288 | " padding=2, output_padding=1, bias=False), #double height and width\n", 289 | " nn.BatchNorm2d(out_dim),\n", 290 | " nn.ReLU(True)\n", 291 | " )\n", 292 | " def forward(self, x):\n", 293 | " y = self.l1(x)\n", 294 | " y = y.view(y.size(0), -1, 4, 4)\n", 295 | " y = self.l2(y)\n", 296 | " y = self.l3(y)\n", 297 | " return y" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": { 304 | "id": "gbFa8bBZ3jf6" 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "# # Discriminator\n", 309 | "# class Discriminator(nn.Module):\n", 310 | "# \"\"\"\n", 311 | "# Input shape: (batch, 3, 64, 64)\n", 312 | "# Output shape: (batch)\n", 313 | "# \"\"\"\n", 314 | "# def __init__(self, in_dim, feature_dim=64):\n", 315 | "# super(Discriminator, self).__init__()\n", 316 | " \n", 317 | "# #input: (batch, 3, 64, 64)\n", 318 | "# \"\"\"\n", 319 | "# NOTE FOR SETTING DISCRIMINATOR:\n", 320 | "\n", 321 | "# Remove last sigmoid layer for WGAN\n", 322 | "# \"\"\"\n", 323 | "# self.l1 = nn.Sequential(\n", 324 | "# nn.Conv2d(in_dim, feature_dim, kernel_size=4, stride=2, padding=1), #(batch, 3, 32, 32)\n", 325 | "# nn.LeakyReLU(0.2),\n", 326 | "# self.conv_bn_lrelu(feature_dim, feature_dim * 2), #(batch, 3, 16, 16)\n", 327 | "# self.conv_bn_lrelu(feature_dim * 2, feature_dim * 4), #(batch, 3, 8, 8)\n", 328 | "# self.conv_bn_lrelu(feature_dim * 4, feature_dim * 8), #(batch, 3, 4, 4)\n", 329 | "# nn.Conv2d(feature_dim * 8, 1, kernel_size=4, stride=1, padding=0), \n", 330 | "# # nn.Sigmoid()\n", 331 | "# )\n", 332 | "# self.apply(weights_init)\n", 333 | "# def conv_bn_lrelu(self, in_dim, out_dim):\n", 334 | "# \"\"\"\n", 335 | "# NOTE FOR SETTING DISCRIMINATOR:\n", 336 | "\n", 337 | "# You can't use nn.Batchnorm for WGAN-GP\n", 338 | "# Use nn.InstanceNorm2d instead\n", 339 | "# \"\"\"\n", 340 | "\n", 341 | "# return nn.Sequential(\n", 342 | "# nn.Conv2d(in_dim, out_dim, 4, 2, 1),\n", 343 | "# nn.BatchNorm2d(out_dim),\n", 344 | "# nn.LeakyReLU(0.2),\n", 345 | "# )\n", 346 | "# def forward(self, x):\n", 347 | "# y = self.l1(x)\n", 348 | "# y = y.view(-1)\n", 349 | "# return y" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": null, 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [ 358 | "# Discriminator\n", 359 | "class Discriminator(nn.Module):\n", 360 | " \"\"\"\n", 361 | " Input shape: (batch, 3, 64, 64)\n", 362 | " Output shape: (batch)\n", 363 | " \"\"\"\n", 364 | " def __init__(self, in_dim, feature_dim=64):\n", 365 | " super(Discriminator, self).__init__()\n", 366 | " \n", 367 | " #input: (batch, 3, 64, 64)\n", 368 | " \"\"\"\n", 369 | " NOTE FOR SETTING DISCRIMINATOR:\n", 370 | "\n", 371 | " Remove last sigmoid layer for WGAN\n", 372 | " \"\"\"\n", 373 | " self.l1 = nn.Sequential(\n", 374 | " nn.Conv2d(in_dim, feature_dim, kernel_size=4, stride=2, padding=1), #(batch, 3, 32, 32)\n", 375 | " nn.LeakyReLU(0.2)\n", 376 | " )\n", 377 | " \n", 378 | " self.l2 = nn.Sequential(\n", 379 | " nn.Conv2d(feature_dim, feature_dim * 2, kernel_size=4, stride=2, padding=1), #(batch, 3, 8, 8)\n", 380 | " nn.BatchNorm2d(feature_dim * 2),\n", 381 | " nn.LeakyReLU(0.2)\n", 382 | " )\n", 383 | " \n", 384 | " self.l3 = nn.Sequential(\n", 385 | " nn.Conv2d(feature_dim * 2, feature_dim * 4, kernel_size=4, stride=2, padding=1), #(batch, 3, 4, 4)\n", 386 | " nn.BatchNorm2d(feature_dim * 4),\n", 387 | " nn.LeakyReLU(0.2)\n", 388 | " )\n", 389 | " \n", 390 | " self.l4 = nn.Sequential(\n", 391 | " nn.Conv2d(feature_dim * 4, feature_dim * 8, kernel_size=4, stride=2, padding=1), #(batch, 3, 4, 4)\n", 392 | " nn.BatchNorm2d(feature_dim * 8),\n", 393 | " nn.LeakyReLU(0.2)\n", 394 | " )\n", 395 | " \n", 396 | " self.l5 = nn.Sequential(\n", 397 | " nn.Conv2d(feature_dim * 8, 1, kernel_size=4, stride=1, padding=0),\n", 398 | " )\n", 399 | " \n", 400 | " self.apply(weights_init)\n", 401 | " \n", 402 | " def forward(self, x):\n", 403 | " x = self.l1(x)\n", 404 | " x = self.l2(x)\n", 405 | " x = self.l3(x)\n", 406 | " x = self.l4(x)\n", 407 | " x = self.l5(x)\n", 408 | " x = x.view(-1)\n", 409 | " return x" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": { 416 | "id": "Hb7Y38bsR35o" 417 | }, 418 | "outputs": [], 419 | "source": [ 420 | "# setting for weight init function\n", 421 | "def weights_init(m):\n", 422 | " classname = m.__class__.__name__\n", 423 | " if classname.find('Conv') != -1:\n", 424 | " m.weight.data.normal_(0.0, 0.02)\n", 425 | " elif classname.find('BatchNorm') != -1:\n", 426 | " m.weight.data.normal_(1.0, 0.02)\n", 427 | " m.bias.data.fill_(0)" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": { 433 | "id": "eC-6M2P3SAu9" 434 | }, 435 | "source": [ 436 | "## Create trainer\n", 437 | "In this section, we will create a trainer which contains following functions:\n", 438 | "1. prepare_environment: prepare the overall environment, construct the models, create directory for the log and ckpt\n", 439 | "2. train: train for generator and discriminator, you can try to modify the code here to construct WGAN or WGAN-GP\n", 440 | "3. inference: after training, you can pass the generator ckpt path into it and the function will save the result for you" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": { 447 | "id": "f8ajFDWBTRzn" 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "class TrainerGAN():\n", 452 | " def __init__(self, config):\n", 453 | " self.config = config\n", 454 | " \n", 455 | " self.G = Generator(100)\n", 456 | " self.D = Discriminator(3)\n", 457 | " \n", 458 | " self.loss = nn.BCELoss()\n", 459 | "\n", 460 | " \"\"\"\n", 461 | " NOTE FOR SETTING OPTIMIZER:\n", 462 | "\n", 463 | " GAN: use Adam optimizer\n", 464 | " WGAN: use RMSprop optimizer\n", 465 | " WGAN-GP: use Adam optimizer \n", 466 | " \"\"\"\n", 467 | " # self.opt_D = torch.optim.Adam(self.D.parameters(), lr=self.config[\"lr\"], betas=(0.5, 0.999))\n", 468 | " # self.opt_G = torch.optim.Adam(self.G.parameters(), lr=self.config[\"lr\"], betas=(0.5, 0.999))\n", 469 | " self.opt_D = torch.optim.RMSprop(self.D.parameters(), lr=self.config[\"lr\"])\n", 470 | " self.opt_G = torch.optim.RMSprop(self.G.parameters(), lr=self.config[\"lr\"])\n", 471 | " \n", 472 | " self.dataloader = None\n", 473 | " self.log_dir = os.path.join(self.config[\"workspace_dir\"], 'logs')\n", 474 | " self.ckpt_dir = os.path.join(self.config[\"workspace_dir\"], 'checkpoints')\n", 475 | " \n", 476 | " FORMAT = '%(asctime)s - %(levelname)s: %(message)s'\n", 477 | " logging.basicConfig(level=logging.INFO, \n", 478 | " format=FORMAT,\n", 479 | " datefmt='%Y-%m-%d %H:%M')\n", 480 | " \n", 481 | " self.steps = 0\n", 482 | " self.z_samples = Variable(torch.randn(100, self.config[\"z_dim\"])).to(device)\n", 483 | " \n", 484 | " def prepare_environment(self):\n", 485 | " \"\"\"\n", 486 | " Use this funciton to prepare function\n", 487 | " \"\"\"\n", 488 | " os.makedirs(self.log_dir, exist_ok=True)\n", 489 | " os.makedirs(self.ckpt_dir, exist_ok=True)\n", 490 | " \n", 491 | " # update dir by time\n", 492 | " time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n", 493 | " self.log_dir = os.path.join(self.log_dir, time+f'_{self.config[\"model_type\"]}')\n", 494 | " self.ckpt_dir = os.path.join(self.ckpt_dir, time+f'_{self.config[\"model_type\"]}')\n", 495 | " os.makedirs(self.log_dir)\n", 496 | " os.makedirs(self.ckpt_dir)\n", 497 | " \n", 498 | " # create dataset by the above function\n", 499 | " dataset = get_dataset(os.path.join(self.config[\"workspace_dir\"], 'faces'))\n", 500 | " self.dataloader = DataLoader(dataset, batch_size=self.config[\"batch_size\"], shuffle=True, num_workers=2)\n", 501 | " \n", 502 | " # model preparation\n", 503 | " self.G = self.G.to(device)\n", 504 | " self.D = self.D.to(device)\n", 505 | " self.G.train()\n", 506 | " self.D.train()\n", 507 | " def gp(self):\n", 508 | " \"\"\"\n", 509 | " Implement gradient penalty function\n", 510 | " \"\"\"\n", 511 | " pass\n", 512 | " \n", 513 | " def train(self):\n", 514 | " \"\"\"\n", 515 | " Use this function to train generator and discriminator\n", 516 | " \"\"\"\n", 517 | " self.prepare_environment()\n", 518 | " \n", 519 | " for e, epoch in enumerate(range(self.config[\"n_epoch\"])):\n", 520 | " progress_bar = tqdm(self.dataloader)\n", 521 | " progress_bar.set_description(f\"Epoch {e+1}\")\n", 522 | " for i, data in enumerate(progress_bar):\n", 523 | " imgs = data.to(device)\n", 524 | " bs = imgs.size(0)\n", 525 | "\n", 526 | " # *********************\n", 527 | " # * Train D *\n", 528 | " # *********************\n", 529 | " z = Variable(torch.randn(bs, self.config[\"z_dim\"])).to(device)\n", 530 | " r_imgs = Variable(imgs).to(device)\n", 531 | " f_imgs = self.G(z)\n", 532 | " r_label = torch.ones((bs)).to(device)\n", 533 | " f_label = torch.zeros((bs)).to(device)\n", 534 | "\n", 535 | "\n", 536 | " # Discriminator forwarding\n", 537 | " r_logit = self.D(r_imgs)\n", 538 | " f_logit = self.D(f_imgs)\n", 539 | "\n", 540 | " \"\"\"\n", 541 | " NOTE FOR SETTING DISCRIMINATOR LOSS:\n", 542 | " \n", 543 | " GAN: \n", 544 | " loss_D = (r_loss + f_loss)/2\n", 545 | " WGAN: \n", 546 | " loss_D = -torch.mean(r_logit) + torch.mean(f_logit)\n", 547 | " WGAN-GP: \n", 548 | " gradient_penalty = self.gp(r_imgs, f_imgs)\n", 549 | " loss_D = -torch.mean(r_logit) + torch.mean(f_logit) + gradient_penalty\n", 550 | " \"\"\"\n", 551 | " # Loss for discriminator\n", 552 | " # r_loss = self.loss(r_logit, r_label)\n", 553 | " # f_loss = self.loss(f_logit, f_label)\n", 554 | " # loss_D = (r_loss + f_loss) / 2\n", 555 | " loss_D = -torch.mean(r_logit) + torch.mean(f_logit)\n", 556 | "\n", 557 | " # Discriminator backwarding\n", 558 | " self.D.zero_grad()\n", 559 | " loss_D.backward()\n", 560 | " self.opt_D.step()\n", 561 | "\n", 562 | " \"\"\"\n", 563 | " NOTE FOR SETTING WEIGHT CLIP:\n", 564 | " \n", 565 | " WGAN: below code\n", 566 | " \"\"\"\n", 567 | " for p in self.D.parameters():\n", 568 | " p.data.clamp_(-self.config[\"clip_value\"], self.config[\"clip_value\"])\n", 569 | "\n", 570 | "\n", 571 | "\n", 572 | " # *********************\n", 573 | " # * Train G *\n", 574 | " # *********************\n", 575 | " if self.steps % self.config[\"n_critic\"] == 0:\n", 576 | " # Generate some fake images.\n", 577 | " z = Variable(torch.randn(bs, self.config[\"z_dim\"])).to(device)\n", 578 | " f_imgs = self.G(z)\n", 579 | "\n", 580 | " # Generator forwarding\n", 581 | " f_logit = self.D(f_imgs)\n", 582 | "\n", 583 | "\n", 584 | " \"\"\"\n", 585 | " NOTE FOR SETTING LOSS FOR GENERATOR:\n", 586 | " \n", 587 | " GAN: loss_G = self.loss(f_logit, r_label)\n", 588 | " WGAN: loss_G = -torch.mean(self.D(f_imgs))\n", 589 | " WGAN-GP: loss_G = -torch.mean(self.D(f_imgs))\n", 590 | " \"\"\"\n", 591 | " # Loss for the generator.\n", 592 | "# loss_G = self.loss(f_logit, r_label)\n", 593 | " loss_G = -torch.mean(self.D(f_imgs))\n", 594 | "\n", 595 | " # Generator backwarding\n", 596 | " self.G.zero_grad()\n", 597 | " loss_G.backward()\n", 598 | " self.opt_G.step()\n", 599 | " \n", 600 | " if self.steps % 10 == 0:\n", 601 | " progress_bar.set_postfix(loss_G=loss_G.item(), loss_D=loss_D.item())\n", 602 | " self.steps += 1\n", 603 | "\n", 604 | " self.G.eval()\n", 605 | " f_imgs_sample = (self.G(self.z_samples).data + 1) / 2.0\n", 606 | " filename = os.path.join(self.log_dir, f'Epoch_{epoch+1:03d}.jpg')\n", 607 | " torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)\n", 608 | " logging.info(f'Save some samples to {filename}.')\n", 609 | "\n", 610 | " # Show some images during training.\n", 611 | " grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)\n", 612 | " plt.figure(figsize=(10,10))\n", 613 | " plt.imshow(grid_img.permute(1, 2, 0))\n", 614 | " plt.show()\n", 615 | "\n", 616 | " self.G.train()\n", 617 | "\n", 618 | " if (e+1) % 2 == 0 or e == 0:\n", 619 | " # Save the checkpoints.\n", 620 | " torch.save(self.G.state_dict(), os.path.join(self.ckpt_dir, f'G_{e}.pth'))\n", 621 | " torch.save(self.D.state_dict(), os.path.join(self.ckpt_dir, f'D_{e}.pth'))\n", 622 | "\n", 623 | " logging.info('Finish training')\n", 624 | "\n", 625 | " def inference(self, G_path, n_generate=1000, n_output=30, show=False):\n", 626 | " \"\"\"\n", 627 | " 1. G_path is the path for Generator ckpt\n", 628 | " 2. You can use this function to generate final answer\n", 629 | " \"\"\"\n", 630 | "\n", 631 | " self.G.load_state_dict(torch.load(G_path))\n", 632 | " self.G.to(device)\n", 633 | " self.G.eval()\n", 634 | " z = Variable(torch.randn(n_generate, self.config[\"z_dim\"])).to(device)\n", 635 | " imgs = (self.G(z).data + 1) / 2.0\n", 636 | " \n", 637 | " os.makedirs('output', exist_ok=True)\n", 638 | " for i in range(n_generate):\n", 639 | " torchvision.utils.save_image(imgs[i], f'output/{i+1}.jpg')\n", 640 | " \n", 641 | " if show:\n", 642 | " row, col = n_output//10 + 1, 10\n", 643 | " grid_img = torchvision.utils.make_grid(imgs[:n_output].cpu(), nrow=row)\n", 644 | " plt.figure(figsize=(row, col))\n", 645 | " plt.imshow(grid_img.permute(1, 2, 0))\n", 646 | " plt.show()" 647 | ] 648 | }, 649 | { 650 | "cell_type": "markdown", 651 | "metadata": { 652 | "id": "-uf8BdVoYNJ8" 653 | }, 654 | "source": [ 655 | "# Train\n", 656 | "In this section, we will first set the config for trainer, then use it to train generator and discriminator" 657 | ] 658 | }, 659 | { 660 | "cell_type": "markdown", 661 | "metadata": { 662 | "id": "ykjfugCdYmYS" 663 | }, 664 | "source": [ 665 | "## Set config" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": null, 671 | "metadata": { 672 | "id": "Jg4YdRVPYJSj" 673 | }, 674 | "outputs": [], 675 | "source": [ 676 | "config = {\n", 677 | " \"model_type\": \"WGAN\",\n", 678 | " \"batch_size\": 64,\n", 679 | " \"lr\": 1e-4,\n", 680 | " \"n_epoch\": 10,\n", 681 | " \"n_critic\": 3,\n", 682 | " \"z_dim\": 100,\n", 683 | " \"workspace_dir\": workspace_dir, # define in the environment setting\n", 684 | " \"clip_value\": 1,\n", 685 | "}" 686 | ] 687 | }, 688 | { 689 | "cell_type": "markdown", 690 | "metadata": { 691 | "id": "ntn56Ffvip-x" 692 | }, 693 | "source": [ 694 | "## Start to train" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": null, 700 | "metadata": { 701 | "colab": { 702 | "base_uri": "https://localhost:8080/", 703 | "height": 648 704 | }, 705 | "id": "NTHoXrLUYJUn", 706 | "outputId": "dbbe78aa-1934-4ca1-9e97-f2189a6c73ec" 707 | }, 708 | "outputs": [], 709 | "source": [ 710 | "trainer = TrainerGAN(config)\n", 711 | "trainer.train()" 712 | ] 713 | }, 714 | { 715 | "cell_type": "markdown", 716 | "metadata": { 717 | "id": "4g3_RUzYix0W" 718 | }, 719 | "source": [ 720 | "# Inference\n", 721 | "In this section, we will use trainer to train model" 722 | ] 723 | }, 724 | { 725 | "cell_type": "markdown", 726 | "metadata": { 727 | "id": "T6hdMgj_i3kk" 728 | }, 729 | "source": [ 730 | "## Inference through trainer" 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": null, 736 | "metadata": {}, 737 | "outputs": [], 738 | "source": [ 739 | "%cd checkpoints\n", 740 | "!ls" 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": null, 746 | "metadata": {}, 747 | "outputs": [], 748 | "source": [ 749 | "%cd 2022-04-21_00-23-19_WGAN\n", 750 | "!ls\n", 751 | "%cd ../.." 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "execution_count": null, 757 | "metadata": { 758 | "id": "72EEf52FrOCp" 759 | }, 760 | "outputs": [], 761 | "source": [ 762 | "# save the 1000 images into ./output folder\n", 763 | "trainer.inference(f'{workspace_dir}/checkpoints/2022-04-21_00-23-19_WGAN/G_13.pth') # you have to modify the path when running this line" 764 | ] 765 | }, 766 | { 767 | "cell_type": "markdown", 768 | "metadata": { 769 | "id": "WuoaEVUgk7oZ" 770 | }, 771 | "source": [ 772 | "## Prepare .tar file for submission" 773 | ] 774 | }, 775 | { 776 | "cell_type": "code", 777 | "execution_count": null, 778 | "metadata": { 779 | "colab": { 780 | "base_uri": "https://localhost:8080/" 781 | }, 782 | "id": "QI2cnbbWlA3Z", 783 | "outputId": "7b4e05af-fbe0-4ac6-9126-054c6c4a8a75" 784 | }, 785 | "outputs": [], 786 | "source": [ 787 | "%cd output\n", 788 | "!tar -zcf ../submission.tgz *.jpg\n", 789 | "%cd .." 790 | ] 791 | }, 792 | { 793 | "cell_type": "markdown", 794 | "metadata": {}, 795 | "source": [ 796 | "# Gradient Norm" 797 | ] 798 | }, 799 | { 800 | "cell_type": "code", 801 | "execution_count": null, 802 | "metadata": {}, 803 | "outputs": [], 804 | "source": [ 805 | "# Get Gradient Norm\n", 806 | "D = Discriminator(3)\n", 807 | "D.load_state_dict(torch.load(f'{workspace_dir}/checkpoints/2022-04-21_00-23-19_WGAN/D_9.pth'))\n", 808 | "D = D.to(device)\n", 809 | "D.train()\n", 810 | "\n", 811 | "G = Generator(100)\n", 812 | "G.load_state_dict(torch.load(f'{workspace_dir}/checkpoints/2022-04-21_00-23-19_WGAN/G_9.pth'))\n", 813 | "G = G.to(device)\n", 814 | "G.train()\n", 815 | "\n", 816 | "dataset = get_dataset(os.path.join(workspace_dir, 'faces'))\n", 817 | "dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)\n", 818 | "\n", 819 | "for e, epoch in enumerate(range(1)):\n", 820 | " progress_bar = tqdm(dataloader)\n", 821 | " progress_bar.set_description(f\"Epoch {e+1}\")\n", 822 | " for i, data in enumerate(progress_bar):\n", 823 | " imgs = data.to(device)\n", 824 | " bs = imgs.size(0)\n", 825 | "\n", 826 | " # *********************\n", 827 | " # * Train D *\n", 828 | " # *********************\n", 829 | " z = Variable(torch.randn(bs, 100)).to(device)\n", 830 | " r_imgs = Variable(imgs).to(device)\n", 831 | " f_imgs = G(z)\n", 832 | " r_label = torch.ones((bs)).to(device)\n", 833 | " f_label = torch.zeros((bs)).to(device)\n", 834 | "\n", 835 | "\n", 836 | " # Discriminator forwarding\n", 837 | " r_logit = D(r_imgs)\n", 838 | " f_logit = D(f_imgs)\n", 839 | "\n", 840 | " \"\"\"\n", 841 | " NOTE FOR SETTING DISCRIMINATOR LOSS:\n", 842 | "\n", 843 | " GAN: \n", 844 | " loss_D = (r_loss + f_loss)/2\n", 845 | " WGAN: \n", 846 | " loss_D = -torch.mean(r_logit) + torch.mean(f_logit)\n", 847 | " WGAN-GP: \n", 848 | " gradient_penalty = self.gp(r_imgs, f_imgs)\n", 849 | " loss_D = -torch.mean(r_logit) + torch.mean(f_logit) + gradient_penalty\n", 850 | " \"\"\"\n", 851 | " # Loss for discriminator\n", 852 | " # r_loss = self.loss(r_logit, r_label)\n", 853 | " # f_loss = self.loss(f_logit, f_label)\n", 854 | " # loss_D = (r_loss + f_loss) / 2\n", 855 | " loss_D = -torch.mean(r_logit) + torch.mean(f_logit)\n", 856 | "\n", 857 | " # Discriminator backwarding\n", 858 | " D.zero_grad()\n", 859 | " loss_D.backward()\n", 860 | " break\n", 861 | " break" 862 | ] 863 | }, 864 | { 865 | "cell_type": "code", 866 | "execution_count": null, 867 | "metadata": {}, 868 | "outputs": [], 869 | "source": [ 870 | "import math\n", 871 | "\n", 872 | "param_norm_log_list = []\n", 873 | "\n", 874 | "for name, p in D.named_parameters():\n", 875 | " if not p.requires_grad:\n", 876 | " continue\n", 877 | " param_norm = p.grad.detach().data.norm(2)\n", 878 | " param_norm_log = math.log(param_norm.item()) if param_norm.item() else 0\n", 879 | " print(name, param_norm_log)\n", 880 | " param_norm_log_list.append(param_norm_log)\n", 881 | "print(param_norm_log_list[0], param_norm_log_list[2], param_norm_log_list[6], param_norm_log_list[10], param_norm_log_list[14])" 882 | ] 883 | }, 884 | { 885 | "cell_type": "code", 886 | "execution_count": null, 887 | "metadata": {}, 888 | "outputs": [], 889 | "source": [ 890 | "D" 891 | ] 892 | } 893 | ], 894 | "metadata": { 895 | "accelerator": "GPU", 896 | "colab": { 897 | "collapsed_sections": [], 898 | "name": "ML_HW6.ipynb", 899 | "provenance": [] 900 | }, 901 | "kernelspec": { 902 | "display_name": "kuokuo_env", 903 | "language": "python", 904 | "name": "kuokuo_env" 905 | }, 906 | "language_info": { 907 | "codemirror_mode": { 908 | "name": "ipython", 909 | "version": 3 910 | }, 911 | "file_extension": ".py", 912 | "mimetype": "text/x-python", 913 | "name": "python", 914 | "nbconvert_exporter": "python", 915 | "pygments_lexer": "ipython3", 916 | "version": "3.7.5" 917 | }, 918 | "toc": { 919 | "base_numbering": 1, 920 | "nav_menu": {}, 921 | "number_sections": true, 922 | "sideBar": true, 923 | "skip_h1_title": false, 924 | "title_cell": "Table of Contents", 925 | "title_sidebar": "Contents", 926 | "toc_cell": false, 927 | "toc_position": {}, 928 | "toc_section_display": true, 929 | "toc_window_display": false 930 | } 931 | }, 932 | "nbformat": 4, 933 | "nbformat_minor": 1 934 | } 935 | --------------------------------------------------------------------------------