├── README.md ├── hw ├── hw1hw1.ipynb ├── hw2hw2.ipynb ├── hw3hw3.ipynb ├── hw4hw4.ipynb └── hw7_tutorial.ipynb ├── 作业思路.md ├── 机器学习经典blog.md └── 李宏毅机器学习.md /README.md: -------------------------------------------------------------------------------- 1 | # 李宏毅2021春季机器学习 2 | 3 | ## requirement 4 | 5 | 这里的笔记基于已经有python基础、pytorch基础、机器学习基础的同学,且是在我的理解之上的精简版,内容也会只列举对我来说比较有意义的地方。 6 | 7 | 如发现错误,请多多指出 8 | 9 | ## 课程地址 10 | 11 | 主页:https://speech.ee.ntu.edu.tw/~hylee/ml/2021-spring.html 12 | 13 | 我的笔记:[note](./李宏毅机器学习.md) 14 | 15 | 作业地址: 16 | 17 | - hw1:预测COVID-19人数的`线性回归`模型 18 | 19 | - code:[hw1](./hw/hw1hw1.ipynb) 20 | 21 | - public score:截至日期3天前【medium baseline】 22 | 23 | ![image-20210323121054406](https://yumytest.oss-cn-chengdu.aliyuncs.com/img/image-20210323121054406.png) 24 | 25 | - private score: 26 | 27 | ![image-20210328101524785](https://yumytest.oss-cn-chengdu.aliyuncs.com/img/image-20210328101524785.png) 28 | 29 | - hw2:音频分类模型 30 | 31 | - code:[hw2](./hw/hw2hw2.ipynb) 32 | 33 | - public score:截至日期10天前【simple baseline】 34 | 35 | ![image-20210323121701107](https://yumytest.oss-cn-chengdu.aliyuncs.com/img/image-20210323121701107.png) 36 | 37 | - private score: 38 | 39 | ![image-20210405171734542](https://yumytest.oss-cn-chengdu.aliyuncs.com/img/image-20210405171734542.png) 40 | 41 | - hw3: 42 | 43 | - code:[hw3](./hw/hw3hw3.ipynb) 44 | 45 | - public score:截至日期1天前【medium baseline】 46 | 47 | ![image-20210416165537162](https://yumytest.oss-cn-chengdu.aliyuncs.com/img/image-20210416165537162.png) 48 | 49 | - private score: 50 | 51 | ![image-20210417001646181](https://yumytest.oss-cn-chengdu.aliyuncs.com/img/image-20210417001646181.png) 52 | 53 | 54 | 55 | - hw4:(太忙基本没时间搞辣) 56 | 57 | - code:[hw4](./hw/hw4hw4.ipynb) 58 | 59 | - public score:(截至时间几分钟后才训练好提交。。没来得及上榜。。看了一下大约在892名)【simple baseline】 60 | 61 | ![image-20210417001838892](https://yumytest.oss-cn-chengdu.aliyuncs.com/img/image-20210417001838892.png) 62 | 63 | - private score: 64 | 65 | ![image-20210417001806459](https://yumytest.oss-cn-chengdu.aliyuncs.com/img/image-20210417001806459.png) 66 | 67 | - hw5: 68 | 69 | - code: 70 | - public score: 71 | - private score: 72 | 73 | - hw6: 74 | 75 | - code: 76 | - public score: 77 | - private score: 78 | - hw7: 79 | 80 | - code:[hw7](./hw/hw7hw7.ipynb) 81 | 82 | - public score:【medium baseline】 83 | 84 | 85 | 86 | - private score: 87 | 88 | 89 | - hw8: 90 | 91 | - code:[hw8](./hw/hw8hw8.ipynb) 92 | - public score: 93 | - private score: 94 | 95 | 96 | 97 | 98 | 99 | --- 100 | 101 | > 作业ps:只有第一个作业是在我笔记本上能跑的(无独显)。。。第二个在台式机上勉强,后面的基本转到colab了。。还买了2个月的pro心疼。。。 -------------------------------------------------------------------------------- /hw/hw1hw1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 8 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 9 | "execution": { 10 | "iopub.execute_input": "2021-03-21T14:04:15.080622Z", 11 | "iopub.status.busy": "2021-03-21T14:04:15.079672Z", 12 | "iopub.status.idle": "2021-03-21T14:04:15.085505Z", 13 | "shell.execute_reply": "2021-03-21T14:04:15.086237Z" 14 | }, 15 | "papermill": { 16 | "duration": 0.027335, 17 | "end_time": "2021-03-21T14:04:15.086616", 18 | "exception": false, 19 | "start_time": "2021-03-21T14:04:15.059281", 20 | "status": "completed" 21 | }, 22 | "tags": [] 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "/kaggle/input/ml2021spring-hw1/covid.test.csv\n", 30 | "/kaggle/input/ml2021spring-hw1/sampleSubmission.csv\n", 31 | "/kaggle/input/ml2021spring-hw1/covid.train.csv\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "# This Python 3 environment comes with many helpful analytics libraries installed\n", 37 | "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n", 38 | "# For example, here's several helpful packages to load\n", 39 | "\n", 40 | "import numpy as np # linear algebra\n", 41 | "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", 42 | "\n", 43 | "# Input data files are available in the read-only \"../input/\" directory\n", 44 | "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n", 45 | "\n", 46 | "import os\n", 47 | "for dirname, _, filenames in os.walk('/kaggle/input'):\n", 48 | " for filename in filenames:\n", 49 | " print(os.path.join(dirname, filename))\n", 50 | "\n", 51 | "# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n", 52 | "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": { 59 | "execution": { 60 | "iopub.execute_input": "2021-03-21T14:04:15.119652Z", 61 | "iopub.status.busy": "2021-03-21T14:04:15.117998Z", 62 | "iopub.status.idle": "2021-03-21T14:04:16.327317Z", 63 | "shell.execute_reply": "2021-03-21T14:04:16.328221Z" 64 | }, 65 | "papermill": { 66 | "duration": 1.228857, 67 | "end_time": "2021-03-21T14:04:16.328445", 68 | "exception": false, 69 | "start_time": "2021-03-21T14:04:15.099588", 70 | "status": "completed" 71 | }, 72 | "tags": [] 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "import csv\n", 77 | "\n", 78 | "import torch\n", 79 | "import numpy as np\n", 80 | "from matplotlib.pyplot import figure\n", 81 | "from torch.utils.data import Dataset\n", 82 | "from torch.utils.data import DataLoader\n", 83 | "from matplotlib import pyplot as plt\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "metadata": { 90 | "execution": { 91 | "iopub.execute_input": "2021-03-21T14:04:16.357208Z", 92 | "iopub.status.busy": "2021-03-21T14:04:16.356078Z", 93 | "iopub.status.idle": "2021-03-21T14:04:16.361594Z", 94 | "shell.execute_reply": "2021-03-21T14:04:16.362193Z" 95 | }, 96 | "papermill": { 97 | "duration": 0.021853, 98 | "end_time": "2021-03-21T14:04:16.362523", 99 | "exception": false, 100 | "start_time": "2021-03-21T14:04:16.340670", 101 | "status": "completed" 102 | }, 103 | "tags": [] 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "# 超参数\n", 108 | "batch_size = 64\n", 109 | "learning_rate = 0.001\n", 110 | "epochs = 3000\n", 111 | "model_path = 'model.pth'\n", 112 | "early_stop = 300\n", 113 | "target_only = False\n", 114 | "weight_decay = 0.1\n", 115 | "momentum = 0.1" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": { 122 | "execution": { 123 | "iopub.execute_input": "2021-03-21T14:04:16.392136Z", 124 | "iopub.status.busy": "2021-03-21T14:04:16.391002Z", 125 | "iopub.status.idle": "2021-03-21T14:04:16.405610Z", 126 | "shell.execute_reply": "2021-03-21T14:04:16.406212Z" 127 | }, 128 | "papermill": { 129 | "duration": 0.031123, 130 | "end_time": "2021-03-21T14:04:16.406416", 131 | "exception": false, 132 | "start_time": "2021-03-21T14:04:16.375293", 133 | "status": "completed" 134 | }, 135 | "tags": [] 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "class covidDataset(Dataset):\n", 140 | " # def __init__(self, filepath, type='train'):\n", 141 | " # data = np.loadtxt(filepath, skiprows=1, delimiter=',', dtype=np.float32)\n", 142 | " # self.len = data.shape[0]\n", 143 | " # if type == 'train':\n", 144 | " # self.x_data = torch.from_numpy(data[:, 1:-1])\n", 145 | " # self.y_data = torch.from_numpy(data[:, [-1]])\n", 146 | " # else:\n", 147 | " # self.x_data = torch.from_numpy(data[:, 1:])\n", 148 | " #\n", 149 | " # def __getitem__(self, index):\n", 150 | " # return self.x_data[index], self.y_data[index]\n", 151 | " #\n", 152 | " # def __len__(self):\n", 153 | " # return self.len\n", 154 | "\n", 155 | " def __init__(self,\n", 156 | " filepath,\n", 157 | " mode='train',\n", 158 | " target_only=False):\n", 159 | " self.mode = mode\n", 160 | "\n", 161 | " # Read data into numpy arrays\n", 162 | " data = np.loadtxt(filepath, skiprows=1, delimiter=',', dtype=np.float32)\n", 163 | " data = data[:, 1:]\n", 164 | "\n", 165 | " if not target_only:\n", 166 | " feats = list(range(93))\n", 167 | " else:\n", 168 | " # TODO: Using 40 states & 2 tested_positive features (indices = 57 & 75)\n", 169 | " feats = list(range(40))\n", 170 | " feats.append(57)\n", 171 | " feats.append(75)\n", 172 | "\n", 173 | " if mode == 'test':\n", 174 | " # Testing data\n", 175 | " # data: 893 x 93 (40 states + day 1 (18) + day 2 (18) + day 3 (17))\n", 176 | " data = data[:, feats]\n", 177 | " self.data = torch.from_numpy(data)\n", 178 | " else:\n", 179 | " # Training data (train/dev sets)\n", 180 | " # data: 2700 x 94 (40 states + day 1 (18) + day 2 (18) + day 3 (18))\n", 181 | " target = data[:, -1]\n", 182 | " data = data[:, feats]\n", 183 | "\n", 184 | " # Splitting training data into train & dev sets\n", 185 | " if mode == 'train':\n", 186 | " indices = [i for i in range(len(data)) if i % 10 != 0]\n", 187 | " elif mode == 'dev':\n", 188 | " indices = [i for i in range(len(data)) if i % 10 == 0]\n", 189 | "\n", 190 | " # Convert data into PyTorch tensors\n", 191 | " self.data = torch.from_numpy(data[indices])\n", 192 | " self.target = torch.from_numpy(target[indices])\n", 193 | "\n", 194 | " # Normalize features (you may remove this part to see what will happen)\n", 195 | " # self.data[:, 40:] = \\\n", 196 | " # (self.data[:, 40:] - self.data[:, 40:].mean(dim=0, keepdim=True)) \\\n", 197 | " # / self.data[:, 40:].std(dim=0, keepdim=True)\n", 198 | "\n", 199 | " self.dim = self.data.shape[1]\n", 200 | "\n", 201 | " print('Finished reading the {} set of COVID19 Dataset ({} samples found, each dim = {})'\n", 202 | " .format(mode, len(self.data), self.dim))\n", 203 | "\n", 204 | " def __getitem__(self, index):\n", 205 | " # Returns one sample at a time\n", 206 | " if self.mode in ['train', 'dev']:\n", 207 | " # For training\n", 208 | " return self.data[index], self.target[index]\n", 209 | " else:\n", 210 | " # For testing (no target)\n", 211 | " return self.data[index]\n", 212 | "\n", 213 | " def __len__(self):\n", 214 | " # Returns the size of the dataset\n", 215 | " return len(self.data)\n", 216 | "\n" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 5, 222 | "metadata": { 223 | "execution": { 224 | "iopub.execute_input": "2021-03-21T14:04:16.435596Z", 225 | "iopub.status.busy": "2021-03-21T14:04:16.434579Z", 226 | "iopub.status.idle": "2021-03-21T14:04:16.441453Z", 227 | "shell.execute_reply": "2021-03-21T14:04:16.441998Z" 228 | }, 229 | "papermill": { 230 | "duration": 0.023349, 231 | "end_time": "2021-03-21T14:04:16.442236", 232 | "exception": false, 233 | "start_time": "2021-03-21T14:04:16.418887", 234 | "status": "completed" 235 | }, 236 | "tags": [] 237 | }, 238 | "outputs": [], 239 | "source": [ 240 | "def dataloader(path, mode, batch_size, n_jobs=0, target_only=False):\n", 241 | " ''' Generates a dataset, then is put into a dataloader. '''\n", 242 | " dataset = covidDataset(path, mode=mode, target_only=target_only) # Construct dataset\n", 243 | " dataloader = DataLoader(\n", 244 | " dataset=dataset,\n", 245 | " batch_size=batch_size,\n", 246 | " shuffle=(mode == 'train'),\n", 247 | " num_workers=n_jobs, # how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process\n", 248 | " pin_memory=True) # Construct dataloader\n", 249 | " return dataloader\n", 250 | "\n", 251 | "\n" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 6, 257 | "metadata": { 258 | "execution": { 259 | "iopub.execute_input": "2021-03-21T14:04:16.473921Z", 260 | "iopub.status.busy": "2021-03-21T14:04:16.472627Z", 261 | "iopub.status.idle": "2021-03-21T14:04:16.478693Z", 262 | "shell.execute_reply": "2021-03-21T14:04:16.479283Z" 263 | }, 264 | "papermill": { 265 | "duration": 0.02386, 266 | "end_time": "2021-03-21T14:04:16.479483", 267 | "exception": false, 268 | "start_time": "2021-03-21T14:04:16.455623", 269 | "status": "completed" 270 | }, 271 | "tags": [] 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 7, 281 | "metadata": { 282 | "execution": { 283 | "iopub.execute_input": "2021-03-21T14:04:16.512809Z", 284 | "iopub.status.busy": "2021-03-21T14:04:16.511763Z", 285 | "iopub.status.idle": "2021-03-21T14:04:16.521610Z", 286 | "shell.execute_reply": "2021-03-21T14:04:16.522184Z" 287 | }, 288 | "papermill": { 289 | "duration": 0.028764, 290 | "end_time": "2021-03-21T14:04:16.522381", 291 | "exception": false, 292 | "start_time": "2021-03-21T14:04:16.493617", 293 | "status": "completed" 294 | }, 295 | "tags": [] 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "class Model(torch.nn.Module):\n", 300 | " def __init__(self, features):\n", 301 | " super(Model, self).__init__()\n", 302 | " self.linear1 = torch.nn.Linear(features, 128)\n", 303 | "\n", 304 | " self.linear2 = torch.nn.Linear(128, 64)\n", 305 | "\n", 306 | " self.linear3 = torch.nn.Linear(64, 32)\n", 307 | "\n", 308 | " self.linear4 = torch.nn.Linear(32, 8)\n", 309 | "\n", 310 | " self.linear5 = torch.nn.Linear(8, 1)\n", 311 | " self.relu = torch.nn.ReLU()\n", 312 | "\n", 313 | "\n", 314 | "\n", 315 | " def forward(self, x):\n", 316 | " in_size = x.size(0)\n", 317 | " x = self.relu(self.linear1(x))\n", 318 | " x = self.relu(self.linear2(x))\n", 319 | " x = self.relu(self.linear3(x))\n", 320 | " x = self.relu(self.linear4(x))\n", 321 | " x = self.relu(self.linear5(x))\n", 322 | " x = x.view(in_size)\n", 323 | " return x\n" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 8, 329 | "metadata": { 330 | "execution": { 331 | "iopub.execute_input": "2021-03-21T14:04:16.553186Z", 332 | "iopub.status.busy": "2021-03-21T14:04:16.552137Z", 333 | "iopub.status.idle": "2021-03-21T14:04:17.105389Z", 334 | "shell.execute_reply": "2021-03-21T14:04:17.104794Z" 335 | }, 336 | "papermill": { 337 | "duration": 0.570374, 338 | "end_time": "2021-03-21T14:04:17.105555", 339 | "exception": false, 340 | "start_time": "2021-03-21T14:04:16.535181", 341 | "status": "completed" 342 | }, 343 | "tags": [] 344 | }, 345 | "outputs": [ 346 | { 347 | "name": "stdout", 348 | "output_type": "stream", 349 | "text": [ 350 | "Finished reading the train set of COVID19 Dataset (2430 samples found, each dim = 93)\n", 351 | "Finished reading the dev set of COVID19 Dataset (270 samples found, each dim = 93)\n", 352 | "Finished reading the test set of COVID19 Dataset (893 samples found, each dim = 93)\n" 353 | ] 354 | } 355 | ], 356 | "source": [ 357 | "\n", 358 | "# load data\n", 359 | "tr_path = '../input/ml2021spring-hw1/covid.train.csv'\n", 360 | "tt_path = '../input/ml2021spring-hw1/covid.test.csv'\n", 361 | "train_loader = dataloader(tr_path, 'train', batch_size, target_only=target_only)\n", 362 | "dev_loader = dataloader(tr_path, 'dev', batch_size, target_only=target_only)\n", 363 | "test_loader = dataloader(tt_path, 'test', batch_size, target_only=target_only)\n", 364 | "\n" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 9, 370 | "metadata": { 371 | "execution": { 372 | "iopub.execute_input": "2021-03-21T14:04:17.146010Z", 373 | "iopub.status.busy": "2021-03-21T14:04:17.145282Z", 374 | "iopub.status.idle": "2021-03-21T14:04:17.167847Z", 375 | "shell.execute_reply": "2021-03-21T14:04:17.165897Z" 376 | }, 377 | "papermill": { 378 | "duration": 0.048915, 379 | "end_time": "2021-03-21T14:04:17.168042", 380 | "exception": false, 381 | "start_time": "2021-03-21T14:04:17.119127", 382 | "status": "completed" 383 | }, 384 | "tags": [] 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "\n", 389 | "model = Model(train_loader.dataset.dim).to(device)\n", 390 | "\n", 391 | "# construct loss and optimizer\n", 392 | "criterion = torch.nn.MSELoss(reduction='mean')\n", 393 | "optimizer = torch.optim.Adam(model.parameters(),\n", 394 | " lr=learning_rate,\n", 395 | " weight_decay=weight_decay)\n", 396 | "# optimizer = torch.optim.SGD(model.parameters(),\n", 397 | "# lr=learning_rate,\n", 398 | "# weight_decay=weight_decay,\n", 399 | "# momentum=momentum)\n", 400 | "\n", 401 | "\n", 402 | "# lossRecord = [0.0 for _ in range(epochs)]\n" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 10, 408 | "metadata": { 409 | "execution": { 410 | "iopub.execute_input": "2021-03-21T14:04:17.201429Z", 411 | "iopub.status.busy": "2021-03-21T14:04:17.200670Z", 412 | "iopub.status.idle": "2021-03-21T14:04:17.203018Z", 413 | "shell.execute_reply": "2021-03-21T14:04:17.202411Z" 414 | }, 415 | "papermill": { 416 | "duration": 0.021217, 417 | "end_time": "2021-03-21T14:04:17.203225", 418 | "exception": false, 419 | "start_time": "2021-03-21T14:04:17.182008", 420 | "status": "completed" 421 | }, 422 | "tags": [] 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "def load_model():\n", 427 | " pass\n" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 11, 433 | "metadata": { 434 | "execution": { 435 | "iopub.execute_input": "2021-03-21T14:04:17.242837Z", 436 | "iopub.status.busy": "2021-03-21T14:04:17.242070Z", 437 | "iopub.status.idle": "2021-03-21T14:04:17.245890Z", 438 | "shell.execute_reply": "2021-03-21T14:04:17.245214Z" 439 | }, 440 | "papermill": { 441 | "duration": 0.028965, 442 | "end_time": "2021-03-21T14:04:17.246049", 443 | "exception": false, 444 | "start_time": "2021-03-21T14:04:17.217084", 445 | "status": "completed" 446 | }, 447 | "tags": [] 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "\n", 452 | "def train(tr_set, dv_set):\n", 453 | " # running_loss = 0.0\n", 454 | " # for batch_idx, data in enumerate(train_loader, 0):\n", 455 | " # inputs, target = data\n", 456 | " # inputs, target = inputs.to(device), target.to(device)\n", 457 | " # optimizer.zero_grad()\n", 458 | " #\n", 459 | " # outputs = model(inputs)\n", 460 | " # loss = criterion(outputs, target)\n", 461 | " # loss.backward()\n", 462 | " # optimizer.step()\n", 463 | " #\n", 464 | " # running_loss += loss.item()\n", 465 | " # lossRecord[epoch] += loss.item()\n", 466 | " # if batch_idx % 10 == 0:\n", 467 | " # print('[%d, %5d] loss: %.3f' % (epoch, batch_idx, running_loss / 10))\n", 468 | " # running_loss = 0.0\n", 469 | "\n", 470 | " min_mse = 10000.\n", 471 | " loss_record = {'train': [], 'dev': []} # for recording training loss\n", 472 | " early_stop_cnt = 0\n", 473 | " epoch = 0\n", 474 | " while epoch < epochs:\n", 475 | " model.train() # set model to training mode\n", 476 | " for x, y in tr_set: # iterate through the dataloader\n", 477 | " optimizer.zero_grad() # set gradient to zero\n", 478 | " x, y = x.to(device), y.to(device) # move data to device (cpu/cuda)\n", 479 | " pred = model(x) # forward pass (compute output)\n", 480 | " mse_loss = criterion(pred, y) # compute loss\n", 481 | " mse_loss.backward() # compute gradient (backpropagation)\n", 482 | " optimizer.step() # update model with optimizer\n", 483 | " loss_record['train'].append(mse_loss.detach().cpu().item())\n", 484 | "\n", 485 | " # After each epoch, test your model on the validation (development) set.\n", 486 | " dev_mse = dev(dv_set)\n", 487 | " if dev_mse < min_mse:\n", 488 | " # Save model if your model improved\n", 489 | " min_mse = dev_mse\n", 490 | " print('Saving model (epoch = {:4d}, loss = {:.4f})'\n", 491 | " .format(epoch + 1, min_mse))\n", 492 | " torch.save(model.state_dict(), model_path) # Save model to specified path\n", 493 | " early_stop_cnt = 0\n", 494 | " else:\n", 495 | " early_stop_cnt += 1\n", 496 | "\n", 497 | " epoch += 1\n", 498 | " loss_record['dev'].append(dev_mse)\n", 499 | " # 如果一个epoch有early_stop个epoch没有更新,则停止训练\n", 500 | " if early_stop_cnt > early_stop:\n", 501 | " # Stop training if your model stops improving for \"config['early_stop']\" epochs.\n", 502 | " break\n", 503 | "\n", 504 | " print('Finished training after {} epochs'.format(epoch))\n", 505 | " return min_mse, loss_record\n" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 12, 511 | "metadata": { 512 | "execution": { 513 | "iopub.execute_input": "2021-03-21T14:04:17.284213Z", 514 | "iopub.status.busy": "2021-03-21T14:04:17.283145Z", 515 | "iopub.status.idle": "2021-03-21T14:04:17.287041Z", 516 | "shell.execute_reply": "2021-03-21T14:04:17.286391Z" 517 | }, 518 | "papermill": { 519 | "duration": 0.026859, 520 | "end_time": "2021-03-21T14:04:17.287216", 521 | "exception": false, 522 | "start_time": "2021-03-21T14:04:17.260357", 523 | "status": "completed" 524 | }, 525 | "tags": [] 526 | }, 527 | "outputs": [], 528 | "source": [ 529 | "\n", 530 | "def dev(dv_set):\n", 531 | " model.eval() # set model to evalutation mode\n", 532 | " total_loss = 0\n", 533 | " for x, y in dv_set: # iterate through the dataloader\n", 534 | " x, y = x.to(device), y.to(device) # move data to device (cpu/cuda)\n", 535 | " with torch.no_grad(): # disable gradient calculation\n", 536 | " pred = model(x) # forward pass (compute output)\n", 537 | " mse_loss = criterion(pred, y) # compute loss\n", 538 | " total_loss += mse_loss.detach().cpu().item() * len(x) # accumulate loss\n", 539 | " total_loss = total_loss / len(dv_set.dataset) # compute averaged loss\n", 540 | "\n", 541 | " return total_loss\n" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 13, 547 | "metadata": { 548 | "execution": { 549 | "iopub.execute_input": "2021-03-21T14:04:17.322689Z", 550 | "iopub.status.busy": "2021-03-21T14:04:17.321980Z", 551 | "iopub.status.idle": "2021-03-21T14:04:17.325760Z", 552 | "shell.execute_reply": "2021-03-21T14:04:17.325216Z" 553 | }, 554 | "papermill": { 555 | "duration": 0.024572, 556 | "end_time": "2021-03-21T14:04:17.325918", 557 | "exception": false, 558 | "start_time": "2021-03-21T14:04:17.301346", 559 | "status": "completed" 560 | }, 561 | "tags": [] 562 | }, 563 | "outputs": [], 564 | "source": [ 565 | "\n", 566 | "def test():\n", 567 | " model.eval() # set model to evalutation mode\n", 568 | " preds = []\n", 569 | " for x in test_loader: # iterate through the dataloader\n", 570 | " x = x.to(device) # move data to device (cpu/cuda)\n", 571 | " with torch.no_grad(): # disable gradient calculation\n", 572 | " pred = model(x) # forward pass (compute output)\n", 573 | " preds.append(pred.detach().cpu()) # collect prediction\n", 574 | " preds = torch.cat(preds, dim=0).numpy() # concatenate all predictions and convert to a numpy array\n", 575 | " return preds\n" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 14, 581 | "metadata": { 582 | "execution": { 583 | "iopub.execute_input": "2021-03-21T14:04:17.361838Z", 584 | "iopub.status.busy": "2021-03-21T14:04:17.360866Z", 585 | "iopub.status.idle": "2021-03-21T14:04:17.364824Z", 586 | "shell.execute_reply": "2021-03-21T14:04:17.364161Z" 587 | }, 588 | "papermill": { 589 | "duration": 0.024901, 590 | "end_time": "2021-03-21T14:04:17.364977", 591 | "exception": false, 592 | "start_time": "2021-03-21T14:04:17.340076", 593 | "status": "completed" 594 | }, 595 | "tags": [] 596 | }, 597 | "outputs": [], 598 | "source": [ 599 | "\n", 600 | "def save_pred(preds, file):\n", 601 | " ''' Save predictions to specified file '''\n", 602 | " with open(file, 'w', newline=\"\") as fp:\n", 603 | " writer = csv.writer(fp)\n", 604 | " writer.writerow(['id', 'tested_positive'])\n", 605 | " for i, p in enumerate(preds):\n", 606 | " writer.writerow([i, p])\n", 607 | " print('Successful save results to {}'.format(file))" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 15, 613 | "metadata": { 614 | "execution": { 615 | "iopub.execute_input": "2021-03-21T14:04:17.402755Z", 616 | "iopub.status.busy": "2021-03-21T14:04:17.401811Z", 617 | "iopub.status.idle": "2021-03-21T14:04:17.404830Z", 618 | "shell.execute_reply": "2021-03-21T14:04:17.404175Z" 619 | }, 620 | "papermill": { 621 | "duration": 0.025493, 622 | "end_time": "2021-03-21T14:04:17.404985", 623 | "exception": false, 624 | "start_time": "2021-03-21T14:04:17.379492", 625 | "status": "completed" 626 | }, 627 | "tags": [] 628 | }, 629 | "outputs": [], 630 | "source": [ 631 | "\n", 632 | "def plot_learning_curve(loss_record):\n", 633 | " ''' Plot learning curve of your DNN (train & dev loss) '''\n", 634 | " total_steps = len(loss_record['train'])\n", 635 | " x_1 = range(total_steps)\n", 636 | " x_2 = x_1[::len(loss_record['train']) // len(loss_record['dev'])]\n", 637 | " figure(figsize=(6, 4))\n", 638 | " plt.plot(x_1, loss_record['train'], c='tab:red', label='train')\n", 639 | " plt.plot(x_2, loss_record['dev'], c='tab:cyan', label='dev')\n", 640 | " plt.ylim(0.0, 5.)\n", 641 | " plt.xlabel('Training steps')\n", 642 | " plt.ylabel('MSE loss')\n", 643 | " plt.title('Learning curve')\n", 644 | " plt.legend()\n", 645 | " plt.show()\n" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": 16, 651 | "metadata": { 652 | "execution": { 653 | "iopub.execute_input": "2021-03-21T14:04:17.442937Z", 654 | "iopub.status.busy": "2021-03-21T14:04:17.442256Z", 655 | "iopub.status.idle": "2021-03-21T14:10:44.252984Z", 656 | "shell.execute_reply": "2021-03-21T14:10:44.251994Z" 657 | }, 658 | "papermill": { 659 | "duration": 386.833628, 660 | "end_time": "2021-03-21T14:10:44.253228", 661 | "exception": false, 662 | "start_time": "2021-03-21T14:04:17.419600", 663 | "status": "completed" 664 | }, 665 | "tags": [] 666 | }, 667 | "outputs": [ 668 | { 669 | "name": "stdout", 670 | "output_type": "stream", 671 | "text": [ 672 | "Saving model (epoch = 1, loss = 44.2446)\n", 673 | "Saving model (epoch = 2, loss = 18.4360)\n", 674 | "Saving model (epoch = 3, loss = 6.2738)\n", 675 | "Saving model (epoch = 4, loss = 5.6395)\n", 676 | "Saving model (epoch = 5, loss = 4.0320)\n", 677 | "Saving model (epoch = 6, loss = 3.0655)\n", 678 | "Saving model (epoch = 7, loss = 2.2270)\n", 679 | "Saving model (epoch = 8, loss = 1.7633)\n", 680 | "Saving model (epoch = 9, loss = 1.5082)\n", 681 | "Saving model (epoch = 11, loss = 1.4643)\n", 682 | "Saving model (epoch = 12, loss = 1.3422)\n", 683 | "Saving model (epoch = 14, loss = 1.2304)\n", 684 | "Saving model (epoch = 16, loss = 1.1859)\n", 685 | "Saving model (epoch = 18, loss = 1.1647)\n", 686 | "Saving model (epoch = 19, loss = 1.1531)\n", 687 | "Saving model (epoch = 21, loss = 1.1395)\n", 688 | "Saving model (epoch = 24, loss = 1.1075)\n", 689 | "Saving model (epoch = 33, loss = 1.0888)\n", 690 | "Saving model (epoch = 41, loss = 1.0633)\n", 691 | "Saving model (epoch = 46, loss = 1.0276)\n", 692 | "Saving model (epoch = 51, loss = 1.0117)\n", 693 | "Saving model (epoch = 58, loss = 1.0074)\n", 694 | "Saving model (epoch = 66, loss = 1.0057)\n", 695 | "Saving model (epoch = 69, loss = 0.9834)\n", 696 | "Saving model (epoch = 79, loss = 0.9831)\n", 697 | "Saving model (epoch = 90, loss = 0.9679)\n", 698 | "Saving model (epoch = 96, loss = 0.9523)\n", 699 | "Saving model (epoch = 99, loss = 0.9442)\n", 700 | "Saving model (epoch = 113, loss = 0.9382)\n", 701 | "Saving model (epoch = 114, loss = 0.9336)\n", 702 | "Saving model (epoch = 121, loss = 0.9329)\n", 703 | "Saving model (epoch = 137, loss = 0.9239)\n", 704 | "Saving model (epoch = 144, loss = 0.9164)\n", 705 | "Saving model (epoch = 156, loss = 0.9146)\n", 706 | "Saving model (epoch = 180, loss = 0.9068)\n", 707 | "Saving model (epoch = 188, loss = 0.9013)\n", 708 | "Saving model (epoch = 206, loss = 0.8997)\n", 709 | "Saving model (epoch = 208, loss = 0.8911)\n", 710 | "Saving model (epoch = 212, loss = 0.8858)\n", 711 | "Saving model (epoch = 246, loss = 0.8844)\n", 712 | "Saving model (epoch = 253, loss = 0.8707)\n", 713 | "Saving model (epoch = 268, loss = 0.8687)\n", 714 | "Saving model (epoch = 362, loss = 0.8682)\n", 715 | "Saving model (epoch = 385, loss = 0.8634)\n", 716 | "Saving model (epoch = 425, loss = 0.8589)\n", 717 | "Saving model (epoch = 450, loss = 0.8560)\n", 718 | "Saving model (epoch = 456, loss = 0.8507)\n", 719 | "Saving model (epoch = 586, loss = 0.8497)\n", 720 | "Saving model (epoch = 857, loss = 0.8480)\n", 721 | "Saving model (epoch = 888, loss = 0.8462)\n", 722 | "Saving model (epoch = 893, loss = 0.8424)\n", 723 | "Saving model (epoch = 1068, loss = 0.8401)\n", 724 | "Saving model (epoch = 1092, loss = 0.8398)\n", 725 | "Saving model (epoch = 1348, loss = 0.8384)\n", 726 | "Saving model (epoch = 1368, loss = 0.8382)\n", 727 | "Saving model (epoch = 1389, loss = 0.8381)\n", 728 | "Saving model (epoch = 1393, loss = 0.8380)\n", 729 | "Saving model (epoch = 1414, loss = 0.8326)\n", 730 | "Saving model (epoch = 1570, loss = 0.8312)\n", 731 | "Saving model (epoch = 1601, loss = 0.8298)\n", 732 | "Saving model (epoch = 1776, loss = 0.8286)\n", 733 | "Saving model (epoch = 1813, loss = 0.8278)\n", 734 | "Saving model (epoch = 1821, loss = 0.8269)\n", 735 | "Saving model (epoch = 1864, loss = 0.8264)\n", 736 | "Finished training after 2165 epochs\n" 737 | ] 738 | }, 739 | { 740 | "data": { 741 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEWCAYAAABsY4yMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAABSA0lEQVR4nO2dd5wTVdfHfyfbG+zSywK71EVQkKaIIiAqIiDYKGJDwUfsihSRR1CxF/RRsaHyqhRFEBFFRVAREelNll52YWkLu2xvue8fmSSTZJJMkpkkO3u+nw+azNy598zdzJl7zz33HBJCgGEYhjEeplALwDAMw+gDK3iGYRiDwgqeYRjGoLCCZxiGMSis4BmGYQwKK3iGYRiDwgqeqbEQ0RVEtCfUcjCMXhD7wTOhgIgOA7hXCLEy1LIwjFHhETxjWIgoItQyBIoR7oEJHazgmbCCiExENJmIDhBRLhF9RUR1ZOe/JqITRJRPRH8QUQfZuc+IaDYR/UBERQD6EtFhIppARNulaxYSUaxUvg8RZcuud1tWOj+RiHKI6DgR3UtEgohau7mPOkT0qVT2HBF9Kx2/i4j+dCprq0fhHqZI9xshKz+MiLar6S+mZsMKngk3HgYwFMCVAJoAOAfgXdn5HwG0AdAAwGYAXzpdPwrATABJAKyK9FYAAwCkA7gIwF0e2lcsS0QDADwOoD+A1pJ8nvgcQDyADpKsb3op7+4eXgNQBKCf0/l50mdv/cXUYFjBM+HGfQCmCiGyhRBlAKYDuJmIIgFACPGJEKJAdq4TEdWWXb9UCLFWCGEWQpRKx94WQhwXQpwFsAxAZw/tuyt7K4BPhRC7hBDFAGa4q4CIGgO4DsB/hBDnhBAVQojffegD53uYD2CkVHcSgIHSMcBLfzE1G1bwTLjRAsASIsojojwAuwFUAWhIRBFE9JJkjjgP4LB0TT3Z9VkKdZ6QfS4GkOihfXdlmzjVrdSOlWYAzgohznko4wnnuucBuJGIYgDcCGCzEOKIdM5tf/nZNmMgWMEz4UYWgOuEEMmyf7FCiGOwmCZugMVMUhtAmnQNya7Xyy0sB0Cq7HszD2WzANQhomSFc0WwmG4AAETUSKGMwz0IIf4FcASWWYHcPGNty11/MTUcVvBMKIkioljZv0gA7wOYSUQtAICI6hPRDVL5JABlAHJhUZIvBFHWrwDcTUTtiSgewH/dFRRC5MCyVvAeEaUQURQR9ZZObwPQgYg6Swu401W2Pw8We3tvAF/LjnvqL6aGwwqeCSU/ACiR/ZsO4C0A3wH4mYgKAPwN4BKp/P/BMpI9BuBf6VxQEEL8COBtAKsB7AewTjpV5uaS2wFUAMgEcArAo1I9ewE8C2AlgH2wLwR7Yz6APgBWCSHOyI576i+mhsMbnRjGD4ioPYCdAGKEEJWhlodhlOARPMOoRPI/jyaiFAAvA1jGyp0JZ3QdwUvb0QtgWdWvFEJ0060xhtEZIloBoCcsv+ffAYyX7O0ME5YEQ8F3c7IZMgzDMEGATTQMwzAGRe8R/CFYtk4LAB8IIT5UKDMOwDgASEhI6JqRkeFzO8ezsnE6uQ7aHj2MuA4XBCg1wzBM9WHTpk1nhBD1lc7preCbCCGOE1EDAL8AeEgI8Ye78t26dRMbN270uZ1pD0/CR8NGYsXDd6Dzju0BSMwwDFO9IKJN7tY3dTXRCCGOS/8/BWAJgB56tEO6bV5kGIapvuim4IkoQQqMBCJKAHANLH7DuiEcdqwzDMPUbPSMONcQliBI1nbmCSFW6NKSNIAXxAqeYRjGim4KXghxEEAnveqXwyYahqm5VFRUIDs7G6Wlpd4LV2NiY2ORmpqKqKgo1dcYK2Y0D+AZpsaRnZ2NpKQkpKWlgQw6ixdCIDc3F9nZ2UhPT1d9nSH84EnyBGIbPMPUPEpLS1G3bl3DKncAICLUrVvX51mKQRR8qCVgGCaUGFm5W/HnHg2h4K3wIivDMIwdgyh4IfsvwzBM8MjLy8N7773n83UDBw5EXl6e9gLJMISCJ45pzzBMiHCn4Kuqqjxe98MPPyA5OVknqSwYzIuGTTQMwwSXyZMn48CBA+jcuTOioqKQmJiIxo0bY+vWrfj3338xdOhQZGVlobS0FI888gjGjRsHAEhLS8PGjRtRWFiI6667Dpdffjn++usvNG3aFEuXLkVcXFzAshlCwVvVOtvgGaZmc+KFF1C2O1PTOmPaZ6DRU0+5Pf/SSy9h586d2Lp1K3777Tdcf/312Llzp82d8ZNPPkGdOnVQUlKC7t2746abbkLdunUd6ti3bx/mz5+Pjz76CLfeeiu++eYbjB49OmDZDaHgwSYahmHChB49ejj4qr/99ttYsmQJACArKwv79u1zUfDp6eno3LkzAKBr1644fPiwJrIYQ8FLsJpnmJqNp5F2sEhISLB9/u2337By5UqsW7cO8fHx6NOnj6Ive0xMjO1zREQESkpKNJHFGIusVi8aNtEwDBNkkpKSUFBQoHguPz8fKSkpiI+PR2ZmJv7++++gymaIETxvdGIYJlTUrVsXvXr1QseOHREXF4eGDRvazg0YMADvv/8+LrroIrRr1w6XXnppUGUzhIK3ozyCL88+huK/1yH55puDLA/DMDWBefPmKR6PiYnBjz/+qHjOamevV68edu60R1KfMGGCZnIZQsHbYtG4sdAcueN2VB7PQa1Bg2CKjQ2iZAzDMKHDEDZ4b8urVblngyQHwzBM+GAQBW+BF1kZhmHsGELB8yIrwzCMK4ZQ8FY4HjzDMIwdQyh4W8o+1u8MwzA2jKHgQ2SiKc/KwrkFC0PTOMMwYcv06dPx2muvhVoMY7hJWgm2iebIqNtQefo0at84DKbo6KC2zTAM4w1DjOAhQhOqoCo/36F9hmFqLjNnzkS7du3Qv39/7NmzBwBw4MABDBgwAF27dsUVV1yBzMxM5OfnIy0tDWazGQBQXFyMZs2aoaKiQnOZDDGCJw4zxjAMgGn7srGzUJtAXVY6JsbhuTapHsts2rQJCxYswJYtW1BZWYkuXbqga9euGDduHN5//320adMG69evx/jx47Fq1Sp06tQJv//+O/r27Ytly5bh2muvRVRUlKZyAwZR8FZYzTMMEwrWrFmDYcOGIT4+HgAwZMgQlJaW4q+//sItt9xiK1dWVgYAGD58OBYuXIi+fftiwYIFGD9+vC5yGULB2xZZeaMTw9RovI209YSc9I/ZbEZycjK2bt3qUnbIkCGYMmUKzp49i02bNqFfv366yGQIG7xXEw3byBmG0ZHevXtjyZIlKCkpQUFBAZYtW4b4+Hikp6fj66+/BgAIIbBt2zYAQGJiInr06IFHHnkEgwYNQkREhC5yGULBW3EXbMyGXiN8foEwTI2mS5cuGD58ODp37oybbroJV1xxBQDgyy+/xJw5c9CpUyd06NABS5cutV0zfPhwfPHFFxg+fLhuchnCRGPzogn2Tic2CTEMIzF16lRMnTrV5fiKFSsUy998880QOg8ODTGC51g0DMMwrhhCwVvhaJIMwzB2DKHg2Q+eYWo2eps6wgF/7tEYCr4G/HEZhlEmNjYWubm5hlbyQgjk5uYi1seMdMZYZJUImYnGwD8shgl3UlNTkZ2djdOnT4daFF2JjY1Faqpvfv41S8FrrYit7bGCZ5iQERUVhfT09FCLEZbUDBONXiN7XtRlGCaM0V3BE1EEEW0hou/1bovH0QzDMHaCYaJ5BMBuALX0asDdCL7o7/WIqJWkV7MMwzBhja4KnohSAVwPYCaAx/VsS2rQ4evRu+6yHI6J0b1phmGYcENvE80sABMBmN0VIKJxRLSRiDb6uwpuVetuQxXwIijDMDUQ3RQ8EQ0CcEoIsclTOSHEh0KIbkKIbvXr1/evMbUKnBdFGYapQeg5gu8FYAgRHQawAEA/IvpCx/a8R5PUrWGeITAME37opuCFEFOEEKlCiDQAIwCsEkKM1qMta6iCoG904hkBwzBhjEH84EMtAcMwTPgRlJ2sQojfAPymf0s8omYYhrFijBG8zUQTYkEYhmHCCEMo+FAvcvIaK8Mw4YgxFLyEXin7Svfswe6M9ijPynI4bm+NNTzDMOGHIRS8dZFVLy+a/MWLAQAFv/7q1DDbhBiGCV+MoeB5BM0wDOOCIRS8DR5QMwzD2DCEgrdGk+RYNAzDMHYMoeBVwzZzhmFqEIZS8CHLycowDBOGGELB2000IYJNQAzDhCGGUvBeCVQRO1/OMwaGYcIYQyh4G+4UrhdFnPvZZ9id0R7C7C4vCStyhmGqH4ZQ8F4zOnnh9OtvWK6vrNRIIoZhmNBjCAXPNnCGYRhXjKHgJTijE8MwjB1DKHj99bobBW617bOCZxgmDDGEgreiux+8c/XsRcMwTBhjCAVvd5OsfgpXmM04PnkKSnbtCrUoDMMYDIMpeJ3RoZnKnBzkf/stsh96SPvKGYap0RhCwVsRBIiqKoUTgWrm6jczYMKfqsJC5d8rw2iEsRQ8CJkdOqJ482blAmwzZ8IEUVmJvd2648T0GaEWhTEwhlDwziaa4vXrQySJncqzZ7E7oz3yly4NtShMGCIqKgAA+cuWhVgSxsgYQsFb8dUPXlRU4PyKnwI3rSuYgMoPHwYAnFuwMNDaGYZh/MIQCt6Wss9HE8yZ2bNx7NFHAWk05XvDlvZOzZqFilOn/KuDYRhGJ4yh4P0cglccz3E84OdibN6ChciZPNk/IRiGYXTCEAreir/BxrTAXF4esrYZhmGUMIaCtyb80Fu/c0gCRiv4t8QEAUMoeHJeJlVri1f7kLmpz+Gou6r4QVaNEAJl+/eHWozgwq67jI4YQsFb8d1Eo2cmKF9devxowmCcnTsXBwcNRsm2baEWhWEMgSEUvG2RVfdgYzrUX01HcMWbtyD/u+80rbN0+3YAQHl2tqb1hiU8s2OCQGSoBdACFxONytGzCPQhq6bKWQuOjBoFAKg9ZEiIJanm1ODfEKM/hhjBW/G6yKrjqKni6FHd6mYYhvEHYyh4qxdNCO3eladPa1dZDSXgGVU1ogbdKhNCDKHgXdQ6T3urNVSD/n41506ZUKCbgieiWCL6h4i2EdEuItI9bJ41o9Ppt99WeYGPwyhZ+cI1f6IqL0/7NhiGYTRCz0XWMgD9hBCFRBQF4E8i+lEI8bfWDbkk/NA6xrbCiPLUm294uUZbERiGYXxFtxG8sFAofY2S/ukynCXbTtYAtSqPtkMLdz+jM1Xnz4dahKCiqw2eiCKIaCuAUwB+EUK4BGononFEtJGINp72c6HSquDNvip4VujhSY2wwfNvL9iU7tmDvT0uQd6Sb0MtStDQVcELIaqEEJ0BpALoQUQdFcp8KIToJoToVr9+fb/a8TdcMCt4JuTUiJdZeFC2dy8AoGjt2hBLEjx8UvBEZCKiWr42IoTIA/AbgAG+XqsGMnt2kxQq472X7dmDc/PnayaXT/DLpmbBf28mCHhV8EQ0j4hqEVECgH8B7CGiJ1VcV5+IkqXPcQD6A8gMUF7ltuBqoilcs8bneg4PH4ETM551OHbqzVko3rTJb9mE01S8bP9+VJ49az8Q5BFcVWEhyg4cCGqbqglTpVe8ZQuE2axP5TyCZ3REzQj+AiHEeQBDAfwAoDmA21Vc1xjAaiLaDmADLDb47/0V1BNKsWiyxo5zW77y3Dlk3T9elZtj7gcfoHTHDtc2vbjJuPPlPjhoMA4OvN5ru94oO3QIx56YoHp2YuXoXXfj4PWDAm5fV8JI6RX9vR5HRo7C2U8+CbUoDOMzatwkoyQ3x6EA3hFCVBB5z6EkhNgO4OIA5VOFdQSvNh782c/monD1ah0l8ozDi8XPUWvO09NQsmkTUkaOQHy3bqqvK92506/2PFF28CAi69RBRHKy5nWHmoocS9avsn01LIwxYwjUjOA/AHAYQAKAP4ioBYCw8jXyZoPXEnNJCUr37PX5ukPDhyPv22/dFwijUauvHBx4PQ7deFOoxWAYxgmvI3ghxNsA5FtDjxBRX/1E8h37CF5/JXl84kQU/LIS0S1aqCovyspR9Pd6lG7bjpxt23WWLnRUHD8eeCW22Uz1fdkxYUyYrvHoiZpF1kekRVYiojlEtBlAvyDIphqrwSgYCr54w0YAgLnCcw5Wa+CsssxMHL3rLr3FMgbCT3fXoKCxcgjrezU4NajP1ZhoxkiLrNcAqA/gbgAv6SqVzwRvBG/F2yIrYxD0+jP7oOAPDrkBJ5591ms5hnFGjYK3/gIHAvhUCLENYTaHNtls8EHEy4MZSEREIYRD6FxRWYlTb7yJqvx8v+s0EubSUpQdPBhqMYJG2d69ODcvRPszmGqNGgW/iYh+hkXB/0RESQB0cgr2j2Da4IMxtc5sfwEOjxhh+17wyy/I/fBDnHz5Fd3a1JKc6dORdf943eo/NmECDg68HubSUt3aYBgjoMZN8h4AnQEcFEIUE1FdWMw04YNWwcbUNGX9oHNbpbIFWVFZafl/uWe7v68IsxmorARFR2tab96Chf5KpKpU8d+WkEaiogKIjfWzLYYxPl5H8EIIMyyxZJ4motcAXCb5uIcNVlXrTcFXnj4TWEMarsKX7d/v6DYZghX+4xOeROZFnYLerhLFmzej4JeVli8avDvNWr8Ma6AHBlP9UeNF8xKAR2AJU/AvgIeJ6EW9BfOFuIwMAN4V/IH+/S0fAlEgGj3oBwcNRs7kKb7NBORtayDH+R9+CLgOXxFVVTg+aTLK9jtuHCr49VfN2ihatw57LuoUUIgJK97WUirPncOZjz7y/QXAL4zgUwP7XI0NfiCAq4UQnwghPoElYFjge+01JOWmGwEEZ6OTDYUHv+LYMb0a80kOJYo3bEDluXMayeM/ZXv3In/pUhyb4DWckXfcPLBFf60DABRvDFzBeyNn2jScfv0N5H+71L94NTXIZS9sqEFdrjaaZLLsc20d5AgIiogAoD5UQUBISqXi6FGXU4V//SUTSgdhhED50aM+KxIhBI7cfgeOjrnHp+uqCotw+r33ILTOkBUoYaQUzYVFAICcKVNwZvbsEEvjSHlWFpuWdCR/+XKUbA8ra7ULahT8iwC2ENFnRDQXwCYAL+grlm/YbkLvB1/H+q3xaTwFQCs7dAgHrrkWuR9+6FcbZbt3+1T+9Buv48zb/8P5H1f41Z7fBNDPajL2mIuLfQ7Spmb9t3jjRvXV6ax4i/75Bweuvgb5i5fo2o6vVObmamqOCyXHn5iAw7cOD7UYHlGzyDofwKUAFkv/egohFugtmC9Y1YGZCFUmE/rOno/51wx2Wz73/Q/8b0zjB7MiO9tSreTyZy4qclu2UgoHULxhoy0Ilp6Yi4otspXp4I5o3el76BDK9u0LaN+AlfM//Yy9PS7xOqra06Urjo5zH23UAW9yBfh70GvIUC6FhC7ZEV4jzKP33IvsBx6EuaQk1KLUCNwqeCLqYv0HS+jfbABZAJpIx8IG6zMoQCiPigIA/N/AGxXL7r20Z7DEUoW5TL23h1yVVPqg4M0B5qHMmfp0QNc74KQwD143EAcHD1Esen7FTzCXlbmvy0m5Fq2zmMhK//3XqxjF6zTP/R4wVXl5yJ3ziaHNKuVW06Ze8fUZBzz5wb/u4ZxAGMWjsakMufJw84yoiQHvEQ8PnxajUMUmy52UnEoFULhmDbLGjkODSZN0kEpjnPqueMMGHHv0UaSMGoVG/53msaxRyJn2XxT88gviOl3kUwhohnGHWwUvhAiriJGesPnB692QEJqbaNS8cE69OQsAYPYxVMHZzz8HoG5EG2zK9uzBgWvlGRztSpuIbLb0ihMngiyZG6S/e2lmJg4NHYaWPyxHTMuWgdcre1lVFRRYmvJ1fcAJc0mJY9YwxicKVq3GuQXz0dzPta5wQs1O1rDHZoM36ZpD3DdUjjKPT5jgtUzVmQA3aPmNDq9MWb+UHzmiQgQFGdS+ZDV4GZ/76iuH7+e/tyQlK/j1V4uCD0NzyuFbb7UnKAmT2U7luXOoUuGme2ziRMRdeBHq3D46CFIpkz1evzAbShSuWYOqc+dQe4iyqTIQwkgj+k8wf8Lm4mJ9G1B6IF2OhY9SyV++3McrVP611CgmIuQt+Rbl0kK1FsrszPsf4OSrrwKwKKUSnX3pvc3gzv/0s891hmP2qUM3DFWVqvL8d8twcuZMt+er8vNxfkWQvbp0JmvsOByfqI8Z1VAKXhAFd7OTiyAatK1iRFiyy25yKT/i6o8fTEp3aWT+kfedPHn66tWoKixUvkYI5EyZgsPDRyif94PTs2bh7Bwp/6qP/v96rMHkfbNI8zr9QQiBwjVr/F4Arjx1ShM5jj3+BI49+phumwpLfXQlDnc8edGMln3u5XTuQT2F8hXrTciVO4VklKvtA747o73FncxJccht8TlPPeU+FaCfCZKEEH4FNqsqLFKMASPKy3HiuefV24WJIBf6zLvvKZyXtZubKzWk/988EA+XMx99hCN3322tSKlyv+t2h1YvnbxFi5A1dhzyl3wbWEUB3qNVsfvifXbyxRdx/pdfvJYzl5Xh0DBl77vqiqcR/OOyz/9zOjdGB1n8xxpN0kSaRJQ8OOQGVBUUIG9RcEdPSqOSytyzXmcGOZOnKB4vWrMGgGXa6wt5ixYh86JOqDju6oq5r28/HLnjTrt8spHZ3m7dHMIcWylYuRLnvvwSJ198SdXL5szs91GRI0sBGIYudTbF6YPCOv36G+rcMzWcCZQfOYrdGe1RJvnF+4s1JWPFCY32XwRxbeDs3P/DsYce9lpOVFQGQZrg4knBk5vPSt9Dit2LRjaCD2CkULZ3L4rWrUPO09O8F5YR6OaNU6+/EdD1WlHw448AgLLDh1zOVebkoPiff2zfSzZvdjhf9q/rFFdICVk8mjxkD3zpjh04+dzzvojsUoevHHtyIvZd0dvv630mgN9nVX6+fd3BC0Vr1wIA8n18yeuGh/uu0MiMw9jxpOCFm89K30OKzURD+gp2Slp8k7MrvQ36zp6P3FrJOPm8TCn5q2xULbJqS/bDj+hS76lZsywmCTXKzJ9b1NCkcX7ZMlSePu140EO/V57JRcnOXZq1746qc3kuxw5cPwgH+l/tW0Ua9pWoqkLWff9B8eYt/lei0LdqFmEDwVfzmi/lzeXlqsJkBBtPCj6DiLYT0Q7ZZ+v3dkGSTxX2RVaT7YcTrCnG4r4WX+6tbS/QpX6LOVrd3YjKSpx49lmfwxgU/KzsqVGlJn6+B9ly3//A0SRB5NYmnLfwK8XjSm1425lrlvzJtVyIO//995Ygb9Izf/azz3D45ps1q1+O3A++dMcO2+e8bxbj9HvvhdBt1kJFzgkU/v67Khdfd5x69TWXEbvZ3WK6HA1eVM6/waJ//nG/jqWSo3fcib09LgmoDj3w5AffPmhSBIhdwduPhWqKYS4uhik+3q9rz//wg0uM9qwHHwJB3f0Ub9iAc/Pmo+zQIbT49FO/ZHBHVWEhIhITXU9oNDJU4yPtKwU//4zGz1mSVYvKSuRM+y/q3acyBo0CuR9+6Pf9Ose/V8Jq4st55hm0WbXK5XzO1Kker6/MzYVJrwxXGi8An5s3D+XZWf5vJlIYJwghULj6NyT27ePT4vJRaU2p1iYfZ0UySrZu9ftaPXE7ghdCHJH/A1AIoAuAetL3sEFugw9KXlYZziGK86WNML4+EOeSaikeL9u9W72Jxoc2fQ05rEUMn/PLl6M0c4/vF6q8L+eH2lxRgbJ9+wAApTt3In/JEhxzCttQvHkLTv/vHXcVOnw9Pest5C9d6rb9on82uD13cJD74HdWhKTgKxUWt9Wwr9flODjkBjeV+/Z7PP/Tz8hbtAiVVg8lCUsfO9YlKipQfvgwKk76ZkP3dceuuazM4+a4/MWLkT1+PPK++tqnet1Rcey46rWOcMWTm+T3RNRR+twYwE5YvGc+J6JHgyOeOuQjeD89A10oDYJ91cqKS3vjxlc+wJ7mylvf1bos+mIzzJmi7HnjlkptPAzyv/tOk3oA1/t1tqGL4mIcHDwElW5MGkIIHBk1CmfefVd1m1VOrp4O/t0a9VEgeDJLmcvKVCcqP/bII8h5ehr29brcYa1ByMN1SC/A45Mm48CA67D/yit9E1aqRlRW4tTrnkJfWSj46SfXKoSw7ZOwhrWoPOk5vEXZocMoXPOn1/YO9O/v+1pHmOHJBp8uhNgpfb4bwC9CiMEALkGYuUnKbfBajeDVxlwnJ51aumOnzzbwLe06AAAONUn16TpnitdbvFvUTE/zl2qjaK3uc6rRaYJVsnOXPaerE9YYL87kLfQ3ObgdNeEWnCNiFv651uIy6oyPoTZ825QjsO+K3tjT+WKf2gCAoj//RNFaWTIbJwXvU+pHhUFIwc8/I/ejjxWLVxUW2V5KSrs9cz/6GHu7dXeYPXgb6JRu346ssWMtba9U/s0YBU+/KPn86SoAPwCAEKIAQFg5JltdIi3mEt/9kwPB2UST9/XX2N83NIE27S8l/7Vo2YEDbt3Vip1cItWj/9/Cnx2f5UezdJDEFeE0as66915bTBsHfByclB9ydWP1hHxxumD1apybP1/VdadnzXJY7LUR6GDK+twquM9aQzjs7dbNKSidIwXSBqbKkye8ynN+ueOLSJjNKNkuuy8V9yOEQN43iz3mbQgnPC2yZhHRQ7DEge8CYAUAEFEcgKggyKYaRxt8SEWxE8IgT2639ntACAEiwsHrB7ktc2TUbYGIZcHsv7IXQiCzvd1bST5TyZvvIQdNWDn1urI7oz1azJun72/GacCTfb8loFbKyJE+VUNEmg2eitevd3vu3Pz5qHf//QCAypMnleVwd8xdrl5p458Vf8JsFG/YgJypUzVJ6B4MPCn4ewA8C6A/gOFCiDzp+KUAtHXRCBAT7FNGYXOTDPOnWkdK/cgTubd7D8ReoI+rp5yiP73bPl2wPrhOI72TL7+iuoryrOCM1v0l+5GH1bmlyghFYpDTb72NhF5S5BIvL6SSnbtw+OabEVGvnpdafX+xCSczEYSQZ/6xoXVsGWuwwcrcwF1Vg/H38+RFc0oI8R8hxA1CiJ9lx1cLIV7TXTIfkKfsUxNsbFd6G3zbW9vFk1C/UAqdRie+Yi4sdNihGk6U7typeDx/ifp8o8efnOhTm1qklKvKz0f54cPqyvqo3B2uLSxCzjPTPZZxp0zK9u/3nDVLgTPvWZKLKyWel5O/+BuLfLJFbuHU1qk3Z8FcEkCEVtv03X5/5z7/3Ha/ASf40YEjd9+NA9cOwLkv5znMSPXA7QieiDyuwgkhtA9e7Cf2Hy+pGgw8ONHiGz30D+8BiFTLEOLoDeVeHrbqTPHGjSg/cgRRTZr4WYPvL989F3dBm7/W+t5SRQVISht5aNiNqDh+HG3X65seMHfOx34vGB8cNBixHTog3Yc1DC3tz7kffICopk01qs3yDJqLi1H0xx9I9NWrJ0hYN//lL16se1ueFll7AkgFsAbAa7Ck8JP/CxsipKlZ2NjfQ4HBLVIHrh2AonXr/LvY36mwHzbxA4MGIX/5clScPGX3MNLLti4sLoaq3Gg9dEHpLrsbpLm4GDkzZnisqniD3d9fVdIWb6Ip+MOXHz7i0+hbVFU59LO51PuspOzAflW/jbxvFqNCWgco9zNoW/7y5did0d41HIbOeFLwjQA8BaAjgLcAXA3gjBDidyHE78EQTi12E40pZCPpUJtoHOLghBE7WraFWSNbo4PHQ5hSceQojj8xwXefcB/IatDI5lK7v//V9vj1GnD2//7P84K1E84eLhWnTmF3Rnuc8xR6wgmlWPH5S5e637QF+4IqSc/7kdtG+/widY7C6s45IWfqVOy/sg8AS4gFAD4PGqybr+RRPc3OuZZ1wJMNvkoIsUIIcScsC6v7AfwmedZ4hYiaEdFqItpNRLuISJ+IVpAn3YZ9kTXkI9qaPJ2wsCmjIx5+cgbmRCdpUt+Zd9zsOPVCqUKESwA4+4l2itETOf99RtP67pjxJsZMk7JOqc1ZK1NItpj0MgrX/Im9PS/D6VlvBSTb2c/mAgBOPBP4PatKEiJT6hVZdjOlP3FyTr7wosfz/rsJK1O+P7AQzmrwmJOViGIAXA9gJIA0AG8DUGs4qgTwhBBiMxElAdhERL8IITTPAO3gJql15X5ydu7cUIsQck6l1AUAHDSF1qv2+JNPhrT9Ap1SzPm7OKkUkz5n2jTN4wEFFG1SLXJX2a/t6wg2s4/XUb1dY5jdbIizoombcJDxtMg6FxbzzI8AZsh2tapCCJEDIEf6XEBEuwE0BaC5grdtmCCyR5MM1kYnNyN1xY0sQeTM+++HtH05xRs2hloEQ3Ji2n9Vly1c84eOkihTlpmpWV0+75hWSZHsZWeNne8rpXt8iK8UZNdWTyP42wEUAWgL4GHZxgICIIQQytGxFCCiNAAXA3DZ2UBE4wCMA4DmzZurrdKxDun/gsJoo1OIsU61S6OiEVvhe/o9JvTsztAuoKtXc4BGisfvhXAvOJuVTr78CpJvvcXjCH13RnvEd+/usV53Lri+cOiGoS7HqgqLsLdbN9QZM8bBFHj6befkePriyQZvEkIkSf9qyf4l+ajcEwF8A+BRIYRLIG8hxIdCiG5CiG7169f36yYcFHyQbd/hYxRyZWub9rju7bnYJMW6YaoPJUqhAXREK9fHMp2SVlc4JZcvXL3ashvXiwlG7vGjKW4Cy+3OaI+i9f9gb7duAFzXeUq2BMFsJcO36EY+QkRRsCj3L4UQujl92m3w8BpsbMLDT2nadqj93z2xvXUGAGCbTslIqiOl29Tv8vXXJU4LrJuJgoWqZBthiPPGqWBR9Jf7mcrZz/8viJJ4RjcFTxabzhwAu4UQuiYbtYYqUBNJclP7CzVr92jDxsizxnHXwbZmJsL6Dp38niMEOza+0Thy+x0ha1uEQejh6oCRN/hpgZ4j+F6w2PH7EdFW6d9APRqSm2iCyZ3T38BWP80fJ1Pq4mhDzzszl/buj8kPTsav3S/zqw0bOi/sLO5zLXJrJevaRkgIQawXK9bAWFVEKI7RKUuTEfCUyD1UhJHVVjcFL4T4UwhBQoiLhBCdpX8+BI5WT0yLNEubpG1Gp+z6jZDVoJFm9ckZ8cI7uHO65w3Bp1IsQZpOS+6G4cix+g3xv+F34Zlxj7ktwzMJ/3lz1L24ftanMHMfKlJdwvaGCl1t8MEiOq0FyGx2sMFr4SZ5+7Nv4o4ZbwZcj79YF3DNFL5/psqICABAQXxCiCUxJj9e1gdAWA0KmWpE+GoOHyEAwmQy1INg3Y1b3V0/g7UnwdDoOII/XTsFP11yhW71M6HD407WaoUQFuVuoKksCWviLD/vyUB9wejHhEem4mjjprhi6wbEl6nL2cpUD4wxgicCQTj4wevln15lMuHNkWNwoo63JAaBI49z7w9aJSBntOO7y6/CpoyOPl+npztubnIKAP9/Z4wjpf9qv1nfXwwzgjeZLQp+8oO+JXbwlZ0t2+K73lfjcOPAEmSrgcyWEbzwMRmzSz06mkjCeR9AOPLmbfcCAFbf71uqvOpupqtJVObkhFoEG8YYwQOIMFfBbDLhWIPGAOz264qICOQlahPNMNiYhHWRNfyf7nDe0WsE2BOJ8QfDKHiT2YwqU4TL8Zl3P4hhr34IACiJjgm2WAEh36FrJT8hEQDwW5dLNN+V6w+seIKFfv3Mr2bjYggFT0QojovHoqtc91H93vVS2+d5A9wnENCT8shI9J09H9/36ufTdVYTjXWxdE/zlhj62kf4pXsvzBj7qNdduUFVvqwldEVPE409UT1jNAyh4NUgAJRHusYlP1avAXalt9G17YJ4y6j7k8G3uJzb07wlimLjFK9zXmTd36wFAGCLj7tn+cFl1MDurMbDsAr+fEICyqLsCt0sedo4M/q5t2xJuAPBUwYpW7sKI+r/TJmJtZ2VQ5pa3SSdR+JhZe9W8fYQRDiVUgfv3HIHqjSeVRyr18Dh72xUgrGYzeY242FYBV8VEYkHnpQpbiKQWT/FWBEZ6V7t+rlhyb7RSUoqHsKx+P7UFug7ez4O+ek99MJdD+CbftdhV6u2mslURYTRz72FGffqlg0ybAiG8g2jYQOjEYZV8ABwoFma7bOA8gheK168+wHMvf4mxXPWdn1V0NYps/U66wtCbb5ZLe92ddeeAIC/LuqqeN5T35IQMAfo6qmEtc5/OnTyWrbKZKoW3khuCYbo1bl/GEUMreDlmE3kUTEuu7wf9qe2sH2fd80Qn9tY0fNKxeMkTynoA7brTE7XyWylqpS4DrbVE3XqYUdLdaNxvUefal54/d/9EhMfmqKrHHpiFBPNiTr1dF/zYuwYZqOTdwjO6rAiwu5W+cZtYwFYNqBUmiLw0TDfNqIAgMkWWsCpZT/1q/PIX+khF0RBXRyzziJGzrSkHvvkudAltLaZrsiSmrA8Kgq1it1HF9QyF0CwMYqJxvq78XWjF+MfNWYEL8jVS+DV2+9zW9Yf3Clya31+j+Ct19nc2XwcwWuC/y3J+91MJqy9qKtGctv78z9TZuKG1z8OuMaTKXXdejWFBJuZTscm2M/KsNQYBW8mk8vPeI0b7xW/R0tuRtLWcL/+LrJ6PB5iu6mvyuHr/tfj6fsnuO17X7Da1EkARzQKHTHihXdw35SZmtSlCUH8+2o9SzhRp16121xoNAyj4EetWIpIT2nOFEbwHgv7gcld/dYRfIAjJaUXhMc6NYyN7w21bWS2aAkAqFDYk+ArepktrOEuwoqgKHpt2xg583+Y+NBkTetkfMMwCr4kJhaVke6XFARI9WKjGm8LJeVCEDgfn4C+s+fjty6XOLYN+P2QOkvtYKLxUGU47lC0hpOIrAo856iWyV3CHb1eZgeaNkdpbKzUhvb175QSv2vJ2aTaONRE/2B/RsAwCn5J32s9nt/Rup3HF4Ac54epNCratYzCdWQWONKoKQBgUb/rAACHmqQG4CIonL4pL7L6SnFMLJ6+73Gcrp2ieP5A0+aYfeNtutp9j9drGHAd9kXWcHqF6YNef4t7n37Z9vlA0+Z4cMKMsDer3PbsLIyZ9mqoxagWGEbBe2Pyg5Mx/1p1sWicFUZBQiJ+uuQKB68boZBGj2D39zYJgb3N0jBm2qv4UoqBE+hDahthOVTku3Jb3a0n1nbujs8G36x4/rHHpuGrqwe5ScPnub2p/3kCw155H31nz0eZ9GKU96f184c3jvIq58EmzbBKFkvIGeOP22FfZA3CS2z2TaOxq1Vb7GjVTve2AsE646gO5CUm4WxSbdv3/IQkzXdze6LGKHgl3D00zsffueUOvHTXeFzzzhc4JSXANjv7pgOAsF9rMpttZZdeeY3lGuml0Hf2/IDk/u7Kq2Wy+lGBLQyx8p/ftqFK4Vx5ZCReduN9BAB/deqGPOkHnZdUy3bcHzPKPdNewXMedqla4+R7qrvv7Pn4cOgIn9sONmVRUZgz5FbFeEkWtFEKJ+rUw//chIyoCTOhYDPs1Q9x0yvvAwDOxydg6GsfYs4Nw4PWfs1W8AoPTRWRyw89V2bKON6gkeVahYfBJMyoko3gXRZd/X1+bG0pPJQeKnWn9twtBm9q1wF9Z89HoRQczUEE6Zq/L7wYK6RE0IBnpTB30M0O13or7yveXhnW82pnbqHk66uuxxfXDcNiZ1Ojxkr3+TEPYXG/67BHWuwOJvkJ1TMvg1acl+7/Tw08yNRiGAV/2baNmtTT/715WHb5VQ7HSL6BycPoV74l32Q2u4wsS2Ni8c8F3rfVy+sDPCtFB/MHgDdHjsHT9z3uoPyc5bB+t8p6Pj4B+1Nb4A/ZwrDSdZb2lH8ynjZz6TUy9FavFqEJfujZB18EIcx0hbQ+VObG/q00U5s78EbbgEItnsvrO4Kf8kDoNsWFG1vbtMf+ps11b8cwCv6OHxb7fE15tOviKeBqH5aPeK1Kb6/CCIiEsClAJQUPAJN8cBtzftwUFRpZFk0/umE41nbqhu96X421nbuj1MNCGTltnnnwyRkYO/UlVTK5M+uoRQtPjW97X42Pbhgh6w/Hfi6PjMSOlm3dvoyUOFurNs4oLDq/esd9mHODZxNPdv1GGPbK+zgpmeTkTHj4Kbx0x3+8tm/9jclNJ8Uxdluz0t/+s8G3YFW3y7zWrRb53+ZowyY4nVxHs7oB4Fj9RprW5y97m6VhxaW9QyrDY4//F2NlC9x6YZhQBXraD9Xajwn2UaNJmN37xWvIrvS2+LX7ZQ5mE0AKruVmtGafGVjOZ0meP84o9anz2oOaO7S8+OybkvxBACiKi0diSTHeGjkGAHDzqh8Uy757yx34rvfV+Pj5SYrnT6bURVRVJeqcz7cdu+nl913K5dStr0q276/oh7yk2hjxwjsYtWIpxi5dYDtnDY8w+f9c65djkpK7WP9mv3TvhRfGPGg7784UV+pmkOIP8h3Td05/HYD/IQWc/8zlkZEe1heCy31PvQgAGPD3Hw7Hj9VrgOiKCtTPPxcKsXTBMCP4ygjXdH1asa3tBarKkTA7mGi0CvLlqZYnH3nKRbkDwOA3P3Fre7Yq+JKYGFQqpDm0Yn1ZZbZoibzEWpIsTopGKnMw1f10093L9+cel7u9psLp77nsiv4Y/MYcZDWwjwLd1WsNGpfvJhfviBfeUVTozjwvU7CeKI+0K1l/s4Y5K/j1HTs7FnAzfvHXBXdLu4740ykyqJYhC5zNY7e8+F7Ye7+Mfu4t3PrSe6rLL+p3HY42bKKjRIFjIAUfnMnIp0NudXuOhP2BI6VFVpUUx8Ti5dvvQ1FcPADg2z7X4vFHn/bbvOG8am+Va23n7nj6P094uNLS4P2TZ2L5FZZ1CdcRvG9CyUecL979gNtyG9tf5PB93YUXAwCyZbtM3bXt62xuX2qa4nG1ytPbHgw1OCt455mOu3tVykPsicz01gCAj4eOwLT7JzietFm87I0XxCf4547q9Dc4ryLxfZXJhL6z52Np7/7+tGivh8hhgPDPBZ00jy9UaYrAu7fcgQefnKH6GltMqiBuPTSMgm9+4nioRYBJmG0jlwOpLfxe5FvcdwBWXNYHC/oPAmB56Le06xDQD+PBCTMw6rlZABx3wq6XFKcSSq6gzkpP/tKxJgR3Rm7iqlDYNGZlV3obfOTF3i1X3lOdFRQsD3NOvQYAgK/6X++xLivP3vuwqnJKlMS4rnWYyffMAxGSgq9ys27g7mWjZZx9pd/XkNc/xrcyt9xA6vJGuZSZ6/1ht/l8rZynHpiIa975AoDFHDfpocl48c77A6rTGevvvtiHWYkeORG8YRgFX6cg33shnclMa40TdS3K5WTd+pr7u+5p0crva3e1aoscaffopoyOimWcH0qlkbDzTGlvc/ti80E/vAJWdr8MubWSAQAPTnwW8wbcoOijreRRtCfN0h/yl8akhybb/PDXd3T/8pKT3dC/2DP5CUkYOOszl+NXvTcPnw5S3kTmDmuoaXdK4Es3pp9AF73lFMZbZox7W6Q7HH/n1rt8rsvbbPPBCTOwtU17h2MHm1h+P1UR9nvald4GfWfPx86W6mPI/9Ohs+1zibRQnaWBKaUkJgZvjLwHRbFxbmeJO1u2cftyt83OgrhFzzAKPlyYffNo22clTxs12GPIOP4QVl7i3mbtCz/06qd4fJnT1FiAXPKdnnRaePxY9hJz96PfJ8us5czMMQ/ZEnFESPFpSsIpXK+Me6a+jNEz3rB9H/rah27LWje3WZl51wMOOxqdseffVX4k3fmtV0WYcKx+QxRK5rxAOCt5ETl7DZlNJqxTeFn+3ONy9J09X9H84c2DaVerti7huq0eZvIXtjVb1yYnkx0AzRZtv7pqoMsI/56pL+Ntpxfb4j4DsKx3fyy8epDib/2Pzt3x0JPPYrmb58vd7ExPWMEzbjETYWlvz9Nzq+3YEweapWGHh6BT1gXa2PJyAFBUGMoOkb5zok49h+/fuLGfH2qcisy01i5yWiNNKik8Oc4KYOUll9t2NALAi3fe77DAHVElmWjcjODdrTFVmUwY/ews3DdZ3xDHu9Na4a3hdznIvOCawQCA5+55COWRkagissmvZr3I2YRZpPCSsg12FH5n7958OwBg1oi78cWAoaruQ4nZN9+On53cJg+mNndZW7GOwM1E9hDgslnv8fqWGXJ2Q2V3UDbRBMi8p/23pYYD+5uloe/s+fhEWsgNdZTE+dcO8ToylC8kByqt9WVRGRnpeu8a9UVJjKPNdHGfAYrlxvzXczCrpx6YGJAcP1/aG3tatMSplDp4c8QYm0KscjONd7eeYx1tW3dYzxlyK/rOno8HfFj8U8Pn19+Eb/tci7Wd7J43VuW7vuPFWNGzDyY+PAX93/3Sck6FDd7TYrjVW8rudiyw4OpBGDflBVsZa8iOpVde49YcKp8NZzVo5BBgz52p0hukIHthbBz2yMyV/6a1clmTCoWCN4wfPAA0zj0dahEC4hspAqWVUEcGWdJXWfnJUTOCV4u1roqISJS58+8OcL/D306LylU6uteq4bXbxmFDh07os3EdALjm37Wh7r6/uG4YAOBflflyfeXvjhfjsu2bMO2+J3C4STPb8TlDbnXwlHF/H3Y8KfgvBwzFTz2vRMbh/QAsyXmUTJ4bMxzTMDor7Scffsr2+Y4Zbzqcm/DIVCya5HnxdVvrDDz6xDNYNOl+h1eu/IVbEhODpx6Y6DBLfWDS80g7nuVQl5KC/63LJZgx9lGsePgOj3L4i6FG8AAwd/oTaHzmZKjFqDE45KENQPnmJSbZFtfWXdgF08c95nDeefetv3w4zL5L2UzksqagJWo2SjmbMrwFgPPEEj+8XXxlxWV9cCC1hYv3lVy572mejr8u7Op8qQtyhee8H+MnKYG91Uzmbj1r3rVDHL5PeGSqw/fcZMuI/aSTac7KzS/Pdjl2TPLCAoCvrxoIAPh00M0O+Qesge7MEREYOOszB+Vunb3IX4CA3QYvX/CdMfZRAMBHQ/XJUaubgieiT4joFBHt1KsNJZqfPI7Pn3kcI39aGsxm9aEaJLLocHCfJvUMe/VDW5Az500+u1u0so1fD0qbmLTA2/pCIAgCxk96zmsZq4nrt249Abi3wavx7X97xBiH7x/dMBwVERFYdnk/LLh6kBqxVbGlXQeP5/8z5QU8f89DXus5IwuFcPW7X/gli9q9JqUx6t0ZRz/3lu2zNdDg8iuusv0NPh94o0cX6LOSV5icf9NbezTRbPbSp/6i5wj+MwDe5/g6EGE2496lC0PRdI1jVXd7LBTn0ZO/OC8ojp/8PP6WFjU/89H90BOZaf65nb5w13hV5fI8eM0Alk1KziYumxJwo7icF4k9MW/AUPx8yRV447ax+ODG23BOFr45EOSzIDV4WpA+Vr9hQLMyLU2ESsiTBMl3y6/p3MPtNUrebg9MfM5Yi6xCiD8AnNWrfm+YhMCT//dBqJoPW+ZLng/hzK5WrvZjtdm4fMHZc8IThTLPnl8uuUKT9vOSaiHLyQffnRI43KQZfujZByNn/s+nNr6+yr7Za85g97uw9cTTgvSJuvUDys/rYCLUAXnAty+l9Q0AeH30WJ/rMpSCVwsRjSOijUS08fRpbRdJB677DTM+eAMXZwbVSqQZRYoZlQLD19EXY8HZ+0YLpo97zCXB9+pul+Gwm+Bvr97hPtGKO47IcpdqsdlHDwLpW/lmNufZjRbhorUcWNRIP3ghxIdCiG5CiG7162u/4NV76wa88dZMzJ3+BFpnHda8fqZm8OgTz/hUXilpilrufuY1lHsI6eAv29u2914oBLgkOfET59mNFhFmT/lgEvOGJ48tLXckyzGUm6Qnmp88jo9esOyYLIhPwJDXPw6xREx1wrqJJVj87iEXrZHQat1GCV+ToejNAQ8OAhU6mCCBMBjBh4Kk4iKsHD8Kk+bOxldTHkDdPEv854SS4hBLxjCMVvxwWd9Qi+CAp9hUxxvokwxFtxE8Ec0H0AdAPSLKBvCMEGKOXu35SoQQtoD/i6bYvSKyGjTCLz0ux9laydjbIh19Nq3H4cZN8UuIM8AwDOMbvzuloKyJkAgjX+tu3bqJjRv9y626O0Nf+6KZCLm1klEeFYU6BfkOUQTr5J/D2dopaHP0EPY1t0TiSzmfj3O1PLvJAcC3E8Z5DFrFMEzN4ETfzn5dR0SbhBDdlM7VGBt8oJiEcEjl5S6V2b7UNGxp1wHd/92GzwcOwxVbNuBooybo/u82ZBw5iINNm+Nnyc3u3qULEF1ZiZfeeQmTH7Tnam1y+iRKo6NxtnYKYktLcf83X2Bt52442LQ5umTuxM+X9sbL/3tJMb/rlE/fdUmkMfS3n7C6a0/ka+QH7Y33X3wKrbKP+r15JRAG/7ESy3r3xyfPPYkx0zzHk2EYo2OYEfzxSZOQv/Q7jSUKPwQsqeiSCwsgABxt1AQre1yO6/9chUZnzwAACuPiURITi8L4eKSezEFUVRUA+3bw84mJSCwuRnRlha3e0qho/NSzNxqfOY0e/27DyZS6aHAu17aDdFvrDDQ7lYPIykrUKi4CYEk2/dT4CXjn1WdwrH4jJBUXoe75c4grKwNgCbiUWFKMvKRa2NcsDecTklAQn4jeW9ajNDoGCaUlaH9oP47Xb4jc2slocPYM1nbqjtPJKbhyy3o0zD2D+vnn8Gu3y/Br98vw5Bcf4tdul2FT+wvx94Vd0O7wAUz6/H1sa90edc7noTg2DgP+/gNmIpiEwPF6DfDabWPxyMJP0eLEcRTGxWP0jDcxftHnOJ1SFy1OHENWw8bov/5P1D2fhzdG3Ys2WYewr1k64spKsUjapi5nwVMPYsQL7wAAvpj2CDZlXIh5A25Aw7NnEFVZiednv4ZPB9+CW1cux9T7J9hi1mcc2o/6eWex5uIeuGvZ10goKca7t95pq/eSnVtsLn/XrV2NH3v1xRtvPocOB/ehNDoay67oj4+HuiZDeeajWbbt7nLefflpPDV+ostLPaGk2CVq46U7NuPvC7vYvl+0bzfaHTmIVd0uQ+uswx6TwjjL7omLM3dii4cAX5PmzsbLGifmkPPo/DmYNfIev6+PKS9DmYdk9r6gNMPXYwTPCp5hVCIQ+gBwcuTyWF9qVUQAESLMluximWmtcMGh/Q7XVZlMECCYhBkkLJFTzibVRn5SElrkHHPZ/l9pigBBIMJsRllUFIpj47AvNQ0XHtxje5mv63gxUk/loNmpEzherwFiy8uQUFKMKlMEzickIqdeA1y891+cqZ2CevnnUGUyoTAuAUnFhcitlYxjDRrhov2Z2NL2AoAIyQXnEVlViaanTiIvqRayGzTCvGtvwPhFn6PBuVysvagrGp47g+2t26MqIgJ9Nq1DWVQ0EkuKURkRgfyEJHw+8Ea0yj6CO35YjKjKSkSaqxz6zmwyAULgu979cf3a31AYF4/8xCS0OGHpg5MpdVEl1dX09AnbwOZUSh2s7toTt65cjrKoaBxIbYFzSbVQGRmJ3NrJqJOfj4rISHQ8uBeJxUWoVVwEAWBJn2vR7ORxZDdohIF//Ybj9RsisrISW9p1xOXbNqDXP3/79TtgBc8wDBPmtM/c7dd1nhS8gdwkw2lsxTAME3oMpOAZhmEYOQZS8OFjamIYhgkHDKPgyV0GIIZhmBqKYRR8yujRoRaBYRgmrDCMgicdou8xDMNUZwyj4BmGYRhHWMEzDMMYFMMo+KjG9nCbjaZPR+yFF4ZQGoZhmNBjGAVviotD+8zdaJ+5GykjhiP966/QYMITqq5tufx7naVjGIYJPoZR8ErUvfde1L3/Px7LtF3/N2JatdJdlsSrrkK7rVt0b4dhGMaKoRU8ANQZNQpx3bqi7rhxAIBG0x1za0bUdh+zPfXddxy+t5j3pWK5uIsdI+nVGz/epUzDSRNhitU+cXMoqHX99bbPSddcg+QR7jPVeKLBhCfQeObzWomlC55+H0zNxvm5l9NmzR9BlMQ9hlfwkfXrI+2LL1D/sUfRbvMmpIwYoT6oD5kQ1dSe4T6+SxdkbN+GqObNAQBxnTujxbx5bhW/gxwNLTk94zp18v0mACRc1lPxeHxPS+5O8vLyiEhO9tyAypyQkQ0aoOnrr6HVyl/Q8ocfkPr2W2g8fTpa//G7rUyjGTNU1VXn7ruRfNNNqsoqkTJqFNK/XYKIOnX8rsOZJq++irbr/0ZCr15o/cfviO3QweF80jXXaNaWO5JvuRkNpz2teC6ifj3UGjTIax2ps98LSAZ/f6da4yxH808/0byN+EsvRdP/ve3zdSmjlHNCABa94wsJl13mc/tqMLyCt0JEMMXHey8IIOHyy6WLgPhLLWm/6j/ysOVQdDTSv1qIFvPnIW3BfMR3uRgky97eauUvgFOy3/aZu2GKscSRTlu4AHXuGQMASLr2WqR/uwRtN/zjIFtsR3vM7Jj27YHISDR94w1FWVNGjET7zN2I79oVAND07beQsWO7y+wjeeQI1H/8cdv3xKuuQuwFF6Ddls3I+HcXWi791qG8VWHXuftu2wst9Z3/Ie2rhQCA6NRUxLRMt5WPatDALtPwWxVltVJ37L3I2LUT5CbLfPO5cxWPW19ySdcNQPvM3Wj032mIzcgAyH2gucS+fR36tuF/pwEAGkyciPRvl6D+o486lK89eBAiatdG8zkfI6pBA0Q1bWK5btrTaLtxA5q+8ToAFS9MNzR55WVbHc5YZ0a1hwxBndtuUyxTe9BgNHlhpsOxtn+vc/ieMmoUkvq6z0ea7vS3llPvwQfRPnM3kq4bYG/zhiHI2P2vS1nn2VdUi+Zo+eMPaJ+5G203/OO2DWc89aXzyyyhp/JAxxMZO3d4PN/is09R6+qrfa43sm5d+wvIzW9ZdV2y50dLaoyCV0v7zN1AhNQtRKgzejQoLg61b7SPNCOSkxHvZnoWnZqKunffheSRrskZnIm7sCNiMzIQkZSEdps32Y43fGoKUt+fjVYrV6LlksVov3OHy0OQ9tVCtFrxI2pdaxlRWpVDXIcOoKgoJF11lcODTJGRqDdurO176jv/Q/rib2CKiwOZTC7rEFENGqB95m40nDQR6Yu/QfqSxUjq3x9RjXxLDpy2YL7LsQZPPOFWuQNAwiU9ECm10+bPNfZ7iI0DANSWmYgAACaLgq87dizaZ+5G699/Q50xY5Dx7y40m/0e4qSXH2Ax2bXP3I26Y+5GbEYGkm+52aP8DadMQeOXXkTKqFGISEwERUYi7auFaLnMHpra2YwTkZJi+1z/icdBsbFou/5vtJg3D7WHDEGtgQPRZt1fLm0lXXsN2m7cgPju3QFYvMEaTJ7kWEgIl7AcEcnJaPGlJXtWxo7taCS9xNwR264d2m74B2kLFwAA0r7+Gil33A4AqDPa8mKRD1qavPyyw3fA8iJwnn21/uknxKRbXvoRSUn247+utA2QlGj922qYEhPR+KUXkdj/KodzycNvRWyniwCoG73LZ9xWSGF2mrFju9e6PBHXuTPiZS+bFl987rG8KSHB0uffLHI4nnTNNUi88ko0fFp5xhYoNTZlX/qSxTAXFzscsyoVyGLkx7Zvj4wtm32q25SQgMbPPIOqs+dQe8hgl/MpI0eh8NdVqDV4iGP7TRqj3v33I75LF5drAMvos2zPXsRddBHiLrrI4VzyjcNQ+4YhDooztl07tzI6P7CeiEhMRER733LeUkwMRFkZ4jp3RlynTijZtg2JV16Jwt9/Vywf3boVGk6ahIKffwEAtFr+Pczl5YiUKcvkW25G4apVLmaT5h/PQd7Chaj/+GMAgKiGDdFw4pO286nvvoODA65Di/muL5vIunXR/NNPcPTuMYpymeLjkTx0qMMxa983++hDnP9+Oeo9MB4HrrnWLs9nn+HQDTcAAOqNHYt6Yy0v1vgu9kGB833lfb0I8d27IyIx0XY8RVrbSL75Fhx77DEUrVnj8NuUE9+1q0/xxCOSkhDXqZPtmrgLO6LRU0/ZzsdK91hrsP33G52Whqpz51CVn297yaQtmI/y7GOoPcjppSsjqmlTy+/6kktxZNQoRDZqhHr3348Tz1jWw0yxsWi3cYPlXocOxbEnJuD88uVo/cfvMEVHI33hQlX31OT115B4ZR+IkmJE1q8Pc0kJzCUlDmVarVwJiowARUW5XN987lyUbN6E02+9jVqDB6N4/XpUnjoFkmaAQqYvrAOX5OHDUbJtG6JbtEBc164o2bTJpV7AMvOPSEpymOlm7NgOREb69Cz6So1V8LFOCiv92yU2O3lS//4o+mNNwN41qW/NUjwendoUrVb86HK8zapVHuurM2qUx/OeRsW+0OTllwKuo+Xy5Sg/eAAA0Pz/5sJcXAxTQgLMBQUuZdtt2ghERcEUHY3EKyz5ak0JCTAlJDiUS+rbV1GJxbZr63HUaoqORutVv7o9H3/pparuyZnEK66wySuHIi1/h+j0dJdzSjR+7jk0fu45t+cjEhNQe+gNKFqzBpHSfo+2GzdibzfFHA8utM/cbUtKH92ypapr4i++GG03bkREov1v0GrFjzj1+uvI/ehj27G4zp0R17mzujq7XIwmr76C+C5dENW0qU3BO9N45vOoc9edDspQTtqC+Sj86y/Ed+mCvMWLcf67ZQBkMztJZlNcHExxcQ7XRqe6jvCtJFzSA5WnTlq+CIHWq35F2YEDtoFS8ebNODLK0XSWfOMwJN84DIDld1iyaZPthVB72DDkL7GYAZOusZiA5LZ5pZeM1tRYBe9MbEaG7XPyLbeg9uDBLj8OT8R3747iDRv0EC0g6tx1F85+9hlSRrpfELLSYNIkFG/YgNrS6NNXmr7xus08EZ3a1PYwmWJibGsQprp1Xa5zVuQu9b71FiJSkv2SSQ1ajKDarPsLB4cMQdXpMx7XBBzalWY5aqg1cCBMcfFIvLI3ADgoXl9I//or1WX9bcMTtWUzgmYffwxTjGsMKVNsLOI8bFSUv1QSLr3UpuA9kXT11Sj45ReHY+nfLUXV2bPKFwgBiox0mAW7m1lbaTBhAmLaZSD51lsAAI1fmInGzz2raCIKFqzgFSAikA/KHQCaz/kY5vJynSTyn4aTJ6Ghkx3X3cih7t13oe7dd/ndVq2BrkmqtcC6zhDORKakoNXy5ag4dgzRLVqg9k03os4dd3q8pvWqX1GVf15V/USEpH6OC6eNZsxAZEP3i3NNZ82CqKwEALRc9h1MtWt7fZl6I6adZSAU09r77NZUuzbM+fluzyde3isgWXyh6VuzALPZ4Vhs27YKJf1/2Zvi4x0cDIhI0Tut2ZyPkf/NYr/b8QVW8BpB0dGIqAYx6dtt3qR6hMn4RkStWoioVQsA0GTmTC+lLfb/SIUZjVq8eSvVGmBfF4hp08bvduTUHnQ9Ytu1VVVfy6Xfomz/AU3a9UTaokUoXu85YTWZTC7ebUpYnRkiGzXUQjRFEnv1QmKv4LzcWMHXMNS6itY04rp0sSlnxjNqXxZRjRr57HXlD3EdOyCuYwfvBVWQcHkvNJ31JpL69VM8H9m4MaJbtNCkrWBAws2qfCjo1q2b2LhxY6jFYBiGqTYQ0SYhhOKKO/vBMwzDGBRW8AzDMAaFFTzDMIxBYQXPMAxjUFjBMwzDGBRW8AzDMAaFFTzDMIxBYQXPMAxjUFjBMwzDGBRdFTwRDSCiPUS0n4gm69kWwzAM44huCp6IIgC8C+A6ABcAGElEF+jVHsMwDOOIniP4HgD2CyEOCiHKASwA4F+gcYZhGMZn9Iwm2RRAlux7NoBLnAsR0TgA46SvhUS0x8/26gE44+e1NQnuJ3VwP6mD+0kdevaT2/CWeip4paDjLqErhRAfAvgw4MaINrqLqMbY4X5SB/eTOrif1BGqftLTRJMNoJnseyqA4zq2xzAMw8jQU8FvANCGiNKJKBrACADf6dgewzAMI0M3E40QopKIHgTwE4AIAJ8IIXbp1R40MPPUELif1MH9pA7uJ3WEpJ/CKqMTwzAMox28k5VhGMagsIJnGIYxKNVewdfEcAhE1IyIVhPRbiLaRUSPSMfrENEvRLRP+n+K7JopUh/tIaJrZce7EtEO6dzbRETS8RgiWigdX09EaUG/UQ0goggi2kJE30vfuY8UIKJkIlpERJnS76on95UjRPSY9LztJKL5RBQb9n0khKi2/2BZvD0AoCWAaADbAFwQarmCcN+NAXSRPicB2AtLOIhXAEyWjk8G8LL0+QKpb2IApEt9FiGd+wdAT1j2LfwI4Drp+HgA70ufRwBYGOr79rOvHgcwD8D30nfuI+V+mgvgXulzNIBk7iuH/mkK4BCAOOn7VwDuCvc+CnnHBdjpPQH8JPs+BcCUUMsVgn5YCuBqAHsANJaONQawR6lfYPFs6imVyZQdHwngA3kZ6XMkLLvwKNT36mO/pAL4FUA/mYLnPnLtp1qS8iKn49xX9nux7syvI8n/PYBrwr2PqruJRikcQtMQyRISpGncxQDWA2gohMgBAOn/DaRi7vqpqfTZ+bjDNUKISgD5AOrqchP6MQvARABm2THuI1daAjgN4FPJnPUxESWA+8qGEOIYgNcAHAWQAyBfCPEzwryPqruCVxUOwagQUSKAbwA8KoQ476mowjHh4bina6oFRDQIwCkhxCa1lygcM3QfyYgE0AXAbCHExQCKYDE3uKPG9ZVkW78BFnNLEwAJRDTa0yUKx4LeR9VdwdfYcAhEFAWLcv9SCLFYOnySiBpL5xsDOCUdd9dP2dJn5+MO1xBRJIDaAM5qfye60QvAECI6DEsk035E9AW4j5TIBpAthFgvfV8Ei8LnvrLTH8AhIcRpIUQFgMUALkOY91F1V/A1MhyCtOo+B8BuIcQbslPfAbhT+nwnLLZ56/ER0ip9OoA2AP6RppQFRHSpVOcdTtdY67oZwCohGQerA0KIKUKIVCFEGiy/i1VCiNHgPnJBCHECQBYRtZMOXQXgX3BfyTkK4FIiipfu7SoAuxHufRTqxQsNFj8GwuJFcgDA1FDLE6R7vhyWqdt2AFulfwNhsdf9CmCf9P86smumSn20B9KqvXS8G4Cd0rl3YN/dHAvgawD7YVn1bxnq+w6gv/rAvsjKfaTcR50BbJR+U98CSOG+cumjGQAypfv7HBYPmbDuIw5VwDAMY1Cqu4mGYRiGcQMreIZhGIPCCp5hGMagsIJnGIYxKKzgGYZhDAoreCZsIaK6RLRV+neCiI7Jvkd7ubYbEb2too2/tJPYpe5kIhqvV/0M4w12k2SqBUQ0HUChEOI12bFIYYnZEZZIcYK+F0J0DLUsTM2ER/BMtYKIPiOiN4hoNYCXiagHEf0lBcn6y7obk4j6kD0G/HQi+oSIfiOig0T0sKy+Qln538geE/1LWZzugdKxP6X43d8ryNWBiP6RZhfbiagNgJcAtJKOvSqVe5KINkhlZkjH0qT650rHFxFRvHTuJSL6Vzr+mnO7DOMJ3ZJuM4yOtAXQXwhRRUS1APQWliTv/QG8AOAmhWsyAPSFJX7+HiKaLSwxReRcDKADLLFB1gLoRUQbAXwgtXGIiOa7kek/AN4SQnwpmY8iYAnY1VEI0RkAiOgaWLas94AlsNR3RNQblm3w7QDcI4RYS0SfABgv/X8YgAwhhCCiZF87iqnZ8AieqY58LYSokj7XBvA1Ee0E8CYsClqJ5UKIMiHEGVgCQjVUKPOPECJbCGGGJfxDGiwvhoNCiENSGXcKfh2Ap4hoEoAWQogShTLXSP+2ANgs1d1GOpclhFgrff4ClnAU5wGUAviYiG4EUOymbYZRhBU8Ux0pkn1+DsBqyc49GJZ4HkqUyT5XQXn2qlRGKYSrC0KIeQCGACgB8BMR9VMoRgBeFEJ0lv61FkLMsVbhWqWohGW0/w2AoQBWqJGFYaywgmeqO7UBHJM+36VD/ZkAWpI9P+ZwpUJE1BKWkf7bsEQFvAhAASwmISs/ARhDljj+IKKmRGRNENGciHpKn0cC+FMqV1sI8QOAR2EJCMYwqmEbPFPdeQXAXCJ6HMAqrSsXQpRIro4riOgMLFH+lBgOYDQRVQA4AeBZIcRZIlormY9+FEI8SUTtAayT1m8LAYyGZbawG8CdRPQBLJEJZ8Py8lpKRLGwjP4f0/r+GGPDbpIM4wUiShRCFEpeNe8C2CeEeFPD+tPA7pSMDrCJhmG8M5aItgLYBcuo+oPQisMw6uARPMMwjEHhETzDMIxBYQXPMAxjUFjBMwzDGBRW8AzDMAaFFTzDMIxB+X8m6i0JUQqYawAAAABJRU5ErkJggg==\n", 742 | "text/plain": [ 743 | "
" 744 | ] 745 | }, 746 | "metadata": { 747 | "needs_background": "light" 748 | }, 749 | "output_type": "display_data" 750 | }, 751 | { 752 | "name": "stdout", 753 | "output_type": "stream", 754 | "text": [ 755 | "Successful save results to pred.csv\n" 756 | ] 757 | } 758 | ], 759 | "source": [ 760 | "\n", 761 | "if __name__ == '__main__':\n", 762 | " model_loss, model_loss_record = train(train_loader, dev_loader)\n", 763 | " plot_learning_curve(model_loss_record)\n", 764 | "\n", 765 | " preds = test()\n", 766 | "\n", 767 | " # print(preds.shape)\n", 768 | " # print(type(preds))\n", 769 | "\n", 770 | " save_pred(preds, 'pred.csv')\n" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": null, 776 | "metadata": { 777 | "papermill": { 778 | "duration": 0.033301, 779 | "end_time": "2021-03-21T14:10:44.321619", 780 | "exception": false, 781 | "start_time": "2021-03-21T14:10:44.288318", 782 | "status": "completed" 783 | }, 784 | "tags": [] 785 | }, 786 | "outputs": [], 787 | "source": [] 788 | } 789 | ], 790 | "metadata": { 791 | "kernelspec": { 792 | "display_name": "Python 3", 793 | "language": "python", 794 | "name": "python3" 795 | }, 796 | "language_info": { 797 | "codemirror_mode": { 798 | "name": "ipython", 799 | "version": 3 800 | }, 801 | "file_extension": ".py", 802 | "mimetype": "text/x-python", 803 | "name": "python", 804 | "nbconvert_exporter": "python", 805 | "pygments_lexer": "ipython3", 806 | "version": "3.7.9" 807 | }, 808 | "papermill": { 809 | "default_parameters": {}, 810 | "duration": 396.639686, 811 | "end_time": "2021-03-21T14:10:45.267687", 812 | "environment_variables": {}, 813 | "exception": null, 814 | "input_path": "__notebook__.ipynb", 815 | "output_path": "__notebook__.ipynb", 816 | "parameters": {}, 817 | "start_time": "2021-03-21T14:04:08.628001", 818 | "version": "2.2.2" 819 | } 820 | }, 821 | "nbformat": 4, 822 | "nbformat_minor": 4 823 | } 824 | -------------------------------------------------------------------------------- /hw/hw2hw2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 8 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 9 | "execution": { 10 | "iopub.execute_input": "2021-03-14T15:23:56.019504Z", 11 | "iopub.status.busy": "2021-03-14T15:23:56.018930Z", 12 | "iopub.status.idle": "2021-03-14T15:23:56.030603Z", 13 | "shell.execute_reply": "2021-03-14T15:23:56.030085Z" 14 | }, 15 | "papermill": { 16 | "duration": 0.02882, 17 | "end_time": "2021-03-14T15:23:56.030765", 18 | "exception": false, 19 | "start_time": "2021-03-14T15:23:56.001945", 20 | "status": "completed" 21 | }, 22 | "tags": [] 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "/kaggle/input/ml2021spring-hw2/sampleSubmission.csv\n", 30 | "/kaggle/input/ml2021spring-hw2/timit_11/timit_11/train_11.npy\n", 31 | "/kaggle/input/ml2021spring-hw2/timit_11/timit_11/test_11.npy\n", 32 | "/kaggle/input/ml2021spring-hw2/timit_11/timit_11/train_label_11.npy\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "# This Python 3 environment comes with many helpful analytics libraries installed\n", 38 | "# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n", 39 | "# For example, here's several helpful packages to load\n", 40 | "\n", 41 | "import numpy as np # linear algebra\n", 42 | "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", 43 | "\n", 44 | "# Input data files are available in the read-only \"../input/\" directory\n", 45 | "# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n", 46 | "\n", 47 | "import os\n", 48 | "for dirname, _, filenames in os.walk('/kaggle/input'):\n", 49 | " for filename in filenames:\n", 50 | " print(os.path.join(dirname, filename))\n", 51 | "\n", 52 | "# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n", 53 | "# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "papermill": { 60 | "duration": 0.010507, 61 | "end_time": "2021-03-14T15:23:56.052244", 62 | "exception": false, 63 | "start_time": "2021-03-14T15:23:56.041737", 64 | "status": "completed" 65 | }, 66 | "tags": [] 67 | }, 68 | "source": [ 69 | "# Load Data" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": { 76 | "execution": { 77 | "iopub.execute_input": "2021-03-14T15:23:56.078437Z", 78 | "iopub.status.busy": "2021-03-14T15:23:56.077789Z", 79 | "iopub.status.idle": "2021-03-14T15:24:42.276046Z", 80 | "shell.execute_reply": "2021-03-14T15:24:42.276476Z" 81 | }, 82 | "papermill": { 83 | "duration": 46.213812, 84 | "end_time": "2021-03-14T15:24:42.276624", 85 | "exception": false, 86 | "start_time": "2021-03-14T15:23:56.062812", 87 | "status": "completed" 88 | }, 89 | "tags": [] 90 | }, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Loading data ...\n", 97 | "/kaggle/working\n", 98 | "Size of training data: (1229932, 429)\n", 99 | "Size of testing data: (451552, 429)\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "print('Loading data ...')\n", 105 | "print(os.getcwd())\n", 106 | "\n", 107 | "data_root='../input/ml2021spring-hw2/timit_11/timit_11/'\n", 108 | "train = np.load(data_root + 'train_11.npy')\n", 109 | "train_label = np.load(data_root + 'train_label_11.npy')\n", 110 | "test = np.load(data_root + 'test_11.npy')\n", 111 | "\n", 112 | "print('Size of training data: {}'.format(train.shape))\n", 113 | "print('Size of testing data: {}'.format(test.shape))" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": { 119 | "papermill": { 120 | "duration": 0.011427, 121 | "end_time": "2021-03-14T15:24:42.299792", 122 | "exception": false, 123 | "start_time": "2021-03-14T15:24:42.288365", 124 | "status": "completed" 125 | }, 126 | "tags": [] 127 | }, 128 | "source": [ 129 | "# DataSet" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 3, 135 | "metadata": { 136 | "execution": { 137 | "iopub.execute_input": "2021-03-14T15:24:42.329912Z", 138 | "iopub.status.busy": "2021-03-14T15:24:42.329317Z", 139 | "iopub.status.idle": "2021-03-14T15:24:43.564479Z", 140 | "shell.execute_reply": "2021-03-14T15:24:43.563939Z" 141 | }, 142 | "papermill": { 143 | "duration": 1.25328, 144 | "end_time": "2021-03-14T15:24:43.564622", 145 | "exception": false, 146 | "start_time": "2021-03-14T15:24:42.311342", 147 | "status": "completed" 148 | }, 149 | "tags": [] 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "import torch\n", 154 | "from torch.utils.data import Dataset\n", 155 | "\n", 156 | "class TIMITDataset(Dataset):\n", 157 | " def __init__(self, X, y=None):\n", 158 | " self.data = torch.from_numpy(X).float()\n", 159 | " if y is not None:\n", 160 | " y = y.astype(np.int)\n", 161 | " self.label = torch.LongTensor(y)\n", 162 | " else:\n", 163 | " self.label = None\n", 164 | "\n", 165 | " def __getitem__(self, idx):\n", 166 | " if self.label is not None:\n", 167 | " return self.data[idx], self.label[idx]\n", 168 | " else:\n", 169 | " return self.data[idx]\n", 170 | "\n", 171 | " def __len__(self):\n", 172 | " return len(self.data)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": { 178 | "papermill": { 179 | "duration": 0.011238, 180 | "end_time": "2021-03-14T15:24:43.587838", 181 | "exception": false, 182 | "start_time": "2021-03-14T15:24:43.576600", 183 | "status": "completed" 184 | }, 185 | "tags": [] 186 | }, 187 | "source": [ 188 | "# 超参数配置" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 4, 194 | "metadata": { 195 | "execution": { 196 | "iopub.execute_input": "2021-03-14T15:24:43.616533Z", 197 | "iopub.status.busy": "2021-03-14T15:24:43.614762Z", 198 | "iopub.status.idle": "2021-03-14T15:24:43.617112Z", 199 | "shell.execute_reply": "2021-03-14T15:24:43.617497Z" 200 | }, 201 | "papermill": { 202 | "duration": 0.018512, 203 | "end_time": "2021-03-14T15:24:43.617630", 204 | "exception": false, 205 | "start_time": "2021-03-14T15:24:43.599118", 206 | "status": "completed" 207 | }, 208 | "tags": [] 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "VAL_RATIO = 0.1 # 验证集比例\n", 213 | "BATCH_SIZE = 512\n", 214 | "\n", 215 | "# training parameters\n", 216 | "num_epoch = 50 # number of training epoch\n", 217 | "learning_rate = 0.0001 # learning rate\n", 218 | "weight_decay = 0 # 不知道为什么基本不能收敛\n", 219 | "dropout = 0.1 # 多大概率丢失\n", 220 | "\n", 221 | "# the path where checkpoint saved\n", 222 | "model_path = './model.ckpt'" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 5, 228 | "metadata": { 229 | "execution": { 230 | "iopub.execute_input": "2021-03-14T15:24:43.645139Z", 231 | "iopub.status.busy": "2021-03-14T15:24:43.644456Z", 232 | "iopub.status.idle": "2021-03-14T15:24:43.648522Z", 233 | "shell.execute_reply": "2021-03-14T15:24:43.649151Z" 234 | }, 235 | "papermill": { 236 | "duration": 0.020441, 237 | "end_time": "2021-03-14T15:24:43.649306", 238 | "exception": false, 239 | "start_time": "2021-03-14T15:24:43.628865", 240 | "status": "completed" 241 | }, 242 | "tags": [] 243 | }, 244 | "outputs": [ 245 | { 246 | "name": "stdout", 247 | "output_type": "stream", 248 | "text": [ 249 | "Size of training set: (1106938, 429)\n", 250 | "Size of validation set: (122994, 429)\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "percent = int(train.shape[0] * (1 - VAL_RATIO))\n", 256 | "train_x, train_y, val_x, val_y = train[:percent], train_label[:percent], train[percent:], train_label[percent:]\n", 257 | "print('Size of training set: {}'.format(train_x.shape))\n", 258 | "print('Size of validation set: {}'.format(val_x.shape))" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 6, 264 | "metadata": { 265 | "execution": { 266 | "iopub.execute_input": "2021-03-14T15:24:43.685283Z", 267 | "iopub.status.busy": "2021-03-14T15:24:43.684705Z", 268 | "iopub.status.idle": "2021-03-14T15:24:46.010627Z", 269 | "shell.execute_reply": "2021-03-14T15:24:46.011662Z" 270 | }, 271 | "papermill": { 272 | "duration": 2.349299, 273 | "end_time": "2021-03-14T15:24:46.011862", 274 | "exception": false, 275 | "start_time": "2021-03-14T15:24:43.662563", 276 | "status": "completed" 277 | }, 278 | "tags": [] 279 | }, 280 | "outputs": [], 281 | "source": [ 282 | "\n", 283 | "\n", 284 | "from torch.utils.data import DataLoader\n", 285 | "\n", 286 | "train_set = TIMITDataset(train_x, train_y)\n", 287 | "val_set = TIMITDataset(val_x, val_y)\n", 288 | "train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) #only shuffle the training data\n", 289 | "val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 7, 295 | "metadata": { 296 | "execution": { 297 | "iopub.execute_input": "2021-03-14T15:24:46.064476Z", 298 | "iopub.status.busy": "2021-03-14T15:24:46.063685Z", 299 | "iopub.status.idle": "2021-03-14T15:24:46.171975Z", 300 | "shell.execute_reply": "2021-03-14T15:24:46.170978Z" 301 | }, 302 | "papermill": { 303 | "duration": 0.141942, 304 | "end_time": "2021-03-14T15:24:46.172194", 305 | "exception": false, 306 | "start_time": "2021-03-14T15:24:46.030252", 307 | "status": "completed" 308 | }, 309 | "tags": [] 310 | }, 311 | "outputs": [ 312 | { 313 | "data": { 314 | "text/plain": [ 315 | "80" 316 | ] 317 | }, 318 | "execution_count": 7, 319 | "metadata": {}, 320 | "output_type": "execute_result" 321 | } 322 | ], 323 | "source": [ 324 | "import gc\n", 325 | "\n", 326 | "del train, train_label, train_x, train_y, val_x, val_y\n", 327 | "gc.collect()" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": { 333 | "papermill": { 334 | "duration": 0.012371, 335 | "end_time": "2021-03-14T15:24:46.212480", 336 | "exception": false, 337 | "start_time": "2021-03-14T15:24:46.200109", 338 | "status": "completed" 339 | }, 340 | "tags": [] 341 | }, 342 | "source": [ 343 | "# model" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 8, 349 | "metadata": { 350 | "execution": { 351 | "iopub.execute_input": "2021-03-14T15:24:46.247912Z", 352 | "iopub.status.busy": "2021-03-14T15:24:46.246192Z", 353 | "iopub.status.idle": "2021-03-14T15:24:46.248493Z", 354 | "shell.execute_reply": "2021-03-14T15:24:46.248889Z" 355 | }, 356 | "papermill": { 357 | "duration": 0.024228, 358 | "end_time": "2021-03-14T15:24:46.249021", 359 | "exception": false, 360 | "start_time": "2021-03-14T15:24:46.224793", 361 | "status": "completed" 362 | }, 363 | "tags": [] 364 | }, 365 | "outputs": [], 366 | "source": [ 367 | "import torch\n", 368 | "import torch.nn as nn\n", 369 | "\n", 370 | "class Classifier(nn.Module):\n", 371 | " def __init__(self):\n", 372 | " super(Classifier, self).__init__()\n", 373 | " self.layer1 = nn.Linear(429, 1024)\n", 374 | " self.bn1 = nn.BatchNorm1d(1024)\n", 375 | " \n", 376 | " self.layer2 = nn.Linear(1024, 256)\n", 377 | " self.bn2 = nn.BatchNorm1d(256)\n", 378 | " \n", 379 | " self.layer3 = nn.Linear(256, 64)\n", 380 | " self.bn3 = nn.BatchNorm1d(64)\n", 381 | " \n", 382 | " self.out = nn.Linear(64, 39) \n", 383 | " self.dropout = nn.Dropout(dropout)\n", 384 | " self.act_fn = nn.Sigmoid()\n", 385 | " \n", 386 | " def forward(self, x):\n", 387 | " x = self.layer1(x)\n", 388 | " x = self.dropout(x)\n", 389 | " x = self.bn1(x)\n", 390 | " x = self.act_fn(x)\n", 391 | "\n", 392 | " x = self.layer2(x)\n", 393 | " x = self.dropout(x)\n", 394 | " x = self.bn2(x)\n", 395 | " x = self.act_fn(x)\n", 396 | "\n", 397 | " x = self.layer3(x)\n", 398 | " x = self.dropout(x)\n", 399 | " x = self.bn3(x)\n", 400 | " x = self.act_fn(x)\n", 401 | "\n", 402 | " x = self.out(x)\n", 403 | " \n", 404 | " return x" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": { 410 | "papermill": { 411 | "duration": 0.012113, 412 | "end_time": "2021-03-14T15:24:46.273469", 413 | "exception": false, 414 | "start_time": "2021-03-14T15:24:46.261356", 415 | "status": "completed" 416 | }, 417 | "tags": [] 418 | }, 419 | "source": [ 420 | "# train" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 9, 426 | "metadata": { 427 | "execution": { 428 | "iopub.execute_input": "2021-03-14T15:24:46.302266Z", 429 | "iopub.status.busy": "2021-03-14T15:24:46.301688Z", 430 | "iopub.status.idle": "2021-03-14T15:24:46.305594Z", 431 | "shell.execute_reply": "2021-03-14T15:24:46.305195Z" 432 | }, 433 | "papermill": { 434 | "duration": 0.019715, 435 | "end_time": "2021-03-14T15:24:46.305722", 436 | "exception": false, 437 | "start_time": "2021-03-14T15:24:46.286007", 438 | "status": "completed" 439 | }, 440 | "tags": [] 441 | }, 442 | "outputs": [], 443 | "source": [ 444 | "#check device\n", 445 | "def get_device():\n", 446 | " return 'cuda' if torch.cuda.is_available() else 'cpu'" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 10, 452 | "metadata": { 453 | "execution": { 454 | "iopub.execute_input": "2021-03-14T15:24:46.333958Z", 455 | "iopub.status.busy": "2021-03-14T15:24:46.333456Z", 456 | "iopub.status.idle": "2021-03-14T15:24:46.337017Z", 457 | "shell.execute_reply": "2021-03-14T15:24:46.337413Z" 458 | }, 459 | "papermill": { 460 | "duration": 0.01941, 461 | "end_time": "2021-03-14T15:24:46.337541", 462 | "exception": false, 463 | "start_time": "2021-03-14T15:24:46.318131", 464 | "status": "completed" 465 | }, 466 | "tags": [] 467 | }, 468 | "outputs": [], 469 | "source": [ 470 | "# fix random seed\n", 471 | "# def same_seeds(seed):\n", 472 | "# torch.manual_seed(seed)\n", 473 | "# if torch.cuda.is_available():\n", 474 | "# torch.cuda.manual_seed(seed)\n", 475 | "# torch.cuda.manual_seed_all(seed) \n", 476 | "# np.random.seed(seed) \n", 477 | "# torch.backends.cudnn.benchmark = False\n", 478 | "# torch.backends.cudnn.deterministic = True" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 11, 484 | "metadata": { 485 | "execution": { 486 | "iopub.execute_input": "2021-03-14T15:24:46.712504Z", 487 | "iopub.status.busy": "2021-03-14T15:24:46.711737Z", 488 | "iopub.status.idle": "2021-03-14T15:24:50.522299Z", 489 | "shell.execute_reply": "2021-03-14T15:24:50.521827Z" 490 | }, 491 | "papermill": { 492 | "duration": 4.172385, 493 | "end_time": "2021-03-14T15:24:50.522433", 494 | "exception": false, 495 | "start_time": "2021-03-14T15:24:46.350048", 496 | "status": "completed" 497 | }, 498 | "tags": [] 499 | }, 500 | "outputs": [ 501 | { 502 | "name": "stdout", 503 | "output_type": "stream", 504 | "text": [ 505 | "DEVICE: cuda\n" 506 | ] 507 | } 508 | ], 509 | "source": [ 510 | "# fix random seed for reproducibility\n", 511 | "# same_seeds(0)\n", 512 | "\n", 513 | "# get device \n", 514 | "device = get_device()\n", 515 | "print(f'DEVICE: {device}')\n", 516 | "\n", 517 | "\n", 518 | "# create model, define a loss function, and optimizer\n", 519 | "model = Classifier().to(device)\n", 520 | "criterion = nn.CrossEntropyLoss() \n", 521 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = weight_decay)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": 12, 527 | "metadata": { 528 | "execution": { 529 | "iopub.execute_input": "2021-03-14T15:24:50.562691Z", 530 | "iopub.status.busy": "2021-03-14T15:24:50.561023Z", 531 | "iopub.status.idle": "2021-03-14T15:43:54.978418Z", 532 | "shell.execute_reply": "2021-03-14T15:43:54.979050Z" 533 | }, 534 | "papermill": { 535 | "duration": 1144.443637, 536 | "end_time": "2021-03-14T15:43:54.979282", 537 | "exception": false, 538 | "start_time": "2021-03-14T15:24:50.535645", 539 | "status": "completed" 540 | }, 541 | "tags": [] 542 | }, 543 | "outputs": [ 544 | { 545 | "name": "stdout", 546 | "output_type": "stream", 547 | "text": [ 548 | "[001/050] Train Acc: 0.366569 Loss: 2.591440 | Val Acc: 0.442452 loss: 2.114810\n", 549 | "saving model with val acc 0.442\n", 550 | "[002/050] Train Acc: 0.478588 Loss: 1.916905 | Val Acc: 0.532546 loss: 1.655892\n", 551 | "saving model with val acc 0.533\n", 552 | "[003/050] Train Acc: 0.541635 Loss: 1.614705 | Val Acc: 0.593964 loss: 1.426979\n", 553 | "saving model with val acc 0.594\n", 554 | "[004/050] Train Acc: 0.577619 Loss: 1.452529 | Val Acc: 0.611314 loss: 1.301433\n", 555 | "saving model with val acc 0.611\n", 556 | "[005/050] Train Acc: 0.600430 Loss: 1.350428 | Val Acc: 0.637072 loss: 1.210813\n", 557 | "saving model with val acc 0.637\n", 558 | "[006/050] Train Acc: 0.616069 Loss: 1.282020 | Val Acc: 0.649812 loss: 1.146848\n", 559 | "saving model with val acc 0.650\n", 560 | "[007/050] Train Acc: 0.627174 Loss: 1.234129 | Val Acc: 0.660821 loss: 1.099274\n", 561 | "saving model with val acc 0.661\n", 562 | "[008/050] Train Acc: 0.635513 Loss: 1.197914 | Val Acc: 0.668911 loss: 1.066928\n", 563 | "saving model with val acc 0.669\n", 564 | "[009/050] Train Acc: 0.642452 Loss: 1.169371 | Val Acc: 0.675179 loss: 1.040947\n", 565 | "saving model with val acc 0.675\n", 566 | "[010/050] Train Acc: 0.648656 Loss: 1.145438 | Val Acc: 0.677984 loss: 1.028577\n", 567 | "saving model with val acc 0.678\n", 568 | "[011/050] Train Acc: 0.653290 Loss: 1.125981 | Val Acc: 0.683123 loss: 1.008011\n", 569 | "saving model with val acc 0.683\n", 570 | "[012/050] Train Acc: 0.657285 Loss: 1.108494 | Val Acc: 0.684212 loss: 1.000955\n", 571 | "saving model with val acc 0.684\n", 572 | "[013/050] Train Acc: 0.661304 Loss: 1.093018 | Val Acc: 0.691725 loss: 0.983416\n", 573 | "saving model with val acc 0.692\n", 574 | "[014/050] Train Acc: 0.665231 Loss: 1.078725 | Val Acc: 0.692774 loss: 0.977441\n", 575 | "saving model with val acc 0.693\n", 576 | "[015/050] Train Acc: 0.667876 Loss: 1.066131 | Val Acc: 0.697058 loss: 0.960904\n", 577 | "saving model with val acc 0.697\n", 578 | "[016/050] Train Acc: 0.670972 Loss: 1.054978 | Val Acc: 0.699741 loss: 0.953624\n", 579 | "saving model with val acc 0.700\n", 580 | "[017/050] Train Acc: 0.674522 Loss: 1.042549 | Val Acc: 0.701384 loss: 0.947594\n", 581 | "saving model with val acc 0.701\n", 582 | "[018/050] Train Acc: 0.676911 Loss: 1.032386 | Val Acc: 0.701750 loss: 0.940864\n", 583 | "saving model with val acc 0.702\n", 584 | "[019/050] Train Acc: 0.679211 Loss: 1.023117 | Val Acc: 0.701628 loss: 0.943298\n", 585 | "[020/050] Train Acc: 0.681863 Loss: 1.013738 | Val Acc: 0.707091 loss: 0.920681\n", 586 | "saving model with val acc 0.707\n", 587 | "[021/050] Train Acc: 0.684547 Loss: 1.004693 | Val Acc: 0.708262 loss: 0.919690\n", 588 | "saving model with val acc 0.708\n", 589 | "[022/050] Train Acc: 0.686926 Loss: 0.996457 | Val Acc: 0.707506 loss: 0.924066\n", 590 | "[023/050] Train Acc: 0.688497 Loss: 0.989403 | Val Acc: 0.709441 loss: 0.911010\n", 591 | "saving model with val acc 0.709\n", 592 | "[024/050] Train Acc: 0.690082 Loss: 0.982603 | Val Acc: 0.711140 loss: 0.907062\n", 593 | "saving model with val acc 0.711\n", 594 | "[025/050] Train Acc: 0.692631 Loss: 0.974660 | Val Acc: 0.713027 loss: 0.897646\n", 595 | "saving model with val acc 0.713\n", 596 | "[026/050] Train Acc: 0.694134 Loss: 0.967665 | Val Acc: 0.715189 loss: 0.890961\n", 597 | "saving model with val acc 0.715\n", 598 | "[027/050] Train Acc: 0.695987 Loss: 0.960965 | Val Acc: 0.716287 loss: 0.886491\n", 599 | "saving model with val acc 0.716\n", 600 | "[028/050] Train Acc: 0.697741 Loss: 0.953969 | Val Acc: 0.718377 loss: 0.885856\n", 601 | "saving model with val acc 0.718\n", 602 | "[029/050] Train Acc: 0.699017 Loss: 0.949349 | Val Acc: 0.718954 loss: 0.882991\n", 603 | "saving model with val acc 0.719\n", 604 | "[030/050] Train Acc: 0.701137 Loss: 0.941593 | Val Acc: 0.718515 loss: 0.881140\n", 605 | "[031/050] Train Acc: 0.702531 Loss: 0.936155 | Val Acc: 0.719889 loss: 0.873693\n", 606 | "saving model with val acc 0.720\n", 607 | "[032/050] Train Acc: 0.704114 Loss: 0.930758 | Val Acc: 0.722751 loss: 0.865760\n", 608 | "saving model with val acc 0.723\n", 609 | "[033/050] Train Acc: 0.705065 Loss: 0.925543 | Val Acc: 0.721222 loss: 0.873852\n", 610 | "[034/050] Train Acc: 0.707352 Loss: 0.919805 | Val Acc: 0.722734 loss: 0.867530\n", 611 | "[035/050] Train Acc: 0.708702 Loss: 0.915352 | Val Acc: 0.722775 loss: 0.866566\n", 612 | "saving model with val acc 0.723\n", 613 | "[036/050] Train Acc: 0.709894 Loss: 0.909784 | Val Acc: 0.722312 loss: 0.864558\n", 614 | "[037/050] Train Acc: 0.710539 Loss: 0.906012 | Val Acc: 0.725775 loss: 0.854063\n", 615 | "saving model with val acc 0.726\n", 616 | "[038/050] Train Acc: 0.712677 Loss: 0.900292 | Val Acc: 0.724588 loss: 0.860009\n", 617 | "[039/050] Train Acc: 0.713300 Loss: 0.897194 | Val Acc: 0.725466 loss: 0.856330\n", 618 | "[040/050] Train Acc: 0.715254 Loss: 0.891468 | Val Acc: 0.726092 loss: 0.855583\n", 619 | "saving model with val acc 0.726\n", 620 | "[041/050] Train Acc: 0.716009 Loss: 0.887732 | Val Acc: 0.726247 loss: 0.853594\n", 621 | "saving model with val acc 0.726\n", 622 | "[042/050] Train Acc: 0.716956 Loss: 0.883696 | Val Acc: 0.727182 loss: 0.850804\n", 623 | "saving model with val acc 0.727\n", 624 | "[043/050] Train Acc: 0.718140 Loss: 0.879459 | Val Acc: 0.728271 loss: 0.845017\n", 625 | "saving model with val acc 0.728\n", 626 | "[044/050] Train Acc: 0.719535 Loss: 0.874435 | Val Acc: 0.728149 loss: 0.847982\n", 627 | "[045/050] Train Acc: 0.720074 Loss: 0.871938 | Val Acc: 0.727296 loss: 0.852061\n", 628 | "[046/050] Train Acc: 0.721647 Loss: 0.867166 | Val Acc: 0.729702 loss: 0.840858\n", 629 | "saving model with val acc 0.730\n", 630 | "[047/050] Train Acc: 0.722977 Loss: 0.863109 | Val Acc: 0.730117 loss: 0.837710\n", 631 | "saving model with val acc 0.730\n", 632 | "[048/050] Train Acc: 0.723118 Loss: 0.860942 | Val Acc: 0.730206 loss: 0.840058\n", 633 | "saving model with val acc 0.730\n", 634 | "[049/050] Train Acc: 0.724737 Loss: 0.856906 | Val Acc: 0.730572 loss: 0.840824\n", 635 | "saving model with val acc 0.731\n", 636 | "[050/050] Train Acc: 0.725250 Loss: 0.853080 | Val Acc: 0.733532 loss: 0.832365\n", 637 | "saving model with val acc 0.734\n" 638 | ] 639 | } 640 | ], 641 | "source": [ 642 | "# start training\n", 643 | "\n", 644 | "best_acc = 0.0\n", 645 | "for epoch in range(num_epoch):\n", 646 | " train_acc = 0.0\n", 647 | " train_loss = 0.0\n", 648 | " val_acc = 0.0\n", 649 | " val_loss = 0.0\n", 650 | "\n", 651 | " # training\n", 652 | " model.train() # set the model to training mode\n", 653 | " for i, data in enumerate(train_loader):\n", 654 | " inputs, labels = data\n", 655 | " inputs, labels = inputs.to(device), labels.to(device)\n", 656 | " optimizer.zero_grad() \n", 657 | " outputs = model(inputs) \n", 658 | " batch_loss = criterion(outputs, labels)\n", 659 | " _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n", 660 | " batch_loss.backward() \n", 661 | " optimizer.step() \n", 662 | "\n", 663 | " train_acc += (train_pred.cpu() == labels.cpu()).sum().item()\n", 664 | " train_loss += batch_loss.item()\n", 665 | "\n", 666 | " # validation\n", 667 | " if len(val_set) > 0:\n", 668 | " model.eval() # set the model to evaluation mode\n", 669 | " with torch.no_grad():\n", 670 | " for i, data in enumerate(val_loader):\n", 671 | " inputs, labels = data\n", 672 | " inputs, labels = inputs.to(device), labels.to(device)\n", 673 | " outputs = model(inputs)\n", 674 | " batch_loss = criterion(outputs, labels) \n", 675 | " _, val_pred = torch.max(outputs, 1) \n", 676 | " \n", 677 | " val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability\n", 678 | " val_loss += batch_loss.item()\n", 679 | "\n", 680 | " print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(\n", 681 | " epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)\n", 682 | " ))\n", 683 | "\n", 684 | " # if the model improves, save a checkpoint at this epoch\n", 685 | " if val_acc > best_acc:\n", 686 | " best_acc = val_acc\n", 687 | " torch.save(model.state_dict(), model_path)\n", 688 | " print('saving model with val acc {:.3f}'.format(best_acc/len(val_set)))\n", 689 | " else:\n", 690 | " print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(\n", 691 | " epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)\n", 692 | " ))\n", 693 | "\n", 694 | "# if not validating, save the last epoch\n", 695 | "if len(val_set) == 0:\n", 696 | " torch.save(model.state_dict(), model_path)\n", 697 | " print('saving model at last epoch')" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": 13, 703 | "metadata": { 704 | "execution": { 705 | "iopub.execute_input": "2021-03-14T15:43:55.038358Z", 706 | "iopub.status.busy": "2021-03-14T15:43:55.037333Z", 707 | "iopub.status.idle": "2021-03-14T15:43:55.636852Z", 708 | "shell.execute_reply": "2021-03-14T15:43:55.637275Z" 709 | }, 710 | "papermill": { 711 | "duration": 0.631282, 712 | "end_time": "2021-03-14T15:43:55.637422", 713 | "exception": false, 714 | "start_time": "2021-03-14T15:43:55.006140", 715 | "status": "completed" 716 | }, 717 | "tags": [] 718 | }, 719 | "outputs": [ 720 | { 721 | "data": { 722 | "text/plain": [ 723 | "" 724 | ] 725 | }, 726 | "execution_count": 13, 727 | "metadata": {}, 728 | "output_type": "execute_result" 729 | } 730 | ], 731 | "source": [ 732 | "# create testing dataset\n", 733 | "test_set = TIMITDataset(test, None)\n", 734 | "test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)\n", 735 | "\n", 736 | "# create model and load weights from checkpoint\n", 737 | "model = Classifier().to(device)\n", 738 | "model.load_state_dict(torch.load(model_path))" 739 | ] 740 | }, 741 | { 742 | "cell_type": "code", 743 | "execution_count": 14, 744 | "metadata": { 745 | "execution": { 746 | "iopub.execute_input": "2021-03-14T15:43:55.696648Z", 747 | "iopub.status.busy": "2021-03-14T15:43:55.694486Z", 748 | "iopub.status.idle": "2021-03-14T15:43:58.685874Z", 749 | "shell.execute_reply": "2021-03-14T15:43:58.685414Z" 750 | }, 751 | "papermill": { 752 | "duration": 3.022661, 753 | "end_time": "2021-03-14T15:43:58.686002", 754 | "exception": false, 755 | "start_time": "2021-03-14T15:43:55.663341", 756 | "status": "completed" 757 | }, 758 | "tags": [] 759 | }, 760 | "outputs": [], 761 | "source": [ 762 | "predict = []\n", 763 | "model.eval() # set the model to evaluation mode\n", 764 | "with torch.no_grad():\n", 765 | " for i, data in enumerate(test_loader):\n", 766 | " inputs = data\n", 767 | " inputs = inputs.to(device)\n", 768 | " outputs = model(inputs)\n", 769 | " _, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n", 770 | "\n", 771 | " for y in test_pred.cpu().numpy():\n", 772 | " predict.append(y)" 773 | ] 774 | }, 775 | { 776 | "cell_type": "code", 777 | "execution_count": 15, 778 | "metadata": { 779 | "execution": { 780 | "iopub.execute_input": "2021-03-14T15:43:58.760603Z", 781 | "iopub.status.busy": "2021-03-14T15:43:58.759830Z", 782 | "iopub.status.idle": "2021-03-14T15:43:58.765488Z", 783 | "shell.execute_reply": "2021-03-14T15:43:58.765050Z" 784 | }, 785 | "papermill": { 786 | "duration": 0.05326, 787 | "end_time": "2021-03-14T15:43:58.765600", 788 | "exception": false, 789 | "start_time": "2021-03-14T15:43:58.712340", 790 | "status": "completed" 791 | }, 792 | "tags": [] 793 | }, 794 | "outputs": [ 795 | { 796 | "data": { 797 | "text/plain": [ 798 | "[36,\n", 799 | " 36,\n", 800 | " 36,\n", 801 | " 36,\n", 802 | " 36,\n", 803 | " 36,\n", 804 | " 36,\n", 805 | " 36,\n", 806 | " 36,\n", 807 | " 36,\n", 808 | " 36,\n", 809 | " 0,\n", 810 | " 0,\n", 811 | " 0,\n", 812 | " 0,\n", 813 | " 0,\n", 814 | " 0,\n", 815 | " 0,\n", 816 | " 0,\n", 817 | " 0,\n", 818 | " 37,\n", 819 | " 37,\n", 820 | " 37,\n", 821 | " 37,\n", 822 | " 37,\n", 823 | " 37,\n", 824 | " 37,\n", 825 | " 37,\n", 826 | " 3,\n", 827 | " 3,\n", 828 | " 3,\n", 829 | " 3,\n", 830 | " 3,\n", 831 | " 3,\n", 832 | " 3,\n", 833 | " 3,\n", 834 | " 3,\n", 835 | " 3,\n", 836 | " 3,\n", 837 | " 3,\n", 838 | " 3,\n", 839 | " 3,\n", 840 | " 38,\n", 841 | " 38,\n", 842 | " 38,\n", 843 | " 25,\n", 844 | " 25,\n", 845 | " 25,\n", 846 | " 15,\n", 847 | " 15,\n", 848 | " 17,\n", 849 | " 17,\n", 850 | " 17,\n", 851 | " 17,\n", 852 | " 17,\n", 853 | " 17,\n", 854 | " 17,\n", 855 | " 38,\n", 856 | " 38,\n", 857 | " 38,\n", 858 | " 38,\n", 859 | " 38,\n", 860 | " 38,\n", 861 | " 25,\n", 862 | " 25,\n", 863 | " 7,\n", 864 | " 7,\n", 865 | " 7,\n", 866 | " 7,\n", 867 | " 7,\n", 868 | " 7,\n", 869 | " 7,\n", 870 | " 7,\n", 871 | " 7,\n", 872 | " 7,\n", 873 | " 7,\n", 874 | " 14,\n", 875 | " 14,\n", 876 | " 14,\n", 877 | " 14,\n", 878 | " 14,\n", 879 | " 14,\n", 880 | " 38,\n", 881 | " 38,\n", 882 | " 38,\n", 883 | " 38,\n", 884 | " 38,\n", 885 | " 38,\n", 886 | " 38,\n", 887 | " 30,\n", 888 | " 35,\n", 889 | " 35,\n", 890 | " 35,\n", 891 | " 35,\n", 892 | " 35,\n", 893 | " 35,\n", 894 | " 35,\n", 895 | " 35,\n", 896 | " 35,\n", 897 | " 35,\n", 898 | " 35,\n", 899 | " 35,\n", 900 | " 35,\n", 901 | " 35,\n", 902 | " 5,\n", 903 | " 5,\n", 904 | " 5,\n", 905 | " 5,\n", 906 | " 5,\n", 907 | " 5,\n", 908 | " 5,\n", 909 | " 5,\n", 910 | " 5,\n", 911 | " 5,\n", 912 | " 5,\n", 913 | " 5,\n", 914 | " 5,\n", 915 | " 5,\n", 916 | " 5,\n", 917 | " 26,\n", 918 | " 26,\n", 919 | " 26,\n", 920 | " 25,\n", 921 | " 1,\n", 922 | " 1,\n", 923 | " 1,\n", 924 | " 1,\n", 925 | " 1,\n", 926 | " 1,\n", 927 | " 1,\n", 928 | " 19,\n", 929 | " 19,\n", 930 | " 19,\n", 931 | " 19,\n", 932 | " 19,\n", 933 | " 19,\n", 934 | " 19,\n", 935 | " 19,\n", 936 | " 19,\n", 937 | " 38,\n", 938 | " 38,\n", 939 | " 38,\n", 940 | " 38,\n", 941 | " 27,\n", 942 | " 27,\n", 943 | " 27,\n", 944 | " 27,\n", 945 | " 14,\n", 946 | " 14,\n", 947 | " 14,\n", 948 | " 14,\n", 949 | " 14,\n", 950 | " 0,\n", 951 | " 0,\n", 952 | " 0,\n", 953 | " 0,\n", 954 | " 0,\n", 955 | " 0,\n", 956 | " 0,\n", 957 | " 0,\n", 958 | " 35,\n", 959 | " 35,\n", 960 | " 35,\n", 961 | " 35,\n", 962 | " 35,\n", 963 | " 35,\n", 964 | " 35,\n", 965 | " 35,\n", 966 | " 35,\n", 967 | " 35,\n", 968 | " 35,\n", 969 | " 35,\n", 970 | " 35,\n", 971 | " 35,\n", 972 | " 0,\n", 973 | " 0,\n", 974 | " 0,\n", 975 | " 0,\n", 976 | " 0,\n", 977 | " 0,\n", 978 | " 0,\n", 979 | " 16,\n", 980 | " 16,\n", 981 | " 16,\n", 982 | " 16,\n", 983 | " 16,\n", 984 | " 16,\n", 985 | " 16,\n", 986 | " 16,\n", 987 | " 16,\n", 988 | " 16,\n", 989 | " 16,\n", 990 | " 16,\n", 991 | " 7,\n", 992 | " 7,\n", 993 | " 7,\n", 994 | " 7,\n", 995 | " 7,\n", 996 | " 7,\n", 997 | " 7,\n", 998 | " 7,\n", 999 | " 7,\n", 1000 | " 7,\n", 1001 | " 7,\n", 1002 | " 7,\n", 1003 | " 7,\n", 1004 | " 7,\n", 1005 | " 36,\n", 1006 | " 36,\n", 1007 | " 36,\n", 1008 | " 36,\n", 1009 | " 36,\n", 1010 | " 36,\n", 1011 | " 36,\n", 1012 | " 36,\n", 1013 | " 36,\n", 1014 | " 36,\n", 1015 | " 36,\n", 1016 | " 36,\n", 1017 | " 36,\n", 1018 | " 36,\n", 1019 | " 36,\n", 1020 | " 38,\n", 1021 | " 38,\n", 1022 | " 38,\n", 1023 | " 16,\n", 1024 | " 16,\n", 1025 | " 16,\n", 1026 | " 16,\n", 1027 | " 16,\n", 1028 | " 16,\n", 1029 | " 16,\n", 1030 | " 7,\n", 1031 | " 7,\n", 1032 | " 7,\n", 1033 | " 7,\n", 1034 | " 7,\n", 1035 | " 7,\n", 1036 | " 7,\n", 1037 | " 7,\n", 1038 | " 7,\n", 1039 | " 7,\n", 1040 | " 7,\n", 1041 | " 7,\n", 1042 | " 26,\n", 1043 | " 26,\n", 1044 | " 26,\n", 1045 | " 17,\n", 1046 | " 17,\n", 1047 | " 17,\n", 1048 | " 17,\n", 1049 | " 17,\n", 1050 | " 17,\n", 1051 | " 17,\n", 1052 | " 17,\n", 1053 | " 17,\n", 1054 | " 17,\n", 1055 | " 17,\n", 1056 | " 17,\n", 1057 | " 17,\n", 1058 | " 17,\n", 1059 | " 38,\n", 1060 | " 38,\n", 1061 | " 38,\n", 1062 | " 38,\n", 1063 | " 38,\n", 1064 | " 38,\n", 1065 | " 7,\n", 1066 | " 7,\n", 1067 | " 7,\n", 1068 | " 7,\n", 1069 | " 7,\n", 1070 | " 7,\n", 1071 | " 7,\n", 1072 | " 7,\n", 1073 | " 7,\n", 1074 | " 7,\n", 1075 | " 7,\n", 1076 | " 7,\n", 1077 | " 13,\n", 1078 | " 13,\n", 1079 | " 13,\n", 1080 | " 13,\n", 1081 | " 13,\n", 1082 | " 13,\n", 1083 | " 13,\n", 1084 | " 15,\n", 1085 | " 15,\n", 1086 | " 15,\n", 1087 | " 15,\n", 1088 | " 15,\n", 1089 | " 15,\n", 1090 | " 15,\n", 1091 | " 15,\n", 1092 | " 0,\n", 1093 | " 0,\n", 1094 | " 0,\n", 1095 | " 0,\n", 1096 | " 0,\n", 1097 | " 0,\n", 1098 | " 0,\n", 1099 | " 0,\n", 1100 | " 1,\n", 1101 | " 1,\n", 1102 | " 1,\n", 1103 | " 1,\n", 1104 | " 1,\n", 1105 | " 1,\n", 1106 | " 1,\n", 1107 | " 17,\n", 1108 | " 17,\n", 1109 | " 17,\n", 1110 | " 17,\n", 1111 | " 17,\n", 1112 | " 17,\n", 1113 | " 17,\n", 1114 | " 17,\n", 1115 | " 17,\n", 1116 | " 17,\n", 1117 | " 25,\n", 1118 | " 12,\n", 1119 | " 12,\n", 1120 | " 12,\n", 1121 | " 12,\n", 1122 | " 12,\n", 1123 | " 12,\n", 1124 | " 12,\n", 1125 | " 12,\n", 1126 | " 12,\n", 1127 | " 12,\n", 1128 | " 12,\n", 1129 | " 19,\n", 1130 | " 19,\n", 1131 | " 19,\n", 1132 | " 19,\n", 1133 | " 3,\n", 1134 | " 3,\n", 1135 | " 3,\n", 1136 | " 3,\n", 1137 | " 3,\n", 1138 | " 3,\n", 1139 | " 3,\n", 1140 | " 3,\n", 1141 | " 3,\n", 1142 | " 3,\n", 1143 | " 3,\n", 1144 | " 3,\n", 1145 | " 3,\n", 1146 | " 3,\n", 1147 | " 3,\n", 1148 | " 35,\n", 1149 | " 35,\n", 1150 | " 35,\n", 1151 | " 35,\n", 1152 | " 35,\n", 1153 | " 35,\n", 1154 | " 35,\n", 1155 | " 35,\n", 1156 | " 35,\n", 1157 | " 35,\n", 1158 | " 35,\n", 1159 | " 38,\n", 1160 | " 38,\n", 1161 | " 38,\n", 1162 | " 38,\n", 1163 | " 38,\n", 1164 | " 38,\n", 1165 | " 38,\n", 1166 | " 18,\n", 1167 | " 18,\n", 1168 | " 18,\n", 1169 | " 18,\n", 1170 | " 0,\n", 1171 | " 0,\n", 1172 | " 0,\n", 1173 | " 0,\n", 1174 | " 0,\n", 1175 | " 0,\n", 1176 | " 0,\n", 1177 | " 0,\n", 1178 | " 38,\n", 1179 | " 38,\n", 1180 | " 38,\n", 1181 | " 26,\n", 1182 | " 25,\n", 1183 | " 1,\n", 1184 | " 1,\n", 1185 | " 1,\n", 1186 | " 1,\n", 1187 | " 1,\n", 1188 | " 38,\n", 1189 | " 38,\n", 1190 | " 38,\n", 1191 | " 38,\n", 1192 | " 30,\n", 1193 | " 30,\n", 1194 | " 30,\n", 1195 | " 30,\n", 1196 | " 30,\n", 1197 | " 30,\n", 1198 | " 3,\n", 1199 | " 3,\n", 1200 | " 3,\n", 1201 | " 3,\n", 1202 | " 3,\n", 1203 | " 3,\n", 1204 | " 3,\n", 1205 | " 3,\n", 1206 | " 3,\n", 1207 | " 3,\n", 1208 | " 3,\n", 1209 | " 7,\n", 1210 | " 14,\n", 1211 | " 14,\n", 1212 | " 14,\n", 1213 | " 14,\n", 1214 | " 14,\n", 1215 | " 14,\n", 1216 | " 14,\n", 1217 | " 14,\n", 1218 | " 14,\n", 1219 | " 14,\n", 1220 | " 0,\n", 1221 | " 0,\n", 1222 | " 0,\n", 1223 | " 0,\n", 1224 | " 0,\n", 1225 | " 0,\n", 1226 | " 0,\n", 1227 | " 0,\n", 1228 | " 0,\n", 1229 | " 1,\n", 1230 | " 1,\n", 1231 | " 1,\n", 1232 | " 1,\n", 1233 | " 1,\n", 1234 | " 1,\n", 1235 | " 19,\n", 1236 | " 19,\n", 1237 | " 19,\n", 1238 | " 19,\n", 1239 | " 19,\n", 1240 | " 19,\n", 1241 | " 10,\n", 1242 | " 10,\n", 1243 | " 10,\n", 1244 | " 10,\n", 1245 | " 10,\n", 1246 | " 10,\n", 1247 | " 10,\n", 1248 | " 10,\n", 1249 | " 10,\n", 1250 | " 10,\n", 1251 | " 10,\n", 1252 | " 10,\n", 1253 | " 10,\n", 1254 | " 10,\n", 1255 | " 10,\n", 1256 | " 10,\n", 1257 | " 10,\n", 1258 | " 10,\n", 1259 | " 10,\n", 1260 | " 10,\n", 1261 | " 13,\n", 1262 | " 13,\n", 1263 | " 13,\n", 1264 | " 13,\n", 1265 | " 13,\n", 1266 | " 13,\n", 1267 | " 13,\n", 1268 | " 13,\n", 1269 | " 0,\n", 1270 | " 0,\n", 1271 | " 0,\n", 1272 | " 0,\n", 1273 | " 0,\n", 1274 | " 0,\n", 1275 | " 0,\n", 1276 | " 0,\n", 1277 | " 0,\n", 1278 | " 14,\n", 1279 | " 14,\n", 1280 | " 14,\n", 1281 | " 14,\n", 1282 | " 14,\n", 1283 | " 14,\n", 1284 | " 14,\n", 1285 | " 14,\n", 1286 | " 14,\n", 1287 | " 14,\n", 1288 | " 3,\n", 1289 | " 3,\n", 1290 | " 3,\n", 1291 | " 3,\n", 1292 | " 3,\n", 1293 | " 3,\n", 1294 | " 3,\n", 1295 | " 3,\n", 1296 | " 3,\n", 1297 | " 3,\n", 1298 | " 3,\n", 1299 | " 3,\n", 1300 | " 3,\n", 1301 | " 3,\n", 1302 | " 3,\n", 1303 | " 3,\n", 1304 | " 3,\n", 1305 | " 3,\n", 1306 | " 38,\n", 1307 | " 38,\n", 1308 | " 38,\n", 1309 | " 38,\n", 1310 | " 38,\n", 1311 | " 27,\n", 1312 | " 27,\n", 1313 | " 27,\n", 1314 | " 13,\n", 1315 | " 13,\n", 1316 | " 13,\n", 1317 | " 13,\n", 1318 | " 13,\n", 1319 | " 9,\n", 1320 | " 9,\n", 1321 | " 9,\n", 1322 | " 9,\n", 1323 | " 9,\n", 1324 | " 9,\n", 1325 | " 9,\n", 1326 | " 9,\n", 1327 | " 9,\n", 1328 | " 9,\n", 1329 | " 9,\n", 1330 | " 9,\n", 1331 | " 38,\n", 1332 | " 38,\n", 1333 | " 38,\n", 1334 | " 38,\n", 1335 | " 38,\n", 1336 | " 30,\n", 1337 | " 30,\n", 1338 | " 30,\n", 1339 | " 23,\n", 1340 | " 23,\n", 1341 | " 23,\n", 1342 | " 23,\n", 1343 | " 23,\n", 1344 | " 23,\n", 1345 | " 23,\n", 1346 | " 23,\n", 1347 | " 3,\n", 1348 | " 3,\n", 1349 | " 9,\n", 1350 | " 9,\n", 1351 | " 3,\n", 1352 | " 3,\n", 1353 | " 3,\n", 1354 | " 3,\n", 1355 | " 3,\n", 1356 | " 3,\n", 1357 | " 3,\n", 1358 | " 3,\n", 1359 | " 3,\n", 1360 | " 3,\n", 1361 | " 3,\n", 1362 | " 3,\n", 1363 | " 17,\n", 1364 | " 28,\n", 1365 | " 37,\n", 1366 | " 37,\n", 1367 | " 1,\n", 1368 | " 1,\n", 1369 | " 1,\n", 1370 | " 1,\n", 1371 | " 1,\n", 1372 | " 1,\n", 1373 | " 1,\n", 1374 | " 1,\n", 1375 | " 35,\n", 1376 | " 35,\n", 1377 | " 35,\n", 1378 | " 35,\n", 1379 | " 35,\n", 1380 | " 35,\n", 1381 | " 35,\n", 1382 | " 35,\n", 1383 | " 35,\n", 1384 | " 35,\n", 1385 | " 38,\n", 1386 | " 38,\n", 1387 | " 38,\n", 1388 | " 38,\n", 1389 | " 30,\n", 1390 | " 30,\n", 1391 | " 30,\n", 1392 | " 30,\n", 1393 | " 30,\n", 1394 | " 3,\n", 1395 | " 3,\n", 1396 | " 3,\n", 1397 | " 7,\n", 1398 | " 7,\n", 1399 | " 3,\n", 1400 | " 3,\n", 1401 | " 7,\n", 1402 | " 7,\n", 1403 | " 7,\n", 1404 | " 7,\n", 1405 | " 7,\n", 1406 | " 38,\n", 1407 | " 38,\n", 1408 | " 38,\n", 1409 | " 38,\n", 1410 | " 38,\n", 1411 | " 38,\n", 1412 | " 38,\n", 1413 | " 38,\n", 1414 | " 38,\n", 1415 | " 38,\n", 1416 | " 38,\n", 1417 | " 38,\n", 1418 | " 29,\n", 1419 | " 4,\n", 1420 | " 4,\n", 1421 | " 1,\n", 1422 | " 1,\n", 1423 | " 1,\n", 1424 | " 1,\n", 1425 | " 19,\n", 1426 | " 19,\n", 1427 | " 19,\n", 1428 | " 19,\n", 1429 | " 19,\n", 1430 | " 19,\n", 1431 | " 20,\n", 1432 | " 20,\n", 1433 | " 38,\n", 1434 | " 10,\n", 1435 | " 10,\n", 1436 | " 10,\n", 1437 | " 16,\n", 1438 | " 16,\n", 1439 | " 16,\n", 1440 | " 4,\n", 1441 | " 4,\n", 1442 | " 4,\n", 1443 | " 4,\n", 1444 | " 4,\n", 1445 | " 35,\n", 1446 | " 35,\n", 1447 | " 35,\n", 1448 | " 35,\n", 1449 | " 35,\n", 1450 | " 35,\n", 1451 | " 35,\n", 1452 | " 35,\n", 1453 | " 35,\n", 1454 | " 35,\n", 1455 | " 35,\n", 1456 | " 35,\n", 1457 | " 35,\n", 1458 | " 35,\n", 1459 | " 38,\n", 1460 | " 38,\n", 1461 | " 38,\n", 1462 | " 38,\n", 1463 | " 38,\n", 1464 | " 29,\n", 1465 | " 28,\n", 1466 | " 0,\n", 1467 | " 0,\n", 1468 | " 0,\n", 1469 | " 1,\n", 1470 | " 1,\n", 1471 | " 1,\n", 1472 | " 1,\n", 1473 | " 1,\n", 1474 | " 1,\n", 1475 | " 19,\n", 1476 | " 19,\n", 1477 | " 19,\n", 1478 | " 19,\n", 1479 | " 19,\n", 1480 | " 19,\n", 1481 | " 19,\n", 1482 | " 4,\n", 1483 | " 4,\n", 1484 | " 4,\n", 1485 | " 4,\n", 1486 | " 4,\n", 1487 | " 4,\n", 1488 | " 19,\n", 1489 | " 19,\n", 1490 | " 19,\n", 1491 | " 19,\n", 1492 | " 19,\n", 1493 | " 19,\n", 1494 | " 19,\n", 1495 | " 19,\n", 1496 | " 18,\n", 1497 | " 19,\n", 1498 | " 37,\n", 1499 | " 37,\n", 1500 | " 37,\n", 1501 | " 37,\n", 1502 | " 37,\n", 1503 | " 37,\n", 1504 | " 37,\n", 1505 | " 3,\n", 1506 | " 3,\n", 1507 | " 3,\n", 1508 | " 3,\n", 1509 | " 3,\n", 1510 | " 3,\n", 1511 | " 3,\n", 1512 | " 3,\n", 1513 | " 3,\n", 1514 | " 3,\n", 1515 | " 3,\n", 1516 | " 3,\n", 1517 | " 3,\n", 1518 | " 3,\n", 1519 | " 3,\n", 1520 | " 3,\n", 1521 | " 3,\n", 1522 | " 3,\n", 1523 | " 38,\n", 1524 | " 38,\n", 1525 | " 38,\n", 1526 | " 38,\n", 1527 | " 27,\n", 1528 | " 25,\n", 1529 | " 17,\n", 1530 | " 17,\n", 1531 | " 17,\n", 1532 | " 17,\n", 1533 | " 17,\n", 1534 | " 17,\n", 1535 | " 17,\n", 1536 | " 17,\n", 1537 | " 17,\n", 1538 | " 17,\n", 1539 | " 17,\n", 1540 | " 17,\n", 1541 | " 38,\n", 1542 | " 38,\n", 1543 | " 38,\n", 1544 | " 38,\n", 1545 | " 25,\n", 1546 | " 25,\n", 1547 | " 1,\n", 1548 | " 1,\n", 1549 | " 1,\n", 1550 | " 1,\n", 1551 | " 1,\n", 1552 | " 1,\n", 1553 | " 19,\n", 1554 | " 19,\n", 1555 | " 19,\n", 1556 | " 19,\n", 1557 | " 19,\n", 1558 | " 19,\n", 1559 | " 19,\n", 1560 | " 19,\n", 1561 | " 1,\n", 1562 | " 1,\n", 1563 | " 1,\n", 1564 | " 1,\n", 1565 | " 1,\n", 1566 | " 1,\n", 1567 | " 1,\n", 1568 | " 1,\n", 1569 | " 1,\n", 1570 | " 31,\n", 1571 | " 31,\n", 1572 | " 31,\n", 1573 | " 31,\n", 1574 | " 31,\n", 1575 | " 31,\n", 1576 | " 31,\n", 1577 | " 31,\n", 1578 | " 31,\n", 1579 | " 31,\n", 1580 | " 35,\n", 1581 | " 38,\n", 1582 | " 38,\n", 1583 | " 38,\n", 1584 | " 38,\n", 1585 | " 38,\n", 1586 | " 24,\n", 1587 | " 24,\n", 1588 | " 0,\n", 1589 | " 0,\n", 1590 | " 0,\n", 1591 | " 0,\n", 1592 | " 0,\n", 1593 | " 0,\n", 1594 | " 0,\n", 1595 | " 5,\n", 1596 | " 5,\n", 1597 | " 5,\n", 1598 | " 5,\n", 1599 | " 26,\n", 1600 | " 26,\n", 1601 | " 26,\n", 1602 | " 1,\n", 1603 | " 1,\n", 1604 | " 4,\n", 1605 | " 1,\n", 1606 | " 1,\n", 1607 | " 33,\n", 1608 | " 33,\n", 1609 | " 33,\n", 1610 | " 33,\n", 1611 | " 33,\n", 1612 | " 33,\n", 1613 | " 33,\n", 1614 | " 33,\n", 1615 | " 33,\n", 1616 | " 33,\n", 1617 | " 33,\n", 1618 | " 28,\n", 1619 | " 13,\n", 1620 | " 13,\n", 1621 | " 13,\n", 1622 | " 13,\n", 1623 | " 13,\n", 1624 | " 13,\n", 1625 | " 13,\n", 1626 | " 13,\n", 1627 | " 38,\n", 1628 | " 38,\n", 1629 | " 38,\n", 1630 | " 38,\n", 1631 | " 38,\n", 1632 | " 38,\n", 1633 | " 38,\n", 1634 | " 38,\n", 1635 | " 38,\n", 1636 | " 28,\n", 1637 | " 28,\n", 1638 | " 12,\n", 1639 | " 12,\n", 1640 | " 5,\n", 1641 | " 5,\n", 1642 | " 5,\n", 1643 | " 5,\n", 1644 | " 5,\n", 1645 | " 5,\n", 1646 | " 5,\n", 1647 | " 5,\n", 1648 | " 5,\n", 1649 | " 5,\n", 1650 | " 5,\n", 1651 | " 5,\n", 1652 | " 38,\n", 1653 | " 38,\n", 1654 | " 38,\n", 1655 | " 38,\n", 1656 | " 38,\n", 1657 | " 38,\n", 1658 | " 38,\n", 1659 | " 35,\n", 1660 | " 35,\n", 1661 | " 35,\n", 1662 | " 35,\n", 1663 | " 35,\n", 1664 | " 35,\n", 1665 | " 35,\n", 1666 | " 35,\n", 1667 | " 35,\n", 1668 | " 35,\n", 1669 | " 35,\n", 1670 | " 35,\n", 1671 | " 38,\n", 1672 | " 38,\n", 1673 | " 38,\n", 1674 | " 38,\n", 1675 | " 16,\n", 1676 | " 16,\n", 1677 | " 16,\n", 1678 | " 16,\n", 1679 | " 16,\n", 1680 | " 16,\n", 1681 | " 16,\n", 1682 | " 17,\n", 1683 | " 17,\n", 1684 | " 17,\n", 1685 | " 17,\n", 1686 | " 14,\n", 1687 | " 16,\n", 1688 | " 16,\n", 1689 | " 16,\n", 1690 | " 16,\n", 1691 | " 16,\n", 1692 | " 16,\n", 1693 | " 16,\n", 1694 | " 16,\n", 1695 | " 16,\n", 1696 | " 16,\n", 1697 | " 16,\n", 1698 | " 16,\n", 1699 | " 16,\n", 1700 | " 16,\n", 1701 | " 7,\n", 1702 | " 7,\n", 1703 | " 7,\n", 1704 | " 7,\n", 1705 | " 7,\n", 1706 | " 7,\n", 1707 | " 7,\n", 1708 | " 7,\n", 1709 | " 7,\n", 1710 | " 7,\n", 1711 | " 7,\n", 1712 | " 14,\n", 1713 | " 14,\n", 1714 | " 14,\n", 1715 | " 14,\n", 1716 | " 14,\n", 1717 | " 14,\n", 1718 | " 14,\n", 1719 | " 19,\n", 1720 | " 19,\n", 1721 | " 19,\n", 1722 | " 19,\n", 1723 | " 19,\n", 1724 | " 19,\n", 1725 | " 1,\n", 1726 | " 1,\n", 1727 | " 1,\n", 1728 | " 1,\n", 1729 | " 1,\n", 1730 | " 19,\n", 1731 | " 19,\n", 1732 | " 19,\n", 1733 | " 19,\n", 1734 | " 19,\n", 1735 | " 19,\n", 1736 | " 19,\n", 1737 | " 19,\n", 1738 | " 38,\n", 1739 | " 36,\n", 1740 | " 36,\n", 1741 | " 36,\n", 1742 | " 36,\n", 1743 | " 36,\n", 1744 | " 36,\n", 1745 | " 36,\n", 1746 | " 36,\n", 1747 | " 36,\n", 1748 | " 36,\n", 1749 | " 21,\n", 1750 | " 21,\n", 1751 | " 3,\n", 1752 | " 3,\n", 1753 | " 3,\n", 1754 | " 3,\n", 1755 | " 3,\n", 1756 | " 3,\n", 1757 | " 3,\n", 1758 | " 3,\n", 1759 | " 3,\n", 1760 | " 3,\n", 1761 | " 3,\n", 1762 | " 3,\n", 1763 | " 3,\n", 1764 | " 3,\n", 1765 | " 3,\n", 1766 | " 3,\n", 1767 | " 3,\n", 1768 | " 3,\n", 1769 | " 3,\n", 1770 | " 3,\n", 1771 | " 3,\n", 1772 | " 38,\n", 1773 | " 38,\n", 1774 | " 38,\n", 1775 | " 38,\n", 1776 | " 38,\n", 1777 | " 24,\n", 1778 | " 24,\n", 1779 | " 0,\n", 1780 | " 0,\n", 1781 | " 0,\n", 1782 | " 0,\n", 1783 | " 0,\n", 1784 | " 0,\n", 1785 | " 0,\n", 1786 | " 0,\n", 1787 | " 0,\n", 1788 | " 0,\n", 1789 | " 23,\n", 1790 | " 23,\n", 1791 | " 1,\n", 1792 | " 1,\n", 1793 | " 1,\n", 1794 | " 1,\n", 1795 | " 1,\n", 1796 | " 32,\n", 1797 | " 32,\n", 1798 | " ...]" 1799 | ] 1800 | }, 1801 | "execution_count": 15, 1802 | "metadata": {}, 1803 | "output_type": "execute_result" 1804 | } 1805 | ], 1806 | "source": [ 1807 | "predict" 1808 | ] 1809 | }, 1810 | { 1811 | "cell_type": "code", 1812 | "execution_count": 16, 1813 | "metadata": { 1814 | "execution": { 1815 | "iopub.execute_input": "2021-03-14T15:43:58.828343Z", 1816 | "iopub.status.busy": "2021-03-14T15:43:58.825881Z", 1817 | "iopub.status.idle": "2021-03-14T15:43:59.217857Z", 1818 | "shell.execute_reply": "2021-03-14T15:43:59.217277Z" 1819 | }, 1820 | "papermill": { 1821 | "duration": 0.425387, 1822 | "end_time": "2021-03-14T15:43:59.217986", 1823 | "exception": false, 1824 | "start_time": "2021-03-14T15:43:58.792599", 1825 | "status": "completed" 1826 | }, 1827 | "tags": [] 1828 | }, 1829 | "outputs": [], 1830 | "source": [ 1831 | "with open('prediction.csv', 'w') as f:\n", 1832 | " f.write('Id,Class\\n')\n", 1833 | " for i, y in enumerate(predict):\n", 1834 | " f.write('{},{}\\n'.format(i, y))\n", 1835 | " " 1836 | ] 1837 | } 1838 | ], 1839 | "metadata": { 1840 | "kernelspec": { 1841 | "display_name": "Python 3", 1842 | "language": "python", 1843 | "name": "python3" 1844 | }, 1845 | "language_info": { 1846 | "codemirror_mode": { 1847 | "name": "ipython", 1848 | "version": 3 1849 | }, 1850 | "file_extension": ".py", 1851 | "mimetype": "text/x-python", 1852 | "name": "python", 1853 | "nbconvert_exporter": "python", 1854 | "pygments_lexer": "ipython3", 1855 | "version": "3.7.9" 1856 | }, 1857 | "papermill": { 1858 | "default_parameters": {}, 1859 | "duration": 1210.024426, 1860 | "end_time": "2021-03-14T15:44:01.349587", 1861 | "environment_variables": {}, 1862 | "exception": null, 1863 | "input_path": "__notebook__.ipynb", 1864 | "output_path": "__notebook__.ipynb", 1865 | "parameters": {}, 1866 | "start_time": "2021-03-14T15:23:51.325161", 1867 | "version": "2.2.2" 1868 | } 1869 | }, 1870 | "nbformat": 4, 1871 | "nbformat_minor": 4 1872 | } 1873 | -------------------------------------------------------------------------------- /hw/hw4hw4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "hw4hw4", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "toc_visible": true 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "name": "python3" 15 | }, 16 | "widgets": { 17 | "application/vnd.jupyter.widget-state+json": { 18 | "868b39af577b4acca0e3a5d8f37ef7d3": { 19 | "model_module": "@jupyter-widgets/controls", 20 | "model_name": "HBoxModel", 21 | "state": { 22 | "_dom_classes": [], 23 | "_model_module": "@jupyter-widgets/controls", 24 | "_model_module_version": "1.5.0", 25 | "_model_name": "HBoxModel", 26 | "_view_count": null, 27 | "_view_module": "@jupyter-widgets/controls", 28 | "_view_module_version": "1.5.0", 29 | "_view_name": "HBoxView", 30 | "box_style": "", 31 | "children": [ 32 | "IPY_MODEL_ed31f46ff8bc4f49b5a3fa90eed580a9", 33 | "IPY_MODEL_b3c53eb71e27489ca5e1c711c4ea2c97" 34 | ], 35 | "layout": "IPY_MODEL_513eb6c1a6744bf087f566f28171d7fa" 36 | } 37 | }, 38 | "ed31f46ff8bc4f49b5a3fa90eed580a9": { 39 | "model_module": "@jupyter-widgets/controls", 40 | "model_name": "FloatProgressModel", 41 | "state": { 42 | "_dom_classes": [], 43 | "_model_module": "@jupyter-widgets/controls", 44 | "_model_module_version": "1.5.0", 45 | "_model_name": "FloatProgressModel", 46 | "_view_count": null, 47 | "_view_module": "@jupyter-widgets/controls", 48 | "_view_module_version": "1.5.0", 49 | "_view_name": "ProgressView", 50 | "bar_style": "", 51 | "description": " 38%", 52 | "description_tooltip": null, 53 | "layout": "IPY_MODEL_3f2eb39caf5642b3b4378ce9714f046a", 54 | "max": 6000, 55 | "min": 0, 56 | "orientation": "horizontal", 57 | "style": "IPY_MODEL_9f6512dfb3154ee7afa2cb6ac831a9bc", 58 | "value": 2307 59 | } 60 | }, 61 | "b3c53eb71e27489ca5e1c711c4ea2c97": { 62 | "model_module": "@jupyter-widgets/controls", 63 | "model_name": "HTMLModel", 64 | "state": { 65 | "_dom_classes": [], 66 | "_model_module": "@jupyter-widgets/controls", 67 | "_model_module_version": "1.5.0", 68 | "_model_name": "HTMLModel", 69 | "_view_count": null, 70 | "_view_module": "@jupyter-widgets/controls", 71 | "_view_module_version": "1.5.0", 72 | "_view_name": "HTMLView", 73 | "description": "", 74 | "description_tooltip": null, 75 | "layout": "IPY_MODEL_fd53b5d41c26451c8952fe8403782ceb", 76 | "placeholder": "​", 77 | "style": "IPY_MODEL_0c0d5b05635d4a72aa77eb762e726ee3", 78 | "value": " 2307/6000 [00:06<00:10, 348.07it/s]" 79 | } 80 | }, 81 | "513eb6c1a6744bf087f566f28171d7fa": { 82 | "model_module": "@jupyter-widgets/base", 83 | "model_name": "LayoutModel", 84 | "state": { 85 | "_model_module": "@jupyter-widgets/base", 86 | "_model_module_version": "1.2.0", 87 | "_model_name": "LayoutModel", 88 | "_view_count": null, 89 | "_view_module": "@jupyter-widgets/base", 90 | "_view_module_version": "1.2.0", 91 | "_view_name": "LayoutView", 92 | "align_content": null, 93 | "align_items": null, 94 | "align_self": null, 95 | "border": null, 96 | "bottom": null, 97 | "display": null, 98 | "flex": null, 99 | "flex_flow": null, 100 | "grid_area": null, 101 | "grid_auto_columns": null, 102 | "grid_auto_flow": null, 103 | "grid_auto_rows": null, 104 | "grid_column": null, 105 | "grid_gap": null, 106 | "grid_row": null, 107 | "grid_template_areas": null, 108 | "grid_template_columns": null, 109 | "grid_template_rows": null, 110 | "height": null, 111 | "justify_content": null, 112 | "justify_items": null, 113 | "left": null, 114 | "margin": null, 115 | "max_height": null, 116 | "max_width": null, 117 | "min_height": null, 118 | "min_width": null, 119 | "object_fit": null, 120 | "object_position": null, 121 | "order": null, 122 | "overflow": null, 123 | "overflow_x": null, 124 | "overflow_y": null, 125 | "padding": null, 126 | "right": null, 127 | "top": null, 128 | "visibility": null, 129 | "width": null 130 | } 131 | }, 132 | "3f2eb39caf5642b3b4378ce9714f046a": { 133 | "model_module": "@jupyter-widgets/base", 134 | "model_name": "LayoutModel", 135 | "state": { 136 | "_model_module": "@jupyter-widgets/base", 137 | "_model_module_version": "1.2.0", 138 | "_model_name": "LayoutModel", 139 | "_view_count": null, 140 | "_view_module": "@jupyter-widgets/base", 141 | "_view_module_version": "1.2.0", 142 | "_view_name": "LayoutView", 143 | "align_content": null, 144 | "align_items": null, 145 | "align_self": null, 146 | "border": null, 147 | "bottom": null, 148 | "display": null, 149 | "flex": null, 150 | "flex_flow": null, 151 | "grid_area": null, 152 | "grid_auto_columns": null, 153 | "grid_auto_flow": null, 154 | "grid_auto_rows": null, 155 | "grid_column": null, 156 | "grid_gap": null, 157 | "grid_row": null, 158 | "grid_template_areas": null, 159 | "grid_template_columns": null, 160 | "grid_template_rows": null, 161 | "height": null, 162 | "justify_content": null, 163 | "justify_items": null, 164 | "left": null, 165 | "margin": null, 166 | "max_height": null, 167 | "max_width": null, 168 | "min_height": null, 169 | "min_width": null, 170 | "object_fit": null, 171 | "object_position": null, 172 | "order": null, 173 | "overflow": null, 174 | "overflow_x": null, 175 | "overflow_y": null, 176 | "padding": null, 177 | "right": null, 178 | "top": null, 179 | "visibility": null, 180 | "width": null 181 | } 182 | }, 183 | "9f6512dfb3154ee7afa2cb6ac831a9bc": { 184 | "model_module": "@jupyter-widgets/controls", 185 | "model_name": "ProgressStyleModel", 186 | "state": { 187 | "_model_module": "@jupyter-widgets/controls", 188 | "_model_module_version": "1.5.0", 189 | "_model_name": "ProgressStyleModel", 190 | "_view_count": null, 191 | "_view_module": "@jupyter-widgets/base", 192 | "_view_module_version": "1.2.0", 193 | "_view_name": "StyleView", 194 | "bar_color": null, 195 | "description_width": "initial" 196 | } 197 | }, 198 | "fd53b5d41c26451c8952fe8403782ceb": { 199 | "model_module": "@jupyter-widgets/base", 200 | "model_name": "LayoutModel", 201 | "state": { 202 | "_model_module": "@jupyter-widgets/base", 203 | "_model_module_version": "1.2.0", 204 | "_model_name": "LayoutModel", 205 | "_view_count": null, 206 | "_view_module": "@jupyter-widgets/base", 207 | "_view_module_version": "1.2.0", 208 | "_view_name": "LayoutView", 209 | "align_content": null, 210 | "align_items": null, 211 | "align_self": null, 212 | "border": null, 213 | "bottom": null, 214 | "display": null, 215 | "flex": null, 216 | "flex_flow": null, 217 | "grid_area": null, 218 | "grid_auto_columns": null, 219 | "grid_auto_flow": null, 220 | "grid_auto_rows": null, 221 | "grid_column": null, 222 | "grid_gap": null, 223 | "grid_row": null, 224 | "grid_template_areas": null, 225 | "grid_template_columns": null, 226 | "grid_template_rows": null, 227 | "height": null, 228 | "justify_content": null, 229 | "justify_items": null, 230 | "left": null, 231 | "margin": null, 232 | "max_height": null, 233 | "max_width": null, 234 | "min_height": null, 235 | "min_width": null, 236 | "object_fit": null, 237 | "object_position": null, 238 | "order": null, 239 | "overflow": null, 240 | "overflow_x": null, 241 | "overflow_y": null, 242 | "padding": null, 243 | "right": null, 244 | "top": null, 245 | "visibility": null, 246 | "width": null 247 | } 248 | }, 249 | "0c0d5b05635d4a72aa77eb762e726ee3": { 250 | "model_module": "@jupyter-widgets/controls", 251 | "model_name": "DescriptionStyleModel", 252 | "state": { 253 | "_model_module": "@jupyter-widgets/controls", 254 | "_model_module_version": "1.5.0", 255 | "_model_name": "DescriptionStyleModel", 256 | "_view_count": null, 257 | "_view_module": "@jupyter-widgets/base", 258 | "_view_module_version": "1.2.0", 259 | "_view_name": "StyleView", 260 | "description_width": "" 261 | } 262 | } 263 | } 264 | } 265 | }, 266 | "cells": [ 267 | { 268 | "cell_type": "markdown", 269 | "metadata": { 270 | "id": "zC5KwRyl6Flp" 271 | }, 272 | "source": [ 273 | "# Task description\n", 274 | "- Classify the speakers of given features.\n", 275 | "- Main goal: Learn how to use transformer.\n", 276 | "- Baselines:\n", 277 | " - Easy: Run sample code and know how to use transformer.\n", 278 | " - Medium: Know how to adjust parameters of transformer.\n", 279 | " - Hard: Construct [conformer](https://arxiv.org/abs/2005.08100) which is a variety of transformer. \n", 280 | "\n", 281 | "- Other links\n", 282 | " - Kaggle: [link](https://www.kaggle.com/t/859c9ca9ede14fdea841be627c412322)\n", 283 | " - Slide: [link](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/hw/HW04/HW04.pdf)\n", 284 | " - Data: [link](https://drive.google.com/file/d/1T0RPnu-Sg5eIPwQPfYysipfcz81MnsYe/view?usp=sharing)\n", 285 | " - Video (Chinese): [link](https://www.youtube.com/watch?v=EPerg2UnGaI)\n", 286 | " - Video (English): [link](https://www.youtube.com/watch?v=Gpz6AUvCak0)\n", 287 | " - Solution for downloading dataset fail.: [link](https://drive.google.com/drive/folders/13T0Pa_WGgQxNkqZk781qhc5T9-zfh19e?usp=sharing)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": { 293 | "id": "TPDoreyypeJE" 294 | }, 295 | "source": [ 296 | "# Download dataset\n", 297 | "- **If all download links fail**\n", 298 | "- **Please follow [here](https://drive.google.com/drive/folders/13T0Pa_WGgQxNkqZk781qhc5T9-zfh19e?usp=sharing)**\n", 299 | "- **Data is [here](https://drive.google.com/file/d/1T0RPnu-Sg5eIPwQPfYysipfcz81MnsYe/view?usp=sharing)**" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "metadata": { 305 | "colab": { 306 | "base_uri": "https://localhost:8080/", 307 | "height": 244 308 | }, 309 | "id": "QvpaILXnJIcw", 310 | "outputId": "5b46c554-0c41-4174-a3e7-786004c38bbc" 311 | }, 312 | "source": [ 313 | "\"\"\"\n", 314 | " For Google drive, You can download data form any link below.\n", 315 | " If a link fails, please use another one.\n", 316 | "\"\"\"\n", 317 | "\"\"\" Download link 1 of Google drive \"\"\"\n", 318 | "# !gdown --id '1T0RPnu-Sg5eIPwQPfYysipfcz81MnsYe' --output Dataset.zip\n", 319 | "\"\"\" Download link 2 of Google drive \"\"\"\n", 320 | "# !gdown --id '1CtHZhJ-mTpNsO-MqvAPIi4Yrt3oSBXYV' --output Dataset.zip\n", 321 | "\"\"\" Download link 3 of Google drive \"\"\"\n", 322 | "# !gdown --id '14hmoMgB1fe6v50biIceKyndyeYABGrRq' --output Dataset.zip\n", 323 | "\"\"\" Download link 4 of Google drive \"\"\"\n", 324 | "# !gdown --id '1e9x-Pjl3n7-9tK9LS_WjiMo2lru4UBH9' --output Dataset.zip\n", 325 | "\"\"\" Download link 5 of Google drive \"\"\"\n", 326 | "# !gdown --id '10TC0g46bcAz_jkiMl65zNmwttT4RiRgY' --output Dataset.zip\n", 327 | "\"\"\" Download link 6 of Google drive \"\"\"\n", 328 | "# !gdown --id '1MUGBvG_JjqO0C2JYHuyV3B0lvaf1kWIm' --output Dataset.zip\n", 329 | "\"\"\" Download link 7 of Google drive \"\"\"\n", 330 | "# !gdown --id '18M91P5DHwILNyOlssZ57AiPOR0OwutOM' --output Dataset.zip\n", 331 | "\"\"\" For all download links fail, Please paste link into 'Paste link here' \"\"\"\n", 332 | "!gdown --id '1fq5-eH0xaY-YBiOqOsTBKzlfWe7ZNQCR' --output Dataset.zip\n", 333 | "\"\"\" For Google drive, you can unzip the data by the command below. \"\"\"\n", 334 | "!unzip Dataset.zip\n", 335 | "\n", 336 | "\"\"\"\n", 337 | " For Dropbox, we split dataset into five files. \n", 338 | " Please download all of them.\n", 339 | "\"\"\"\n", 340 | "# If Dropbox is not work. Please use google drive.\n", 341 | "# !wget https://www.dropbox.com/s/vw324newiku0sz0/Dataset.tar.gz.aa?dl=0\n", 342 | "# !wget https://www.dropbox.com/s/z840g69e7lnkayo/Dataset.tar.gz.ab?dl=0\n", 343 | "# !wget https://www.dropbox.com/s/hl081e1ggonio81/Dataset.tar.gz.ac?dl=0\n", 344 | "# !wget https://www.dropbox.com/s/fh3zd8ow668c4th/Dataset.tar.gz.ad?dl=0\n", 345 | "# !wget https://www.dropbox.com/s/ydzygoy2pv6gw9d/Dataset.tar.gz.ae?dl=0\n", 346 | "# !cat Dataset.tar.gz.* | tar zxvf -\n", 347 | "\n", 348 | "\"\"\"\n", 349 | " For Onedrive, we split dataset into five files. \n", 350 | " Please download all of them.\n", 351 | "\"\"\"\n", 352 | "# !wget --no-check-certificate \"https://onedrive.live.com/download?cid=10C95EE5FD151BFB&resid=10C95EE5FD151BFB%21106&authkey=ACB6opQR3CG9kmc\" -O Dataset.tar.gz.aa\n", 353 | "# !wget --no-check-certificate \"https://onedrive.live.com/download?cid=93DDDDD552E145DB&resid=93DDDDD552E145DB%21106&authkey=AP6EepjxSdvyV6Y\" -O Dataset.tar.gz.ab\n", 354 | "# !wget --no-check-certificate \"https://onedrive.live.com/download?cid=644545816461BCCC&resid=644545816461BCCC%21106&authkey=ALiefB0kI7Epb0Q\" -O Dataset.tar.gz.ac\n", 355 | "# !wget --no-check-certificate \"https://onedrive.live.com/download?cid=77CEBB3C3C512821&resid=77CEBB3C3C512821%21106&authkey=AAXCx4TTDYC0yjM\" -O Dataset.tar.gz.ad\n", 356 | "# !wget --no-check-certificate \"https://onedrive.live.com/download?cid=383D0E0146A11B02&resid=383D0E0146A11B02%21106&authkey=ALwVc4StVbig6QI\" -O Dataset.tar.gz.ae\n", 357 | "# !cat Dataset.tar.gz.* | tar zxvf -" 358 | ], 359 | "execution_count": null, 360 | "outputs": [ 361 | { 362 | "output_type": "stream", 363 | "text": [ 364 | "Downloading...\n", 365 | "From: https://drive.google.com/uc?id=1fq5-eH0xaY-YBiOqOsTBKzlfWe7ZNQCR\n", 366 | "To: /content/Dataset.zip\n", 367 | "527MB [00:06, 53.5MB/s]Traceback (most recent call last):\n", 368 | " File \"/usr/local/bin/gdown\", line 8, in \n", 369 | " sys.exit(main())\n", 370 | " File \"/usr/local/lib/python2.7/dist-packages/gdown/cli.py\", line 61, in main\n", 371 | " quiet=args.quiet,\n", 372 | "\n", 373 | "528MB [00:06, 82.9MB/s]\n", 374 | "Archive: Dataset.zip\n", 375 | "replace Dataset/uttr-f2e2dc9d3a6c471abd6464a316d1e21a.pt? [y]es, [n]o, [A]ll, [N]one, [r]ename: " 376 | ], 377 | "name": "stdout" 378 | }, 379 | { 380 | "output_type": "execute_result", 381 | "data": { 382 | "application/vnd.google.colaboratory.intrinsic+json": { 383 | "type": "string" 384 | }, 385 | "text/plain": [ 386 | "'\\n For Onedrive, we split dataset into five files. \\n Please download all of them.\\n'" 387 | ] 388 | }, 389 | "metadata": { 390 | "tags": [] 391 | }, 392 | "execution_count": 9 393 | } 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": { 399 | "id": "v1gYr_aoNDue" 400 | }, 401 | "source": [ 402 | "# Data" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": { 408 | "id": "Mz_NpuAipk3h" 409 | }, 410 | "source": [ 411 | "## Dataset\n", 412 | "- Original dataset is [Voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/).\n", 413 | "- The [license](https://creativecommons.org/licenses/by/4.0/) and [complete version](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/files/license.txt) of Voxceleb1.\n", 414 | "- We randomly select 600 speakers from Voxceleb1.\n", 415 | "- Then preprocess the raw waveforms into mel-spectrograms.\n", 416 | "\n", 417 | "- Args:\n", 418 | " - data_dir: The path to the data directory.\n", 419 | " - metadata_path: The path to the metadata.\n", 420 | " - segment_len: The length of audio segment for training. \n", 421 | "- The architecture of data directory \\\\\n", 422 | " - data directory \\\\\n", 423 | " |---- metadata.json \\\\\n", 424 | " |---- testdata.json \\\\\n", 425 | " |---- mapping.json \\\\\n", 426 | " |---- uttr-{random string}.pt \\\\\n", 427 | "\n", 428 | "- The information in metadata\n", 429 | " - \"n_mels\": The dimention of mel-spectrogram.\n", 430 | " - \"speakers\": A dictionary. \n", 431 | " - Key: speaker ids.\n", 432 | " - value: \"feature_path\" and \"mel_len\"\n", 433 | "\n", 434 | "\n", 435 | "For efficiency, we segment the mel-spectrograms into segments in the traing step." 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "metadata": { 441 | "id": "cd7hoGhYtbXQ" 442 | }, 443 | "source": [ 444 | "import os\n", 445 | "import json\n", 446 | "import torch\n", 447 | "import random\n", 448 | "from pathlib import Path\n", 449 | "from torch.utils.data import Dataset\n", 450 | "from torch.nn.utils.rnn import pad_sequence\n", 451 | " \n", 452 | " \n", 453 | "class myDataset(Dataset):\n", 454 | " def __init__(self, data_dir, segment_len=128):\n", 455 | " self.data_dir = data_dir\n", 456 | " self.segment_len = segment_len\n", 457 | " \n", 458 | " # Load the mapping from speaker neme to their corresponding id. \n", 459 | " mapping_path = Path(data_dir) / \"mapping.json\" # 拼接路径\n", 460 | " mapping = json.load(mapping_path.open())\n", 461 | " self.speaker2id = mapping[\"speaker2id\"]\n", 462 | " \n", 463 | " # Load metadata of training data.\n", 464 | " metadata_path = Path(data_dir) / \"metadata.json\"\n", 465 | " metadata = json.load(open(metadata_path))[\"speakers\"]\n", 466 | " \n", 467 | " # Get the total number of speaker.\n", 468 | " self.speaker_num = len(metadata.keys())\n", 469 | " self.data = []\n", 470 | " for speaker in metadata.keys():\n", 471 | " for utterances in metadata[speaker]:\n", 472 | " self.data.append([utterances[\"feature_path\"], self.speaker2id[speaker]])\n", 473 | " \n", 474 | " def __len__(self):\n", 475 | " return len(self.data)\n", 476 | " \n", 477 | " def __getitem__(self, index):\n", 478 | " feat_path, speaker = self.data[index]\n", 479 | " # Load preprocessed mel-spectrogram.\n", 480 | " mel = torch.load(os.path.join(self.data_dir, feat_path))\n", 481 | " \n", 482 | " # Segmemt mel-spectrogram into \"segment_len\" frames.\n", 483 | " if len(mel) > self.segment_len:\n", 484 | " # Randomly get the starting point of the segment.\n", 485 | " start = random.randint(0, len(mel) - self.segment_len)\n", 486 | " # Get a segment with \"segment_len\" frames.\n", 487 | " mel = torch.FloatTensor(mel[start:start+self.segment_len])\n", 488 | " else:\n", 489 | " mel = torch.FloatTensor(mel)\n", 490 | " # Turn the speaker id into long for computing loss later.\n", 491 | " speaker = torch.FloatTensor([speaker]).long()\n", 492 | " return mel, speaker\n", 493 | " \n", 494 | " def get_speaker_number(self):\n", 495 | " return self.speaker_num" 496 | ], 497 | "execution_count": null, 498 | "outputs": [] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": { 503 | "id": "mqJxjoi_NGnB" 504 | }, 505 | "source": [ 506 | "## Dataloader\n", 507 | "- Split dataset into training dataset(90%) and validation dataset(10%).\n", 508 | "- Create dataloader to iterate the data.\n" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "metadata": { 514 | "id": "zuT1AuFENI8t" 515 | }, 516 | "source": [ 517 | "import torch\n", 518 | "from torch.utils.data import DataLoader, random_split\n", 519 | "from torch.nn.utils.rnn import pad_sequence\n", 520 | "\n", 521 | "\n", 522 | "def collate_batch(batch):\n", 523 | " # Process features within a batch.\n", 524 | " \"\"\"Collate a batch of data.\"\"\"\n", 525 | " mel, speaker = zip(*batch)\n", 526 | " # Because we train the model batch by batch, we need to pad the features in the same batch to make their lengths the same.\n", 527 | " mel = pad_sequence(mel, batch_first=True, padding_value=-20) # pad log 10^(-20) which is very small value.\n", 528 | " # mel: (batch size, length, 40)\n", 529 | " return mel, torch.FloatTensor(speaker).long()\n", 530 | "\n", 531 | "\n", 532 | "def get_dataloader(data_dir, batch_size, n_workers):\n", 533 | " \"\"\"Generate dataloader\"\"\"\n", 534 | " dataset = myDataset(data_dir)\n", 535 | " speaker_num = dataset.get_speaker_number()\n", 536 | " # Split dataset into training dataset and validation dataset\n", 537 | " trainlen = int(0.9 * len(dataset))\n", 538 | " lengths = [trainlen, len(dataset) - trainlen]\n", 539 | " trainset, validset = random_split(dataset, lengths)\n", 540 | "\n", 541 | " train_loader = DataLoader(\n", 542 | " trainset,\n", 543 | " batch_size=batch_size,\n", 544 | " shuffle=True,\n", 545 | " drop_last=True,\n", 546 | " num_workers=n_workers,\n", 547 | " pin_memory=True,\n", 548 | " collate_fn=collate_batch,\n", 549 | " )\n", 550 | " valid_loader = DataLoader(\n", 551 | " validset,\n", 552 | " batch_size=batch_size,\n", 553 | " num_workers=n_workers,\n", 554 | " drop_last=True,\n", 555 | " pin_memory=True,\n", 556 | " collate_fn=collate_batch,\n", 557 | " )\n", 558 | "\n", 559 | " return train_loader, valid_loader, speaker_num\n" 560 | ], 561 | "execution_count": null, 562 | "outputs": [] 563 | }, 564 | { 565 | "cell_type": "markdown", 566 | "metadata": { 567 | "id": "X0x6eXiHpr4R" 568 | }, 569 | "source": [ 570 | "# Model\n", 571 | "- TransformerEncoderLayer:\n", 572 | " - Base transformer encoder layer in [Attention Is All You Need](https://arxiv.org/abs/1706.03762)\n", 573 | " - Parameters:\n", 574 | " - d_model: the number of expected features of the input (required).\n", 575 | "\n", 576 | " - nhead: the number of heads of the multiheadattention models (required).\n", 577 | "\n", 578 | " - dim_feedforward: the dimension of the feedforward network model (default=2048).\n", 579 | "\n", 580 | " - dropout: the dropout value (default=0.1).\n", 581 | "\n", 582 | " - activation: the activation function of intermediate layer, relu or gelu (default=relu).\n", 583 | "\n", 584 | "- TransformerEncoder:\n", 585 | " - TransformerEncoder is a stack of N transformer encoder layers\n", 586 | " - Parameters:\n", 587 | " - encoder_layer: an instance of the TransformerEncoderLayer() class (required).\n", 588 | "\n", 589 | " - num_layers: the number of sub-encoder-layers in the encoder (required).\n", 590 | "\n", 591 | " - norm: the layer normalization component (optional)." 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "metadata": { 597 | "id": "SHX4eVj4tjtd" 598 | }, 599 | "source": [ 600 | "import torch\n", 601 | "import torch.nn as nn\n", 602 | "import torch.nn.functional as F\n", 603 | "\n", 604 | "\n", 605 | "class Classifier(nn.Module):\n", 606 | " def __init__(self, d_model=80, n_spks=600, dropout=0.2):\n", 607 | " super().__init__()\n", 608 | " # Project the dimension of features from that of input into d_model.\n", 609 | " self.prenet = nn.Linear(40, d_model)\n", 610 | " # TODO:\n", 611 | " # Change Transformer to Conformer.\n", 612 | " # https://arxiv.org/abs/2005.08100\n", 613 | " self.encoder_layer = nn.TransformerEncoderLayer(\n", 614 | " d_model=d_model, dim_feedforward=256, nhead=2\n", 615 | " )\n", 616 | " # self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2)\n", 617 | "\n", 618 | " # Project the the dimension of features from d_model into speaker nums.\n", 619 | " self.pred_layer = nn.Sequential(\n", 620 | " nn.Linear(d_model, d_model),\n", 621 | " nn.ReLU(),\n", 622 | " nn.Linear(d_model, n_spks),\n", 623 | " )\n", 624 | "\n", 625 | " def forward(self, mels):\n", 626 | " \"\"\"\n", 627 | " args:\n", 628 | " mels: (batch size, length, 40)\n", 629 | " return:\n", 630 | " out: (batch size, n_spks)\n", 631 | " \"\"\"\n", 632 | " # out: (batch size, length, d_model)\n", 633 | " out = self.prenet(mels)\n", 634 | " # out: (length, batch size, d_model)\n", 635 | " out = out.permute(1, 0, 2)\n", 636 | " # The encoder layer expect features in the shape of (length, batch size, d_model).\n", 637 | " out = self.encoder_layer(out)\n", 638 | " # out: (batch size, length, d_model)\n", 639 | " out = out.transpose(0, 1)\n", 640 | " # mean pooling\n", 641 | " stats = out.mean(dim=1)\n", 642 | "\n", 643 | " # out: (batch, n_spks)\n", 644 | " out = self.pred_layer(stats)\n", 645 | " return out\n" 646 | ], 647 | "execution_count": null, 648 | "outputs": [] 649 | }, 650 | { 651 | "cell_type": "markdown", 652 | "metadata": { 653 | "id": "-__DolPGpvDZ" 654 | }, 655 | "source": [ 656 | "# Learning rate schedule\n", 657 | "- For transformer architecture, the design of learning rate schedule is different from that of CNN.\n", 658 | "- Previous works show that the warmup of learning rate is useful for training models with transformer architectures.\n", 659 | "- The warmup schedule\n", 660 | " - Set learning rate to 0 in the beginning.\n", 661 | " - The learning rate increases linearly from 0 to initial learning rate during warmup period." 662 | ] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "metadata": { 667 | "id": "K-0816BntqT9" 668 | }, 669 | "source": [ 670 | "import math\n", 671 | "\n", 672 | "import torch\n", 673 | "from torch.optim import Optimizer\n", 674 | "from torch.optim.lr_scheduler import LambdaLR\n", 675 | "\n", 676 | "\n", 677 | "def get_cosine_schedule_with_warmup(\n", 678 | " optimizer: Optimizer,\n", 679 | " num_warmup_steps: int,\n", 680 | " num_training_steps: int,\n", 681 | " num_cycles: float = 0.5,\n", 682 | " last_epoch: int = -1,\n", 683 | "):\n", 684 | " \"\"\"\n", 685 | " Create a schedule with a learning rate that decreases following the values of the cosine function between the\n", 686 | " initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n", 687 | " initial lr set in the optimizer.\n", 688 | "\n", 689 | " Args:\n", 690 | " optimizer (:class:`~torch.optim.Optimizer`):\n", 691 | " The optimizer for which to schedule the learning rate.\n", 692 | " num_warmup_steps (:obj:`int`):\n", 693 | " The number of steps for the warmup phase.\n", 694 | " num_training_steps (:obj:`int`):\n", 695 | " The total number of training steps.\n", 696 | " num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n", 697 | " The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n", 698 | " following a half-cosine).\n", 699 | " last_epoch (:obj:`int`, `optional`, defaults to -1):\n", 700 | " The index of the last epoch when resuming training.\n", 701 | "\n", 702 | " Return:\n", 703 | " :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n", 704 | " \"\"\"\n", 705 | "\n", 706 | " def lr_lambda(current_step):\n", 707 | " # Warmup\n", 708 | " if current_step < num_warmup_steps:\n", 709 | " return float(current_step) / float(max(1, num_warmup_steps))\n", 710 | " # decadence\n", 711 | " progress = float(current_step - num_warmup_steps) / float(\n", 712 | " max(1, num_training_steps - num_warmup_steps)\n", 713 | " )\n", 714 | " return max(\n", 715 | " 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n", 716 | " )\n", 717 | "\n", 718 | " return LambdaLR(optimizer, lr_lambda, last_epoch)\n" 719 | ], 720 | "execution_count": null, 721 | "outputs": [] 722 | }, 723 | { 724 | "cell_type": "markdown", 725 | "metadata": { 726 | "id": "IP03FFo9K8DS" 727 | }, 728 | "source": [ 729 | "# Model Function\n", 730 | "- Model forward function." 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "metadata": { 736 | "id": "fohaLEFJK9-t" 737 | }, 738 | "source": [ 739 | "import torch\n", 740 | "\n", 741 | "\n", 742 | "def model_fn(batch, model, criterion, device):\n", 743 | " \"\"\"Forward a batch through the model.\"\"\"\n", 744 | "\n", 745 | " mels, labels = batch\n", 746 | " mels = mels.to(device)\n", 747 | " labels = labels.to(device)\n", 748 | "\n", 749 | " outs = model(mels)\n", 750 | "\n", 751 | " loss = criterion(outs, labels)\n", 752 | "\n", 753 | " # Get the speaker id with highest probability.\n", 754 | " preds = outs.argmax(1)\n", 755 | " # Compute accuracy.\n", 756 | " accuracy = torch.mean((preds == labels).float())\n", 757 | "\n", 758 | " return loss, accuracy\n" 759 | ], 760 | "execution_count": null, 761 | "outputs": [] 762 | }, 763 | { 764 | "cell_type": "markdown", 765 | "metadata": { 766 | "id": "F7cg-YrzLQcf" 767 | }, 768 | "source": [ 769 | "# Validate\n", 770 | "- Calculate accuracy of the validation set." 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "metadata": { 776 | "id": "mD-_p6nWLO2L" 777 | }, 778 | "source": [ 779 | "from tqdm import tqdm\n", 780 | "import torch\n", 781 | "\n", 782 | "\n", 783 | "def valid(dataloader, model, criterion, device): \n", 784 | " \"\"\"Validate on validation set.\"\"\"\n", 785 | "\n", 786 | " model.eval()\n", 787 | " running_loss = 0.0\n", 788 | " running_accuracy = 0.0\n", 789 | " pbar = tqdm(total=len(dataloader.dataset), ncols=0, desc=\"Valid\", unit=\" uttr\")\n", 790 | "\n", 791 | " for i, batch in enumerate(dataloader):\n", 792 | " with torch.no_grad():\n", 793 | " loss, accuracy = model_fn(batch, model, criterion, device)\n", 794 | " running_loss += loss.item()\n", 795 | " running_accuracy += accuracy.item()\n", 796 | "\n", 797 | " pbar.update(dataloader.batch_size)\n", 798 | " pbar.set_postfix(\n", 799 | " loss=f\"{running_loss / (i+1):.2f}\",\n", 800 | " accuracy=f\"{running_accuracy / (i+1):.2f}\",\n", 801 | " )\n", 802 | "\n", 803 | " pbar.close()\n", 804 | " model.train()\n", 805 | "\n", 806 | " return running_accuracy / len(dataloader)\n" 807 | ], 808 | "execution_count": null, 809 | "outputs": [] 810 | }, 811 | { 812 | "cell_type": "markdown", 813 | "metadata": { 814 | "id": "noHXyal5p1W5" 815 | }, 816 | "source": [ 817 | "# Main function" 818 | ] 819 | }, 820 | { 821 | "cell_type": "code", 822 | "metadata": { 823 | "colab": { 824 | "background_save": true, 825 | "base_uri": "https://localhost:8080/" 826 | }, 827 | "id": "chRQE7oYtw62", 828 | "outputId": "1eed2105-d3ab-48eb-e660-017e7083e404" 829 | }, 830 | "source": [ 831 | "from tqdm import tqdm\n", 832 | "\n", 833 | "import torch\n", 834 | "import torch.nn as nn\n", 835 | "from torch.optim import AdamW\n", 836 | "from torch.utils.data import DataLoader, random_split\n", 837 | "\n", 838 | "\n", 839 | "def parse_args():\n", 840 | " \"\"\"arguments\"\"\"\n", 841 | " config = {\n", 842 | " \"data_dir\": \"./Dataset\",\n", 843 | " \"save_path\": \"model.ckpt\",\n", 844 | " \"batch_size\": 128,\n", 845 | " \"n_workers\": 8,\n", 846 | " \"valid_steps\": 2000,\n", 847 | " \"warmup_steps\": 1000,\n", 848 | " \"save_steps\": 10000,\n", 849 | " \"total_steps\": 70000,\n", 850 | " }\n", 851 | "\n", 852 | " return config\n", 853 | "\n", 854 | "\n", 855 | "def main(\n", 856 | " data_dir,\n", 857 | " save_path,\n", 858 | " batch_size,\n", 859 | " n_workers,\n", 860 | " valid_steps,\n", 861 | " warmup_steps,\n", 862 | " total_steps,\n", 863 | " save_steps,\n", 864 | "):\n", 865 | " \"\"\"Main function.\"\"\"\n", 866 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 867 | " print(f\"[Info]: Use {device} now!\")\n", 868 | "\n", 869 | " train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)\n", 870 | " train_iterator = iter(train_loader)\n", 871 | " print(f\"[Info]: Finish loading data!\",flush = True)\n", 872 | "\n", 873 | " model = Classifier(n_spks=speaker_num).to(device)\n", 874 | " criterion = nn.CrossEntropyLoss()\n", 875 | " optimizer = AdamW(model.parameters(), lr=1e-3)\n", 876 | " scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)\n", 877 | " print(f\"[Info]: Finish creating model!\",flush = True)\n", 878 | "\n", 879 | " best_accuracy = -1.0\n", 880 | " best_state_dict = None\n", 881 | "\n", 882 | " pbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n", 883 | "\n", 884 | " for step in range(total_steps):\n", 885 | " # Get data\n", 886 | " try:\n", 887 | " batch = next(train_iterator)\n", 888 | " except StopIteration:\n", 889 | " train_iterator = iter(train_loader)\n", 890 | " batch = next(train_iterator)\n", 891 | "\n", 892 | " loss, accuracy = model_fn(batch, model, criterion, device)\n", 893 | " batch_loss = loss.item()\n", 894 | " batch_accuracy = accuracy.item()\n", 895 | "\n", 896 | " # Updata model\n", 897 | " loss.backward()\n", 898 | " optimizer.step()\n", 899 | " scheduler.step()\n", 900 | " optimizer.zero_grad()\n", 901 | " \n", 902 | " # Log\n", 903 | " pbar.update()\n", 904 | " pbar.set_postfix(\n", 905 | " loss=f\"{batch_loss:.2f}\",\n", 906 | " accuracy=f\"{batch_accuracy:.2f}\",\n", 907 | " step=step + 1,\n", 908 | " )\n", 909 | "\n", 910 | " # Do validation\n", 911 | " if (step + 1) % valid_steps == 0:\n", 912 | " pbar.close()\n", 913 | "\n", 914 | " valid_accuracy = valid(valid_loader, model, criterion, device)\n", 915 | "\n", 916 | " # keep the best model\n", 917 | " if valid_accuracy > best_accuracy:\n", 918 | " best_accuracy = valid_accuracy\n", 919 | " best_state_dict = model.state_dict()\n", 920 | "\n", 921 | " pbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n", 922 | "\n", 923 | " # Save the best model so far.\n", 924 | " if (step + 1) % save_steps == 0 and best_state_dict is not None:\n", 925 | " torch.save(best_state_dict, save_path)\n", 926 | " pbar.write(f\"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})\")\n", 927 | "\n", 928 | " pbar.close()\n", 929 | "\n", 930 | "\n", 931 | "if __name__ == \"__main__\":\n", 932 | " main(**parse_args())\n" 933 | ], 934 | "execution_count": null, 935 | "outputs": [ 936 | { 937 | "output_type": "stream", 938 | "text": [ 939 | "[Info]: Use cuda now!\n" 940 | ], 941 | "name": "stdout" 942 | }, 943 | { 944 | "output_type": "stream", 945 | "text": [ 946 | "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", 947 | " cpuset_checked))\n" 948 | ], 949 | "name": "stderr" 950 | }, 951 | { 952 | "output_type": "stream", 953 | "text": [ 954 | "[Info]: Finish loading data!\n", 955 | "[Info]: Finish creating model!\n" 956 | ], 957 | "name": "stdout" 958 | }, 959 | { 960 | "output_type": "stream", 961 | "text": [ 962 | "Train: 100% 2000/2000 [06:27<00:00, 5.17 step/s, accuracy=0.33, loss=3.05, step=2000]\n", 963 | "Valid: 100% 6912/6944 [00:13<00:00, 497.46 uttr/s, accuracy=0.33, loss=3.04]\n", 964 | "Train: 100% 2000/2000 [05:55<00:00, 5.63 step/s, accuracy=0.44, loss=2.33, step=4000]\n", 965 | "Valid: 100% 6912/6944 [00:13<00:00, 499.16 uttr/s, accuracy=0.45, loss=2.41]\n", 966 | "Train: 100% 2000/2000 [05:52<00:00, 5.68 step/s, accuracy=0.57, loss=1.90, step=6000]\n", 967 | "Valid: 100% 6912/6944 [00:14<00:00, 472.39 uttr/s, accuracy=0.51, loss=2.09]\n", 968 | "Train: 100% 2000/2000 [05:52<00:00, 5.68 step/s, accuracy=0.62, loss=1.64, step=8000]\n", 969 | "Valid: 100% 6912/6944 [00:14<00:00, 471.86 uttr/s, accuracy=0.54, loss=1.92]\n", 970 | "Train: 100% 2000/2000 [05:52<00:00, 5.68 step/s, accuracy=0.55, loss=1.81, step=1e+4]\n", 971 | "Valid: 100% 6912/6944 [00:15<00:00, 459.10 uttr/s, accuracy=0.59, loss=1.76]\n", 972 | "Train: 1% 16/2000 [00:00<00:44, 44.19 step/s, accuracy=0.62, loss=1.55, step=1e+4]" 973 | ], 974 | "name": "stderr" 975 | }, 976 | { 977 | "output_type": "stream", 978 | "text": [ 979 | "Step 10000, best model saved. (accuracy=0.5884)\n" 980 | ], 981 | "name": "stdout" 982 | }, 983 | { 984 | "output_type": "stream", 985 | "text": [ 986 | "Train: 100% 2000/2000 [05:50<00:00, 5.71 step/s, accuracy=0.60, loss=1.37, step=12000]\n", 987 | "Valid: 100% 6912/6944 [00:16<00:00, 414.31 uttr/s, accuracy=0.61, loss=1.69]\n", 988 | "Train: 100% 2000/2000 [05:51<00:00, 5.70 step/s, accuracy=0.69, loss=1.47, step=14000]\n", 989 | "Valid: 100% 6912/6944 [00:16<00:00, 408.57 uttr/s, accuracy=0.63, loss=1.56]\n", 990 | "Train: 100% 2000/2000 [05:55<00:00, 5.62 step/s, accuracy=0.60, loss=1.49, step=16000]\n", 991 | "Valid: 100% 6912/6944 [00:17<00:00, 406.33 uttr/s, accuracy=0.64, loss=1.52]\n", 992 | "Train: 100% 2000/2000 [05:59<00:00, 5.56 step/s, accuracy=0.64, loss=1.46, step=18000]\n", 993 | "Valid: 100% 6912/6944 [00:16<00:00, 418.04 uttr/s, accuracy=0.64, loss=1.49]\n", 994 | "Train: 100% 2000/2000 [06:01<00:00, 5.53 step/s, accuracy=0.70, loss=1.23, step=2e+4]\n", 995 | "Valid: 100% 6912/6944 [00:14<00:00, 486.63 uttr/s, accuracy=0.67, loss=1.42]\n", 996 | "Train: 0% 8/2000 [00:00<00:42, 46.47 step/s, accuracy=0.72, loss=1.32, step=2e+4]" 997 | ], 998 | "name": "stderr" 999 | }, 1000 | { 1001 | "output_type": "stream", 1002 | "text": [ 1003 | "Step 20000, best model saved. (accuracy=0.6655)\n" 1004 | ], 1005 | "name": "stdout" 1006 | }, 1007 | { 1008 | "output_type": "stream", 1009 | "text": [ 1010 | "Train: 100% 2000/2000 [05:56<00:00, 5.61 step/s, accuracy=0.67, loss=1.07, step=22000]\n", 1011 | "Valid: 100% 6912/6944 [00:14<00:00, 488.76 uttr/s, accuracy=0.67, loss=1.39]\n", 1012 | "Train: 100% 2000/2000 [05:53<00:00, 5.66 step/s, accuracy=0.75, loss=1.18, step=24000]\n", 1013 | "Valid: 100% 6912/6944 [00:13<00:00, 496.34 uttr/s, accuracy=0.68, loss=1.37]\n", 1014 | "Train: 100% 2000/2000 [05:53<00:00, 5.66 step/s, accuracy=0.66, loss=1.19, step=26000]\n", 1015 | "Valid: 100% 6912/6944 [00:14<00:00, 475.30 uttr/s, accuracy=0.69, loss=1.34]\n", 1016 | "Train: 100% 2000/2000 [05:53<00:00, 5.66 step/s, accuracy=0.79, loss=0.86, step=28000]\n", 1017 | "Valid: 100% 6912/6944 [00:14<00:00, 482.66 uttr/s, accuracy=0.69, loss=1.30]\n", 1018 | "Train: 100% 2000/2000 [05:52<00:00, 5.68 step/s, accuracy=0.70, loss=1.13, step=3e+4]\n", 1019 | "Valid: 100% 6912/6944 [00:15<00:00, 455.44 uttr/s, accuracy=0.70, loss=1.30]\n", 1020 | "Train: 1% 15/2000 [00:00<00:43, 45.36 step/s, accuracy=0.73, loss=1.13, step=3e+4]" 1021 | ], 1022 | "name": "stderr" 1023 | }, 1024 | { 1025 | "output_type": "stream", 1026 | "text": [ 1027 | "Step 30000, best model saved. (accuracy=0.7005)\n" 1028 | ], 1029 | "name": "stdout" 1030 | }, 1031 | { 1032 | "output_type": "stream", 1033 | "text": [ 1034 | "Train: 100% 2000/2000 [05:49<00:00, 5.72 step/s, accuracy=0.79, loss=0.82, step=32000]\n", 1035 | "Valid: 100% 6912/6944 [00:16<00:00, 421.98 uttr/s, accuracy=0.70, loss=1.24]\n", 1036 | "Train: 100% 2000/2000 [05:52<00:00, 5.67 step/s, accuracy=0.80, loss=0.90, step=34000]\n", 1037 | "Valid: 100% 6912/6944 [00:16<00:00, 416.04 uttr/s, accuracy=0.71, loss=1.22]\n", 1038 | "Train: 100% 2000/2000 [05:55<00:00, 5.62 step/s, accuracy=0.71, loss=1.06, step=36000]\n", 1039 | "Valid: 100% 6912/6944 [00:16<00:00, 415.42 uttr/s, accuracy=0.72, loss=1.20]\n", 1040 | "Train: 100% 2000/2000 [05:59<00:00, 5.57 step/s, accuracy=0.84, loss=0.66, step=38000]\n", 1041 | "Valid: 100% 6912/6944 [00:16<00:00, 427.32 uttr/s, accuracy=0.72, loss=1.19]\n", 1042 | "Train: 100% 2000/2000 [06:02<00:00, 5.52 step/s, accuracy=0.84, loss=0.76, step=4e+4]\n", 1043 | "Valid: 100% 6912/6944 [00:15<00:00, 434.39 uttr/s, accuracy=0.73, loss=1.16]\n", 1044 | "Train: 1% 16/2000 [00:00<00:36, 54.71 step/s, accuracy=0.80, loss=0.85, step=4e+4]" 1045 | ], 1046 | "name": "stderr" 1047 | }, 1048 | { 1049 | "output_type": "stream", 1050 | "text": [ 1051 | "Step 40000, best model saved. (accuracy=0.7312)\n" 1052 | ], 1053 | "name": "stdout" 1054 | }, 1055 | { 1056 | "output_type": "stream", 1057 | "text": [ 1058 | "Train: 100% 2000/2000 [05:58<00:00, 5.58 step/s, accuracy=0.77, loss=0.85, step=42000]\n", 1059 | "Valid: 100% 6912/6944 [00:14<00:00, 472.11 uttr/s, accuracy=0.74, loss=1.15]\n", 1060 | "Train: 100% 2000/2000 [05:53<00:00, 5.65 step/s, accuracy=0.81, loss=0.81, step=44000]\n", 1061 | "Valid: 100% 6912/6944 [00:14<00:00, 489.00 uttr/s, accuracy=0.74, loss=1.11]\n", 1062 | "Train: 100% 2000/2000 [05:56<00:00, 5.61 step/s, accuracy=0.84, loss=0.63, step=46000]\n", 1063 | "Valid: 100% 6912/6944 [00:14<00:00, 476.40 uttr/s, accuracy=0.74, loss=1.11]\n", 1064 | "Train: 100% 2000/2000 [05:52<00:00, 5.68 step/s, accuracy=0.84, loss=0.70, step=48000]\n", 1065 | "Valid: 100% 6912/6944 [00:14<00:00, 463.28 uttr/s, accuracy=0.75, loss=1.10]\n", 1066 | "Train: 100% 2000/2000 [05:50<00:00, 5.71 step/s, accuracy=0.85, loss=0.56, step=5e+4]\n", 1067 | "Valid: 100% 6912/6944 [00:15<00:00, 437.16 uttr/s, accuracy=0.75, loss=1.09]\n", 1068 | "Train: 1% 16/2000 [00:00<00:43, 45.46 step/s, accuracy=0.83, loss=0.75, step=5e+4]" 1069 | ], 1070 | "name": "stderr" 1071 | }, 1072 | { 1073 | "output_type": "stream", 1074 | "text": [ 1075 | "Step 50000, best model saved. (accuracy=0.7471)\n" 1076 | ], 1077 | "name": "stdout" 1078 | }, 1079 | { 1080 | "output_type": "stream", 1081 | "text": [ 1082 | "Train: 100% 2000/2000 [05:52<00:00, 5.68 step/s, accuracy=0.83, loss=0.70, step=52000]\n", 1083 | "Valid: 100% 6912/6944 [00:15<00:00, 433.56 uttr/s, accuracy=0.75, loss=1.08]\n", 1084 | "Train: 100% 2000/2000 [06:31<00:00, 5.11 step/s, accuracy=0.81, loss=0.67, step=54000]\n", 1085 | "Valid: 100% 6912/6944 [00:20<00:00, 345.36 uttr/s, accuracy=0.74, loss=1.10]\n", 1086 | "Train: 100% 2000/2000 [06:56<00:00, 4.80 step/s, accuracy=0.74, loss=1.00, step=56000]\n", 1087 | "Valid: 100% 6912/6944 [00:18<00:00, 370.93 uttr/s, accuracy=0.75, loss=1.07]\n", 1088 | "Train: 100% 2000/2000 [06:09<00:00, 5.41 step/s, accuracy=0.77, loss=0.83, step=58000]\n", 1089 | "Valid: 100% 6912/6944 [00:15<00:00, 446.31 uttr/s, accuracy=0.75, loss=1.07]\n", 1090 | "Train: 100% 2000/2000 [05:54<00:00, 5.64 step/s, accuracy=0.84, loss=0.51, step=6e+4]\n", 1091 | "Valid: 100% 6912/6944 [00:15<00:00, 436.26 uttr/s, accuracy=0.76, loss=1.05]\n", 1092 | "Train: 1% 15/2000 [00:00<00:38, 51.84 step/s, accuracy=0.78, loss=0.83, step=6e+4]" 1093 | ], 1094 | "name": "stderr" 1095 | }, 1096 | { 1097 | "output_type": "stream", 1098 | "text": [ 1099 | "Step 60000, best model saved. (accuracy=0.7588)\n" 1100 | ], 1101 | "name": "stdout" 1102 | }, 1103 | { 1104 | "output_type": "stream", 1105 | "text": [ 1106 | "Train: 100% 2000/2000 [06:00<00:00, 5.54 step/s, accuracy=0.80, loss=0.81, step=62000]\n", 1107 | "Valid: 100% 6912/6944 [00:14<00:00, 486.52 uttr/s, accuracy=0.75, loss=1.07]\n", 1108 | "Train: 100% 2000/2000 [05:57<00:00, 5.60 step/s, accuracy=0.88, loss=0.44, step=64000]\n", 1109 | "Valid: 100% 6912/6944 [00:14<00:00, 488.02 uttr/s, accuracy=0.76, loss=1.05]\n", 1110 | "Train: 100% 2000/2000 [05:54<00:00, 5.64 step/s, accuracy=0.77, loss=0.76, step=66000]\n", 1111 | "Valid: 100% 6912/6944 [00:14<00:00, 483.56 uttr/s, accuracy=0.75, loss=1.03]\n", 1112 | "Train: 100% 2000/2000 [05:55<00:00, 5.62 step/s, accuracy=0.81, loss=0.64, step=68000]\n", 1113 | "Valid: 100% 6912/6944 [00:14<00:00, 488.28 uttr/s, accuracy=0.76, loss=1.02]\n", 1114 | "Train: 100% 2000/2000 [05:51<00:00, 5.69 step/s, accuracy=0.78, loss=0.74, step=7e+4]\n", 1115 | "Valid: 100% 6912/6944 [00:15<00:00, 456.34 uttr/s, accuracy=0.76, loss=1.02]\n", 1116 | "Train: 0% 0/2000 [00:00 TransformerEncoder 142 | - RNNDecoder -> TransformerDecoder 143 | - encoder_ffn_embed_dim -> 1024 144 | - encoder_layers/decoder_layers -> 4 145 | - hard:使用back-translation 146 | 1. 训练zh-en的模型 147 | 2. 使用transformer结构 148 | 3. 将中文数据集输入得到英文数据,并获得合成数据集 149 | 4. 使用合成数据集和原数据集进行训练(如果实现了,30个epochs即可实现) 150 | - 评判标准:BLEU(基于准确率的句子相似度评估) 151 | 152 | ## tips 153 | 154 | - Tokenize data with sub-word units:将sub-word作为最小单位(token),比如transportation可以切成trans+port+ation,这样做的好处可以减少总词汇数以及减少罕见词的影响 155 | - Label smoothing regularization:让所有分类结果都得到重视,避免过拟合 156 | - Learning rate scheduling:学习率先线性增加,后以平方根减小,让训练跟stable 157 | - Back-translation(BT):train一个中翻英的model,然后将只有中文的数据集放进去得到更多的数据用来训练模型,增加数据量 158 | - 数据集的背景要一致,是TED则都要演讲背景的数据 159 | 160 | 161 | 162 | # 第7次作业 163 | 164 | 作业地址:https://www.kaggle.com/c/ml2021-spring-hw7【ddl:2021.05.21】 165 | 166 | 作业介绍:https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/hw/HW07/HW07.pdf(不保证不会失效) 167 | 168 | 参考代码:https://colab.research.google.com/github/ga642381/ML2021-Spring/blob/main/HW07/HW07.ipynb 169 | 170 | 参考视频:https://www.youtube.com/watch?v=DoZlp0vfDmI 171 | 172 | ## 简要说明 173 | 174 | 预训练bert模型,fine-tuning下游任务——基于提取的QA【给一段文字,给一个问题,答案从文字中找到】 175 | 176 | - 数据集:中文阅读理解 177 | 178 | image-20210507154757480 179 | 180 | - 训练集:8k段话+27k问题 181 | - dev集:1k+3K5 182 | - 测试集:1k+3K5 183 | 184 | - 要求/目标: 185 | 186 | - simple:sample 187 | - medium:学习率衰减+“doc_stride”值修改 188 | - strong:提高预处理+换模型 189 | - boss: 190 | 191 | - 评判标准:精确匹配的准确率simpleend 215 | 216 | - Automatic mixed precision:加速训练【only work on some gpu (e.g. T4, V100)】 217 | 218 | - **梯度累加(Gradient Accumulation)**:当我们做一些计算量需求大的任务(例如语义分割、GAN等)或者输入图片尺寸太大的时候,我们的batch size往往只能设置为2或者4,否则就会出现 “CUDA OUT OF MEMORY” 的不可抗力报错。https://cowarder.site/2019/10/29/Gradient-Accumulation/ 219 | 220 | 简言之,梯度累加就是每计算一个batch的梯度,不进行清零,而是做梯度的累加,当累加到一定的次数之后,再更新网络参数,然后将梯度清零。也就相当于手动做大的batch size 221 | 222 | 参考code 223 | 224 | ```python 225 | for i, (inputs, labels) in enumerate(trainloader): 226 | outputs = net(inputs) # 正向传播 227 | loss = criterion(outputs, labels) # 计算损失函数 228 | loss = loss / accumulation_steps # 损失标准化 229 | loss.backward() # 反向传播,计算梯度 230 | if (i+1) % accumulation_steps == 0: 231 | optimizer.step() # 更新参数 232 | optimizer.zero_grad() # 梯度清零 233 | if (i+1) % evaluation_steps == 0: 234 | evaluate_model() 235 | 236 | ``` 237 | 238 | ## 改进部分 239 | 240 | ### 加速训练√ 241 | 242 | 仅在特定gpu【colab pro的v100】上有效果,将fp16_training设置为True 243 | 244 | ```python 245 | # Change "fp16_training" to True to support automatic mixed precision training (fp16) 246 | fp16_training = True 247 | 248 | if fp16_training: 249 | !pip install accelerate==0.2.0 250 | from accelerate import Accelerator 251 | accelerator = Accelerator(fp16=True) 252 | device = accelerator.device 253 | 254 | # Documentation for the toolkit: https://huggingface.co/docs/accelerate/ 255 | ``` 256 | 257 | 258 | 259 | ### 改变doc_stride值 & Preprocessing√ 260 | 261 | Dataset and Dataloader部分:这个仅用于非train的时候 262 | 263 | ```python 264 | ##### TODO: Change value of doc_stride ##### 265 | self.doc_stride = 32 266 | ``` 267 | 268 | 这里预处理部分修改见注释,其中`ran = random.uniform(0.2, 0.8)`这里可以调整 269 | 270 | 271 | ```python 272 | ##### TODO: Preprocessing ##### 273 | # Hint: How to prevent model from learning something it should not learn (i.e. answers are not always near the middle of window) 274 | if self.split == "train": 275 | # Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraph 276 | answer_start_token = tokenized_paragraph.char_to_token(question["answer_start"]) 277 | answer_end_token = tokenized_paragraph.char_to_token(question["answer_end"]) 278 | 279 | # A single window is obtained by slicing the portion of paragraph containing the answer 280 | # 原做法限定paragraph一定要包含answer,且尽量靠近中间,但问题是,这样做让模型更聚焦于answer处于中间的情况,而answer在test时不一定只在中间有 281 | # mid = (answer_start_token + answer_end_token) // 2 282 | # paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len)) 283 | # paragraph_end = paragraph_start + self.max_paragraph_len 284 | 285 | # 新做法: 286 | ran = random.uniform(0.2, 0.8) 287 | mid = (int)(answer_start_token + ran * (answer_end_token - answer_start_token)) 288 | paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len)) 289 | paragraph_end = paragraph_start + self.max_paragraph_len 290 | ``` 291 | 292 | 293 | 294 | ### Function for Evaluation部分改错和提升√ 295 | 296 | 观察预测结果csv,有不认识的字符[unk]和空白,所以可以处理一下这两种情况 297 | 298 | - 空白可以设置start>end时,翻转 299 | - [unk]直接去掉吧 300 | 301 | ### 学习率衰减√ 302 | 303 | ```python 304 | learning_rate = 1e-4 305 | optimizer = AdamW(model.parameters(), lr=learning_rate) 306 | scheduler = get_linear_schedule_with_warmup(optimizer, 0, len(train_loader)) 307 | 308 | ... 309 | for epoch in range(num_epoch): 310 | step = 1 311 | ... 312 | for data in tqdm(train_loader): 313 | ##### TODO: Apply linear learning rate decay ##### 314 | optimizer.step() 315 | scheduler.step() 316 | 317 | ``` 318 | 319 | 打印效果: 320 | 321 | ```shell 322 | lr: 9.412114014251782e-05 323 | lr: 8.818289786223278e-05 324 | lr: 8.224465558194774e-05 325 | lr: 7.630641330166272e-05 326 | lr: 7.036817102137767e-05 327 | lr: 6.442992874109263e-05 328 | lr: 5.84916864608076e-05 329 | lr: 5.255344418052257e-05 330 | lr: 4.661520190023753e-05 331 | lr: 4.06769596199525e-05 332 | lr: 3.473871733966746e-05 333 | lr: 2.8800475059382425e-05 334 | lr: 2.2862232779097388e-05 335 | lr: 1.6923990498812355e-05 336 | lr: 1.0985748218527318e-05 337 | lr: 5.047505938242281e-06 338 | ``` 339 | 340 | 341 | 342 | ### 改变模型√ 343 | 344 | 这里我没有细纠,直接使用BertForQuestionAnswering和BertTokenizerFast修改后面的模型名字做的,他给的模型网站是用的Autoxxx来做,会报错`TypeError: forward() got an unexpected keyword argument` 345 | 346 | ```python 347 | model_name = 'xxxxxx' 348 | model = BertForQuestionAnswering.from_pretrained(model_name).to(device) 349 | tokenizer = BertTokenizerFast.from_pretrained(model_name) 350 | ``` 351 | 352 | ### 超参数 353 | 354 | - 模型选择model_name 355 | - 窗口大小self.doc_stride 356 | - 范围限定ran = random.uniform(0, 1) 357 | - train_batch_size 358 | - learning_rate 359 | 360 | ### 总结 361 | 362 | 已经尽力了还是没过strong baseline,用了macbert模型,调整doc_stride=8,epoch=2了,还是无法提升55555~ 363 | 364 | ## leaderboard 365 | 366 | 见[here](./README.md) 367 | 368 | ## 其他 369 | 370 | 一些好的教程: 371 | 372 | - https://www.cnblogs.com/cxq1126/p/13517394.html 373 | 374 | 375 | 376 | # 第8次作业 377 | 378 | 作业地址:https://www.kaggle.com/c/ml2021spring-hw8【ddl:2021.05.21】 379 | 380 | 作业介绍:https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/hw/HW08/HW08.pdf(不保证不会失效) 381 | 382 | 参考代码:https://colab.research.google.com/github/ga642381/ML2021-Spring/blob/main/HW08/HW08.ipynb 383 | 384 | 参考视频:https://youtu.be/xkpXP4byXqk 385 | 386 | ## 简要说明 387 | 388 | 使用autoencoder判断一个机器学习模型是否能检测测试图片与训练图片是同一个类别/分布 389 | 390 | - 数据集: 391 | 392 | - 训练集:140k人脸图片(64x64x3) 393 | - 测试集:10k张人脸图片(标记为0),10k张另外分布的人脸(标记为1) 394 | 395 | - 要求/目标: 396 | 397 | - simple:FCN autoencoder 398 | - medium:CNN autoencoder + 更小的模型、层数 + 更小的batch size 399 | - strong:Add BatchNorm + train 更久 400 | - boss:加额外的分类器(discriminator) + 给异常图片添加随机噪声 + OCGAN 401 | 402 | - 评判标准:需要设置门槛(分数)来判断是否异常,但这样就需要手动去找这个值,我们想的是从模型本身出发,找到一个敏感的探测器,将正常和异常能明显分开 403 | 404 | ROC_AUC分数:含义是sensor下半部分的面积,越高越好,会经过正则化,因此最大为1。计算方法(ppt里面p15~p17):先由分数排序,计算TP和FP,然后正则化,最后计算面积 405 | 406 | - 提交格式:id,score 407 | 408 | 409 | 410 | ## 改进部分 411 | 412 | ### 模型部分 413 | 414 | 给了四种模型:FCN、CNN、VAE、resnet,可以自行调整其参数 415 | 416 | #### FCN 417 | 418 | 没啥需要改的 419 | 420 | #### CNN 421 | 422 | - 减少层数 423 | - 调整channel 424 | - 加bn层 425 | 426 | ```python 427 | class conv_autoencoder(nn.Module): 428 | def __init__(self): 429 | super(conv_autoencoder, self).__init__() 430 | self.encoder = nn.Sequential( 431 | nn.Conv2d(3, 12, 4, stride=2, padding=1), 432 | nn.BatchNorm2d(12), 433 | nn.ReLU(), 434 | nn.Conv2d(12, 24, 4, stride=2, padding=1), 435 | nn.BatchNorm2d(24), 436 | nn.ReLU(), 437 | nn.Conv2d(24, 48, 4, stride=2, padding=1), 438 | nn.BatchNorm2d(48), 439 | nn.ReLU(), 440 | # nn.Conv2d(48, 96, 4, stride=2, padding=1), # medium: remove this layer 4x4x96 441 | # nn.ReLU(), 442 | ) 443 | self.decoder = nn.Sequential( 444 | # nn.ConvTranspose2d(96, 48, 4, stride=2, padding=1), # medium: remove this layer 445 | # nn.ReLU(), 446 | nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1), 447 | nn.BatchNorm2d(24), 448 | nn.ReLU(), 449 | nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1), 450 | nn.BatchNorm2d(12), 451 | nn.ReLU(), 452 | nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1), 453 | nn.BatchNorm2d(3), 454 | nn.Tanh(), 455 | ) 456 | 457 | def forward(self, x): 458 | x = self.encoder(x) 459 | x = self.decoder(x) 460 | return x 461 | 462 | ``` 463 | 464 | 465 | 466 | #### VAE 467 | 468 | 469 | 470 | #### resnet 471 | 472 | 473 | 474 | ### 学习率衰减√ 475 | 476 | ```python 477 | optimizer = torch.optim.Adam( 478 | model.parameters(), lr=learning_rate) 479 | # 学习率衰减 480 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 481 | 482 | for epoch in qqdm_train: 483 | for data in train_dataloader: 484 | ... 485 | optimizer.step() 486 | scheduler.step() 487 | print(optimizer.param_groups[0]['lr'] ) 488 | ``` 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | -------------------------------------------------------------------------------- /机器学习经典blog.md: -------------------------------------------------------------------------------- 1 | - 朴素贝叶斯的经典blog:https://zhuanlan.zhihu.com/p/26262151 2 | - SVM经典介绍:https://zhuanlan.zhihu.com/p/49331510 3 | - 集成学习介绍:https://www.cnblogs.com/zongfa/p/9304353.html -------------------------------------------------------------------------------- /李宏毅机器学习.md: -------------------------------------------------------------------------------- 1 | [toc] 2 | 3 | 4 | 5 | # 课程简介 6 | 7 | ## introduction 8 | 9 | 机器学习的三个步骤: 10 | 11 | 1. 定义模型:定义带未知量的函数 12 | 2. 定义损失函数: 13 | 3. 优化:找参数最优值,如gradient descent 14 | 15 | 机器学习=找函数f(),大致分类 16 | 17 | - 二分类:输出是与否 18 | - 回归:输出标量 19 | - 多分类:CNN输出分类 20 | 21 | 怎么告诉机器需要的函数: 22 | 23 | - 监督学习:labeled data 24 | - 计算Loss 25 | - 机器会自动找出loss最低的情况 26 | - 强化学习:alphaGo——监督学习之上强化学习 27 | - 无监督学习 28 | 29 | 机器怎样找出你想要的函数? 30 | 31 | ## rule 32 | 33 | - git 34 | - github 35 | - Ubuntu环境下pyenv配置 36 | 37 | 38 | 39 | ## gradient and error 40 | 41 | ### gradient:梯度下降法 42 | 43 | ### error来自: 44 | 45 | - variance:方差造成的偏差,理解为预测值之间的偏差,训练集误差小,测试集大==过拟合 46 | - 解决:更多data;正则化——让曲线更加平滑 47 | - bias:均值造成的偏差,理解为预测与真实值之间的偏差,训练集误差大==欠拟合 48 | - 49 | 50 | image-20201027204600646 51 | 52 | # 深度学习 53 | 54 | ## 简介 55 | 56 | 怎样提升准确度? 57 | 58 | image-20210309152535164 59 | 60 | - 过拟合: 61 | - 更多训练数据 62 | - 数据增强(翻转、裁剪) 63 | - 限制模型:减少参数、参数共享(CNN)、减少特征、提前结束、正则化、dropout 64 | - 交叉验证:N-fold Cross Validation 65 | - mismatch:训练集和测试集分布不一样 66 | 67 | ## 优化——梯度消失 68 | 69 | - 梯度消失:当走到梯度为0的地方,训练几乎停止 70 | 71 | - 鞍点:微分为0的点,可解决 72 | 73 | - 局部最优:local minima/maxima,不可解决 74 | 75 | image-20210309145858695 76 | 77 | - 判断方法:Hessian矩阵是二次微分矩阵 78 | 79 | image-20210312155725425 80 | 81 | - 当我们抵达critical point即梯度为0的时候,绿色这一项为0,可以通过红色部分来判断当前是局部最优还是鞍点 82 | 83 | - 很简单,我们可以分三种情况,如图很直观地指出了分类方法:每种类别的第二行是等价条件,可以通过特征值来判断是否正定 84 | 85 | image-20210312160339857 86 | 87 | - 通过鞍点的特征向量来计算loss减小的方向(实作中几乎没人这么做) 88 | 89 | 90 | 91 | - 从经验上看:鞍点更常见 92 | 93 | ## 优化——训练提示 94 | 95 | ### batch 96 | 97 | 为什么要用batch 98 | 99 | - 实验表明,batch_size越大,精确度会下降。一种可能的解释是batch训练时的loss函数不一样,因而遇到鞍点时可以继续训练 100 | 101 | image-20210309153907587 102 | 103 | - 实验表明,小批量测试集表现也更好(泛化性) 104 | 105 | - 劣势:batch_size越大,每一epoc用时越大,总用时越少 106 | 107 | ### momentum 108 | 109 | - 类比物理中的动量、惯性,每一次梯度更新方向还要考虑前一次梯度方向 110 | 111 | image-20210309155155708 112 | 113 | ### Adaptive Learning Rate 114 | 115 | - 怎样选择学习率?? 116 | 117 | - 原则:梯度变化平缓,学习率设置大一点;反之,小一点 118 | 119 | image-20210316100705186 120 | 121 | - Adagrad方法: 122 | 123 | - 更新原则: 124 | $$ 125 | \boldsymbol{\theta}_{i}^{t+1} \leftarrow \boldsymbol{\theta}_{i}^{t}-\frac{\eta}{\sigma_{i}^{t}} \boldsymbol{g}_{i}^{t} \quad \sigma_{i}^{t}=\sqrt{\frac{1}{t+1} \sum_{i=0}^{t}\left(\boldsymbol{g}_{i}^{t}\right)^{2}} 126 | $$ 127 | 128 | - 其中,g为梯度,$\eta$为学习率 129 | 130 | - 直观解释:缓梯度的时候,参数更小,学习率就更大 131 | 132 | image-20210316101446560 133 | 134 | - 缺点:参数不随时间变化,不能动态调整 135 | 136 | - RMSProp: 137 | 138 | - 思路:可以自己调整梯度的占比 139 | $$ 140 | \begin{aligned} 141 | &\boldsymbol{\theta}_{i}^{1} \leftarrow \boldsymbol{\theta}_{i}^{0}-\frac{\eta}{\sigma_{i}^{0}} g_{i}^{0} \quad \sigma_{i}^{0}=\sqrt{\left(g_{i}^{0}\right)^{2}}\\ 142 | &\boldsymbol{\theta}_{i}^{2} \leftarrow \boldsymbol{\theta}_{i}^{1}-\frac{\eta}{\sigma_{i}^{1}} g_{i}^{1} \quad \sigma_{i}^{1}=\sqrt{\alpha\left(\sigma_{i}^{0}\right)^{2}+(1-\alpha)\left(g_{i}^{1}\right)^{2}}\\ 143 | &\boldsymbol{\theta}_{i}^{3} \leftarrow \boldsymbol{\theta}_{i}^{2}-\frac{\eta}{\sigma_{i}^{2}} g_{i}^{2} \quad \sigma_{i}^{2}=\sqrt{\alpha\left(\sigma_{i}^{1}\right)^{2}+(1-\alpha)\left(g_{i}^{2}\right)^{2}}\\ 144 | &\boldsymbol{\theta}_{i}^{t+1} \leftarrow \boldsymbol{\theta}_{i}^{t}-\frac{\eta}{\sigma_{i}^{t}} \boldsymbol{g}_{i}^{t} \quad \sigma_{i}^{t}=\sqrt{\alpha\left(\sigma_{i}^{t-1}\right)^{2}+(1-\alpha)\left(\boldsymbol{g}_{\mathfrak{q}}^{t}\right)^{2}} 145 | \end{aligned} 146 | $$ 147 | 148 | 149 | - Adam:RMSProp+Momentum 150 | - 学习率衰减decay 151 | - warm up:学习率先增后减(resNet、Transformer) 152 | 153 | ### 优化总结 154 | 155 | image-20210316103252293 156 | 157 | ## 分类(短版本) 158 | 159 | - 用回归做:引入独热向量,每个类是一次回归 160 | 161 | - 用分类区别: 162 | 163 | image-20210316103847177 164 | 165 | - Loss函数:基本都用Cross-entropy交叉熵,MSE也可以但是hui 166 | 167 | - PS:pytorch里面,如果使用nn.CrossEntropyLoss() 则自动使用softmax而不需要添加softmax层 168 | 169 | 170 | 171 | # CNN & Self-Attention 172 | 173 | ## CNN 174 | 175 | - 背景:输入大小一样,输出为one-hot 176 | 177 | - 已有的解决方法:将图片像素全部拉直成特征,喂到DNN中 178 | 179 | - 观察1:通过找图中的patterns(我理解为:部分特征),然后进行提取 180 | 181 | image-20210316145529979 182 | 183 | 所以得到简化方法:分receptive field考虑,比如先考虑任意3x3x3 184 | 185 | 引入基本概念:kernel size、channel、stride、padding、 186 | 187 | - 观察2:同一个patterns可能出现在不同图片的不同地方 188 | 189 | 简化方法:共享参数(filter相同) 190 | 191 | - 卷积层: 192 | 193 | - 使用filter来抓取图像中的patterns 194 | 195 | - 图像通过filter得到的结果叫:feature map 196 | 197 | - 有多少filter,feature map就有多少channels 198 | 199 | - 多层卷积层的效果:如下图,假设上面矩阵(原图)用3x3卷积核,然后得到下面矩阵,如果再来一次卷积,则卷积的范围在原图中就更大一点(蓝色框)。也就是层越深,考虑的范围越大 200 | 201 | image-20210316153948211 202 | 203 | - 观察3: 204 | 205 | - 下采样subsampling:即缩小图片,比如可以间隔s个像素取出来生成新的图片 206 | 207 | - 整个框架 208 | 209 | image-20210316155023358 210 | 211 | - 应用:下围棋 212 | 213 | ## self-attention 214 | 215 | ### intro 216 | 217 | - 前面看到的输入都一样长,那么如果输入不一样长的序列会怎么办呢 218 | - 举例:输入序列this is a cat 219 | - 表示方法——独热向量:一个词占一个维度,但没有突出单词间的关系 220 | - word embedding:每个词一个向量(包含语义),同类词进行聚类【一句话就是长度不一的向量】 221 | - 举例:音频、图结构(社交网络)、分子结构 222 | - 输出情况: 223 | - 一个向量对应一个label【sequence labeling】:如POS tagging(标词性)、声音识别(HW2)、社交网络图 224 | - 所有向量对应一个label:Sentiment analysis(情感分析)、语音辨认、判断分子是什么 225 | - 模型决定输出长度:【seq2seq】(HW5) 226 | 227 | ### 自注意力 228 | 229 | https://www.youtube.com/watch?v=hYdO9CscNes 230 | 231 | #### 先前做法 232 | 233 | 考虑输入输出一样的情况【sequence labeling】 234 | 235 | - 先前做法:分别对每一个sequence进行FC,独立判断这个sequence的输出 236 | 237 | - 但是sequence之间是有联系的,因此得考虑context 238 | 239 | - 因此当前FC可以给当前和前后向量,如下图 240 | 241 | image-20210326152046725 242 | 243 | 问题:输入序列长度不一致,全面概括需要大量参数 244 | 245 | #### 引入自注意力 246 | 247 | image-20210318151209129 248 | 249 | - self-attention可交替/叠加使用(多次使用) 250 | 251 | - 内部结构 252 | 253 | - a1~a4可能时输入层,也可能是隐藏层 254 | 255 | image-20210318151415172 256 | 257 | - 那么怎么考虑b1与a1相关的向量之间得关联性呢? 258 | 259 | - 评估相关程度$\alpha$的方法: 260 | 261 | - $\alpha$代表 262 | 263 | image-20210318151807291 264 | 265 | - 具体做法: 266 | 267 | - 计算a1与a2~a4之间的关联性 268 | 269 | image-20210318151949157 270 | 271 | - 一般自己跟自己也计算关联性(可实验) 272 | 273 | - 使用softmax进行normalization 274 | 275 | image-20210318152052851 276 | 277 | - 基于attention分数抽取重要资讯:attention分数越大,在最终信息中占比就越大 278 | 279 | image-20210318152339780 280 | 281 | ​ 282 | 283 | # Theory of ML 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | # Transformer 304 | 305 | ## Normalization 306 | 307 | ### Batch Normalization训练部分 308 | 309 | HW3(CNN)能用上 310 | 311 | - 问题:当不同特征的input值时大时小,导致权重w对于损失函数的变化也时大时小 312 | 313 | 怎么将输入放在同一个范围中呢 314 | 315 | image-20210330145852476 316 | 317 | - 方法:Feature Normalization 318 | 319 | 将同一维度,不同特征的变量标准化【标准化】 320 | 321 | image-20210330150249812 322 | 323 | - 细节:Normalization放在激活函数前后都可 324 | 325 | - 引入BN:引入均值和方差(都是向量),参数也更多,一般是一次batch算一次,也叫batch normalization 326 | 327 | ### 测试部分 328 | 329 | - 使用moving average来计算均值和方差,pytorch中自动计算 330 | 331 | image-20210330151521760 332 | 333 | - 为什么BN有用? 334 | 335 | 看paper,貌似是一个偶然的发现2333333 336 | 337 | 338 | 339 | ## Transformer 340 | 341 | ### 简介 342 | 343 | - Seq2seq模型:模型决定输出长度 344 | - 应用: 345 | 346 | - 语音识别【语音转文字】 347 | - 机器翻译【文字转文字】 348 | - 语音翻译【语音另一种文字】(因为有些语言没有文字) 349 | - 语音合成【文字转语音】 350 | - 聊天机器人【input->seq2seq->reply】 351 | - QA【question & context->seq2seq->answer】 352 | - 文法剖析【硬 train 一發 】论文:Grammar as a foreign language 353 | - multi-lable classification【自己决定label数】 354 | - 目标检测【https://arxiv.org/abs/2005.12872】 355 | - seq2seq起源:https://arxiv.org/abs/1409.3215 356 | 357 | ### encoder 358 | 359 | - encoder部分==bert的结构: 360 | 361 | image-20210416173444271 362 | 363 | - 结构变化相关论文: 364 | - https://arxiv.org/abs/2002.04745 【On Layer Normalization in the Transformer Architecture】 365 | - https://arxiv.org/abs/2003.07845 【PowerNorm: Rethinking Batch Normalization in Transformers】 366 | 367 | ### decoder 368 | 369 | - decoder架构: 370 | 371 | image-20210418162535707 372 | 373 | 首先,给decoder一个Begin Of Sentences(自己设计的独热向量),然后通过decoder输出序列向量,再经过softmax,选出可能性最大的预测值,并将输出作为下一次的输入 374 | 375 | - 结构: 376 | 377 | image-20210418165505367 378 | 379 | - masked self-attention:做attention时只考虑前面的序列,如下图生成b2时之关注a1和a2 380 | 381 | image-20210418165425721 382 | 383 | - 怎么让他停止产生序列? 384 | 385 | 设置end标识符 386 | 387 | image-20210418165816847 388 | 389 | - AT(autogressive) vs NAT: 390 | 391 | image-20210418170104702 392 | 393 | - NAT优势:并行性、可控输出长度、但通常比AT表现差 394 | 395 | - 传递部分:cross attention 396 | 397 | image-20210418170555481 398 | 399 | - 关于cross attention不一定要从encoder的最后一层来融合的论文:https://arxiv.org/abs/2005.08081 400 | 401 | 402 | 403 | ### 训练 404 | 405 | - 损失函数:最小化cross entropy 406 | 407 | - teacher forcing:将真实值当作decoder输入 408 | 409 | image-20210418172011357 410 | 411 | 412 | 413 | ### tips 414 | 415 | - 复制机制: 416 | 417 | - 聊天机器人(对于不懂得东西,直接复制)、总结 418 | 419 | - 进一步了解↓ 420 | 421 | image-20210418172313805 422 | 423 | - guided attention:强迫学习到相关的样貌【monotonic attention、location-aware attention】 424 | 425 | - beam search 426 | 427 | image-20210418173244744 428 | 429 | - 原文训练时用BLUE score衡量而不是cross entropy,BLUE越大越好,衡量句子之间的距离。但BLUE不可微分,不能求导,用强化学习硬train吧 430 | 431 | image-20210418173927797 432 | 433 | - exposure bias:面临一步错步步错的问题,怎么办? 434 | 435 | 在训练的时候就加一些错误信息 436 | 437 | image-20210418174056060 438 | 439 | 解决:scheduled sampling 440 | 441 | image-20210418174132556 442 | 443 | # GAN 444 | 445 | ## 基础 446 | 447 | - 将网络作为生成器,输入x和z,z取自一个简单的分布(高斯分布等等),而y也就是一个分布 448 | 449 | image-20210421165116091 450 | 451 | - 为什么需要输出一个分布? 452 | 453 | 例子:视频预测糖豆人游戏的后面的画面时,如果用普通网络输入糖豆人有的向左有的向右,则他会预测一个糖豆人向左右转都对 454 | 455 | image-20210421165723286 456 | 457 | 输出如果是分布,则可以预测概率 458 | 459 | image-20210421165851367 460 | 461 | - 什么时候用:当人物需要有创造力的时候(同一个输入有多种正确的输出) 462 | 463 | - 画图【红眼睛:辉夜、库拉皮卡】 464 | - 聊天机器人【辉夜是谁:火影前传、猎人后传】 465 | 466 | - 生成式模型中的GAN——以生成动漫人脸距离【unconditional generation:先把x拿掉】 467 | 468 | - 输入的z是一个低维向量,输出y是一个高维向量(图片) 469 | 470 | - generator的作用:想办法让简单的分布对应到高维向量 471 | 472 | image-20210421170658678 473 | 474 | - Discriminator:是一个神经网络(比如CNN),输入图片,输出标量值,代表图片的真实性 475 | 476 | image-20210421170910325 477 | 478 | - GAN基本思想: 479 | 480 | - 尝试用generator来骗过discriminator的检测,因此adversarial的意思就是他们俩是对抗关系,比如generator做假钞,discriminator是警察,则双方都会在竞争对抗中越来越厉害 481 | 482 | image-20210421171651651 483 | 484 | - 算法:首先初始化Generator、Discriminator 485 | 486 | - step1:G固定,更新D。其中D可以使用回归或者分类来学习到真实数据与生成数据之间的差异 487 | 488 | image-20210421172245752 489 | 490 | - step2:固定D,训练G。相当于让D当作评价标准,训练G的准确率越高越好。网络将黄框部分合并,中间图片也当作一个hidden layer 491 | 492 | image-20210421172809053 493 | 494 | - 以上步骤反复训练 495 | 496 | - 现有应用: 497 | 498 | - progressive GAN:生成现实人脸 499 | - style GAN:生成二次元人脸https://www.gwern.net/Faces 500 | - BigGAN 501 | 502 | 503 | ## 理论 504 | 505 | ### 简单理论 506 | 507 | - 以一维分布举例,PG代表生成的分布,Pdata代表真实数据,Divergence代表PG和Pdata之间的某种距离 508 | 509 | image-20210421213553242 510 | 511 | - 怎么计算divergence:做sample即可,图库、生成器分别取样 512 | 513 | - discriminator:Pdata给高分,PG给低分。然后训练discriminator获得高分即可。目标函数V如下图 514 | 515 | image-20210421221545311 516 | 517 | 转换一下,之际上要求G\*等价于求D\*.那么就把D\*中的max替换过去就得到了: 518 | 519 | image-20210422150246957 520 | 521 | ### tips——WGAN 522 | 523 | - 问题:PG和Pdata重叠部分很少【因为都是sample出来的,不知道原来的样本重叠情况】 524 | 525 | 当两个分布没有重合时,JS divergence==log2,所以容易产生无差别的差异 526 | 527 | image-20210422151804599 528 | 529 | - Wasserstein distance用来计算从P分布到Q分布的平均最小步骤,如上图 530 | 531 | WGAN就是用W distance来计算PG和Pdata的距离而不用JS divergence。公式如下图,其中D需要足够平滑,X应该是Y代表输出 532 | 533 | image-20210422152302248 534 | 535 | 其中,如果Pdata和PG无重叠,那么上式会让第一个Dx趋于无穷,第二个Dx区域负无穷,则无法收敛。因此需要D足够平滑 536 | 537 | image-20210422152447577 538 | 539 | - 设置上下界±c(原文) 540 | - 梯度惩罚(https://arxiv.org/abs/1704.00028) 541 | - SNGAN(https://arxiv.org/abs/1802.05957) 542 | 543 | ## 生成器效能評估與條件式生成 544 | 545 | - GAN的问题:G和D容易同时退步,需要*棋逢敌手* 546 | 547 | - 最难的是生成一段文字,G生成文字,D来判断生成文字与真实文字的差距。因为没法算微分:当改变decoder中的参数,最后max输出的值也不变,因此对于D来说没有改变 548 | 549 | image-20210422154956514 550 | 551 | 虽然可用强化学习解决,但强化学习也很难训练 552 | 553 | - 直到ScrachGAN(https://arxiv.org/abs/1905.09922) 554 | 555 | - 选修内容: 556 | 557 | image-20210422155316232image-20210422155331319 558 | 559 | - 可能的解决方法:随机的向量通过监督学习硬训练 560 | 561 | image-20210422155928340 562 | 563 | 参考:https://arxiv.org/abs/1707.05776 564 | 565 | https://arxiv.org/abs/2007.02798 566 | 567 | - 怎样评估生成图像的好坏? 568 | 569 | - 将图片放到图片分类器中,出来的分类结果越集中越好 570 | 571 | image-20210422160508163 572 | 573 | - 多样性:越平坦越好 574 | 575 | image-20210422161205431 576 | 577 | - IS的评价标准:质量高(单个有突出)且多样性大(总的平坦) 578 | 579 | - FID评价标准:图片经过cnn得到的向量(不经过softmax)之间的距离 580 | 581 | image-20210422161931130 582 | 583 | - 存在的问题 584 | 585 | - mode collapse:训练时可能生成数据只围绕某一个真实数据,导致生成的图片反反复复就那几张 586 | 587 | image-20210422160743237 588 | 589 | 解决:碰到问题之前停止训练,取上一次的model 590 | 591 | - model dropping(更难侦测出来):训练时可能生成数据只围绕某些真实数据,比如下图第t次生成都是白脸,t+1次都是黄脸 592 | 593 | image-20210422160957357 594 | 595 | - 另一个问题:可能G生成的图片和真实图片一模一样,那做GAN的目的是什么呢 596 | 597 | image-20210422171112211 598 | 599 | ## conditional GAN 600 | 601 | 加入条件x,从z中选取样本,得到y分布是有条件的GAN 602 | 603 | image-20210424172112068 604 | 605 | - 那么网络结构就需要调整:对于discriminator,评价时不光需要看生成图片是否真实,还要看与输入条件x是否匹配。对应的,标签也需要成对出现。 606 | 607 | image-20210424172720756 608 | 609 | - 不光是通过文字产生图片,还可以通过图片产生图片(黑白变彩色,素描变实物,去雾等等pix2pix应用) 610 | 611 | - 如果用监督学习,可能非常模糊 612 | 613 | - 因此可以用GAN+监督学习(https://arxiv.org/abs/1611.07004) 614 | 615 | image-20210424173112520 616 | 617 | - 莫名其妙的应用: 618 | 619 | - 声音->图像 620 | 621 | image-20210424173306104 622 | 623 | - 图片变动图(https://arxiv.org/abs/1905.08233) 624 | 625 | ## Cycle GAN 626 | 627 | - 无监督学习:样本无法获得labels【比如图片风格转换:将真实人物图片转成二次元风格】 628 | 629 | - 面临的问题:如果使用下图所示的方法,D学习到的只是能让G生成二次元图片,但跟输入却没有关系。 630 | 631 | image-20210425170250536 632 | 633 | - Cycle GAN:再增加一个G来保证输出与原图尽可能像 634 | 635 | image-20210425171247571 636 | 637 | 理论上会出现G学到其他转换,比如左右翻转,但实作上,区别不大 638 | 639 | - 双向Cycle GAN: 640 | 641 | image-20210425171554356 642 | 643 | - 其他GAN(前三个想法一致) 644 | 645 | - Disco GAN:https://arxiv.org/abs/1703.05192 646 | - Dual GAN:https://arxiv.org/abs/1704.02510 647 | - Cycle GAN:https://arxiv.org/abs/1703.10593 648 | - StarGAN:https://arxiv.org/abs/1711.09020 649 | - demo:http://selfie2anime.com/ 650 | 651 | - 文本风格转换:消极->积极 652 | 653 | image-20210425172148679 654 | 655 | - 其他应用: 656 | - 生成总结句子:https://arxiv.org/abs/1810.02851 657 | - 非监督翻译:https://arxiv.org/abs/1710.04087、https://arxiv.org/abs/11041 658 | - 分监督语音辨识:https://arxiv.org/abs/1804.00316、https://arxiv.org/abs/1812.09323、https://arxiv.org/abs/1904.04100 659 | 660 | 661 | 662 | # 自监督学习 663 | 664 | ## 简介 665 | 666 | - 芝麻街大家族(xswl233333也是自监督学习的家族 667 | 668 | image-20210425204148286 669 | 670 | - 模型参数越来越多......除了下图,还有GPT-3、switch transformer 671 | 672 | image-20210425205040915 673 | 674 | 675 | 676 | ## BERT简介 677 | 678 | - 自监督vs无监督: 679 | - 自监督是自己做label; 680 | 681 | - 无监督是无label 682 | 683 | - 自监督是一种无监督的方法 684 | 685 | - 比如下面这样,x分成两部分:$x^{'}$和$x^{''}$,前者用于训练,后者用于生成label 686 | 687 | image-20210426155345732 688 | 689 | - masking input:将输入以随机概率进行mask(盖住)或者random替换成另一个字 690 | 691 | image-20210426160707409 692 | 693 | - next sentence prediction:训练**[CLS]**用来表示句子是否关联(sentence 1和sentence 2是否连接),但这个方法可能不是很适用 694 | 695 | image-20210426161213448 696 | 697 | - BERT学到:怎么去填空,可以来做各式各样的下游任务(微调)。【HW7】 698 | 699 | image-20210426191844194 700 | 701 | - 评价标准:[GLUE](https://gluebenchmark.com/),中文版本是[CLUE](https://cluebenchmarks.com/)。GLUE的分数等于9类任务的平均分 702 | 703 | image-20210426192608975 704 | 705 | - BERT的使用例子1:输入序列输出分类【情感分析】,其中bert部分参数是pre-train的,Linear+softmax是随机初始化的,整体结构如下图,这个整体需要拿来进行半监督训练。 706 | 707 | image-20210426193002439 708 | 709 | https://arxiv.org/abs/1908.05620证明了预训练比随机初始化效果好 710 | 711 | image-20210426193158475 712 | 713 | - BERT的使用例子2:输入输出都是序列,长度一样【词性标注】 714 | 715 | - BERT的使用例子3:输入两个序列,输出类别【自然语言推理:吃两个句子,推断出其意见是赞同还是反对】 716 | 717 | - BERT的使用例子4:Extraction-based QA【HW7】输出是两个整数s和e,代表第s到第e个词汇之间的就是答案 718 | 719 | image-20210426200251921 720 | 721 | - BERT输入的长度不是无限长,输入长了可以切割了再放入 722 | 723 | - 做seq2seq需要对encoder的输入做corrupt(翻转、删除、顺序打乱等等),decoder的输出要训练得与输入一致 724 | 725 | ## BERT原理 726 | 727 | - 因为BERT==encoder,所以这两边输入的果会根据attention考虑上下文信息,得到的embeddings向量也就不一样,计算其cos相似性 728 | 729 | image-20210426210304600 730 | 731 | - 因此可以认为BERT输出的向量就代表那个字的意思,那么挖空可以直观解释为:利用上下文判断这个单词应该是什么(填空) 732 | 733 | 734 | 735 | - 选修课: 736 | 737 | image-20210426211404023 738 | 739 | - Multi-lingual BERT:多语言训练的BERT,会做104种语言的填空题。 740 | 741 | 实验中发现,英文训练QA,他能做中文QA 742 | 743 | 其中pre-train代表会做填空题的训练 744 | 745 | image-20210426211715064 746 | 747 | - 有趣的现象:如果分别用中文英文训练BERT,然后再计算他们的平均之间的差异,然后在一个英文任务中加上这个差异就能变成中文的任务 748 | 749 | image-20210426213140462 750 | 751 | 这样就可以做无监督的翻译啦 752 | 753 | ## GPT-3 754 | 755 | 回顾:Bert是做填空题,gpt是做预测后文的模型【生成的能力】icon:独角兽 756 | 757 | - 基本结构:做masked-attention,然后预测下一个单词是什么 758 | 759 | image-20210506104256839 760 | 761 | - in-context learning:给任务描述和少许例子,预测出答案 762 | 763 | image-20210506105728380 764 | 765 | - 选修: 766 | 767 | image-20210506105922818 768 | 769 | - 拓展内容:半监督学习的应用总结,在各领域都有应用 770 | 771 | image-20210506110126950 772 | 773 | - SimCLR【https://arxiv.org/abs/2002.05709】:图像视觉表示 774 | - BYOL【https://arxiv.org/abs/2006.07733】:一种自我监督学习的新方法 775 | 776 | ## Auto-Encoders 777 | 778 | 回归:自监督学习也叫pre-train,不需要label data,比如bert可以做填空,gpt可以补全句子,做一些微调就可以把他们用在一些下游任务 779 | 780 | image-20210506111412516 781 | 782 | - Auto-encoder(reconstruction):根Cycle GAN差不多,将图片通过encoder转换成11维向量【叫embedding或者code或者】,然后向量通过decoder转换成另一张图片,要尽可能decoder生成的图像与原图接近 783 | 784 | 也就是将高维向量转换成中间低维向量,然后再通过dimention reduction将低维向量转换成新的照片 785 | 786 | image-20210506152423892 787 | 788 | - 选修:PCA等dimention reduction技术 789 | 790 | image-20210506152548518 791 | 792 | - 为什么能用auto-encoder? 793 | 794 | 一种解释是原来的图片特征无论怎么变换都只有很少几种或者几十种,那么我们就可以通过encoder化繁为简 795 | 796 | - de-noising auto-encoder【https://dl.acm.org/doi/10.1145/1390156.1390294】:图片送入encoder之前添加噪声 797 | 798 | image-20210506153601073 799 | 800 | - BERT也可以看作是一个de-noising auto-encoder 801 | 802 | image-20210506153901903 803 | 804 | ## auto-encoder的其他应用 805 | 806 | - 在cv、nlp、语音处理领域都可以用auto-encoder 807 | 808 | - Feature Disentangle特征解读:中间特征的每一个维度的含义是什么,比如对于语音来说:前50维代表背景信息,后50维代表说话者的信息 809 | 810 | image-20210506154337814 811 | 812 | - 应用:voice conversion音色变换(柯南变声器 813 | 814 | 将中间向量提取的语音内容特征部分不变,讲话人的特征进行替换,就可以用别人的声音做音色转换 815 | 816 | image-20210506161158597 817 | 818 | - Discrete Latent Representation:将中间向量变成二分类形式,或者独热向量(强制让它学会)【即将中间向量变成离散的,且中间变量是自己学会的】 819 | 820 | image-20210506214657219 821 | 822 | - VQVAE:学习一个Codebook让decoder从中选与encoder输出向量相似度最高的向量,codebook中向量个数固定。用于语音则codebook可以学到基本的音标 823 | 824 | image-20210507150902646 825 | 826 | - 文本表示:给一堆文本,让网络自己学会文本总结,但是可能他的总结是只有自己看得懂的,因此引入GAN中的discriminator来让中间变量更接近人写的文字【Cycle GAN】 827 | 828 | image-20210507151343453 829 | 830 | - 更多应用: 831 | 832 | - 将auto-encoder中的decoder拿出来当Generator用 833 | 834 | image-20210507151651643 835 | 836 | - 压缩和解压缩(会失真 837 | 838 | - Anomaly Detection异常检测:判断新来的数据与原数据是否相似(normal or anomaly) 839 | 840 | image-20210507152115058 841 | 842 | 可以用来做欺诈检测,比如信用卡交易、网络侵入检测、癌症癌细胞检测(是二分类问题吗?不太能是,因为一般训练集都有大量正常资料而很少有异常资料,因此这也叫one class) 843 | 844 | image-20210507152337617 845 | 846 | 做法:训练正常照片,测试异常照片时,发现输出和输入差异较大,则判断为anomaly 847 | 848 | image-20210507152707163 849 | 850 | - 选修: 851 | 852 | image-20210507152735179 853 | 854 | 855 | 856 | 857 | 858 | 859 | 860 | 861 | 862 | 863 | 864 | 865 | 866 | 867 | 868 | 869 | 870 | 871 | 872 | 873 | 874 | 875 | 876 | 877 | 878 | 879 | 880 | 881 | 882 | 883 | 884 | 885 | 886 | 887 | 888 | 889 | 890 | # 可解释AI/对抗攻击 891 | 892 | 893 | 894 | 895 | 896 | 897 | 898 | 899 | 900 | # Domain Adaptation 901 | 902 | ## 简介 903 | 904 | - domain shift问题:训练和测试集的分布不同【比如黑白数字识别,在彩色数字图上效果并不好】 905 | 906 | - domain adaptive:类似迁移学习,将在A训练集训练的模型用在B场景 907 | 908 | - 分类: 909 | 910 | - 输入分布不一样【:star:】 911 | - 输出分布不一样【比如:训练集每个数字出现的概率相同,而测试集某些数字出现概率大】 912 | - 输入输出关系不一样 913 | 914 | image-20210607212712004 915 | 916 | 这里只关注第一种情况 917 | 918 | 919 | 920 | ## 基本思想 921 | 922 | 情况: 923 | 924 | - 当target domain量很足,且有标注时,可以直接拿来训练 925 | - 当target domain量很少,且有标注时,可以用来微调source domain上的模型【类似bert】当心过拟合 926 | - 当target domain量很足,且无标注时【:star:】 927 | 928 | 基本做法:训练一个网络,去掉差异(比如颜色信息),从而抽取出同样分布的特征 929 | 930 | image-20210607214340388 931 | 932 | Domain Adversarial Training: 933 | 934 | - domain classifier是一个二元分类器,类似GAN中的discriminator 935 | 936 | image-20210607215353480 937 | 938 | 939 | 940 | 优化1:预测结果越集中越好 941 | 942 | image-20210607215901782 943 | 944 | 优化2:如何解决训练测试集有不同的label的问题 945 | 946 | image-20210607220327372 947 | 948 | 进一步:如果测试集为空,使用Domain Generalization 949 | 950 | - 情况1:训练集多个domain,学习模型间差异 951 | - 情况2:测试集多个domain,类似数据增强来做 952 | 953 | image-20210607220824914 954 | 955 | 956 | 957 | 958 | 959 | 960 | 961 | # Privacy v.s. ML 962 | 963 | # RL 964 | 965 | 966 | 967 | # Quantum ML 968 | 969 | 970 | 971 | 972 | 973 | # Life-Long/Compression 974 | 975 | 976 | 977 | 978 | 979 | # Meta Learning 980 | 981 | 982 | 983 | 984 | 985 | 986 | 987 | 988 | 989 | 990 | 991 | 992 | 993 | 994 | 995 | 996 | 997 | 998 | 999 | 1000 | 1001 | 1002 | 1003 | 1004 | 1005 | 1006 | 1007 | 1008 | 1009 | 1010 | 1011 | --------------------------------------------------------------------------------