├── README-图1.png ├── README-图2.png ├── README-图3.png ├── README-图4.png ├── README-图5.png ├── README.md ├── README.pdf ├── Task1 └── Task1 气象数据分析常用工具.ipynb ├── Task2 └── Task2 数据分析.ipynb ├── Task3 ├── Task3 模型建立之CNN+LSTM.ipynb └── fig │ ├── Task3-CNN+LSTM模型.png │ ├── Task3-样本拼接示意图.png │ └── Task3-滑窗构造训练样本.png ├── Task4 ├── Task4 模型建立之TCNN+RNN.ipynb └── fig │ ├── Task4-CNN单元.png │ ├── Task4-TCNN+RNN模型.png │ ├── Task4-TCNN层.png │ ├── Task4-TCN单元.png │ ├── Task4-扩张卷积.png │ └── Task4-残差连接.png └── Task5 ├── Task5 模型建立之SA-ConvLSTM.ipynb └── fig ├── Task5-LSTM与ConvLSTM公式比较.png ├── Task5-SA-ConvLSTM模型.png ├── Task5-SAM模块.png └── Task5-Seq2Seq基础结构.png /README-图1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图1.png -------------------------------------------------------------------------------- /README-图2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图2.png -------------------------------------------------------------------------------- /README-图3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图3.png -------------------------------------------------------------------------------- /README-图4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图4.png -------------------------------------------------------------------------------- /README-图5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图5.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 竞赛介绍 2 | 3 | 2021 “AI Earth” 人工智能创新挑战赛,以 “AI 助力精准气象和海洋预测” 为主题,旨在探索人工智能技术在气象和海洋领域的应用。 4 | 5 | 本赛题的背景是厄尔尼诺 - 南方涛动(ENSO)现象。ENSO现象是厄尔尼诺(EN)现象和南方涛动(SO)现象的合称,其中厄尔尼诺现象是指赤道中东太平洋附近的海表面温度持续异常增暖的现象,南方涛动现象是指热带东太平洋与热带西太平洋气压场存在的气压变化相反的跷跷板现象。厄尔尼诺现象和南方涛动现象实际是反常气候分别在海洋和大气中的表现,二者密切相关,因此合称为厄尔尼诺 - 南方涛动现象。 6 | 7 | ENSO现象会在世界大部分地区引起极端天气,对全球的天气、气候以及粮食产量具有重要的影响,准确预测ENSO,是提高东亚和全球气候预测水平和防灾减灾的关键。Nino3.4指数是ENSO现象监测的一个重要指标,它是指Nino3.4区(170°W - 120°W,5°S - 5°N)的平均海温距平指数,用于反应海表温度异常,若Nino3.4指数连续5个月超过0.5℃就判定为一次ENSO事件。本赛题的目标,就是基于历史气候观测和模式模拟数据,利用T时刻过去12个月(包含T时刻)的时空序列,预测未来1 - 24个月的Nino3.4指数。 8 | 9 | 图1 10 | 11 |
图1 Nino3.4区域
12 | 13 | README-图2 14 | 15 |
图2 Nino3.4指数(图片来源于weatherzone.com.au)
16 | 17 | README-图3 18 | 19 |
图3 赛题示意图
20 | 21 | 基于以上信息可以看出,我们本期的组队学习要完成的是一个时空序列的预测任务。 22 | 23 | # 竞赛题目 24 | 25 | ## 数据简介 26 | 27 | 本赛题使用的训练数据包括CMIP5中17个模式提供的140年的历史模拟数据、CMIP6中15个模式提供的151年的历史模拟数据和美国SODA模式重建的100年的历史观测同化数据,采用nc格式保存,其中CMIP5和CMIP6分别是世界气候研究计划(WCRP)的第5次和第6次耦合模式比较计划,这二者都提供了多种不同的气候模式对于多种气候变量的模拟数据。这些数据包含四种气候变量:海表温度异常(SST)、热含量异常(T300)、纬向风异常(Ua)、经向风异常(Va),数据维度为(year, month, lat, lon),对于训练数据提供对应月份的Nino3.4指数标签数据。简而言之,提供的训练数据中的每个样本为某年、某月、某个维度、某个经度的SST、T300、Ua、Va数值,标签为对应年、对应月的Nino3.4指数。 28 | 29 | 需要注意的是,样本的第二维度month的长度不是12个月,而是36个月,对应从当前year开始连续三年的数据,例如SODA训练数据中year为0时包含的是从第1 - 第3年逐月的历史观测数据,year为1时包含的是从第2年 - 第4年逐月的历史观测数据,也就是说,样本在时间上是有交叉的。 30 | 31 | 图4 32 | 33 |
图4 样本时间跨度示意图
34 | 35 | 另外一点需要注意的是,Nino3.4指数是Nino3.4区域从当前月开始连续三个月的SST平均值,也就是说,我们也可以不直接预测Nino3.4指数,而是以SST为预测目标,间接求得Nino3.4指数。 36 | 37 | 测试数据为国际多个海洋资料同化结果提供的随机抽取的$N$段长度为12个月的时间序列,数据采用npy格式保存,维度为(12, lat, lon, 4),第一维度为连续的12个月份,第四维度为4个气候变量,按SST、T300、Ua、Va的顺序存放。测试集文件序列的命名如test_00001_01_12.npy中00001表示编号,01表示起始月份,12表示终止月份。 38 | 39 | ## 评估指标 40 | 41 | 本赛题的评估指标如下: 42 | $$ 43 | Score = \frac{2}{3} \times accskill - RMSE 44 | $$ 45 | 其中$accskill$为相关性技巧评分,计算方式如下: 46 | $$ 47 | accskill = \sum_{i=1}^{24} a \times ln(i) \times cor_i \\ 48 | (i \leq 4, a = 1.5; 5 \leq i \leq 11, a = 2; 12 \leq i \leq 18, a = 3; 19 \leq i, a = 4) 49 | $$ 50 | 可以看出,月份$i$增加时系数$a$也增大,也就是说,模型能准确预测的时间越长,评分就越高。 51 | 52 | $cor_i$是对于$N$个测试集样本在时刻$i$的预测值与实际值的相关系数,计算公式如下: 53 | $$ 54 | cor_i = \frac{\sum_{j=1}^N(y_{truej}-\bar{y}_{true})(y_{predj}-\bar{y}_{pred})}{\sqrt{\sum(y_{truej}-\bar{y}_{true})^2\sum(y_{predj}-\bar{y}_{pred})^2}} 55 | $$ 56 | 其中$y_{truej}$为时刻$i$样本$j$的实际Nino3.4指数,$\bar{y}_{true}$为该时刻$N$个测试集样本的Nino3.4指数的均值,$y_{predj}$为时刻$i$样本$j$的预测Nino3.4指数,$\bar{y}_{pred}$为该时刻$N$个测试集样本的预测Nino3.4指数的均值。 57 | 58 | $RMSE$为24个月份的累计均方根误差,计算公式为: 59 | $$ 60 | RMSE = \sum_{i=1}^{24}rmse_i \\ 61 | rmse = \sqrt{\frac{1}{N}\sum_{j=1}^N(y_{truej}-y_{predj})^2} 62 | $$ 63 | README-图5 64 | 65 |
图5 评估指标计算示意图
66 | 67 | ## 赛题分析 68 | 69 | 分析上述赛题信息可以发现,我们需要解决的是以下问题: 70 | 71 | - 对于一个时空序列预测问题,要如何挖掘时间信息?如何挖掘空间信息? 72 | - 数据中给出的特征是四个气象领域公认的、通用的气候变量,我们很难再由此构造新的特征。如果不构造新的特征,要如何从给出的特征中挖掘出更多的信息? 73 | - 训练集的数据量不大,总共只有$140\times17+151\times15+100=4745$个训练样本,对于数据量小的预测问题,我们通常需要从以下两方面考虑: 74 | - 如何增加数据量? 75 | - 如何构造小(参数量小,减小过拟合风险)而深(能提取出足够丰富的信息)的模型? 76 | 77 | # 学习目标 78 | 79 | 我们对比赛top选手的方案进行了梳理和整合,形成了本次组队学习的五个小目标,希望你能够带着以上问题进行学习,在学习过程中找到答案。 80 | 81 | 我们希望你在本次组队学习中能有以下收获: 82 | 83 | 1. 掌握气象数据分析的常用工具。 84 | 2. 掌握时空数据的分析能力。 85 | 3. 掌握在本次组队学习中用到的模型。 86 | 4. 学会在时空序列预测问题中进行模型选择和模型构造的一些思路和方法。 87 | 88 | 同时,期待你在本次组队学习中不止局限于给出的任务,能够有更多的思考和拓展。 89 | 90 | # 组队学习安排 91 | 92 | ### Task01:气象数据分析的常用工具(2天) 93 | 94 | ### Task02:数据分析(2天) 95 | 96 | ### Task03:模型建立之 CNN + LSTM(3天) 97 | 98 | ### Task04:模型建立之 TCNN + RNN(4天) 99 | 100 | ### Task05:模型建立之 SA-ConvLSTM(4天) 101 | 102 | # 相关资料 103 | 104 | 一、比赛官网 105 | 106 | https://tianchi.aliyun.com/competition/entrance/531871/information 107 | 108 | 二、比赛开源方案 109 | 110 | 1. https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.15.561d53309Kn9hK&postId=210391(swg-lhl,Rank1) 111 | 2. https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.9.561d53309Kn9hK&postId=210734(ailab) 112 | 3. https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.6.561d53309Kn9hK&postId=210836(有源码,神之一手YueTan,Rank5) 113 | 4. https://tianchi.aliyun.com/notebook-ai/detail?spm=5176.12586969.1002.18.561d5330HKwYOW&postId=196536(有源码,学习AI的打工人) 114 | 5. https://github.com/jerrywn121/TianChi_AIEarth?spm=5176.21852664.0.0.6b612aedW6oyIQ(有源码,吴先生的队伍) 115 | 116 | # 项目贡献情况 117 | 118 | - 项目构建与整合:曾海如 119 | -------------------------------------------------------------------------------- /README.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README.pdf -------------------------------------------------------------------------------- /Task3/fig/Task3-CNN+LSTM模型.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task3/fig/Task3-CNN+LSTM模型.png -------------------------------------------------------------------------------- /Task3/fig/Task3-样本拼接示意图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task3/fig/Task3-样本拼接示意图.png -------------------------------------------------------------------------------- /Task3/fig/Task3-滑窗构造训练样本.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task3/fig/Task3-滑窗构造训练样本.png -------------------------------------------------------------------------------- /Task4/Task4 模型建立之TCNN+RNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Datawhale 气象海洋预测-Task4 模型建立之 TCNN+RNN\n", 8 | "本次任务我们将学习来自TOP选手“swg-lhl”的冠军建模方案,该方案中采用的模型是TCNN+RNN。\n", 9 | "\n", 10 | "在Task3中我们学习了CNN+LSTM模型,但是LSTM层的参数量较大,这就带来以下问题:一是参数量大的模型在数据量小的情况下容易过拟合;二是为了尽量避免过拟合,在有限的数据集下我们无法构建更深的模型,难以挖掘到更丰富的信息。相较于LSTM,CNN的参数量只与过滤器的大小有关,在各类任务中往往都有不错的表现,因此我们可以考虑同样用卷积操作来挖掘时间信息。但是如果用三维卷积来同时挖掘时间和空间信息,假设使用的过滤器大小为(T_f, H_f, W_f),那么一层的参数量就是T_f×H_f×W_f,这样的参数量仍然是比较大的。为了进一步降低每一层的参数,增加模型深度,我们本次学习的这个TOP方案对时间和空间分别进行卷积操作,即采用TCN单元挖掘时间信息,然后输入CNN单元中挖掘空间信息,将TCN单元+CNN单元的串行结构称为TCNN层,通过堆叠多层的TCNN层就可以交替地提取时间和空间信息。同时,考虑到不同时间尺度下的时空信息对预测结果的影响可能是不同的,该方案采用了三个RNN层来抽取三种时间尺度下的特征,将三者拼接起来通过全连接层预测Nino3.4指数。" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "## 学习目标\n", 18 | "1. 学习TOP方案的模型构建方法" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## 内容介绍\n", 26 | "1. 数据处理\n", 27 | " - 数据扁平化\n", 28 | " - 空值填充\n", 29 | " - 构造数据集\n", 30 | "2. 模型构建\n", 31 | " - 构造评估函数\n", 32 | " - 模型构造\n", 33 | " - 模型训练\n", 34 | " - 模型评估\n", 35 | "3. 总结" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## 代码示例" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### 数据处理\n", 50 | "该TOP方案的数据处理主要包括三部分:\n", 51 | "1. 数据扁平化。\n", 52 | "2. 空值填充。\n", 53 | "3. 构造数据集\n", 54 | "\n", 55 | "在该方案中除了没有构造新的特征外,其他数据处理方法都与Task3基本相同,因此不多做赘述。" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 1, 61 | "metadata": { 62 | "execution": { 63 | "iopub.execute_input": "2021-11-07T11:32:41.439634Z", 64 | "iopub.status.busy": "2021-11-07T11:32:41.431440Z", 65 | "iopub.status.idle": "2021-11-07T11:32:43.645979Z", 66 | "shell.execute_reply": "2021-11-07T11:32:43.645292Z", 67 | "shell.execute_reply.started": "2021-11-07T11:22:11.718250Z" 68 | }, 69 | "papermill": { 70 | "duration": 2.245092, 71 | "end_time": "2021-11-07T11:32:43.646157", 72 | "exception": false, 73 | "start_time": "2021-11-07T11:32:41.401065", 74 | "status": "completed" 75 | }, 76 | "tags": [] 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "import netCDF4 as nc\n", 81 | "import random\n", 82 | "import os\n", 83 | "from tqdm import tqdm\n", 84 | "import pandas as pd\n", 85 | "import numpy as np\n", 86 | "import matplotlib.pyplot as plt\n", 87 | "%matplotlib inline\n", 88 | "\n", 89 | "import torch\n", 90 | "from torch import nn, optim\n", 91 | "import torch.nn.functional as F\n", 92 | "from torch.utils.data import Dataset, DataLoader\n", 93 | "\n", 94 | "from sklearn.metrics import mean_squared_error" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 2, 100 | "metadata": { 101 | "execution": { 102 | "iopub.execute_input": "2021-11-07T11:32:43.703083Z", 103 | "iopub.status.busy": "2021-11-07T11:32:43.702422Z", 104 | "iopub.status.idle": "2021-11-07T11:32:43.706718Z", 105 | "shell.execute_reply": "2021-11-07T11:32:43.707099Z", 106 | "shell.execute_reply.started": "2021-11-07T11:22:17.072537Z" 107 | }, 108 | "papermill": { 109 | "duration": 0.03493, 110 | "end_time": "2021-11-07T11:32:43.707228", 111 | "exception": false, 112 | "start_time": "2021-11-07T11:32:43.672298", 113 | "status": "completed" 114 | }, 115 | "tags": [] 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "# 固定随机种子\n", 120 | "SEED = 22\n", 121 | "\n", 122 | "def seed_everything(seed=42):\n", 123 | " random.seed(seed)\n", 124 | " os.environ['PYTHONHASHSEED'] = str(seed)\n", 125 | " np.random.seed(seed)\n", 126 | " torch.manual_seed(seed)\n", 127 | " torch.cuda.manual_seed(seed)\n", 128 | " torch.backends.cudnn.deterministic = True\n", 129 | " \n", 130 | "seed_everything(SEED)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 3, 136 | "metadata": { 137 | "execution": { 138 | "iopub.execute_input": "2021-11-07T11:32:43.806316Z", 139 | "iopub.status.busy": "2021-11-07T11:32:43.805429Z", 140 | "iopub.status.idle": "2021-11-07T11:32:43.809106Z", 141 | "shell.execute_reply": "2021-11-07T11:32:43.809544Z", 142 | "shell.execute_reply.started": "2021-11-07T11:22:34.845477Z" 143 | }, 144 | "papermill": { 145 | "duration": 0.077409, 146 | "end_time": "2021-11-07T11:32:43.809691", 147 | "exception": false, 148 | "start_time": "2021-11-07T11:32:43.732282", 149 | "status": "completed" 150 | }, 151 | "tags": [] 152 | }, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "CUDA is available! Training on GPU ...\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "# 查看GPU是否可用\n", 164 | "train_on_gpu = torch.cuda.is_available()\n", 165 | "\n", 166 | "if not train_on_gpu:\n", 167 | " print('CUDA is not available. Training on CPU ...')\n", 168 | "else:\n", 169 | " print('CUDA is available! Training on GPU ...')" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 4, 175 | "metadata": { 176 | "execution": { 177 | "iopub.execute_input": "2021-11-07T11:32:43.938252Z", 178 | "iopub.status.busy": "2021-11-07T11:32:43.937335Z", 179 | "iopub.status.idle": "2021-11-07T11:32:43.991840Z", 180 | "shell.execute_reply": "2021-11-07T11:32:43.991334Z", 181 | "shell.execute_reply.started": "2021-11-07T11:22:38.372763Z" 182 | }, 183 | "papermill": { 184 | "duration": 0.156188, 185 | "end_time": "2021-11-07T11:32:43.991963", 186 | "exception": false, 187 | "start_time": "2021-11-07T11:32:43.835775", 188 | "status": "completed" 189 | }, 190 | "tags": [] 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "# 读取数据\n", 195 | "\n", 196 | "# 存放数据的路径\n", 197 | "path = '/kaggle/input/ninoprediction/'\n", 198 | "soda_train = nc.Dataset(path + 'SODA_train.nc')\n", 199 | "soda_label = nc.Dataset(path + 'SODA_label.nc')\n", 200 | "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n", 201 | "cmip_label = nc.Dataset(path + 'CMIP_label.nc')" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "#### 数据扁平化\n", 209 | "采用滑窗构造数据集。" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 5, 215 | "metadata": { 216 | "execution": { 217 | "iopub.execute_input": "2021-11-07T11:32:44.053052Z", 218 | "iopub.status.busy": "2021-11-07T11:32:44.052369Z", 219 | "iopub.status.idle": "2021-11-07T11:32:44.055801Z", 220 | "shell.execute_reply": "2021-11-07T11:32:44.055258Z", 221 | "shell.execute_reply.started": "2021-11-07T11:22:41.386272Z" 222 | }, 223 | "papermill": { 224 | "duration": 0.037368, 225 | "end_time": "2021-11-07T11:32:44.055938", 226 | "exception": false, 227 | "start_time": "2021-11-07T11:32:44.018570", 228 | "status": "completed" 229 | }, 230 | "tags": [] 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "def make_flatted(train_ds, label_ds, info, start_idx=0):\n", 235 | " keys = ['sst', 't300', 'ua', 'va']\n", 236 | " label_key = 'nino'\n", 237 | " # 年数\n", 238 | " years = info[1]\n", 239 | " # 模式数\n", 240 | " models = info[2]\n", 241 | " \n", 242 | " train_list = []\n", 243 | " label_list = []\n", 244 | " \n", 245 | " # 将同种模式下的数据拼接起来\n", 246 | " for model_i in range(models):\n", 247 | " blocks = []\n", 248 | " \n", 249 | " # 对每个特征,取每条数据的前12个月进行拼接\n", 250 | " for key in keys:\n", 251 | " block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12].reshape(-1, 24, 72, 1).data\n", 252 | " blocks.append(block)\n", 253 | " \n", 254 | " # 将所有特征在最后一个维度上拼接起来\n", 255 | " train_flatted = np.concatenate(blocks, axis=-1)\n", 256 | " \n", 257 | " # 取12-23月的标签进行拼接,注意加上最后一年的最后12个月的标签(与最后一年12-23月的标签共同构成最后一年前12个月的预测目标)\n", 258 | " label_flatted = np.concatenate([\n", 259 | " label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n", 260 | " label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n", 261 | " ], axis=0)\n", 262 | " \n", 263 | " train_list.append(train_flatted)\n", 264 | " label_list.append(label_flatted)\n", 265 | " \n", 266 | " return train_list, label_list" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 6, 272 | "metadata": { 273 | "execution": { 274 | "iopub.execute_input": "2021-11-07T11:32:44.114429Z", 275 | "iopub.status.busy": "2021-11-07T11:32:44.113857Z", 276 | "iopub.status.idle": "2021-11-07T11:33:35.423369Z", 277 | "shell.execute_reply": "2021-11-07T11:33:35.423920Z", 278 | "shell.execute_reply.started": "2021-11-07T11:22:43.054065Z" 279 | }, 280 | "papermill": { 281 | "duration": 51.342264, 282 | "end_time": "2021-11-07T11:33:35.424098", 283 | "exception": false, 284 | "start_time": "2021-11-07T11:32:44.081834", 285 | "status": "completed" 286 | }, 287 | "tags": [] 288 | }, 289 | "outputs": [ 290 | { 291 | "data": { 292 | "text/plain": [ 293 | "((1, 1200, 24, 72, 4), (15, 1812, 24, 72, 4), (17, 1680, 24, 72, 4))" 294 | ] 295 | }, 296 | "execution_count": 6, 297 | "metadata": {}, 298 | "output_type": "execute_result" 299 | } 300 | ], 301 | "source": [ 302 | "soda_info = ('soda', 100, 1)\n", 303 | "cmip6_info = ('cmip6', 151, 15)\n", 304 | "cmip5_info = ('cmip5', 140, 17)\n", 305 | "\n", 306 | "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n", 307 | "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n", 308 | "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n", 309 | "\n", 310 | "# 得到扁平化后的数据维度为(模式数×序列长度×纬度×经度×特征数),其中序列长度=年数×12\n", 311 | "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "metadata": {}, 317 | "source": [ 318 | "#### 空值填充\n", 319 | "将空值填充为0。" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 8, 325 | "metadata": { 326 | "execution": { 327 | "iopub.execute_input": "2021-11-07T11:33:35.538420Z", 328 | "iopub.status.busy": "2021-11-07T11:33:35.537032Z", 329 | "iopub.status.idle": "2021-11-07T11:33:35.589717Z", 330 | "shell.execute_reply": "2021-11-07T11:33:35.590116Z", 331 | "shell.execute_reply.started": "2021-11-07T11:23:30.781276Z" 332 | }, 333 | "papermill": { 334 | "duration": 0.083633, 335 | "end_time": "2021-11-07T11:33:35.590263", 336 | "exception": false, 337 | "start_time": "2021-11-07T11:33:35.506630", 338 | "status": "completed" 339 | }, 340 | "tags": [] 341 | }, 342 | "outputs": [ 343 | { 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "Number of null in soda_trains after fillna: 0\n" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "# 填充SODA数据中的空值\n", 353 | "soda_trains = np.array(soda_trains)\n", 354 | "soda_trains_nan = np.isnan(soda_trains)\n", 355 | "soda_trains[soda_trains_nan] = 0\n", 356 | "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 9, 362 | "metadata": { 363 | "execution": { 364 | "iopub.execute_input": "2021-11-07T11:33:35.649057Z", 365 | "iopub.status.busy": "2021-11-07T11:33:35.647940Z", 366 | "iopub.status.idle": "2021-11-07T11:33:36.785018Z", 367 | "shell.execute_reply": "2021-11-07T11:33:36.784016Z", 368 | "shell.execute_reply.started": "2021-11-07T11:23:30.842000Z" 369 | }, 370 | "papermill": { 371 | "duration": 1.1683, 372 | "end_time": "2021-11-07T11:33:36.785204", 373 | "exception": false, 374 | "start_time": "2021-11-07T11:33:35.616904", 375 | "status": "completed" 376 | }, 377 | "tags": [] 378 | }, 379 | "outputs": [ 380 | { 381 | "name": "stdout", 382 | "output_type": "stream", 383 | "text": [ 384 | "Number of null in cmip6_trains after fillna: 0\n" 385 | ] 386 | } 387 | ], 388 | "source": [ 389 | "# 填充CMIP6数据中的空值\n", 390 | "cmip6_trains = np.array(cmip6_trains)\n", 391 | "cmip6_trains_nan = np.isnan(cmip6_trains)\n", 392 | "cmip6_trains[cmip6_trains_nan] = 0\n", 393 | "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 10, 399 | "metadata": { 400 | "execution": { 401 | "iopub.execute_input": "2021-11-07T11:33:36.844494Z", 402 | "iopub.status.busy": "2021-11-07T11:33:36.843382Z", 403 | "iopub.status.idle": "2021-11-07T11:33:37.982369Z", 404 | "shell.execute_reply": "2021-11-07T11:33:37.983434Z", 405 | "shell.execute_reply.started": "2021-11-07T11:23:31.982897Z" 406 | }, 407 | "papermill": { 408 | "duration": 1.170648, 409 | "end_time": "2021-11-07T11:33:37.983683", 410 | "exception": false, 411 | "start_time": "2021-11-07T11:33:36.813035", 412 | "status": "completed" 413 | }, 414 | "tags": [] 415 | }, 416 | "outputs": [ 417 | { 418 | "name": "stdout", 419 | "output_type": "stream", 420 | "text": [ 421 | "Number of null in cmip6_trains after fillna: 0\n" 422 | ] 423 | } 424 | ], 425 | "source": [ 426 | "# 填充CMIP5数据中的空值\n", 427 | "cmip5_trains = np.array(cmip5_trains)\n", 428 | "cmip5_trains_nan = np.isnan(cmip5_trains)\n", 429 | "cmip5_trains[cmip5_trains_nan] = 0\n", 430 | "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "#### 构造数据集\n", 438 | "构造训练和验证集。" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 11, 444 | "metadata": { 445 | "execution": { 446 | "iopub.execute_input": "2021-11-07T11:33:38.091882Z", 447 | "iopub.status.busy": "2021-11-07T11:33:38.091030Z", 448 | "iopub.status.idle": "2021-11-07T11:33:38.757089Z", 449 | "shell.execute_reply": "2021-11-07T11:33:38.756486Z", 450 | "shell.execute_reply.started": "2021-11-07T11:23:33.105534Z" 451 | }, 452 | "papermill": { 453 | "duration": 0.72117, 454 | "end_time": "2021-11-07T11:33:38.757230", 455 | "exception": false, 456 | "start_time": "2021-11-07T11:33:38.036060", 457 | "status": "completed" 458 | }, 459 | "tags": [] 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "# 构造训练集\n", 464 | "\n", 465 | "X_train = []\n", 466 | "y_train = []\n", 467 | "# 从CMIP5的17种模式中各抽取100条数据\n", 468 | "for model_i in range(17):\n", 469 | " samples = np.random.choice(cmip5_trains.shape[1]-12, size=100)\n", 470 | " for ind in samples:\n", 471 | " X_train.append(cmip5_trains[model_i, ind: ind+12])\n", 472 | " y_train.append(cmip5_labels[model_i][ind: ind+24])\n", 473 | "# 从CMIP6的15种模式种各抽取100条数据\n", 474 | "for model_i in range(15):\n", 475 | " samples = np.random.choice(cmip6_trains.shape[1]-12, size=100)\n", 476 | " for ind in samples:\n", 477 | " X_train.append(cmip6_trains[model_i, ind: ind+12])\n", 478 | " y_train.append(cmip6_labels[model_i][ind: ind+24])\n", 479 | "X_train = np.array(X_train)\n", 480 | "y_train = np.array(y_train)" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 12, 486 | "metadata": { 487 | "execution": { 488 | "iopub.execute_input": "2021-11-07T11:33:38.819204Z", 489 | "iopub.status.busy": "2021-11-07T11:33:38.818020Z", 490 | "iopub.status.idle": "2021-11-07T11:33:38.840229Z", 491 | "shell.execute_reply": "2021-11-07T11:33:38.839737Z", 492 | "shell.execute_reply.started": "2021-11-07T11:23:33.801270Z" 493 | }, 494 | "papermill": { 495 | "duration": 0.055138, 496 | "end_time": "2021-11-07T11:33:38.840360", 497 | "exception": false, 498 | "start_time": "2021-11-07T11:33:38.785222", 499 | "status": "completed" 500 | }, 501 | "tags": [] 502 | }, 503 | "outputs": [], 504 | "source": [ 505 | "# 构造验证集\n", 506 | "\n", 507 | "X_valid = []\n", 508 | "y_valid = []\n", 509 | "samples = np.random.choice(soda_trains.shape[1]-12, size=100)\n", 510 | "for ind in samples:\n", 511 | " X_valid.append(soda_trains[0, ind: ind+12])\n", 512 | " y_valid.append(soda_labels[0][ind: ind+24])\n", 513 | "X_valid = np.array(X_valid)\n", 514 | "y_valid = np.array(y_valid)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 13, 520 | "metadata": { 521 | "execution": { 522 | "iopub.execute_input": "2021-11-07T11:33:38.898632Z", 523 | "iopub.status.busy": "2021-11-07T11:33:38.897899Z", 524 | "iopub.status.idle": "2021-11-07T11:33:38.900564Z", 525 | "shell.execute_reply": "2021-11-07T11:33:38.901077Z", 526 | "shell.execute_reply.started": "2021-11-07T11:23:33.839695Z" 527 | }, 528 | "papermill": { 529 | "duration": 0.034011, 530 | "end_time": "2021-11-07T11:33:38.901204", 531 | "exception": false, 532 | "start_time": "2021-11-07T11:33:38.867193", 533 | "status": "completed" 534 | }, 535 | "tags": [] 536 | }, 537 | "outputs": [ 538 | { 539 | "data": { 540 | "text/plain": [ 541 | "((3200, 12, 24, 72, 4), (3200, 24), (100, 12, 24, 72, 4), (100, 24))" 542 | ] 543 | }, 544 | "execution_count": 13, 545 | "metadata": {}, 546 | "output_type": "execute_result" 547 | } 548 | ], 549 | "source": [ 550 | "# 查看数据集维度\n", 551 | "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 15, 557 | "metadata": { 558 | "execution": { 559 | "iopub.execute_input": "2021-11-07T11:33:39.033073Z", 560 | "iopub.status.busy": "2021-11-07T11:33:39.032319Z", 561 | "iopub.status.idle": "2021-11-07T11:33:40.963743Z", 562 | "shell.execute_reply": "2021-11-07T11:33:40.962736Z", 563 | "shell.execute_reply.started": "2021-11-07T11:23:33.879524Z" 564 | }, 565 | "papermill": { 566 | "duration": 1.963702, 567 | "end_time": "2021-11-07T11:33:40.963922", 568 | "exception": false, 569 | "start_time": "2021-11-07T11:33:39.000220", 570 | "status": "completed" 571 | }, 572 | "tags": [] 573 | }, 574 | "outputs": [], 575 | "source": [ 576 | "# 保存数据集\n", 577 | "np.save('X_train_sample.npy', X_train)\n", 578 | "np.save('y_train_sample.npy', y_train)\n", 579 | "np.save('X_valid_sample.npy', X_valid)\n", 580 | "np.save('y_valid_sample.npy', y_valid)" 581 | ] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": { 586 | "papermill": { 587 | "duration": 0.442674, 588 | "end_time": "2021-11-07T11:33:43.003352", 589 | "exception": false, 590 | "start_time": "2021-11-07T11:33:42.560678", 591 | "status": "completed" 592 | }, 593 | "tags": [] 594 | }, 595 | "source": [ 596 | "### 模型构建" 597 | ] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": {}, 602 | "source": [ 603 | "这一部分我们来重点学习一下该方案的模型结构。" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": 16, 609 | "metadata": { 610 | "execution": { 611 | "iopub.execute_input": "2021-11-07T11:33:43.228569Z", 612 | "iopub.status.busy": "2021-11-07T11:33:43.223113Z", 613 | "iopub.status.idle": "2021-11-07T11:33:58.100748Z", 614 | "shell.execute_reply": "2021-11-07T11:33:58.101185Z", 615 | "shell.execute_reply.started": "2021-11-07T11:23:42.357343Z" 616 | }, 617 | "papermill": { 618 | "duration": 15.039399, 619 | "end_time": "2021-11-07T11:33:58.101351", 620 | "exception": false, 621 | "start_time": "2021-11-07T11:33:43.061952", 622 | "status": "completed" 623 | }, 624 | "tags": [] 625 | }, 626 | "outputs": [], 627 | "source": [ 628 | "# 读取数据集\n", 629 | "X_train = np.load('../input/ai-earth-task04-samples/X_train_sample.npy')\n", 630 | "y_train = np.load('../input/ai-earth-task04-samples/y_train_sample.npy')\n", 631 | "X_valid = np.load('../input/ai-earth-task04-samples/X_valid_sample.npy')\n", 632 | "y_valid = np.load('../input/ai-earth-task04-samples/y_valid_sample.npy')" 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": 17, 638 | "metadata": { 639 | "execution": { 640 | "iopub.execute_input": "2021-11-07T11:33:58.162900Z", 641 | "iopub.status.busy": "2021-11-07T11:33:58.161742Z", 642 | "iopub.status.idle": "2021-11-07T11:33:58.164817Z", 643 | "shell.execute_reply": "2021-11-07T11:33:58.165196Z", 644 | "shell.execute_reply.started": "2021-11-07T11:23:56.303716Z" 645 | }, 646 | "papermill": { 647 | "duration": 0.036534, 648 | "end_time": "2021-11-07T11:33:58.165327", 649 | "exception": false, 650 | "start_time": "2021-11-07T11:33:58.128793", 651 | "status": "completed" 652 | }, 653 | "tags": [] 654 | }, 655 | "outputs": [ 656 | { 657 | "data": { 658 | "text/plain": [ 659 | "((3200, 12, 24, 72, 4), (3200, 24), (100, 12, 24, 72, 4), (100, 24))" 660 | ] 661 | }, 662 | "execution_count": 17, 663 | "metadata": {}, 664 | "output_type": "execute_result" 665 | } 666 | ], 667 | "source": [ 668 | "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": 18, 674 | "metadata": { 675 | "execution": { 676 | "iopub.execute_input": "2021-11-07T11:33:58.225845Z", 677 | "iopub.status.busy": "2021-11-07T11:33:58.224285Z", 678 | "iopub.status.idle": "2021-11-07T11:33:58.226437Z", 679 | "shell.execute_reply": "2021-11-07T11:33:58.226877Z", 680 | "shell.execute_reply.started": "2021-11-07T11:23:56.313927Z" 681 | }, 682 | "papermill": { 683 | "duration": 0.03485, 684 | "end_time": "2021-11-07T11:33:58.227001", 685 | "exception": false, 686 | "start_time": "2021-11-07T11:33:58.192151", 687 | "status": "completed" 688 | }, 689 | "tags": [] 690 | }, 691 | "outputs": [], 692 | "source": [ 693 | "# 构造数据管道\n", 694 | "class AIEarthDataset(Dataset):\n", 695 | " def __init__(self, data, label):\n", 696 | " self.data = torch.tensor(data, dtype=torch.float32)\n", 697 | " self.label = torch.tensor(label, dtype=torch.float32)\n", 698 | "\n", 699 | " def __len__(self):\n", 700 | " return len(self.label)\n", 701 | " \n", 702 | " def __getitem__(self, idx):\n", 703 | " return self.data[idx], self.label[idx]" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": 19, 709 | "metadata": { 710 | "execution": { 711 | "iopub.execute_input": "2021-11-07T11:33:58.289091Z", 712 | "iopub.status.busy": "2021-11-07T11:33:58.288549Z", 713 | "iopub.status.idle": "2021-11-07T11:33:59.029333Z", 714 | "shell.execute_reply": "2021-11-07T11:33:59.028812Z", 715 | "shell.execute_reply.started": "2021-11-07T11:23:56.324788Z" 716 | }, 717 | "papermill": { 718 | "duration": 0.775212, 719 | "end_time": "2021-11-07T11:33:59.029494", 720 | "exception": false, 721 | "start_time": "2021-11-07T11:33:58.254282", 722 | "status": "completed" 723 | }, 724 | "tags": [] 725 | }, 726 | "outputs": [], 727 | "source": [ 728 | "batch_size = 32\n", 729 | "\n", 730 | "trainset = AIEarthDataset(X_train, y_train)\n", 731 | "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", 732 | "\n", 733 | "validset = AIEarthDataset(X_valid, y_valid)\n", 734 | "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)" 735 | ] 736 | }, 737 | { 738 | "cell_type": "markdown", 739 | "metadata": {}, 740 | "source": [ 741 | "#### 构造评估函数" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": 24, 747 | "metadata": { 748 | "execution": { 749 | "iopub.execute_input": "2021-11-07T11:33:59.402757Z", 750 | "iopub.status.busy": "2021-11-07T11:33:59.401182Z", 751 | "iopub.status.idle": "2021-11-07T11:33:59.404701Z", 752 | "shell.execute_reply": "2021-11-07T11:33:59.405088Z", 753 | "shell.execute_reply.started": "2021-11-07T11:23:57.219014Z" 754 | }, 755 | "papermill": { 756 | "duration": 0.037919, 757 | "end_time": "2021-11-07T11:33:59.405211", 758 | "exception": false, 759 | "start_time": "2021-11-07T11:33:59.367292", 760 | "status": "completed" 761 | }, 762 | "tags": [] 763 | }, 764 | "outputs": [], 765 | "source": [ 766 | "def rmse(y_true, y_preds):\n", 767 | " return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n", 768 | "\n", 769 | "# 评估函数\n", 770 | "def score(y_true, y_preds):\n", 771 | " # 相关性技巧评分\n", 772 | " accskill_score = 0\n", 773 | " # RMSE\n", 774 | " rmse_scores = 0\n", 775 | " a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n", 776 | " y_true_mean = np.mean(y_true, axis=0)\n", 777 | " y_pred_mean = np.mean(y_preds, axis=0)\n", 778 | " for i in range(24):\n", 779 | " fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n", 780 | " fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n", 781 | " cor_i = fenzi / fenmu\n", 782 | " accskill_score += a[i] * np.log(i+1) * cor_i\n", 783 | " rmse_score = rmse(y_true[:, i], y_preds[:, i])\n", 784 | " rmse_scores += rmse_score\n", 785 | " return 2/3.0 * accskill_score - rmse_scores" 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "metadata": {}, 791 | "source": [ 792 | "#### 模型构造\n", 793 | "\n", 794 | "该TOP方案采用TCN单元+CNN单元串行组成TCNN层,通过堆叠多层的TCNN层来交替地提取时间和空间信息,并将提取到的时空信息用RNN来抽取出三种不同时间尺度的特征表达。" 795 | ] 796 | }, 797 | { 798 | "cell_type": "markdown", 799 | "metadata": {}, 800 | "source": [ 801 | "- **TCN单元**\n", 802 | "\n", 803 | "TCN模型全称时间卷积网络(Temporal Convolutional Network),与RNN一样是时序模型。TCN以CNN为基础,为了适应序列问题,它从以下三方面做出了改进:\n", 804 | "\n", 805 | "1. 因果卷积\n", 806 | "\n", 807 | "TCN处理输入与输出等长的序列问题,它的每一个隐藏层节点数与输入步长是相同的,并且隐藏层t时刻节点的值只依赖于前一层t时刻及之前节点的值。也就是说TCN通过追溯前因(t时刻及之前的值)来获得当前结果,称为因果卷积。\n", 808 | "\n", 809 | "2. 扩张卷积\n", 810 | "\n", 811 | "传统CNN的感受野受限于卷积核的大小,需要通过增加池化层来获得更大的感受野,但是池化的操作会带来信息的损失。为了解决这个问题,TCN采用扩张卷积来增大感受野,获取更长时间的信息。扩张卷积对输入进行间隔采样,采样间隔由扩张因子d控制,公式定义如下:\n", 812 | "$$\n", 813 | "F(s) = (X * df)(s) = \\sum_{i=0}^{k-1} f(i) \\times X_{s-di}\n", 814 | "$$\n", 815 | "其中X为当前层的输入,k为当前层的卷积核大小,s为当前节点的时刻。也就是说,对于扩张因子为d、卷积核为k的隐藏层,对前一层的输入每d个点采样一次,共采样k个点作为当前时刻s的输入。这样TCN的感受野就由卷积核的大小k和扩张因子d共同决定,可以获取更长时间的依赖信息。\n", 816 | "\n", 817 | "\n", 818 | "\n", 819 | "3. 残差连接\n", 820 | "\n", 821 | "网络的层数越多,所能提取到的特征就越丰富,但这也会带来梯度消失或爆炸的问题,目前解决这个问题的一个有效方法就是残差连接。TCN的残差模块包含两层卷积操作,并且采用了WeightNorm和Dropout进行正则化,如下图所示。\n", 822 | "\n", 823 | "\n", 824 | "\n", 825 | "总的来说,TCN是卷积操作在序列问题上的改进,具有CNN参数量少的优点,可以搭建更深层的网络,相比于RNN不容易存在梯度消失和爆炸的问题,同时TCN具有灵活的感受野,能够适应不同的任务,在许多数据集上的比较表明TCN比RNN、LSTM、GRU等序列模型有更好的表现。\n", 826 | "\n", 827 | "想要更深入地了解TCN可以参考以下链接:\n", 828 | " \n", 829 | " - 论文原文:https://arxiv.org/pdf/1803.01271.pdf\n", 830 | " - GitHub:https://github.com/locuslab/tcn\n", 831 | " \n", 832 | "该方案中所构建的TCN单元并不是标准的TCN层,它的结构如下图所示,可以看到,这里的TCN单元只是用了一个卷积层,并且在卷积层前后都采用了BatchNormalization来提高模型的泛化能力。需要注意的是,这里的卷积操作是对时间维度进行操作,因此需要对输入的形状进行转换,并且为了便于匹配之后的网络层,需要将输出的形状转换回输入时的(N,T,C,H,W)的形式。\n", 833 | "\n", 834 | "" 835 | ] 836 | }, 837 | { 838 | "cell_type": "code", 839 | "execution_count": 20, 840 | "metadata": { 841 | "execution": { 842 | "iopub.execute_input": "2021-11-07T11:33:59.094709Z", 843 | "iopub.status.busy": "2021-11-07T11:33:59.094119Z", 844 | "iopub.status.idle": "2021-11-07T11:33:59.097687Z", 845 | "shell.execute_reply": "2021-11-07T11:33:59.097168Z", 846 | "shell.execute_reply.started": "2021-11-07T11:23:57.153300Z" 847 | }, 848 | "papermill": { 849 | "duration": 0.039993, 850 | "end_time": "2021-11-07T11:33:59.097803", 851 | "exception": false, 852 | "start_time": "2021-11-07T11:33:59.057810", 853 | "status": "completed" 854 | }, 855 | "tags": [] 856 | }, 857 | "outputs": [], 858 | "source": [ 859 | "# 构建TCN单元\n", 860 | "class TCNBlock(nn.Module):\n", 861 | " def __init__(self, in_channels, out_channels, kernel_size, stride, padding):\n", 862 | " super().__init__()\n", 863 | " self.bn1 = nn.BatchNorm1d(in_channels)\n", 864 | " self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)\n", 865 | " self.bn2 = nn.BatchNorm1d(out_channels)\n", 866 | " \n", 867 | " if in_channels == out_channels and stride == 1:\n", 868 | " self.res = lambda x: x\n", 869 | " else:\n", 870 | " self.res = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride)\n", 871 | " \n", 872 | " def forward(self, x):\n", 873 | " # 转换输入形状\n", 874 | " N, T, C, H, W = x.shape\n", 875 | " x = x.permute(0, 3, 4, 2, 1).contiguous()\n", 876 | " x = x.view(N*H*W, C, T)\n", 877 | " \n", 878 | " # 残差\n", 879 | " res = self.res(x) \n", 880 | " res = self.bn2(res)\n", 881 | "\n", 882 | " x = F.relu(self.bn1(x))\n", 883 | " x = self.conv(x)\n", 884 | " x = self.bn2(x)\n", 885 | " \n", 886 | " x = x + res\n", 887 | " \n", 888 | " # 将输出转换回(N,T,C,H,W)的形式\n", 889 | " _, C_new, T_new = x.shape\n", 890 | " x = x.view(N, H, W, C_new, T_new)\n", 891 | " x = x.permute(0, 4, 3, 1, 2).contiguous()\n", 892 | " \n", 893 | " return x" 894 | ] 895 | }, 896 | { 897 | "cell_type": "markdown", 898 | "metadata": {}, 899 | "source": [ 900 | "- **CNN单元**\n", 901 | "\n", 902 | "CNN单元结构与TCN单元相似,都只有一个卷积层,并且使用BatchNormalization来提高模型泛化能力。同时,类似TCN单元,CNN单元中也加入了残差连接。结构如下图所示:\n", 903 | "\n", 904 | "" 905 | ] 906 | }, 907 | { 908 | "cell_type": "code", 909 | "execution_count": 21, 910 | "metadata": { 911 | "execution": { 912 | "iopub.execute_input": "2021-11-07T11:33:59.160481Z", 913 | "iopub.status.busy": "2021-11-07T11:33:59.158973Z", 914 | "iopub.status.idle": "2021-11-07T11:33:59.163197Z", 915 | "shell.execute_reply": "2021-11-07T11:33:59.163610Z", 916 | "shell.execute_reply.started": "2021-11-07T11:23:57.170593Z" 917 | }, 918 | "papermill": { 919 | "duration": 0.038796, 920 | "end_time": "2021-11-07T11:33:59.163736", 921 | "exception": false, 922 | "start_time": "2021-11-07T11:33:59.124940", 923 | "status": "completed" 924 | }, 925 | "tags": [] 926 | }, 927 | "outputs": [], 928 | "source": [ 929 | "# 构建CNN单元\n", 930 | "class CNNBlock(nn.Module):\n", 931 | " def __init__(self, in_channels, out_channels, kernel_size, stride, padding):\n", 932 | " super().__init__()\n", 933 | " self.bn1 = nn.BatchNorm2d(in_channels)\n", 934 | " self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)\n", 935 | " self.bn2 = nn.BatchNorm2d(out_channels)\n", 936 | " \n", 937 | " if (in_channels == out_channels) and (stride == 1):\n", 938 | " self.res = lambda x: x\n", 939 | " else:\n", 940 | " self.res = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)\n", 941 | " \n", 942 | " def forward(self, x):\n", 943 | " # 转换输入形状\n", 944 | " N, T, C, H, W = x.shape\n", 945 | " x = x.view(N*T, C, H, W)\n", 946 | " \n", 947 | " # 残差\n", 948 | " res = self.res(x)\n", 949 | " res = self.bn2(res)\n", 950 | "\n", 951 | " x = F.relu(self.bn1(x))\n", 952 | " x = self.conv(x)\n", 953 | " x = self.bn2(x)\n", 954 | " \n", 955 | " x = x + res\n", 956 | " \n", 957 | " # 将输出转换回(N,T,C,H,W)的形式\n", 958 | " _, C_new, H_new, W_new = x.shape\n", 959 | " x = x.view(N, T, C_new, H_new, W_new)\n", 960 | " \n", 961 | " return x" 962 | ] 963 | }, 964 | { 965 | "cell_type": "markdown", 966 | "metadata": {}, 967 | "source": [ 968 | "- **TCNN层**\n", 969 | "\n", 970 | "将TCN单元和CNN单元串行连接,就构成了一个TCNN层。\n", 971 | "\n", 972 | "" 973 | ] 974 | }, 975 | { 976 | "cell_type": "code", 977 | "execution_count": 22, 978 | "metadata": { 979 | "execution": { 980 | "iopub.execute_input": "2021-11-07T11:33:59.223231Z", 981 | "iopub.status.busy": "2021-11-07T11:33:59.222402Z", 982 | "iopub.status.idle": "2021-11-07T11:33:59.224980Z", 983 | "shell.execute_reply": "2021-11-07T11:33:59.224529Z", 984 | "shell.execute_reply.started": "2021-11-07T11:23:57.182192Z" 985 | }, 986 | "papermill": { 987 | "duration": 0.034509, 988 | "end_time": "2021-11-07T11:33:59.225082", 989 | "exception": false, 990 | "start_time": "2021-11-07T11:33:59.190573", 991 | "status": "completed" 992 | }, 993 | "tags": [] 994 | }, 995 | "outputs": [], 996 | "source": [ 997 | "class TCNNBlock(nn.Module):\n", 998 | " def __init__(self, in_channels, out_channels, kernel_size, stride_tcn, stride_cnn, padding):\n", 999 | " super().__init__()\n", 1000 | " self.tcn = TCNBlock(in_channels, out_channels, kernel_size, stride_tcn, padding)\n", 1001 | " self.cnn = CNNBlock(out_channels, out_channels, kernel_size, stride_cnn, padding)\n", 1002 | " \n", 1003 | " def forward(self, x):\n", 1004 | " x = self.tcn(x)\n", 1005 | " x = self.cnn(x)\n", 1006 | " return x" 1007 | ] 1008 | }, 1009 | { 1010 | "cell_type": "markdown", 1011 | "metadata": {}, 1012 | "source": [ 1013 | "- **TCNN+RNN模型**\n", 1014 | "\n", 1015 | "整体的模型结构如下图所示:\n", 1016 | "\n", 1017 | "\n", 1018 | "\n", 1019 | "1. TCNN部分\n", 1020 | "\n", 1021 | "TCNN部分的模型结构类似传统CNN的结构,非常规整,通过逐渐增加通道数来提取更丰富的特征表达。需要注意的是输入数据的格式是(N,T,H,W,C),为了匹配卷积层的输入格式,需要将数据格式转换为(N,T,C,H,W)。\n", 1022 | "\n", 1023 | "2. GAP层\n", 1024 | "\n", 1025 | "GAP全称为全局平均池化(Global Average Pooling)层,它的作用是把每个通道上的特征图取全局平均,假设经过TCNN部分得到的输出格式为(N,T,C,H,W),那么GAP层就会把每个通道上形状为H×W的特征图上的所有值求平均,最终得到的输出格式就变成(N,T,C)。GAP层最早出现在论文《Network in Network》(论文原文:https://arxiv.org/pdf/1312.4400.pdf )中用于代替传统CNN中的全连接层,之后的许多实验证明GAP层确实可以提高CNN的效果。\n", 1026 | "\n", 1027 | "那么GAP层为什么可以代替全连接层呢?在传统CNN中,经过多层卷积和池化的操作后,会由Flatten层将特征图拉伸成一列,然后经过全连接层,那么对于形状为(C,H,W)的一条数据,经Flatten层拉伸后的长度为C×H×W,此时假设全连接层节点数为U,全连接层的参数量就是C×H×W×U,这么大的参数量很容易使得模型过拟合。相比之下,GAP层不引入新的参数,因此可以有效减少过拟合问题,并且模型参数少也能加快训练速度。另一方面,全连接层是一个黑箱子,我们很难解释多分类的信息是怎样传回卷积层的,而GAP层就很容易理解,每个通道的值就代表了经过多层卷积操作后所提取出来的特征。更详细的理解可以参考https://www.zhihu.com/question/373188099\n", 1028 | "\n", 1029 | "在Pytorch中没有内置的GAP层,因此可以用adaptive_avg_pool2d来替代,这个函数可以将特征图压缩成给定的输出形状,将output_size参数设置为(1,1),就等同于GAP操作,函数的详细使用方法可以参考https://pytorch.org/docs/stable/generated/torch.nn.functional.adaptive_avg_pool2d.html?highlight=adaptive_avg_pool2d#torch.nn.functional.adaptive_avg_pool2d\n", 1030 | "\n", 1031 | "3. RNN部分\n", 1032 | "\n", 1033 | "至此为止我们所使用的都是长度为12的时间序列,每个时间步代表一个月的信息。不同尺度的时间序列所携带的信息是不尽相同的,比如用长度为6的时间序列来表达一年的SST值,那么每个时间步所代表的就是两个月的SST信息,这种时间尺度下的SST序列与长度为12的SST序列所反映的一年中SST变化趋势等信息就不完全相同。所以,为了尽可能全面地挖掘更多信息,该TOP方案中用MaxPool层来获得三种不同时间尺度的序列,同时,用RNN层来抽取序列的特征表达。RNN非常适合用于线性序列的自动特征提取,例如对于形状为(T,C1)的一条输入数据,R经过节点数为C2的RNN层就能抽取出长度为C2的向量,由于RNN由前往后进行信息线性传递的网络结构,抽取出的向量能够很好地表达序列中的依赖关系。\n", 1034 | "\n", 1035 | "此时三种不同时间尺度的序列都抽取出了一个向量来表示,将向量拼接起来再经过一个全连接层就得到了24个月的预测序列。" 1036 | ] 1037 | }, 1038 | { 1039 | "cell_type": "code", 1040 | "execution_count": 23, 1041 | "metadata": { 1042 | "execution": { 1043 | "iopub.execute_input": "2021-11-07T11:33:59.339242Z", 1044 | "iopub.status.busy": "2021-11-07T11:33:59.337656Z", 1045 | "iopub.status.idle": "2021-11-07T11:33:59.339842Z", 1046 | "shell.execute_reply": "2021-11-07T11:33:59.340257Z", 1047 | "shell.execute_reply.started": "2021-11-07T11:23:57.199137Z" 1048 | }, 1049 | "papermill": { 1050 | "duration": 0.088367, 1051 | "end_time": "2021-11-07T11:33:59.340377", 1052 | "exception": false, 1053 | "start_time": "2021-11-07T11:33:59.252010", 1054 | "status": "completed" 1055 | }, 1056 | "tags": [] 1057 | }, 1058 | "outputs": [], 1059 | "source": [ 1060 | "# 构造模型\n", 1061 | "class Model(nn.Module):\n", 1062 | " def __init__(self):\n", 1063 | " super().__init__()\n", 1064 | " self.conv = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3)\n", 1065 | " self.tcnn1 = TCNNBlock(64, 64, 3, 1, 1, 1)\n", 1066 | " self.tcnn2 = TCNNBlock(64, 128, 3, 1, 2, 1)\n", 1067 | " self.tcnn3 = TCNNBlock(128, 128, 3, 1, 1, 1)\n", 1068 | " self.tcnn4 = TCNNBlock(128, 256, 3, 1, 2, 1)\n", 1069 | " self.tcnn5 = TCNNBlock(256, 256, 3, 1, 1, 1)\n", 1070 | " self.rnn = nn.RNN(256, 256, batch_first=True)\n", 1071 | " self.maxpool = nn.MaxPool1d(2)\n", 1072 | " self.fc = nn.Linear(256*3, 24)\n", 1073 | " \n", 1074 | " def forward(self, x):\n", 1075 | " # 转换输入形状\n", 1076 | " N, T, H, W, C = x.shape\n", 1077 | " x = x.permute(0, 1, 4, 2, 3).contiguous()\n", 1078 | " x = x.view(N*T, C, H, W)\n", 1079 | " \n", 1080 | " # 经过一个卷积层\n", 1081 | " x = self.conv(x)\n", 1082 | " _, C_new, H_new, W_new = x.shape\n", 1083 | " x = x.view(N, T, C_new, H_new, W_new)\n", 1084 | " \n", 1085 | " # TCNN部分\n", 1086 | " for i in range(3):\n", 1087 | " x = self.tcnn1(x)\n", 1088 | " x = self.tcnn2(x)\n", 1089 | " for i in range(2):\n", 1090 | " x = self.tcnn3(x)\n", 1091 | " x = self.tcnn4(x)\n", 1092 | " for i in range(2):\n", 1093 | " x = self.tcnn5(x)\n", 1094 | " \n", 1095 | " # 全局平均池化\n", 1096 | " x = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()\n", 1097 | " \n", 1098 | " # RNN部分,分别得到长度为T、T/2、T/4三种时间尺度的特征表达,注意转换RNN层输出的格式\n", 1099 | " hidden_state = []\n", 1100 | " for i in range(3):\n", 1101 | " x, h = self.rnn(x)\n", 1102 | " h = h.squeeze()\n", 1103 | " hidden_state.append(h)\n", 1104 | " x = self.maxpool(x.transpose(1, 2)).transpose(1, 2)\n", 1105 | " \n", 1106 | " x = torch.cat(hidden_state, dim=1)\n", 1107 | " x = self.fc(x)\n", 1108 | " \n", 1109 | " return x" 1110 | ] 1111 | }, 1112 | { 1113 | "cell_type": "code", 1114 | "execution_count": 25, 1115 | "metadata": { 1116 | "execution": { 1117 | "iopub.execute_input": "2021-11-07T11:33:59.467079Z", 1118 | "iopub.status.busy": "2021-11-07T11:33:59.466474Z", 1119 | "iopub.status.idle": "2021-11-07T11:33:59.507567Z", 1120 | "shell.execute_reply": "2021-11-07T11:33:59.507973Z", 1121 | "shell.execute_reply.started": "2021-11-07T11:23:57.232949Z" 1122 | }, 1123 | "papermill": { 1124 | "duration": 0.076245, 1125 | "end_time": "2021-11-07T11:33:59.508101", 1126 | "exception": false, 1127 | "start_time": "2021-11-07T11:33:59.431856", 1128 | "status": "completed" 1129 | }, 1130 | "tags": [] 1131 | }, 1132 | "outputs": [ 1133 | { 1134 | "name": "stdout", 1135 | "output_type": "stream", 1136 | "text": [ 1137 | "Model(\n", 1138 | " (conv): Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", 1139 | " (tcnn1): TCNNBlock(\n", 1140 | " (tcn): TCNBlock(\n", 1141 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1142 | " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n", 1143 | " (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1144 | " )\n", 1145 | " (cnn): CNNBlock(\n", 1146 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1147 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1148 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1149 | " )\n", 1150 | " )\n", 1151 | " (tcnn2): TCNNBlock(\n", 1152 | " (tcn): TCNBlock(\n", 1153 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1154 | " (conv): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", 1155 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1156 | " (res): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", 1157 | " )\n", 1158 | " (cnn): CNNBlock(\n", 1159 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1160 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", 1161 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1162 | " (res): Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 2))\n", 1163 | " )\n", 1164 | " )\n", 1165 | " (tcnn3): TCNNBlock(\n", 1166 | " (tcn): TCNBlock(\n", 1167 | " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1168 | " (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n", 1169 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1170 | " )\n", 1171 | " (cnn): CNNBlock(\n", 1172 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1173 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1174 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1175 | " )\n", 1176 | " )\n", 1177 | " (tcnn4): TCNNBlock(\n", 1178 | " (tcn): TCNBlock(\n", 1179 | " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1180 | " (conv): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", 1181 | " (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1182 | " (res): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", 1183 | " )\n", 1184 | " (cnn): CNNBlock(\n", 1185 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1186 | " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", 1187 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1188 | " (res): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2))\n", 1189 | " )\n", 1190 | " )\n", 1191 | " (tcnn5): TCNNBlock(\n", 1192 | " (tcn): TCNBlock(\n", 1193 | " (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1194 | " (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n", 1195 | " (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1196 | " )\n", 1197 | " (cnn): CNNBlock(\n", 1198 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1199 | " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1200 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1201 | " )\n", 1202 | " )\n", 1203 | " (rnn): RNN(256, 256, batch_first=True)\n", 1204 | " (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 1205 | " (fc): Linear(in_features=768, out_features=24, bias=True)\n", 1206 | ")\n" 1207 | ] 1208 | } 1209 | ], 1210 | "source": [ 1211 | "model = Model()\n", 1212 | "print(model)" 1213 | ] 1214 | }, 1215 | { 1216 | "cell_type": "markdown", 1217 | "metadata": {}, 1218 | "source": [ 1219 | "#### 模型训练" 1220 | ] 1221 | }, 1222 | { 1223 | "cell_type": "code", 1224 | "execution_count": 26, 1225 | "metadata": { 1226 | "execution": { 1227 | "iopub.execute_input": "2021-11-07T11:33:59.569324Z", 1228 | "iopub.status.busy": "2021-11-07T11:33:59.568443Z", 1229 | "iopub.status.idle": "2021-11-07T11:33:59.570995Z", 1230 | "shell.execute_reply": "2021-11-07T11:33:59.570559Z", 1231 | "shell.execute_reply.started": "2021-11-07T11:23:59.196182Z" 1232 | }, 1233 | "papermill": { 1234 | "duration": 0.035076, 1235 | "end_time": "2021-11-07T11:33:59.571104", 1236 | "exception": false, 1237 | "start_time": "2021-11-07T11:33:59.536028", 1238 | "status": "completed" 1239 | }, 1240 | "tags": [] 1241 | }, 1242 | "outputs": [], 1243 | "source": [ 1244 | "# 采用RMSE作为损失函数\n", 1245 | "def RMSELoss(y_pred,y_true):\n", 1246 | " loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n", 1247 | " return loss" 1248 | ] 1249 | }, 1250 | { 1251 | "cell_type": "code", 1252 | "execution_count": 27, 1253 | "metadata": { 1254 | "execution": { 1255 | "iopub.execute_input": "2021-11-07T11:33:59.635947Z", 1256 | "iopub.status.busy": "2021-11-07T11:33:59.635133Z", 1257 | "iopub.status.idle": "2021-11-07T11:41:54.987872Z", 1258 | "shell.execute_reply": "2021-11-07T11:41:54.987159Z", 1259 | "shell.execute_reply.started": "2021-11-07T11:24:00.777235Z" 1260 | }, 1261 | "papermill": { 1262 | "duration": 475.387685, 1263 | "end_time": "2021-11-07T11:41:54.988052", 1264 | "exception": false, 1265 | "start_time": "2021-11-07T11:33:59.600367", 1266 | "status": "completed" 1267 | }, 1268 | "tags": [] 1269 | }, 1270 | "outputs": [ 1271 | { 1272 | "name": "stdout", 1273 | "output_type": "stream", 1274 | "text": [ 1275 | "Epoch: 1/10\n" 1276 | ] 1277 | }, 1278 | { 1279 | "name": "stderr", 1280 | "output_type": "stream", 1281 | "text": [ 1282 | "100%|██████████| 100/100 [00:51<00:00, 1.95it/s]\n" 1283 | ] 1284 | }, 1285 | { 1286 | "name": "stdout", 1287 | "output_type": "stream", 1288 | "text": [ 1289 | "Training Loss: 18.099\n" 1290 | ] 1291 | }, 1292 | { 1293 | "name": "stderr", 1294 | "output_type": "stream", 1295 | "text": [ 1296 | "4it [00:00, 6.19it/s]\n" 1297 | ] 1298 | }, 1299 | { 1300 | "name": "stdout", 1301 | "output_type": "stream", 1302 | "text": [ 1303 | "Validation Loss: 16.756\n", 1304 | "Score: -4.320\n", 1305 | "Epoch: 2/10\n" 1306 | ] 1307 | }, 1308 | { 1309 | "name": "stderr", 1310 | "output_type": "stream", 1311 | "text": [ 1312 | "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n" 1313 | ] 1314 | }, 1315 | { 1316 | "name": "stdout", 1317 | "output_type": "stream", 1318 | "text": [ 1319 | "Training Loss: 16.955\n" 1320 | ] 1321 | }, 1322 | { 1323 | "name": "stderr", 1324 | "output_type": "stream", 1325 | "text": [ 1326 | "4it [00:00, 6.41it/s]\n" 1327 | ] 1328 | }, 1329 | { 1330 | "name": "stdout", 1331 | "output_type": "stream", 1332 | "text": [ 1333 | "Validation Loss: 17.657\n", 1334 | "Score: -32.332\n", 1335 | "Epoch: 3/10\n" 1336 | ] 1337 | }, 1338 | { 1339 | "name": "stderr", 1340 | "output_type": "stream", 1341 | "text": [ 1342 | "100%|██████████| 100/100 [00:45<00:00, 2.18it/s]\n" 1343 | ] 1344 | }, 1345 | { 1346 | "name": "stdout", 1347 | "output_type": "stream", 1348 | "text": [ 1349 | "Training Loss: 16.639\n" 1350 | ] 1351 | }, 1352 | { 1353 | "name": "stderr", 1354 | "output_type": "stream", 1355 | "text": [ 1356 | "4it [00:00, 6.29it/s]\n" 1357 | ] 1358 | }, 1359 | { 1360 | "name": "stdout", 1361 | "output_type": "stream", 1362 | "text": [ 1363 | "Validation Loss: 19.156\n", 1364 | "Score: -25.483\n", 1365 | "Epoch: 4/10\n" 1366 | ] 1367 | }, 1368 | { 1369 | "name": "stderr", 1370 | "output_type": "stream", 1371 | "text": [ 1372 | "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n" 1373 | ] 1374 | }, 1375 | { 1376 | "name": "stdout", 1377 | "output_type": "stream", 1378 | "text": [ 1379 | "Training Loss: 16.173\n" 1380 | ] 1381 | }, 1382 | { 1383 | "name": "stderr", 1384 | "output_type": "stream", 1385 | "text": [ 1386 | "4it [00:00, 6.29it/s]\n" 1387 | ] 1388 | }, 1389 | { 1390 | "name": "stdout", 1391 | "output_type": "stream", 1392 | "text": [ 1393 | "Validation Loss: 18.130\n", 1394 | "Score: -15.470\n", 1395 | "Epoch: 5/10\n" 1396 | ] 1397 | }, 1398 | { 1399 | "name": "stderr", 1400 | "output_type": "stream", 1401 | "text": [ 1402 | "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n" 1403 | ] 1404 | }, 1405 | { 1406 | "name": "stdout", 1407 | "output_type": "stream", 1408 | "text": [ 1409 | "Training Loss: 15.818\n" 1410 | ] 1411 | }, 1412 | { 1413 | "name": "stderr", 1414 | "output_type": "stream", 1415 | "text": [ 1416 | "4it [00:00, 6.28it/s]\n" 1417 | ] 1418 | }, 1419 | { 1420 | "name": "stdout", 1421 | "output_type": "stream", 1422 | "text": [ 1423 | "Validation Loss: 17.367\n", 1424 | "Score: -14.745\n", 1425 | "Epoch: 6/10\n" 1426 | ] 1427 | }, 1428 | { 1429 | "name": "stderr", 1430 | "output_type": "stream", 1431 | "text": [ 1432 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n" 1433 | ] 1434 | }, 1435 | { 1436 | "name": "stdout", 1437 | "output_type": "stream", 1438 | "text": [ 1439 | "Training Loss: 15.464\n" 1440 | ] 1441 | }, 1442 | { 1443 | "name": "stderr", 1444 | "output_type": "stream", 1445 | "text": [ 1446 | "4it [00:00, 6.28it/s]\n" 1447 | ] 1448 | }, 1449 | { 1450 | "name": "stdout", 1451 | "output_type": "stream", 1452 | "text": [ 1453 | "Validation Loss: 18.289\n", 1454 | "Score: -4.441\n", 1455 | "Epoch: 7/10\n" 1456 | ] 1457 | }, 1458 | { 1459 | "name": "stderr", 1460 | "output_type": "stream", 1461 | "text": [ 1462 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n" 1463 | ] 1464 | }, 1465 | { 1466 | "name": "stdout", 1467 | "output_type": "stream", 1468 | "text": [ 1469 | "Training Loss: 15.175\n" 1470 | ] 1471 | }, 1472 | { 1473 | "name": "stderr", 1474 | "output_type": "stream", 1475 | "text": [ 1476 | "4it [00:00, 6.26it/s]\n" 1477 | ] 1478 | }, 1479 | { 1480 | "name": "stdout", 1481 | "output_type": "stream", 1482 | "text": [ 1483 | "Validation Loss: 18.604\n", 1484 | "Score: -21.144\n", 1485 | "Epoch: 8/10\n" 1486 | ] 1487 | }, 1488 | { 1489 | "name": "stderr", 1490 | "output_type": "stream", 1491 | "text": [ 1492 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n" 1493 | ] 1494 | }, 1495 | { 1496 | "name": "stdout", 1497 | "output_type": "stream", 1498 | "text": [ 1499 | "Training Loss: 15.004\n" 1500 | ] 1501 | }, 1502 | { 1503 | "name": "stderr", 1504 | "output_type": "stream", 1505 | "text": [ 1506 | "4it [00:00, 6.27it/s]\n" 1507 | ] 1508 | }, 1509 | { 1510 | "name": "stdout", 1511 | "output_type": "stream", 1512 | "text": [ 1513 | "Validation Loss: 18.593\n", 1514 | "Score: -27.508\n", 1515 | "Epoch: 9/10\n" 1516 | ] 1517 | }, 1518 | { 1519 | "name": "stderr", 1520 | "output_type": "stream", 1521 | "text": [ 1522 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n" 1523 | ] 1524 | }, 1525 | { 1526 | "name": "stdout", 1527 | "output_type": "stream", 1528 | "text": [ 1529 | "Training Loss: 14.578\n" 1530 | ] 1531 | }, 1532 | { 1533 | "name": "stderr", 1534 | "output_type": "stream", 1535 | "text": [ 1536 | "4it [00:00, 6.28it/s]\n" 1537 | ] 1538 | }, 1539 | { 1540 | "name": "stdout", 1541 | "output_type": "stream", 1542 | "text": [ 1543 | "Validation Loss: 18.264\n", 1544 | "Score: -19.113\n", 1545 | "Epoch: 10/10\n" 1546 | ] 1547 | }, 1548 | { 1549 | "name": "stderr", 1550 | "output_type": "stream", 1551 | "text": [ 1552 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n" 1553 | ] 1554 | }, 1555 | { 1556 | "name": "stdout", 1557 | "output_type": "stream", 1558 | "text": [ 1559 | "Training Loss: 14.330\n" 1560 | ] 1561 | }, 1562 | { 1563 | "name": "stderr", 1564 | "output_type": "stream", 1565 | "text": [ 1566 | "4it [00:00, 6.27it/s]\n" 1567 | ] 1568 | }, 1569 | { 1570 | "name": "stdout", 1571 | "output_type": "stream", 1572 | "text": [ 1573 | "Validation Loss: 17.739\n", 1574 | "Score: -18.628\n" 1575 | ] 1576 | } 1577 | ], 1578 | "source": [ 1579 | "model_weights = './task04_model_weights.pth'\n", 1580 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 1581 | "model = Model().to(device)\n", 1582 | "criterion = RMSELoss\n", 1583 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", 1584 | "epochs = 10\n", 1585 | "train_losses, valid_losses = [], []\n", 1586 | "scores = []\n", 1587 | "best_score = float('-inf')\n", 1588 | "preds = np.zeros((len(y_valid),24))\n", 1589 | "\n", 1590 | "for epoch in range(epochs):\n", 1591 | " print('Epoch: {}/{}'.format(epoch+1, epochs))\n", 1592 | " \n", 1593 | " # 模型训练\n", 1594 | " model.train()\n", 1595 | " losses = 0\n", 1596 | " for data, labels in tqdm(trainloader):\n", 1597 | " data = data.to(device)\n", 1598 | " labels = labels.to(device)\n", 1599 | " optimizer.zero_grad()\n", 1600 | " pred = model(data)\n", 1601 | " loss = criterion(pred, labels)\n", 1602 | " losses += loss.cpu().detach().numpy()\n", 1603 | " loss.backward()\n", 1604 | " optimizer.step()\n", 1605 | " train_loss = losses / len(trainloader)\n", 1606 | " train_losses.append(train_loss)\n", 1607 | " print('Training Loss: {:.3f}'.format(train_loss))\n", 1608 | " \n", 1609 | " # 模型验证\n", 1610 | " model.eval()\n", 1611 | " losses = 0\n", 1612 | " with torch.no_grad():\n", 1613 | " for i, data in tqdm(enumerate(validloader)):\n", 1614 | " data, labels = data\n", 1615 | " data = data.to(device)\n", 1616 | " labels = labels.to(device)\n", 1617 | " pred = model(data)\n", 1618 | " loss = criterion(pred, labels)\n", 1619 | " losses += loss.cpu().detach().numpy()\n", 1620 | " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n", 1621 | " valid_loss = losses / len(validloader)\n", 1622 | " valid_losses.append(valid_loss)\n", 1623 | " print('Validation Loss: {:.3f}'.format(valid_loss))\n", 1624 | " s = score(y_valid, preds)\n", 1625 | " scores.append(s)\n", 1626 | " print('Score: {:.3f}'.format(s))\n", 1627 | " \n", 1628 | " # 保存最佳模型权重\n", 1629 | " if s > best_score:\n", 1630 | " best_score = s\n", 1631 | " checkpoint = {'best_score': s,\n", 1632 | " 'state_dict': model.state_dict()}\n", 1633 | " torch.save(checkpoint, model_weights)" 1634 | ] 1635 | }, 1636 | { 1637 | "cell_type": "code", 1638 | "execution_count": 28, 1639 | "metadata": { 1640 | "execution": { 1641 | "iopub.execute_input": "2021-11-07T11:41:55.621271Z", 1642 | "iopub.status.busy": "2021-11-07T11:41:55.620191Z", 1643 | "iopub.status.idle": "2021-11-07T11:41:55.623008Z", 1644 | "shell.execute_reply": "2021-11-07T11:41:55.623387Z", 1645 | "shell.execute_reply.started": "2021-11-07T11:31:56.277287Z" 1646 | }, 1647 | "papermill": { 1648 | "duration": 0.31815, 1649 | "end_time": "2021-11-07T11:41:55.623547", 1650 | "exception": false, 1651 | "start_time": "2021-11-07T11:41:55.305397", 1652 | "status": "completed" 1653 | }, 1654 | "tags": [] 1655 | }, 1656 | "outputs": [], 1657 | "source": [ 1658 | "# 绘制训练/验证曲线\n", 1659 | "def training_vis(train_losses, valid_losses):\n", 1660 | " # 绘制损失函数曲线\n", 1661 | " fig = plt.figure(figsize=(8,4))\n", 1662 | " # subplot loss\n", 1663 | " ax1 = fig.add_subplot(121)\n", 1664 | " ax1.plot(train_losses, label='train_loss')\n", 1665 | " ax1.plot(valid_losses,label='val_loss')\n", 1666 | " ax1.set_xlabel('Epochs')\n", 1667 | " ax1.set_ylabel('Loss')\n", 1668 | " ax1.set_title('Loss on Training and Validation Data')\n", 1669 | " ax1.legend()\n", 1670 | " plt.tight_layout()" 1671 | ] 1672 | }, 1673 | { 1674 | "cell_type": "code", 1675 | "execution_count": 29, 1676 | "metadata": { 1677 | "execution": { 1678 | "iopub.execute_input": "2021-11-07T11:41:56.271045Z", 1679 | "iopub.status.busy": "2021-11-07T11:41:56.270383Z", 1680 | "iopub.status.idle": "2021-11-07T11:41:56.549178Z", 1681 | "shell.execute_reply": "2021-11-07T11:41:56.549615Z", 1682 | "shell.execute_reply.started": "2021-11-07T11:31:56.286373Z" 1683 | }, 1684 | "papermill": { 1685 | "duration": 0.611301, 1686 | "end_time": "2021-11-07T11:41:56.549765", 1687 | "exception": false, 1688 | "start_time": "2021-11-07T11:41:55.938464", 1689 | "status": "completed" 1690 | }, 1691 | "tags": [] 1692 | }, 1693 | "outputs": [ 1694 | { 1695 | "data": { 1696 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA51ElEQVR4nO3dd3hU1dbA4d9KIyGEHnoJNbQQSkBAelGaYENERVRsoGLlil4/9d7rtVcsSBELIIrYBRWpAQUh9E7oRFoSeifJ/v7YgzdAAikzc2aS9T5PHmbOzJy9zsywZp99dhFjDEop5W8CnA5AKaXyQpOXUsovafJSSvklTV5KKb+kyUsp5Zc0eSml/JImr0JKRH4WkUHufq6TRGS7iHT1wH7nisjdrtu3isiMnDw3D+VUE5FjIhKY11gLk0KVvDz15fYW1xf73F+GiJzMdP/W3OzLGNPDGPOpu5/ri0RkhIjEZ7G9rIicEZFGOd2XMWaSMeYqN8V13vfRGLPTGFPMGJPujv1fUJYRkeOu70qqiMwSkf65eH1HEUlyd1z5UaiSl79zfbGLGWOKATuBazJtm3TueSIS5FyUPmki0EZEalyw/WZgtTFmjQMxOSHW9d2JBj4B3hOR55wNKe80eQEiUkRE3haR3a6/t0WkiOuxsiLyk4gcEpEDIjJfRAJcjz0pIn+JyFER2SgiXbLZfwkR+UxEkkVkh4g8k2kfd4jIAhF5XUQOisg2EemRy/g7ikiSK569wMciUsoVd7Jrvz+JSJVMr8l8KnTJGHL53BoiEu96T2aKyPsiMjGbuHMS439E5HfX/maISNlMjw90vZ+pIvLP7N4fY0wSMBsYeMFDtwOfXS6OC2K+Q0QWZLrfTUQ2iMhhEXkPkEyP1RKR2a74UkRkkoiUdD02AagG/OiqDf1DRKJcNaQg13MqicgPru/dZhG5J9O+nxeRKa7v1VERWSsicdm9Bxe8HynGmAnAEOApESnj2uedIrLetb+tInKfa3s48DNQSf5X068kIi1FZKHr/8YeEXlPREJyEoM7aPKy/gm0ApoAsUBL4BnXY48DSUAkUB54GjAiEg08CLQwxkQAVwPbs9n/u0AJoCbQAfuf5s5Mj18BbATKAq8CH4mIXLiTy6gAlAaqA/diP9uPXferASeB9y7x+tzEcKnnfg4sBsoAz3NxwsgsJzHegn2vygEhwBMAItIAGOXafyVXeVkmHJdPM8fi+vyauOLN7Xt1bh9lgW+w35WywBbgysxPAV5yxVcfqIp9TzDGDOT82vOrWRTxBfa7Vwm4EXhRRDpneryP6zklgR9yEvMFvgeCsN93gP1Ab6A49j1/S0SaGWOOAz2A3Zlq+ruBdOBR17G3BroAQ3MZQ94ZYwrNHza5dM1i+xagZ6b7VwPbXbf/jf2Qa1/wmtrYD7srEHyJMgOBM0CDTNvuA+a6bt8BbM70WFHAABVyeixAR1cZoZd4fhPgYKb7c4G7cxJDTp+L/Y+fBhTN9PhEYGIOP5+sYnwm0/2hwC+u288CX2R6LNz1Hlz0+WaK8wjQxnX/v8D3eXyvFrhu3w4syvQ8wSabu7PZ77XA8uy+j0CU670Mwia6dCAi0+MvAZ+4bj8PzMz0WAPg5CXeW8MF32HX9r3Ardm85jvg4UzfsaTLfH6PAN/m5LN2x5/WvKxKwI5M93e4tgG8BmwGZriq0iMAjDGbsR/W88B+EflCRCpxsbJAcBb7r5zp/t5zN4wxJ1w3i+XyGJKNMafO3RGRoiIy2nVadQSIB0pK9leychNDds+tBBzItA1gV3YB5zDGvZlun8gUU6XM+za2dpCaXVmumL4CbnfVEm8FPstFHFm5MAaT+b6IlHd9L/5y7Xci9vuQE+fey6OZtmX7vcG+N6GSi/ZOEQnGnlEccN3vISKLXKeph4Cel4pXROq6TrH3uo7vxUs93900eVm7sacM51RzbcMYc9QY87gxpia2mv6YuNq2jDGfG2Paul5rgFey2HcKcDaL/f/l5mO4cHqQx7ENs1cYY4oD7V3bc3s6mht7gNIiUjTTtqqXeH5+YtyTed+uMstc5jWfAjcB3YAI4Md8xnFhDML5x/si9nOJce33tgv2eakpXXZj38uITNvc/b3pi60pLxbbxvs18DpQ3hhTEpieKd6sYh0FbADquI7vaTz7/TpPYUxewSISmukvCJgMPCMika52jGexv5KISG8Rqe36Yh7GVuUzRCRaRDq7PvRT2HaSjAsLM/ay9xTgvyISISLVgcfO7d+DIlwxHRKR0oDHryoZY3YACcDzIhIiIq2BazwU41Sgt4i0dTUS/5vLf5/nA4eAMdhTzjP5jGMa0FBErnd9j4ZhT5/PiQCOAYdFpDIw/ILX78O2g17EGLML+AN4yfU9bQwMxg3fGxEpLbZrzfvAK8aYVGx7YhEgGUgTexEmc5eQfUAZESlxwfEdAY6JSD3sBQCvKYzJazr2i3ru73ngBex/ulXAamCZaxtAHWAm9ku4EPjAGDMH+0G/jK1Z7cU2KD+VTZkPAceBrcACbCPxePce1kXeBsJc8S0CfvFweefcim28TcW+h18Cp7N57tvkMUZjzFrgAex7uQc4iG1vutRrDPZUsbrr33zFYYxJAfphvwep2O/K75me8i+gGfZHbxq2cT+zl7A/modE5IksihiAbQfbDXwLPGeMmZmT2LKxUkSOYZtB7gYeNcY86zqWo9jkOwX7Xt6CvQhw7lg3YH/kt7rirYS9eHILcBQYi/2svUZcDW1KeYSIfAlsMMb4bX8i5ZsKY81LeZCItBDbvylARLpj21W+czgsVQBpT2zlbhWwp0dlsKdxQ4wxy50NSRVEetqolPJLetqolPJLfnHaWLZsWRMVFeV0GEopL1u6dGmKMSYyq8f8InlFRUWRkJDgdBhKKS8TkR3ZPaanjUopv6TJSynllzR5KaX8kiYvpZRf0uSllPJLmryUUn5Jk5dSyi9p8vJlB7bC/g1OR6GUT9Lk5auMgcm3wMc94MQBp6NRyudo8vJVO36H5PVw8gDM+pfT0SjlczR5+aol4yC0JMQNhqWfQpIOj1IqM01evujoXlj/IzS9Dbr9CyIqwE+PQobbV4FXym95LHmJyHgR2S8iazJti3WtsLtaRH4UkeKeKt+vLf0UMtIg7i4oEgFXvwh7V8GSj5yOTCmf4cma1ydA9wu2jQNGGGNisAsKXLiaiko/C0s/hlpdoEwtu63hdVCzE8x+AY7uczY+pXyEx5KXMSYe12KWmdTFLugJ8Btwg6fK91sbp8PRPdDi7v9tE4Ger0PaSfjt/5yLTSkf4u02r7XYBRnALhmV7YKkInKviCSISEJycrJXgvMJS8ZBiapQ9+rzt5etDVc+DKu+hO0LnIlNKR/i7eR1FzBURJZiF6w8k90TjTFjjDFxxpi4yMgsJ1IseJI3wrZ4iLsTArJYab7tY1CyGkx73J5eKlWIeTV5GWM2GGOuMsY0xy5gucWb5fu8JR9BYAg0vT3rx0OKQo9XIXkDLPrAu7Ep5WO8mrxEpJzr3wDgGeBDb5bv004fg5WTocG1UOwSNc3oHhDdE+a+DIcvuUC0UgWaJ7tKTAYWAtEikiQig4EBIrIJ2IBdwvxjT5Xvd1ZPgdNHzm+oz073l+3woV+e8nxcSvkojy3AYYwZkM1D73iqTL9ljD1lrBADVVte/vmlqkP7J2D2fyBxJtTp6vkYlfIx2sPeF+z6E/atsbUukZy9ps1DUKY2TH8Czp7ybHxK+SBNXr5g8VgoUgJi+uX8NUFFbN+vg9vg97c9FppSvkqTl9OO7Yd130OTWyAkPHevrdUJGl4P89+0c38pVYho8nLass8g4yy0GJy311/9XwgMhp+ftG1nShUSmryclJ4GCR9DjQ5Qtk7e9lG8EnR6GhJnwIaf3BufUj5Mk5eTEn+FI0nQ8p787aflfVCuIfw8As4cd09sSvk4TV5OWjwWileGuj3yt5/AIOj1hk2E8151T2xK+ThNXk5J2Qxb50DzO23yya/qraHJrbDwPV20QxUKmryckjAeAoKgWTbjGPOi278hpJjt+6WN96qA81gPe3UJZ07AiolQvw9ElHfffsPLQpdnYdpjsHoqNM5FvzGVc+lnYeUXsHe1vdIbEAgBwfbHKDAo0+0LH8vqftDFtzPfDwyG4lUgQOsZF9Lk5YQ1U+HU4fw31Gel+R2wfCL8+jTUvQpCS7i/jMLqXNKKfxUO7YSQCDAZdsrujLP2tidE1oPO/wf1euV8BEYhoMnL24yxDfXlGkC11u7ff0Cgbbwf2xnmvAg9XnF/GYVNepodOD/vFTi4HSo1hZ5vQJ1u5yeTjHOJzJXMMtJtwrvovmtbeubnZnP/1GFYPBq+vBUqx0HX56BGe8feCl+iycvbkhLsYhq93vDcr2jlZnbxjsVjbM/9irGeKaegy0i3p9/zXoEDW6BCYxjwBdTtnvVnFxAAASFAiHvjiLsLVn5up0H69Bqo1dk2D1Rq6t5y/IyeSHvbknH2dKNxf8+W0+X/IKy0nXU1w0OnMwXVuaT1/hXw7b0QXBT6T4L74u18at4+dQt0Xdh5aBlc9V/YvQLGdIQpt0NKondj8SGavLzpeCqs/QZib7ZLmnlSWCm46j+QtASWT/BsWQVFRgas+QY+aA1fD7aN5Td9ZpNW/d7OtzcFh0KbB+HhldDhSdg8yybY7x8slBNTavLypuWfQfqZvI9jzK3YAVCtDcx8ziZOlbWMDDs4/sMrYeqdNknd+DHc/zs06Ot7V/pCi9shYcNWQMt77aIsI5vBr/8sVJ+zj30qBVhGuu3bFdUOytX3Tpki0Ot1OHUEZj3vnTL9iTGw/icY3c6egqWfhRs+giF/QKPrfS9pXahYJPR4GR5aaqdTWvQBvBMLc1+B00edjs7jfPzTKUASf7OX171V6zqnfENoNcTOXrFriXfL9lXGwMafYXR7exXv7Em4bgw88CfE3Jj1yk2+rGQ1uPZ9GLIQanWEuS/CO01g0ShIO+10dB6jyctbloyDYhWgXm/vl91xBERUtJ1X09O8X76vMAY2zYCxnWDyzXbNgGtHwQOLIba//yWtC5WrB/0nwt2z7Y/WLyPg3eawfJKt+Rcwmry84cBW2DzTdiANDPZ++UUioPtLtotGwkfeL99pxtj3f1xX+LwfnEiFPu/Bgwm2K4k7xpb6kirNYdAPMPA7O+ri+6H2IsT6HwvUsDFNXt6QMB4kAJoPci6GBtdCzU4w+wU4us+5OLzJGNgyBz66CibeAMf2wTXvwINLodlAZ35IvKlWJ7hnDtw0ATDw5W0wrgtsned0ZG6hycvTzp60w3Xq97YTBzpFxM55n3YKZjzjXBzesi0ePu4BE66FI39BrzdtP6nmd0CQmzuR+jIRaNDHtof1fd/+cH3WBz7rC38tczq6fNHk5Wlrv4WTB3O2HqOnla0NVz5sh7psm+90NJ6RugU+6W17oh/cbhP2sOX2QklhSloXCgyCprfZK5NXv2QHlY/tBF8OhORNTkeXJ2L84Bw4Li7OJCQkOB1G3ozpZGc3feBP5zs5gp3R4oMrICgM7l9QsP5Dp5+1YzoP7YCOT9taVnCo01H5ptNHYeH78Md7cPa4bfvr9h8oWtrpyM4jIkuNMXFZPaY1L0/6aynsXpa79Rg9LaQo9HgNUjbafkEFyfw37EWJvu9Dq/s1cV1KkQh7FfrhldBqKKz80l7QSNnsdGQ5psnLk5aMh+Bwexnel0R3h+iedsDxoV1OR+Meu1dA/GsQcxPUv8bpaPxHeBm7AtWgH+HUIdugvy3e6ahyRJOXp5w4YOftanyTb86p1f1lezXu16ecjiT/0k7Dd0OgaFmdAiivqreGu2dBRAWYcB0s/dTpiC5Lk5enrJhkr+z5QkN9VkpVh/ZP2L4/ib85HU3+zH0Z9q+DPiN9rs3Gr5SuAYNn2KX4fhxmx0r6cOdWTV6ekJEBSz6ykw1WaOR0NNlr8xCUqWPnvD970ulo8iYpAX5/G5rcBnWvdjoa/xdaAm6ZYpfTW/gefHGLz46T1OTlCVtmw8FtvlvrOieoiB24fXC77bzqb86ehG/vh4hK0P1Fp6MpOAKDoOertptJ4m8wvrtPto0WqOSVuO8oA8Ys4q9DDtciloyD8Ei7wIavq9kR4gbbX9lVXzkdTe7MfgFSE6Hvu77ZrujvWt4Dt35lJxQY29nWcn1IgUpeRYsEsXTnQd6c4WCnu4M7YNMv0GyQ//Sh6v6ynffrhwf9p9f1jj9sP6W4u+y0yMozaneBwb/ZLjYf97QzzPoIjyUvERkvIvtFZE2mbU1EZJGIrBCRBBFp6c4yK5cM4842UXyzPIl1u4+4c9c5t/Rj26cr7k5nys+LoBA7Y2h4JHxxq++PfTxzHL4baqeC6fYfp6Mp+MrVszNVVG5uZ5id+7JPDPD2ZM3rE6D7BdteBf5ljGkCPOu671ZDO9ameGgwr/ziwKrRaaftvFnRPaFEFe+Xnx/FIuHmz21fny9v8+15oGY+b9sUr/0AihRzOprCIbwM3P4dxN4Cc1+Cr+92/CKPx5KXMSYeOHDhZqC463YJYLe7yy1RNJgHO9Vm3qZkft+c4u7dX9ra7+x0K96ecNBdKja2CSFpsZ37ywd+XS+ydZ5dFemKIRDV1uloCpegIvb70fV524fxk96O1tK93eb1CPCaiOwCXgey7SEpIve6Ti0TkpOTc1XIwNbVqVwyjJd+Xk9Ghhf/Ay4ZB2VqQ42O3ivT3RpeB+2H25kw/hztdDTnO3XELjZRupZd+kt5nwi0fdROerh/ne2Rv3fN5V/nAd5OXkOAR40xVYFHgWxnxjPGjDHGxBlj4iIjI3NVSGhwIE9cXZc1fx3hx1Vur9xlbc9KW2OJG+z7c59fTsenIbqXXXV761yno/mfGc/AkSQ7+2lIUaejKdzqXwN3/mwXxx1/NWz8xeshePt/2SDgG9ftrwC3Nthn1je2Mg0qFue1XzdyOs0LvYSXjLMzNTQZ4PmyPC0gAK4fDWXrwpRBdiZYpyXOhGWfQusHodoVTkejACo1gXtm27ONyTfbGSq82NTg7eS1G+jgut0Z8NiKmQEBwlM965F08CQTF+30VDHWyUO2j1Tjfna9xIKgSAQM+NzenuxwL+uTh+CHhyCyHnT6p3NxqIsVr2RrYPWvgRn/hB8ftlMTeYEnu0pMBhYC0SKSJCKDgXuAN0RkJfAicK+nygdoVyeSdnXK8u7sRA6f9OAbunIypJ30/R71uVW6JvT7BFI2wTf3Obfy9i8j7BTO147SaW58UUhR6PcptHvc1o4nXm8nJvAwT15tHGCMqWiMCTbGVDHGfGSMWWCMaW6MiTXGXGGMWeqp8s95sns9Dp04y4fztnimgIwMe8pYpSVUjPVMGU6q1clOmbJxmr1E7m0bptsfh3aPQeVm3i9f5UxAgL2Ict1o2LnIK3OD+XnL8uU1qlyC65pWZvyCbew57IF+KdvmQermglfryuyK++3A5/hX7bTW3nLigD0NKR8D7f/hvXJV3sXeDLf/4JW5wQp88gJ4rFtdjMEzw4aWjIOiZeyy8AWVCPR+09YuvxsKe1Z5p9zpT9j5/68b5T9DrZTX5gYrFMmraumiDGpTna+XJbFxrxsbng8nwcbp0Oz2gt8WE1TE9u0JLWmHEB33cAfgtd/Bmq+hw5NQIcazZSn388LcYIUieQE80Kk2xYoEuXfY0NJP7KXh5n40jjE/IsrDzZPg+H6YcjuknfFMOceSbQ//ik1sh0jlnzw8N1ihSV4li4YwtFNtZm/Yz8ItqfnfYdoZWx2ue7WdlbSwqNzMrja943f45Un3798Y+OkR+yW/7sOCt5p1YZPV3GBuGlJUaJIXwB1toqhUItQ9w4bW/2BrIC3ucU9w/qRxP7v+Y8J4O2OsO63+Cjb8ZPtzlavv3n0r55ybG6x4JQgr6ZZdFqrkFRocyGNXRbMq6TDTVu/J386WfASlogrvXFJdnoM6V8HP/4Dtv7tnn0f22Eb6Ki3tFNWqYKndxSawoCJu2V2hSl4A1zWtTL0KEbz260bOpOWx0+W+tbDzj4IxjjGvAgLhhnFQqgZMGWhn28wPY2zDbtoZ2xk1INA9caoCq9D9zwsMEEb0qMfOAyf4/M8dud9BRoYdsBwcbpdPL8xCS8CAyZCeZocQnTme930tnwiJM6Drc1C2tvtiVAVWoUteAB3qRtKmVhlGzt7M0VO5HDa0ZKydaeHqF3SZLYCydeDGj2DfGtsHLC8Dcw/tgl+egupt7ZUppXKgUCYvEeGpHvU5cPwMo+flYsaE5E3w27O2raewdI/IiTrdoNu/YN13EP967l5rjJ0732RA3/cK72m4yrVC+02JqVKCPrGVGLdgK3sPn7r8C9LPwrf3QnBR6POu7XWu/qfNMGjcH+a8ABum5fx1CeNtTfaq/9iOjUrlUKFNXgDDr44mPcPw9swcDBuKfx12L4feb9lhD+p8InDNO1CpKXxzL+xbd/nXHNgGM/4PanayqwAplQuFOnlVLV2Uga2imJKwi8R9l+j5m7QU4l+DxjdDw2u9Fp/fCQ6zi3iEhMMXAy49LUpGhp3SOSDQni5qTVblUqFOXgAPdq5NeMglhg2dOWFPFyMq2p7C6tKKV4L+k+DIbvjqDnslMiuLR8OOBXD1i/630pLyCYU+eZUOD+H+jrWYuX4/f27NYtjQb8/aKW+u/UBXZc6pqi2g99t2uqAZz1z8eMpmmPkve+GjsHc3UXlW6JMXwF1X1qBC8VBe+nkDJvOl/s2zbNeIVkOhZofsd6Au1vRW+779Ocr24TonIx2+G2J7WV8zUk8XVZ5p8gLCQgJ5rFtdVuw6xM9r9tqNJw7A9w9A2WhdZiuvuv3HNsb/9CjsWmy3/fGuXWWp52tQvKKz8Sm/psnL5YbmVahbvhiv/rKBs+kZdozd8WS4foxtiFa5FxgEN46H4pXtHGCbZ8Kc/0K93hDTz+nolJ/T5OVybtjQ9tQTLPx+tGsivBF2eSeVd0VL2yFEZ0/AxBvsqkS939bTRZVvmrwy6RRdju7VMmiy6j+kV4rTifDcpVx9O4g7JML2BSuWu0WElcqKzvSWiRjDq0GjCTJpfFp+BHfpRHjuE90Dntyukwsqt9GaV2ZLxlF893y+Lz+U1xLS2X8kB8OGVM5p4lJupMnrnJRE26erdjda93uCs+kZvD3LYwt6K6XySZMX2EHX39xrVwDq+x5RkcW4rVV1vlyyi837jzkdnVIqC5q8AOa/AbuXnTfo+qHOtQkLDuRVd642pJRyG01efy2Fea9CzE3Q8Lq/N5cpVoT7O9Rkxrp9JGy/xABjpZQjCnfyOnMCvrnP1rZ6vnbRw3e1rUG5iCK8OH39+cOGlFKOK9zJa+bzkJpoB11nsRxT0ZAgHu1Wl2U7D/HrWvesNaeUco/Cm7y2zLbTslwxBGp2zPZp/ZpXoXa5TMOGlFI+oXAmr5MH4TvXoOuuz13yqUGBATzZvR5bU47z5ZJdXgpQKXU5hTN5TXvCrnZ9/egcDbruWr8cLaJK8fbMRI6fzmZyPaWUV3kseYnIeBHZLyJrMm37UkRWuP62i8gKT5WfrTVfw5qp0OFJO996DogIT/WsT8qx04ydn4vVhpRSHuPJmtcnQPfMG4wx/Y0xTYwxTYCvgW88WP7FjuyGnx6DynHQ9rFcvbRZtVL0aFSBMfFbST562kMBKqVyymPJyxgTD2TZQUpEBLgJmOyp8rMIyC74kHYarhudp3F2w6+O5kxaBu/MysFqQ0opj3KqzasdsM8Yk+3gQRG5V0QSRCQhOTk5/yUuGQdbZtn1AfO4nHzNyGIMaFmNyYt3sTVZhw0p5SSnktcALlPrMsaMMcbEGWPiIiPzOf9TSqJdH7BWF2hxd752NaxLHUKDAnhx+gbStOuEUo7xevISkSDgeuBLrxSYngbf3mcXfOj7fr5n8IyMKMLQTrWZuX4fXd+cx7fLk0jP0N73SnlbjpKXiISLSIDrdl0R6SMiwXkssyuwwRiTlMfX5878N+z4xd5vuW3Bh6EdazH29jjCQoJ49MuVXPXWPH5YuVuTmFJelNOaVzwQKiKVgRnAQOzVxGyJyGRgIRAtIkkiMtj10M14q6H+r2Uw7xW72EOj6922WxGhW4PyTHuoLR/e1oyggACGTV5Oj3fimb56DxmaxJTyOMnJgGMRWWaMaSYiDwFhxphXRWSFq8uDx8XFxZmEhITcvejsSRjdHk4fg6F/QFgpzwQHZGQYpq/Zw9szE9m8/xj1KkTwaLe6XNWgPKILTSiVZyKy1BgTl9VjOa15iYi0Bm4Fprm2BbojOI+Z+TykbHINuvZc4gIICBB6N67Er4+0552bm3AmLYP7Jiyl97sLmLlun85IoZQH5DR5PQI8BXxrjFkrIjWBOR6LKr+2zIY/P4SW90GtTl4rNjBA6NukMjMebc8b/WI5eiqNuz9LoO/7vzNn435NYkq5UY5OG897gW24L2aMOeKZkC6Wq9PGkwfhgzYQEg73xUNIUc8Gdwln0zP4dtlfjJydSNLBkzStVpLHutWlbe2yejqpVA7k+7RRRD4XkeIiEg6sAdaJyHB3Buk204fDsX120LWDiQsgODCAm1pUZfbjHXnxuhj2HT7FwI8Wc9PohfyxJcXR2JTydzk9bWzgqmldC/wM1MBecfQtW2bD6q/soOvKzZ2O5m8hQQHcckU15gzvyH/6NmTngRPcMvZPbh6zkMXbdIpppfIip8kr2NWv61rgB2PMWcD3GnBqdIC+H0C7x52OJEtFggIZ2DqKecM78dw1DdiSfJybRi/ktnF/snSHJjGlciOnyWs0sB0IB+JFpDrgtTavHAsIhKa3+vzipqHBgdx5ZQ3ih3fimV71Wb/nCDeMWsig8YtZseuQ0+Ep5Rdy3WD/9wtFgowxXpmZL0/9vPzIiTNpfLZwB6PnbeHgibN0qVeOR7vVpVHlEk6HppSjLtVgn9NOqiWA54D2rk3zgH8bYw67LcpLKOjJ65xjp9P49I/tjInfyuGTZ+nWoDxPXBVNdIUIp0NTyhHu6KQ6HjiKnYPrJuwp48fuCU+dU6xIEA90qs38JzvxWLe6LNqayjXvLmDCwu3aR0ypC+S05nXRUCCfHx5UABw4fobHp6xgzsZkesZU4OUbGlM8NK/j4ZXyP+6oeZ0UkbaZdnglcNIdwanslQ4P4aNBLRjRox6/rt1H75ELWJV0yOmwlPIJOU1e9wPvuxbN2A68B9znsajU3wIChPs71GLKfa1IS8/ghlF/8Mnv2/Q0UhV6OUpexpiVxphYoDHQ2BjTFOjs0cjUeZpXL820Ye3oUDeS539cx5CJyzh88qzTYSnlmFzNpGqMOZJpTGPult9R+VYqPISxt8fxTK/6zFy/j14j52u/MFVo5WcaaB1Z7AAR4e52NZlyf2uMgX4f/sG4+Vv1NFIVOvlJXvq/xUHNqpVi+rB2dIwuxwvT1nPPZ0s5dOKM02Ep5TWXTF4iclREjmTxdxSo5KUYVTZKFA1mzMDmPNu7AfM27afXyAUs23nQ6bCU8opLJi9jTIQxpngWfxHGGN8eQFhIiAh3ta3B1PvbEBAAN324kDHxW3QefVXgObVuo3Kz2Kol+emhdnStX54Xp2/g7s8SOHhcTyNVwaXJqwApERbMqNua8a8+DVmQmELPkfN1qh1VYGnyKmBEhEFtovh6SBtCggK4afQiRs3V00hV8GjyKqBiqpTgx4fa0r1RBV75ZQN3fbqE1GOnnQ5LKbfR5FWAFQ8N5r0BTXnh2kb8sSWVXiMX6LTTqsDQ5FXAiQi3tarOt0PbEBYSyICxi3h/zmY9jVR+T5NXIdGwkj2N7BVTkdd+3cigjxeToqeRyo9p8ipEihUJ4p2bm/DS9TEs3naAnu/MZ+GWVKfDUipPNHkVMiLCgJbV+O6BKykWGsSt4xYxclYi6XoaqfyMJq9Cqn7F4vz4YFv6NqnMm79t4voPftc+YcqvaPIqxMKLBPHmTbG8c3MT9h05zQ2jFvLQ5OUkHTzhdGhKXZYmr0JOROjbpDKzn+jAsC51mLF2L13emMcbMzZy/LRXVrZTKk80eSkAioYE8Vi3usx+oiPdG1Xg3dmb6fT6XKYuTdJuFconafJS56lcMox3bm7K10PaULFkGE98tZJrP/idJdu1PUz5Fo8lLxEZLyL7RWTNBdsfEpENIrJWRF71VPkqf5pXL8W3Q9rwdv8m7D9ymn4fLuSBz5dpe5jyGZ6seX0CdM+8QUQ6AX2BWGNMQ+B1D5av8ikgQLi2qW0Pe7hLHWat30fnN+bx+q/aHqac57HkZYyJBy481xgCvGyMOe16zn5Pla/cp2hIEI92q8vsxzvSs1EF3puzmY6vz+WrhF3aHqYc4+02r7pAOxH5U0TmiUiL7J4oIveKSIKIJCQnJ3sxRJWdSiXDePvmpnwztA2VS4YxfOoq+ryvg72VM7ydvIKA0kArYDgwRUSyXIXIGDPGGBNnjImLjIz0ZozqMppVK8W3Q9vwzs1NSD12hptGL+SBScvYdUDbw5T3eDt5JQHfGGsxkAGU9XIMyg3+7h/2eEce7VqX2Rv20+XNebz6ywaOaXuY8gJvJ6/vgE4AIlIXCAFSvByDcqOwkEAe7lqH2U90oFdMRT6Yu4VOr89liraHKQ/zZFeJycBCIFpEkkRkMDAeqOnqPvEFMMjoaqkFQsUSYbzVvwnfDm1DlVJh/GPqKq55bwF/btVZK5RniD/kjri4OJOQkOB0GCqHjDH8sHI3r/y8gd2HT9EzpgJP9ahP1dJFnQ5N+RkRWWqMicvqMV17UbndufawqxpUYOz8rYyau4WZ6/YzuF0NhnasRURosNMhqgJAhwcpjwkLCWRYlzrMeaIjvWMrMsrVHjZ58U6dP0zlmyYv5XEVSoTy5k1N+P6BK4kqE85T36ym18j5/L5Zr9WovNPkpbwmtmpJvrq/Ne/d0pSjp9K4ddyf3P3pErYmH3M6NOWHNHkprxIRejeuxKzHO/CP7tEs2nqAq96K598/ruPwibNOh6f8iCYv5YjQ4ECGdqzNnCc60i+uCp/8sY0Or8/hk9+3cTY9w+nwlB/Q5KUcFRlRhJeub8y0Ye1oWKk4z/+4ju5vxzN7wz78oRuPco4mL+UT6lcszsTBVzD29jgyDNz1SQK3j1/Mxr1HnQ5N+ShNXspniAjdGpTn10fa83+9G7By1yF6vBPPP79dTaoukKsuoMlL+ZyQoAAGt63BvOGduL11FF8s2UXH1+Yyet4WTqelOx2e8hGavJTPKhUewvN9GvLrI+1pUaM0L/28gW5vxvPz6j3aHqY0eSnfV7tcMcbf0YLP7mpJaHAAQyYto/+YRaxOOux0aMpBmryU32hfN5Lpw9rxwrWN2LL/GH3eX8DjU1ay78gpp0NTDtDkpfxKUGAAt7WqzpzhHbm3XU1+XLmbjq/NZeSsRE6e0fawwkSTl/JLxUODeapnfX57rD0doyN587dNdH5jLt8t/0snQSwkNHkpv1a9TDijbmvOl/e2okyxEB75cgXXffA7v67dq0msgNPJCFWBkZFh+HpZEu/MSiTp4ElqRYZzX/ta9G1aiSJBgU6Hp/LgUpMRavJSBU5aegbTVu/hw3lbWb/nCOWLF2Fw2xoMaFlNJ0L0M5q8VKFkjCE+MYUP525h4dZUIkKDGNiqOndcGUW5iFCnw1M5oMlLFXordx1idPwWfl6zl+DAAG5oVoV729ekRtlwp0NTl6DJSymXbSnHGRO/la+XJXE2PYMejSpwf4daNK5S0unQVBY0eSl1gf1HT/HJ79uZsGgHR0+l0bpmGe7vWIv2dcqSzSLuygGavJTKxtFTZ5m8eCcfLdjGviOnaVCxOPd1qEmvmIoEBWpPIqdp8lLqMk6npfP98t2Mjt/CluTjVC0dxj3tatKveVXCQrSbhVM0eSmVQxkZhpnr9/HhvC0s23mI0uEhDGodxe2tq1MqPMTp8AodTV5K5ZIxhiXbD/LhvC3M3rCfsOBAbm5Zlbvb1aRyyTCnwys0dMVspXJJRGhZozQta5Rm496jjI7fwoSFO5iwcAd9YitxX4daRFeIcDrMQk1rXkrl0F+HTvLR/G18sWQnJ86kc0WN0nSIjqR9nUgaVCxOQIBepXQ3PW1Uyo0OnTjDhIU7mL5mL+v3HAGgTHgIbeuUpX2dSNrVKUu54tqD3x00eSnlIfuPnGLB5hTiNyUzPzGF1ONnAKhXIYL2dW2tLC6qFKHBesUyLzR5KeUFGRmGdXuOMD/RJrOEHQc4m24IDQ7gihplaFenLB3qRlK7XDHtCJtDmryUcsCJM2ks2ppK/KYU4hOT2Zp8HICKJUJpV6cs7epE0rZ2We2CcQmOJC8RGQ/0BvYbYxq5tj0P3AMku572tDFm+uX2pclLFQRJB0+wINEmsgWJKRw5lYYINK5cgnZ1ImlfN5Km1UoSrD37/+ZU8moPHAM+uyB5HTPGvJ6bfWWVvM6ePUtSUhKnTuniC/kVGhpKlSpVCA7Wua68JT3DsDLpEPNdtbIVuw6RnmEoViSI1rXK0L5OWdrXjaR6mcI964Uj/byMMfEiEuWp/SclJREREUFUVJS2H+SDMYbU1FSSkpKoUaOG0+EUGoEBQrNqpWhWrRQPd63D4ZNnWbglhXhXe9lv6/YBUK10UQa3rcHAVtW1K8YFnOik+qCI3A4kAI8bYw5m9SQRuRe4F6BatWoXPX7q1ClNXG4gIpQpU4bk5OTLP1l5TImwYLo3qkj3RhUxxrA99QTxm5KZtnoPz/2wll/X7uW1frHauz8Tb59cjwJqAU2APcAb2T3RGDPGGBNnjImLjIzM8jmauNxD30ffIiLUKBvOoDZRfHlvK166PoaVuw7R/a14vkrYpauFu3g1eRlj9hlj0o0xGcBYoKU3y1fK34gIA1pW45dH2lO/UnGGT13FPZ8tZf9Rbev1avISkYqZ7l4HrPFm+Ur5q6qli/LFPa14pld94hOTufqteKav3uN0WI7yWPISkcnAQiBaRJJEZDDwqoisFpFVQCfgUU+V72mHDh3igw8+yPXrevbsyaFDh3L9ujvuuIOpU6fm+nWq4AgIEO5uV5Ppw9pStXRRhk5axsNfLOfwibNOh+YIT15tHJDF5o88Uda/flzLut1H3LrPBpWK89w1DbN9/FzyGjp06Hnb09LSCArK/m2dPv2y3dqUuqTa5SL4ekgbRs3dwshZiSzamsorNzSmY3Q5p0PzKu0Nl0cjRoxgy5YtNGnShBYtWtCuXTv69OlDgwYNALj22mtp3rw5DRs2ZMyYMX+/LioqipSUFLZv3079+vW55557aNiwIVdddRUnT57MUdmzZs2iadOmxMTEcNddd3H69Om/Y2rQoAGNGzfmiSeeAOCrr76iUaNGxMbG0r59eze/C8opwYEBDOtSh+8euJLiocHc8fESnv52NcdPpzkdmvcYY3z+r3nz5uZC69atu2ibN23bts00bNjQGGPMnDlzTNGiRc3WrVv/fjw1NdUYY8yJEydMw4YNTUpKijHGmOrVq5vk5GSzbds2ExgYaJYvX26MMaZfv35mwoQJ2ZY3aNAg89VXX5mTJ0+aKlWqmI0bNxpjjBk4cKB56623TEpKiqlbt67JyMgwxhhz8OBBY4wxjRo1MklJSedty4rT76fKu5Nn0syL09aZqBE/mbavzDJ/bk11OiS3ARJMNnlBa15u0rJly/M6eY4cOZLY2FhatWrFrl27SExMvOg1NWrUoEmTJgA0b96c7du3X7acjRs3UqNGDerWrQvAoEGDiI+Pp0SJEoSGhjJ48GC++eYbihYtCsCVV17JHXfcwdixY0lPT8//gSqfExocyFM96zPlvtYIQv8xC/nvtHWcOluwP29NXm4SHv6/YRxz585l5syZLFy4kJUrV9K0adMshzEVKVLk79uBgYGkpeW9yh8UFMTixYu58cYb+emnn+jevTsAH374IS+88AK7du2iefPmpKam5rkM5dtaRJXm54fbcUvLaoydv43e7y5gVdIhp8PyGE1eeRQREcHRo0ezfOzw4cOUKlWKokWLsmHDBhYtWuS2cqOjo9m+fTubN28GYMKECXTo0IFjx45x+PBhevbsyVtvvcXKlSsB2LJlC1dccQX//ve/iYyMZNeuXW6LRfme8CJB/Pe6GD69qyXHTqVx3Qd/8NZvmzibnuF0aG6nc9jnUZkyZbjyyitp1KgRYWFhlC9f/u/Hunfvzocffkj9+vWJjo6mVatWbis3NDSUjz/+mH79+pGWlkaLFi24//77OXDgAH379uXUqVMYY3jzzTcBGD58OImJiRhj6NKlC7GxsW6LRfmuDnUj+fWR9jz/41remZXI7A37efOmWOqULzjz7vvtfF7r16+nfv36DkVU8Oj7WXD9smYPT3+7hmOn0xh+VTR3ta1BoJ8M8r7UrBJ62qhUAde9UUV+faQ9HepG8t/p6xkwZhE7U084HVa+afLyMQ888ABNmjQ57+/jjz92Oizl5yIjijBmYHNe7xfL+j1H6P5OPJP+3OHXg7y1zcvHvP/++06HoAooEeHG5lVoXasM/5i6kn9+u4YZa/fxyg2NqVDC/1Y70pqXUoVM5ZJhTLjrCv7dtyF/bkvl6rfj+X7FX35XC9PkpVQhFBAg3N46ip8fbk+tyHAe/mIFXd+cx7j5WzngWr7N12nyUqoQq1E2nK/ub8Pr/WIpERbMC9PW0+rFWQybvJyFW1J9ujambV5KFXKBAbYt7MbmVdiw9whfLN7FN8uS+GHlbmqUDWdAy6rc0KwKZYoVufzOvEhrXl5SrFixbB/bvn07jRo18mI0SmWtXoXiPN+nIYv/2ZU3b4qlbLEQXpy+gVYvzeKBz5fx++YUMjJ8ozZWMGpeP4+Avavdu88KMdDjZffuUyk/ERocyPXNqnB9syok7jvK5MW7+GZ5EtNW7aF6maLc3KIaNzavQmSEc7UxrXnl0YgRI87r1vD888/zwgsv0KVLF5o1a0ZMTAzff/99rvd76tQp7rzzTmJiYmjatClz5swBYO3atbRs2ZImTZrQuHFjEhMTOX78OL169SI2NpZGjRrx5Zdfuu34lDqnTvkInr2mAYue6sI7NzehQvFQXvllA61fmsWQiUuJ35TsTG0su7lyfOnPF+fzWrZsmWnfvv3f9+vXr2927txpDh8+bIwxJjk52dSqVevv+bXCw8Oz3VfmucFef/11c+eddxpjjFm/fr2pWrWqOXnypHnwwQfNxIkTjTHGnD592pw4ccJMnTrV3H333X/v59ChQ3k+HqffT+VfEvcdNS/8tNY0+devpvqTdh6x92Ynmn2HT7q1HHQ+L/dr2rQp+/fvZ/fu3axcuZJSpUpRoUIFnn76aRo3bkzXrl3566+/2LdvX672u2DBAm677TYA6tWrR/Xq1dm0aROtW7fmxRdf5JVXXmHHjh2EhYURExPDb7/9xpNPPsn8+fMpUaKEJw5VqYvULleMf/ZqwKKnuzByQFOqlirKa79upPXLs7n3swTmbNxPuodrYwWjzcsh/fr1Y+rUqezdu5f+/fszadIkkpOTWbp0KcHBwURFRWU5j1de3HLLLVxxxRVMmzaNnj17Mnr0aDp37syyZcuYPn06zzzzDF26dOHZZ591S3lK5USRoED6xFaiT2wltqUc54slO5makMSMdfuoXDKM/i2qclNcVY/04NfklQ/9+/fnnnvuISUlhXnz5jFlyhTKlStHcHAwc+bMYceOHbneZ7t27Zg0aRKdO3dm06ZN7Ny5k+joaLZu3UrNmjUZNmwYO3fuZNWqVdSrV4/SpUtz2223UbJkScaNG+eBo1QqZ2qUDeepHvV5vFs0v63bxxdLdvLmb5t4e+YmOtcrx4CW1egYXc5tM1po8sqHhg0bcvToUSpXrkzFihW59dZbueaaa4iJiSEuLo569erlep9Dhw5lyJAhxMTEEBQUxCeffEKRIkWYMmUKEyZMIDg4+O/T0yVLljB8+HACAgIIDg5m1KhRHjhKpXInJCiAXo0r0qtxRXakHufLJbuYkpDEzPUJVCwRyncPXEn54vmviel8XgrQ91N51tn0DGat30d8Ygr/vbYRIjmrfV1qPi+teSmlPC44MIDujSrSvVFFt+1Tk5cXrV69moEDB563rUiRIvz5558ORaSU//Lr5GWMyXH10xfExMSwYsUKp8O4iD80HSh1Ib/t5xUaGkpqqm+PevcHxhhSU1MJDfW/yehU4ea3Na8qVaqQlJREcnKy06H4vdDQUKpUqeJ0GErlit8mr+Dg4PNWqFZKFS5+e9qolCrcNHkppfySJi+llF/yix72IpIM5GagYFkgxUPheFtBOhYoWMdTkI4FfPN4qhtjIrN6wC+SV26JSEJ2Qwr8TUE6FihYx1OQjgX873j0tFEp5Zc0eSml/FJBTV5jnA7AjQrSsUDBOp6CdCzgZ8dTINu8lFIFX0GteSmlCjhNXkopv1SgkpeIdBeRjSKyWURGOB1PfohIVRGZIyLrRGStiDzsdEz5JSKBIrJcRH5yOpb8EpGSIjJVRDaIyHoRae10THklIo+6vmNrRGSyiPjFFCMFJnmJSCDwPtADaAAMEJEGzkaVL2nA48aYBkAr4AE/Px6Ah4H1TgfhJu8Avxhj6gGx+OlxiUhlYBgQZ4xpBAQCNzsbVc4UmOQFtAQ2G2O2GmPOAF8AfR2OKc+MMXuMMctct49i/3NUdjaqvBORKkAvwO+XOBKREkB74CMAY8wZY8whR4PKnyAgTESCgKLAbofjyZGClLwqA7sy3U/Cj/+zZyYiUUBTwJ/ni34b+AeQ4XAc7lADSAY+dp0GjxORcKeDygtjzF/A68BOYA9w2Bgzw9mocqYgJa8CSUSKAV8DjxhjjjgdT16ISG9gvzFmqdOxuEkQ0AwYZYxpChwH/LKNVURKYc9QagCVgHARuc3ZqHKmICWvv4Cqme5XcW3zWyISjE1ck4wx3zgdTz5cCfQRke3Y0/nOIjLR2ZDyJQlIMsacqwlPxSYzf9QV2GaMSTbGnAW+Ado4HFOOFKTktQSoIyI1RCQE2+j4g8Mx5ZnYlUU+AtYbY950Op78MMY8ZYypYoyJwn4us40xfvHrnhVjzF5gl4hEuzZ1AdY5GFJ+7ARaiUhR13euC35y8cFvp4G+kDEmTUQeBH7FXjEZb4xZ63BY+XElMBBYLSIrXNueNsZMdy4klclDwCTXD+VW4E6H48kTY8yfIjIVWIa9wr0cPxkmpMODlFJ+qSCdNiqlChFNXkopv6TJSynllzR5KaX8kiYvpZRf0uSlPEpE0kVkRaY/t/VEF5EoEVnjrv0p/1Jg+nkpn3XSGNPE6SBUwaM1L+UIEdkuIq+KyGoRWSwitV3bo0RktoisEpFZIlLNtb28iHwrIitdf+eGsASKyFjXfFQzRCTM9fxhrrnQVonIFw4dpvIgTV7K08IuOG3sn+mxw8aYGOA97KwTAO8CnxpjGgOTgJGu7SOBecaYWOw4wnOjJ+oA7xtjGgKHgBtc20cATV37ud8zh6acpD3slUeJyDFjTLEstm8HOhtjtroGoO81xpQRkRSgojHmrGv7HmNMWdeq6VWMMacz7SMK+M0YU8d1/0kg2Bjzgoj8AhwDvgO+M8Yc8/ChKi/Tmpdyksnmdm6cznQ7nf+14/bCzqzbDFjimmhPFSCavJST+mf6d6Hr9h/8bxriW4H5rtuzgCHw91z4JbLbqYgEAFWNMXOAJ4ESwEW1P+Xf9NdIeVpYplkxwM77fq67RCkRWYWtPQ1wbXsIO0PpcOxspedma3gYGCMig7E1rCHYmT+zEghMdCU4AUb6+TTNKgva5qUc4WrzijPGpDgdi/JPetqolPJLWvNSSvklrXkppfySJi+llF/S5KWU8kuavJRSfkmTl1LKL/0/rC1qDE44mlAAAAAASUVORK5CYII=\n", 1697 | "text/plain": [ 1698 | "
" 1699 | ] 1700 | }, 1701 | "metadata": { 1702 | "needs_background": "light" 1703 | }, 1704 | "output_type": "display_data" 1705 | } 1706 | ], 1707 | "source": [ 1708 | "training_vis(train_losses, valid_losses)" 1709 | ] 1710 | }, 1711 | { 1712 | "cell_type": "markdown", 1713 | "metadata": {}, 1714 | "source": [ 1715 | "#### 模型评估\n", 1716 | "\n", 1717 | "在测试集上评估模型效果。" 1718 | ] 1719 | }, 1720 | { 1721 | "cell_type": "code", 1722 | "execution_count": 30, 1723 | "metadata": { 1724 | "execution": { 1725 | "iopub.execute_input": "2021-11-07T11:41:57.180629Z", 1726 | "iopub.status.busy": "2021-11-07T11:41:57.180063Z", 1727 | "iopub.status.idle": "2021-11-07T11:41:57.544721Z", 1728 | "shell.execute_reply": "2021-11-07T11:41:57.544233Z", 1729 | "shell.execute_reply.started": "2021-11-07T11:31:56.594261Z" 1730 | }, 1731 | "papermill": { 1732 | "duration": 0.680346, 1733 | "end_time": "2021-11-07T11:41:57.544847", 1734 | "exception": false, 1735 | "start_time": "2021-11-07T11:41:56.864501", 1736 | "status": "completed" 1737 | }, 1738 | "tags": [] 1739 | }, 1740 | "outputs": [ 1741 | { 1742 | "data": { 1743 | "text/plain": [ 1744 | "" 1745 | ] 1746 | }, 1747 | "execution_count": 30, 1748 | "metadata": {}, 1749 | "output_type": "execute_result" 1750 | } 1751 | ], 1752 | "source": [ 1753 | "# 加载最佳模型权重\n", 1754 | "checkpoint = torch.load('../input/ai-earth-model-weights/task04_model_weights.pth')\n", 1755 | "model = Model()\n", 1756 | "model.load_state_dict(checkpoint['state_dict'])" 1757 | ] 1758 | }, 1759 | { 1760 | "cell_type": "code", 1761 | "execution_count": 31, 1762 | "metadata": { 1763 | "execution": { 1764 | "iopub.execute_input": "2021-11-07T11:41:58.181883Z", 1765 | "iopub.status.busy": "2021-11-07T11:41:58.181002Z", 1766 | "iopub.status.idle": "2021-11-07T11:41:58.182836Z", 1767 | "shell.execute_reply": "2021-11-07T11:41:58.183296Z", 1768 | "shell.execute_reply.started": "2021-11-07T11:31:57.001527Z" 1769 | }, 1770 | "papermill": { 1771 | "duration": 0.323811, 1772 | "end_time": "2021-11-07T11:41:58.183429", 1773 | "exception": false, 1774 | "start_time": "2021-11-07T11:41:57.859618", 1775 | "status": "completed" 1776 | }, 1777 | "tags": [] 1778 | }, 1779 | "outputs": [], 1780 | "source": [ 1781 | "# 测试集路径\n", 1782 | "test_path = '../input/ai-earth-tests/'\n", 1783 | "# 测试集标签路径\n", 1784 | "test_label_path = '../input/ai-earth-tests-labels/'" 1785 | ] 1786 | }, 1787 | { 1788 | "cell_type": "code", 1789 | "execution_count": 32, 1790 | "metadata": { 1791 | "execution": { 1792 | "iopub.execute_input": "2021-11-07T11:41:58.817618Z", 1793 | "iopub.status.busy": "2021-11-07T11:41:58.816950Z", 1794 | "iopub.status.idle": "2021-11-07T11:42:00.269201Z", 1795 | "shell.execute_reply": "2021-11-07T11:42:00.268643Z", 1796 | "shell.execute_reply.started": "2021-11-07T11:31:57.010291Z" 1797 | }, 1798 | "papermill": { 1799 | "duration": 1.770946, 1800 | "end_time": "2021-11-07T11:42:00.269350", 1801 | "exception": false, 1802 | "start_time": "2021-11-07T11:41:58.498404", 1803 | "status": "completed" 1804 | }, 1805 | "tags": [] 1806 | }, 1807 | "outputs": [], 1808 | "source": [ 1809 | "import os\n", 1810 | "\n", 1811 | "# 读取测试数据和测试数据的标签\n", 1812 | "files = os.listdir(test_path)\n", 1813 | "X_test = []\n", 1814 | "y_test = []\n", 1815 | "for file in files:\n", 1816 | " X_test.append(np.load(test_path + file))\n", 1817 | " y_test.append(np.load(test_label_path + file))" 1818 | ] 1819 | }, 1820 | { 1821 | "cell_type": "code", 1822 | "execution_count": 33, 1823 | "metadata": { 1824 | "execution": { 1825 | "iopub.execute_input": "2021-11-07T11:42:00.911046Z", 1826 | "iopub.status.busy": "2021-11-07T11:42:00.909973Z", 1827 | "iopub.status.idle": "2021-11-07T11:42:00.973831Z", 1828 | "shell.execute_reply": "2021-11-07T11:42:00.973361Z", 1829 | "shell.execute_reply.started": "2021-11-07T11:31:58.325122Z" 1830 | }, 1831 | "papermill": { 1832 | "duration": 0.395072, 1833 | "end_time": "2021-11-07T11:42:00.973969", 1834 | "exception": false, 1835 | "start_time": "2021-11-07T11:42:00.578897", 1836 | "status": "completed" 1837 | }, 1838 | "tags": [] 1839 | }, 1840 | "outputs": [ 1841 | { 1842 | "data": { 1843 | "text/plain": [ 1844 | "((103, 12, 24, 72, 4), (103, 24))" 1845 | ] 1846 | }, 1847 | "execution_count": 33, 1848 | "metadata": {}, 1849 | "output_type": "execute_result" 1850 | } 1851 | ], 1852 | "source": [ 1853 | "X_test = np.array(X_test)\n", 1854 | "y_test = np.array(y_test)\n", 1855 | "X_test.shape, y_test.shape" 1856 | ] 1857 | }, 1858 | { 1859 | "cell_type": "code", 1860 | "execution_count": 34, 1861 | "metadata": { 1862 | "execution": { 1863 | "iopub.execute_input": "2021-11-07T11:42:01.612810Z", 1864 | "iopub.status.busy": "2021-11-07T11:42:01.611682Z", 1865 | "iopub.status.idle": "2021-11-07T11:42:01.638035Z", 1866 | "shell.execute_reply": "2021-11-07T11:42:01.637526Z", 1867 | "shell.execute_reply.started": "2021-11-07T11:31:58.416006Z" 1868 | }, 1869 | "papermill": { 1870 | "duration": 0.3441, 1871 | "end_time": "2021-11-07T11:42:01.638176", 1872 | "exception": false, 1873 | "start_time": "2021-11-07T11:42:01.294076", 1874 | "status": "completed" 1875 | }, 1876 | "tags": [] 1877 | }, 1878 | "outputs": [], 1879 | "source": [ 1880 | "testset = AIEarthDataset(X_test, y_test)\n", 1881 | "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)" 1882 | ] 1883 | }, 1884 | { 1885 | "cell_type": "code", 1886 | "execution_count": 35, 1887 | "metadata": { 1888 | "execution": { 1889 | "iopub.execute_input": "2021-11-07T11:42:02.286777Z", 1890 | "iopub.status.busy": "2021-11-07T11:42:02.285865Z", 1891 | "iopub.status.idle": "2021-11-07T11:42:02.621942Z", 1892 | "shell.execute_reply": "2021-11-07T11:42:02.622610Z", 1893 | "shell.execute_reply.started": "2021-11-07T11:31:58.447987Z" 1894 | }, 1895 | "papermill": { 1896 | "duration": 0.666798, 1897 | "end_time": "2021-11-07T11:42:02.622817", 1898 | "exception": false, 1899 | "start_time": "2021-11-07T11:42:01.956019", 1900 | "status": "completed" 1901 | }, 1902 | "tags": [] 1903 | }, 1904 | "outputs": [ 1905 | { 1906 | "name": "stderr", 1907 | "output_type": "stream", 1908 | "text": [ 1909 | "4it [00:00, 12.75it/s]" 1910 | ] 1911 | }, 1912 | { 1913 | "name": "stdout", 1914 | "output_type": "stream", 1915 | "text": [ 1916 | "Score: 20.274\n" 1917 | ] 1918 | }, 1919 | { 1920 | "name": "stderr", 1921 | "output_type": "stream", 1922 | "text": [ 1923 | "\n" 1924 | ] 1925 | } 1926 | ], 1927 | "source": [ 1928 | "# 在测试集上评估模型效果\n", 1929 | "model.eval()\n", 1930 | "model.to(device)\n", 1931 | "preds = np.zeros((len(y_test),24))\n", 1932 | "for i, data in tqdm(enumerate(testloader)):\n", 1933 | " data, labels = data\n", 1934 | " data = data.to(device)\n", 1935 | " labels = labels.to(device)\n", 1936 | " pred = model(data)\n", 1937 | " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n", 1938 | "s = score(y_test, preds)\n", 1939 | "print('Score: {:.3f}'.format(s))" 1940 | ] 1941 | }, 1942 | { 1943 | "cell_type": "markdown", 1944 | "metadata": { 1945 | "papermill": { 1946 | "duration": 0.311337, 1947 | "end_time": "2021-11-07T11:42:03.280397", 1948 | "exception": false, 1949 | "start_time": "2021-11-07T11:42:02.969060", 1950 | "status": "completed" 1951 | }, 1952 | "tags": [] 1953 | }, 1954 | "source": [ 1955 | "## 总结\n", 1956 | "\n", 1957 | "- 该方案充分考虑到数据量小、特征少的数据情况,对时间和空间分别进行卷积操作,交替地提取时间和空间信息,用GAP层对提取的信息进行降维,尽可能减少每一层的参数量、增加模型层数以提取更丰富的特征。\n", 1958 | "- 该方案考虑到不同时间尺度序列所携带的信息不同,用池化层变换时间尺度,并用RNN进行信息提取,综合三种不同时间尺度的序列信息得到最终的预测序列。\n", 1959 | "- 该方案同样选择了自己设计模型,在构造模型时充分考虑了数据集情况和问题背景,并能灵活运用各种网络层来处理特定问题,这种模型构造思路要求对不同网络层的作用有较为深刻地理解,方案中各种网络层的用法值得大家学习和借鉴。" 1960 | ] 1961 | }, 1962 | { 1963 | "cell_type": "markdown", 1964 | "metadata": {}, 1965 | "source": [ 1966 | "## 参考文献\n", 1967 | "\n", 1968 | "1. Top1思路分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.6.561d482cp7CFlx&postId=210391" 1969 | ] 1970 | } 1971 | ], 1972 | "metadata": { 1973 | "kernelspec": { 1974 | "display_name": "Python 3", 1975 | "language": "python", 1976 | "name": "python3" 1977 | }, 1978 | "language_info": { 1979 | "codemirror_mode": { 1980 | "name": "ipython", 1981 | "version": 3 1982 | }, 1983 | "file_extension": ".py", 1984 | "mimetype": "text/x-python", 1985 | "name": "python", 1986 | "nbconvert_exporter": "python", 1987 | "pygments_lexer": "ipython3", 1988 | "version": "3.7.3" 1989 | }, 1990 | "papermill": { 1991 | "default_parameters": {}, 1992 | "duration": 571.585002, 1993 | "end_time": "2021-11-07T11:42:05.308202", 1994 | "environment_variables": {}, 1995 | "exception": null, 1996 | "input_path": "__notebook__.ipynb", 1997 | "output_path": "__notebook__.ipynb", 1998 | "parameters": {}, 1999 | "start_time": "2021-11-07T11:32:33.723200", 2000 | "version": "2.3.3" 2001 | } 2002 | }, 2003 | "nbformat": 4, 2004 | "nbformat_minor": 5 2005 | } 2006 | -------------------------------------------------------------------------------- /Task4/fig/Task4-CNN单元.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-CNN单元.png -------------------------------------------------------------------------------- /Task4/fig/Task4-TCNN+RNN模型.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-TCNN+RNN模型.png -------------------------------------------------------------------------------- /Task4/fig/Task4-TCNN层.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-TCNN层.png -------------------------------------------------------------------------------- /Task4/fig/Task4-TCN单元.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-TCN单元.png -------------------------------------------------------------------------------- /Task4/fig/Task4-扩张卷积.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-扩张卷积.png -------------------------------------------------------------------------------- /Task4/fig/Task4-残差连接.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-残差连接.png -------------------------------------------------------------------------------- /Task5/Task5 模型建立之SA-ConvLSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Datawhale 气象海洋预测-Task5 模型建立之 SA-ConvLSTM\n", 8 | "\n", 9 | "本次任务我们将学习来自TOP选手“吴先生的队伍”的建模方案,该方案中采用的模型是SA-ConvLSTM。\n", 10 | "\n", 11 | "前两个TOP方案中选择将赛题看作一个多输出的任务,通过构建神经网络直接输出24个nino3.4预测值,这种思路的问题在于,序列问题往往是时序依赖的,当我们采用多输出的方法时其实把这24个nino3.4预测值看作是完全独立的,但是实际上它们之间是存在序列依赖的,即每个预测值往往受上一个时间步的预测值的影响。因此,在这次的TOP方案中,采用Seq2Seq结构来考虑输出预测值的序列依赖性。\n", 12 | "\n", 13 | "Seq2Seq结构包括Encoder(编码器)和Decoder(解码器)两部分,Encoder部分将输入序列编码成一个向量,Decoder部分对向量进行解码,输出一个预测序列。要将Seq2Seq结构应用于不同的序列问题,关键在于每一个时间步所使用的Cell。我们之前说到,挖掘空间信息通常会采用CNN,挖掘时间信息通常会采用RNN或LSTM,将二者结合在一起就得到了时空序列领域的经典模型——ConvLSTM,我们本次要学习的SA-ConvLSTM模型是对ConvLSTM模型的改进,在其基础上引入了自注意力机制来提高模型对于长期空间依赖关系的挖掘能力。\n", 14 | "\n", 15 | "另外与前两个TOP方案所不同的一点是,该TOP方案没有直接预测Nino3.4指数,而是通过预测sst来间接求得Nino3.4指数序列。" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## 学习目标\n", 23 | "1. 学习TOP方案的模型构建方法" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## 内容介绍\n", 31 | "1. 数据处理\n", 32 | " - 数据扁平化\n", 33 | " - 空值填充\n", 34 | " - 构造数据集\n", 35 | "2. 模型构建\n", 36 | " - 构造评估函数\n", 37 | " - 模型构造\n", 38 | " - 模型训练\n", 39 | " - 模型评估\n", 40 | "3. 总结" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## 代码示例" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "### 数据处理\n", 55 | "该TOP方案的数据处理主要包括三部分:\n", 56 | "1. 数据扁平化。\n", 57 | "2. 空值填充。\n", 58 | "3. 构造数据集" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 1, 64 | "metadata": { 65 | "execution": { 66 | "iopub.execute_input": "2021-11-29T03:04:34.698663Z", 67 | "iopub.status.busy": "2021-11-29T03:04:34.697133Z", 68 | "iopub.status.idle": "2021-11-29T03:04:37.035400Z", 69 | "shell.execute_reply": "2021-11-29T03:04:37.034767Z", 70 | "shell.execute_reply.started": "2021-11-29T01:02:51.883602Z" 71 | }, 72 | "papermill": { 73 | "duration": 2.370278, 74 | "end_time": "2021-11-29T03:04:37.035673", 75 | "exception": false, 76 | "start_time": "2021-11-29T03:04:34.665395", 77 | "status": "completed" 78 | }, 79 | "tags": [] 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "import netCDF4 as nc\n", 84 | "import random\n", 85 | "import os\n", 86 | "from tqdm import tqdm\n", 87 | "import pandas as pd\n", 88 | "import numpy as np\n", 89 | "import math\n", 90 | "import matplotlib.pyplot as plt\n", 91 | "%matplotlib inline\n", 92 | "\n", 93 | "import torch\n", 94 | "from torch import nn, optim\n", 95 | "import torch.nn.functional as F\n", 96 | "from torch.utils.data import Dataset, DataLoader\n", 97 | "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", 98 | "\n", 99 | "from sklearn.metrics import mean_squared_error" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 2, 105 | "metadata": { 106 | "execution": { 107 | "iopub.execute_input": "2021-11-29T03:04:37.102995Z", 108 | "iopub.status.busy": "2021-11-29T03:04:37.102144Z", 109 | "iopub.status.idle": "2021-11-29T03:04:37.107646Z", 110 | "shell.execute_reply": "2021-11-29T03:04:37.107161Z", 111 | "shell.execute_reply.started": "2021-11-29T01:02:54.06493Z" 112 | }, 113 | "papermill": { 114 | "duration": 0.040737, 115 | "end_time": "2021-11-29T03:04:37.107761", 116 | "exception": false, 117 | "start_time": "2021-11-29T03:04:37.067024", 118 | "status": "completed" 119 | }, 120 | "tags": [] 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "# 固定随机种子\n", 125 | "SEED = 22\n", 126 | "\n", 127 | "def seed_everything(seed=42):\n", 128 | " random.seed(seed)\n", 129 | " os.environ['PYTHONHASHSEED'] = str(seed)\n", 130 | " np.random.seed(seed)\n", 131 | " torch.manual_seed(seed)\n", 132 | " torch.cuda.manual_seed(seed)\n", 133 | " torch.backends.cudnn.deterministic = True\n", 134 | " \n", 135 | "seed_everything(SEED)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 3, 141 | "metadata": { 142 | "execution": { 143 | "iopub.execute_input": "2021-11-29T03:04:37.222525Z", 144 | "iopub.status.busy": "2021-11-29T03:04:37.221100Z", 145 | "iopub.status.idle": "2021-11-29T03:04:37.225844Z", 146 | "shell.execute_reply": "2021-11-29T03:04:37.226442Z", 147 | "shell.execute_reply.started": "2021-11-29T01:02:54.074875Z" 148 | }, 149 | "papermill": { 150 | "duration": 0.090198, 151 | "end_time": "2021-11-29T03:04:37.226602", 152 | "exception": false, 153 | "start_time": "2021-11-29T03:04:37.136404", 154 | "status": "completed" 155 | }, 156 | "tags": [] 157 | }, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "CUDA is available! Training on GPU ...\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "# 查看CUDA是否可用\n", 169 | "train_on_gpu = torch.cuda.is_available()\n", 170 | "\n", 171 | "if not train_on_gpu:\n", 172 | " print('CUDA is not available. Training on CPU ...')\n", 173 | "else:\n", 174 | " print('CUDA is available! Training on GPU ...')" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 4, 180 | "metadata": { 181 | "execution": { 182 | "iopub.execute_input": "2021-11-29T03:04:37.353082Z", 183 | "iopub.status.busy": "2021-11-29T03:04:37.352143Z", 184 | "iopub.status.idle": "2021-11-29T03:04:37.432852Z", 185 | "shell.execute_reply": "2021-11-29T03:04:37.434332Z", 186 | "shell.execute_reply.started": "2021-11-28T10:13:13.644947Z" 187 | }, 188 | "papermill": { 189 | "duration": 0.179146, 190 | "end_time": "2021-11-29T03:04:37.434792", 191 | "exception": false, 192 | "start_time": "2021-11-29T03:04:37.255646", 193 | "status": "completed" 194 | }, 195 | "tags": [] 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "# 读取数据\n", 200 | "\n", 201 | "# 存放数据的路径\n", 202 | "path = '/kaggle/input/ninoprediction/'\n", 203 | "soda_train = nc.Dataset(path + 'SODA_train.nc')\n", 204 | "soda_label = nc.Dataset(path + 'SODA_label.nc')\n", 205 | "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n", 206 | "cmip_label = nc.Dataset(path + 'CMIP_label.nc')" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "#### 数据扁平化\n", 214 | "采用滑窗构造数据集。该方案中只使用了sst特征,且只使用了lon值在[90, 330]范围内的数据,可能是为了节约计算资源。" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 5, 220 | "metadata": { 221 | "execution": { 222 | "iopub.execute_input": "2021-11-29T03:04:37.548239Z", 223 | "iopub.status.busy": "2021-11-29T03:04:37.546737Z", 224 | "iopub.status.idle": "2021-11-29T03:04:37.551951Z", 225 | "shell.execute_reply": "2021-11-29T03:04:37.553081Z", 226 | "shell.execute_reply.started": "2021-11-27T13:38:32.620904Z" 227 | }, 228 | "papermill": { 229 | "duration": 0.065069, 230 | "end_time": "2021-11-29T03:04:37.553274", 231 | "exception": false, 232 | "start_time": "2021-11-29T03:04:37.488205", 233 | "status": "completed" 234 | }, 235 | "tags": [] 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "def make_flatted(train_ds, label_ds, info, start_idx=0):\n", 240 | " # 只使用sst特征\n", 241 | " keys = ['sst']\n", 242 | " label_key = 'nino'\n", 243 | " # 年数\n", 244 | " years = info[1]\n", 245 | " # 模式数\n", 246 | " models = info[2]\n", 247 | " \n", 248 | " train_list = []\n", 249 | " label_list = []\n", 250 | " \n", 251 | " # 将同种模式下的数据拼接起来\n", 252 | " for model_i in range(models):\n", 253 | " blocks = []\n", 254 | " \n", 255 | " # 对每个特征,取每条数据的前12个月进行拼接,只使用lon值在[90, 330]范围内的数据\n", 256 | " for key in keys:\n", 257 | " block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12, :, 19: 67].reshape(-1, 24, 48, 1).data\n", 258 | " blocks.append(block)\n", 259 | " \n", 260 | " # 将所有特征在最后一个维度上拼接起来\n", 261 | " train_flatted = np.concatenate(blocks, axis=-1)\n", 262 | " \n", 263 | " # 取12-23月的标签进行拼接,注意加上最后一年的最后12个月的标签(与最后一年12-23月的标签共同构成最后一年前12个月的预测目标)\n", 264 | " label_flatted = np.concatenate([\n", 265 | " label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n", 266 | " label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n", 267 | " ], axis=0)\n", 268 | " \n", 269 | " train_list.append(train_flatted)\n", 270 | " label_list.append(label_flatted)\n", 271 | " \n", 272 | " return train_list, label_list" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 6, 278 | "metadata": { 279 | "execution": { 280 | "iopub.execute_input": "2021-11-29T03:04:37.661954Z", 281 | "iopub.status.busy": "2021-11-29T03:04:37.660977Z", 282 | "iopub.status.idle": "2021-11-29T03:05:11.515409Z", 283 | "shell.execute_reply": "2021-11-29T03:05:11.515853Z", 284 | "shell.execute_reply.started": "2021-11-27T13:38:33.844185Z" 285 | }, 286 | "papermill": { 287 | "duration": 33.912013, 288 | "end_time": "2021-11-29T03:05:11.516001", 289 | "exception": false, 290 | "start_time": "2021-11-29T03:04:37.603988", 291 | "status": "completed" 292 | }, 293 | "tags": [] 294 | }, 295 | "outputs": [ 296 | { 297 | "data": { 298 | "text/plain": [ 299 | "((1, 1200, 24, 48, 1), (15, 1812, 24, 48, 1), (17, 1680, 24, 48, 1))" 300 | ] 301 | }, 302 | "execution_count": 6, 303 | "metadata": {}, 304 | "output_type": "execute_result" 305 | } 306 | ], 307 | "source": [ 308 | "soda_info = ('soda', 100, 1)\n", 309 | "cmip6_info = ('cmip6', 151, 15)\n", 310 | "cmip5_info = ('cmip5', 140, 17)\n", 311 | "\n", 312 | "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n", 313 | "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n", 314 | "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n", 315 | "\n", 316 | "# 得到扁平化后的数据维度为(模式数×序列长度×纬度×经度×特征数),其中序列长度=年数×12\n", 317 | "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "#### 空值填充\n", 325 | "将空值填充为0。" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 8, 331 | "metadata": { 332 | "execution": { 333 | "iopub.execute_input": "2021-11-29T03:05:11.638562Z", 334 | "iopub.status.busy": "2021-11-29T03:05:11.637553Z", 335 | "iopub.status.idle": "2021-11-29T03:05:11.644302Z", 336 | "shell.execute_reply": "2021-11-29T03:05:11.644742Z", 337 | "shell.execute_reply.started": "2021-11-27T13:39:22.665855Z" 338 | }, 339 | "papermill": { 340 | "duration": 0.040786, 341 | "end_time": "2021-11-29T03:05:11.644893", 342 | "exception": false, 343 | "start_time": "2021-11-29T03:05:11.604107", 344 | "status": "completed" 345 | }, 346 | "tags": [] 347 | }, 348 | "outputs": [ 349 | { 350 | "name": "stdout", 351 | "output_type": "stream", 352 | "text": [ 353 | "Number of null in soda_trains after fillna: 0\n" 354 | ] 355 | } 356 | ], 357 | "source": [ 358 | "# 填充SODA数据中的空值\n", 359 | "soda_trains = np.array(soda_trains)\n", 360 | "soda_trains_nan = np.isnan(soda_trains)\n", 361 | "soda_trains[soda_trains_nan] = 0\n", 362 | "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 9, 368 | "metadata": { 369 | "execution": { 370 | "iopub.execute_input": "2021-11-29T03:05:11.709054Z", 371 | "iopub.status.busy": "2021-11-29T03:05:11.707767Z", 372 | "iopub.status.idle": "2021-11-29T03:05:11.862744Z", 373 | "shell.execute_reply": "2021-11-29T03:05:11.863294Z", 374 | "shell.execute_reply.started": "2021-11-27T13:39:24.110039Z" 375 | }, 376 | "papermill": { 377 | "duration": 0.18937, 378 | "end_time": "2021-11-29T03:05:11.863480", 379 | "exception": false, 380 | "start_time": "2021-11-29T03:05:11.674110", 381 | "status": "completed" 382 | }, 383 | "tags": [] 384 | }, 385 | "outputs": [ 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "Number of null in cmip6_trains after fillna: 0\n" 391 | ] 392 | } 393 | ], 394 | "source": [ 395 | "# 填充CMIP6数据中的空值\n", 396 | "cmip6_trains = np.array(cmip6_trains)\n", 397 | "cmip6_trains_nan = np.isnan(cmip6_trains)\n", 398 | "cmip6_trains[cmip6_trains_nan] = 0\n", 399 | "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 10, 405 | "metadata": { 406 | "execution": { 407 | "iopub.execute_input": "2021-11-29T03:05:11.927752Z", 408 | "iopub.status.busy": "2021-11-29T03:05:11.925353Z", 409 | "iopub.status.idle": "2021-11-29T03:05:12.091117Z", 410 | "shell.execute_reply": "2021-11-29T03:05:12.091855Z", 411 | "shell.execute_reply.started": "2021-11-27T13:39:24.520724Z" 412 | }, 413 | "papermill": { 414 | "duration": 0.197975, 415 | "end_time": "2021-11-29T03:05:12.092014", 416 | "exception": false, 417 | "start_time": "2021-11-29T03:05:11.894039", 418 | "status": "completed" 419 | }, 420 | "tags": [] 421 | }, 422 | "outputs": [ 423 | { 424 | "name": "stdout", 425 | "output_type": "stream", 426 | "text": [ 427 | "Number of null in cmip6_trains after fillna: 0\n" 428 | ] 429 | } 430 | ], 431 | "source": [ 432 | "# 填充CMIP5数据中的空值\n", 433 | "cmip5_trains = np.array(cmip5_trains)\n", 434 | "cmip5_trains_nan = np.isnan(cmip5_trains)\n", 435 | "cmip5_trains[cmip5_trains_nan] = 0\n", 436 | "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "#### 构造数据集\n", 444 | "构造训练和验证集。注意这里取每条输入数据的序列长度是38,这是因为输入sst序列长度是12,输出sst序列长度是26,在训练中采用teacher forcing策略(这个策略会在之后的模型构造时详细说明),因此这里在构造输入数据时包含了输出sst序列的实际值。" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 11, 450 | "metadata": { 451 | "execution": { 452 | "iopub.execute_input": "2021-11-29T03:05:12.165242Z", 453 | "iopub.status.busy": "2021-11-29T03:05:12.164045Z", 454 | "iopub.status.idle": "2021-11-29T03:05:12.480257Z", 455 | "shell.execute_reply": "2021-11-29T03:05:12.479767Z", 456 | "shell.execute_reply.started": "2021-11-27T13:39:25.418945Z" 457 | }, 458 | "papermill": { 459 | "duration": 0.361254, 460 | "end_time": "2021-11-29T03:05:12.480405", 461 | "exception": false, 462 | "start_time": "2021-11-29T03:05:12.119151", 463 | "status": "completed" 464 | }, 465 | "tags": [] 466 | }, 467 | "outputs": [], 468 | "source": [ 469 | "# 构造训练集\n", 470 | "\n", 471 | "X_train = []\n", 472 | "y_train = []\n", 473 | "# 从CMIP5的17种模式中各抽取100条数据\n", 474 | "for model_i in range(17):\n", 475 | " samples = np.random.choice(cmip5_trains.shape[1]-38, size=100)\n", 476 | " for ind in samples:\n", 477 | " X_train.append(cmip5_trains[model_i, ind: ind+38])\n", 478 | " y_train.append(cmip5_labels[model_i][ind: ind+24])\n", 479 | "# 从CMIP6的15种模式种各抽取100条数据\n", 480 | "for model_i in range(15):\n", 481 | " samples = np.random.choice(cmip6_trains.shape[1]-38, size=100)\n", 482 | " for ind in samples:\n", 483 | " X_train.append(cmip6_trains[model_i, ind: ind+38])\n", 484 | " y_train.append(cmip6_labels[model_i][ind: ind+24])\n", 485 | "X_train = np.array(X_train)\n", 486 | "y_train = np.array(y_train)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 12, 492 | "metadata": { 493 | "execution": { 494 | "iopub.execute_input": "2021-11-29T03:05:12.541232Z", 495 | "iopub.status.busy": "2021-11-29T03:05:12.540676Z", 496 | "iopub.status.idle": "2021-11-29T03:05:12.548103Z", 497 | "shell.execute_reply": "2021-11-29T03:05:12.547520Z", 498 | "shell.execute_reply.started": "2021-11-27T13:39:26.341849Z" 499 | }, 500 | "papermill": { 501 | "duration": 0.040262, 502 | "end_time": "2021-11-29T03:05:12.548224", 503 | "exception": false, 504 | "start_time": "2021-11-29T03:05:12.507962", 505 | "status": "completed" 506 | }, 507 | "tags": [] 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "# 构造测试集\n", 512 | "\n", 513 | "X_valid = []\n", 514 | "y_valid = []\n", 515 | "samples = np.random.choice(soda_trains.shape[1]-38, size=100)\n", 516 | "for ind in samples:\n", 517 | " X_valid.append(soda_trains[0, ind: ind+38])\n", 518 | " y_valid.append(soda_labels[0][ind: ind+24])\n", 519 | "X_valid = np.array(X_valid)\n", 520 | "y_valid = np.array(y_valid)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 13, 526 | "metadata": { 527 | "execution": { 528 | "iopub.execute_input": "2021-11-29T03:05:12.606407Z", 529 | "iopub.status.busy": "2021-11-29T03:05:12.605555Z", 530 | "iopub.status.idle": "2021-11-29T03:05:12.611580Z", 531 | "shell.execute_reply": "2021-11-29T03:05:12.611152Z", 532 | "shell.execute_reply.started": "2021-11-27T13:39:27.247585Z" 533 | }, 534 | "papermill": { 535 | "duration": 0.036214, 536 | "end_time": "2021-11-29T03:05:12.611721", 537 | "exception": false, 538 | "start_time": "2021-11-29T03:05:12.575507", 539 | "status": "completed" 540 | }, 541 | "tags": [] 542 | }, 543 | "outputs": [ 544 | { 545 | "data": { 546 | "text/plain": [ 547 | "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))" 548 | ] 549 | }, 550 | "execution_count": 13, 551 | "metadata": {}, 552 | "output_type": "execute_result" 553 | } 554 | ], 555 | "source": [ 556 | "# 查看数据集维度\n", 557 | "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 15, 563 | "metadata": { 564 | "execution": { 565 | "iopub.execute_input": "2021-11-29T03:05:12.737322Z", 566 | "iopub.status.busy": "2021-11-29T03:05:12.736558Z", 567 | "iopub.status.idle": "2021-11-29T03:05:13.516712Z", 568 | "shell.execute_reply": "2021-11-29T03:05:13.517217Z", 569 | "shell.execute_reply.started": "2021-11-27T13:39:38.421657Z" 570 | }, 571 | "papermill": { 572 | "duration": 0.812187, 573 | "end_time": "2021-11-29T03:05:13.517368", 574 | "exception": false, 575 | "start_time": "2021-11-29T03:05:12.705181", 576 | "status": "completed" 577 | }, 578 | "tags": [] 579 | }, 580 | "outputs": [], 581 | "source": [ 582 | "# 保存数据集\n", 583 | "np.save('X_train_sample.npy', X_train)\n", 584 | "np.save('y_train_sample.npy', y_train)\n", 585 | "np.save('X_valid_sample.npy', X_valid)\n", 586 | "np.save('y_valid_sample.npy', y_valid)" 587 | ] 588 | }, 589 | { 590 | "cell_type": "markdown", 591 | "metadata": {}, 592 | "source": [ 593 | "### 模型构建" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": 16, 599 | "metadata": { 600 | "execution": { 601 | "iopub.execute_input": "2021-11-29T03:05:13.577516Z", 602 | "iopub.status.busy": "2021-11-29T03:05:13.576992Z", 603 | "iopub.status.idle": "2021-11-29T03:05:21.917657Z", 604 | "shell.execute_reply": "2021-11-29T03:05:21.918265Z", 605 | "shell.execute_reply.started": "2021-11-29T01:03:01.505192Z" 606 | }, 607 | "papermill": { 608 | "duration": 8.372964, 609 | "end_time": "2021-11-29T03:05:21.918443", 610 | "exception": false, 611 | "start_time": "2021-11-29T03:05:13.545479", 612 | "status": "completed" 613 | }, 614 | "tags": [] 615 | }, 616 | "outputs": [], 617 | "source": [ 618 | "# 读取数据集\n", 619 | "X_train = np.load('../input/ai-earth-task05-samples/X_train_sample.npy')\n", 620 | "y_train = np.load('../input/ai-earth-task05-samples/y_train_sample.npy')\n", 621 | "X_valid = np.load('../input/ai-earth-task05-samples/X_valid_sample.npy')\n", 622 | "y_valid = np.load('../input/ai-earth-task05-samples/y_valid_sample.npy')" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 17, 628 | "metadata": { 629 | "execution": { 630 | "iopub.execute_input": "2021-11-29T03:05:21.983898Z", 631 | "iopub.status.busy": "2021-11-29T03:05:21.982953Z", 632 | "iopub.status.idle": "2021-11-29T03:05:21.986939Z", 633 | "shell.execute_reply": "2021-11-29T03:05:21.986453Z", 634 | "shell.execute_reply.started": "2021-11-29T01:03:11.548945Z" 635 | }, 636 | "papermill": { 637 | "duration": 0.039398, 638 | "end_time": "2021-11-29T03:05:21.987066", 639 | "exception": false, 640 | "start_time": "2021-11-29T03:05:21.947668", 641 | "status": "completed" 642 | }, 643 | "tags": [] 644 | }, 645 | "outputs": [ 646 | { 647 | "data": { 648 | "text/plain": [ 649 | "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))" 650 | ] 651 | }, 652 | "execution_count": 17, 653 | "metadata": {}, 654 | "output_type": "execute_result" 655 | } 656 | ], 657 | "source": [ 658 | "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 18, 664 | "metadata": { 665 | "execution": { 666 | "iopub.execute_input": "2021-11-29T03:05:22.341929Z", 667 | "iopub.status.busy": "2021-11-29T03:05:22.340932Z", 668 | "iopub.status.idle": "2021-11-29T03:05:22.346140Z", 669 | "shell.execute_reply": "2021-11-29T03:05:22.346878Z", 670 | "shell.execute_reply.started": "2021-11-29T01:03:11.560457Z" 671 | }, 672 | "papermill": { 673 | "duration": 0.143838, 674 | "end_time": "2021-11-29T03:05:22.347113", 675 | "exception": false, 676 | "start_time": "2021-11-29T03:05:22.203275", 677 | "status": "completed" 678 | }, 679 | "tags": [] 680 | }, 681 | "outputs": [], 682 | "source": [ 683 | "# 构造数据管道\n", 684 | "class AIEarthDataset(Dataset):\n", 685 | " def __init__(self, data, label):\n", 686 | " self.data = torch.tensor(data, dtype=torch.float32)\n", 687 | " self.label = torch.tensor(label, dtype=torch.float32)\n", 688 | "\n", 689 | " def __len__(self):\n", 690 | " return len(self.label)\n", 691 | " \n", 692 | " def __getitem__(self, idx):\n", 693 | " return self.data[idx], self.label[idx]" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 19, 699 | "metadata": { 700 | "execution": { 701 | "iopub.execute_input": "2021-11-29T03:05:22.583350Z", 702 | "iopub.status.busy": "2021-11-29T03:05:22.582298Z", 703 | "iopub.status.idle": "2021-11-29T03:05:23.243100Z", 704 | "shell.execute_reply": "2021-11-29T03:05:23.243851Z", 705 | "shell.execute_reply.started": "2021-11-29T01:03:23.691846Z" 706 | }, 707 | "papermill": { 708 | "duration": 0.825537, 709 | "end_time": "2021-11-29T03:05:23.244098", 710 | "exception": false, 711 | "start_time": "2021-11-29T03:05:22.418561", 712 | "status": "completed" 713 | }, 714 | "tags": [] 715 | }, 716 | "outputs": [], 717 | "source": [ 718 | "batch_size = 2\n", 719 | "\n", 720 | "trainset = AIEarthDataset(X_train, y_train)\n", 721 | "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", 722 | "\n", 723 | "validset = AIEarthDataset(X_valid, y_valid)\n", 724 | "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)" 725 | ] 726 | }, 727 | { 728 | "cell_type": "markdown", 729 | "metadata": {}, 730 | "source": [ 731 | "#### 构造评估函数" 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": 24, 737 | "metadata": { 738 | "execution": { 739 | "iopub.execute_input": "2021-11-29T03:05:23.655820Z", 740 | "iopub.status.busy": "2021-11-29T03:05:23.655241Z", 741 | "iopub.status.idle": "2021-11-29T03:05:23.658416Z", 742 | "shell.execute_reply": "2021-11-29T03:05:23.658859Z", 743 | "shell.execute_reply.started": "2021-11-29T01:03:26.481561Z" 744 | }, 745 | "papermill": { 746 | "duration": 0.040887, 747 | "end_time": "2021-11-29T03:05:23.658990", 748 | "exception": false, 749 | "start_time": "2021-11-29T03:05:23.618103", 750 | "status": "completed" 751 | }, 752 | "tags": [] 753 | }, 754 | "outputs": [], 755 | "source": [ 756 | "def rmse(y_true, y_preds):\n", 757 | " return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n", 758 | "\n", 759 | "# 评估函数\n", 760 | "def score(y_true, y_preds):\n", 761 | " # 相关性技巧评分\n", 762 | " accskill_score = 0\n", 763 | " # RMSE\n", 764 | " rmse_scores = 0\n", 765 | " a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n", 766 | " y_true_mean = np.mean(y_true, axis=0)\n", 767 | " y_pred_mean = np.mean(y_preds, axis=0)\n", 768 | " for i in range(24):\n", 769 | " fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n", 770 | " fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n", 771 | " cor_i = fenzi / fenmu\n", 772 | " accskill_score += a[i] * np.log(i+1) * cor_i\n", 773 | " rmse_score = rmse(y_true[:, i], y_preds[:, i])\n", 774 | " rmse_scores += rmse_score\n", 775 | " return 2/3.0 * accskill_score - rmse_scores" 776 | ] 777 | }, 778 | { 779 | "cell_type": "markdown", 780 | "metadata": { 781 | "papermill": { 782 | "duration": 0.028556, 783 | "end_time": "2021-11-29T03:05:23.310560", 784 | "exception": false, 785 | "start_time": "2021-11-29T03:05:23.282004", 786 | "status": "completed" 787 | }, 788 | "tags": [] 789 | }, 790 | "source": [ 791 | "#### 模型构造" 792 | ] 793 | }, 794 | { 795 | "cell_type": "markdown", 796 | "metadata": {}, 797 | "source": [ 798 | "不同于前两个TOP方案所构建的多输出神经网络,该TOP方案采用的是Seq2Seq结构,以本赛题为例,输入的序列长度是12,输出的序列长度是26,方案中构建了四个隐藏层,那么一个基础的Seq2Seq结构就如下图所示:\n", 799 | "\n", 800 | "\n", 801 | "\n", 802 | "要将Seq2Seq结构应用于不同的问题,重点在于使用怎样的Cell(神经元)。在该TOP方案中使用的Cell是清华大学提出的SA-ConvLSTM(Self-Attention ConvLSTM),论文原文可参考https://ojs.aaai.org//index.php/AAAI/article/view/6819\n", 803 | "\n", 804 | "SA-ConvLSTM是施行健博士提出的时空序列领域经典模型ConvLSTM的改进模型,为了捕捉空间信息的时序依赖关系,它在ConvLSTM的基础上增加了SAM模块,用来记忆空间的聚合特征。ConvLSTM的论文原文可参考https://arxiv.org/pdf/1506.04214.pdf\n", 805 | "\n", 806 | "1. ConvLSTM模型\n", 807 | "\n", 808 | "LSTM模型是非常经典的时序模型,三个门的结构使得它在挖掘长期的时间依赖任务中有不俗的表现,并且相较于RNN,LSTM能够有效地避免梯度消失问题。对于单个输入样本,在每个时间步上,LSTM的每个门实际是对输入向量做了一个全连接,那么对应到我们这个赛题上,输入X的形状是(N,T,H,W,C),则单个输入样本在每个时间步上输入LSTM的就是形状为(H,W,C)的空间信息。我们知道,全连接网络对于这种空间信息的提取能力并不强,转换成卷积操作后能够在大大减少参数量的同时通过堆叠多层网络逐步提取出更复杂的特征,到这里就可以很自然地想到,把LSTM中的全连接操作转换为卷积操作,就能够适用于时空序列问题。ConvLSTM模型就是这么做的,实践也表明这样的作法是非常有效的。\n", 809 | "\n", 810 | "\n", 811 | "\n", 812 | "2. SAM模块\n", 813 | "\n", 814 | "然而,ConvLSTM模型存在两个问题:\n", 815 | "\n", 816 | "一是卷积层的感受野受限于卷积核的大小,需要通过堆叠多个卷积层来扩大感受野,发掘全局的特征。举例来说,假设第一个卷积层的卷积核大小是3×3,那么这一层的每个节点就只能感知这3×3的空间范围内的输入信息,此时再增加一个3×3的卷积层,那么每个节点所能感知的就是3×3个第一层的节点内的信息,在第一层步长为1的情况下,就是4×4范围内的输入信息,于是相比于第一个卷积层,第二层所能感知的输入信息的空间范围就增大了,而这样做所带来的后果就是参数量增加。对于单纯的CNN模型来说增加一层只是增加了一个卷积核大小的参数量,但是对于ConvLSTM来说就有些不堪重负,参数量的增加增大了过拟合的风险,与此同时模型的收效却并不高。\n", 817 | "\n", 818 | "二是卷积操作只针对当前时间步输入的空间信息,而忽视了过去的空间信息,因此难以挖掘空间信息在时间上的依赖关系。\n", 819 | "\n", 820 | "因此,为了同时挖掘全局和本地的空间依赖,提升模型在大空间范围和长时间的时空序列预测任务中的预测效果,SA-ConvLSTM模型在ConvLSTM模型的基础上引入了SAM(self-attention memory)模块。\n", 821 | "\n", 822 | "\n", 823 | "\n", 824 | "SAM模块引入了一个新的记忆单元M,用来记忆包含时序依赖关系的空间信息。SAM模块以当前时间步通过ConvLSTM所获得的隐藏层状态$H_t$和上一个时间步的记忆$M_{t-1}$作为输入,首先将$H_t$通过自注意力机制得到特征$Z_h$,自注意力机制能够增加$H_t$中与其他部分更相关的部分的权重,同时$H_t$也作为Query与$M_{t-1}$共同通过注意力机制得到特征$Z_m$,用以增强对$M_{t-1}$中与$H_t$有更强依赖关系的部分的权重,将$Z_h$和$Z_m$拼接起来就得到了二者的聚合特征$Z$。此时,聚合特征$Z$中既包含了当前时间步的信息,又包含了全局的时空记忆信息,接下来借鉴LSTM中的门控结构用聚合特征$Z$对隐藏层状态和记忆单元进行更新,就得到了更新后的隐藏层状态$\\hat{H_t}$和当前时间步的记忆$M_t$。SAM模块的公式如下:\n", 825 | "\n", 826 | "$$\n", 827 | "\\begin{aligned}\n", 828 | "& i'_t = \\sigma (W_{m;zi} \\ast Z + W_{m;hi} \\ast H_t + b_{m;i}) \\\\\n", 829 | "& g'_t = tanh (W_{m;zg} \\ast Z + W_{m;hg} \\ast H_t + b_{m;g}) \\\\\n", 830 | "& M_t = (1 - i'_t) \\circ M_{t-1} + i'_t \\circ g'_t \\\\\n", 831 | "& o'_t = \\sigma (W_{m;zo} \\ast Z + W_{m;ho} \\ast H_t + b_{m;o}) \\\\\n", 832 | "& \\hat{H_t} = o'_t \\circ M_t\n", 833 | "\\end{aligned}\n", 834 | "$$\n", 835 | "\n", 836 | "关于注意力机制和自注意力机制可以参考以下链接:\n", 837 | "\n", 838 | " - 深度学习中的注意力机制:https://blog.csdn.net/malefactor/article/details/78767781\n", 839 | " - 目前主流的Attention方法:https://www.zhihu.com/question/68482809\n", 840 | "\n", 841 | "3. SA-ConvLSTM模型\n", 842 | "\n", 843 | "将以上二者结合起来,就得到了SA-ConvLSTM模型:\n", 844 | "\n", 845 | "" 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "execution_count": 20, 851 | "metadata": { 852 | "execution": { 853 | "iopub.execute_input": "2021-11-29T03:05:23.372772Z", 854 | "iopub.status.busy": "2021-11-29T03:05:23.371873Z", 855 | "iopub.status.idle": "2021-11-29T03:05:23.373700Z", 856 | "shell.execute_reply": "2021-11-29T03:05:23.374122Z", 857 | "shell.execute_reply.started": "2021-11-29T01:03:24.585147Z" 858 | }, 859 | "papermill": { 860 | "duration": 0.035787, 861 | "end_time": "2021-11-29T03:05:23.374254", 862 | "exception": false, 863 | "start_time": "2021-11-29T03:05:23.338467", 864 | "status": "completed" 865 | }, 866 | "tags": [] 867 | }, 868 | "outputs": [], 869 | "source": [ 870 | "# Attention机制\n", 871 | "def attn(query, key, value):\n", 872 | " # query、key、value的形状都是(N, C, H*W),令S=H*W\n", 873 | " # 采用缩放点积模型计算得分,scores(i)=key(i)^T query/根号C\n", 874 | " scores = torch.matmul(query.transpose(1, 2), key / math.sqrt(query.size(1))) # (N, S, S)\n", 875 | " # 计算注意力得分\n", 876 | " attn = F.softmax(scores, dim=-1)\n", 877 | " output = torch.matmul(attn, value.transpose(1, 2)) # (N, S, C)\n", 878 | " return output.transpose(1, 2) # (N, C, S)" 879 | ] 880 | }, 881 | { 882 | "cell_type": "code", 883 | "execution_count": 21, 884 | "metadata": { 885 | "execution": { 886 | "iopub.execute_input": "2021-11-29T03:05:23.440765Z", 887 | "iopub.status.busy": "2021-11-29T03:05:23.440042Z", 888 | "iopub.status.idle": "2021-11-29T03:05:23.442191Z", 889 | "shell.execute_reply": "2021-11-29T03:05:23.442569Z", 890 | "shell.execute_reply.started": "2021-11-29T01:03:25.147999Z" 891 | }, 892 | "papermill": { 893 | "duration": 0.041095, 894 | "end_time": "2021-11-29T03:05:23.442725", 895 | "exception": false, 896 | "start_time": "2021-11-29T03:05:23.401630", 897 | "status": "completed" 898 | }, 899 | "tags": [] 900 | }, 901 | "outputs": [], 902 | "source": [ 903 | "# SAM模块\n", 904 | "class SAAttnMem(nn.Module):\n", 905 | " def __init__(self, input_dim, d_model, kernel_size):\n", 906 | " super().__init__()\n", 907 | " pad = kernel_size[0] // 2, kernel_size[1] // 2\n", 908 | " self.d_model = d_model\n", 909 | " self.input_dim = input_dim\n", 910 | " # 用1*1卷积实现全连接操作WhHt\n", 911 | " self.conv_h = nn.Conv2d(input_dim, d_model*3, kernel_size=1)\n", 912 | " # 用1*1卷积实现全连接操作WmMt-1\n", 913 | " self.conv_m = nn.Conv2d(input_dim, d_model*2, kernel_size=1)\n", 914 | " # 用1*1卷积实现全连接操作Wz[Zh,Zm]\n", 915 | " self.conv_z = nn.Conv2d(d_model*2, d_model, kernel_size=1)\n", 916 | " # 注意输出维度和输入维度要保持一致,都是input_dim\n", 917 | " self.conv_output = nn.Conv2d(input_dim+d_model, input_dim*3, kernel_size=kernel_size, padding=pad)\n", 918 | " \n", 919 | " def forward(self, h, m):\n", 920 | " # self.conv_h(h)得到WhHt,将其在dim=1上划分成大小为self.d_model的块,每一块的形状就是(N, d_model, H, W),所得到的三块就是Qh、Kh、Vh\n", 921 | " hq, hk, hv = torch.split(self.conv_h(h), self.d_model, dim=1)\n", 922 | " # 同样的方法得到Km和Vm\n", 923 | " mk, mv = torch.split(self.conv_m(m), self.d_model, dim=1)\n", 924 | " N, C, H, W = hq.size()\n", 925 | " # 通过自注意力机制得到Zh\n", 926 | " Zh = attn(hq.view(N, C, -1), hk.view(N, C, -1), hv.view(N, C, -1)) # (N, C, S), C=d_model\n", 927 | " # 通过注意力机制得到Zm\n", 928 | " Zm = attn(hq.view(N, C, -1), mk.view(N, C, -1), mv.view(N, C, -1)) # (N, C, S), C=d_model\n", 929 | " # 将Zh和Zm拼接起来,并进行全连接操作得到聚合特征Z\n", 930 | " Z = self.conv_z(torch.cat([Zh.view(N, C, H, W), Zm.view(N, C, H, W)], dim=1)) # (N, C, H, W), C=d_model\n", 931 | " # 计算i't、g't、o't\n", 932 | " i, g, o = torch.split(self.conv_output(torch.cat([Z, h], dim=1)), self.input_dim, dim=1) # (N, C, H, W), C=input_dim\n", 933 | " i = torch.sigmoid(i)\n", 934 | " g = torch.tanh(g)\n", 935 | " # 得到更新后的记忆单元Mt\n", 936 | " m_next = i * g + (1 - i) * m\n", 937 | " # 得到更新后的隐藏状态Ht\n", 938 | " h_next = torch.sigmoid(o) * m_next\n", 939 | " return h_next, m_next" 940 | ] 941 | }, 942 | { 943 | "cell_type": "code", 944 | "execution_count": 22, 945 | "metadata": { 946 | "execution": { 947 | "iopub.execute_input": "2021-11-29T03:05:23.509738Z", 948 | "iopub.status.busy": "2021-11-29T03:05:23.509080Z", 949 | "iopub.status.idle": "2021-11-29T03:05:23.512667Z", 950 | "shell.execute_reply": "2021-11-29T03:05:23.512182Z", 951 | "shell.execute_reply.started": "2021-11-29T01:03:25.667808Z" 952 | }, 953 | "papermill": { 954 | "duration": 0.042616, 955 | "end_time": "2021-11-29T03:05:23.512781", 956 | "exception": false, 957 | "start_time": "2021-11-29T03:05:23.470165", 958 | "status": "completed" 959 | }, 960 | "tags": [] 961 | }, 962 | "outputs": [], 963 | "source": [ 964 | "# SA-ConvLSTM Cell\n", 965 | "class SAConvLSTMCell(nn.Module):\n", 966 | " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n", 967 | " super().__init__()\n", 968 | " self.input_dim = input_dim\n", 969 | " self.hidden_dim = hidden_dim\n", 970 | " pad = kernel_size[0] // 2, kernel_size[1] // 2\n", 971 | " # 卷积操作Wx*Xt+Wh*Ht-1\n", 972 | " self.conv = nn.Conv2d(in_channels=input_dim+hidden_dim, out_channels=4*hidden_dim, kernel_size=kernel_size, padding=pad)\n", 973 | " self.sa = SAAttnMem(input_dim=hidden_dim, d_model=d_attn, kernel_size=kernel_size)\n", 974 | " \n", 975 | " def initialize(self, inputs):\n", 976 | " device = inputs.device\n", 977 | " N, _, H, W = inputs.size()\n", 978 | " # 初始化隐藏层状态Ht\n", 979 | " self.hidden_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", 980 | " # 初始化记忆细胞状态ct\n", 981 | " self.cell_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", 982 | " # 初始化记忆单元状态Mt\n", 983 | " self.memory_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", 984 | " \n", 985 | " def forward(self, inputs, first_step=False):\n", 986 | " # 如果当前是第一个时间步,初始化Ht、ct、Mt\n", 987 | " if first_step:\n", 988 | " self.initialize(inputs)\n", 989 | " \n", 990 | " # ConvLSTM部分\n", 991 | " # 拼接Xt和Ht\n", 992 | " combined = torch.cat([inputs, self.hidden_state], dim=1) # (N, C, H, W), C=input_dim+hidden_dim\n", 993 | " # 进行卷积操作\n", 994 | " combined_conv = self.conv(combined) \n", 995 | " # 得到四个门控单元it、ft、ot、gt\n", 996 | " cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)\n", 997 | " i = torch.sigmoid(cc_i)\n", 998 | " f = torch.sigmoid(cc_f)\n", 999 | " o = torch.sigmoid(cc_o)\n", 1000 | " g = torch.tanh(cc_g)\n", 1001 | " # 得到当前时间步的记忆细胞状态ct=ft·ct-1+it·gt\n", 1002 | " self.cell_state = f * self.cell_state + i * g\n", 1003 | " # 得到当前时间步的隐藏层状态Ht=ot·tanh(ct)\n", 1004 | " self.hidden_state = o * torch.tanh(self.cell_state)\n", 1005 | " \n", 1006 | " # SAM部分,更新Ht和Mt\n", 1007 | " self.hidden_state, self.memory_state = self.sa(self.hidden_state, self.memory_state)\n", 1008 | " \n", 1009 | " return self.hidden_state" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "markdown", 1014 | "metadata": {}, 1015 | "source": [ 1016 | "在Seq2Seq模型的训练中,有两种训练模式。一是Free running,也就是传统的训练方式,以上一个时间步的输出$\\hat{y_{t-1}}$作为下一个时间步的输入,但是这种做法存在的问题是在训练的初期所得到的$\\hat{y_{t-1}}$与实际标签$y_{t-1}$相差甚远,以此作为输入会导致后续的输出越来越偏离我们期望的预测标签。于是就产生了第二种训练模式——Teacher forcing。\n", 1017 | "\n", 1018 | "Teacher forcing就是直接使用实际标签$y_{t-1}$作为下一个时间步的输入,由老师(ground truth)带领着防止模型越走越偏。但是老师不能总是手把手领着学生走,要逐渐放手让学生自主学习,于是我们使用Scheduled Sampling来控制使用实际标签的概率。我们用ratio来表示Scheduled Sampling的比例,在训练初期,ratio=1,模型完全由老师带领着,随着训练论述的增加,ratio以一定的方式衰减(该方案中使用线性衰减,ratio每次减小一个衰减率decay_rate),每个时间步以ratio的概率从伯努利分布中提取二进制随机数0或1,为1时输入就是实际标签$y_{t-1}$,否则输入为$\\hat{y_{t-1}}$。" 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "code", 1023 | "execution_count": 23, 1024 | "metadata": { 1025 | "execution": { 1026 | "iopub.execute_input": "2021-11-29T03:05:23.587567Z", 1027 | "iopub.status.busy": "2021-11-29T03:05:23.586781Z", 1028 | "iopub.status.idle": "2021-11-29T03:05:23.588776Z", 1029 | "shell.execute_reply": "2021-11-29T03:05:23.589156Z", 1030 | "shell.execute_reply.started": "2021-11-29T01:03:26.065997Z" 1031 | }, 1032 | "papermill": { 1033 | "duration": 0.047514, 1034 | "end_time": "2021-11-29T03:05:23.589277", 1035 | "exception": false, 1036 | "start_time": "2021-11-29T03:05:23.541763", 1037 | "status": "completed" 1038 | }, 1039 | "tags": [] 1040 | }, 1041 | "outputs": [], 1042 | "source": [ 1043 | "# 构建SA-ConvLSTM模型\n", 1044 | "class SAConvLSTM(nn.Module):\n", 1045 | " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n", 1046 | " super().__init__()\n", 1047 | " self.input_dim = input_dim\n", 1048 | " self.hidden_dim = hidden_dim\n", 1049 | " self.num_layers = len(hidden_dim)\n", 1050 | " \n", 1051 | " layers = []\n", 1052 | " for i in range(self.num_layers):\n", 1053 | " cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]\n", 1054 | " layers.append(SAConvLSTMCell(input_dim=cur_input_dim, hidden_dim=self.hidden_dim[i], d_attn = d_attn, kernel_size=kernel_size)) \n", 1055 | " self.layers = nn.ModuleList(layers)\n", 1056 | " \n", 1057 | " self.conv_output = nn.Conv2d(self.hidden_dim[-1], 1, kernel_size=1)\n", 1058 | " \n", 1059 | " def forward(self, input_x, device=torch.device('cuda:0'), input_frames=12, future_frames=26, output_frames=37, teacher_forcing=False, scheduled_sampling_ratio=0, train=True):\n", 1060 | " # 将输入样本X的形状(N, T, H, W, C)转换为(N, T, C, H, W)\n", 1061 | " input_x = input_x.permute(0, 1, 4, 2, 3).contiguous()\n", 1062 | " \n", 1063 | " # 仅在训练时使用teacher forcing\n", 1064 | " if train:\n", 1065 | " if teacher_forcing and scheduled_sampling_ratio > 1e-6:\n", 1066 | " teacher_forcing_mask = torch.bernoulli(scheduled_sampling_ratio * torch.ones(input_x.size(0), future_frames-1, 1, 1, 1))\n", 1067 | " else:\n", 1068 | " teacher_forcing = False\n", 1069 | " else:\n", 1070 | " teacher_forcing = False\n", 1071 | " \n", 1072 | " total_steps = input_frames + future_frames - 1\n", 1073 | " outputs = [None] * total_steps\n", 1074 | " \n", 1075 | " # 对于每一个时间步\n", 1076 | " for t in range(total_steps):\n", 1077 | " # 在前12个月,使用每个月的输入样本Xt\n", 1078 | " if t < input_frames:\n", 1079 | " input_ = input_x[:, t].to(device)\n", 1080 | " # 若不使用teacher forcing,则以上一个时间步的预测标签作为当前时间步的输入\n", 1081 | " elif not teacher_forcing:\n", 1082 | " input_ = outputs[t-1]\n", 1083 | " # 若使用teacher forcing,则以ratio的概率使用上一个时间步的实际标签作为当前时间步的输入\n", 1084 | " else:\n", 1085 | " mask = teacher_forcing_mask[:, t-input_frames].float().to(device)\n", 1086 | " input_ = input_x[:, t].to(device) * mask + outputs[t-1] * (1-mask)\n", 1087 | " first_step = (t==0)\n", 1088 | " input_ = input_.float()\n", 1089 | " \n", 1090 | " # 将当前时间步的输入通过隐藏层\n", 1091 | " for layer_idx in range(self.num_layers):\n", 1092 | " input_ = self.layers[layer_idx](input_, first_step=first_step)\n", 1093 | " \n", 1094 | " # 记录每个时间步的输出\n", 1095 | " if train or (t >= (input_frames - 1)):\n", 1096 | " outputs[t] = self.conv_output(input_)\n", 1097 | " \n", 1098 | " outputs = [x for x in outputs if x is not None]\n", 1099 | " \n", 1100 | " # 确认输出序列的长度\n", 1101 | " if train:\n", 1102 | " assert len(outputs) == output_frames\n", 1103 | " else:\n", 1104 | " assert len(outputs) == future_frames\n", 1105 | " \n", 1106 | " # 得到sst的预测序列\n", 1107 | " outputs = torch.stack(outputs, dim=1)[:, :, 0] # (N, 37, H, W)\n", 1108 | " # 对sst的预测序列在nino3.4区域取三个月的平均值就得到nino3.4指数的预测序列\n", 1109 | " nino_pred = outputs[:, -future_frames:, 10:13, 19:30].mean(dim=[2, 3]) # (N, 26)\n", 1110 | " nino_pred = nino_pred.unfold(dimension=1, size=3, step=1).mean(dim=2) # (N, 24)\n", 1111 | " \n", 1112 | " return nino_pred" 1113 | ] 1114 | }, 1115 | { 1116 | "cell_type": "code", 1117 | "execution_count": 25, 1118 | "metadata": { 1119 | "execution": { 1120 | "iopub.execute_input": "2021-11-29T03:05:23.726291Z", 1121 | "iopub.status.busy": "2021-11-29T03:05:23.725688Z", 1122 | "iopub.status.idle": "2021-11-29T03:05:23.753509Z", 1123 | "shell.execute_reply": "2021-11-29T03:05:23.753976Z", 1124 | "shell.execute_reply.started": "2021-11-29T01:03:29.448921Z" 1125 | }, 1126 | "papermill": { 1127 | "duration": 0.066105, 1128 | "end_time": "2021-11-29T03:05:23.754109", 1129 | "exception": false, 1130 | "start_time": "2021-11-29T03:05:23.688004", 1131 | "status": "completed" 1132 | }, 1133 | "tags": [] 1134 | }, 1135 | "outputs": [ 1136 | { 1137 | "name": "stdout", 1138 | "output_type": "stream", 1139 | "text": [ 1140 | "SAConvLSTM(\n", 1141 | " (layers): ModuleList(\n", 1142 | " (0): SAConvLSTMCell(\n", 1143 | " (conv): Conv2d(65, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1144 | " (sa): SAAttnMem(\n", 1145 | " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1146 | " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", 1147 | " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", 1148 | " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1149 | " )\n", 1150 | " )\n", 1151 | " (1): SAConvLSTMCell(\n", 1152 | " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1153 | " (sa): SAAttnMem(\n", 1154 | " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1155 | " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", 1156 | " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", 1157 | " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1158 | " )\n", 1159 | " )\n", 1160 | " (2): SAConvLSTMCell(\n", 1161 | " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1162 | " (sa): SAAttnMem(\n", 1163 | " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1164 | " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", 1165 | " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", 1166 | " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1167 | " )\n", 1168 | " )\n", 1169 | " (3): SAConvLSTMCell(\n", 1170 | " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1171 | " (sa): SAAttnMem(\n", 1172 | " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1173 | " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", 1174 | " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", 1175 | " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 1176 | " )\n", 1177 | " )\n", 1178 | " )\n", 1179 | " (conv_output): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))\n", 1180 | ")\n" 1181 | ] 1182 | } 1183 | ], 1184 | "source": [ 1185 | "# 输入特征数\n", 1186 | "input_dim = 1\n", 1187 | "# 隐藏层节点数\n", 1188 | "hidden_dim = (64, 64, 64, 64)\n", 1189 | "# 注意力机制节点数\n", 1190 | "d_attn = 32\n", 1191 | "# 卷积核大小\n", 1192 | "kernel_size = (3, 3)\n", 1193 | "\n", 1194 | "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n", 1195 | "print(model)" 1196 | ] 1197 | }, 1198 | { 1199 | "cell_type": "markdown", 1200 | "metadata": {}, 1201 | "source": [ 1202 | "#### 模型训练" 1203 | ] 1204 | }, 1205 | { 1206 | "cell_type": "code", 1207 | "execution_count": 26, 1208 | "metadata": { 1209 | "execution": { 1210 | "iopub.execute_input": "2021-11-29T03:05:23.816671Z", 1211 | "iopub.status.busy": "2021-11-29T03:05:23.815927Z", 1212 | "iopub.status.idle": "2021-11-29T03:05:23.818479Z", 1213 | "shell.execute_reply": "2021-11-29T03:05:23.818058Z", 1214 | "shell.execute_reply.started": "2021-11-29T01:03:31.476806Z" 1215 | }, 1216 | "papermill": { 1217 | "duration": 0.035723, 1218 | "end_time": "2021-11-29T03:05:23.818579", 1219 | "exception": false, 1220 | "start_time": "2021-11-29T03:05:23.782856", 1221 | "status": "completed" 1222 | }, 1223 | "tags": [] 1224 | }, 1225 | "outputs": [], 1226 | "source": [ 1227 | "# 采用RMSE作为损失函数\n", 1228 | "def RMSELoss(y_pred,y_true):\n", 1229 | " loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n", 1230 | " return loss" 1231 | ] 1232 | }, 1233 | { 1234 | "cell_type": "code", 1235 | "execution_count": 27, 1236 | "metadata": { 1237 | "execution": { 1238 | "iopub.execute_input": "2021-11-29T03:05:23.893469Z", 1239 | "iopub.status.busy": "2021-11-29T03:05:23.892684Z", 1240 | "iopub.status.idle": "2021-11-29T04:55:28.956056Z", 1241 | "shell.execute_reply": "2021-11-29T04:55:28.956434Z" 1242 | }, 1243 | "papermill": { 1244 | "duration": 6605.109145, 1245 | "end_time": "2021-11-29T04:55:28.956614", 1246 | "exception": false, 1247 | "start_time": "2021-11-29T03:05:23.847469", 1248 | "status": "completed" 1249 | }, 1250 | "tags": [] 1251 | }, 1252 | "outputs": [ 1253 | { 1254 | "name": "stdout", 1255 | "output_type": "stream", 1256 | "text": [ 1257 | "Epoch: 1/5\n" 1258 | ] 1259 | }, 1260 | { 1261 | "name": "stderr", 1262 | "output_type": "stream", 1263 | "text": [ 1264 | "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n" 1265 | ] 1266 | }, 1267 | { 1268 | "name": "stdout", 1269 | "output_type": "stream", 1270 | "text": [ 1271 | "Training Loss: 3.289\n" 1272 | ] 1273 | }, 1274 | { 1275 | "name": "stderr", 1276 | "output_type": "stream", 1277 | "text": [ 1278 | "50it [00:11, 4.47it/s]\n" 1279 | ] 1280 | }, 1281 | { 1282 | "name": "stdout", 1283 | "output_type": "stream", 1284 | "text": [ 1285 | "Validation Loss: 44.009\n", 1286 | "Score: -43.458\n", 1287 | "Epoch: 2/5\n" 1288 | ] 1289 | }, 1290 | { 1291 | "name": "stderr", 1292 | "output_type": "stream", 1293 | "text": [ 1294 | "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n" 1295 | ] 1296 | }, 1297 | { 1298 | "name": "stdout", 1299 | "output_type": "stream", 1300 | "text": [ 1301 | "Training Loss: 3.084\n" 1302 | ] 1303 | }, 1304 | { 1305 | "name": "stderr", 1306 | "output_type": "stream", 1307 | "text": [ 1308 | "50it [00:11, 4.33it/s]\n" 1309 | ] 1310 | }, 1311 | { 1312 | "name": "stdout", 1313 | "output_type": "stream", 1314 | "text": [ 1315 | "Validation Loss: 25.011\n", 1316 | "Score: -19.966\n", 1317 | "Epoch: 3/5\n" 1318 | ] 1319 | }, 1320 | { 1321 | "name": "stderr", 1322 | "output_type": "stream", 1323 | "text": [ 1324 | "100%|██████████| 1600/1600 [21:46<00:00, 1.22it/s]\n" 1325 | ] 1326 | }, 1327 | { 1328 | "name": "stdout", 1329 | "output_type": "stream", 1330 | "text": [ 1331 | "Training Loss: 13.461\n" 1332 | ] 1333 | }, 1334 | { 1335 | "name": "stderr", 1336 | "output_type": "stream", 1337 | "text": [ 1338 | "50it [00:12, 4.16it/s]\n" 1339 | ] 1340 | }, 1341 | { 1342 | "name": "stdout", 1343 | "output_type": "stream", 1344 | "text": [ 1345 | "Validation Loss: 15.438\n", 1346 | "Score: -14.139\n", 1347 | "Epoch: 4/5\n" 1348 | ] 1349 | }, 1350 | { 1351 | "name": "stderr", 1352 | "output_type": "stream", 1353 | "text": [ 1354 | "100%|██████████| 1600/1600 [21:54<00:00, 1.22it/s]\n" 1355 | ] 1356 | }, 1357 | { 1358 | "name": "stdout", 1359 | "output_type": "stream", 1360 | "text": [ 1361 | "Training Loss: 17.627\n" 1362 | ] 1363 | }, 1364 | { 1365 | "name": "stderr", 1366 | "output_type": "stream", 1367 | "text": [ 1368 | "50it [00:12, 3.99it/s]\n" 1369 | ] 1370 | }, 1371 | { 1372 | "name": "stdout", 1373 | "output_type": "stream", 1374 | "text": [ 1375 | "Validation Loss: 15.389\n", 1376 | "Score: -22.500\n", 1377 | "Epoch: 5/5\n" 1378 | ] 1379 | }, 1380 | { 1381 | "name": "stderr", 1382 | "output_type": "stream", 1383 | "text": [ 1384 | "100%|██████████| 1600/1600 [21:55<00:00, 1.22it/s]\n" 1385 | ] 1386 | }, 1387 | { 1388 | "name": "stdout", 1389 | "output_type": "stream", 1390 | "text": [ 1391 | "Training Loss: 17.592\n" 1392 | ] 1393 | }, 1394 | { 1395 | "name": "stderr", 1396 | "output_type": "stream", 1397 | "text": [ 1398 | "50it [00:11, 4.48it/s]" 1399 | ] 1400 | }, 1401 | { 1402 | "name": "stdout", 1403 | "output_type": "stream", 1404 | "text": [ 1405 | "Validation Loss: 15.252\n", 1406 | "Score: -14.459\n" 1407 | ] 1408 | }, 1409 | { 1410 | "name": "stderr", 1411 | "output_type": "stream", 1412 | "text": [ 1413 | "\n" 1414 | ] 1415 | } 1416 | ], 1417 | "source": [ 1418 | "model_weights = './task05_model_weights.pth'\n", 1419 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 1420 | "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size).to(device)\n", 1421 | "criterion = RMSELoss\n", 1422 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", 1423 | "lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=0, verbose=True, min_lr=0.0001)\n", 1424 | "epochs = 5\n", 1425 | "ratio, decay_rate = 1, 8e-5\n", 1426 | "train_losses, valid_losses = [], []\n", 1427 | "scores = []\n", 1428 | "best_score = float('-inf')\n", 1429 | "preds = np.zeros((len(y_valid),24))\n", 1430 | "\n", 1431 | "for epoch in range(epochs):\n", 1432 | " print('Epoch: {}/{}'.format(epoch+1, epochs))\n", 1433 | " \n", 1434 | " # 模型训练\n", 1435 | " model.train()\n", 1436 | " losses = 0\n", 1437 | " for data, labels in tqdm(trainloader):\n", 1438 | " data = data.to(device)\n", 1439 | " labels = labels.to(device)\n", 1440 | " optimizer.zero_grad()\n", 1441 | " # ratio线性衰减\n", 1442 | " ratio = max(ratio-decay_rate, 0)\n", 1443 | " pred = model(data, teacher_forcing=True, scheduled_sampling_ratio=ratio, train=True)\n", 1444 | " loss = criterion(pred, labels)\n", 1445 | " losses += loss.cpu().detach().numpy()\n", 1446 | " loss.backward()\n", 1447 | " optimizer.step()\n", 1448 | " train_loss = losses / len(trainloader)\n", 1449 | " train_losses.append(train_loss)\n", 1450 | " print('Training Loss: {:.3f}'.format(train_loss))\n", 1451 | " \n", 1452 | " # 模型验证\n", 1453 | " model.eval()\n", 1454 | " losses = 0\n", 1455 | " with torch.no_grad():\n", 1456 | " for i, data in tqdm(enumerate(validloader)):\n", 1457 | " data, labels = data\n", 1458 | " data = data.to(device)\n", 1459 | " labels = labels.to(device)\n", 1460 | " pred = model(data, train=False)\n", 1461 | " loss = criterion(pred, labels)\n", 1462 | " losses += loss.cpu().detach().numpy()\n", 1463 | " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n", 1464 | " valid_loss = losses / len(validloader)\n", 1465 | " valid_losses.append(valid_loss)\n", 1466 | " print('Validation Loss: {:.3f}'.format(valid_loss))\n", 1467 | " s = score(y_valid, preds)\n", 1468 | " scores.append(s)\n", 1469 | " print('Score: {:.3f}'.format(s))\n", 1470 | " \n", 1471 | " # 保存最佳模型权重\n", 1472 | " if s > best_score:\n", 1473 | " best_score = s\n", 1474 | " checkpoint = {'best_score': s,\n", 1475 | " 'state_dict': model.state_dict()}\n", 1476 | " torch.save(checkpoint, model_weights)" 1477 | ] 1478 | }, 1479 | { 1480 | "cell_type": "code", 1481 | "execution_count": 28, 1482 | "metadata": { 1483 | "execution": { 1484 | "iopub.execute_input": "2021-11-29T04:55:33.957872Z", 1485 | "iopub.status.busy": "2021-11-29T04:55:33.957066Z", 1486 | "iopub.status.idle": "2021-11-29T04:55:33.960119Z", 1487 | "shell.execute_reply": "2021-11-29T04:55:33.959684Z", 1488 | "shell.execute_reply.started": "2021-11-28T14:00:36.33194Z" 1489 | }, 1490 | "papermill": { 1491 | "duration": 2.38263, 1492 | "end_time": "2021-11-29T04:55:33.960247", 1493 | "exception": false, 1494 | "start_time": "2021-11-29T04:55:31.577617", 1495 | "status": "completed" 1496 | }, 1497 | "tags": [] 1498 | }, 1499 | "outputs": [], 1500 | "source": [ 1501 | "# 绘制训练/验证曲线\n", 1502 | "def training_vis(train_losses, valid_losses):\n", 1503 | " # 绘制损失函数曲线\n", 1504 | " fig = plt.figure(figsize=(8,4))\n", 1505 | " # subplot loss\n", 1506 | " ax1 = fig.add_subplot(121)\n", 1507 | " ax1.plot(train_losses, label='train_loss')\n", 1508 | " ax1.plot(valid_losses,label='val_loss')\n", 1509 | " ax1.set_xlabel('Epochs')\n", 1510 | " ax1.set_ylabel('Loss')\n", 1511 | " ax1.set_title('Loss on Training and Validation Data')\n", 1512 | " ax1.legend()\n", 1513 | " plt.tight_layout()" 1514 | ] 1515 | }, 1516 | { 1517 | "cell_type": "code", 1518 | "execution_count": 29, 1519 | "metadata": { 1520 | "execution": { 1521 | "iopub.execute_input": "2021-11-29T04:55:38.227343Z", 1522 | "iopub.status.busy": "2021-11-29T04:55:38.226636Z", 1523 | "iopub.status.idle": "2021-11-29T04:55:38.470256Z", 1524 | "shell.execute_reply": "2021-11-29T04:55:38.469252Z", 1525 | "shell.execute_reply.started": "2021-11-28T14:00:43.42651Z" 1526 | }, 1527 | "papermill": { 1528 | "duration": 2.378943, 1529 | "end_time": "2021-11-29T04:55:38.470387", 1530 | "exception": false, 1531 | "start_time": "2021-11-29T04:55:36.091444", 1532 | "status": "completed" 1533 | }, 1534 | "tags": [] 1535 | }, 1536 | "outputs": [ 1537 | { 1538 | "data": { 1539 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsYUlEQVR4nO3dd3wUdf7H8dcnhSRApIYSeu8QMKKIngoWRATPE1AsgIVT0cMu3umpWE899TwVT08BFQsWDsQu4A89UQiYUKRzlBBCDwQlkPL5/TETWCCBhOxkspvP8/HYR2ZnZmfeu5l8MvPd78yIqmKMMaEmwu8AxhhzIqx4GWNCkhUvY0xIsuJljAlJVryMMSHJipcxJiRZ8aqkRORzERke7Hn9JCLrRORcD5b7rYhc7w5fKSJflWTeE1hPUxHZKyKRJ5q1MqlUxcurjbu8uBt24aNARPYFPL+yNMtS1QtVdVKw562IRGSsiMwpYnxdETkgIp1LuixVnayq5wcp12Hbo6puUNXqqpofjOUfsS4VkV/dbWWHiMwUkaGleP3ZIpIe7FxlUamKV6hzN+zqqlod2ABcHDBucuF8IhLlX8oK6W3gdBFpccT4y4HFqrrEh0x+6OZuO+2AicCLIvKgv5FOnBUvQERiROR5EclwH8+LSIw7ra6IzBCRLBHZKSLfiUiEO+1eEdkkItkiskJE+haz/Boi8qaIbBOR9SJyf8AyRojI9yLyjIjsEpH/iciFpcx/toiku3kygQkiUsvNvc1d7gwRaRzwmsBDoWNmKOW8LURkjvuZfCMiL4nI28XkLknGR0Tkv+7yvhKRugHTr3Y/zx0i8pfiPh9VTQdmAVcfMeka4M3j5Tgi8wgR+T7g+XkislxEdovIi4AETGslIrPcfNtFZLKI1HSnvQU0BT5x94buEZHm7h5SlDtPoohMd7e71SJyQ8CyHxKRKe52lS0iS0UkubjP4IjPY7uqvgXcBNwnInXcZY4UkWXu8taKyB/d8dWAz4FEObSnnygiPUVkrvu3sVlEXhSRKiXJEAxWvBx/AU4DkoBuQE/gfnfanUA6kADUB/4MqIi0A24BTlHVeOACYF0xy/8nUANoCZyF80czMmD6qcAKoC7wFPC6iMiRCzmOBkBtoBkwCud3O8F93hTYB7x4jNeXJsOx5n0HmAfUAR7i6IIRqCQZh+F8VvWAKsBdACLSERjvLj/RXV+RBcc1KTCL+/tLcvOW9rMqXEZd4GOcbaUusAboHTgL8ISbrwPQBOczQVWv5vC956eKWMV7ONteInAZ8LiI9AmYPtCdpyYwvSSZjzANiMLZ3gG2AgOAk3A+8+dEpIeq/gpcCGQE7OlnAPnA7e577wX0BW4uZYYTp6qV5oFTXM4tYvwaoH/A8wuAde7wOJxfcusjXtMa55d9LhB9jHVGAgeAjgHj/gh86w6PAFYHTKsKKNCgpO8FONtdR+wx5k8CdgU8/xa4viQZSjovzh9+HlA1YPrbwNsl/P0UlfH+gOc3A1+4w38F3guYVs39DI76/Qbk3AOc7j5/DJh2gp/V9+7wNcCPAfMJTrG5vpjlXgL8XNz2CDR3P8sonEKXD8QHTH8CmOgOPwR8EzCtI7DvGJ+tcsQ27I7PBK4s5jX/AcYEbGPpx/n93QZMLcnvOhgP2/NyJALrA56vd8cBPA2sBr5yd6XHAqjqapxf1kPAVhF5T0QSOVpdILqI5TcKeJ5ZOKCqv7mD1Uv5Hrapak7hExGpKiL/cg+r9gBzgJpS/DdZpclQ3LyJwM6AcQAbiwtcwoyZAcO/BWRKDFy2OnsHO4pbl5vpA+Aady/xSuDNUuQoypEZNPC5iNR3t4tN7nLfxtkeSqLws8wOGFfsdoPz2cRKKdo7RSQa54hip/v8QhH50T1MzQL6HyuviLR1D7Ez3ff3+LHmDzYrXo4MnEOGQk3dcahqtqreqaotcXbT7xC3bUtV31HVM9zXKvC3Ipa9HcgtYvmbgvwejrw8yJ04DbOnqupJwO/c8aU9HC2NzUBtEakaMK7JMeYvS8bNgct211nnOK+ZBAwBzgPigU/KmOPIDMLh7/dxnN9LF3e5Vx2xzGNd0iUD57OMDxgX7O1mEM6e8jxx2ng/Ap4B6qtqTeCzgLxFZR0PLAfauO/vz3i7fR2mMhavaBGJDXhEAe8C94tIgtuO8Vec/5KIyAARae1umLtxduULRKSdiPRxf+k5OO0kBUeuTJ2vvacAj4lIvIg0A+4oXL6H4t1MWSJSG/D8WyVVXQ+kAA+JSBUR6QVc7FHGD4EBInKG20g8juNvz98BWcCrOIecB8qY41Ogk4hc6m5Hf8I5fC4UD+wFdotII+DuI16/Bacd9CiquhH4AXjC3U67AtcRhO1GRGqL07XmJeBvqroDpz0xBtgG5InzJUxgl5AtQB0RqXHE+9sD7BWR9jhfAJSbyli8PsPZUAsfDwGP4vzRLQIWAwvdcQBtgG9wNsK5wMuqOhvnF/0kzp5VJk6D8n3FrPNW4FdgLfA9TiPxG8F9W0d5Hohz8/0IfOHx+gpdidN4uwPnM3wf2F/MvM9zghlVdSkwGuez3AzswmlvOtZrFOdQsZn7s0w5VHU7MBhnO9iBs638N2CWh4EeOP/0PsVp3A/0BM4/zSwRuauIVVyB0w6WAUwFHlTVb0qSrRhpIrIXpxnkeuB2Vf2r+16ycYrvFJzPchjOlwCF73U5zj/5tW7eRJwvT4YB2cBrOL/rciNuQ5sxnhCR94Hlqhqy/YlMxVQZ97yMh0TkFHH6N0WISD+cdpX/+BzLhCHriW2CrQHO4VEdnMO4m1T1Z38jmXBkh43GmJBkh43GmJAUEoeNdevW1ebNm/sdwxhTzhYsWLBdVROKmhYSxat58+akpKT4HcMYU85EZH1x0+yw0RgTkqx4GWNCkhUvY0xICok2L2MqotzcXNLT08nJyTn+zOaYYmNjady4MdHR0SV+jRUvY05Qeno68fHxNG/enNJfO9IUUlV27NhBeno6LVoceaXu4tlhozEnKCcnhzp16ljhKiMRoU6dOqXeg7XiZUwZWOEKjhP5HMOreOXsgR/Hg53yZEzYC6/itXwGfDEWUicff15jTEgLr+LV9XJo2gu+uh9+3e53GmM8lZWVxcsvv1zq1/Xv35+srKxSv27EiBF8+OGHpX6dV8KreEVEwIDnYf9ep4AZE8aKK155eXnHfN1nn31GzZo1PUpVfsKvq0S99tB7DHz3DHS7Alqe5XciUwk8/MlSfsnYE9Rldkw8iQcv7lTs9LFjx7JmzRqSkpKIjo4mNjaWWrVqsXz5clauXMkll1zCxo0bycnJYcyYMYwaNQo4dK7w3r17ufDCCznjjDP44YcfaNSoEdOmTSMuLu642WbOnMldd91FXl4ep5xyCuPHjycmJoaxY8cyffp0oqKiOP/883nmmWf44IMPePjhh4mMjKRGjRrMmTMnKJ9PeO15FfrdXVC7Jcy4HXKtA6EJT08++SStWrUiNTWVp59+moULF/KPf/yDlStXAvDGG2+wYMECUlJSeOGFF9ix4+g7w61atYrRo0ezdOlSatasyUcffXTc9ebk5DBixAjef/99Fi9eTF5eHuPHj2fHjh1MnTqVpUuXsmjRIu6/3zn6GTduHF9++SVpaWlMnz79OEsvufDb8wKIjoOLnoW3LoHv/g59ir0TvDFBcaw9pPLSs2fPwzp5vvDCC0ydOhWAjRs3smrVKurUOfzucC1atCApKQmAk08+mXXr1h13PStWrKBFixa0bdsWgOHDh/PSSy9xyy23EBsby3XXXceAAQMYMGAAAL1792bEiBEMGTKESy+9NAjv1BGee14Arc6BrkPh++dg2wq/0xjjuWrVqh0c/vbbb/nmm2+YO3cuaWlpdO/evchOoDExMQeHIyMjj9tedixRUVHMmzePyy67jBkzZtCvXz8AXnnlFR599FE2btzIySefXOQe4IkI3+IFcP5jUKUafHIbFBx1S0VjQlp8fDzZ2dlFTtu9eze1atWiatWqLF++nB9//DFo623Xrh3r1q1j9erVALz11lucddZZ7N27l927d9O/f3+ee+450tLSAFizZg2nnnoq48aNIyEhgY0bi72JeqmE52FjoeoJcP4jMP1Wp+9Xj6v9TmRM0NSpU4fevXvTuXNn4uLiqF+//sFp/fr145VXXqFDhw60a9eO0047LWjrjY2NZcKECQwePPhgg/2NN97Izp07GTRoEDk5Oagqzz77LAB33303q1atQlXp27cv3bp1C0qOkLgBR3Jysp7wlVRVYeJFsGUp3JLiFDRjgmDZsmV06NDB7xhho6jPU0QWqGpyUfOH92EjgAgMeA4O/ApfWcO9MeEi/IsXQEI7OON2WPQ+rJntdxpjKrTRo0eTlJR02GPChAl+xzpKeLd5BTrzTljyIXx6B9z0g9OdwhhzlJdeesnvCCVSOfa8AKJjncPHnWthzjN+pzHGlFHlKV4ALc92Thn67z9g6zK/0xhjyqByFS+A8x+FmOrOqUPW98uYkFX5ile1uk4B2zAXfn7L7zTGmBPkefESkUgR+VlEZrjPW4jITyKyWkTeF5EqXmc4StKV0OwM+PoB2Lu13FdvjB+qV69e7LR169bRuXPnckxTduWx5zUGCGxg+hvwnKq2BnYB15VDhsMV9v3K3Qdf/rncV2+MKTtPu0qISGPgIuAx4A5xrrLfBxjmzjIJeAgY72WOIiW0hTPugP970mnEb9233COYMPL5WMhcHNxlNugCFz5Z7OSxY8fSpEkTRo8eDcBDDz1EVFQUs2fPZteuXeTm5vLoo48yaNCgUq02JyeHm266iZSUFKKionj22Wc555xzWLp0KSNHjuTAgQMUFBTw0UcfkZiYyJAhQ0hPTyc/P58HHniAoUOHlultl5TXe17PA/cAhS3jdYAsVS08dT0daFTUC0VklIikiEjKtm3bvEl3xu1Qp7XT9yt3nzfrMMYjQ4cOZcqUKQefT5kyheHDhzN16lQWLlzI7NmzufPOOyntKYAvvfQSIsLixYt59913GT58ODk5ObzyyiuMGTOG1NRUUlJSaNy4MV988QWJiYmkpaWxZMmSg1eSKA+e7XmJyABgq6ouEJGzS/t6VX0VeBWccxuDm85V2Pdr0sUw52no+1dPVmMqgWPsIXmle/fubN26lYyMDLZt20atWrVo0KABt99+O3PmzCEiIoJNmzaxZcsWGjRoUOLlfv/999x6660AtG/fnmbNmrFy5Up69erFY489Rnp6Opdeeilt2rShS5cu3Hnnndx7770MGDCAM88806u3exQv97x6AwNFZB3wHs7h4j+AmiJSWDQbA5s8zHB8LX4H3YY5fb+2/OJrFGNKa/DgwXz44Ye8//77DB06lMmTJ7Nt2zYWLFhAamoq9evXL/XNXIszbNgwpk+fTlxcHP3792fWrFm0bduWhQsX0qVLF+6//37GjRsXlHWVhGfFS1XvU9XGqtocuByYpapXArOBy9zZhgPTvMpQYuc/CjEnwYzbrO+XCSlDhw7lvffe48MPP2Tw4MHs3r2bevXqER0dzezZs1m/fn2pl3nmmWcyebJz+8CVK1eyYcMG2rVrx9q1a2nZsiV/+tOfGDRoEIsWLSIjI4OqVaty1VVXcffdd7Nw4cJgv8Vi+dHP616cxvvVOG1gr/uQ4XDV6sAFj8HGn2DhJL/TGFNinTp1Ijs7m0aNGtGwYUOuvPJKUlJS6NKlC2+++Sbt27cv9TJvvvlmCgoK6NKlC0OHDmXixInExMQwZcoUOnfuTFJSEkuWLOGaa65h8eLF9OzZk6SkJB5++OGD160vD+F/Pa+SUnXavjIXwej5EF//+K8xlZpdzyu47HpeJ+qwvl/3+Z3GGHMcleeSOCVRtw2ceRd8+7jTiN/mXL8TGRNUixcv5uqrD78cekxMDD/99JNPiU6cFa8jnXEbLP7A6ft1849QparfiUwFpqo4fa9DQ5cuXUhNTfU7xlFOpPnKDhuPFBUDFz8PWethzlN+pzEVWGxsLDt27DihPzxziKqyY8cOYmNjS/U62/MqSvMzIOkq+OGf0GUw1Pf/hqKm4mncuDHp6el4dgZIJRIbG0vjxo1L9RorXsU5/xFY+blzz8drv4QI20k1h4uOjj7sDtWmfNlfZHGq1oYLHof0ebCg4t18wJjKzorXsXQd6pw+9M3DkJ3pdxpjTAArXsciAhc9B3k58IX1/TKmIrHidTx1W8Pv7oKlH8Oqr/1OY4xxWfEqid5joG47mHGHc+dtY4zvrHiVRFSMc+rQ7g3wf3/zO40xBiteJde8N3S/Gn54ETKX+J3GmErPildpnDcO4mrBJ2OgIN/vNMZUala8SqOw79emFEh5w+80xlRqVrxKq+sQaHk2zBwHezb7ncaYSsuKV2mJwEXPQt5++GKs32mMqbSseJ2IOq3grLvhl//Ayi/9TmNMpWTF60SdPgYS2sOnd1nfL2N8YMXrREVVgQHPO32/vn3C7zTGVDpWvMqiWS/oMRzmvgybF/mdxphKxYpXWZ37kNOFYsZt1vfLmHJkxausqtaGC56ATQtgvv+3oDSmsrDiFQxdLoNWfdy+Xxl+pzGmUrDiFQwicNHfoSAXPr/X7zTGVApWvIKldks46x5YNh1WfO53GmPCnhWvYOp1KyR0gM/uhv17/U5jTFiz4hVMUVWcez7u3mh9v4zxmBWvYGt6Gpw8En58GTan+Z3GmLBlxcsL5z4IVevadb+M8ZAVLy/E1YJ+T0DGzzDvNb/TGBOWrHh5pfMfoFVfmPUI7N7kdxpjwo4VL6+IwIBnncPGz+/xO40xYceKl5dqNYez74XlM2D5p36nMSasWPHyWq9boF5Ht+9Xtt9pjAkbVry8FhkNF/8D9myC2Y/7ncaYsGHFqzw06QnJ18JPrzjfQBpjysyKV3np+yBUS3D6fuXn+Z3GmJBnxau8xNWEfk86ve7nW98vY8rKs+IlIrEiMk9E0kRkqYg87I5vISI/ichqEXlfRKp4laHC6fR7aH0ezHoUdqf7ncaYkOblntd+oI+qdgOSgH4ichrwN+A5VW0N7AKu8zBDxSICFz3j9P36zPp+GVMWnhUvdRReFybafSjQB/jQHT8JuMSrDBVSreZwzn2w4lNYNsPvNMaELE/bvEQkUkRSga3A18AaIEtVC1us04FGXmaokE67Gep3tr5fxpSBp8VLVfNVNQloDPQE2pf0tSIySkRSRCRl27ZtXkX0R2S0c8/H7M1O+5cxptTK5dtGVc0CZgO9gJoiEuVOagwUedayqr6qqsmqmpyQkFAeMctXk1PglOvgp385dx4yxpSKl982JohITXc4DjgPWIZTxC5zZxsOTPMqQ4XX969QvT58cpv1/TKmlLzc82oIzBaRRcB84GtVnQHcC9whIquBOkDlvdlhbA248G+QuQjm/cvvNMaElKjjz3JiVHUR0L2I8Wtx2r8MQMdB0OYCmPUYdBgINZv4nciYkGA97P0mAv2fBtT59lHV70TGhAQrXhVBrWZw9n2w8nNY9onfaYwJCVa8KorTbob6XZyrrubs8TuNMRWeFa+KIjLKue5Xdqb1/TKmBKx4VSSNT4aeN8C8VyHd+n4ZcyxWvCqaPvdDfAO77pcxx2HFq6Ip7Pu1ZTH8NN7vNMZUWFa8KqIOA6Hthc4177M2+J3GmArJildFdLDvl8Cnd1nfL2OKYMWroqrZBM75M6z6En6pvKd/GlMcK14V2ak3QoOu8Pm9kLPb7zTGVChWvCqywr5fv26FmY/4ncaYCsWKV0XXqAf0HAXz/w3pKX6nMabCsOIVCs75C8Q3dPt+5fqdxpgKwYpXKIg9Cfo/BVuWwI8v+53GmArBileoaD8A2vWH2U/ArvV+pzHGd1a8QkVh3y+JgM+s75cxVrxCSY3GzrmPq76CpVP9TmOMr0pUvESkmohEuMNtRWSgiER7G80UqecoaNjN6fu1fZXfaYzxTUn3vOYAsSLSCPgKuBqY6FUocwyRUTDoZdACeK0PrPzS70TG+KKkxUtU9TfgUuBlVR0MdPIuljmmBp1h1LdQqzm8MxTmPGNtYKbSKXHxEpFewJXAp+64SG8imRKp2QSu/RK6DIZZj8AHw2H/Xr9TGVNuSlq8bgPuA6aq6lIRaYlz81jjpypV4dJX4fxHnRt3vH4e7FzrdypjyoVoKQ833Ib76qpabneJSE5O1pQUOzXmmNbMgg9GOsODJ0CrPv7mMSYIRGSBqiYXNa2k3za+IyIniUg1YAnwi4jcHcyQpoxa9XHawU5qBG//Af77grWDmbBW0sPGju6e1iXA50ALnG8cTUVSuwVc9xV0uBi+fgA+vgEO/OZ3KmM8UdLiFe3267oEmK6quYD9W6+IYqrD4EnQ96+w+EN44wK7lLQJSyUtXv8C1gHVgDki0gywO6NWVCJw5p0wbIpzHuSrZ8P/vvM7lTFBVaLipaovqGojVe2vjvXAOR5nM2XV9ny4YRZUrQtvDoKf/mXtYCZslLTBvoaIPCsiKe7j7zh7Yaaiq9sarv8G2l4An98D00ZDbo7fqYwps5IeNr4BZAND3MceYIJXoUyQxZ4EQyfDWWMhdTJMuBB2b/I7lTFlUtLi1UpVH1TVte7jYaCll8FMkEVEwDn3OUVs+0qnHWzDj36nMuaElbR47RORMwqfiEhvYJ83kYynOgyA62c630pOHAApb/idyJgTElXC+W4E3hSRGu7zXcBwbyIZz9VrDzfMho+uhxm3w+Y0uPBpiKridzJjSqyk3zamqWo3oCvQVVW7A3b+SSiLqwnD3ocz7oAFE2HSAMjO9DuVMSVWqiupquqegHMa7/AgjylPEZFw7oNw2QTIXOy0g6Uv8DuVMSVSlstAS9BSGH91vhSu+xoiq8CEfvDz234nMua4ylK8rLdjOCm8wGHTXk5fsM/usXtEmgrtmA32IpJN0UVKgDhPEhn/VK0NV30M3zwIc1+ELUthyCSoVtfvZMYc5Zh7Xqoar6onFfGIV9WSflNpQklkFFzwGPz+VdiU4rSDZaT6ncqYo3h26zMRaSIis0XkFxFZKiJj3PG1ReRrEVnl/qzlVQZTBt2GwrVfOOdCvnEBLJridyJjDuPlfRvzgDtVtSNwGjBaRDoCY4GZqtoGmOk+NxVRYnenHazRyc61wb78C+Tn+Z3KGMDD4qWqm1V1oTucDSwDGgGDgEnubJNwrhFmKqrqCXDNNOd+kXNfhMl/gN92+p3KmPK5Y7aINAe6Az8B9VV1szspE6hfzGtGFV7FYtu2beUR0xQnMhr6Pw0DX4T1PzjtYJlL/E5lKjnPi5eIVAc+Am478qYd6tz9o8guF6r6qqomq2pyQkKC1zFNSfS4GkZ+DvkHnDsVLf2P34lMJeZp8XIvHf0RMFlVP3ZHbxGRhu70hsBWLzOYIGuc7LSD1e/s3Cty5jgoyPc7lamEvPy2UYDXgWWq+mzApOkcOql7ODDNqwzGI/ENYMQM6DEcvvs7vHs57MvyO5WpZLzc8+qNc4ehPiKS6j76A08C54nIKuBc97kJNVExMPAFGPCcc8/I1/rAthV+pzKViGcdTVX1e4o//7GvV+s15Sz5WkjoAFOugdf6wqX/gvYX+Z3KVALl8m2jCXPNejntYHVbw3vD4NsnoaDA71QmzFnxMsFRo5HzTWS3K+DbJ+D9qyDH7o5nvGPFywRPdBxcMh76PQkrv4B/nwvbV/udyoQpK14muETgtJvg6qnw6zanIX/lV36nMmHIipfxRsuznHawmk3hnSHw3bN2w1sTVHZZG+OdWs3guq9g+i0w82HnRh+XvAxV7H7FAKrKwg27mJaawc8bslAUQRD3O3oBEDn4lb2IM07cGeTguEMvKBznPHWWdeRzDi5Pjpj/0DgOjpcjph9aX+Gyj17/kfkOrS86Unjqsm4n/qEFsOJlvFWlKvzhdWjYDb55CLavgssnQ+0WfifzzfLMPUxLzeCTtAzSd+0jJiqCU5rXJjrS+TNXnJ3Uwv1UdfdYC3dcFT00rEc8B7TAGYc7XgOXcdhynCeH1lPUsg+9NnCewOmFGfXgQg+97shlR0cG72DPipfxngj0HgP1O8GH18Jr5zg3/Wh1jt/Jys3Gnb8xPS2D6akZrNiSTWSE0Lt1XW4/ty3nd6pPfGy03xFDjmgItEMkJydrSkqK3zFMMOxYA+9dCdtXwHmPQK/RHHYsE0a2Ze/n00UZTE/LYOGGLACSm9ViYFIi/bs0pG71GH8DhgARWaCqyUVNsz0vU77qtILrv4b/3ARf/cVpBxv4gtPNIgzsycnlyyWZTE/L4L+rt1Og0L5BPPf0a8fFXRNpUruq3xHDhhUvU/5i4mHwm85J3bMfc/bChk6Gmk38TnZCcnLzmb18K9PTMpi5fCsH8gpoUjuOm85uxcBujWjXIN7viGHJipfxR0QEnHW3c8u1j0c5FzgcMgman+F3shLJyy9g7todTEvN4MslmWTvz6Nu9SoM69mUgUmJdG9S8+C3bsYbVryMv9pdCNfPdM6JfHMQXPAE9LyhQraDqSo/b8xiemoGMxZlsH3vAeJjorigcwMGJSXSq2UdooL4bZo5Nitexn8JbeGGmc4e2Od3Q2Ya9P87RMf6nQyAlVuymZa6ielpGWzcuY8qURH0bV+PQUmJnN2uHrHRkX5HrJSseJmKIbYGXP6uc1L3nKdg63IY+haclOhLnI07f+OTRU7XhuWZ2UQI9G5dlzF9na4NJ1nXBt9Z8TIVR0QE9PkLNOgCU2902sF+/wrUbonTfVtK/7MU827/9QCfL8nkk7RMUjZkoUD3JrV4eGAn+ndpSEK8dW2oSKyfl6mYtvzitIPt+p/fSQIUVfQiSlAYi3ttYJGNAIl0CrhEQkRkwDh3OCIy4PmxxkcUMV9x4wNef9iySrOM0oyPgianlPwTt35eJuTU7+ic2L3qayjIdc9p0cN/asHR46DoeQN+5ubns2brXpZl7GbN1mzyC5SacZF0aBhPhwbVSagec9xlHPPnCb2mwHkU5IPmuz8L3OGCgHEB0wryQQ8UMz5w/qJeX8z4om/mFTyRVeCB4NzK0IqXqbjiakLXwUFZVH6BMnfNDqanbeLzJZlk5+RRp1oVBiQ3ZGBSIj2a1rKuDXCokB5V6IorjKUsmEFkxcuELVUldWMW09MymLFoM9uy91M9JooLOjVgYFIivVtZ14ajiBw6zKvgrHiZsLN6azbTUjOYlprBhp2/USUygj7t6zEwKZE+7a1rQ7iw4mXCwqasfXyS5hSsZZv3ECFwequ63NKnNRd0akCNOOvaEG6seJmQtfPXA3y6eDPTUzcxf90uAJKa1OTBiztyUdeG1IuvGJ1cjTeseJmQsnd/Hl//ksm01Ay+X7WdvAKlTb3q3HV+Wy7ulkizOnaV1srCipep8Pbn5fN/K7YxPS2Db5ZtISe3gEY147j+zJYMSkqkfYN4+6awErLiZSqs9F2/8eKs1Xy2eDN7cvKoXa0Kg09uwiC3a0NEhBWsysyKl6mQFqVnce3EFH7dn8eFnd2uDa3rBvUa6Ca0WfEyFc7MZVu45Z2fqVO9Cu+N6k3renYxP3M0K16mQnlr7joenL6UTok1eH1Esn1jaIplxctUCAUFyt++WM6/5qzl3A71eOGK7lStYpunKZ5tHcZ3Obn53PlBGp8u2szVpzXjoYGdiLTGeHMcVryMr3b9eoAb3kwhZf0u/ty/PTec2dK6PZgSseJlfLN+x6+MnDCf9Kx9vDSsBxd1beh3JBNCrHgZX/y8YRfXT0ohX5V3rj+V5Oa1/Y5kQowVL1PuvlyayZj3fqZefCwTR55Cy4TqfkcyIciKlylXb3z/Px759Be6Na7J68OTqWO3vDcnyIqXKRf5Bcpjny7jjf/+jws61ef5od2Jq2LX1TInzoqX8VxObj63vZfKF0szGdm7Ofdf1NG6Qpgys+JlPLVj736ufzOF1I1Z/HVAR649o4XfkUyYsOJlPLN2215GTpxP5u4cxl95Mv06N/A7kgkjnp2iLyJviMhWEVkSMK62iHwtIqvcn7W8Wr/x14L1O/nD+B/Izsnj3VGnWeEyQefl9UUmAv2OGDcWmKmqbYCZ7nMTZj5dtJkrXvuJmlWrMPXm0+nR1P5HmeDzrHip6hxg5xGjBwGT3OFJwCVerd+UP1XltTlrGf3OQro2qsFHN51ul2U2ninvNq/6qrrZHc4E6hc3o4iMAkYBNG3atByimbLIL1DGfbKUSXPXc1GXhvx9SDe7xZjxlG+XpVQtvNd5sdNfVdVkVU1OSEgox2SmtH47kMcf31rApLnrGfW7lvzziu5WuIznynvPa4uINFTVzSLSENhazus3QbYtez/XTZrPkk27eWRQJ67u1dzvSKaSKO89r+nAcHd4ODCtnNdvgmj11r38/uX/smrLXl69OtkKlylXnu15ici7wNlAXRFJBx4EngSmiMh1wHpgiFfrN976ae0ORr21gOhI4f0/nkbXxjX9jmQqGc+Kl6peUcykvl6t05SPaambuPuDRTSpHcfEkT1pUruq35FMJWQ97E2JqSrj/28NT32xgp4tavPa1cnUqBrtdyxTSVnxMiWSl1/AA9OW8u68DQzslsjTg7sSE2XfKBr/WPEyx/Xr/jxueWchs1ds4+azW3HX+e3sbtXGd1a8zDFt2ZPDtRPnszwzm8d/34Vhp1qHYVMxWPEyxVq5JZsRb8wja18u/x6ezDnt6vkdyZiDrHiZIv2wejt/fHsBcdGRTPljLzo3quF3JGMOY8XLHOXjhenc+9EiWtStxoSRPWlUM87vSMYcxYqXOUhV+ees1Tz79UpOb1WH8VedTI046wphKiYrXgaA3PwC/jJ1MVNS0rm0RyOevLQrVaJ8O2/fmOOy4mXIzsnl5skL+W7Vdv7Utw23n9sGEesKYSo2K16V3Obd+xg5YT6rt+7lqT90ZcgpTfyOZEyJWPGqxJZt3sPICfPZuz+PCSNP4cw2dt00EzqseFVSc1Zu4+bJC6keE8UHN/aiQ8OT/I5kTKlY8aqEpqRs5M8fL6Z1vepMGHkKDWtYVwgTeqx4VSKqynPfrOKFmas4s01dXr6yB/Gx1hXChCYrXpXEgbwCxn68iI8XbmJIcmMe+30XoiOtK4QJXVa8KoHd+3K56e0F/LBmB3ee15Zb+rS2rhAm5FnxCnObsvYxcsI8/rf9V54d0o1LezT2O5IxQWHFK4wt2bSbayfOZ19uPpNG9uT01nX9jmRM0FjxClOzV2xl9OSF1KpahbevP5W29eP9jmRMUFnxCkPv/LSBB6YtoX2DeCaMOIV6J8X6HcmYoLPiFUYKCpRnvlrBy9+u4Zx2Cbw4rAfVYuxXbMKTbdlhYn9ePnd/sIjpaRkMO7Up4wZ2Isq6QpgwZsUrDGT9doBRby1g3v92cm+/9tx4VkvrCmHCnhWvELdx52+MmDCPjTv38Y/LkxiU1MjvSMaUCyteIWxRehbXTpxPbr7y1nU9ObVlHb8jGVNuwqp4fbk0k3s/WkRUhBAVEUFkhBAdKe5P53lUhBAVGTgtgugIZ56oSOd1Ue5wZETEwdcXvq5w2c68RS876uCy3BwBw1FFDB+2jMgipkVEHHWfxG9+2cKt7/5MnepVeG9UT1rXq+7Tp26MP8KqeDWqGcegbonkFij5+UpuQQH5BUpevpLnDufmq/uzgP25BeQW5JNfUODOc2jaoXkLyHOXkV/gLFO1/N+bCERHHCpwe/fn0bVRDf49/BQS4mPKP5AxPgur4tW5UY1yuUVXQUFAYQwojnn5hxe/w6YdMZx/1HglL7/gqAJ61OvcaTXiohn1u5ZUrRJWv0JjSsy2/BMQESHERET6HcOYSs06AhljQpIVL2NMSLLiZYwJSVa8jDEhyYqXMSYkWfEyxoQkK17GmJBkxcsYE5JE/TjXpZREZBuwvoSz1wW2exinorD3GV7sfRatmaomFDUhJIpXaYhIiqom+53Da/Y+w4u9z9Kzw0ZjTEiy4mWMCUnhWLxe9TtAObH3GV7sfZZS2LV5GWMqh3Dc8zLGVAJWvIwxISmsipeI9BORFSKyWkTG+p3HCyLyhohsFZElfmfxkog0EZHZIvKLiCwVkTF+Z/KCiMSKyDwRSXPf58N+Z/KSiESKyM8iMqOsywqb4iUikcBLwIVAR+AKEenobypPTAT6+R2iHOQBd6pqR+A0YHSY/j73A31UtRuQBPQTkdP8jeSpMcCyYCwobIoX0BNYraprVfUA8B4wyOdMQaeqc4CdfufwmqpuVtWF7nA2zgYfdjelVMde92m0+wjLb9FEpDFwEfDvYCwvnIpXI2BjwPN0wnBjr4xEpDnQHfjJ5yiecA+lUoGtwNeqGpbvE3geuAcoCMbCwql4mTAkItWBj4DbVHWP33m8oKr5qpoENAZ6ikhnnyMFnYgMALaq6oJgLTOcitcmoEnA88buOBOiRCQap3BNVtWP/c7jNVXNAmYTnm2avYGBIrIOp0mnj4i8XZYFhlPxmg+0EZEWIlIFuByY7nMmc4JERIDXgWWq+qzfebwiIgkiUtMdjgPOA5b7GsoDqnqfqjZW1eY4f5uzVPWqsiwzbIqXquYBtwBf4jTuTlHVpf6mCj4ReReYC7QTkXQRuc7vTB7pDVyN8x861X309zuUBxoCs0VkEc4/4K9VtczdCCoDOz3IGBOSwmbPyxhTuVjxMsaEJCtexpiQZMXLGBOSrHgZY0KSFS/jKRHJD+jqkBrMq32ISPNwv7qGKV6U3wFM2NvnnvpiTFDZnpfxhYisE5GnRGSxez2r1u745iIyS0QWichMEWnqjq8vIlPd616licjp7qIiReQ191pYX7m91BGRP7nXAlskIu/59DaNh6x4Ga/FHXHYODRg2m5V7QK8iHPFAYB/ApNUtSswGXjBHf8C8H/uda96AIVnT7QBXlLVTkAW8Ad3/Figu7ucG715a8ZP1sPeeEpE9qpq9SLGr8O5CN9a9wTsTFWtIyLbgYaqmuuO36yqdd27pjdW1f0By2iOczpNG/f5vUC0qj4qIl8Ae4H/AP8JuGaWCRO252X8pMUMl8b+gOF8DrXjXoRzZd0ewHwRsfbdMGPFy/hpaMDPue7wDzhXHQC4EvjOHZ4J3AQHL95Xo7iFikgE0ERVZwP3AjWAo/b+TGiz/0bGa3HuVUILfaGqhd0larlXU9gPXOGOuxWYICJ3A9uAke74McCr7lU08nEK2eZi1hkJvO0WOAFecK+VZcKItXkZX7htXsmqut3vLCY02WGjMSYk2Z6XMSYk2Z6XMSYkWfEyxoQkK17GmJBkxcsYE5KseBljQtL/A/w8USoPr020AAAAAElFTkSuQmCC\n", 1540 | "text/plain": [ 1541 | "
" 1542 | ] 1543 | }, 1544 | "metadata": { 1545 | "needs_background": "light" 1546 | }, 1547 | "output_type": "display_data" 1548 | } 1549 | ], 1550 | "source": [ 1551 | "training_vis(train_losses, valid_losses)" 1552 | ] 1553 | }, 1554 | { 1555 | "cell_type": "markdown", 1556 | "metadata": {}, 1557 | "source": [ 1558 | "#### 模型评估\n", 1559 | "\n", 1560 | "在测试集上评估模型效果。" 1561 | ] 1562 | }, 1563 | { 1564 | "cell_type": "code", 1565 | "execution_count": 31, 1566 | "metadata": { 1567 | "execution": { 1568 | "iopub.execute_input": "2021-11-29T04:55:47.340416Z", 1569 | "iopub.status.busy": "2021-11-29T04:55:47.339537Z", 1570 | "iopub.status.idle": "2021-11-29T04:55:47.606447Z", 1571 | "shell.execute_reply": "2021-11-29T04:55:47.607038Z", 1572 | "shell.execute_reply.started": "2021-11-28T14:01:44.127754Z" 1573 | }, 1574 | "papermill": { 1575 | "duration": 2.453872, 1576 | "end_time": "2021-11-29T04:55:47.607210", 1577 | "exception": false, 1578 | "start_time": "2021-11-29T04:55:45.153338", 1579 | "status": "completed" 1580 | }, 1581 | "tags": [] 1582 | }, 1583 | "outputs": [ 1584 | { 1585 | "data": { 1586 | "text/plain": [ 1587 | "" 1588 | ] 1589 | }, 1590 | "execution_count": 31, 1591 | "metadata": {}, 1592 | "output_type": "execute_result" 1593 | } 1594 | ], 1595 | "source": [ 1596 | "# 加载得分最高的模型\n", 1597 | "checkpoint = torch.load('../input/ai-earth-model-weights/task05_model_weights.pth')\n", 1598 | "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n", 1599 | "model.load_state_dict(checkpoint['state_dict'])" 1600 | ] 1601 | }, 1602 | { 1603 | "cell_type": "code", 1604 | "execution_count": 32, 1605 | "metadata": { 1606 | "execution": { 1607 | "iopub.execute_input": "2021-11-29T04:55:51.849996Z", 1608 | "iopub.status.busy": "2021-11-29T04:55:51.849073Z", 1609 | "iopub.status.idle": "2021-11-29T04:55:51.851492Z", 1610 | "shell.execute_reply": "2021-11-29T04:55:51.850969Z", 1611 | "shell.execute_reply.started": "2021-11-28T14:06:59.931318Z" 1612 | }, 1613 | "papermill": { 1614 | "duration": 2.125413, 1615 | "end_time": "2021-11-29T04:55:51.851629", 1616 | "exception": false, 1617 | "start_time": "2021-11-29T04:55:49.726216", 1618 | "status": "completed" 1619 | }, 1620 | "tags": [] 1621 | }, 1622 | "outputs": [], 1623 | "source": [ 1624 | "# 测试集路径\n", 1625 | "test_path = '../input/ai-earth-tests/'\n", 1626 | "# 测试集标签路径\n", 1627 | "test_label_path = '../input/ai-earth-tests-labels/'" 1628 | ] 1629 | }, 1630 | { 1631 | "cell_type": "code", 1632 | "execution_count": 33, 1633 | "metadata": { 1634 | "execution": { 1635 | "iopub.execute_input": "2021-11-29T04:55:56.429364Z", 1636 | "iopub.status.busy": "2021-11-29T04:55:56.428800Z", 1637 | "iopub.status.idle": "2021-11-29T04:55:58.115486Z", 1638 | "shell.execute_reply": "2021-11-29T04:55:58.115007Z", 1639 | "shell.execute_reply.started": "2021-11-28T14:07:13.415385Z" 1640 | }, 1641 | "papermill": { 1642 | "duration": 4.135325, 1643 | "end_time": "2021-11-29T04:55:58.115667", 1644 | "exception": false, 1645 | "start_time": "2021-11-29T04:55:53.980342", 1646 | "status": "completed" 1647 | }, 1648 | "tags": [] 1649 | }, 1650 | "outputs": [], 1651 | "source": [ 1652 | "import os\n", 1653 | "\n", 1654 | "# 读取测试数据和测试数据的标签\n", 1655 | "files = os.listdir(test_path)\n", 1656 | "X_test = []\n", 1657 | "y_test = []\n", 1658 | "for file in files:\n", 1659 | " X_test.append(np.load(test_path + file))\n", 1660 | " y_test.append(np.load(test_label_path + file))" 1661 | ] 1662 | }, 1663 | { 1664 | "cell_type": "code", 1665 | "execution_count": 34, 1666 | "metadata": { 1667 | "execution": { 1668 | "iopub.execute_input": "2021-11-29T04:56:02.560786Z", 1669 | "iopub.status.busy": "2021-11-29T04:56:02.559461Z", 1670 | "iopub.status.idle": "2021-11-29T04:56:02.587431Z", 1671 | "shell.execute_reply": "2021-11-29T04:56:02.588024Z", 1672 | "shell.execute_reply.started": "2021-11-28T14:07:17.046359Z" 1673 | }, 1674 | "papermill": { 1675 | "duration": 2.329175, 1676 | "end_time": "2021-11-29T04:56:02.588201", 1677 | "exception": false, 1678 | "start_time": "2021-11-29T04:56:00.259026", 1679 | "status": "completed" 1680 | }, 1681 | "tags": [] 1682 | }, 1683 | "outputs": [ 1684 | { 1685 | "data": { 1686 | "text/plain": [ 1687 | "((103, 12, 24, 48, 1), (103, 24))" 1688 | ] 1689 | }, 1690 | "execution_count": 34, 1691 | "metadata": {}, 1692 | "output_type": "execute_result" 1693 | } 1694 | ], 1695 | "source": [ 1696 | "X_test = np.array(X_test)[:, :, :, 19: 67, :1]\n", 1697 | "y_test = np.array(y_test)\n", 1698 | "X_test.shape, y_test.shape" 1699 | ] 1700 | }, 1701 | { 1702 | "cell_type": "code", 1703 | "execution_count": 35, 1704 | "metadata": { 1705 | "execution": { 1706 | "iopub.execute_input": "2021-11-29T04:56:07.675344Z", 1707 | "iopub.status.busy": "2021-11-29T04:56:07.674488Z", 1708 | "iopub.status.idle": "2021-11-29T04:56:07.682481Z", 1709 | "shell.execute_reply": "2021-11-29T04:56:07.682895Z", 1710 | "shell.execute_reply.started": "2021-11-28T14:07:31.503452Z" 1711 | }, 1712 | "papermill": { 1713 | "duration": 2.455352, 1714 | "end_time": "2021-11-29T04:56:07.683041", 1715 | "exception": false, 1716 | "start_time": "2021-11-29T04:56:05.227689", 1717 | "status": "completed" 1718 | }, 1719 | "tags": [] 1720 | }, 1721 | "outputs": [], 1722 | "source": [ 1723 | "testset = AIEarthDataset(X_test, y_test)\n", 1724 | "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)" 1725 | ] 1726 | }, 1727 | { 1728 | "cell_type": "code", 1729 | "execution_count": null, 1730 | "metadata": {}, 1731 | "outputs": [], 1732 | "source": [ 1733 | "# 在测试集上评估模型效果\n", 1734 | "model.eval()\n", 1735 | "model.to(device)\n", 1736 | "preds = np.zeros((len(y_test),24))\n", 1737 | "for i, data in tqdm(enumerate(testloader)):\n", 1738 | " data, labels = data\n", 1739 | " data = data.to(device)\n", 1740 | " labels = labels.to(device)\n", 1741 | " pred = model(data, train=False)\n", 1742 | " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n", 1743 | "s = score(y_test, preds)\n", 1744 | "print('Score: {:.3f}'.format(s))" 1745 | ] 1746 | }, 1747 | { 1748 | "cell_type": "markdown", 1749 | "metadata": { 1750 | "papermill": { 1751 | "duration": null, 1752 | "end_time": null, 1753 | "exception": null, 1754 | "start_time": null, 1755 | "status": "pending" 1756 | }, 1757 | "tags": [] 1758 | }, 1759 | "source": [ 1760 | "## 总结\n", 1761 | "\n", 1762 | "这一次的TOP方案没有自己设计模型,而是使用了目前时空序列预测领域现有的模型,另一组TOP选手“ailab”也使用了现有的模型PredRNN++,关于时空序列预测领域的一些比较经典的模型可以参考https://www.zhihu.com/column/c_1208033701705162752" 1763 | ] 1764 | }, 1765 | { 1766 | "cell_type": "markdown", 1767 | "metadata": {}, 1768 | "source": [ 1769 | "## 作业\n", 1770 | "\n", 1771 | "该TOP方案中以sst作为预测目标,间接计算nino3.4指数,学有余力的同学可以尝试用SA-ConvLSTM模型直接预测nino3.4指数。" 1772 | ] 1773 | }, 1774 | { 1775 | "cell_type": "markdown", 1776 | "metadata": {}, 1777 | "source": [ 1778 | "## 参考文献\n", 1779 | "\n", 1780 | "1. 吴先生的队伍方案分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.9.561d5330dF9lX1&postId=231465\n", 1781 | "2. ailab团队思路分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.15.561d5330dF9lX1&postId=210734" 1782 | ] 1783 | } 1784 | ], 1785 | "metadata": { 1786 | "kernelspec": { 1787 | "display_name": "Python 3", 1788 | "language": "python", 1789 | "name": "python3" 1790 | }, 1791 | "language_info": { 1792 | "codemirror_mode": { 1793 | "name": "ipython", 1794 | "version": 3 1795 | }, 1796 | "file_extension": ".py", 1797 | "mimetype": "text/x-python", 1798 | "name": "python", 1799 | "nbconvert_exporter": "python", 1800 | "pygments_lexer": "ipython3", 1801 | "version": "3.7.3" 1802 | }, 1803 | "papermill": { 1804 | "default_parameters": {}, 1805 | "duration": 6708.571081, 1806 | "end_time": "2021-11-29T04:56:15.789285", 1807 | "environment_variables": {}, 1808 | "exception": true, 1809 | "input_path": "__notebook__.ipynb", 1810 | "output_path": "__notebook__.ipynb", 1811 | "parameters": {}, 1812 | "start_time": "2021-11-29T03:04:27.218204", 1813 | "version": "2.3.3" 1814 | } 1815 | }, 1816 | "nbformat": 4, 1817 | "nbformat_minor": 5 1818 | } 1819 | -------------------------------------------------------------------------------- /Task5/fig/Task5-LSTM与ConvLSTM公式比较.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task5/fig/Task5-LSTM与ConvLSTM公式比较.png -------------------------------------------------------------------------------- /Task5/fig/Task5-SA-ConvLSTM模型.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task5/fig/Task5-SA-ConvLSTM模型.png -------------------------------------------------------------------------------- /Task5/fig/Task5-SAM模块.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task5/fig/Task5-SAM模块.png -------------------------------------------------------------------------------- /Task5/fig/Task5-Seq2Seq基础结构.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task5/fig/Task5-Seq2Seq基础结构.png --------------------------------------------------------------------------------