├── MAML-miniImageNet-5-way-1-shot.ipynb ├── MAML-miniImageNet-5-way-5-shot-.ipynb ├── MAML-omniglot-ADAM-20way-1-shot.ipynb ├── MAML-omniglot-ADAM-20way-5shot-16batch-.ipynb ├── MAML-omniglot-ADAM-5way-32batch.ipynb ├── MAML-v1.ipynb ├── Preprocess.ipynb └── README.md /MAML-omniglot-ADAM-20way-1-shot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-04-14T03:22:31.272106Z", 9 | "start_time": "2020-04-14T03:22:30.955368Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import torch\n", 15 | "import numpy as np\n", 16 | "import os\n", 17 | "import zipfile\n", 18 | "\n", 19 | "# root_path = './../datasets'\n", 20 | "# processed_folder = os.path.join(root_path)\n", 21 | "\n", 22 | "# zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot.zip'), 'r')\n", 23 | "# zip_ref.extractall(root_path)\n", 24 | "# zip_ref.close()\n", 25 | "root_dir = './../datasets/omniglot/python'\n", 26 | "root_dir_train = os.path.join(root_dir,'images_background')" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "ExecuteTime": { 34 | "end_time": "2020-04-07T06:55:03.418079Z", 35 | "start_time": "2020-04-07T06:55:03.346124Z" 36 | } 37 | }, 38 | "outputs": [], 39 | "source": [] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "### 数据预处理\n", 46 | "拿到原始数据之后先将下面的代码取消注释,进行数据预处理。" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": { 53 | "ExecuteTime": { 54 | "end_time": "2020-04-13T07:07:01.042611Z", 55 | "start_time": "2020-04-13T07:07:01.035061Z" 56 | }, 57 | "scrolled": true 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "\n", 62 | "# # 数据预处理\n", 63 | "# import torchvision.transforms as transforms\n", 64 | "# from PIL import Image\n", 65 | "\n", 66 | "# '''\n", 67 | "# an example of img_items:\n", 68 | "# ( '0709_17.png',\n", 69 | "# 'Alphabet_of_the_Magi/character01',\n", 70 | "# './../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')\n", 71 | "# '''\n", 72 | "\n", 73 | "\n", 74 | "# root_dir_train = os.path.join(root_dir, 'images_background')\n", 75 | "# root_dir_test = os.path.join(root_dir, 'images_evaluation')\n", 76 | "\n", 77 | "# def find_classes(root_dir_train):\n", 78 | "# img_items = []\n", 79 | "# for (root, dirs, files) in os.walk(root_dir_train): \n", 80 | "# for file in files:\n", 81 | "# if (file.endswith(\"png\")):\n", 82 | "# r = root.split('/')\n", 83 | "# img_items.append((file, r[-2] + \"/\" + r[-1], root))\n", 84 | "# print(\"== Found %d items \" % len(img_items))\n", 85 | "# return img_items\n", 86 | "\n", 87 | "# ## 构建一个词典{class:idx}\n", 88 | "# def index_classes(items):\n", 89 | "# class_idx = {}\n", 90 | "# count = 0\n", 91 | "# for item in items:\n", 92 | "# if item[1] not in class_idx:\n", 93 | "# class_idx[item[1]] = count\n", 94 | "# count += 1\n", 95 | "# print('== Found {} classes'.format(len(class_idx)))\n", 96 | "# return class_idx\n", 97 | " \n", 98 | "\n", 99 | "# img_items_train = find_classes(root_dir_train) # [(file1, label1, root1),..]\n", 100 | "# img_items_test = find_classes(root_dir_test)\n", 101 | "\n", 102 | "# class_idx_train = index_classes(img_items_train)\n", 103 | "# class_idx_test = index_classes(img_items_test)\n", 104 | "\n", 105 | "\n", 106 | "# def generate_temp(img_items,class_idx):\n", 107 | "# temp = dict()\n", 108 | "# for imgname, classes, dirs in img_items:\n", 109 | "# img = '{}/{}'.format(dirs, imgname)\n", 110 | "# label = class_idx[classes]\n", 111 | "# transform = transforms.Compose([lambda img: Image.open(img).convert('L'),\n", 112 | "# lambda img: img.resize((28,28)),\n", 113 | "# lambda img: np.reshape(img, (28,28,1)),\n", 114 | "# lambda img: np.transpose(img, [2,0,1]),\n", 115 | "# lambda img: img/255.\n", 116 | "# ])\n", 117 | "# img = transform(img)\n", 118 | "# if label in temp.keys():\n", 119 | "# temp[label].append(img)\n", 120 | "# else:\n", 121 | "# temp[label] = [img]\n", 122 | "# print('begin to generate omniglot.npy')\n", 123 | "# return temp\n", 124 | "# ## 每个字符包含20个样本\n", 125 | "\n", 126 | "# temp_train = generate_temp(img_items_train, class_idx_train)\n", 127 | "# temp_test = generate_temp(img_items_test, class_idx_test)\n", 128 | "\n", 129 | "# img_list = []\n", 130 | "# for label, imgs in temp_train.items():\n", 131 | "# img_list.append(np.array(imgs))\n", 132 | "# img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]\n", 133 | "# print('data shape:{}'.format(img_list.shape)) # (964, 20, 1, 28, 28)\n", 134 | "# np.save(os.path.join(root_dir, 'omniglot_train.npy'), img_list)\n", 135 | "# print('end.')\n", 136 | "\n", 137 | "\n", 138 | "# img_list = []\n", 139 | "# for label, imgs in temp_test.items():\n", 140 | "# img_list.append(np.array(imgs))\n", 141 | "# img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]\n", 142 | "# print('data shape:{}'.format(img_list.shape)) # (659, 20, 1, 28, 28)\n", 143 | "\n", 144 | "# np.save(os.path.join(root_dir, 'omniglot_test.npy'), img_list)\n", 145 | "# print('end.')" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "ExecuteTime": { 153 | "end_time": "2020-04-07T08:02:03.283025Z", 154 | "start_time": "2020-04-07T08:02:03.276106Z" 155 | } 156 | }, 157 | "outputs": [], 158 | "source": [] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 2, 163 | "metadata": { 164 | "ExecuteTime": { 165 | "end_time": "2020-04-14T03:23:43.874491Z", 166 | "start_time": "2020-04-14T03:23:41.975288Z" 167 | } 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "img_list_train = np.load(os.path.join(root_dir, 'omniglot_train.npy')) # (964, 20, 1, 28, 28)\n", 172 | "img_list_test = np.load(os.path.join(root_dir, 'omniglot_test.npy')) # (659, 20, 1, 28, 28)\n", 173 | "\n", 174 | "x_train = img_list_train\n", 175 | "x_test = img_list_test\n", 176 | "# num_classes = img_list.shape[0]\n", 177 | "datasets = {'train': x_train, 'test': x_test}" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 3, 197 | "metadata": { 198 | "ExecuteTime": { 199 | "end_time": "2020-04-14T03:23:51.352052Z", 200 | "start_time": "2020-04-14T03:23:50.310935Z" 201 | }, 202 | "code_folding": [] 203 | }, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "DB: train (964, 20, 1, 28, 28) test (659, 20, 1, 28, 28)\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "### 准备数据迭代器\n", 215 | "n_way = 20\n", 216 | "k_spt = 1 ## support data 的个数\n", 217 | "k_query = 15 ## query data 的个数\n", 218 | "imgsz = 28\n", 219 | "resize = imgsz\n", 220 | "task_num = 16\n", 221 | "batch_size = task_num\n", 222 | "\n", 223 | "indexes = {\"train\": 0, \"test\": 0}\n", 224 | "datasets = {\"train\": x_train, \"test\": x_test}\n", 225 | "print(\"DB: train\", x_train.shape, \"test\", x_test.shape)\n", 226 | "\n", 227 | "\n", 228 | "def load_data_cache(dataset):\n", 229 | " \"\"\"\n", 230 | " Collects several batches data for N-shot learning\n", 231 | " :param dataset: [cls_num, 20, 84, 84, 1]\n", 232 | " :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks\n", 233 | " \"\"\"\n", 234 | " # take 5 way 1 shot as example: 5 * 1\n", 235 | " setsz = k_spt * n_way\n", 236 | " querysz = k_query * n_way\n", 237 | " data_cache = []\n", 238 | "\n", 239 | " # print('preload next 10 caches of batch_size of batch.')\n", 240 | " for sample in range(10): # num of epochs\n", 241 | "\n", 242 | " x_spts, y_spts, x_qrys, y_qrys = [], [], [], []\n", 243 | " for i in range(batch_size): # one batch means one set\n", 244 | "\n", 245 | " x_spt, y_spt, x_qry, y_qry = [], [], [], []\n", 246 | " selected_cls = np.random.choice(dataset.shape[0], n_way, replace = False) \n", 247 | "\n", 248 | " for j, cur_class in enumerate(selected_cls):\n", 249 | "\n", 250 | " selected_img = np.random.choice(20, k_spt + k_query, replace = False)\n", 251 | "\n", 252 | " # 构造support集和query集\n", 253 | " x_spt.append(dataset[cur_class][selected_img[:k_spt]])\n", 254 | " x_qry.append(dataset[cur_class][selected_img[k_spt:]])\n", 255 | " y_spt.append([j for _ in range(k_spt)])\n", 256 | " y_qry.append([j for _ in range(k_query)])\n", 257 | "\n", 258 | " # shuffle inside a batch\n", 259 | " perm = np.random.permutation(n_way * k_spt)\n", 260 | " x_spt = np.array(x_spt).reshape(n_way * k_spt, 1, resize, resize)[perm]\n", 261 | " y_spt = np.array(y_spt).reshape(n_way * k_spt)[perm]\n", 262 | " perm = np.random.permutation(n_way * k_query)\n", 263 | " x_qry = np.array(x_qry).reshape(n_way * k_query, 1, resize, resize)[perm]\n", 264 | " y_qry = np.array(y_qry).reshape(n_way * k_query)[perm]\n", 265 | " \n", 266 | " # append [sptsz, 1, 84, 84] => [batch_size, setsz, 1, 84, 84]\n", 267 | " x_spts.append(x_spt)\n", 268 | " y_spts.append(y_spt)\n", 269 | " x_qrys.append(x_qry)\n", 270 | " y_qrys.append(y_qry)\n", 271 | "\n", 272 | "# print(x_spts[0].shape)\n", 273 | " # [b, setsz = n_way * k_spt, 1, 84, 84]\n", 274 | " x_spts = np.array(x_spts).astype(np.float32).reshape(batch_size, setsz, 1, resize, resize)\n", 275 | " y_spts = np.array(y_spts).astype(np.int).reshape(batch_size, setsz)\n", 276 | " # [b, qrysz = n_way * k_query, 1, 84, 84]\n", 277 | " x_qrys = np.array(x_qrys).astype(np.float32).reshape(batch_size, querysz, 1, resize, resize)\n", 278 | " y_qrys = np.array(y_qrys).astype(np.int).reshape(batch_size, querysz)\n", 279 | "# print(x_qrys.shape)\n", 280 | " data_cache.append([x_spts, y_spts, x_qrys, y_qrys])\n", 281 | "\n", 282 | " return data_cache\n", 283 | "\n", 284 | "datasets_cache = {\"train\": load_data_cache(x_train), # current epoch data cached\n", 285 | " \"test\": load_data_cache(x_test)}\n", 286 | "\n", 287 | "def next(mode='train'):\n", 288 | " \"\"\"\n", 289 | " Gets next batch from the dataset with name.\n", 290 | " :param mode: The name of the splitting (one of \"train\", \"val\", \"test\")\n", 291 | " :return:\n", 292 | " \"\"\"\n", 293 | " # update cache if indexes is larger than len(data_cache)\n", 294 | " if indexes[mode] >= len(datasets_cache[mode]):\n", 295 | " indexes[mode] = 0\n", 296 | " datasets_cache[mode] = load_data_cache(datasets[mode])\n", 297 | "\n", 298 | " next_batch = datasets_cache[mode][indexes[mode]]\n", 299 | " indexes[mode] += 1\n", 300 | "\n", 301 | " return next_batch\n" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 17, 328 | "metadata": { 329 | "ExecuteTime": { 330 | "end_time": "2020-04-14T03:27:10.793796Z", 331 | "start_time": "2020-04-14T03:27:10.756874Z" 332 | }, 333 | "code_folding": [] 334 | }, 335 | "outputs": [], 336 | "source": [ 337 | "import torch\n", 338 | "from torch import nn\n", 339 | "from torch.nn import functional as F\n", 340 | "from copy import deepcopy,copy\n", 341 | " \n", 342 | "\n", 343 | "class BaseNet(nn.Module):\n", 344 | " def __init__(self):\n", 345 | " super(BaseNet, self).__init__()\n", 346 | " self.vars = nn.ParameterList() ## 包含了所有需要被优化的tensor\n", 347 | " self.vars_bn = nn.ParameterList()\n", 348 | " \n", 349 | " # 第1个conv2d\n", 350 | " weight = nn.Parameter(torch.ones(64, 1, 3, 3))\n", 351 | " nn.init.kaiming_normal_(weight)\n", 352 | " bias = nn.Parameter(torch.zeros(64))\n", 353 | " self.vars.extend([weight,bias])\n", 354 | " \n", 355 | " # 第1个BatchNorm层\n", 356 | " weight = nn.Parameter(torch.ones(64))\n", 357 | " bias = nn.Parameter(torch.zeros(64))\n", 358 | " self.vars.extend([weight,bias])\n", 359 | " \n", 360 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 361 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 362 | " self.vars_bn.extend([running_mean, running_var])\n", 363 | " \n", 364 | " # 第2个conv2d\n", 365 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 366 | " nn.init.kaiming_normal_(weight)\n", 367 | " bias = nn.Parameter(torch.zeros(64))\n", 368 | " self.vars.extend([weight,bias])\n", 369 | " \n", 370 | " # 第2个BatchNorm层\n", 371 | " weight = nn.Parameter(torch.ones(64))\n", 372 | " bias = nn.Parameter(torch.zeros(64))\n", 373 | " self.vars.extend([weight,bias])\n", 374 | " \n", 375 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 376 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 377 | " self.vars_bn.extend([running_mean, running_var])\n", 378 | " \n", 379 | " # 第3个conv2d\n", 380 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 381 | " nn.init.kaiming_normal_(weight)\n", 382 | " bias = nn.Parameter(torch.zeros(64))\n", 383 | " self.vars.extend([weight,bias])\n", 384 | " \n", 385 | " # 第3个BatchNorm层\n", 386 | " weight = nn.Parameter(torch.ones(64))\n", 387 | " bias = nn.Parameter(torch.zeros(64))\n", 388 | " self.vars.extend([weight,bias])\n", 389 | " \n", 390 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 391 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 392 | " self.vars_bn.extend([running_mean, running_var])\n", 393 | " \n", 394 | " # 第4个conv2d\n", 395 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 396 | " nn.init.kaiming_normal_(weight)\n", 397 | " bias = nn.Parameter(torch.zeros(64))\n", 398 | " self.vars.extend([weight,bias])\n", 399 | " \n", 400 | " # 第4个BatchNorm层\n", 401 | " weight = nn.Parameter(torch.ones(64))\n", 402 | " bias = nn.Parameter(torch.zeros(64))\n", 403 | " self.vars.extend([weight,bias])\n", 404 | " \n", 405 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 406 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 407 | " self.vars_bn.extend([running_mean, running_var])\n", 408 | " \n", 409 | " ##linear\n", 410 | " weight = nn.Parameter(torch.ones([20,64]))\n", 411 | " bias = nn.Parameter(torch.zeros(20))\n", 412 | " self.vars.extend([weight,bias])\n", 413 | " \n", 414 | " def forward(self, x, params = None, bn_training=True):\n", 415 | " '''\n", 416 | " :bn_training: set False to not update\n", 417 | " :return: \n", 418 | " '''\n", 419 | " if params is None:\n", 420 | " params = self.vars\n", 421 | " \n", 422 | " weight, bias = params[0], params[1] # 第1个CONV层\n", 423 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 424 | " weight, bias = params[2], params[3] # 第1个BN层\n", 425 | " running_mean, running_var = self.vars_bn[0], self.vars_bn[1]\n", 426 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training, momentum = 1)\n", 427 | " x = F.relu(x, inplace = [True]) #第1个relu\n", 428 | " x = F.max_pool2d(x,kernel_size=2) #第1个MAX_POOL层 \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " weight, bias = params[4], params[5] # 第2个CONV层\n", 433 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 434 | " weight, bias = params[6], params[7] # 第2个BN层\n", 435 | " running_mean, running_var = self.vars_bn[2], self.vars_bn[3]\n", 436 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training, momentum=1)\n", 437 | " x = F.relu(x, inplace = [True]) #第2个relu\n", 438 | " x = F.max_pool2d(x,kernel_size=2) #第2个MAX_POOL层 \n", 439 | " \n", 440 | " \n", 441 | " weight, bias = params[8], params[9] # 第3个CONV层\n", 442 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 443 | " weight, bias = params[10], params[11] # 第3个BN层\n", 444 | " running_mean, running_var = self.vars_bn[4], self.vars_bn[5]\n", 445 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training,momentum=1)\n", 446 | " x = F.relu(x, inplace = [True]) #第3个relu,\n", 447 | " x = F.max_pool2d(x,kernel_size=2) #第3个MAX_POOL层\n", 448 | " \n", 449 | " \n", 450 | " weight, bias = params[12], params[13] # 第4个CONV层\n", 451 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 452 | " weight, bias = params[14], params[15] # 第4个BN层\n", 453 | " running_mean, running_var = self.vars_bn[6], self.vars_bn[7]\n", 454 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training)\n", 455 | " x = F.max_pool2d(x,kernel_size=2) #第4个MAX_POOL层\n", 456 | " \n", 457 | " x = F.relu(x, inplace = [True]) #第4个relu\n", 458 | " \n", 459 | " x = x.view(x.size(0), -1) ## flatten\n", 460 | " weight, bias = params[-2], params[-1] # linear\n", 461 | " x = F.linear(x, weight, bias)\n", 462 | " \n", 463 | " output = x\n", 464 | " \n", 465 | " return output\n", 466 | " \n", 467 | " \n", 468 | " def parameters(self):\n", 469 | " \n", 470 | " return self.vars\n" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "metadata": { 477 | "ExecuteTime": { 478 | "end_time": "2020-02-29T12:00:30.197710Z", 479 | "start_time": "2020-02-29T12:00:30.186076Z" 480 | } 481 | }, 482 | "outputs": [], 483 | "source": [] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "metadata": { 489 | "ExecuteTime": { 490 | "end_time": "2020-02-29T05:41:40.773998Z", 491 | "start_time": "2020-02-29T05:41:40.762077Z" 492 | } 493 | }, 494 | "outputs": [], 495 | "source": [] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 18, 500 | "metadata": { 501 | "ExecuteTime": { 502 | "end_time": "2020-04-14T03:27:13.001807Z", 503 | "start_time": "2020-04-14T03:27:12.969355Z" 504 | }, 505 | "code_folding": [] 506 | }, 507 | "outputs": [], 508 | "source": [ 509 | "class MetaLearner(nn.Module):\n", 510 | " def __init__(self):\n", 511 | " super(MetaLearner, self).__init__()\n", 512 | " self.update_step = 5 ## task-level inner update steps\n", 513 | " self.update_step_test = 5\n", 514 | " self.net = BaseNet()\n", 515 | " self.meta_lr = 0.001\n", 516 | " self.base_lr = 0.1\n", 517 | " self.meta_optim = torch.optim.Adam(self.net.parameters(), lr = self.meta_lr)\n", 518 | "# self.meta_optim = torch.optim.SGD(self.net.parameters(), lr = self.meta_lr, momentum = 0.9, weight_decay=0.0005)\n", 519 | " \n", 520 | " def forward(self,x_spt, y_spt, x_qry, y_qry):\n", 521 | " # 初始化\n", 522 | " task_num, ways, shots, h, w = x_spt.size()\n", 523 | " query_size = x_qry.size(1) # 75 = 15 * 5\n", 524 | " loss_list_qry = [0 for _ in range(self.update_step + 1)]\n", 525 | " correct_list = [0 for _ in range(self.update_step + 1)]\n", 526 | " \n", 527 | " for i in range(task_num):\n", 528 | " ## 第0步更新\n", 529 | " y_hat = self.net(x_spt[i], params = None, bn_training=True) # (ways * shots, ways)\n", 530 | " loss = F.cross_entropy(y_hat, y_spt[i]) \n", 531 | " grad = torch.autograd.grad(loss, self.net.parameters())\n", 532 | " tuples = zip(grad, self.net.parameters()) ## 将梯度和参数\\theta一一对应起来\n", 533 | " # fast_weights这一步相当于求了一个\\theta - \\alpha*\\nabla(L)\n", 534 | " fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))\n", 535 | " # 在query集上测试,计算准确率\n", 536 | " # 这一步使用更新前的数据\n", 537 | " with torch.no_grad():\n", 538 | " y_hat = self.net(x_qry[i], self.net.parameters(), bn_training = True)\n", 539 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 540 | " loss_list_qry[0] += loss_qry\n", 541 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 542 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 543 | " correct_list[0] += correct\n", 544 | " \n", 545 | " # 使用更新后的数据在query集上测试。\n", 546 | " with torch.no_grad():\n", 547 | " y_hat = self.net(x_qry[i], fast_weights, bn_training = True)\n", 548 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 549 | " loss_list_qry[1] += loss_qry\n", 550 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 551 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 552 | " correct_list[1] += correct \n", 553 | " \n", 554 | " for k in range(1, self.update_step):\n", 555 | " \n", 556 | " y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)\n", 557 | " loss = F.cross_entropy(y_hat, y_spt[i])\n", 558 | " grad = torch.autograd.grad(loss, fast_weights)\n", 559 | " tuples = zip(grad, fast_weights) \n", 560 | " fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))\n", 561 | " \n", 562 | " if k < self.update_step - 1:\n", 563 | " with torch.no_grad():\n", 564 | " y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)\n", 565 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 566 | " loss_list_qry[k+1] += loss_qry\n", 567 | " else:\n", 568 | " y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)\n", 569 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 570 | " loss_list_qry[k+1] += loss_qry\n", 571 | " \n", 572 | " with torch.no_grad():\n", 573 | " pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)\n", 574 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 575 | " correct_list[k+1] += correct\n", 576 | "# print('hello')\n", 577 | " \n", 578 | " loss_qry = loss_list_qry[-1] / task_num\n", 579 | " self.meta_optim.zero_grad() # 梯度清零\n", 580 | " loss_qry.backward()\n", 581 | " self.meta_optim.step()\n", 582 | " \n", 583 | " accs = np.array(correct_list) / (query_size * task_num)\n", 584 | " loss = np.array(loss_list_qry) / ( task_num)\n", 585 | " return accs,loss\n", 586 | "\n", 587 | " \n", 588 | " \n", 589 | " def finetunning(self, x_spt, y_spt, x_qry, y_qry):\n", 590 | " assert len(x_spt.shape) == 4\n", 591 | " \n", 592 | " query_size = x_qry.size(0)\n", 593 | " correct_list = [0 for _ in range(self.update_step_test + 1)]\n", 594 | " \n", 595 | " new_net = deepcopy(self.net)\n", 596 | " y_hat = new_net(x_spt)\n", 597 | " loss = F.cross_entropy(y_hat, y_spt)\n", 598 | " grad = torch.autograd.grad(loss, new_net.parameters())\n", 599 | " fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, new_net.parameters())))\n", 600 | " \n", 601 | " # 在query集上测试,计算准确率\n", 602 | " # 这一步使用更新前的数据\n", 603 | " with torch.no_grad():\n", 604 | " y_hat = new_net(x_qry, params = new_net.parameters(), bn_training = True)\n", 605 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 606 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 607 | " correct_list[0] += correct\n", 608 | "\n", 609 | " # 使用更新后的数据在query集上测试。\n", 610 | " with torch.no_grad():\n", 611 | " y_hat = new_net(x_qry, params = fast_weights, bn_training = True)\n", 612 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 613 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 614 | " correct_list[1] += correct\n", 615 | "\n", 616 | " for k in range(1, self.update_step_test):\n", 617 | " y_hat = new_net(x_spt, params = fast_weights, bn_training=True)\n", 618 | " loss = F.cross_entropy(y_hat, y_spt)\n", 619 | " grad = torch.autograd.grad(loss, fast_weights)\n", 620 | " fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, fast_weights)))\n", 621 | " \n", 622 | " y_hat = new_net(x_qry, fast_weights, bn_training=True)\n", 623 | " \n", 624 | " with torch.no_grad():\n", 625 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)\n", 626 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 627 | " correct_list[k+1] += correct\n", 628 | " \n", 629 | " del new_net\n", 630 | " accs = np.array(correct_list) / query_size\n", 631 | " return accs\n", 632 | " " 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": 19, 638 | "metadata": { 639 | "ExecuteTime": { 640 | "end_time": "2020-04-14T03:27:15.093572Z", 641 | "start_time": "2020-04-14T03:27:15.090246Z" 642 | } 643 | }, 644 | "outputs": [], 645 | "source": [ 646 | "# net = torch.load('./trained_models/MTL-5000epochs.pt')" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": { 653 | "ExecuteTime": { 654 | "start_time": "2020-04-14T03:27:13.031Z" 655 | }, 656 | "scrolled": true 657 | }, 658 | "outputs": [ 659 | { 660 | "name": "stdout", 661 | "output_type": "stream", 662 | "text": [ 663 | "epoch: 0\n", 664 | "[0.05 0.08625 0.1075 0.13270833 0.161875 0.19520833]\n", 665 | "在mean process之前: (992, 6)\n", 666 | "测试集准确率: [0.04962 0.1088 0.1405 0.1738 0.2063 0.2367 ]\n" 667 | ] 668 | } 669 | ], 670 | "source": [ 671 | "## omniglot\n", 672 | "import random\n", 673 | "random.seed(1337)\n", 674 | "np.random.seed(1337)\n", 675 | "\n", 676 | "import time\n", 677 | "device = torch.device('cuda:2')\n", 678 | "\n", 679 | "meta = MetaLearner().to(device)\n", 680 | "\n", 681 | "epochs = 60001\n", 682 | "for step in range(epochs):\n", 683 | " start = time.time()\n", 684 | " x_spt, y_spt, x_qry, y_qry = next('train')\n", 685 | " x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device),\\\n", 686 | " torch.from_numpy(y_spt).to(device),\\\n", 687 | " torch.from_numpy(x_qry).to(device),\\\n", 688 | " torch.from_numpy(y_qry).to(device)\n", 689 | " accs,loss = meta(x_spt, y_spt, x_qry, y_qry)\n", 690 | " end = time.time()\n", 691 | " if step % 100 == 0:\n", 692 | " print(\"epoch:\" ,step)\n", 693 | " print(accs)\n", 694 | "# print(loss)\n", 695 | " \n", 696 | " if step % 1000 == 0:\n", 697 | " accs = []\n", 698 | " for _ in range(1000//task_num):\n", 699 | " # db_train.next('test')\n", 700 | " x_spt, y_spt, x_qry, y_qry = next('test')\n", 701 | " x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device),\\\n", 702 | " torch.from_numpy(y_spt).to(device),\\\n", 703 | " torch.from_numpy(x_qry).to(device),\\\n", 704 | " torch.from_numpy(y_qry).to(device)\n", 705 | "\n", 706 | " \n", 707 | " for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):\n", 708 | " test_acc = meta.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)\n", 709 | " accs.append(test_acc)\n", 710 | " print('在mean process之前:',np.array(accs).shape)\n", 711 | " accs = np.array(accs).mean(axis=0).astype(np.float16)\n", 712 | " print('测试集准确率:',accs)" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": null, 718 | "metadata": { 719 | "ExecuteTime": { 720 | "end_time": "2020-03-01T03:00:56.266331Z", 721 | "start_time": "2020-03-01T03:00:56.205955Z" 722 | } 723 | }, 724 | "outputs": [], 725 | "source": [ 726 | "\n" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": null, 732 | "metadata": {}, 733 | "outputs": [], 734 | "source": [] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": null, 739 | "metadata": {}, 740 | "outputs": [], 741 | "source": [] 742 | } 743 | ], 744 | "metadata": { 745 | "kernelspec": { 746 | "display_name": "ML3.6", 747 | "language": "python", 748 | "name": "ml3.6" 749 | }, 750 | "language_info": { 751 | "codemirror_mode": { 752 | "name": "ipython", 753 | "version": 3 754 | }, 755 | "file_extension": ".py", 756 | "mimetype": "text/x-python", 757 | "name": "python", 758 | "nbconvert_exporter": "python", 759 | "pygments_lexer": "ipython3", 760 | "version": "3.6.9" 761 | }, 762 | "latex_envs": { 763 | "LaTeX_envs_menu_present": true, 764 | "autoclose": false, 765 | "autocomplete": true, 766 | "bibliofile": "biblio.bib", 767 | "cite_by": "apalike", 768 | "current_citInitial": 1, 769 | "eqLabelWithNumbers": true, 770 | "eqNumInitial": 1, 771 | "hotkeys": { 772 | "equation": "Ctrl-E", 773 | "itemize": "Ctrl-I" 774 | }, 775 | "labels_anchors": false, 776 | "latex_user_defs": false, 777 | "report_style_numbering": false, 778 | "user_envs_cfg": false 779 | }, 780 | "toc": { 781 | "base_numbering": 1, 782 | "nav_menu": {}, 783 | "number_sections": true, 784 | "sideBar": true, 785 | "skip_h1_title": false, 786 | "title_cell": "Table of Contents", 787 | "title_sidebar": "Contents", 788 | "toc_cell": false, 789 | "toc_position": {}, 790 | "toc_section_display": true, 791 | "toc_window_display": false 792 | }, 793 | "varInspector": { 794 | "cols": { 795 | "lenName": 16, 796 | "lenType": 16, 797 | "lenVar": 40 798 | }, 799 | "kernels_config": { 800 | "python": { 801 | "delete_cmd_postfix": "", 802 | "delete_cmd_prefix": "del ", 803 | "library": "var_list.py", 804 | "varRefreshCmd": "print(var_dic_list())" 805 | }, 806 | "r": { 807 | "delete_cmd_postfix": ") ", 808 | "delete_cmd_prefix": "rm(", 809 | "library": "var_list.r", 810 | "varRefreshCmd": "cat(var_dic_list()) " 811 | } 812 | }, 813 | "types_to_exclude": [ 814 | "module", 815 | "function", 816 | "builtin_function_or_method", 817 | "instance", 818 | "_Feature" 819 | ], 820 | "window_display": false 821 | } 822 | }, 823 | "nbformat": 4, 824 | "nbformat_minor": 2 825 | } 826 | -------------------------------------------------------------------------------- /MAML-omniglot-ADAM-20way-5shot-16batch-.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-04-16T02:33:53.570525Z", 9 | "start_time": "2020-04-16T02:33:53.256842Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import torch\n", 15 | "import numpy as np\n", 16 | "import os\n", 17 | "import zipfile\n", 18 | "\n", 19 | "# root_path = './../datasets'\n", 20 | "# processed_folder = os.path.join(root_path)\n", 21 | "\n", 22 | "# zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot.zip'), 'r')\n", 23 | "# zip_ref.extractall(root_path)\n", 24 | "# zip_ref.close()\n", 25 | "root_dir = './../datasets/omniglot/python'\n", 26 | "root_dir_train = os.path.join(root_dir,'images_background')" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "ExecuteTime": { 34 | "end_time": "2020-04-07T06:55:03.418079Z", 35 | "start_time": "2020-04-07T06:55:03.346124Z" 36 | } 37 | }, 38 | "outputs": [], 39 | "source": [] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "### 数据预处理\n", 46 | "拿到原始数据之后先将下面的代码取消注释,进行数据预处理。" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": { 53 | "ExecuteTime": { 54 | "end_time": "2020-04-16T02:33:55.423210Z", 55 | "start_time": "2020-04-16T02:33:55.405143Z" 56 | }, 57 | "scrolled": true 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "\n", 62 | "# # 数据预处理\n", 63 | "# import torchvision.transforms as transforms\n", 64 | "# from PIL import Image\n", 65 | "\n", 66 | "# '''\n", 67 | "# an example of img_items:\n", 68 | "# ( '0709_17.png',\n", 69 | "# 'Alphabet_of_the_Magi/character01',\n", 70 | "# './../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')\n", 71 | "# '''\n", 72 | "\n", 73 | "\n", 74 | "# root_dir_train = os.path.join(root_dir, 'images_background')\n", 75 | "# root_dir_test = os.path.join(root_dir, 'images_evaluation')\n", 76 | "\n", 77 | "# def find_classes(root_dir_train):\n", 78 | "# img_items = []\n", 79 | "# for (root, dirs, files) in os.walk(root_dir_train): \n", 80 | "# for file in files:\n", 81 | "# if (file.endswith(\"png\")):\n", 82 | "# r = root.split('/')\n", 83 | "# img_items.append((file, r[-2] + \"/\" + r[-1], root))\n", 84 | "# print(\"== Found %d items \" % len(img_items))\n", 85 | "# return img_items\n", 86 | "\n", 87 | "# ## 构建一个词典{class:idx}\n", 88 | "# def index_classes(items):\n", 89 | "# class_idx = {}\n", 90 | "# count = 0\n", 91 | "# for item in items:\n", 92 | "# if item[1] not in class_idx:\n", 93 | "# class_idx[item[1]] = count\n", 94 | "# count += 1\n", 95 | "# print('== Found {} classes'.format(len(class_idx)))\n", 96 | "# return class_idx\n", 97 | " \n", 98 | "\n", 99 | "# img_items_train = find_classes(root_dir_train) # [(file1, label1, root1),..]\n", 100 | "# img_items_test = find_classes(root_dir_test)\n", 101 | "\n", 102 | "# class_idx_train = index_classes(img_items_train)\n", 103 | "# class_idx_test = index_classes(img_items_test)\n", 104 | "\n", 105 | "\n", 106 | "# def generate_temp(img_items,class_idx):\n", 107 | "# temp = dict()\n", 108 | "# for imgname, classes, dirs in img_items:\n", 109 | "# img = '{}/{}'.format(dirs, imgname)\n", 110 | "# label = class_idx[classes]\n", 111 | "# transform = transforms.Compose([lambda img: Image.open(img).convert('L'),\n", 112 | "# lambda img: img.resize((28,28)),\n", 113 | "# lambda img: np.reshape(img, (28,28,1)),\n", 114 | "# lambda img: np.transpose(img, [2,0,1]),\n", 115 | "# lambda img: img/255.\n", 116 | "# ])\n", 117 | "# img = transform(img)\n", 118 | "# if label in temp.keys():\n", 119 | "# temp[label].append(img)\n", 120 | "# else:\n", 121 | "# temp[label] = [img]\n", 122 | "# print('begin to generate omniglot.npy')\n", 123 | "# return temp\n", 124 | "# ## 每个字符包含20个样本\n", 125 | "\n", 126 | "# temp_train = generate_temp(img_items_train, class_idx_train)\n", 127 | "# temp_test = generate_temp(img_items_test, class_idx_test)\n", 128 | "\n", 129 | "# img_list = []\n", 130 | "# for label, imgs in temp_train.items():\n", 131 | "# img_list.append(np.array(imgs))\n", 132 | "# img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]\n", 133 | "# print('data shape:{}'.format(img_list.shape)) # (964, 20, 1, 28, 28)\n", 134 | "# np.save(os.path.join(root_dir, 'omniglot_train.npy'), img_list)\n", 135 | "# print('end.')\n", 136 | "\n", 137 | "\n", 138 | "# img_list = []\n", 139 | "# for label, imgs in temp_test.items():\n", 140 | "# img_list.append(np.array(imgs))\n", 141 | "# img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]\n", 142 | "# print('data shape:{}'.format(img_list.shape)) # (659, 20, 1, 28, 28)\n", 143 | "\n", 144 | "# np.save(os.path.join(root_dir, 'omniglot_test.npy'), img_list)\n", 145 | "# print('end.')" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "ExecuteTime": { 153 | "end_time": "2020-04-07T08:02:03.283025Z", 154 | "start_time": "2020-04-07T08:02:03.276106Z" 155 | } 156 | }, 157 | "outputs": [], 158 | "source": [] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 3, 163 | "metadata": { 164 | "ExecuteTime": { 165 | "end_time": "2020-04-16T02:34:05.401785Z", 166 | "start_time": "2020-04-16T02:34:03.474067Z" 167 | } 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "img_list_train = np.load(os.path.join(root_dir, 'omniglot_train.npy')) # (964, 20, 1, 28, 28)\n", 172 | "img_list_test = np.load(os.path.join(root_dir, 'omniglot_test.npy')) # (659, 20, 1, 28, 28)\n", 173 | "\n", 174 | "x_train = img_list_train\n", 175 | "x_test = img_list_test\n", 176 | "# num_classes = img_list.shape[0]\n", 177 | "datasets = {'train': x_train, 'test': x_test}" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 4, 197 | "metadata": { 198 | "ExecuteTime": { 199 | "end_time": "2020-04-16T02:34:22.871343Z", 200 | "start_time": "2020-04-16T02:34:21.674782Z" 201 | }, 202 | "code_folding": [] 203 | }, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "DB: train (964, 20, 1, 28, 28) test (659, 20, 1, 28, 28)\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "### 准备数据迭代器\n", 215 | "n_way = 20\n", 216 | "k_spt = 5 ## support data 的个数\n", 217 | "k_query = 15 ## query data 的个数\n", 218 | "imgsz = 28\n", 219 | "resize = imgsz\n", 220 | "task_num = 16\n", 221 | "batch_size = task_num\n", 222 | "\n", 223 | "indexes = {\"train\": 0, \"test\": 0}\n", 224 | "datasets = {\"train\": x_train, \"test\": x_test}\n", 225 | "print(\"DB: train\", x_train.shape, \"test\", x_test.shape)\n", 226 | "\n", 227 | "\n", 228 | "def load_data_cache(dataset):\n", 229 | " \"\"\"\n", 230 | " Collects several batches data for N-shot learning\n", 231 | " :param dataset: [cls_num, 20, 84, 84, 1]\n", 232 | " :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks\n", 233 | " \"\"\"\n", 234 | " # take 5 way 1 shot as example: 5 * 1\n", 235 | " setsz = k_spt * n_way\n", 236 | " querysz = k_query * n_way\n", 237 | " data_cache = []\n", 238 | "\n", 239 | " # print('preload next 10 caches of batch_size of batch.')\n", 240 | " for sample in range(10): # num of epochs\n", 241 | "\n", 242 | " x_spts, y_spts, x_qrys, y_qrys = [], [], [], []\n", 243 | " for i in range(batch_size): # one batch means one set\n", 244 | "\n", 245 | " x_spt, y_spt, x_qry, y_qry = [], [], [], []\n", 246 | " selected_cls = np.random.choice(dataset.shape[0], n_way, replace = False) \n", 247 | "\n", 248 | " for j, cur_class in enumerate(selected_cls):\n", 249 | "\n", 250 | " selected_img = np.random.choice(20, k_spt + k_query, replace = False)\n", 251 | "\n", 252 | " # 构造support集和query集\n", 253 | " x_spt.append(dataset[cur_class][selected_img[:k_spt]])\n", 254 | " x_qry.append(dataset[cur_class][selected_img[k_spt:]])\n", 255 | " y_spt.append([j for _ in range(k_spt)])\n", 256 | " y_qry.append([j for _ in range(k_query)])\n", 257 | "\n", 258 | " # shuffle inside a batch\n", 259 | " perm = np.random.permutation(n_way * k_spt)\n", 260 | " x_spt = np.array(x_spt).reshape(n_way * k_spt, 1, resize, resize)[perm]\n", 261 | " y_spt = np.array(y_spt).reshape(n_way * k_spt)[perm]\n", 262 | " perm = np.random.permutation(n_way * k_query)\n", 263 | " x_qry = np.array(x_qry).reshape(n_way * k_query, 1, resize, resize)[perm]\n", 264 | " y_qry = np.array(y_qry).reshape(n_way * k_query)[perm]\n", 265 | " \n", 266 | " # append [sptsz, 1, 84, 84] => [batch_size, setsz, 1, 84, 84]\n", 267 | " x_spts.append(x_spt)\n", 268 | " y_spts.append(y_spt)\n", 269 | " x_qrys.append(x_qry)\n", 270 | " y_qrys.append(y_qry)\n", 271 | "\n", 272 | "# print(x_spts[0].shape)\n", 273 | " # [b, setsz = n_way * k_spt, 1, 84, 84]\n", 274 | " x_spts = np.array(x_spts).astype(np.float32).reshape(batch_size, setsz, 1, resize, resize)\n", 275 | " y_spts = np.array(y_spts).astype(np.int).reshape(batch_size, setsz)\n", 276 | " # [b, qrysz = n_way * k_query, 1, 84, 84]\n", 277 | " x_qrys = np.array(x_qrys).astype(np.float32).reshape(batch_size, querysz, 1, resize, resize)\n", 278 | " y_qrys = np.array(y_qrys).astype(np.int).reshape(batch_size, querysz)\n", 279 | "# print(x_qrys.shape)\n", 280 | " data_cache.append([x_spts, y_spts, x_qrys, y_qrys])\n", 281 | "\n", 282 | " return data_cache\n", 283 | "\n", 284 | "datasets_cache = {\"train\": load_data_cache(x_train), # current epoch data cached\n", 285 | " \"test\": load_data_cache(x_test)}\n", 286 | "\n", 287 | "def next(mode='train'):\n", 288 | " \"\"\"\n", 289 | " Gets next batch from the dataset with name.\n", 290 | " :param mode: The name of the splitting (one of \"train\", \"val\", \"test\")\n", 291 | " :return:\n", 292 | " \"\"\"\n", 293 | " # update cache if indexes is larger than len(data_cache)\n", 294 | " if indexes[mode] >= len(datasets_cache[mode]):\n", 295 | " indexes[mode] = 0\n", 296 | " datasets_cache[mode] = load_data_cache(datasets[mode])\n", 297 | "\n", 298 | " next_batch = datasets_cache[mode][indexes[mode]]\n", 299 | " indexes[mode] += 1\n", 300 | "\n", 301 | " return next_batch\n" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 5, 328 | "metadata": { 329 | "ExecuteTime": { 330 | "end_time": "2020-04-16T02:34:25.602642Z", 331 | "start_time": "2020-04-16T02:34:25.547982Z" 332 | }, 333 | "code_folding": [] 334 | }, 335 | "outputs": [], 336 | "source": [ 337 | "import torch\n", 338 | "from torch import nn\n", 339 | "from torch.nn import functional as F\n", 340 | "from copy import deepcopy,copy\n", 341 | " \n", 342 | "\n", 343 | "class BaseNet(nn.Module):\n", 344 | " def __init__(self):\n", 345 | " super(BaseNet, self).__init__()\n", 346 | " self.vars = nn.ParameterList() ## 包含了所有需要被优化的tensor\n", 347 | " self.vars_bn = nn.ParameterList()\n", 348 | " \n", 349 | " # 第1个conv2d\n", 350 | " weight = nn.Parameter(torch.ones(64, 1, 3, 3))\n", 351 | " nn.init.kaiming_normal_(weight)\n", 352 | " bias = nn.Parameter(torch.zeros(64))\n", 353 | " self.vars.extend([weight,bias])\n", 354 | " \n", 355 | " # 第1个BatchNorm层\n", 356 | " weight = nn.Parameter(torch.ones(64))\n", 357 | " bias = nn.Parameter(torch.zeros(64))\n", 358 | " self.vars.extend([weight,bias])\n", 359 | " \n", 360 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 361 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 362 | " self.vars_bn.extend([running_mean, running_var])\n", 363 | " \n", 364 | " # 第2个conv2d\n", 365 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 366 | " nn.init.kaiming_normal_(weight)\n", 367 | " bias = nn.Parameter(torch.zeros(64))\n", 368 | " self.vars.extend([weight,bias])\n", 369 | " \n", 370 | " # 第2个BatchNorm层\n", 371 | " weight = nn.Parameter(torch.ones(64))\n", 372 | " bias = nn.Parameter(torch.zeros(64))\n", 373 | " self.vars.extend([weight,bias])\n", 374 | " \n", 375 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 376 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 377 | " self.vars_bn.extend([running_mean, running_var])\n", 378 | " \n", 379 | " # 第3个conv2d\n", 380 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 381 | " nn.init.kaiming_normal_(weight)\n", 382 | " bias = nn.Parameter(torch.zeros(64))\n", 383 | " self.vars.extend([weight,bias])\n", 384 | " \n", 385 | " # 第3个BatchNorm层\n", 386 | " weight = nn.Parameter(torch.ones(64))\n", 387 | " bias = nn.Parameter(torch.zeros(64))\n", 388 | " self.vars.extend([weight,bias])\n", 389 | " \n", 390 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 391 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 392 | " self.vars_bn.extend([running_mean, running_var])\n", 393 | " \n", 394 | " # 第4个conv2d\n", 395 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 396 | " nn.init.kaiming_normal_(weight)\n", 397 | " bias = nn.Parameter(torch.zeros(64))\n", 398 | " self.vars.extend([weight,bias])\n", 399 | " \n", 400 | " # 第4个BatchNorm层\n", 401 | " weight = nn.Parameter(torch.ones(64))\n", 402 | " bias = nn.Parameter(torch.zeros(64))\n", 403 | " self.vars.extend([weight,bias])\n", 404 | " \n", 405 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 406 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 407 | " self.vars_bn.extend([running_mean, running_var])\n", 408 | " \n", 409 | " ##linear\n", 410 | " weight = nn.Parameter(torch.ones([20,64]))\n", 411 | " bias = nn.Parameter(torch.zeros(20))\n", 412 | " self.vars.extend([weight,bias])\n", 413 | " \n", 414 | " def forward(self, x, params = None, bn_training=True):\n", 415 | " '''\n", 416 | " :bn_training: set False to not update\n", 417 | " :return: \n", 418 | " '''\n", 419 | " if params is None:\n", 420 | " params = self.vars\n", 421 | " \n", 422 | " weight, bias = params[0], params[1] # 第1个CONV层\n", 423 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 424 | " weight, bias = params[2], params[3] # 第1个BN层\n", 425 | " running_mean, running_var = self.vars_bn[0], self.vars_bn[1]\n", 426 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training, momentum = 1)\n", 427 | " x = F.relu(x, inplace = [True]) #第1个relu\n", 428 | " x = F.max_pool2d(x,kernel_size=2) #第1个MAX_POOL层 \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " weight, bias = params[4], params[5] # 第2个CONV层\n", 433 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 434 | " weight, bias = params[6], params[7] # 第2个BN层\n", 435 | " running_mean, running_var = self.vars_bn[2], self.vars_bn[3]\n", 436 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training, momentum=1)\n", 437 | " x = F.relu(x, inplace = [True]) #第2个relu\n", 438 | " x = F.max_pool2d(x,kernel_size=2) #第2个MAX_POOL层 \n", 439 | " \n", 440 | " \n", 441 | " weight, bias = params[8], params[9] # 第3个CONV层\n", 442 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 443 | " weight, bias = params[10], params[11] # 第3个BN层\n", 444 | " running_mean, running_var = self.vars_bn[4], self.vars_bn[5]\n", 445 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training,momentum=1)\n", 446 | " x = F.relu(x, inplace = [True]) #第3个relu,\n", 447 | " x = F.max_pool2d(x,kernel_size=2) #第3个MAX_POOL层\n", 448 | " \n", 449 | " \n", 450 | " weight, bias = params[12], params[13] # 第4个CONV层\n", 451 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 452 | " weight, bias = params[14], params[15] # 第4个BN层\n", 453 | " running_mean, running_var = self.vars_bn[6], self.vars_bn[7]\n", 454 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training)\n", 455 | " x = F.max_pool2d(x,kernel_size=2) #第4个MAX_POOL层\n", 456 | " \n", 457 | " x = F.relu(x, inplace = [True]) #第4个relu\n", 458 | " \n", 459 | " x = x.view(x.size(0), -1) ## flatten\n", 460 | " weight, bias = params[-2], params[-1] # linear\n", 461 | " x = F.linear(x, weight, bias)\n", 462 | " \n", 463 | " output = x\n", 464 | " \n", 465 | " return output\n", 466 | " \n", 467 | " \n", 468 | " def parameters(self):\n", 469 | " \n", 470 | " return self.vars\n" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "metadata": { 477 | "ExecuteTime": { 478 | "end_time": "2020-02-29T12:00:30.197710Z", 479 | "start_time": "2020-02-29T12:00:30.186076Z" 480 | } 481 | }, 482 | "outputs": [], 483 | "source": [] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "metadata": { 489 | "ExecuteTime": { 490 | "end_time": "2020-02-29T05:41:40.773998Z", 491 | "start_time": "2020-02-29T05:41:40.762077Z" 492 | } 493 | }, 494 | "outputs": [], 495 | "source": [] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 12, 500 | "metadata": { 501 | "ExecuteTime": { 502 | "end_time": "2020-04-17T03:02:31.026926Z", 503 | "start_time": "2020-04-17T03:02:30.989680Z" 504 | }, 505 | "code_folding": [] 506 | }, 507 | "outputs": [], 508 | "source": [ 509 | "class MetaLearner(nn.Module):\n", 510 | " def __init__(self):\n", 511 | " super(MetaLearner, self).__init__()\n", 512 | " self.update_step = 5 ## task-level inner update steps\n", 513 | " self.update_step_test = 5\n", 514 | " self.net = BaseNet()\n", 515 | " self.meta_lr = 0.0008\n", 516 | " self.base_lr = 0.075\n", 517 | " self.meta_optim = torch.optim.Adam(self.net.parameters(), lr = self.meta_lr)\n", 518 | "# self.meta_optim = torch.optim.SGD(self.net.parameters(), lr = self.meta_lr, momentum = 0.9, weight_decay=0.0005)\n", 519 | " \n", 520 | " def forward(self,x_spt, y_spt, x_qry, y_qry):\n", 521 | " # 初始化\n", 522 | " task_num, ways, shots, h, w = x_spt.size()\n", 523 | " query_size = x_qry.size(1) # 75 = 15 * 5\n", 524 | " loss_list_qry = [0 for _ in range(self.update_step + 1)]\n", 525 | " correct_list = [0 for _ in range(self.update_step + 1)]\n", 526 | " \n", 527 | " for i in range(task_num):\n", 528 | " ## 第0步更新\n", 529 | " y_hat = self.net(x_spt[i], params = None, bn_training=True) # (ways * shots, ways)\n", 530 | " loss = F.cross_entropy(y_hat, y_spt[i]) \n", 531 | " grad = torch.autograd.grad(loss, self.net.parameters())\n", 532 | " tuples = zip(grad, self.net.parameters()) ## 将梯度和参数\\theta一一对应起来\n", 533 | " # fast_weights这一步相当于求了一个\\theta - \\alpha*\\nabla(L)\n", 534 | " fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))\n", 535 | " # 在query集上测试,计算准确率\n", 536 | " # 这一步使用更新前的数据\n", 537 | " with torch.no_grad():\n", 538 | " y_hat = self.net(x_qry[i], self.net.parameters(), bn_training = True)\n", 539 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 540 | " loss_list_qry[0] += loss_qry\n", 541 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 542 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 543 | " correct_list[0] += correct\n", 544 | " \n", 545 | " # 使用更新后的数据在query集上测试。\n", 546 | " with torch.no_grad():\n", 547 | " y_hat = self.net(x_qry[i], fast_weights, bn_training = True)\n", 548 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 549 | " loss_list_qry[1] += loss_qry\n", 550 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 551 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 552 | " correct_list[1] += correct \n", 553 | " \n", 554 | " for k in range(1, self.update_step):\n", 555 | " \n", 556 | " y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)\n", 557 | " loss = F.cross_entropy(y_hat, y_spt[i])\n", 558 | " grad = torch.autograd.grad(loss, fast_weights)\n", 559 | " tuples = zip(grad, fast_weights) \n", 560 | " fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))\n", 561 | " \n", 562 | " if k < self.update_step - 1:\n", 563 | " with torch.no_grad():\n", 564 | " y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)\n", 565 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 566 | " loss_list_qry[k+1] += loss_qry\n", 567 | " else:\n", 568 | " y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)\n", 569 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 570 | " loss_list_qry[k+1] += loss_qry\n", 571 | " \n", 572 | " with torch.no_grad():\n", 573 | " pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)\n", 574 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 575 | " correct_list[k+1] += correct\n", 576 | "# print('hello')\n", 577 | " \n", 578 | " loss_qry = loss_list_qry[-1] / task_num\n", 579 | " self.meta_optim.zero_grad() # 梯度清零\n", 580 | " loss_qry.backward()\n", 581 | " self.meta_optim.step()\n", 582 | " \n", 583 | " accs = np.array(correct_list) / (query_size * task_num)\n", 584 | " loss = np.array(loss_list_qry) / ( task_num)\n", 585 | " return accs,loss\n", 586 | "\n", 587 | " \n", 588 | " \n", 589 | " def finetunning(self, x_spt, y_spt, x_qry, y_qry):\n", 590 | " assert len(x_spt.shape) == 4\n", 591 | " \n", 592 | " query_size = x_qry.size(0)\n", 593 | " correct_list = [0 for _ in range(self.update_step_test + 1)]\n", 594 | " \n", 595 | " new_net = deepcopy(self.net)\n", 596 | " y_hat = new_net(x_spt)\n", 597 | " loss = F.cross_entropy(y_hat, y_spt)\n", 598 | " grad = torch.autograd.grad(loss, new_net.parameters())\n", 599 | " fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, new_net.parameters())))\n", 600 | " \n", 601 | " # 在query集上测试,计算准确率\n", 602 | " # 这一步使用更新前的数据\n", 603 | " with torch.no_grad():\n", 604 | " y_hat = new_net(x_qry, params = new_net.parameters(), bn_training = True)\n", 605 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 606 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 607 | " correct_list[0] += correct\n", 608 | "\n", 609 | " # 使用更新后的数据在query集上测试。\n", 610 | " with torch.no_grad():\n", 611 | " y_hat = new_net(x_qry, params = fast_weights, bn_training = True)\n", 612 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 613 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 614 | " correct_list[1] += correct\n", 615 | "\n", 616 | " for k in range(1, self.update_step_test):\n", 617 | " y_hat = new_net(x_spt, params = fast_weights, bn_training=True)\n", 618 | " loss = F.cross_entropy(y_hat, y_spt)\n", 619 | " grad = torch.autograd.grad(loss, fast_weights)\n", 620 | " fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, fast_weights)))\n", 621 | " \n", 622 | " y_hat = new_net(x_qry, fast_weights, bn_training=True)\n", 623 | " \n", 624 | " with torch.no_grad():\n", 625 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)\n", 626 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 627 | " correct_list[k+1] += correct\n", 628 | " \n", 629 | " del new_net\n", 630 | " accs = np.array(correct_list) / query_size\n", 631 | " return accs\n", 632 | " " 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": 13, 638 | "metadata": { 639 | "ExecuteTime": { 640 | "end_time": "2020-04-17T03:02:38.654198Z", 641 | "start_time": "2020-04-17T03:02:38.651549Z" 642 | } 643 | }, 644 | "outputs": [], 645 | "source": [ 646 | "# net = torch.load('./trained_models/MTL-5000epochs.pt')" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": 15, 652 | "metadata": { 653 | "ExecuteTime": { 654 | "end_time": "2020-04-17T12:33:25.158244Z", 655 | "start_time": "2020-04-17T03:03:26.192147Z" 656 | }, 657 | "scrolled": true 658 | }, 659 | "outputs": [ 660 | { 661 | "name": "stdout", 662 | "output_type": "stream", 663 | "text": [ 664 | "epoch: 0\n", 665 | "[0.05 0.155625 0.1775 0.19875 0.22145833 0.25270833]\n", 666 | "在mean process之前: (992, 6)\n", 667 | "测试集准确率: [0.05008 0.1423 0.1737 0.2052 0.236 0.2651 ]\n", 668 | "epoch: 100\n", 669 | "[0.045625 0.50666667 0.65104167 0.74479167 0.79166667 0.819375 ]\n", 670 | "epoch: 200\n", 671 | "[0.05104167 0.70208333 0.84291667 0.90270833 0.92166667 0.931875 ]\n", 672 | "epoch: 300\n", 673 | "[0.05 0.28020833 0.728125 0.81791667 0.85729167 0.876875 ]\n", 674 | "epoch: 400\n", 675 | "[0.05 0.56208333 0.80291667 0.89854167 0.94770833 0.960625 ]\n", 676 | "epoch: 500\n", 677 | "[0.05 0.53666667 0.81395833 0.926875 0.948125 0.95958333]\n", 678 | "epoch: 600\n", 679 | "[0.05 0.54020833 0.85041667 0.95395833 0.9625 0.966875 ]\n", 680 | "epoch: 700\n", 681 | "[0.05 0.4775 0.84375 0.9425 0.94833333 0.94729167]\n", 682 | "epoch: 800\n", 683 | "[0.05 0.47833333 0.82541667 0.90229167 0.908125 0.9175 ]\n", 684 | "epoch: 900\n", 685 | "[0.05 0.46354167 0.84291667 0.91666667 0.91041667 0.92979167]\n", 686 | "epoch: 1000\n", 687 | "[0.05 0.44541667 0.80625 0.87041667 0.87979167 0.90583333]\n", 688 | "在mean process之前: (992, 6)\n", 689 | "测试集准确率: [0.05 0.3694 0.729 0.783 0.8086 0.857 ]\n", 690 | "epoch: 1100\n", 691 | "[0.05 0.441875 0.80333333 0.85416667 0.89520833 0.92708333]\n", 692 | "epoch: 1200\n", 693 | "[0.05 0.143125 0.72583333 0.78479167 0.84958333 0.900625 ]\n", 694 | "epoch: 1300\n", 695 | "[0.05 0.05 0.05 0.68291667 0.82416667 0.8975 ]\n", 696 | "epoch: 1400\n", 697 | "[0.05 0.05 0.05 0.773125 0.896875 0.9275 ]\n", 698 | "epoch: 1500\n", 699 | "[0.05 0.05 0.05 0.82208333 0.91541667 0.93770833]\n", 700 | "epoch: 1600\n", 701 | "[0.05 0.05 0.05 0.84416667 0.91583333 0.93916667]\n", 702 | "epoch: 1700\n", 703 | "[0.05 0.05 0.05 0.840625 0.91958333 0.94125 ]\n", 704 | "epoch: 1800\n", 705 | "[0.05 0.05 0.05 0.8875 0.93375 0.94729167]\n", 706 | "epoch: 1900\n", 707 | "[0.05 0.05 0.05 0.87916667 0.944375 0.954375 ]\n", 708 | "epoch: 2000\n", 709 | "[0.05 0.05 0.05 0.90416667 0.92833333 0.94395833]\n", 710 | "在mean process之前: (992, 6)\n", 711 | "测试集准确率: [0.05 0.05 0.05 0.8057 0.8794 0.906 ]\n", 712 | "epoch: 2100\n", 713 | "[0.05 0.05 0.05 0.91 0.94291667 0.95458333]\n", 714 | "epoch: 2200\n", 715 | "[0.05 0.05 0.05 0.90625 0.95458333 0.964375 ]\n", 716 | "epoch: 2300\n", 717 | "[0.05 0.05 0.05 0.91291667 0.94145833 0.94395833]\n", 718 | "epoch: 2400\n", 719 | "[0.05 0.05 0.05 0.935625 0.95541667 0.96125 ]\n", 720 | "epoch: 2500\n", 721 | "[0.05 0.05 0.05 0.94041667 0.96916667 0.97625 ]\n", 722 | "epoch: 2600\n", 723 | "[0.05 0.05 0.05 0.94104167 0.96875 0.97270833]\n", 724 | "epoch: 2700\n", 725 | "[0.05 0.05 0.05 0.95479167 0.96895833 0.97395833]\n", 726 | "epoch: 2800\n", 727 | "[0.05 0.05 0.05 0.955 0.97791667 0.97770833]\n", 728 | "epoch: 2900\n", 729 | "[0.05 0.05 0.05 0.95583333 0.97125 0.978125 ]\n", 730 | "epoch: 3000\n", 731 | "[0.05 0.05 0.05 0.946875 0.969375 0.975 ]\n", 732 | "在mean process之前: (992, 6)\n", 733 | "测试集准确率: [0.05 0.05 0.05 0.8594 0.9204 0.933 ]\n", 734 | "epoch: 3100\n", 735 | "[0.05 0.05 0.05 0.95583333 0.97541667 0.97229167]\n", 736 | "epoch: 3200\n", 737 | "[0.05 0.05 0.05 0.955 0.976875 0.97770833]\n", 738 | "epoch: 3300\n", 739 | "[0.05 0.05 0.05 0.96083333 0.98229167 0.98208333]\n", 740 | "epoch: 3400\n", 741 | "[0.05 0.05 0.05 0.94354167 0.97020833 0.97583333]\n", 742 | "epoch: 3500\n", 743 | "[0.05 0.05 0.05 0.96895833 0.97583333 0.97479167]\n", 744 | "epoch: 3600\n", 745 | "[0.05 0.05 0.05 0.95833333 0.97104167 0.97645833]\n", 746 | "epoch: 3700\n", 747 | "[0.05 0.05 0.05 0.95708333 0.97604167 0.97895833]\n", 748 | "epoch: 3800\n", 749 | "[0.05 0.05 0.05 0.96770833 0.97895833 0.98083333]\n", 750 | "epoch: 3900\n", 751 | "[0.05 0.05 0.05 0.96 0.97666667 0.97979167]\n", 752 | "epoch: 4000\n", 753 | "[0.05 0.05 0.05 0.97333333 0.97875 0.976875 ]\n", 754 | "在mean process之前: (992, 6)\n", 755 | "测试集准确率: [0.05 0.05 0.05 0.8857 0.933 0.943 ]\n", 756 | "epoch: 4100\n", 757 | "[0.05 0.05 0.05 0.96625 0.9775 0.98041667]\n", 758 | "epoch: 4200\n", 759 | "[0.05 0.05 0.05 0.97125 0.9825 0.98354167]\n", 760 | "epoch: 4300\n", 761 | "[0.05 0.05 0.05 0.96604167 0.976875 0.97625 ]\n", 762 | "epoch: 4400\n", 763 | "[0.05 0.05 0.05 0.97520833 0.983125 0.98583333]\n", 764 | "epoch: 4500\n", 765 | "[0.05 0.05 0.05 0.9675 0.97333333 0.98083333]\n", 766 | "epoch: 4600\n", 767 | "[0.05 0.05 0.05 0.96291667 0.96958333 0.97291667]\n", 768 | "epoch: 4700\n", 769 | "[0.05 0.05 0.05 0.96958333 0.98229167 0.98458333]\n", 770 | "epoch: 4800\n", 771 | "[0.05 0.05 0.05 0.97 0.98395833 0.98416667]\n", 772 | "epoch: 4900\n", 773 | "[0.05 0.05 0.05 0.97333333 0.985 0.98479167]\n", 774 | "epoch: 5000\n", 775 | "[0.05 0.05 0.05 0.97854167 0.98604167 0.988125 ]\n", 776 | "在mean process之前: (992, 6)\n", 777 | "测试集准确率: [0.05 0.05 0.05 0.8945 0.9355 0.943 ]\n", 778 | "epoch: 5100\n", 779 | "[0.05 0.05 0.05 0.97125 0.976875 0.98041667]\n", 780 | "epoch: 5200\n", 781 | "[0.05 0.05 0.05 0.97375 0.98020833 0.98208333]\n", 782 | "epoch: 5300\n", 783 | "[0.05 0.05 0.05 0.97583333 0.98520833 0.98895833]\n", 784 | "epoch: 5400\n", 785 | "[0.05 0.05 0.05 0.95729167 0.97083333 0.97479167]\n", 786 | "epoch: 5500\n", 787 | "[0.05 0.05 0.05 0.97270833 0.97729167 0.983125 ]\n", 788 | "epoch: 5600\n", 789 | "[0.05 0.05 0.05 0.97854167 0.99 0.99104167]\n", 790 | "epoch: 5700\n", 791 | "[0.05 0.05 0.05 0.97270833 0.98395833 0.98291667]\n", 792 | "epoch: 5800\n", 793 | "[0.05 0.05 0.05 0.9825 0.98791667 0.98875 ]\n", 794 | "epoch: 5900\n", 795 | "[0.05 0.05 0.05 0.974375 0.98291667 0.98541667]\n", 796 | "epoch: 6000\n", 797 | "[0.05 0.05 0.05 0.97479167 0.98416667 0.98291667]\n", 798 | "在mean process之前: (992, 6)\n", 799 | "测试集准确率: [0.05 0.05 0.05 0.887 0.9316 0.9395]\n", 800 | "epoch: 6100\n", 801 | "[0.05 0.05 0.05 0.98 0.988125 0.98833333]\n", 802 | "epoch: 6200\n", 803 | "[0.05 0.05 0.05 0.973125 0.98125 0.98166667]\n", 804 | "epoch: 6300\n", 805 | "[0.05 0.05 0.05 0.97395833 0.979375 0.98458333]\n", 806 | "epoch: 6400\n", 807 | "[0.05 0.05 0.05 0.97791667 0.98541667 0.98791667]\n", 808 | "epoch: 6500\n", 809 | "[0.05 0.05 0.05 0.97791667 0.978125 0.98083333]\n", 810 | "epoch: 6600\n", 811 | "[0.05 0.05 0.05 0.98041667 0.98520833 0.98583333]\n", 812 | "epoch: 6700\n", 813 | "[0.05 0.05 0.05 0.98645833 0.98979167 0.99166667]\n", 814 | "epoch: 6800\n", 815 | "[0.05 0.05 0.05 0.97854167 0.98270833 0.98666667]\n", 816 | "epoch: 6900\n", 817 | "[0.05 0.05 0.05 0.96666667 0.97354167 0.97708333]\n", 818 | "epoch: 7000\n", 819 | "[0.05 0.05 0.05 0.98354167 0.98458333 0.98458333]\n", 820 | "在mean process之前: (992, 6)\n", 821 | "测试集准确率: [0.05 0.05 0.05 0.9087 0.939 0.944 ]\n", 822 | "epoch: 7100\n", 823 | "[0.05 0.05 0.05 0.98 0.985 0.98583333]\n", 824 | "epoch: 7200\n", 825 | "[0.05 0.05 0.05 0.97291667 0.98541667 0.98666667]\n", 826 | "epoch: 7300\n", 827 | "[0.05 0.05 0.05 0.97625 0.98229167 0.98208333]\n", 828 | "epoch: 7400\n", 829 | "[0.05 0.05 0.05 0.98145833 0.98916667 0.98875 ]\n", 830 | "epoch: 7500\n", 831 | "[0.05 0.05 0.05 0.98041667 0.983125 0.98333333]\n", 832 | "epoch: 7600\n", 833 | "[0.05 0.05 0.05 0.97416667 0.97854167 0.98458333]\n", 834 | "epoch: 7700\n", 835 | "[0.05 0.05 0.05 0.979375 0.98520833 0.985625 ]\n", 836 | "epoch: 7800\n", 837 | "[0.05 0.05 0.05 0.975625 0.98270833 0.98145833]\n", 838 | "epoch: 7900\n", 839 | "[0.05 0.05 0.05 0.98520833 0.98729167 0.98875 ]\n", 840 | "epoch: 8000\n", 841 | "[0.05 0.05 0.05 0.98104167 0.98666667 0.98625 ]\n", 842 | "在mean process之前: (992, 6)\n", 843 | "测试集准确率: [0.05 0.05 0.05 0.9126 0.9395 0.9434]\n", 844 | "epoch: 8100\n", 845 | "[0.05 0.05 0.05 0.98 0.98625 0.98729167]\n", 846 | "epoch: 8200\n", 847 | "[0.05 0.05 0.05 0.98375 0.98916667 0.98958333]\n", 848 | "epoch: 8300\n", 849 | "[0.05 0.05 0.05 0.98166667 0.98666667 0.986875 ]\n", 850 | "epoch: 8400\n", 851 | "[0.05 0.05 0.05 0.97020833 0.97604167 0.98041667]\n", 852 | "epoch: 8500\n", 853 | "[0.05 0.05 0.05 0.98375 0.98875 0.99020833]\n", 854 | "epoch: 8600\n", 855 | "[0.05 0.05 0.05 0.98166667 0.98208333 0.98583333]\n", 856 | "epoch: 8700\n", 857 | "[0.05 0.05 0.05 0.98458333 0.985625 0.9875 ]\n", 858 | "epoch: 8800\n", 859 | "[0.05 0.05 0.05 0.98229167 0.98708333 0.98645833]\n", 860 | "epoch: 8900\n", 861 | "[0.05 0.05 0.05 0.985625 0.99125 0.99270833]\n", 862 | "epoch: 9000\n", 863 | "[0.05 0.05 0.05 0.980625 0.98625 0.98854167]\n", 864 | "在mean process之前: (992, 6)\n", 865 | "测试集准确率: [0.05 0.05 0.05 0.9062 0.9355 0.94 ]\n", 866 | "epoch: 9100\n", 867 | "[0.05 0.05 0.05 0.98708333 0.99229167 0.99083333]\n", 868 | "epoch: 9200\n", 869 | "[0.05 0.05 0.05 0.985625 0.986875 0.98979167]\n", 870 | "epoch: 9300\n", 871 | "[0.05 0.05 0.05 0.986875 0.98479167 0.98583333]\n" 872 | ] 873 | }, 874 | { 875 | "name": "stdout", 876 | "output_type": "stream", 877 | "text": [ 878 | "epoch: 9400\n", 879 | "[0.05 0.05 0.05 0.98520833 0.99145833 0.99270833]\n", 880 | "epoch: 9500\n", 881 | "[0.05 0.05 0.05 0.981875 0.98833333 0.989375 ]\n", 882 | "epoch: 9600\n", 883 | "[0.05 0.05 0.05 0.980625 0.99020833 0.99208333]\n", 884 | "epoch: 9700\n", 885 | "[0.05 0.05 0.05 0.98375 0.98708333 0.98875 ]\n", 886 | "epoch: 9800\n", 887 | "[0.05 0.05 0.05 0.985 0.99125 0.99104167]\n", 888 | "epoch: 9900\n", 889 | "[0.05 0.05 0.05 0.98270833 0.983125 0.98375 ]\n", 890 | "epoch: 10000\n", 891 | "[0.05 0.05 0.05 0.97583333 0.98104167 0.98270833]\n", 892 | "在mean process之前: (992, 6)\n", 893 | "测试集准确率: [0.05 0.05 0.05 0.9106 0.9365 0.941 ]\n", 894 | "epoch: 10100\n", 895 | "[0.05 0.05 0.05 0.97854167 0.9825 0.98125 ]\n", 896 | "epoch: 10200\n", 897 | "[0.05 0.05 0.05 0.97770833 0.98104167 0.98020833]\n", 898 | "epoch: 10300\n", 899 | "[0.05 0.05 0.05 0.981875 0.98458333 0.98604167]\n", 900 | "epoch: 10400\n", 901 | "[0.05 0.05 0.05 0.98020833 0.98479167 0.98625 ]\n", 902 | "epoch: 10500\n", 903 | "[0.05 0.05 0.05 0.98229167 0.98854167 0.99083333]\n", 904 | "epoch: 10600\n", 905 | "[0.05 0.05 0.05 0.99041667 0.99041667 0.99208333]\n", 906 | "epoch: 10700\n", 907 | "[0.05 0.05 0.05 0.97770833 0.98541667 0.98791667]\n", 908 | "epoch: 10800\n", 909 | "[0.05 0.05 0.05 0.98875 0.99354167 0.99416667]\n", 910 | "epoch: 10900\n", 911 | "[0.05 0.05 0.05 0.99208333 0.99375 0.99416667]\n", 912 | "epoch: 11000\n", 913 | "[0.05 0.05 0.05 0.98145833 0.98645833 0.98416667]\n", 914 | "在mean process之前: (992, 6)\n", 915 | "测试集准确率: [0.05 0.05 0.05 0.914 0.9365 0.9404]\n", 916 | "epoch: 11100\n", 917 | "[0.05 0.05 0.05 0.98375 0.99041667 0.99125 ]\n", 918 | "epoch: 11200\n", 919 | "[0.05 0.05 0.05 0.97979167 0.98875 0.99020833]\n", 920 | "epoch: 11300\n", 921 | "[0.05 0.05 0.05 0.989375 0.98895833 0.99041667]\n", 922 | "epoch: 11400\n", 923 | "[0.05 0.05 0.05 0.99125 0.99291667 0.993125 ]\n", 924 | "epoch: 11500\n", 925 | "[0.05 0.05 0.05 0.99020833 0.99083333 0.990625 ]\n", 926 | "epoch: 11600\n", 927 | "[0.05 0.05 0.05 0.98020833 0.98395833 0.985625 ]\n", 928 | "epoch: 11700\n", 929 | "[0.05 0.05 0.05 0.98916667 0.99104167 0.99083333]\n", 930 | "epoch: 11800\n", 931 | "[0.05 0.05 0.05 0.98291667 0.99083333 0.99166667]\n", 932 | "epoch: 11900\n", 933 | "[0.05 0.05 0.05 0.98854167 0.99208333 0.99229167]\n", 934 | "epoch: 12000\n", 935 | "[0.05 0.05 0.05 0.98291667 0.98708333 0.98958333]\n", 936 | "在mean process之前: (992, 6)\n", 937 | "测试集准确率: [0.05 0.05 0.05 0.9106 0.9326 0.9365]\n", 938 | "epoch: 12100\n", 939 | "[0.05 0.05 0.05 0.989375 0.99104167 0.99208333]\n", 940 | "epoch: 12200\n", 941 | "[0.05 0.05 0.05 0.98958333 0.990625 0.99083333]\n", 942 | "epoch: 12300\n", 943 | "[0.05 0.05 0.05 0.98791667 0.991875 0.99208333]\n", 944 | "epoch: 12400\n", 945 | "[0.05 0.05 0.05 0.99145833 0.99625 0.99625 ]\n", 946 | "epoch: 12500\n", 947 | "[0.05 0.05 0.05 0.98395833 0.98604167 0.98729167]\n", 948 | "epoch: 12600\n", 949 | "[0.05 0.05 0.05 0.99229167 0.99354167 0.991875 ]\n", 950 | "epoch: 12700\n", 951 | "[0.05 0.05 0.05 0.99083333 0.99520833 0.99541667]\n", 952 | "epoch: 12800\n", 953 | "[0.05 0.05 0.05 0.99104167 0.99291667 0.9925 ]\n", 954 | "epoch: 12900\n", 955 | "[0.05 0.05 0.05 0.98979167 0.994375 0.99520833]\n", 956 | "epoch: 13000\n", 957 | "[0.05 0.05 0.05 0.98875 0.9925 0.99583333]\n", 958 | "在mean process之前: (992, 6)\n", 959 | "测试集准确率: [0.05 0.05 0.05 0.912 0.932 0.9355]\n", 960 | "epoch: 13100\n", 961 | "[0.05 0.05 0.05 0.98645833 0.988125 0.98958333]\n", 962 | "epoch: 13200\n", 963 | "[0.05 0.05 0.05 0.99270833 0.99645833 0.99666667]\n", 964 | "epoch: 13300\n", 965 | "[0.05 0.05 0.05 0.98895833 0.99041667 0.99229167]\n", 966 | "epoch: 13400\n", 967 | "[0.05 0.05 0.05 0.97958333 0.98291667 0.98354167]\n", 968 | "epoch: 13500\n", 969 | "[0.05 0.05 0.05 0.98791667 0.98854167 0.98916667]\n", 970 | "epoch: 13600\n", 971 | "[0.05 0.05 0.05 0.9825 0.98583333 0.985625 ]\n", 972 | "epoch: 13700\n", 973 | "[0.05 0.05 0.05 0.97895833 0.98479167 0.985625 ]\n", 974 | "epoch: 13800\n", 975 | "[0.05 0.05 0.05 0.98833333 0.995625 0.99645833]\n", 976 | "epoch: 13900\n", 977 | "[0.05 0.05 0.05 0.98770833 0.99416667 0.995 ]\n", 978 | "epoch: 14000\n", 979 | "[0.05 0.05 0.05 0.98791667 0.988125 0.988125 ]\n", 980 | "在mean process之前: (992, 6)\n", 981 | "测试集准确率: [0.05 0.05 0.05 0.9087 0.9297 0.9326]\n", 982 | "epoch: 14100\n", 983 | "[0.05 0.05 0.05 0.99083333 0.99229167 0.99270833]\n", 984 | "epoch: 14200\n", 985 | "[0.05 0.05 0.05 0.98125 0.99270833 0.993125 ]\n", 986 | "epoch: 14300\n", 987 | "[0.05 0.05 0.05 0.98916667 0.99333333 0.99291667]\n", 988 | "epoch: 14400\n", 989 | "[0.05 0.05 0.05 0.98333333 0.98479167 0.98854167]\n", 990 | "epoch: 14500\n", 991 | "[0.05 0.05 0.05 0.983125 0.99125 0.99125 ]\n", 992 | "epoch: 14600\n", 993 | "[0.05 0.05 0.05 0.99083333 0.99229167 0.99229167]\n", 994 | "epoch: 14700\n", 995 | "[0.05 0.05 0.05 0.98125 0.98520833 0.99041667]\n", 996 | "epoch: 14800\n", 997 | "[0.05 0.05 0.05 0.985 0.98375 0.98333333]\n", 998 | "epoch: 14900\n", 999 | "[0.05 0.05 0.05 0.990625 0.99229167 0.994375 ]\n", 1000 | "epoch: 15000\n", 1001 | "[0.05 0.05 0.05 0.98541667 0.98375 0.98395833]\n", 1002 | "在mean process之前: (992, 6)\n", 1003 | "测试集准确率: [0.05 0.05002 0.05 0.914 0.931 0.9336 ]\n", 1004 | "epoch: 15100\n", 1005 | "[0.05 0.05 0.05 0.99291667 0.995625 0.99541667]\n", 1006 | "epoch: 15200\n", 1007 | "[0.05 0.05 0.05 0.98083333 0.98479167 0.98375 ]\n", 1008 | "epoch: 15300\n", 1009 | "[0.05 0.05 0.05 0.98479167 0.98541667 0.98625 ]\n", 1010 | "epoch: 15400\n", 1011 | "[0.05 0.05 0.05 0.98729167 0.989375 0.98979167]\n", 1012 | "epoch: 15500\n", 1013 | "[0.05 0.05 0.05 0.99583333 0.99645833 0.99645833]\n", 1014 | "epoch: 15600\n", 1015 | "[0.05 0.05 0.05 0.99083333 0.99291667 0.99375 ]\n", 1016 | "epoch: 15700\n", 1017 | "[0.05 0.05 0.05 0.98291667 0.9875 0.98770833]\n", 1018 | "epoch: 15800\n", 1019 | "[0.05 0.05 0.05 0.99395833 0.99416667 0.99416667]\n", 1020 | "epoch: 15900\n", 1021 | "[0.05 0.05 0.05 0.9875 0.98833333 0.98875 ]\n", 1022 | "epoch: 16000\n", 1023 | "[0.05 0.05 0.05 0.9875 0.99020833 0.990625 ]\n", 1024 | "在mean process之前: (992, 6)\n", 1025 | "测试集准确率: [0.05 0.05008 0.05 0.9116 0.9297 0.932 ]\n", 1026 | "epoch: 16100\n", 1027 | "[0.05 0.05 0.05 0.993125 0.99541667 0.995625 ]\n", 1028 | "epoch: 16200\n", 1029 | "[0.05 0.05 0.05 0.98270833 0.98791667 0.989375 ]\n", 1030 | "epoch: 16300\n", 1031 | "[0.05 0.05 0.05 0.986875 0.99083333 0.99145833]\n", 1032 | "epoch: 16400\n", 1033 | "[0.05 0.05 0.05 0.9875 0.99020833 0.99041667]\n", 1034 | "epoch: 16500\n", 1035 | "[0.05 0.05 0.05 0.99583333 0.996875 0.9975 ]\n", 1036 | "epoch: 16600\n", 1037 | "[0.05 0.05 0.05 0.98666667 0.98791667 0.988125 ]\n", 1038 | "epoch: 16700\n", 1039 | "[0.05 0.05 0.05 0.98666667 0.98958333 0.98979167]\n", 1040 | "epoch: 16800\n", 1041 | "[0.05 0.05 0.05 0.99291667 0.99458333 0.995 ]\n", 1042 | "epoch: 16900\n", 1043 | "[0.05 0.05 0.05 0.99375 0.99541667 0.99541667]\n", 1044 | "epoch: 17000\n", 1045 | "[0.05 0.05 0.05 0.990625 0.990625 0.99208333]\n", 1046 | "在mean process之前: (992, 6)\n", 1047 | "测试集准确率: [0.05 0.05002 0.05 0.9136 0.9316 0.934 ]\n", 1048 | "epoch: 17100\n", 1049 | "[0.05 0.05 0.05 0.99291667 0.99479167 0.99479167]\n", 1050 | "epoch: 17200\n", 1051 | "[0.05 0.05 0.05 0.978125 0.98125 0.984375]\n", 1052 | "epoch: 17300\n", 1053 | "[0.05 0.05 0.05 0.99229167 0.99541667 0.99541667]\n", 1054 | "epoch: 17400\n", 1055 | "[0.05 0.05 0.05 0.99479167 0.99541667 0.99520833]\n", 1056 | "epoch: 17500\n", 1057 | "[0.05 0.05 0.05 0.98604167 0.99125 0.99125 ]\n", 1058 | "epoch: 17600\n", 1059 | "[0.05 0.05 0.05 0.98625 0.99104167 0.99208333]\n", 1060 | "epoch: 17700\n", 1061 | "[0.05 0.05 0.05 0.99375 0.99583333 0.99604167]\n", 1062 | "epoch: 17800\n", 1063 | "[0.05 0.05 0.05 0.98895833 0.99229167 0.9925 ]\n", 1064 | "epoch: 17900\n", 1065 | "[0.05 0.05 0.05 0.9875 0.99208333 0.99291667]\n", 1066 | "epoch: 18000\n", 1067 | "[0.05 0.05 0.05 0.98229167 0.98520833 0.98604167]\n", 1068 | "在mean process之前: (992, 6)\n", 1069 | "测试集准确率: [0.05 0.05 0.05 0.911 0.9272 0.9307]\n", 1070 | "epoch: 18100\n", 1071 | "[0.05 0.05 0.05 0.99291667 0.99270833 0.99208333]\n", 1072 | "epoch: 18200\n", 1073 | "[0.05 0.05 0.05 0.98833333 0.98645833 0.98729167]\n", 1074 | "epoch: 18300\n", 1075 | "[0.05 0.05 0.05 0.98916667 0.99229167 0.9925 ]\n", 1076 | "epoch: 18400\n", 1077 | "[0.05 0.05 0.05 0.99208333 0.99416667 0.994375 ]\n", 1078 | "epoch: 18500\n", 1079 | "[0.05 0.05 0.05 0.99 0.99375 0.99395833]\n", 1080 | "epoch: 18600\n", 1081 | "[0.05 0.05 0.05 0.98916667 0.99083333 0.99104167]\n" 1082 | ] 1083 | }, 1084 | { 1085 | "name": "stdout", 1086 | "output_type": "stream", 1087 | "text": [ 1088 | "epoch: 18700\n", 1089 | "[0.05 0.05 0.05 0.99145833 0.993125 0.99333333]\n", 1090 | "epoch: 18800\n", 1091 | "[0.05 0.05 0.05 0.99666667 0.99604167 0.99625 ]\n", 1092 | "epoch: 18900\n", 1093 | "[0.05 0.05 0.05 0.9925 0.995 0.99520833]\n", 1094 | "epoch: 19000\n", 1095 | "[0.05 0.05 0.05 0.98354167 0.98666667 0.98708333]\n", 1096 | "在mean process之前: (992, 6)\n", 1097 | "测试集准确率: [0.05 0.05002 0.05 0.91 0.927 0.9297 ]\n", 1098 | "epoch: 19100\n", 1099 | "[0.05 0.05 0.05 0.99104167 0.993125 0.99291667]\n", 1100 | "epoch: 19200\n", 1101 | "[0.05 0.05 0.05 0.985625 0.98583333 0.98604167]\n", 1102 | "epoch: 19300\n", 1103 | "[0.05 0.05 0.05 0.99041667 0.99229167 0.9925 ]\n", 1104 | "epoch: 19400\n", 1105 | "[0.05 0.05 0.05 0.988125 0.993125 0.99270833]\n", 1106 | "epoch: 19500\n", 1107 | "[0.05 0.05 0.05 0.98895833 0.990625 0.98979167]\n", 1108 | "epoch: 19600\n", 1109 | "[0.05 0.05 0.05 0.99375 0.994375 0.99395833]\n", 1110 | "epoch: 19700\n", 1111 | "[0.05 0.05 0.05 0.98979167 0.994375 0.99458333]\n", 1112 | "epoch: 19800\n", 1113 | "[0.05 0.05 0.05 0.990625 0.994375 0.99458333]\n", 1114 | "epoch: 19900\n", 1115 | "[0.05 0.05 0.05 0.99125 0.99375 0.99416667]\n", 1116 | "epoch: 20000\n", 1117 | "[0.05 0.05 0.05 0.98645833 0.98625 0.98958333]\n", 1118 | "在mean process之前: (992, 6)\n", 1119 | "测试集准确率: [0.05 0.05 0.05 0.9146 0.929 0.931 ]\n", 1120 | "epoch: 20100\n", 1121 | "[0.05 0.05 0.05 0.99541667 0.99625 0.99625 ]\n", 1122 | "epoch: 20200\n", 1123 | "[0.05 0.05 0.05 0.99 0.991875 0.9925 ]\n", 1124 | "epoch: 20300\n", 1125 | "[0.05 0.05 0.05 0.99270833 0.99395833 0.99625 ]\n", 1126 | "epoch: 20400\n", 1127 | "[0.05 0.05 0.05 0.990625 0.99229167 0.9925 ]\n", 1128 | "epoch: 20500\n", 1129 | "[0.05 0.05 0.05 0.99375 0.994375 0.994375]\n", 1130 | "epoch: 20600\n", 1131 | "[0.05 0.05 0.05 0.98895833 0.99270833 0.991875 ]\n", 1132 | "epoch: 20700\n", 1133 | "[0.05 0.05 0.05 0.990625 0.99083333 0.99145833]\n", 1134 | "epoch: 20800\n", 1135 | "[0.05 0.05 0.05 0.991875 0.99270833 0.993125 ]\n", 1136 | "epoch: 20900\n", 1137 | "[0.05 0.05 0.05 0.99041667 0.99458333 0.99479167]\n", 1138 | "epoch: 21000\n", 1139 | "[0.05 0.05 0.05 0.9875 0.991875 0.99125 ]\n", 1140 | "在mean process之前: (992, 6)\n", 1141 | "测试集准确率: [0.05 0.05 0.05 0.9097 0.9253 0.9277]\n", 1142 | "epoch: 21100\n", 1143 | "[0.05 0.05 0.05 0.99270833 0.99333333 0.99395833]\n", 1144 | "epoch: 21200\n", 1145 | "[0.05 0.05 0.05 0.98541667 0.98541667 0.98583333]\n", 1146 | "epoch: 21300\n", 1147 | "[0.05 0.05 0.05 0.9925 0.99375 0.99395833]\n", 1148 | "epoch: 21400\n", 1149 | "[0.05 0.05 0.05 0.98458333 0.98791667 0.98854167]\n", 1150 | "epoch: 21500\n", 1151 | "[0.05 0.05 0.05 0.99229167 0.995625 0.995625 ]\n", 1152 | "epoch: 21600\n", 1153 | "[0.05 0.05 0.05 0.985 0.98520833 0.986875 ]\n", 1154 | "epoch: 21700\n", 1155 | "[0.05 0.05 0.05 0.99520833 0.99541667 0.99625 ]\n", 1156 | "epoch: 21800\n", 1157 | "[0.05 0.05 0.05 0.989375 0.99416667 0.995 ]\n", 1158 | "epoch: 21900\n", 1159 | "[0.05 0.05 0.05 0.98916667 0.99458333 0.99479167]\n", 1160 | "epoch: 22000\n", 1161 | "[0.05 0.05 0.05 0.98291667 0.98416667 0.9875 ]\n", 1162 | "在mean process之前: (992, 6)\n", 1163 | "测试集准确率: [0.05 0.05 0.05 0.9146 0.9277 0.9297]\n", 1164 | "epoch: 22100\n", 1165 | "[0.05 0.05 0.05 0.99083333 0.99333333 0.99333333]\n", 1166 | "epoch: 22200\n", 1167 | "[0.05 0.05 0.05 0.99125 0.994375 0.99333333]\n", 1168 | "epoch: 22300\n", 1169 | "[0.05 0.05 0.05 0.98604167 0.99041667 0.98979167]\n", 1170 | "epoch: 22400\n", 1171 | "[0.05 0.05 0.05 0.99520833 0.99604167 0.99604167]\n", 1172 | "epoch: 22500\n", 1173 | "[0.05 0.05 0.05 0.99125 0.99333333 0.99354167]\n", 1174 | "epoch: 22600\n", 1175 | "[0.05 0.05 0.05 0.99479167 0.99583333 0.99583333]\n", 1176 | "epoch: 22700\n", 1177 | "[0.05 0.05 0.05 0.995 0.994375 0.99520833]\n", 1178 | "epoch: 22800\n", 1179 | "[0.05 0.05 0.05 0.98083333 0.98770833 0.98770833]\n", 1180 | "epoch: 22900\n", 1181 | "[0.05 0.05 0.05 0.99 0.990625 0.99020833]\n", 1182 | "epoch: 23000\n", 1183 | "[0.05 0.05 0.05 0.994375 0.99520833 0.995625 ]\n", 1184 | "在mean process之前: (992, 6)\n", 1185 | "测试集准确率: [0.05 0.05 0.05 0.9067 0.9233 0.926 ]\n", 1186 | "epoch: 23100\n", 1187 | "[0.05 0.05 0.05 0.995 0.99479167 0.99583333]\n", 1188 | "epoch: 23200\n", 1189 | "[0.05 0.05 0.05 0.99020833 0.995 0.99520833]\n", 1190 | "epoch: 23300\n", 1191 | "[0.05 0.05 0.05 0.991875 0.99416667 0.99479167]\n", 1192 | "epoch: 23400\n", 1193 | "[0.05 0.05 0.05 0.99125 0.99104167 0.99145833]\n", 1194 | "epoch: 23500\n", 1195 | "[0.05 0.05 0.05 0.98875 0.99104167 0.99104167]\n", 1196 | "epoch: 23600\n", 1197 | "[0.05 0.05 0.05 0.98916667 0.98729167 0.98583333]\n", 1198 | "epoch: 23700\n", 1199 | "[0.05 0.05 0.05 0.985625 0.9875 0.98854167]\n", 1200 | "epoch: 23800\n", 1201 | "[0.05 0.05 0.05 0.99291667 0.995 0.995 ]\n", 1202 | "epoch: 23900\n", 1203 | "[0.05 0.05 0.05 0.98875 0.991875 0.99166667]\n", 1204 | "epoch: 24000\n", 1205 | "[0.05 0.05 0.05 0.99770833 0.998125 0.998125 ]\n", 1206 | "在mean process之前: (992, 6)\n", 1207 | "测试集准确率: [0.05 0.05002 0.05 0.908 0.923 0.925 ]\n", 1208 | "epoch: 24100\n", 1209 | "[0.05 0.05 0.05 0.99125 0.9925 0.99395833]\n", 1210 | "epoch: 24200\n", 1211 | "[0.05 0.05 0.05 0.990625 0.99145833 0.99083333]\n", 1212 | "epoch: 24300\n", 1213 | "[0.05 0.05 0.05 0.9925 0.99395833 0.99458333]\n", 1214 | "epoch: 24400\n", 1215 | "[0.05 0.05 0.05 0.99416667 0.99375 0.99416667]\n", 1216 | "epoch: 24500\n", 1217 | "[0.05 0.05 0.05 0.99520833 0.99666667 0.996875 ]\n", 1218 | "epoch: 24600\n", 1219 | "[0.05 0.05 0.05 0.98916667 0.98958333 0.991875 ]\n", 1220 | "epoch: 24700\n", 1221 | "[0.05 0.05 0.05 0.99583333 0.99666667 0.99833333]\n", 1222 | "epoch: 24800\n", 1223 | "[0.05 0.05 0.05 0.995 0.99541667 0.99583333]\n", 1224 | "epoch: 24900\n", 1225 | "[0.05 0.05 0.05 0.99333333 0.99541667 0.99583333]\n" 1226 | ] 1227 | }, 1228 | { 1229 | "ename": "KeyboardInterrupt", 1230 | "evalue": "", 1231 | "output_type": "error", 1232 | "traceback": [ 1233 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 1234 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 1235 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_qry\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_qry\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0maccs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmeta\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_spt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_spt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_qry\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_qry\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m100\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1236 | "\u001b[0;32m~/anaconda3/envs/ML3.6/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1237 | "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x_spt, y_spt, x_qry, y_qry)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mgrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfast_weights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mtuples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfast_weights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mfast_weights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_lr\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_step\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1238 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m(p)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mgrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfast_weights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mtuples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfast_weights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mfast_weights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_lr\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_step\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1239 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 1240 | ] 1241 | } 1242 | ], 1243 | "source": [ 1244 | "## omniglot\n", 1245 | "import random\n", 1246 | "random.seed(1337)\n", 1247 | "np.random.seed(1337)\n", 1248 | "\n", 1249 | "import time\n", 1250 | "device = torch.device('cuda:3')\n", 1251 | "\n", 1252 | "meta = MetaLearner().to(device)\n", 1253 | "\n", 1254 | "epochs = 60001\n", 1255 | "for step in range(epochs):\n", 1256 | " start = time.time()\n", 1257 | " x_spt, y_spt, x_qry, y_qry = next('train')\n", 1258 | " x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device),\\\n", 1259 | " torch.from_numpy(y_spt).to(device),\\\n", 1260 | " torch.from_numpy(x_qry).to(device),\\\n", 1261 | " torch.from_numpy(y_qry).to(device)\n", 1262 | " accs,loss = meta(x_spt, y_spt, x_qry, y_qry)\n", 1263 | " end = time.time()\n", 1264 | " if step % 100 == 0:\n", 1265 | " print(\"epoch:\" ,step)\n", 1266 | " print(accs)\n", 1267 | "# print(loss)\n", 1268 | " \n", 1269 | " if step % 1000 == 0:\n", 1270 | " accs = []\n", 1271 | " for _ in range(1000//task_num):\n", 1272 | " # db_train.next('test')\n", 1273 | " x_spt, y_spt, x_qry, y_qry = next('test')\n", 1274 | " x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device),\\\n", 1275 | " torch.from_numpy(y_spt).to(device),\\\n", 1276 | " torch.from_numpy(x_qry).to(device),\\\n", 1277 | " torch.from_numpy(y_qry).to(device)\n", 1278 | "\n", 1279 | " \n", 1280 | " for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):\n", 1281 | " test_acc = meta.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)\n", 1282 | " accs.append(test_acc)\n", 1283 | " print('在mean process之前:',np.array(accs).shape)\n", 1284 | " accs = np.array(accs).mean(axis=0).astype(np.float16)\n", 1285 | " print('测试集准确率:',accs)" 1286 | ] 1287 | }, 1288 | { 1289 | "cell_type": "code", 1290 | "execution_count": null, 1291 | "metadata": { 1292 | "ExecuteTime": { 1293 | "end_time": "2020-03-01T03:00:56.266331Z", 1294 | "start_time": "2020-03-01T03:00:56.205955Z" 1295 | } 1296 | }, 1297 | "outputs": [], 1298 | "source": [ 1299 | "\n" 1300 | ] 1301 | }, 1302 | { 1303 | "cell_type": "code", 1304 | "execution_count": null, 1305 | "metadata": {}, 1306 | "outputs": [], 1307 | "source": [] 1308 | }, 1309 | { 1310 | "cell_type": "code", 1311 | "execution_count": null, 1312 | "metadata": {}, 1313 | "outputs": [], 1314 | "source": [] 1315 | } 1316 | ], 1317 | "metadata": { 1318 | "kernelspec": { 1319 | "display_name": "ML3.6", 1320 | "language": "python", 1321 | "name": "ml3.6" 1322 | }, 1323 | "latex_envs": { 1324 | "LaTeX_envs_menu_present": true, 1325 | "autoclose": false, 1326 | "autocomplete": true, 1327 | "bibliofile": "biblio.bib", 1328 | "cite_by": "apalike", 1329 | "current_citInitial": 1, 1330 | "eqLabelWithNumbers": true, 1331 | "eqNumInitial": 1, 1332 | "hotkeys": { 1333 | "equation": "Ctrl-E", 1334 | "itemize": "Ctrl-I" 1335 | }, 1336 | "labels_anchors": false, 1337 | "latex_user_defs": false, 1338 | "report_style_numbering": false, 1339 | "user_envs_cfg": false 1340 | }, 1341 | "toc": { 1342 | "base_numbering": 1, 1343 | "nav_menu": {}, 1344 | "number_sections": true, 1345 | "sideBar": true, 1346 | "skip_h1_title": false, 1347 | "title_cell": "Table of Contents", 1348 | "title_sidebar": "Contents", 1349 | "toc_cell": false, 1350 | "toc_position": {}, 1351 | "toc_section_display": true, 1352 | "toc_window_display": false 1353 | }, 1354 | "varInspector": { 1355 | "cols": { 1356 | "lenName": 16, 1357 | "lenType": 16, 1358 | "lenVar": 40 1359 | }, 1360 | "kernels_config": { 1361 | "python": { 1362 | "delete_cmd_postfix": "", 1363 | "delete_cmd_prefix": "del ", 1364 | "library": "var_list.py", 1365 | "varRefreshCmd": "print(var_dic_list())" 1366 | }, 1367 | "r": { 1368 | "delete_cmd_postfix": ") ", 1369 | "delete_cmd_prefix": "rm(", 1370 | "library": "var_list.r", 1371 | "varRefreshCmd": "cat(var_dic_list()) " 1372 | } 1373 | }, 1374 | "types_to_exclude": [ 1375 | "module", 1376 | "function", 1377 | "builtin_function_or_method", 1378 | "instance", 1379 | "_Feature" 1380 | ], 1381 | "window_display": false 1382 | } 1383 | }, 1384 | "nbformat": 4, 1385 | "nbformat_minor": 2 1386 | } 1387 | -------------------------------------------------------------------------------- /MAML-omniglot-ADAM-5way-32batch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-04-11T10:06:47.039498Z", 9 | "start_time": "2020-04-11T10:06:46.712493Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import torch\n", 15 | "import numpy as np\n", 16 | "import os\n", 17 | "import zipfile\n", 18 | "\n", 19 | "# root_path = './../datasets'\n", 20 | "# processed_folder = os.path.join(root_path)\n", 21 | "\n", 22 | "# zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot.zip'), 'r')\n", 23 | "# zip_ref.extractall(root_path)\n", 24 | "# zip_ref.close()\n", 25 | "root_dir = './../datasets/omniglot/python'\n", 26 | "root_dir_train = os.path.join(root_dir,'images_background')" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "ExecuteTime": { 34 | "end_time": "2020-04-07T06:55:03.418079Z", 35 | "start_time": "2020-04-07T06:55:03.346124Z" 36 | } 37 | }, 38 | "outputs": [], 39 | "source": [] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "### 数据预处理" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": { 52 | "ExecuteTime": { 53 | "end_time": "2020-04-11T10:06:58.751129Z", 54 | "start_time": "2020-04-11T10:06:58.742270Z" 55 | }, 56 | "scrolled": true 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "# # 数据预处理\n", 61 | "# import torchvision.transforms as transforms\n", 62 | "# from PIL import Image\n", 63 | "\n", 64 | "# '''\n", 65 | "# an example of img_items:\n", 66 | "# ( '0709_17.png',\n", 67 | "# 'Alphabet_of_the_Magi/character01',\n", 68 | "# './../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')\n", 69 | "# '''\n", 70 | "\n", 71 | "\n", 72 | "# root_dir_train = os.path.join(root_dir, 'images_background')\n", 73 | "# root_dir_test = os.path.join(root_dir, 'images_evaluation')\n", 74 | "\n", 75 | "# def find_classes(root_dir_train):\n", 76 | "# img_items = []\n", 77 | "# for (root, dirs, files) in os.walk(root_dir_train): \n", 78 | "# for file in files:\n", 79 | "# if (file.endswith(\"png\")):\n", 80 | "# r = root.split('/')\n", 81 | "# img_items.append((file, r[-2] + \"/\" + r[-1], root))\n", 82 | "# print(\"== Found %d items \" % len(img_items))\n", 83 | "# return img_items\n", 84 | "\n", 85 | "# ## 构建一个词典{class:idx}\n", 86 | "# def index_classes(items):\n", 87 | "# class_idx = {}\n", 88 | "# count = 0\n", 89 | "# for item in items:\n", 90 | "# if item[1] not in class_idx:\n", 91 | "# class_idx[item[1]] = count\n", 92 | "# count += 1\n", 93 | "# print('== Found {} classes'.format(len(class_idx)))\n", 94 | "# return class_idx\n", 95 | " \n", 96 | "\n", 97 | "# img_items_train = find_classes(root_dir_train) # [(file1, label1, root1),..]\n", 98 | "# img_items_test = find_classes(root_dir_test)\n", 99 | "\n", 100 | "# class_idx_train = index_classes(img_items_train)\n", 101 | "# class_idx_test = index_classes(img_items_test)\n", 102 | "\n", 103 | "\n", 104 | "# def generate_temp(img_items,class_idx):\n", 105 | "# temp = dict()\n", 106 | "# for imgname, classes, dirs in img_items:\n", 107 | "# img = '{}/{}'.format(dirs, imgname)\n", 108 | "# label = class_idx[classes]\n", 109 | "# transform = transforms.Compose([lambda img: Image.open(img).convert('L'),\n", 110 | "# lambda img: img.resize((28,28)),\n", 111 | "# lambda img: np.reshape(img, (28,28,1)),\n", 112 | "# lambda img: np.transpose(img, [2,0,1]),\n", 113 | "# lambda img: img/255.\n", 114 | "# ])\n", 115 | "# img = transform(img)\n", 116 | "# if label in temp.keys():\n", 117 | "# temp[label].append(img)\n", 118 | "# else:\n", 119 | "# temp[label] = [img]\n", 120 | "# print('begin to generate omniglot.npy')\n", 121 | "# return temp\n", 122 | "# ## 每个字符包含20个样本\n", 123 | "\n", 124 | "# temp_train = generate_temp(img_items_train, class_idx_train)\n", 125 | "# temp_test = generate_temp(img_items_test, class_idx_test)\n", 126 | "\n", 127 | "# img_list = []\n", 128 | "# for label, imgs in temp_train.items():\n", 129 | "# img_list.append(np.array(imgs))\n", 130 | "# img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]\n", 131 | "# print('data shape:{}'.format(img_list.shape)) # (964, 20, 1, 28, 28)\n", 132 | "# np.save(os.path.join(root_dir, 'omniglot_train.npy'), img_list)\n", 133 | "# print('end.')\n", 134 | "\n", 135 | "\n", 136 | "# img_list = []\n", 137 | "# for label, imgs in temp_test.items():\n", 138 | "# img_list.append(np.array(imgs))\n", 139 | "# img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]\n", 140 | "# print('data shape:{}'.format(img_list.shape)) # (659, 20, 1, 28, 28)\n", 141 | "\n", 142 | "# np.save(os.path.join(root_dir, 'omniglot_test.npy'), img_list)\n", 143 | "# print('end.')" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": { 150 | "ExecuteTime": { 151 | "end_time": "2020-04-07T08:02:03.283025Z", 152 | "start_time": "2020-04-07T08:02:03.276106Z" 153 | } 154 | }, 155 | "outputs": [], 156 | "source": [] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 3, 161 | "metadata": { 162 | "ExecuteTime": { 163 | "end_time": "2020-04-11T10:07:02.119085Z", 164 | "start_time": "2020-04-11T10:07:00.028234Z" 165 | } 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "img_list_train = np.load(os.path.join(root_dir, 'omniglot_train.npy')) # (964, 20, 1, 28, 28)\n", 170 | "img_list_test = np.load(os.path.join(root_dir, 'omniglot_test.npy')) # (659, 20, 1, 28, 28)\n", 171 | "\n", 172 | "x_train = img_list_train\n", 173 | "x_test = img_list_test\n", 174 | "# num_classes = img_list.shape[0]\n", 175 | "datasets = {'train': x_train, 'test': x_test}" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 9, 195 | "metadata": { 196 | "ExecuteTime": { 197 | "end_time": "2020-04-12T07:38:15.494291Z", 198 | "start_time": "2020-04-12T07:38:12.226065Z" 199 | }, 200 | "code_folding": [] 201 | }, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "DB: train (964, 20, 1, 28, 28) test (659, 20, 1, 28, 28)\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "### 准备数据迭代器\n", 213 | "n_way = 5\n", 214 | "k_spt = 1 ## support data 的个数\n", 215 | "k_query = 15 ## query data 的个数\n", 216 | "imgsz = 28\n", 217 | "resize = imgsz\n", 218 | "task_num = 32\n", 219 | "batch_size = task_num\n", 220 | "\n", 221 | "indexes = {\"train\": 0, \"test\": 0}\n", 222 | "datasets = {\"train\": x_train, \"test\": x_test}\n", 223 | "print(\"DB: train\", x_train.shape, \"test\", x_test.shape)\n", 224 | "\n", 225 | "\n", 226 | "def load_data_cache(dataset):\n", 227 | " \"\"\"\n", 228 | " Collects several batches data for N-shot learning\n", 229 | " :param dataset: [cls_num, 20, 84, 84, 1]\n", 230 | " :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks\n", 231 | " \"\"\"\n", 232 | " # take 5 way 1 shot as example: 5 * 1\n", 233 | " setsz = k_spt * n_way\n", 234 | " querysz = k_query * n_way\n", 235 | " data_cache = []\n", 236 | "\n", 237 | " # print('preload next 10 caches of batch_size of batch.')\n", 238 | " for sample in range(50): # num of epochs\n", 239 | "\n", 240 | " x_spts, y_spts, x_qrys, y_qrys = [], [], [], []\n", 241 | " for i in range(batch_size): # one batch means one set\n", 242 | "\n", 243 | " x_spt, y_spt, x_qry, y_qry = [], [], [], []\n", 244 | " selected_cls = np.random.choice(dataset.shape[0], n_way, replace = False) \n", 245 | "\n", 246 | " for j, cur_class in enumerate(selected_cls):\n", 247 | "\n", 248 | " selected_img = np.random.choice(20, k_spt + k_query, replace = False)\n", 249 | "\n", 250 | " # 构造support集和query集\n", 251 | " x_spt.append(dataset[cur_class][selected_img[:k_spt]])\n", 252 | " x_qry.append(dataset[cur_class][selected_img[k_spt:]])\n", 253 | " y_spt.append([j for _ in range(k_spt)])\n", 254 | " y_qry.append([j for _ in range(k_query)])\n", 255 | "\n", 256 | " # shuffle inside a batch\n", 257 | " perm = np.random.permutation(n_way * k_spt)\n", 258 | " x_spt = np.array(x_spt).reshape(n_way * k_spt, 1, resize, resize)[perm]\n", 259 | " y_spt = np.array(y_spt).reshape(n_way * k_spt)[perm]\n", 260 | " perm = np.random.permutation(n_way * k_query)\n", 261 | " x_qry = np.array(x_qry).reshape(n_way * k_query, 1, resize, resize)[perm]\n", 262 | " y_qry = np.array(y_qry).reshape(n_way * k_query)[perm]\n", 263 | " \n", 264 | " # append [sptsz, 1, 84, 84] => [batch_size, setsz, 1, 84, 84]\n", 265 | " x_spts.append(x_spt)\n", 266 | " y_spts.append(y_spt)\n", 267 | " x_qrys.append(x_qry)\n", 268 | " y_qrys.append(y_qry)\n", 269 | "\n", 270 | "# print(x_spts[0].shape)\n", 271 | " # [b, setsz = n_way * k_spt, 1, 84, 84]\n", 272 | " x_spts = np.array(x_spts).astype(np.float32).reshape(batch_size, setsz, 1, resize, resize)\n", 273 | " y_spts = np.array(y_spts).astype(np.int).reshape(batch_size, setsz)\n", 274 | " # [b, qrysz = n_way * k_query, 1, 84, 84]\n", 275 | " x_qrys = np.array(x_qrys).astype(np.float32).reshape(batch_size, querysz, 1, resize, resize)\n", 276 | " y_qrys = np.array(y_qrys).astype(np.int).reshape(batch_size, querysz)\n", 277 | "# print(x_qrys.shape)\n", 278 | " data_cache.append([x_spts, y_spts, x_qrys, y_qrys])\n", 279 | "\n", 280 | " return data_cache\n", 281 | "\n", 282 | "datasets_cache = {\"train\": load_data_cache(x_train), # current epoch data cached\n", 283 | " \"test\": load_data_cache(x_test)}\n", 284 | "\n", 285 | "def next(mode='train'):\n", 286 | " \"\"\"\n", 287 | " Gets next batch from the dataset with name.\n", 288 | " :param mode: The name of the splitting (one of \"train\", \"val\", \"test\")\n", 289 | " :return:\n", 290 | " \"\"\"\n", 291 | " # update cache if indexes is larger than len(data_cache)\n", 292 | " if indexes[mode] >= len(datasets_cache[mode]):\n", 293 | " indexes[mode] = 0\n", 294 | " datasets_cache[mode] = load_data_cache(datasets[mode])\n", 295 | "\n", 296 | " next_batch = datasets_cache[mode][indexes[mode]]\n", 297 | " indexes[mode] += 1\n", 298 | "\n", 299 | " return next_batch\n" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 14, 326 | "metadata": { 327 | "ExecuteTime": { 328 | "end_time": "2020-04-13T04:23:17.901802Z", 329 | "start_time": "2020-04-13T04:23:17.858709Z" 330 | }, 331 | "code_folding": [] 332 | }, 333 | "outputs": [], 334 | "source": [ 335 | "import torch\n", 336 | "from torch import nn\n", 337 | "from torch.nn import functional as F\n", 338 | "from copy import deepcopy,copy\n", 339 | " \n", 340 | "\n", 341 | "class BaseNet(nn.Module):\n", 342 | " def __init__(self):\n", 343 | " super(BaseNet, self).__init__()\n", 344 | " self.vars = nn.ParameterList() ## 包含了所有需要被优化的tensor\n", 345 | " self.vars_bn = nn.ParameterList()\n", 346 | " \n", 347 | " # 第1个conv2d\n", 348 | " weight = nn.Parameter(torch.ones(64, 1, 3, 3))\n", 349 | " nn.init.kaiming_normal_(weight)\n", 350 | " bias = nn.Parameter(torch.zeros(64))\n", 351 | " self.vars.extend([weight,bias])\n", 352 | " \n", 353 | " # 第1个BatchNorm层\n", 354 | " weight = nn.Parameter(torch.ones(64))\n", 355 | " bias = nn.Parameter(torch.zeros(64))\n", 356 | " self.vars.extend([weight,bias])\n", 357 | " \n", 358 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 359 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 360 | " self.vars_bn.extend([running_mean, running_var])\n", 361 | " \n", 362 | " # 第2个conv2d\n", 363 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 364 | " nn.init.kaiming_normal_(weight)\n", 365 | " bias = nn.Parameter(torch.zeros(64))\n", 366 | " self.vars.extend([weight,bias])\n", 367 | " \n", 368 | " # 第2个BatchNorm层\n", 369 | " weight = nn.Parameter(torch.ones(64))\n", 370 | " bias = nn.Parameter(torch.zeros(64))\n", 371 | " self.vars.extend([weight,bias])\n", 372 | " \n", 373 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 374 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 375 | " self.vars_bn.extend([running_mean, running_var])\n", 376 | " \n", 377 | " # 第3个conv2d\n", 378 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 379 | " nn.init.kaiming_normal_(weight)\n", 380 | " bias = nn.Parameter(torch.zeros(64))\n", 381 | " self.vars.extend([weight,bias])\n", 382 | " \n", 383 | " # 第3个BatchNorm层\n", 384 | " weight = nn.Parameter(torch.ones(64))\n", 385 | " bias = nn.Parameter(torch.zeros(64))\n", 386 | " self.vars.extend([weight,bias])\n", 387 | " \n", 388 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 389 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 390 | " self.vars_bn.extend([running_mean, running_var])\n", 391 | " \n", 392 | " # 第4个conv2d\n", 393 | " weight = nn.Parameter(torch.ones(64, 64, 3, 3))\n", 394 | " nn.init.kaiming_normal_(weight)\n", 395 | " bias = nn.Parameter(torch.zeros(64))\n", 396 | " self.vars.extend([weight,bias])\n", 397 | " \n", 398 | " # 第4个BatchNorm层\n", 399 | " weight = nn.Parameter(torch.ones(64))\n", 400 | " bias = nn.Parameter(torch.zeros(64))\n", 401 | " self.vars.extend([weight,bias])\n", 402 | " \n", 403 | " running_mean = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 404 | " running_var = nn.Parameter(torch.zeros(64), requires_grad= False)\n", 405 | " self.vars_bn.extend([running_mean, running_var])\n", 406 | " \n", 407 | " ##linear\n", 408 | " weight = nn.Parameter(torch.ones([5,64]))\n", 409 | " bias = nn.Parameter(torch.zeros(5))\n", 410 | " self.vars.extend([weight,bias])\n", 411 | " \n", 412 | " def forward(self, x, params = None, bn_training=True):\n", 413 | " '''\n", 414 | " :bn_training: set False to not update\n", 415 | " :return: \n", 416 | " '''\n", 417 | " if params is None:\n", 418 | " params = self.vars\n", 419 | " \n", 420 | " weight, bias = params[0], params[1] # 第1个CONV层\n", 421 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 422 | " weight, bias = params[2], params[3] # 第1个BN层\n", 423 | " running_mean, running_var = self.vars_bn[0], self.vars_bn[1]\n", 424 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training)\n", 425 | " x = F.relu(x, inplace = [True]) #第1个relu\n", 426 | " x = F.max_pool2d(x,kernel_size=2) #第1个MAX_POOL层 \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " weight, bias = params[4], params[5] # 第2个CONV层\n", 431 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 432 | " weight, bias = params[6], params[7] # 第2个BN层\n", 433 | " running_mean, running_var = self.vars_bn[2], self.vars_bn[3]\n", 434 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training)\n", 435 | " x = F.relu(x, inplace = [True]) #第2个relu\n", 436 | " x = F.max_pool2d(x,kernel_size=2) #第2个MAX_POOL层 \n", 437 | " \n", 438 | " \n", 439 | " weight, bias = params[8], params[9] # 第3个CONV层\n", 440 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 441 | " weight, bias = params[10], params[11] # 第3个BN层\n", 442 | " running_mean, running_var = self.vars_bn[4], self.vars_bn[5]\n", 443 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training)\n", 444 | " x = F.relu(x, inplace = [True]) #第3个relu\n", 445 | " x = F.max_pool2d(x,kernel_size=2) #第3个MAX_POOL层\n", 446 | " \n", 447 | " \n", 448 | " weight, bias = params[12], params[13] # 第4个CONV层\n", 449 | " x = F.conv2d(x, weight, bias, stride = 1, padding = 1)\n", 450 | " weight, bias = params[14], params[15] # 第4个BN层\n", 451 | " running_mean, running_var = self.vars_bn[6], self.vars_bn[7]\n", 452 | " x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training)\n", 453 | " x = F.relu(x, inplace = [True]) #第4个relu\n", 454 | " x = F.max_pool2d(x,kernel_size=2) #第4个MAX_POOL层\n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " x = x.view(x.size(0), -1) ## flatten\n", 459 | " weight, bias = params[-2], params[-1] # linear\n", 460 | " x = F.linear(x, weight, bias)\n", 461 | " \n", 462 | " output = x\n", 463 | " \n", 464 | " return output\n", 465 | " \n", 466 | " \n", 467 | " def parameters(self):\n", 468 | " \n", 469 | " return self.vars\n" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "ExecuteTime": { 477 | "end_time": "2020-02-29T12:00:30.197710Z", 478 | "start_time": "2020-02-29T12:00:30.186076Z" 479 | } 480 | }, 481 | "outputs": [], 482 | "source": [] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "metadata": { 488 | "ExecuteTime": { 489 | "end_time": "2020-02-29T05:41:40.773998Z", 490 | "start_time": "2020-02-29T05:41:40.762077Z" 491 | } 492 | }, 493 | "outputs": [], 494 | "source": [] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 15, 499 | "metadata": { 500 | "ExecuteTime": { 501 | "end_time": "2020-04-13T04:23:19.569591Z", 502 | "start_time": "2020-04-13T04:23:19.546142Z" 503 | }, 504 | "code_folding": [] 505 | }, 506 | "outputs": [], 507 | "source": [ 508 | "class MetaLearner(nn.Module):\n", 509 | " def __init__(self):\n", 510 | " super(MetaLearner, self).__init__()\n", 511 | " self.update_step = 5 ## task-level inner update steps\n", 512 | " self.update_step_test = 5\n", 513 | " self.net = BaseNet()\n", 514 | " self.meta_lr = 0.001\n", 515 | " self.base_lr = 0.1\n", 516 | " self.meta_optim = torch.optim.Adam(self.net.parameters(), lr = self.meta_lr)\n", 517 | "# self.meta_optim = torch.optim.SGD(self.net.parameters(), lr = self.meta_lr, momentum = 0.9, weight_decay=0.0005)\n", 518 | " \n", 519 | " def forward(self,x_spt, y_spt, x_qry, y_qry):\n", 520 | " # 初始化\n", 521 | " task_num, ways, shots, h, w = x_spt.size()\n", 522 | " query_size = x_qry.size(1) # 75 = 15 * 5\n", 523 | " loss_list_qry = [0 for _ in range(self.update_step + 1)]\n", 524 | " correct_list = [0 for _ in range(self.update_step + 1)]\n", 525 | " \n", 526 | " for i in range(task_num):\n", 527 | " ## 第0步更新\n", 528 | " y_hat = self.net(x_spt[i], params = None, bn_training=True) # (ways * shots, ways)\n", 529 | " loss = F.cross_entropy(y_hat, y_spt[i]) \n", 530 | " grad = torch.autograd.grad(loss, self.net.parameters())\n", 531 | " tuples = zip(grad, self.net.parameters()) ## 将梯度和参数\\theta一一对应起来\n", 532 | " # fast_weights这一步相当于求了一个\\theta - \\alpha*\\nabla(L)\n", 533 | " fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))\n", 534 | " # 在query集上测试,计算准确率\n", 535 | " # 这一步使用更新前的数据\n", 536 | " with torch.no_grad():\n", 537 | " y_hat = self.net(x_qry[i], self.net.parameters(), bn_training = True)\n", 538 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 539 | " loss_list_qry[0] += loss_qry\n", 540 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 541 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 542 | " correct_list[0] += correct\n", 543 | " \n", 544 | " # 使用更新后的数据在query集上测试。\n", 545 | " with torch.no_grad():\n", 546 | " y_hat = self.net(x_qry[i], fast_weights, bn_training = True)\n", 547 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 548 | " loss_list_qry[1] += loss_qry\n", 549 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 550 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 551 | " correct_list[1] += correct \n", 552 | " \n", 553 | " for k in range(1, self.update_step):\n", 554 | " \n", 555 | " y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)\n", 556 | " loss = F.cross_entropy(y_hat, y_spt[i])\n", 557 | " grad = torch.autograd.grad(loss, fast_weights)\n", 558 | " tuples = zip(grad, fast_weights) \n", 559 | " fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))\n", 560 | " \n", 561 | " if k < self.update_step - 1:\n", 562 | " with torch.no_grad():\n", 563 | " y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)\n", 564 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 565 | " loss_list_qry[k+1] += loss_qry\n", 566 | " else:\n", 567 | " y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)\n", 568 | " loss_qry = F.cross_entropy(y_hat, y_qry[i])\n", 569 | " loss_list_qry[k+1] += loss_qry\n", 570 | " \n", 571 | " with torch.no_grad():\n", 572 | " pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)\n", 573 | " correct = torch.eq(pred_qry, y_qry[i]).sum().item()\n", 574 | " correct_list[k+1] += correct\n", 575 | "# print('hello')\n", 576 | " \n", 577 | " loss_qry = loss_list_qry[-1] / task_num\n", 578 | " self.meta_optim.zero_grad() # 梯度清零\n", 579 | " loss_qry.backward()\n", 580 | " self.meta_optim.step()\n", 581 | " \n", 582 | " accs = np.array(correct_list) / (query_size * task_num)\n", 583 | " loss = np.array(loss_list_qry) / ( task_num)\n", 584 | " return accs,loss\n", 585 | "\n", 586 | " \n", 587 | " \n", 588 | " def finetunning(self, x_spt, y_spt, x_qry, y_qry):\n", 589 | " assert len(x_spt.shape) == 4\n", 590 | " \n", 591 | " query_size = x_qry.size(0)\n", 592 | " correct_list = [0 for _ in range(self.update_step_test + 1)]\n", 593 | " \n", 594 | " new_net = deepcopy(self.net)\n", 595 | " y_hat = new_net(x_spt)\n", 596 | " loss = F.cross_entropy(y_hat, y_spt)\n", 597 | " grad = torch.autograd.grad(loss, new_net.parameters())\n", 598 | " fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, new_net.parameters())))\n", 599 | " \n", 600 | " # 在query集上测试,计算准确率\n", 601 | " # 这一步使用更新前的数据\n", 602 | " with torch.no_grad():\n", 603 | " y_hat = new_net(x_qry, params = new_net.parameters(), bn_training = True)\n", 604 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 605 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 606 | " correct_list[0] += correct\n", 607 | "\n", 608 | " # 使用更新后的数据在query集上测试。\n", 609 | " with torch.no_grad():\n", 610 | " y_hat = new_net(x_qry, params = fast_weights, bn_training = True)\n", 611 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)\n", 612 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 613 | " correct_list[1] += correct\n", 614 | "\n", 615 | " for k in range(1, self.update_step_test):\n", 616 | " y_hat = new_net(x_spt, params = fast_weights, bn_training=True)\n", 617 | " loss = F.cross_entropy(y_hat, y_spt)\n", 618 | " grad = torch.autograd.grad(loss, fast_weights)\n", 619 | " fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, fast_weights)))\n", 620 | " \n", 621 | " y_hat = new_net(x_qry, fast_weights, bn_training=True)\n", 622 | " \n", 623 | " with torch.no_grad():\n", 624 | " pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)\n", 625 | " correct = torch.eq(pred_qry, y_qry).sum().item()\n", 626 | " correct_list[k+1] += correct\n", 627 | " \n", 628 | " del new_net\n", 629 | " accs = np.array(correct_list) / query_size\n", 630 | " return accs\n", 631 | " " 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 16, 637 | "metadata": { 638 | "ExecuteTime": { 639 | "end_time": "2020-04-13T04:23:20.160756Z", 640 | "start_time": "2020-04-13T04:23:20.158324Z" 641 | } 642 | }, 643 | "outputs": [], 644 | "source": [ 645 | "# net = torch.load('./trained_models/MTL-5000epochs.pt')" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": 17, 651 | "metadata": { 652 | "ExecuteTime": { 653 | "end_time": "2020-04-13T22:32:22.135112Z", 654 | "start_time": "2020-04-13T04:23:20.608802Z" 655 | }, 656 | "scrolled": true 657 | }, 658 | "outputs": [ 659 | { 660 | "name": "stdout", 661 | "output_type": "stream", 662 | "text": [ 663 | "epoch: 0\n", 664 | "[0.2 0.2875 0.56458333 0.57375 0.61833333 0.6325 ]\n", 665 | "在mean process之前: (992, 6)\n", 666 | "测试集准确率: [0.2 0.3408 0.586 0.5796 0.6177 0.633 ]\n", 667 | "epoch: 100\n", 668 | "[0.2 0.2 0.57625 0.83416667 0.89708333 0.91375 ]\n", 669 | "epoch: 200\n", 670 | "[0.2 0.20541667 0.2 0.75333333 0.8925 0.91875 ]\n", 671 | "epoch: 300\n", 672 | "[0.2 0.32083333 0.57875 0.88125 0.95708333 0.96916667]\n", 673 | "epoch: 400\n", 674 | "[0.2 0.22875 0.4425 0.86083333 0.93333333 0.95291667]\n", 675 | "epoch: 500\n", 676 | "[0.2 0.32833333 0.65916667 0.91708333 0.95291667 0.96666667]\n", 677 | "epoch: 600\n", 678 | "[0.2 0.37875 0.74375 0.945 0.98 0.98166667]\n", 679 | "epoch: 700\n", 680 | "[0.2 0.27666667 0.50458333 0.95208333 0.97 0.97333333]\n", 681 | "epoch: 800\n", 682 | "[0.2 0.22208333 0.53 0.91458333 0.95041667 0.95875 ]\n", 683 | "epoch: 900\n", 684 | "[0.2 0.2 0.885 0.96166667 0.97291667 0.97583333]\n", 685 | "epoch: 1000\n", 686 | "[0.2 0.2 0.90291667 0.94083333 0.9575 0.96666667]\n", 687 | "在mean process之前: (992, 6)\n", 688 | "测试集准确率: [0.2 0.2 0.8193 0.8965 0.922 0.932 ]\n", 689 | "epoch: 1100\n", 690 | "[0.2 0.2 0.94 0.97708333 0.98166667 0.9825 ]\n", 691 | "epoch: 1200\n", 692 | "[0.2 0.2 0.90666667 0.94625 0.96041667 0.965 ]\n", 693 | "epoch: 1300\n", 694 | "[0.2 0.2 0.91625 0.94208333 0.95791667 0.9625 ]\n", 695 | "epoch: 1400\n", 696 | "[0.2 0.2 0.92958333 0.95333333 0.96541667 0.97041667]\n", 697 | "epoch: 1500\n", 698 | "[0.2 0.2 0.96083333 0.97916667 0.9825 0.98333333]\n", 699 | "epoch: 1600\n", 700 | "[0.2 0.2 0.96541667 0.96916667 0.97375 0.97541667]\n", 701 | "epoch: 1700\n", 702 | "[0.2 0.2 0.935 0.955 0.95916667 0.95875 ]\n", 703 | "epoch: 1800\n", 704 | "[0.2 0.2 0.93458333 0.95583333 0.96625 0.96958333]\n", 705 | "epoch: 1900\n", 706 | "[0.2 0.2 0.94541667 0.95166667 0.965 0.96583333]\n", 707 | "epoch: 2000\n", 708 | "[0.2 0.2 0.94291667 0.97 0.97541667 0.97708333]\n", 709 | "在mean process之前: (992, 6)\n", 710 | "测试集准确率: [0.2 0.2 0.8804 0.9155 0.924 0.9272]\n", 711 | "epoch: 2100\n", 712 | "[0.2 0.2 0.97166667 0.97958333 0.98 0.98083333]\n", 713 | "epoch: 2200\n", 714 | "[0.2 0.2 0.94041667 0.97 0.97416667 0.975 ]\n", 715 | "epoch: 2300\n", 716 | "[0.2 0.2 0.94625 0.97041667 0.97333333 0.97458333]\n", 717 | "epoch: 2400\n", 718 | "[0.2 0.2 0.96416667 0.98041667 0.98208333 0.98291667]\n", 719 | "epoch: 2500\n", 720 | "[0.2 0.2 0.96333333 0.96916667 0.96916667 0.97 ]\n", 721 | "epoch: 2600\n", 722 | "[0.2 0.20916667 0.95291667 0.97291667 0.97208333 0.975 ]\n", 723 | "epoch: 2700\n", 724 | "[0.2 0.50791667 0.955 0.97708333 0.9775 0.97875 ]\n", 725 | "epoch: 2800\n", 726 | "[0.2 0.63791667 0.96458333 0.97 0.97166667 0.9725 ]\n", 727 | "epoch: 2900\n", 728 | "[0.2 0.67375 0.9725 0.98041667 0.98291667 0.98416667]\n", 729 | "epoch: 3000\n", 730 | "[0.2 0.67458333 0.95 0.9675 0.96875 0.96916667]\n", 731 | "在mean process之前: (992, 6)\n", 732 | "测试集准确率: [0.2 0.599 0.9087 0.9253 0.9287 0.9307]\n", 733 | "epoch: 3100\n", 734 | "[0.2 0.72791667 0.97 0.9825 0.98291667 0.98416667]\n", 735 | "epoch: 3200\n", 736 | "[0.2 0.82291667 0.97083333 0.97166667 0.97208333 0.9725 ]\n", 737 | "epoch: 3300\n", 738 | "[0.2 0.91208333 0.95875 0.96708333 0.96916667 0.97 ]\n", 739 | "epoch: 3400\n", 740 | "[0.2 0.96583333 0.97125 0.97666667 0.9775 0.97791667]\n", 741 | "epoch: 3500\n", 742 | "[0.2 0.9525 0.96666667 0.97458333 0.975 0.97625 ]\n", 743 | "epoch: 3600\n", 744 | "[0.2 0.97541667 0.9725 0.97958333 0.98 0.98041667]\n", 745 | "epoch: 3700\n", 746 | "[0.2 0.96541667 0.96166667 0.965 0.96458333 0.965 ]\n", 747 | "epoch: 3800\n", 748 | "[0.2 0.995 0.99541667 0.99541667 0.99541667 0.99583333]\n", 749 | "epoch: 3900\n", 750 | "[0.2 0.97791667 0.98083333 0.98041667 0.98166667 0.98291667]\n", 751 | "epoch: 4000\n", 752 | "[0.2 0.9825 0.98583333 0.98583333 0.98625 0.98666667]\n", 753 | "在mean process之前: (992, 6)\n", 754 | "测试集准确率: [0.2 0.9395 0.945 0.9473 0.9487 0.949 ]\n", 755 | "epoch: 4100\n", 756 | "[0.2 0.98333333 0.98583333 0.98916667 0.98875 0.98916667]\n", 757 | "epoch: 4200\n", 758 | "[0.2 0.9675 0.97291667 0.97625 0.9775 0.97875 ]\n", 759 | "epoch: 4300\n", 760 | "[0.2 0.98708333 0.98958333 0.99041667 0.98958333 0.98958333]\n", 761 | "epoch: 4400\n", 762 | "[0.2 0.975 0.97625 0.98333333 0.98333333 0.98458333]\n", 763 | "epoch: 4500\n", 764 | "[0.2 0.98083333 0.98375 0.98375 0.98458333 0.985 ]\n", 765 | "epoch: 4600\n", 766 | "[0.2 0.98125 0.9825 0.9825 0.98166667 0.98125 ]\n", 767 | "epoch: 4700\n", 768 | "[0.2 0.975 0.9775 0.97916667 0.97958333 0.97916667]\n", 769 | "epoch: 4800\n", 770 | "[0.2 0.97291667 0.97916667 0.97958333 0.98041667 0.98041667]\n", 771 | "epoch: 4900\n", 772 | "[0.2 0.99291667 0.9925 0.9925 0.99291667 0.99291667]\n", 773 | "epoch: 5000\n", 774 | "[0.2 0.97458333 0.9775 0.97875 0.98 0.98125 ]\n", 775 | "在mean process之前: (992, 6)\n", 776 | "测试集准确率: [0.2 0.9453 0.951 0.953 0.954 0.9546]\n", 777 | "epoch: 5100\n", 778 | "[0.2 0.98416667 0.98291667 0.98583333 0.98833333 0.9875 ]\n", 779 | "epoch: 5200\n", 780 | "[0.2 0.9875 0.98791667 0.99083333 0.99083333 0.99125 ]\n", 781 | "epoch: 5300\n", 782 | "[0.2 0.98625 0.98208333 0.98208333 0.98291667 0.98291667]\n", 783 | "epoch: 5400\n", 784 | "[0.2 0.98416667 0.985 0.98625 0.98708333 0.9875 ]\n", 785 | "epoch: 5500\n", 786 | "[0.2 0.98208333 0.98708333 0.9875 0.9875 0.98791667]\n", 787 | "epoch: 5600\n", 788 | "[0.2 0.99291667 0.99375 0.99375 0.99375 0.99375 ]\n", 789 | "epoch: 5700\n", 790 | "[0.2 0.995 0.99458333 0.99458333 0.99458333 0.99458333]\n", 791 | "epoch: 5800\n", 792 | "[0.2 0.99125 0.99375 0.99416667 0.99416667 0.99416667]\n", 793 | "epoch: 5900\n", 794 | "[0.2 0.9875 0.98666667 0.9875 0.98833333 0.98875 ]\n", 795 | "epoch: 6000\n", 796 | "[0.2 0.98333333 0.98875 0.98875 0.98875 0.98833333]\n", 797 | "在mean process之前: (992, 6)\n", 798 | "测试集准确率: [0.2 0.9565 0.959 0.9595 0.96 0.96 ]\n", 799 | "epoch: 6100\n", 800 | "[0.2 0.98375 0.985 0.985 0.985 0.98583333]\n", 801 | "epoch: 6200\n", 802 | "[0.2 0.98791667 0.98791667 0.98791667 0.98791667 0.98833333]\n", 803 | "epoch: 6300\n", 804 | "[0.2 0.98916667 0.99 0.99 0.99 0.99 ]\n", 805 | "epoch: 6400\n", 806 | "[0.2 0.97958333 0.98583333 0.98666667 0.98666667 0.98708333]\n", 807 | "epoch: 6500\n", 808 | "[0.2 0.98791667 0.98458333 0.99083333 0.99083333 0.99125 ]\n", 809 | "epoch: 6600\n", 810 | "[0.2 0.9925 0.99291667 0.99333333 0.99333333 0.99333333]\n", 811 | "epoch: 6700\n", 812 | "[0.2 0.98666667 0.98791667 0.99041667 0.99041667 0.99083333]\n", 813 | "epoch: 6800\n", 814 | "[0.2 0.98541667 0.98583333 0.98541667 0.98541667 0.98541667]\n", 815 | "epoch: 6900\n", 816 | "[0.2 0.98958333 0.99333333 0.99333333 0.99333333 0.99333333]\n", 817 | "epoch: 7000\n", 818 | "[0.2 0.99333333 0.99458333 0.99458333 0.995 0.99541667]\n", 819 | "在mean process之前: (992, 6)\n", 820 | "测试集准确率: [0.2 0.9624 0.9634 0.964 0.964 0.9644]\n", 821 | "epoch: 7100\n", 822 | "[0.2 0.98458333 0.985 0.98625 0.98625 0.98625 ]\n", 823 | "epoch: 7200\n", 824 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 825 | "epoch: 7300\n", 826 | "[0.2 0.97833333 0.97583333 0.98041667 0.98041667 0.98041667]\n", 827 | "epoch: 7400\n", 828 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 829 | "epoch: 7500\n", 830 | "[0.2 0.99333333 0.99333333 0.99333333 0.99333333 0.99333333]\n", 831 | "epoch: 7600\n", 832 | "[0.2 0.98291667 0.9825 0.98333333 0.98333333 0.98333333]\n", 833 | "epoch: 7700\n", 834 | "[0.2 0.99125 0.99125 0.99125 0.99125 0.99166667]\n", 835 | "epoch: 7800\n", 836 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 837 | "epoch: 7900\n", 838 | "[0.2 0.99041667 0.99166667 0.99166667 0.99166667 0.99208333]\n", 839 | "epoch: 8000\n", 840 | "[0.2 0.98541667 0.98583333 0.98666667 0.98666667 0.98708333]\n", 841 | "在mean process之前: (992, 6)\n", 842 | "测试集准确率: [0.2 0.9614 0.962 0.9624 0.963 0.963 ]\n", 843 | "epoch: 8100\n", 844 | "[0.2 0.99166667 0.99166667 0.99166667 0.99166667 0.99166667]\n", 845 | "epoch: 8200\n", 846 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 847 | "epoch: 8300\n", 848 | "[0.2 0.99375 0.99458333 0.99458333 0.99458333 0.995 ]\n", 849 | "epoch: 8400\n", 850 | "[0.2 0.98541667 0.9875 0.98791667 0.98833333 0.98875 ]\n", 851 | "epoch: 8500\n", 852 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 853 | "epoch: 8600\n", 854 | "[0.2 0.99458333 0.995 0.99583333 0.99583333 0.99583333]\n", 855 | "epoch: 8700\n", 856 | "[0.2 0.98791667 0.98666667 0.98625 0.98666667 0.98666667]\n", 857 | "epoch: 8800\n", 858 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 859 | "epoch: 8900\n", 860 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99458333]\n", 861 | "epoch: 9000\n", 862 | "[0.2 0.99 0.99041667 0.99083333 0.99083333 0.99083333]\n", 863 | "在mean process之前: (992, 6)\n", 864 | "测试集准确率: [0.2 0.966 0.9663 0.967 0.967 0.967 ]\n", 865 | "epoch: 9100\n", 866 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 867 | "epoch: 9200\n", 868 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 869 | "epoch: 9300\n", 870 | "[0.2 0.995 0.99541667 0.99541667 0.99541667 0.99541667]\n" 871 | ] 872 | }, 873 | { 874 | "name": "stdout", 875 | "output_type": "stream", 876 | "text": [ 877 | "epoch: 9400\n", 878 | "[0.2 0.99291667 0.99291667 0.99291667 0.99291667 0.99333333]\n", 879 | "epoch: 9500\n", 880 | "[0.2 0.99583333 0.99625 0.99625 0.99625 0.99666667]\n", 881 | "epoch: 9600\n", 882 | "[0.2 0.99125 0.99166667 0.99166667 0.99166667 0.99166667]\n", 883 | "epoch: 9700\n", 884 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 885 | "epoch: 9800\n", 886 | "[0.2 0.9875 0.9875 0.98791667 0.98791667 0.98791667]\n", 887 | "epoch: 9900\n", 888 | "[0.2 0.9925 0.9925 0.9925 0.9925 0.9925]\n", 889 | "epoch: 10000\n", 890 | "[0.2 0.99166667 0.99208333 0.99208333 0.99208333 0.99208333]\n", 891 | "在mean process之前: (992, 6)\n", 892 | "测试集准确率: [0.2 0.9663 0.967 0.967 0.967 0.967 ]\n", 893 | "epoch: 10100\n", 894 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 895 | "epoch: 10200\n", 896 | "[0.2 0.98541667 0.98541667 0.98541667 0.98541667 0.98541667]\n", 897 | "epoch: 10300\n", 898 | "[0.2 0.99333333 0.99416667 0.99416667 0.99416667 0.99416667]\n", 899 | "epoch: 10400\n", 900 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 901 | "epoch: 10500\n", 902 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 903 | "epoch: 10600\n", 904 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 905 | "epoch: 10700\n", 906 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 907 | "epoch: 10800\n", 908 | "[0.2 0.98333333 0.98333333 0.98333333 0.98375 0.98375 ]\n", 909 | "epoch: 10900\n", 910 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 911 | "epoch: 11000\n", 912 | "[0.2 0.99166667 0.99166667 0.99166667 0.99166667 0.99166667]\n", 913 | "在mean process之前: (992, 6)\n", 914 | "测试集准确率: [0.2 0.9707 0.9707 0.9707 0.971 0.971 ]\n", 915 | "epoch: 11100\n", 916 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99583333]\n", 917 | "epoch: 11200\n", 918 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 919 | "epoch: 11300\n", 920 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 921 | "epoch: 11400\n", 922 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 923 | "epoch: 11500\n", 924 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 925 | "epoch: 11600\n", 926 | "[0.2 0.99125 0.99166667 0.99166667 0.99166667 0.99166667]\n", 927 | "epoch: 11700\n", 928 | "[0.2 0.98833333 0.98875 0.98958333 0.98958333 0.98958333]\n", 929 | "epoch: 11800\n", 930 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.9975 ]\n", 931 | "epoch: 11900\n", 932 | "[0.2 0.99375 0.99541667 0.99541667 0.99583333 0.99583333]\n", 933 | "epoch: 12000\n", 934 | "[0.2 0.99166667 0.99166667 0.99166667 0.99166667 0.99166667]\n", 935 | "在mean process之前: (992, 6)\n", 936 | "测试集准确率: [0.2 0.9697 0.9697 0.9697 0.9697 0.9697]\n", 937 | "epoch: 12100\n", 938 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 939 | "epoch: 12200\n", 940 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 941 | "epoch: 12300\n", 942 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 943 | "epoch: 12400\n", 944 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 945 | "epoch: 12500\n", 946 | "[0.2 0.99333333 0.99333333 0.99333333 0.99333333 0.99333333]\n", 947 | "epoch: 12600\n", 948 | "[0.2 0.99041667 0.99041667 0.99041667 0.99041667 0.99041667]\n", 949 | "epoch: 12700\n", 950 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 951 | "epoch: 12800\n", 952 | "[0.2 0.98875 0.98875 0.98875 0.98875 0.98916667]\n", 953 | "epoch: 12900\n", 954 | "[0.2 0.99375 0.995 0.995 0.995 0.995 ]\n", 955 | "epoch: 13000\n", 956 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 957 | "在mean process之前: (992, 6)\n", 958 | "测试集准确率: [0.2 0.9673 0.9673 0.9673 0.9673 0.9673]\n", 959 | "epoch: 13100\n", 960 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 961 | "epoch: 13200\n", 962 | "[0.2 0.99208333 0.99208333 0.99208333 0.99208333 0.99208333]\n", 963 | "epoch: 13300\n", 964 | "[0.2 0.99458333 0.995 0.995 0.995 0.995 ]\n", 965 | "epoch: 13400\n", 966 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 967 | "epoch: 13500\n", 968 | "[0.2 0.99208333 0.99208333 0.99208333 0.99208333 0.99208333]\n", 969 | "epoch: 13600\n", 970 | "[0.2 0.99 0.99 0.99 0.99 0.99]\n", 971 | "epoch: 13700\n", 972 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 973 | "epoch: 13800\n", 974 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 975 | "epoch: 13900\n", 976 | "[0.2 0.99333333 0.99333333 0.99333333 0.99333333 0.99333333]\n", 977 | "epoch: 14000\n", 978 | "[0.2 0.99291667 0.99291667 0.99291667 0.99291667 0.99291667]\n", 979 | "在mean process之前: (992, 6)\n", 980 | "测试集准确率: [0.2 0.97 0.97 0.97 0.97 0.97]\n", 981 | "epoch: 14100\n", 982 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 983 | "epoch: 14200\n", 984 | "[0.2 0.9975 0.9975 0.99791667 0.99791667 0.99791667]\n", 985 | "epoch: 14300\n", 986 | "[0.2 0.98 0.98 0.98 0.98 0.98083333]\n", 987 | "epoch: 14400\n", 988 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 989 | "epoch: 14500\n", 990 | "[0.2 0.99625 0.99833333 0.99833333 0.99833333 0.99833333]\n", 991 | "epoch: 14600\n", 992 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 993 | "epoch: 14700\n", 994 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 995 | "epoch: 14800\n", 996 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 997 | "epoch: 14900\n", 998 | "[0.2 0.98916667 0.98916667 0.98916667 0.98916667 0.98916667]\n", 999 | "epoch: 15000\n", 1000 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1001 | "在mean process之前: (992, 6)\n", 1002 | "测试集准确率: [0.2 0.9707 0.9707 0.9707 0.9707 0.9707]\n", 1003 | "epoch: 15100\n", 1004 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1005 | "epoch: 15200\n", 1006 | "[0.2 0.99125 0.99125 0.99125 0.99125 0.99125]\n", 1007 | "epoch: 15300\n", 1008 | "[0.2 0.99708333 0.99666667 0.99625 0.99625 0.99625 ]\n", 1009 | "epoch: 15400\n", 1010 | "[0.2 0.99125 0.99125 0.99125 0.99125 0.99125]\n", 1011 | "epoch: 15500\n", 1012 | "[0.2 0.98916667 0.98916667 0.98916667 0.98916667 0.98916667]\n", 1013 | "epoch: 15600\n", 1014 | "[0.2 0.99458333 0.995 0.995 0.995 0.99416667]\n", 1015 | "epoch: 15700\n", 1016 | "[0.2 0.9925 0.99333333 0.99375 0.99375 0.99375 ]\n", 1017 | "epoch: 15800\n", 1018 | "[0.2 0.99125 0.99125 0.99125 0.99125 0.99166667]\n", 1019 | "epoch: 15900\n", 1020 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1021 | "epoch: 16000\n", 1022 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1023 | "在mean process之前: (992, 6)\n", 1024 | "测试集准确率: [0.2 0.973 0.973 0.973 0.973 0.973]\n", 1025 | "epoch: 16100\n", 1026 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1027 | "epoch: 16200\n", 1028 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1029 | "epoch: 16300\n", 1030 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1031 | "epoch: 16400\n", 1032 | "[0.2 0.99208333 0.99208333 0.99208333 0.99208333 0.99208333]\n", 1033 | "epoch: 16500\n", 1034 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1035 | "epoch: 16600\n", 1036 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1037 | "epoch: 16700\n", 1038 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1039 | "epoch: 16800\n", 1040 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1041 | "epoch: 16900\n", 1042 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1043 | "epoch: 17000\n", 1044 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1045 | "在mean process之前: (992, 6)\n", 1046 | "测试集准确率: [0.2 0.9707 0.9707 0.9707 0.9707 0.9707]\n", 1047 | "epoch: 17100\n", 1048 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1049 | "epoch: 17200\n", 1050 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1051 | "epoch: 17300\n", 1052 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1053 | "epoch: 17400\n", 1054 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1055 | "epoch: 17500\n", 1056 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1057 | "epoch: 17600\n", 1058 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1059 | "epoch: 17700\n", 1060 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1061 | "epoch: 17800\n", 1062 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1063 | "epoch: 17900\n", 1064 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1065 | "epoch: 18000\n", 1066 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1067 | "在mean process之前: (992, 6)\n", 1068 | "测试集准确率: [0.2 0.9688 0.9688 0.9688 0.9688 0.9688]\n", 1069 | "epoch: 18100\n", 1070 | "[0.2 0.99166667 0.99166667 0.9925 0.9925 0.9925 ]\n", 1071 | "epoch: 18200\n", 1072 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1073 | "epoch: 18300\n", 1074 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1075 | "epoch: 18400\n", 1076 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1077 | "epoch: 18500\n", 1078 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1079 | "epoch: 18600\n", 1080 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1081 | "epoch: 18700\n", 1082 | "[0.2 0.9925 0.9925 0.99291667 0.99291667 0.99291667]\n", 1083 | "epoch: 18800\n", 1084 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1085 | "epoch: 18900\n", 1086 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1087 | "epoch: 19000\n", 1088 | "[0.2 0.98916667 0.98916667 0.98916667 0.98916667 0.98916667]\n", 1089 | "在mean process之前: (992, 6)\n", 1090 | "测试集准确率: [0.2 0.973 0.973 0.973 0.973 0.973]\n", 1091 | "epoch: 19100\n", 1092 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n" 1093 | ] 1094 | }, 1095 | { 1096 | "name": "stdout", 1097 | "output_type": "stream", 1098 | "text": [ 1099 | "epoch: 19200\n", 1100 | "[0.2 0.98875 0.98875 0.98875 0.98875 0.98875]\n", 1101 | "epoch: 19300\n", 1102 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1103 | "epoch: 19400\n", 1104 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1105 | "epoch: 19500\n", 1106 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1107 | "epoch: 19600\n", 1108 | "[0.2 0.99708333 0.9975 0.9975 0.9975 0.9975 ]\n", 1109 | "epoch: 19700\n", 1110 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1111 | "epoch: 19800\n", 1112 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1113 | "epoch: 19900\n", 1114 | "[0.2 0.99208333 0.99208333 0.99208333 0.99208333 0.99208333]\n", 1115 | "epoch: 20000\n", 1116 | "[0.2 0.99166667 0.99166667 0.99166667 0.99166667 0.99166667]\n", 1117 | "在mean process之前: (992, 6)\n", 1118 | "测试集准确率: [0.2 0.9688 0.9688 0.9688 0.9688 0.9688]\n", 1119 | "epoch: 20100\n", 1120 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1121 | "epoch: 20200\n", 1122 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1123 | "epoch: 20300\n", 1124 | "[0.2 0.99208333 0.99208333 0.99208333 0.99208333 0.99208333]\n", 1125 | "epoch: 20400\n", 1126 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1127 | "epoch: 20500\n", 1128 | "[0.2 0.99041667 0.99041667 0.99 0.99 0.99 ]\n", 1129 | "epoch: 20600\n", 1130 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1131 | "epoch: 20700\n", 1132 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1133 | "epoch: 20800\n", 1134 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1135 | "epoch: 20900\n", 1136 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1137 | "epoch: 21000\n", 1138 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1139 | "在mean process之前: (992, 6)\n", 1140 | "测试集准确率: [0.2 0.9717 0.9717 0.9717 0.9717 0.9717]\n", 1141 | "epoch: 21100\n", 1142 | "[0.2 0.99333333 0.995 0.995 0.995 0.995 ]\n", 1143 | "epoch: 21200\n", 1144 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1145 | "epoch: 21300\n", 1146 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1147 | "epoch: 21400\n", 1148 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1149 | "epoch: 21500\n", 1150 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1151 | "epoch: 21600\n", 1152 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1153 | "epoch: 21700\n", 1154 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1155 | "epoch: 21800\n", 1156 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1157 | "epoch: 21900\n", 1158 | "[0.2 0.98958333 0.99041667 0.99041667 0.99041667 0.99041667]\n", 1159 | "epoch: 22000\n", 1160 | "[0.2 0.99541667 0.99541667 0.99541667 0.99583333 0.99583333]\n", 1161 | "在mean process之前: (992, 6)\n", 1162 | "测试集准确率: [0.2 0.972 0.972 0.972 0.972 0.972]\n", 1163 | "epoch: 22100\n", 1164 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1165 | "epoch: 22200\n", 1166 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1167 | "epoch: 22300\n", 1168 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1169 | "epoch: 22400\n", 1170 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1171 | "epoch: 22500\n", 1172 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1173 | "epoch: 22600\n", 1174 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1175 | "epoch: 22700\n", 1176 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1177 | "epoch: 22800\n", 1178 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1179 | "epoch: 22900\n", 1180 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1181 | "epoch: 23000\n", 1182 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1183 | "在mean process之前: (992, 6)\n", 1184 | "测试集准确率: [0.2 0.97 0.9707 0.9707 0.9707 0.9707]\n", 1185 | "epoch: 23100\n", 1186 | "[0.2 0.98958333 0.98958333 0.98958333 0.99 0.98958333]\n", 1187 | "epoch: 23200\n", 1188 | "[0.2 0.9875 0.9875 0.9875 0.9875 0.9875]\n", 1189 | "epoch: 23300\n", 1190 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1191 | "epoch: 23400\n", 1192 | "[0.2 0.995 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1193 | "epoch: 23500\n", 1194 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1195 | "epoch: 23600\n", 1196 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1197 | "epoch: 23700\n", 1198 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1199 | "epoch: 23800\n", 1200 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1201 | "epoch: 23900\n", 1202 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1203 | "epoch: 24000\n", 1204 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1205 | "在mean process之前: (992, 6)\n", 1206 | "测试集准确率: [0.2 0.9736 0.9736 0.9736 0.9736 0.9736]\n", 1207 | "epoch: 24100\n", 1208 | "[0.2 0.99041667 0.99041667 0.99041667 0.99083333 0.99083333]\n", 1209 | "epoch: 24200\n", 1210 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1211 | "epoch: 24300\n", 1212 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1213 | "epoch: 24400\n", 1214 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1215 | "epoch: 24500\n", 1216 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1217 | "epoch: 24600\n", 1218 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1219 | "epoch: 24700\n", 1220 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1221 | "epoch: 24800\n", 1222 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1223 | "epoch: 24900\n", 1224 | "[0.2 0.99291667 0.99291667 0.99291667 0.99291667 0.99291667]\n", 1225 | "epoch: 25000\n", 1226 | "[0.2 0.99291667 0.99291667 0.99291667 0.99291667 0.99291667]\n", 1227 | "在mean process之前: (992, 6)\n", 1228 | "测试集准确率: [0.2 0.9707 0.9707 0.9707 0.9707 0.9707]\n", 1229 | "epoch: 25100\n", 1230 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1231 | "epoch: 25200\n", 1232 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1233 | "epoch: 25300\n", 1234 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 1235 | "epoch: 25400\n", 1236 | "[0.2 0.99083333 0.99083333 0.99125 0.99125 0.99125 ]\n", 1237 | "epoch: 25500\n", 1238 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1239 | "epoch: 25600\n", 1240 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1241 | "epoch: 25700\n", 1242 | "[0.2 0.9925 0.9925 0.9925 0.9925 0.9925]\n", 1243 | "epoch: 25800\n", 1244 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1245 | "epoch: 25900\n", 1246 | "[0.2 0.99666667 0.99708333 0.9975 0.99791667 0.99791667]\n", 1247 | "epoch: 26000\n", 1248 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1249 | "在mean process之前: (992, 6)\n", 1250 | "测试集准确率: [0.2 0.97 0.97 0.97 0.97 0.97]\n", 1251 | "epoch: 26100\n", 1252 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1253 | "epoch: 26200\n", 1254 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1255 | "epoch: 26300\n", 1256 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1257 | "epoch: 26400\n", 1258 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1259 | "epoch: 26500\n", 1260 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1261 | "epoch: 26600\n", 1262 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1263 | "epoch: 26700\n", 1264 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1265 | "epoch: 26800\n", 1266 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1267 | "epoch: 26900\n", 1268 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1269 | "epoch: 27000\n", 1270 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1271 | "在mean process之前: (992, 6)\n", 1272 | "测试集准确率: [0.2 0.9717 0.9717 0.9717 0.9717 0.9717]\n", 1273 | "epoch: 27100\n", 1274 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1275 | "epoch: 27200\n", 1276 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1277 | "epoch: 27300\n", 1278 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1279 | "epoch: 27400\n", 1280 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1281 | "epoch: 27500\n", 1282 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1283 | "epoch: 27600\n", 1284 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1285 | "epoch: 27700\n", 1286 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1287 | "epoch: 27800\n", 1288 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1289 | "epoch: 27900\n", 1290 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1291 | "epoch: 28000\n", 1292 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1293 | "在mean process之前: (992, 6)\n", 1294 | "测试集准确率: [0.2 0.972 0.972 0.972 0.972 0.972]\n", 1295 | "epoch: 28100\n", 1296 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1297 | "epoch: 28200\n", 1298 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1299 | "epoch: 28300\n", 1300 | "[0.2 0.99333333 0.99333333 0.99333333 0.99333333 0.99333333]\n", 1301 | "epoch: 28400\n", 1302 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1303 | "epoch: 28500\n", 1304 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1305 | "epoch: 28600\n", 1306 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1307 | "epoch: 28700\n", 1308 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1309 | "epoch: 28800\n", 1310 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1311 | "epoch: 28900\n", 1312 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1313 | "epoch: 29000\n", 1314 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1315 | "在mean process之前: (992, 6)\n", 1316 | "测试集准确率: [0.2 0.9697 0.9697 0.9697 0.9697 0.9697]\n", 1317 | "epoch: 29100\n", 1318 | "[0.2 0.98708333 0.98708333 0.98708333 0.98583333 0.98583333]\n" 1319 | ] 1320 | }, 1321 | { 1322 | "name": "stdout", 1323 | "output_type": "stream", 1324 | "text": [ 1325 | "epoch: 29200\n", 1326 | "[0.2 1. 1. 1. 1. 0.99958333]\n", 1327 | "epoch: 29300\n", 1328 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1329 | "epoch: 29400\n", 1330 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1331 | "epoch: 29500\n", 1332 | "[0.2 0.9925 0.99333333 0.99416667 0.99416667 0.99416667]\n", 1333 | "epoch: 29600\n", 1334 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1335 | "epoch: 29700\n", 1336 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1337 | "epoch: 29800\n", 1338 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1339 | "epoch: 29900\n", 1340 | "[0.2 0.98708333 0.98708333 0.9875 0.98708333 0.98708333]\n", 1341 | "epoch: 30000\n", 1342 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1343 | "在mean process之前: (992, 6)\n", 1344 | "测试集准确率: [0.2 0.97 0.97 0.97 0.97 0.97]\n", 1345 | "epoch: 30100\n", 1346 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1347 | "epoch: 30200\n", 1348 | "[0.2 0.9925 0.9925 0.9925 0.9925 0.9925]\n", 1349 | "epoch: 30300\n", 1350 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1351 | "epoch: 30400\n", 1352 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1353 | "epoch: 30500\n", 1354 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1355 | "epoch: 30600\n", 1356 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1357 | "epoch: 30700\n", 1358 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1359 | "epoch: 30800\n", 1360 | "[0.2 0.99416667 0.99375 0.99375 0.99375 0.99375 ]\n", 1361 | "epoch: 30900\n", 1362 | "[0.2 0.99333333 0.99333333 0.99333333 0.99333333 0.99333333]\n", 1363 | "epoch: 31000\n", 1364 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1365 | "在mean process之前: (992, 6)\n", 1366 | "测试集准确率: [0.2 0.971 0.971 0.971 0.971 0.971]\n", 1367 | "epoch: 31100\n", 1368 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1369 | "epoch: 31200\n", 1370 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1371 | "epoch: 31300\n", 1372 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1373 | "epoch: 31400\n", 1374 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1375 | "epoch: 31500\n", 1376 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1377 | "epoch: 31600\n", 1378 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1379 | "epoch: 31700\n", 1380 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1381 | "epoch: 31800\n", 1382 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1383 | "epoch: 31900\n", 1384 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 1385 | "epoch: 32000\n", 1386 | "[0.2 0.99208333 0.99208333 0.99208333 0.99208333 0.99208333]\n", 1387 | "在mean process之前: (992, 6)\n", 1388 | "测试集准确率: [0.2 0.9707 0.9707 0.9707 0.9707 0.9707]\n", 1389 | "epoch: 32100\n", 1390 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1391 | "epoch: 32200\n", 1392 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1393 | "epoch: 32300\n", 1394 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1395 | "epoch: 32400\n", 1396 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1397 | "epoch: 32500\n", 1398 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1399 | "epoch: 32600\n", 1400 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1401 | "epoch: 32700\n", 1402 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1403 | "epoch: 32800\n", 1404 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1405 | "epoch: 32900\n", 1406 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1407 | "epoch: 33000\n", 1408 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1409 | "在mean process之前: (992, 6)\n", 1410 | "测试集准确率: [0.2 0.9683 0.9683 0.9683 0.9683 0.9683]\n", 1411 | "epoch: 33100\n", 1412 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1413 | "epoch: 33200\n", 1414 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1415 | "epoch: 33300\n", 1416 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1417 | "epoch: 33400\n", 1418 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1419 | "epoch: 33500\n", 1420 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1421 | "epoch: 33600\n", 1422 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99583333]\n", 1423 | "epoch: 33700\n", 1424 | "[0.2 1. 1. 1. 1. 1. ]\n", 1425 | "epoch: 33800\n", 1426 | "[0.2 0.99625 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1427 | "epoch: 33900\n", 1428 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1429 | "epoch: 34000\n", 1430 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1431 | "在mean process之前: (992, 6)\n", 1432 | "测试集准确率: [0.2 0.967 0.967 0.967 0.967 0.967]\n", 1433 | "epoch: 34100\n", 1434 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1435 | "epoch: 34200\n", 1436 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1437 | "epoch: 34300\n", 1438 | "[0.2 0.99333333 0.99416667 0.99541667 0.99541667 0.99541667]\n", 1439 | "epoch: 34400\n", 1440 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1441 | "epoch: 34500\n", 1442 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1443 | "epoch: 34600\n", 1444 | "[0.2 0.99208333 0.99208333 0.99208333 0.99208333 0.99208333]\n", 1445 | "epoch: 34700\n", 1446 | "[0.2 1. 1. 1. 1. 1. ]\n", 1447 | "epoch: 34800\n", 1448 | "[0.2 0.99333333 0.99333333 0.99333333 0.99333333 0.99333333]\n", 1449 | "epoch: 34900\n", 1450 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1451 | "epoch: 35000\n", 1452 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1453 | "在mean process之前: (992, 6)\n", 1454 | "测试集准确率: [0.2 0.9727 0.9727 0.9727 0.9727 0.9727]\n", 1455 | "epoch: 35100\n", 1456 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1457 | "epoch: 35200\n", 1458 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1459 | "epoch: 35300\n", 1460 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1461 | "epoch: 35400\n", 1462 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1463 | "epoch: 35500\n", 1464 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1465 | "epoch: 35600\n", 1466 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1467 | "epoch: 35700\n", 1468 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1469 | "epoch: 35800\n", 1470 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1471 | "epoch: 35900\n", 1472 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1473 | "epoch: 36000\n", 1474 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1475 | "在mean process之前: (992, 6)\n", 1476 | "测试集准确率: [0.2 0.9688 0.9688 0.9688 0.9688 0.9688]\n", 1477 | "epoch: 36100\n", 1478 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1479 | "epoch: 36200\n", 1480 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1481 | "epoch: 36300\n", 1482 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1483 | "epoch: 36400\n", 1484 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1485 | "epoch: 36500\n", 1486 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1487 | "epoch: 36600\n", 1488 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1489 | "epoch: 36700\n", 1490 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1491 | "epoch: 36800\n", 1492 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1493 | "epoch: 36900\n", 1494 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1495 | "epoch: 37000\n", 1496 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1497 | "在mean process之前: (992, 6)\n", 1498 | "测试集准确率: [0.2 0.97 0.97 0.97 0.97 0.97]\n", 1499 | "epoch: 37100\n", 1500 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1501 | "epoch: 37200\n", 1502 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1503 | "epoch: 37300\n", 1504 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1505 | "epoch: 37400\n", 1506 | "[0.2 0.99625 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1507 | "epoch: 37500\n", 1508 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1509 | "epoch: 37600\n", 1510 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1511 | "epoch: 37700\n", 1512 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1513 | "epoch: 37800\n", 1514 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1515 | "epoch: 37900\n", 1516 | "[0.2 1. 1. 1. 1. 1. ]\n", 1517 | "epoch: 38000\n", 1518 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1519 | "在mean process之前: (992, 6)\n", 1520 | "测试集准确率: [0.2 0.971 0.971 0.971 0.971 0.971]\n", 1521 | "epoch: 38100\n", 1522 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1523 | "epoch: 38200\n", 1524 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1525 | "epoch: 38300\n", 1526 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1527 | "epoch: 38400\n", 1528 | "[0.2 1. 1. 1. 1. 1. ]\n", 1529 | "epoch: 38500\n", 1530 | "[0.2 0.99333333 0.99416667 0.99416667 0.99458333 0.99458333]\n", 1531 | "epoch: 38600\n", 1532 | "[0.2 1. 1. 1. 1. 1. ]\n", 1533 | "epoch: 38700\n", 1534 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1535 | "epoch: 38800\n", 1536 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1537 | "epoch: 38900\n", 1538 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1539 | "epoch: 39000\n", 1540 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1541 | "在mean process之前: (992, 6)\n", 1542 | "测试集准确率: [0.2 0.97 0.97 0.97 0.97 0.97]\n", 1543 | "epoch: 39100\n", 1544 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1545 | "epoch: 39200\n", 1546 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1547 | "epoch: 39300\n", 1548 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1549 | "epoch: 39400\n", 1550 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n" 1551 | ] 1552 | }, 1553 | { 1554 | "name": "stdout", 1555 | "output_type": "stream", 1556 | "text": [ 1557 | "epoch: 39500\n", 1558 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1559 | "epoch: 39600\n", 1560 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1561 | "epoch: 39700\n", 1562 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1563 | "epoch: 39800\n", 1564 | "[0.2 1. 1. 1. 1. 1. ]\n", 1565 | "epoch: 39900\n", 1566 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1567 | "epoch: 40000\n", 1568 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1569 | "在mean process之前: (992, 6)\n", 1570 | "测试集准确率: [0.2 0.966 0.966 0.966 0.966 0.966]\n", 1571 | "epoch: 40100\n", 1572 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1573 | "epoch: 40200\n", 1574 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1575 | "epoch: 40300\n", 1576 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1577 | "epoch: 40400\n", 1578 | "[0.2 0.99375 0.99375 0.99375 0.99375 0.99375]\n", 1579 | "epoch: 40500\n", 1580 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1581 | "epoch: 40600\n", 1582 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1583 | "epoch: 40700\n", 1584 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1585 | "epoch: 40800\n", 1586 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1587 | "epoch: 40900\n", 1588 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1589 | "epoch: 41000\n", 1590 | "[0.2 0.99291667 0.99291667 0.99291667 0.99291667 0.99291667]\n", 1591 | "在mean process之前: (992, 6)\n", 1592 | "测试集准确率: [0.2 0.9688 0.9688 0.9688 0.9688 0.9688]\n", 1593 | "epoch: 41100\n", 1594 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1595 | "epoch: 41200\n", 1596 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1597 | "epoch: 41300\n", 1598 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1599 | "epoch: 41400\n", 1600 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1601 | "epoch: 41500\n", 1602 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1603 | "epoch: 41600\n", 1604 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1605 | "epoch: 41700\n", 1606 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1607 | "epoch: 41800\n", 1608 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1609 | "epoch: 41900\n", 1610 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1611 | "epoch: 42000\n", 1612 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1613 | "在mean process之前: (992, 6)\n", 1614 | "测试集准确率: [0.2 0.968 0.968 0.968 0.968 0.968]\n", 1615 | "epoch: 42100\n", 1616 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1617 | "epoch: 42200\n", 1618 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1619 | "epoch: 42300\n", 1620 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1621 | "epoch: 42400\n", 1622 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1623 | "epoch: 42500\n", 1624 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1625 | "epoch: 42600\n", 1626 | "[0.2 0.995 0.995 0.99458333 0.99458333 0.99458333]\n", 1627 | "epoch: 42700\n", 1628 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1629 | "epoch: 42800\n", 1630 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1631 | "epoch: 42900\n", 1632 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1633 | "epoch: 43000\n", 1634 | "[0.2 0.99291667 0.99291667 0.99291667 0.99291667 0.99291667]\n", 1635 | "在mean process之前: (992, 6)\n", 1636 | "测试集准确率: [0.2 0.9697 0.9697 0.9697 0.9697 0.9697]\n", 1637 | "epoch: 43100\n", 1638 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1639 | "epoch: 43200\n", 1640 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 1641 | "epoch: 43300\n", 1642 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1643 | "epoch: 43400\n", 1644 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1645 | "epoch: 43500\n", 1646 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1647 | "epoch: 43600\n", 1648 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1649 | "epoch: 43700\n", 1650 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1651 | "epoch: 43800\n", 1652 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1653 | "epoch: 43900\n", 1654 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1655 | "epoch: 44000\n", 1656 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1657 | "在mean process之前: (992, 6)\n", 1658 | "测试集准确率: [0.2 0.969 0.969 0.969 0.969 0.969]\n", 1659 | "epoch: 44100\n", 1660 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1661 | "epoch: 44200\n", 1662 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1663 | "epoch: 44300\n", 1664 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1665 | "epoch: 44400\n", 1666 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1667 | "epoch: 44500\n", 1668 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1669 | "epoch: 44600\n", 1670 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1671 | "epoch: 44700\n", 1672 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1673 | "epoch: 44800\n", 1674 | "[0.2 0.98708333 0.98708333 0.98708333 0.98666667 0.98666667]\n", 1675 | "epoch: 44900\n", 1676 | "[0.2 1. 1. 1. 1. 1. ]\n", 1677 | "epoch: 45000\n", 1678 | "[0.2 1. 1. 1. 1. 1. ]\n", 1679 | "在mean process之前: (992, 6)\n", 1680 | "测试集准确率: [0.2 0.972 0.972 0.972 0.972 0.972]\n", 1681 | "epoch: 45100\n", 1682 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1683 | "epoch: 45200\n", 1684 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1685 | "epoch: 45300\n", 1686 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1687 | "epoch: 45400\n", 1688 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1689 | "epoch: 45500\n", 1690 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1691 | "epoch: 45600\n", 1692 | "[0.2 0.99541667 0.99541667 0.99541667 0.99583333 0.99583333]\n", 1693 | "epoch: 45700\n", 1694 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1695 | "epoch: 45800\n", 1696 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1697 | "epoch: 45900\n", 1698 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1699 | "epoch: 46000\n", 1700 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1701 | "在mean process之前: (992, 6)\n", 1702 | "测试集准确率: [0.2 0.969 0.969 0.969 0.969 0.969]\n", 1703 | "epoch: 46100\n", 1704 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1705 | "epoch: 46200\n", 1706 | "[0.2 1. 1. 1. 1. 1. ]\n", 1707 | "epoch: 46300\n", 1708 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1709 | "epoch: 46400\n", 1710 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1711 | "epoch: 46500\n", 1712 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99833333]\n", 1713 | "epoch: 46600\n", 1714 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1715 | "epoch: 46700\n", 1716 | "[0.2 0.99 0.99 0.99 0.99 0.99]\n", 1717 | "epoch: 46800\n", 1718 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1719 | "epoch: 46900\n", 1720 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1721 | "epoch: 47000\n", 1722 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1723 | "在mean process之前: (992, 6)\n", 1724 | "测试集准确率: [0.2 0.9697 0.9697 0.9697 0.9697 0.9697]\n", 1725 | "epoch: 47100\n", 1726 | "[0.2 1. 1. 1. 1. 1. ]\n", 1727 | "epoch: 47200\n", 1728 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1729 | "epoch: 47300\n", 1730 | "[0.2 0.99041667 0.99041667 0.99041667 0.99041667 0.99041667]\n", 1731 | "epoch: 47400\n", 1732 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1733 | "epoch: 47500\n", 1734 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1735 | "epoch: 47600\n", 1736 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1737 | "epoch: 47700\n", 1738 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1739 | "epoch: 47800\n", 1740 | "[0.2 0.9975 0.9975 0.99791667 0.99833333 0.99833333]\n", 1741 | "epoch: 47900\n", 1742 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1743 | "epoch: 48000\n", 1744 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1745 | "在mean process之前: (992, 6)\n", 1746 | "测试集准确率: [0.2 0.9673 0.9673 0.9673 0.9673 0.9673]\n", 1747 | "epoch: 48100\n", 1748 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1749 | "epoch: 48200\n", 1750 | "[0.2 1. 1. 1. 1. 1. ]\n", 1751 | "epoch: 48300\n", 1752 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1753 | "epoch: 48400\n", 1754 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1755 | "epoch: 48500\n", 1756 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1757 | "epoch: 48600\n", 1758 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1759 | "epoch: 48700\n", 1760 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n", 1761 | "epoch: 48800\n", 1762 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1763 | "epoch: 48900\n", 1764 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1765 | "epoch: 49000\n", 1766 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1767 | "在mean process之前: (992, 6)\n", 1768 | "测试集准确率: [0.2 0.971 0.971 0.971 0.971 0.971]\n", 1769 | "epoch: 49100\n", 1770 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1771 | "epoch: 49200\n", 1772 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1773 | "epoch: 49300\n", 1774 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1775 | "epoch: 49400\n", 1776 | "[0.2 1. 1. 1. 1. 1. ]\n", 1777 | "epoch: 49500\n", 1778 | "[0.2 0.99333333 0.99416667 0.99458333 0.99458333 0.99458333]\n", 1779 | "epoch: 49600\n", 1780 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1781 | "epoch: 49700\n", 1782 | "[0.2 0.99583333 0.99583333 0.99583333 0.99583333 0.99583333]\n" 1783 | ] 1784 | }, 1785 | { 1786 | "name": "stdout", 1787 | "output_type": "stream", 1788 | "text": [ 1789 | "epoch: 49800\n", 1790 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1791 | "epoch: 49900\n", 1792 | "[0.2 0.99083333 0.99083333 0.99083333 0.99083333 0.99083333]\n", 1793 | "epoch: 50000\n", 1794 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1795 | "在mean process之前: (992, 6)\n", 1796 | "测试集准确率: [0.2 0.9683 0.9683 0.9683 0.9683 0.9683]\n", 1797 | "epoch: 50100\n", 1798 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1799 | "epoch: 50200\n", 1800 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1801 | "epoch: 50300\n", 1802 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1803 | "epoch: 50400\n", 1804 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1805 | "epoch: 50500\n", 1806 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1807 | "epoch: 50600\n", 1808 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1809 | "epoch: 50700\n", 1810 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1811 | "epoch: 50800\n", 1812 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1813 | "epoch: 50900\n", 1814 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1815 | "epoch: 51000\n", 1816 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1817 | "在mean process之前: (992, 6)\n", 1818 | "测试集准确率: [0.2 0.969 0.969 0.969 0.9697 0.9697]\n", 1819 | "epoch: 51100\n", 1820 | "[0.2 0.99333333 0.99333333 0.99333333 0.99333333 0.99333333]\n", 1821 | "epoch: 51200\n", 1822 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1823 | "epoch: 51300\n", 1824 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1825 | "epoch: 51400\n", 1826 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1827 | "epoch: 51500\n", 1828 | "[0.2 0.99208333 0.99291667 0.99083333 0.99291667 0.99166667]\n", 1829 | "epoch: 51600\n", 1830 | "[0.2 1. 1. 1. 1. 1. ]\n", 1831 | "epoch: 51700\n", 1832 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1833 | "epoch: 51800\n", 1834 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1835 | "epoch: 51900\n", 1836 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1837 | "epoch: 52000\n", 1838 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1839 | "在mean process之前: (992, 6)\n", 1840 | "测试集准确率: [0.2 0.97 0.97 0.97 0.97 0.97]\n", 1841 | "epoch: 52100\n", 1842 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1843 | "epoch: 52200\n", 1844 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1845 | "epoch: 52300\n", 1846 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1847 | "epoch: 52400\n", 1848 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1849 | "epoch: 52500\n", 1850 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1851 | "epoch: 52600\n", 1852 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1853 | "epoch: 52700\n", 1854 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1855 | "epoch: 52800\n", 1856 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1857 | "epoch: 52900\n", 1858 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1859 | "epoch: 53000\n", 1860 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1861 | "在mean process之前: (992, 6)\n", 1862 | "测试集准确率: [0.2 0.9697 0.9697 0.9697 0.9697 0.9697]\n", 1863 | "epoch: 53100\n", 1864 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1865 | "epoch: 53200\n", 1866 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1867 | "epoch: 53300\n", 1868 | "[0.2 0.9975 0.9975 0.99791667 0.99791667 0.99791667]\n", 1869 | "epoch: 53400\n", 1870 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1871 | "epoch: 53500\n", 1872 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1873 | "epoch: 53600\n", 1874 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1875 | "epoch: 53700\n", 1876 | "[0.2 0.9925 0.99291667 0.99333333 0.99333333 0.99333333]\n", 1877 | "epoch: 53800\n", 1878 | "[0.2 1. 1. 1. 1. 1. ]\n", 1879 | "epoch: 53900\n", 1880 | "[0.2 1. 1. 1. 1. 1. ]\n", 1881 | "epoch: 54000\n", 1882 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 1883 | "在mean process之前: (992, 6)\n", 1884 | "测试集准确率: [0.2 0.9688 0.9688 0.9688 0.9688 0.9688]\n", 1885 | "epoch: 54100\n", 1886 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1887 | "epoch: 54200\n", 1888 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1889 | "epoch: 54300\n", 1890 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1891 | "epoch: 54400\n", 1892 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1893 | "epoch: 54500\n", 1894 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1895 | "epoch: 54600\n", 1896 | "[0.2 1. 1. 1. 1. 1. ]\n", 1897 | "epoch: 54700\n", 1898 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1899 | "epoch: 54800\n", 1900 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1901 | "epoch: 54900\n", 1902 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 1903 | "epoch: 55000\n", 1904 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1905 | "在mean process之前: (992, 6)\n", 1906 | "测试集准确率: [0.2 0.972 0.972 0.972 0.972 0.972]\n", 1907 | "epoch: 55100\n", 1908 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1909 | "epoch: 55200\n", 1910 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1911 | "epoch: 55300\n", 1912 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1913 | "epoch: 55400\n", 1914 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 1915 | "epoch: 55500\n", 1916 | "[0.2 0.99666667 0.99666667 0.99666667 0.99666667 0.99666667]\n", 1917 | "epoch: 55600\n", 1918 | "[0.2 0.9875 0.9875 0.9875 0.9875 0.9875]\n", 1919 | "epoch: 55700\n", 1920 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1921 | "epoch: 55800\n", 1922 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1923 | "epoch: 55900\n", 1924 | "[0.2 1. 1. 1. 1. 1. ]\n", 1925 | "epoch: 56000\n", 1926 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1927 | "在mean process之前: (992, 6)\n", 1928 | "测试集准确率: [0.2 0.969 0.969 0.969 0.969 0.969]\n", 1929 | "epoch: 56100\n", 1930 | "[0.2 0.99708333 0.99708333 0.99708333 0.99708333 0.99708333]\n", 1931 | "epoch: 56200\n", 1932 | "[0.2 1. 1. 1. 1. 1. ]\n", 1933 | "epoch: 56300\n", 1934 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1935 | "epoch: 56400\n", 1936 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1937 | "epoch: 56500\n", 1938 | "[0.2 0.99416667 0.99416667 0.99416667 0.99416667 0.99416667]\n", 1939 | "epoch: 56600\n", 1940 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1941 | "epoch: 56700\n", 1942 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1943 | "epoch: 56800\n", 1944 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1945 | "epoch: 56900\n", 1946 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1947 | "epoch: 57000\n", 1948 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 1949 | "在mean process之前: (992, 6)\n", 1950 | "测试集准确率: [0.2 0.9688 0.9688 0.9688 0.9688 0.9688]\n", 1951 | "epoch: 57100\n", 1952 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1953 | "epoch: 57200\n", 1954 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1955 | "epoch: 57300\n", 1956 | "[0.2 1. 1. 1. 1. 1. ]\n", 1957 | "epoch: 57400\n", 1958 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1959 | "epoch: 57500\n", 1960 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1961 | "epoch: 57600\n", 1962 | "[0.2 0.995 0.995 0.995 0.995 0.995]\n", 1963 | "epoch: 57700\n", 1964 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1965 | "epoch: 57800\n", 1966 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1967 | "epoch: 57900\n", 1968 | "[0.2 1. 1. 1. 1. 1. ]\n", 1969 | "epoch: 58000\n", 1970 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1971 | "在mean process之前: (992, 6)\n", 1972 | "测试集准确率: [0.2 0.9697 0.9697 0.9697 0.9697 0.9697]\n", 1973 | "epoch: 58100\n", 1974 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 1975 | "epoch: 58200\n", 1976 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 1977 | "epoch: 58300\n", 1978 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1979 | "epoch: 58400\n", 1980 | "[0.2 0.99458333 0.99458333 0.99458333 0.99458333 0.99458333]\n", 1981 | "epoch: 58500\n", 1982 | "[0.2 0.99625 0.99625 0.99625 0.99625 0.99625]\n", 1983 | "epoch: 58600\n", 1984 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1985 | "epoch: 58700\n", 1986 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 1987 | "epoch: 58800\n", 1988 | "[0.2 1. 1. 1. 1. 1. ]\n", 1989 | "epoch: 58900\n", 1990 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 1991 | "epoch: 59000\n", 1992 | "[0.2 0.99291667 0.99291667 0.99291667 0.99291667 0.99291667]\n", 1993 | "在mean process之前: (992, 6)\n", 1994 | "测试集准确率: [0.2 0.9673 0.9673 0.9673 0.9673 0.9673]\n", 1995 | "epoch: 59100\n", 1996 | "[0.2 0.99916667 0.99916667 0.99916667 0.99916667 0.99916667]\n", 1997 | "epoch: 59200\n", 1998 | "[0.2 1. 1. 1. 1. 1. ]\n", 1999 | "epoch: 59300\n", 2000 | "[0.2 0.99958333 0.99958333 0.99958333 0.99958333 0.99958333]\n", 2001 | "epoch: 59400\n", 2002 | "[0.2 1. 1. 1. 1. 1. ]\n", 2003 | "epoch: 59500\n", 2004 | "[0.2 0.99833333 0.99833333 0.99833333 0.99833333 0.99833333]\n", 2005 | "epoch: 59600\n", 2006 | "[0.2 0.99875 0.99875 0.99875 0.99875 0.99875]\n", 2007 | "epoch: 59700\n", 2008 | "[0.2 0.99791667 0.99791667 0.99791667 0.99791667 0.99791667]\n", 2009 | "epoch: 59800\n", 2010 | "[0.2 0.9975 0.9975 0.9975 0.9975 0.9975]\n", 2011 | "epoch: 59900\n", 2012 | "[0.2 0.99625 0.99708333 0.99708333 0.9975 0.99791667]\n", 2013 | "epoch: 60000\n", 2014 | "[0.2 0.99541667 0.99541667 0.99541667 0.99541667 0.99541667]\n", 2015 | "在mean process之前: (992, 6)\n", 2016 | "测试集准确率: [0.2 0.968 0.9683 0.9683 0.9683 0.9683]\n" 2017 | ] 2018 | } 2019 | ], 2020 | "source": [ 2021 | "## omniglot\n", 2022 | "import random\n", 2023 | "random.seed(1337)\n", 2024 | "np.random.seed(1337)\n", 2025 | "\n", 2026 | "import time\n", 2027 | "device = torch.device('cuda')\n", 2028 | "\n", 2029 | "meta = MetaLearner().to(device)\n", 2030 | "\n", 2031 | "epochs = 60001\n", 2032 | "for step in range(epochs):\n", 2033 | " start = time.time()\n", 2034 | " x_spt, y_spt, x_qry, y_qry = next('train')\n", 2035 | " x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device),\\\n", 2036 | " torch.from_numpy(y_spt).to(device),\\\n", 2037 | " torch.from_numpy(x_qry).to(device),\\\n", 2038 | " torch.from_numpy(y_qry).to(device)\n", 2039 | " accs,loss = meta(x_spt, y_spt, x_qry, y_qry)\n", 2040 | " end = time.time()\n", 2041 | " if step % 100 == 0:\n", 2042 | " print(\"epoch:\" ,step)\n", 2043 | " print(accs)\n", 2044 | "# print(loss)\n", 2045 | " \n", 2046 | " if step % 1000 == 0:\n", 2047 | " accs = []\n", 2048 | " for _ in range(1000//task_num):\n", 2049 | " # db_train.next('test')\n", 2050 | " x_spt, y_spt, x_qry, y_qry = next('test')\n", 2051 | " x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device),\\\n", 2052 | " torch.from_numpy(y_spt).to(device),\\\n", 2053 | " torch.from_numpy(x_qry).to(device),\\\n", 2054 | " torch.from_numpy(y_qry).to(device)\n", 2055 | "\n", 2056 | " \n", 2057 | " for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):\n", 2058 | " test_acc = meta.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)\n", 2059 | " accs.append(test_acc)\n", 2060 | " print('在mean process之前:',np.array(accs).shape)\n", 2061 | " accs = np.array(accs).mean(axis=0).astype(np.float16)\n", 2062 | " print('测试集准确率:',accs)" 2063 | ] 2064 | }, 2065 | { 2066 | "cell_type": "code", 2067 | "execution_count": null, 2068 | "metadata": { 2069 | "ExecuteTime": { 2070 | "end_time": "2020-03-01T03:00:56.266331Z", 2071 | "start_time": "2020-03-01T03:00:56.205955Z" 2072 | } 2073 | }, 2074 | "outputs": [], 2075 | "source": [ 2076 | "\n" 2077 | ] 2078 | }, 2079 | { 2080 | "cell_type": "code", 2081 | "execution_count": null, 2082 | "metadata": {}, 2083 | "outputs": [], 2084 | "source": [] 2085 | }, 2086 | { 2087 | "cell_type": "code", 2088 | "execution_count": null, 2089 | "metadata": {}, 2090 | "outputs": [], 2091 | "source": [] 2092 | } 2093 | ], 2094 | "metadata": { 2095 | "kernelspec": { 2096 | "display_name": "ML3.6", 2097 | "language": "python", 2098 | "name": "ml3.6" 2099 | }, 2100 | "language_info": { 2101 | "codemirror_mode": { 2102 | "name": "ipython", 2103 | "version": 3 2104 | }, 2105 | "file_extension": ".py", 2106 | "mimetype": "text/x-python", 2107 | "name": "python", 2108 | "nbconvert_exporter": "python", 2109 | "pygments_lexer": "ipython3", 2110 | "version": "3.6.9" 2111 | }, 2112 | "latex_envs": { 2113 | "LaTeX_envs_menu_present": true, 2114 | "autoclose": false, 2115 | "autocomplete": true, 2116 | "bibliofile": "biblio.bib", 2117 | "cite_by": "apalike", 2118 | "current_citInitial": 1, 2119 | "eqLabelWithNumbers": true, 2120 | "eqNumInitial": 1, 2121 | "hotkeys": { 2122 | "equation": "Ctrl-E", 2123 | "itemize": "Ctrl-I" 2124 | }, 2125 | "labels_anchors": false, 2126 | "latex_user_defs": false, 2127 | "report_style_numbering": false, 2128 | "user_envs_cfg": false 2129 | }, 2130 | "toc": { 2131 | "base_numbering": 1, 2132 | "nav_menu": {}, 2133 | "number_sections": true, 2134 | "sideBar": true, 2135 | "skip_h1_title": false, 2136 | "title_cell": "Table of Contents", 2137 | "title_sidebar": "Contents", 2138 | "toc_cell": false, 2139 | "toc_position": {}, 2140 | "toc_section_display": true, 2141 | "toc_window_display": false 2142 | }, 2143 | "varInspector": { 2144 | "cols": { 2145 | "lenName": 16, 2146 | "lenType": 16, 2147 | "lenVar": 40 2148 | }, 2149 | "kernels_config": { 2150 | "python": { 2151 | "delete_cmd_postfix": "", 2152 | "delete_cmd_prefix": "del ", 2153 | "library": "var_list.py", 2154 | "varRefreshCmd": "print(var_dic_list())" 2155 | }, 2156 | "r": { 2157 | "delete_cmd_postfix": ") ", 2158 | "delete_cmd_prefix": "rm(", 2159 | "library": "var_list.r", 2160 | "varRefreshCmd": "cat(var_dic_list()) " 2161 | } 2162 | }, 2163 | "types_to_exclude": [ 2164 | "module", 2165 | "function", 2166 | "builtin_function_or_method", 2167 | "instance", 2168 | "_Feature" 2169 | ], 2170 | "window_display": false 2171 | } 2172 | }, 2173 | "nbformat": 4, 2174 | "nbformat_minor": 2 2175 | } 2176 | -------------------------------------------------------------------------------- /Preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2020-03-01T06:48:09.450393Z", 9 | "start_time": "2020-03-01T06:48:09.123812Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import torch\n", 15 | "import numpy as np\n", 16 | "import os\n", 17 | "import zipfile\n", 18 | "\n", 19 | "root_path = './../datasets'\n", 20 | "processed_folder = os.path.join(root_path)\n", 21 | "\n", 22 | "zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot_standard.zip'), 'r')\n", 23 | "zip_ref.extractall(root_path)\n", 24 | "zip_ref.close()\n", 25 | "\n", 26 | "root_dir = './../datasets/omniglot/python'\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# 数据预处理\n", 36 | "\n", 37 | "import torchvision.transforms as transforms\n", 38 | "from PIL import Image\n", 39 | "\n", 40 | "'''\n", 41 | "an example of img_items:\n", 42 | "( '0709_17.png',\n", 43 | " 'Alphabet_of_the_Magi/character01',\n", 44 | " './../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')\n", 45 | "'''\n", 46 | "def find_classes(root_dir):\n", 47 | " img_items = []\n", 48 | " for (root, dirs, files) in os.walk(root_dir): \n", 49 | " for file in files:\n", 50 | " if (file.endswith(\"png\")):\n", 51 | " r = root.split('/')\n", 52 | " img_items.append((file, r[-2] + \"/\" + r[-1], root))\n", 53 | " print(\"== Found %d items \" % len(img_items))\n", 54 | " return img_items\n", 55 | "\n", 56 | "## 构建一个词典{class:idx}\n", 57 | "def index_classes(items):\n", 58 | " class_idx = {}\n", 59 | " count = 0\n", 60 | " for item in items:\n", 61 | " if item[1] not in class_idx:\n", 62 | " class_idx[item[1]] = count\n", 63 | " count += 1\n", 64 | " print('== Found {} classes'.format(len(class_idx)))\n", 65 | " return class_idx\n", 66 | " \n", 67 | "\n", 68 | "img_items = find_classes(root_dir)\n", 69 | "class_idx = index_classes(img_items)\n", 70 | "\n", 71 | "\n", 72 | "temp = dict()\n", 73 | "for imgname, classes, dirs in img_items:\n", 74 | " img = '{}/{}'.format(dirs, imgname)\n", 75 | " label = class_idx[classes]\n", 76 | " transform = transforms.Compose([lambda img: Image.open(img).convert('L'),\n", 77 | " lambda img: img.resize((28,28)),\n", 78 | " lambda img: np.reshape(img, (28,28,1)),\n", 79 | " lambda img: np.transpose(img, [2,0,1]),\n", 80 | " lambda img: img/255.\n", 81 | " ])\n", 82 | " img = transform(img)\n", 83 | " if label in temp.keys():\n", 84 | " temp[label].append(img)\n", 85 | " else:\n", 86 | " temp[label] = [img]\n", 87 | "print('begin to generate omniglot.npy')\n", 88 | "## 移除标签信息,每个标签包含20个样本\n", 89 | "img_list = []\n", 90 | "for label, imgs in temp.items():\n", 91 | " img_list.append(np.array(imgs))\n", 92 | "img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]\n", 93 | "print('data shape:{}'.format(img_list.shape)) # (1623, 20, 1, 28, 28)\n", 94 | "temp = []\n", 95 | "np.save(os.path.join(root_dir, 'omniglot.npy'), img_list)\n", 96 | "print('end.')" 97 | ] 98 | } 99 | ], 100 | "metadata": { 101 | "kernelspec": { 102 | "display_name": "wangkai3.6", 103 | "language": "python", 104 | "name": "wangkai3.6" 105 | }, 106 | "language_info": { 107 | "codemirror_mode": { 108 | "name": "ipython", 109 | "version": 3 110 | }, 111 | "file_extension": ".py", 112 | "mimetype": "text/x-python", 113 | "name": "python", 114 | "nbconvert_exporter": "python", 115 | "pygments_lexer": "ipython3", 116 | "version": "3.6.7" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 2 121 | } 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAML-Pythorch 2 | You can find the omniglot dataset in my repository: [datasets](https://github.com/miguealanmath/Datasets) 3 | 4 | These my best repo performance: 5 | 6 | ### Omniglot 7 | #### 5-way-1-shot 8 | MAML:$98.7 \pm 0.4\%$ 9 | 10 | **ours:$97.5\%$** 11 | 12 | #### 20-way-1-shot 13 | MAML:$95.8 \pm 0.3\%$ 14 | 15 | **ours:$84.8\%$** 16 | 17 | #### 20-way-5-shot 18 | MAML:$98.9 \pm 0.2\%$ 19 | 20 | **ours:$94.4\%$** 21 | 22 | ### miniImageNet 23 | 24 | #### 5-way-1-shot 25 | MAML:$48.70 \pm 1.84\%$ 26 | 27 | **ours:$49.15\%$** 28 | 29 | #### 5-way-5-shot 30 | MAML:$63.11 \pm 0.92\%$ 31 | 32 | **ours:$62.26\%$** 33 | 34 | --------------------------------------------------------------------------------