├── README-图1.png
├── README-图2.png
├── README-图3.png
├── README-图4.png
├── README-图5.png
├── README.md
├── README.pdf
├── Task1
└── Task1 气象数据分析常用工具.ipynb
├── Task2
└── Task2 数据分析.ipynb
├── Task3
├── Task3 模型建立之CNN+LSTM.ipynb
└── fig
│ ├── Task3-CNN+LSTM模型.png
│ ├── Task3-样本拼接示意图.png
│ └── Task3-滑窗构造训练样本.png
├── Task4
├── Task4 模型建立之TCNN+RNN.ipynb
└── fig
│ ├── Task4-CNN单元.png
│ ├── Task4-TCNN+RNN模型.png
│ ├── Task4-TCNN层.png
│ ├── Task4-TCN单元.png
│ ├── Task4-扩张卷积.png
│ └── Task4-残差连接.png
└── Task5
├── Task5 模型建立之SA-ConvLSTM.ipynb
└── fig
├── Task5-LSTM与ConvLSTM公式比较.png
├── Task5-SA-ConvLSTM模型.png
├── Task5-SAM模块.png
└── Task5-Seq2Seq基础结构.png
/README-图1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图1.png
--------------------------------------------------------------------------------
/README-图2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图2.png
--------------------------------------------------------------------------------
/README-图3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图3.png
--------------------------------------------------------------------------------
/README-图4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图4.png
--------------------------------------------------------------------------------
/README-图5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README-图5.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 竞赛介绍
2 |
3 | 2021 “AI Earth” 人工智能创新挑战赛,以 “AI 助力精准气象和海洋预测” 为主题,旨在探索人工智能技术在气象和海洋领域的应用。
4 |
5 | 本赛题的背景是厄尔尼诺 - 南方涛动(ENSO)现象。ENSO现象是厄尔尼诺(EN)现象和南方涛动(SO)现象的合称,其中厄尔尼诺现象是指赤道中东太平洋附近的海表面温度持续异常增暖的现象,南方涛动现象是指热带东太平洋与热带西太平洋气压场存在的气压变化相反的跷跷板现象。厄尔尼诺现象和南方涛动现象实际是反常气候分别在海洋和大气中的表现,二者密切相关,因此合称为厄尔尼诺 - 南方涛动现象。
6 |
7 | ENSO现象会在世界大部分地区引起极端天气,对全球的天气、气候以及粮食产量具有重要的影响,准确预测ENSO,是提高东亚和全球气候预测水平和防灾减灾的关键。Nino3.4指数是ENSO现象监测的一个重要指标,它是指Nino3.4区(170°W - 120°W,5°S - 5°N)的平均海温距平指数,用于反应海表温度异常,若Nino3.4指数连续5个月超过0.5℃就判定为一次ENSO事件。本赛题的目标,就是基于历史气候观测和模式模拟数据,利用T时刻过去12个月(包含T时刻)的时空序列,预测未来1 - 24个月的Nino3.4指数。
8 |
9 |
10 |
11 |
图1 Nino3.4区域
12 |
13 |
14 |
15 | 图2 Nino3.4指数(图片来源于weatherzone.com.au)
16 |
17 |
18 |
19 | 图3 赛题示意图
20 |
21 | 基于以上信息可以看出,我们本期的组队学习要完成的是一个时空序列的预测任务。
22 |
23 | # 竞赛题目
24 |
25 | ## 数据简介
26 |
27 | 本赛题使用的训练数据包括CMIP5中17个模式提供的140年的历史模拟数据、CMIP6中15个模式提供的151年的历史模拟数据和美国SODA模式重建的100年的历史观测同化数据,采用nc格式保存,其中CMIP5和CMIP6分别是世界气候研究计划(WCRP)的第5次和第6次耦合模式比较计划,这二者都提供了多种不同的气候模式对于多种气候变量的模拟数据。这些数据包含四种气候变量:海表温度异常(SST)、热含量异常(T300)、纬向风异常(Ua)、经向风异常(Va),数据维度为(year, month, lat, lon),对于训练数据提供对应月份的Nino3.4指数标签数据。简而言之,提供的训练数据中的每个样本为某年、某月、某个维度、某个经度的SST、T300、Ua、Va数值,标签为对应年、对应月的Nino3.4指数。
28 |
29 | 需要注意的是,样本的第二维度month的长度不是12个月,而是36个月,对应从当前year开始连续三年的数据,例如SODA训练数据中year为0时包含的是从第1 - 第3年逐月的历史观测数据,year为1时包含的是从第2年 - 第4年逐月的历史观测数据,也就是说,样本在时间上是有交叉的。
30 |
31 |
32 |
33 | 图4 样本时间跨度示意图
34 |
35 | 另外一点需要注意的是,Nino3.4指数是Nino3.4区域从当前月开始连续三个月的SST平均值,也就是说,我们也可以不直接预测Nino3.4指数,而是以SST为预测目标,间接求得Nino3.4指数。
36 |
37 | 测试数据为国际多个海洋资料同化结果提供的随机抽取的$N$段长度为12个月的时间序列,数据采用npy格式保存,维度为(12, lat, lon, 4),第一维度为连续的12个月份,第四维度为4个气候变量,按SST、T300、Ua、Va的顺序存放。测试集文件序列的命名如test_00001_01_12.npy中00001表示编号,01表示起始月份,12表示终止月份。
38 |
39 | ## 评估指标
40 |
41 | 本赛题的评估指标如下:
42 | $$
43 | Score = \frac{2}{3} \times accskill - RMSE
44 | $$
45 | 其中$accskill$为相关性技巧评分,计算方式如下:
46 | $$
47 | accskill = \sum_{i=1}^{24} a \times ln(i) \times cor_i \\
48 | (i \leq 4, a = 1.5; 5 \leq i \leq 11, a = 2; 12 \leq i \leq 18, a = 3; 19 \leq i, a = 4)
49 | $$
50 | 可以看出,月份$i$增加时系数$a$也增大,也就是说,模型能准确预测的时间越长,评分就越高。
51 |
52 | $cor_i$是对于$N$个测试集样本在时刻$i$的预测值与实际值的相关系数,计算公式如下:
53 | $$
54 | cor_i = \frac{\sum_{j=1}^N(y_{truej}-\bar{y}_{true})(y_{predj}-\bar{y}_{pred})}{\sqrt{\sum(y_{truej}-\bar{y}_{true})^2\sum(y_{predj}-\bar{y}_{pred})^2}}
55 | $$
56 | 其中$y_{truej}$为时刻$i$样本$j$的实际Nino3.4指数,$\bar{y}_{true}$为该时刻$N$个测试集样本的Nino3.4指数的均值,$y_{predj}$为时刻$i$样本$j$的预测Nino3.4指数,$\bar{y}_{pred}$为该时刻$N$个测试集样本的预测Nino3.4指数的均值。
57 |
58 | $RMSE$为24个月份的累计均方根误差,计算公式为:
59 | $$
60 | RMSE = \sum_{i=1}^{24}rmse_i \\
61 | rmse = \sqrt{\frac{1}{N}\sum_{j=1}^N(y_{truej}-y_{predj})^2}
62 | $$
63 |
64 |
65 | 图5 评估指标计算示意图
66 |
67 | ## 赛题分析
68 |
69 | 分析上述赛题信息可以发现,我们需要解决的是以下问题:
70 |
71 | - 对于一个时空序列预测问题,要如何挖掘时间信息?如何挖掘空间信息?
72 | - 数据中给出的特征是四个气象领域公认的、通用的气候变量,我们很难再由此构造新的特征。如果不构造新的特征,要如何从给出的特征中挖掘出更多的信息?
73 | - 训练集的数据量不大,总共只有$140\times17+151\times15+100=4745$个训练样本,对于数据量小的预测问题,我们通常需要从以下两方面考虑:
74 | - 如何增加数据量?
75 | - 如何构造小(参数量小,减小过拟合风险)而深(能提取出足够丰富的信息)的模型?
76 |
77 | # 学习目标
78 |
79 | 我们对比赛top选手的方案进行了梳理和整合,形成了本次组队学习的五个小目标,希望你能够带着以上问题进行学习,在学习过程中找到答案。
80 |
81 | 我们希望你在本次组队学习中能有以下收获:
82 |
83 | 1. 掌握气象数据分析的常用工具。
84 | 2. 掌握时空数据的分析能力。
85 | 3. 掌握在本次组队学习中用到的模型。
86 | 4. 学会在时空序列预测问题中进行模型选择和模型构造的一些思路和方法。
87 |
88 | 同时,期待你在本次组队学习中不止局限于给出的任务,能够有更多的思考和拓展。
89 |
90 | # 组队学习安排
91 |
92 | ### Task01:气象数据分析的常用工具(2天)
93 |
94 | ### Task02:数据分析(2天)
95 |
96 | ### Task03:模型建立之 CNN + LSTM(3天)
97 |
98 | ### Task04:模型建立之 TCNN + RNN(4天)
99 |
100 | ### Task05:模型建立之 SA-ConvLSTM(4天)
101 |
102 | # 相关资料
103 |
104 | 一、比赛官网
105 |
106 | https://tianchi.aliyun.com/competition/entrance/531871/information
107 |
108 | 二、比赛开源方案
109 |
110 | 1. https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.15.561d53309Kn9hK&postId=210391(swg-lhl,Rank1)
111 | 2. https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.9.561d53309Kn9hK&postId=210734(ailab)
112 | 3. https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.6.561d53309Kn9hK&postId=210836(有源码,神之一手YueTan,Rank5)
113 | 4. https://tianchi.aliyun.com/notebook-ai/detail?spm=5176.12586969.1002.18.561d5330HKwYOW&postId=196536(有源码,学习AI的打工人)
114 | 5. https://github.com/jerrywn121/TianChi_AIEarth?spm=5176.21852664.0.0.6b612aedW6oyIQ(有源码,吴先生的队伍)
115 |
116 | # 项目贡献情况
117 |
118 | - 项目构建与整合:曾海如
119 |
--------------------------------------------------------------------------------
/README.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/README.pdf
--------------------------------------------------------------------------------
/Task3/fig/Task3-CNN+LSTM模型.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task3/fig/Task3-CNN+LSTM模型.png
--------------------------------------------------------------------------------
/Task3/fig/Task3-样本拼接示意图.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task3/fig/Task3-样本拼接示意图.png
--------------------------------------------------------------------------------
/Task3/fig/Task3-滑窗构造训练样本.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task3/fig/Task3-滑窗构造训练样本.png
--------------------------------------------------------------------------------
/Task4/Task4 模型建立之TCNN+RNN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Datawhale 气象海洋预测-Task4 模型建立之 TCNN+RNN\n",
8 | "本次任务我们将学习来自TOP选手“swg-lhl”的冠军建模方案,该方案中采用的模型是TCNN+RNN。\n",
9 | "\n",
10 | "在Task3中我们学习了CNN+LSTM模型,但是LSTM层的参数量较大,这就带来以下问题:一是参数量大的模型在数据量小的情况下容易过拟合;二是为了尽量避免过拟合,在有限的数据集下我们无法构建更深的模型,难以挖掘到更丰富的信息。相较于LSTM,CNN的参数量只与过滤器的大小有关,在各类任务中往往都有不错的表现,因此我们可以考虑同样用卷积操作来挖掘时间信息。但是如果用三维卷积来同时挖掘时间和空间信息,假设使用的过滤器大小为(T_f, H_f, W_f),那么一层的参数量就是T_f×H_f×W_f,这样的参数量仍然是比较大的。为了进一步降低每一层的参数,增加模型深度,我们本次学习的这个TOP方案对时间和空间分别进行卷积操作,即采用TCN单元挖掘时间信息,然后输入CNN单元中挖掘空间信息,将TCN单元+CNN单元的串行结构称为TCNN层,通过堆叠多层的TCNN层就可以交替地提取时间和空间信息。同时,考虑到不同时间尺度下的时空信息对预测结果的影响可能是不同的,该方案采用了三个RNN层来抽取三种时间尺度下的特征,将三者拼接起来通过全连接层预测Nino3.4指数。"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {},
16 | "source": [
17 | "## 学习目标\n",
18 | "1. 学习TOP方案的模型构建方法"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## 内容介绍\n",
26 | "1. 数据处理\n",
27 | " - 数据扁平化\n",
28 | " - 空值填充\n",
29 | " - 构造数据集\n",
30 | "2. 模型构建\n",
31 | " - 构造评估函数\n",
32 | " - 模型构造\n",
33 | " - 模型训练\n",
34 | " - 模型评估\n",
35 | "3. 总结"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## 代码示例"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {},
48 | "source": [
49 | "### 数据处理\n",
50 | "该TOP方案的数据处理主要包括三部分:\n",
51 | "1. 数据扁平化。\n",
52 | "2. 空值填充。\n",
53 | "3. 构造数据集\n",
54 | "\n",
55 | "在该方案中除了没有构造新的特征外,其他数据处理方法都与Task3基本相同,因此不多做赘述。"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 1,
61 | "metadata": {
62 | "execution": {
63 | "iopub.execute_input": "2021-11-07T11:32:41.439634Z",
64 | "iopub.status.busy": "2021-11-07T11:32:41.431440Z",
65 | "iopub.status.idle": "2021-11-07T11:32:43.645979Z",
66 | "shell.execute_reply": "2021-11-07T11:32:43.645292Z",
67 | "shell.execute_reply.started": "2021-11-07T11:22:11.718250Z"
68 | },
69 | "papermill": {
70 | "duration": 2.245092,
71 | "end_time": "2021-11-07T11:32:43.646157",
72 | "exception": false,
73 | "start_time": "2021-11-07T11:32:41.401065",
74 | "status": "completed"
75 | },
76 | "tags": []
77 | },
78 | "outputs": [],
79 | "source": [
80 | "import netCDF4 as nc\n",
81 | "import random\n",
82 | "import os\n",
83 | "from tqdm import tqdm\n",
84 | "import pandas as pd\n",
85 | "import numpy as np\n",
86 | "import matplotlib.pyplot as plt\n",
87 | "%matplotlib inline\n",
88 | "\n",
89 | "import torch\n",
90 | "from torch import nn, optim\n",
91 | "import torch.nn.functional as F\n",
92 | "from torch.utils.data import Dataset, DataLoader\n",
93 | "\n",
94 | "from sklearn.metrics import mean_squared_error"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 2,
100 | "metadata": {
101 | "execution": {
102 | "iopub.execute_input": "2021-11-07T11:32:43.703083Z",
103 | "iopub.status.busy": "2021-11-07T11:32:43.702422Z",
104 | "iopub.status.idle": "2021-11-07T11:32:43.706718Z",
105 | "shell.execute_reply": "2021-11-07T11:32:43.707099Z",
106 | "shell.execute_reply.started": "2021-11-07T11:22:17.072537Z"
107 | },
108 | "papermill": {
109 | "duration": 0.03493,
110 | "end_time": "2021-11-07T11:32:43.707228",
111 | "exception": false,
112 | "start_time": "2021-11-07T11:32:43.672298",
113 | "status": "completed"
114 | },
115 | "tags": []
116 | },
117 | "outputs": [],
118 | "source": [
119 | "# 固定随机种子\n",
120 | "SEED = 22\n",
121 | "\n",
122 | "def seed_everything(seed=42):\n",
123 | " random.seed(seed)\n",
124 | " os.environ['PYTHONHASHSEED'] = str(seed)\n",
125 | " np.random.seed(seed)\n",
126 | " torch.manual_seed(seed)\n",
127 | " torch.cuda.manual_seed(seed)\n",
128 | " torch.backends.cudnn.deterministic = True\n",
129 | " \n",
130 | "seed_everything(SEED)"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 3,
136 | "metadata": {
137 | "execution": {
138 | "iopub.execute_input": "2021-11-07T11:32:43.806316Z",
139 | "iopub.status.busy": "2021-11-07T11:32:43.805429Z",
140 | "iopub.status.idle": "2021-11-07T11:32:43.809106Z",
141 | "shell.execute_reply": "2021-11-07T11:32:43.809544Z",
142 | "shell.execute_reply.started": "2021-11-07T11:22:34.845477Z"
143 | },
144 | "papermill": {
145 | "duration": 0.077409,
146 | "end_time": "2021-11-07T11:32:43.809691",
147 | "exception": false,
148 | "start_time": "2021-11-07T11:32:43.732282",
149 | "status": "completed"
150 | },
151 | "tags": []
152 | },
153 | "outputs": [
154 | {
155 | "name": "stdout",
156 | "output_type": "stream",
157 | "text": [
158 | "CUDA is available! Training on GPU ...\n"
159 | ]
160 | }
161 | ],
162 | "source": [
163 | "# 查看GPU是否可用\n",
164 | "train_on_gpu = torch.cuda.is_available()\n",
165 | "\n",
166 | "if not train_on_gpu:\n",
167 | " print('CUDA is not available. Training on CPU ...')\n",
168 | "else:\n",
169 | " print('CUDA is available! Training on GPU ...')"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 4,
175 | "metadata": {
176 | "execution": {
177 | "iopub.execute_input": "2021-11-07T11:32:43.938252Z",
178 | "iopub.status.busy": "2021-11-07T11:32:43.937335Z",
179 | "iopub.status.idle": "2021-11-07T11:32:43.991840Z",
180 | "shell.execute_reply": "2021-11-07T11:32:43.991334Z",
181 | "shell.execute_reply.started": "2021-11-07T11:22:38.372763Z"
182 | },
183 | "papermill": {
184 | "duration": 0.156188,
185 | "end_time": "2021-11-07T11:32:43.991963",
186 | "exception": false,
187 | "start_time": "2021-11-07T11:32:43.835775",
188 | "status": "completed"
189 | },
190 | "tags": []
191 | },
192 | "outputs": [],
193 | "source": [
194 | "# 读取数据\n",
195 | "\n",
196 | "# 存放数据的路径\n",
197 | "path = '/kaggle/input/ninoprediction/'\n",
198 | "soda_train = nc.Dataset(path + 'SODA_train.nc')\n",
199 | "soda_label = nc.Dataset(path + 'SODA_label.nc')\n",
200 | "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n",
201 | "cmip_label = nc.Dataset(path + 'CMIP_label.nc')"
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {},
207 | "source": [
208 | "#### 数据扁平化\n",
209 | "采用滑窗构造数据集。"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 5,
215 | "metadata": {
216 | "execution": {
217 | "iopub.execute_input": "2021-11-07T11:32:44.053052Z",
218 | "iopub.status.busy": "2021-11-07T11:32:44.052369Z",
219 | "iopub.status.idle": "2021-11-07T11:32:44.055801Z",
220 | "shell.execute_reply": "2021-11-07T11:32:44.055258Z",
221 | "shell.execute_reply.started": "2021-11-07T11:22:41.386272Z"
222 | },
223 | "papermill": {
224 | "duration": 0.037368,
225 | "end_time": "2021-11-07T11:32:44.055938",
226 | "exception": false,
227 | "start_time": "2021-11-07T11:32:44.018570",
228 | "status": "completed"
229 | },
230 | "tags": []
231 | },
232 | "outputs": [],
233 | "source": [
234 | "def make_flatted(train_ds, label_ds, info, start_idx=0):\n",
235 | " keys = ['sst', 't300', 'ua', 'va']\n",
236 | " label_key = 'nino'\n",
237 | " # 年数\n",
238 | " years = info[1]\n",
239 | " # 模式数\n",
240 | " models = info[2]\n",
241 | " \n",
242 | " train_list = []\n",
243 | " label_list = []\n",
244 | " \n",
245 | " # 将同种模式下的数据拼接起来\n",
246 | " for model_i in range(models):\n",
247 | " blocks = []\n",
248 | " \n",
249 | " # 对每个特征,取每条数据的前12个月进行拼接\n",
250 | " for key in keys:\n",
251 | " block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12].reshape(-1, 24, 72, 1).data\n",
252 | " blocks.append(block)\n",
253 | " \n",
254 | " # 将所有特征在最后一个维度上拼接起来\n",
255 | " train_flatted = np.concatenate(blocks, axis=-1)\n",
256 | " \n",
257 | " # 取12-23月的标签进行拼接,注意加上最后一年的最后12个月的标签(与最后一年12-23月的标签共同构成最后一年前12个月的预测目标)\n",
258 | " label_flatted = np.concatenate([\n",
259 | " label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n",
260 | " label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n",
261 | " ], axis=0)\n",
262 | " \n",
263 | " train_list.append(train_flatted)\n",
264 | " label_list.append(label_flatted)\n",
265 | " \n",
266 | " return train_list, label_list"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": 6,
272 | "metadata": {
273 | "execution": {
274 | "iopub.execute_input": "2021-11-07T11:32:44.114429Z",
275 | "iopub.status.busy": "2021-11-07T11:32:44.113857Z",
276 | "iopub.status.idle": "2021-11-07T11:33:35.423369Z",
277 | "shell.execute_reply": "2021-11-07T11:33:35.423920Z",
278 | "shell.execute_reply.started": "2021-11-07T11:22:43.054065Z"
279 | },
280 | "papermill": {
281 | "duration": 51.342264,
282 | "end_time": "2021-11-07T11:33:35.424098",
283 | "exception": false,
284 | "start_time": "2021-11-07T11:32:44.081834",
285 | "status": "completed"
286 | },
287 | "tags": []
288 | },
289 | "outputs": [
290 | {
291 | "data": {
292 | "text/plain": [
293 | "((1, 1200, 24, 72, 4), (15, 1812, 24, 72, 4), (17, 1680, 24, 72, 4))"
294 | ]
295 | },
296 | "execution_count": 6,
297 | "metadata": {},
298 | "output_type": "execute_result"
299 | }
300 | ],
301 | "source": [
302 | "soda_info = ('soda', 100, 1)\n",
303 | "cmip6_info = ('cmip6', 151, 15)\n",
304 | "cmip5_info = ('cmip5', 140, 17)\n",
305 | "\n",
306 | "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n",
307 | "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n",
308 | "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n",
309 | "\n",
310 | "# 得到扁平化后的数据维度为(模式数×序列长度×纬度×经度×特征数),其中序列长度=年数×12\n",
311 | "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)"
312 | ]
313 | },
314 | {
315 | "cell_type": "markdown",
316 | "metadata": {},
317 | "source": [
318 | "#### 空值填充\n",
319 | "将空值填充为0。"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": 8,
325 | "metadata": {
326 | "execution": {
327 | "iopub.execute_input": "2021-11-07T11:33:35.538420Z",
328 | "iopub.status.busy": "2021-11-07T11:33:35.537032Z",
329 | "iopub.status.idle": "2021-11-07T11:33:35.589717Z",
330 | "shell.execute_reply": "2021-11-07T11:33:35.590116Z",
331 | "shell.execute_reply.started": "2021-11-07T11:23:30.781276Z"
332 | },
333 | "papermill": {
334 | "duration": 0.083633,
335 | "end_time": "2021-11-07T11:33:35.590263",
336 | "exception": false,
337 | "start_time": "2021-11-07T11:33:35.506630",
338 | "status": "completed"
339 | },
340 | "tags": []
341 | },
342 | "outputs": [
343 | {
344 | "name": "stdout",
345 | "output_type": "stream",
346 | "text": [
347 | "Number of null in soda_trains after fillna: 0\n"
348 | ]
349 | }
350 | ],
351 | "source": [
352 | "# 填充SODA数据中的空值\n",
353 | "soda_trains = np.array(soda_trains)\n",
354 | "soda_trains_nan = np.isnan(soda_trains)\n",
355 | "soda_trains[soda_trains_nan] = 0\n",
356 | "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": 9,
362 | "metadata": {
363 | "execution": {
364 | "iopub.execute_input": "2021-11-07T11:33:35.649057Z",
365 | "iopub.status.busy": "2021-11-07T11:33:35.647940Z",
366 | "iopub.status.idle": "2021-11-07T11:33:36.785018Z",
367 | "shell.execute_reply": "2021-11-07T11:33:36.784016Z",
368 | "shell.execute_reply.started": "2021-11-07T11:23:30.842000Z"
369 | },
370 | "papermill": {
371 | "duration": 1.1683,
372 | "end_time": "2021-11-07T11:33:36.785204",
373 | "exception": false,
374 | "start_time": "2021-11-07T11:33:35.616904",
375 | "status": "completed"
376 | },
377 | "tags": []
378 | },
379 | "outputs": [
380 | {
381 | "name": "stdout",
382 | "output_type": "stream",
383 | "text": [
384 | "Number of null in cmip6_trains after fillna: 0\n"
385 | ]
386 | }
387 | ],
388 | "source": [
389 | "# 填充CMIP6数据中的空值\n",
390 | "cmip6_trains = np.array(cmip6_trains)\n",
391 | "cmip6_trains_nan = np.isnan(cmip6_trains)\n",
392 | "cmip6_trains[cmip6_trains_nan] = 0\n",
393 | "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": 10,
399 | "metadata": {
400 | "execution": {
401 | "iopub.execute_input": "2021-11-07T11:33:36.844494Z",
402 | "iopub.status.busy": "2021-11-07T11:33:36.843382Z",
403 | "iopub.status.idle": "2021-11-07T11:33:37.982369Z",
404 | "shell.execute_reply": "2021-11-07T11:33:37.983434Z",
405 | "shell.execute_reply.started": "2021-11-07T11:23:31.982897Z"
406 | },
407 | "papermill": {
408 | "duration": 1.170648,
409 | "end_time": "2021-11-07T11:33:37.983683",
410 | "exception": false,
411 | "start_time": "2021-11-07T11:33:36.813035",
412 | "status": "completed"
413 | },
414 | "tags": []
415 | },
416 | "outputs": [
417 | {
418 | "name": "stdout",
419 | "output_type": "stream",
420 | "text": [
421 | "Number of null in cmip6_trains after fillna: 0\n"
422 | ]
423 | }
424 | ],
425 | "source": [
426 | "# 填充CMIP5数据中的空值\n",
427 | "cmip5_trains = np.array(cmip5_trains)\n",
428 | "cmip5_trains_nan = np.isnan(cmip5_trains)\n",
429 | "cmip5_trains[cmip5_trains_nan] = 0\n",
430 | "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))"
431 | ]
432 | },
433 | {
434 | "cell_type": "markdown",
435 | "metadata": {},
436 | "source": [
437 | "#### 构造数据集\n",
438 | "构造训练和验证集。"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": 11,
444 | "metadata": {
445 | "execution": {
446 | "iopub.execute_input": "2021-11-07T11:33:38.091882Z",
447 | "iopub.status.busy": "2021-11-07T11:33:38.091030Z",
448 | "iopub.status.idle": "2021-11-07T11:33:38.757089Z",
449 | "shell.execute_reply": "2021-11-07T11:33:38.756486Z",
450 | "shell.execute_reply.started": "2021-11-07T11:23:33.105534Z"
451 | },
452 | "papermill": {
453 | "duration": 0.72117,
454 | "end_time": "2021-11-07T11:33:38.757230",
455 | "exception": false,
456 | "start_time": "2021-11-07T11:33:38.036060",
457 | "status": "completed"
458 | },
459 | "tags": []
460 | },
461 | "outputs": [],
462 | "source": [
463 | "# 构造训练集\n",
464 | "\n",
465 | "X_train = []\n",
466 | "y_train = []\n",
467 | "# 从CMIP5的17种模式中各抽取100条数据\n",
468 | "for model_i in range(17):\n",
469 | " samples = np.random.choice(cmip5_trains.shape[1]-12, size=100)\n",
470 | " for ind in samples:\n",
471 | " X_train.append(cmip5_trains[model_i, ind: ind+12])\n",
472 | " y_train.append(cmip5_labels[model_i][ind: ind+24])\n",
473 | "# 从CMIP6的15种模式种各抽取100条数据\n",
474 | "for model_i in range(15):\n",
475 | " samples = np.random.choice(cmip6_trains.shape[1]-12, size=100)\n",
476 | " for ind in samples:\n",
477 | " X_train.append(cmip6_trains[model_i, ind: ind+12])\n",
478 | " y_train.append(cmip6_labels[model_i][ind: ind+24])\n",
479 | "X_train = np.array(X_train)\n",
480 | "y_train = np.array(y_train)"
481 | ]
482 | },
483 | {
484 | "cell_type": "code",
485 | "execution_count": 12,
486 | "metadata": {
487 | "execution": {
488 | "iopub.execute_input": "2021-11-07T11:33:38.819204Z",
489 | "iopub.status.busy": "2021-11-07T11:33:38.818020Z",
490 | "iopub.status.idle": "2021-11-07T11:33:38.840229Z",
491 | "shell.execute_reply": "2021-11-07T11:33:38.839737Z",
492 | "shell.execute_reply.started": "2021-11-07T11:23:33.801270Z"
493 | },
494 | "papermill": {
495 | "duration": 0.055138,
496 | "end_time": "2021-11-07T11:33:38.840360",
497 | "exception": false,
498 | "start_time": "2021-11-07T11:33:38.785222",
499 | "status": "completed"
500 | },
501 | "tags": []
502 | },
503 | "outputs": [],
504 | "source": [
505 | "# 构造验证集\n",
506 | "\n",
507 | "X_valid = []\n",
508 | "y_valid = []\n",
509 | "samples = np.random.choice(soda_trains.shape[1]-12, size=100)\n",
510 | "for ind in samples:\n",
511 | " X_valid.append(soda_trains[0, ind: ind+12])\n",
512 | " y_valid.append(soda_labels[0][ind: ind+24])\n",
513 | "X_valid = np.array(X_valid)\n",
514 | "y_valid = np.array(y_valid)"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": 13,
520 | "metadata": {
521 | "execution": {
522 | "iopub.execute_input": "2021-11-07T11:33:38.898632Z",
523 | "iopub.status.busy": "2021-11-07T11:33:38.897899Z",
524 | "iopub.status.idle": "2021-11-07T11:33:38.900564Z",
525 | "shell.execute_reply": "2021-11-07T11:33:38.901077Z",
526 | "shell.execute_reply.started": "2021-11-07T11:23:33.839695Z"
527 | },
528 | "papermill": {
529 | "duration": 0.034011,
530 | "end_time": "2021-11-07T11:33:38.901204",
531 | "exception": false,
532 | "start_time": "2021-11-07T11:33:38.867193",
533 | "status": "completed"
534 | },
535 | "tags": []
536 | },
537 | "outputs": [
538 | {
539 | "data": {
540 | "text/plain": [
541 | "((3200, 12, 24, 72, 4), (3200, 24), (100, 12, 24, 72, 4), (100, 24))"
542 | ]
543 | },
544 | "execution_count": 13,
545 | "metadata": {},
546 | "output_type": "execute_result"
547 | }
548 | ],
549 | "source": [
550 | "# 查看数据集维度\n",
551 | "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
552 | ]
553 | },
554 | {
555 | "cell_type": "code",
556 | "execution_count": 15,
557 | "metadata": {
558 | "execution": {
559 | "iopub.execute_input": "2021-11-07T11:33:39.033073Z",
560 | "iopub.status.busy": "2021-11-07T11:33:39.032319Z",
561 | "iopub.status.idle": "2021-11-07T11:33:40.963743Z",
562 | "shell.execute_reply": "2021-11-07T11:33:40.962736Z",
563 | "shell.execute_reply.started": "2021-11-07T11:23:33.879524Z"
564 | },
565 | "papermill": {
566 | "duration": 1.963702,
567 | "end_time": "2021-11-07T11:33:40.963922",
568 | "exception": false,
569 | "start_time": "2021-11-07T11:33:39.000220",
570 | "status": "completed"
571 | },
572 | "tags": []
573 | },
574 | "outputs": [],
575 | "source": [
576 | "# 保存数据集\n",
577 | "np.save('X_train_sample.npy', X_train)\n",
578 | "np.save('y_train_sample.npy', y_train)\n",
579 | "np.save('X_valid_sample.npy', X_valid)\n",
580 | "np.save('y_valid_sample.npy', y_valid)"
581 | ]
582 | },
583 | {
584 | "cell_type": "markdown",
585 | "metadata": {
586 | "papermill": {
587 | "duration": 0.442674,
588 | "end_time": "2021-11-07T11:33:43.003352",
589 | "exception": false,
590 | "start_time": "2021-11-07T11:33:42.560678",
591 | "status": "completed"
592 | },
593 | "tags": []
594 | },
595 | "source": [
596 | "### 模型构建"
597 | ]
598 | },
599 | {
600 | "cell_type": "markdown",
601 | "metadata": {},
602 | "source": [
603 | "这一部分我们来重点学习一下该方案的模型结构。"
604 | ]
605 | },
606 | {
607 | "cell_type": "code",
608 | "execution_count": 16,
609 | "metadata": {
610 | "execution": {
611 | "iopub.execute_input": "2021-11-07T11:33:43.228569Z",
612 | "iopub.status.busy": "2021-11-07T11:33:43.223113Z",
613 | "iopub.status.idle": "2021-11-07T11:33:58.100748Z",
614 | "shell.execute_reply": "2021-11-07T11:33:58.101185Z",
615 | "shell.execute_reply.started": "2021-11-07T11:23:42.357343Z"
616 | },
617 | "papermill": {
618 | "duration": 15.039399,
619 | "end_time": "2021-11-07T11:33:58.101351",
620 | "exception": false,
621 | "start_time": "2021-11-07T11:33:43.061952",
622 | "status": "completed"
623 | },
624 | "tags": []
625 | },
626 | "outputs": [],
627 | "source": [
628 | "# 读取数据集\n",
629 | "X_train = np.load('../input/ai-earth-task04-samples/X_train_sample.npy')\n",
630 | "y_train = np.load('../input/ai-earth-task04-samples/y_train_sample.npy')\n",
631 | "X_valid = np.load('../input/ai-earth-task04-samples/X_valid_sample.npy')\n",
632 | "y_valid = np.load('../input/ai-earth-task04-samples/y_valid_sample.npy')"
633 | ]
634 | },
635 | {
636 | "cell_type": "code",
637 | "execution_count": 17,
638 | "metadata": {
639 | "execution": {
640 | "iopub.execute_input": "2021-11-07T11:33:58.162900Z",
641 | "iopub.status.busy": "2021-11-07T11:33:58.161742Z",
642 | "iopub.status.idle": "2021-11-07T11:33:58.164817Z",
643 | "shell.execute_reply": "2021-11-07T11:33:58.165196Z",
644 | "shell.execute_reply.started": "2021-11-07T11:23:56.303716Z"
645 | },
646 | "papermill": {
647 | "duration": 0.036534,
648 | "end_time": "2021-11-07T11:33:58.165327",
649 | "exception": false,
650 | "start_time": "2021-11-07T11:33:58.128793",
651 | "status": "completed"
652 | },
653 | "tags": []
654 | },
655 | "outputs": [
656 | {
657 | "data": {
658 | "text/plain": [
659 | "((3200, 12, 24, 72, 4), (3200, 24), (100, 12, 24, 72, 4), (100, 24))"
660 | ]
661 | },
662 | "execution_count": 17,
663 | "metadata": {},
664 | "output_type": "execute_result"
665 | }
666 | ],
667 | "source": [
668 | "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
669 | ]
670 | },
671 | {
672 | "cell_type": "code",
673 | "execution_count": 18,
674 | "metadata": {
675 | "execution": {
676 | "iopub.execute_input": "2021-11-07T11:33:58.225845Z",
677 | "iopub.status.busy": "2021-11-07T11:33:58.224285Z",
678 | "iopub.status.idle": "2021-11-07T11:33:58.226437Z",
679 | "shell.execute_reply": "2021-11-07T11:33:58.226877Z",
680 | "shell.execute_reply.started": "2021-11-07T11:23:56.313927Z"
681 | },
682 | "papermill": {
683 | "duration": 0.03485,
684 | "end_time": "2021-11-07T11:33:58.227001",
685 | "exception": false,
686 | "start_time": "2021-11-07T11:33:58.192151",
687 | "status": "completed"
688 | },
689 | "tags": []
690 | },
691 | "outputs": [],
692 | "source": [
693 | "# 构造数据管道\n",
694 | "class AIEarthDataset(Dataset):\n",
695 | " def __init__(self, data, label):\n",
696 | " self.data = torch.tensor(data, dtype=torch.float32)\n",
697 | " self.label = torch.tensor(label, dtype=torch.float32)\n",
698 | "\n",
699 | " def __len__(self):\n",
700 | " return len(self.label)\n",
701 | " \n",
702 | " def __getitem__(self, idx):\n",
703 | " return self.data[idx], self.label[idx]"
704 | ]
705 | },
706 | {
707 | "cell_type": "code",
708 | "execution_count": 19,
709 | "metadata": {
710 | "execution": {
711 | "iopub.execute_input": "2021-11-07T11:33:58.289091Z",
712 | "iopub.status.busy": "2021-11-07T11:33:58.288549Z",
713 | "iopub.status.idle": "2021-11-07T11:33:59.029333Z",
714 | "shell.execute_reply": "2021-11-07T11:33:59.028812Z",
715 | "shell.execute_reply.started": "2021-11-07T11:23:56.324788Z"
716 | },
717 | "papermill": {
718 | "duration": 0.775212,
719 | "end_time": "2021-11-07T11:33:59.029494",
720 | "exception": false,
721 | "start_time": "2021-11-07T11:33:58.254282",
722 | "status": "completed"
723 | },
724 | "tags": []
725 | },
726 | "outputs": [],
727 | "source": [
728 | "batch_size = 32\n",
729 | "\n",
730 | "trainset = AIEarthDataset(X_train, y_train)\n",
731 | "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
732 | "\n",
733 | "validset = AIEarthDataset(X_valid, y_valid)\n",
734 | "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)"
735 | ]
736 | },
737 | {
738 | "cell_type": "markdown",
739 | "metadata": {},
740 | "source": [
741 | "#### 构造评估函数"
742 | ]
743 | },
744 | {
745 | "cell_type": "code",
746 | "execution_count": 24,
747 | "metadata": {
748 | "execution": {
749 | "iopub.execute_input": "2021-11-07T11:33:59.402757Z",
750 | "iopub.status.busy": "2021-11-07T11:33:59.401182Z",
751 | "iopub.status.idle": "2021-11-07T11:33:59.404701Z",
752 | "shell.execute_reply": "2021-11-07T11:33:59.405088Z",
753 | "shell.execute_reply.started": "2021-11-07T11:23:57.219014Z"
754 | },
755 | "papermill": {
756 | "duration": 0.037919,
757 | "end_time": "2021-11-07T11:33:59.405211",
758 | "exception": false,
759 | "start_time": "2021-11-07T11:33:59.367292",
760 | "status": "completed"
761 | },
762 | "tags": []
763 | },
764 | "outputs": [],
765 | "source": [
766 | "def rmse(y_true, y_preds):\n",
767 | " return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n",
768 | "\n",
769 | "# 评估函数\n",
770 | "def score(y_true, y_preds):\n",
771 | " # 相关性技巧评分\n",
772 | " accskill_score = 0\n",
773 | " # RMSE\n",
774 | " rmse_scores = 0\n",
775 | " a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n",
776 | " y_true_mean = np.mean(y_true, axis=0)\n",
777 | " y_pred_mean = np.mean(y_preds, axis=0)\n",
778 | " for i in range(24):\n",
779 | " fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n",
780 | " fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n",
781 | " cor_i = fenzi / fenmu\n",
782 | " accskill_score += a[i] * np.log(i+1) * cor_i\n",
783 | " rmse_score = rmse(y_true[:, i], y_preds[:, i])\n",
784 | " rmse_scores += rmse_score\n",
785 | " return 2/3.0 * accskill_score - rmse_scores"
786 | ]
787 | },
788 | {
789 | "cell_type": "markdown",
790 | "metadata": {},
791 | "source": [
792 | "#### 模型构造\n",
793 | "\n",
794 | "该TOP方案采用TCN单元+CNN单元串行组成TCNN层,通过堆叠多层的TCNN层来交替地提取时间和空间信息,并将提取到的时空信息用RNN来抽取出三种不同时间尺度的特征表达。"
795 | ]
796 | },
797 | {
798 | "cell_type": "markdown",
799 | "metadata": {},
800 | "source": [
801 | "- **TCN单元**\n",
802 | "\n",
803 | "TCN模型全称时间卷积网络(Temporal Convolutional Network),与RNN一样是时序模型。TCN以CNN为基础,为了适应序列问题,它从以下三方面做出了改进:\n",
804 | "\n",
805 | "1. 因果卷积\n",
806 | "\n",
807 | "TCN处理输入与输出等长的序列问题,它的每一个隐藏层节点数与输入步长是相同的,并且隐藏层t时刻节点的值只依赖于前一层t时刻及之前节点的值。也就是说TCN通过追溯前因(t时刻及之前的值)来获得当前结果,称为因果卷积。\n",
808 | "\n",
809 | "2. 扩张卷积\n",
810 | "\n",
811 | "传统CNN的感受野受限于卷积核的大小,需要通过增加池化层来获得更大的感受野,但是池化的操作会带来信息的损失。为了解决这个问题,TCN采用扩张卷积来增大感受野,获取更长时间的信息。扩张卷积对输入进行间隔采样,采样间隔由扩张因子d控制,公式定义如下:\n",
812 | "$$\n",
813 | "F(s) = (X * df)(s) = \\sum_{i=0}^{k-1} f(i) \\times X_{s-di}\n",
814 | "$$\n",
815 | "其中X为当前层的输入,k为当前层的卷积核大小,s为当前节点的时刻。也就是说,对于扩张因子为d、卷积核为k的隐藏层,对前一层的输入每d个点采样一次,共采样k个点作为当前时刻s的输入。这样TCN的感受野就由卷积核的大小k和扩张因子d共同决定,可以获取更长时间的依赖信息。\n",
816 | "\n",
817 | "
\n",
818 | "\n",
819 | "3. 残差连接\n",
820 | "\n",
821 | "网络的层数越多,所能提取到的特征就越丰富,但这也会带来梯度消失或爆炸的问题,目前解决这个问题的一个有效方法就是残差连接。TCN的残差模块包含两层卷积操作,并且采用了WeightNorm和Dropout进行正则化,如下图所示。\n",
822 | "\n",
823 | "
\n",
824 | "\n",
825 | "总的来说,TCN是卷积操作在序列问题上的改进,具有CNN参数量少的优点,可以搭建更深层的网络,相比于RNN不容易存在梯度消失和爆炸的问题,同时TCN具有灵活的感受野,能够适应不同的任务,在许多数据集上的比较表明TCN比RNN、LSTM、GRU等序列模型有更好的表现。\n",
826 | "\n",
827 | "想要更深入地了解TCN可以参考以下链接:\n",
828 | " \n",
829 | " - 论文原文:https://arxiv.org/pdf/1803.01271.pdf\n",
830 | " - GitHub:https://github.com/locuslab/tcn\n",
831 | " \n",
832 | "该方案中所构建的TCN单元并不是标准的TCN层,它的结构如下图所示,可以看到,这里的TCN单元只是用了一个卷积层,并且在卷积层前后都采用了BatchNormalization来提高模型的泛化能力。需要注意的是,这里的卷积操作是对时间维度进行操作,因此需要对输入的形状进行转换,并且为了便于匹配之后的网络层,需要将输出的形状转换回输入时的(N,T,C,H,W)的形式。\n",
833 | "\n",
834 | "
"
835 | ]
836 | },
837 | {
838 | "cell_type": "code",
839 | "execution_count": 20,
840 | "metadata": {
841 | "execution": {
842 | "iopub.execute_input": "2021-11-07T11:33:59.094709Z",
843 | "iopub.status.busy": "2021-11-07T11:33:59.094119Z",
844 | "iopub.status.idle": "2021-11-07T11:33:59.097687Z",
845 | "shell.execute_reply": "2021-11-07T11:33:59.097168Z",
846 | "shell.execute_reply.started": "2021-11-07T11:23:57.153300Z"
847 | },
848 | "papermill": {
849 | "duration": 0.039993,
850 | "end_time": "2021-11-07T11:33:59.097803",
851 | "exception": false,
852 | "start_time": "2021-11-07T11:33:59.057810",
853 | "status": "completed"
854 | },
855 | "tags": []
856 | },
857 | "outputs": [],
858 | "source": [
859 | "# 构建TCN单元\n",
860 | "class TCNBlock(nn.Module):\n",
861 | " def __init__(self, in_channels, out_channels, kernel_size, stride, padding):\n",
862 | " super().__init__()\n",
863 | " self.bn1 = nn.BatchNorm1d(in_channels)\n",
864 | " self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)\n",
865 | " self.bn2 = nn.BatchNorm1d(out_channels)\n",
866 | " \n",
867 | " if in_channels == out_channels and stride == 1:\n",
868 | " self.res = lambda x: x\n",
869 | " else:\n",
870 | " self.res = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride)\n",
871 | " \n",
872 | " def forward(self, x):\n",
873 | " # 转换输入形状\n",
874 | " N, T, C, H, W = x.shape\n",
875 | " x = x.permute(0, 3, 4, 2, 1).contiguous()\n",
876 | " x = x.view(N*H*W, C, T)\n",
877 | " \n",
878 | " # 残差\n",
879 | " res = self.res(x) \n",
880 | " res = self.bn2(res)\n",
881 | "\n",
882 | " x = F.relu(self.bn1(x))\n",
883 | " x = self.conv(x)\n",
884 | " x = self.bn2(x)\n",
885 | " \n",
886 | " x = x + res\n",
887 | " \n",
888 | " # 将输出转换回(N,T,C,H,W)的形式\n",
889 | " _, C_new, T_new = x.shape\n",
890 | " x = x.view(N, H, W, C_new, T_new)\n",
891 | " x = x.permute(0, 4, 3, 1, 2).contiguous()\n",
892 | " \n",
893 | " return x"
894 | ]
895 | },
896 | {
897 | "cell_type": "markdown",
898 | "metadata": {},
899 | "source": [
900 | "- **CNN单元**\n",
901 | "\n",
902 | "CNN单元结构与TCN单元相似,都只有一个卷积层,并且使用BatchNormalization来提高模型泛化能力。同时,类似TCN单元,CNN单元中也加入了残差连接。结构如下图所示:\n",
903 | "\n",
904 | "
"
905 | ]
906 | },
907 | {
908 | "cell_type": "code",
909 | "execution_count": 21,
910 | "metadata": {
911 | "execution": {
912 | "iopub.execute_input": "2021-11-07T11:33:59.160481Z",
913 | "iopub.status.busy": "2021-11-07T11:33:59.158973Z",
914 | "iopub.status.idle": "2021-11-07T11:33:59.163197Z",
915 | "shell.execute_reply": "2021-11-07T11:33:59.163610Z",
916 | "shell.execute_reply.started": "2021-11-07T11:23:57.170593Z"
917 | },
918 | "papermill": {
919 | "duration": 0.038796,
920 | "end_time": "2021-11-07T11:33:59.163736",
921 | "exception": false,
922 | "start_time": "2021-11-07T11:33:59.124940",
923 | "status": "completed"
924 | },
925 | "tags": []
926 | },
927 | "outputs": [],
928 | "source": [
929 | "# 构建CNN单元\n",
930 | "class CNNBlock(nn.Module):\n",
931 | " def __init__(self, in_channels, out_channels, kernel_size, stride, padding):\n",
932 | " super().__init__()\n",
933 | " self.bn1 = nn.BatchNorm2d(in_channels)\n",
934 | " self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)\n",
935 | " self.bn2 = nn.BatchNorm2d(out_channels)\n",
936 | " \n",
937 | " if (in_channels == out_channels) and (stride == 1):\n",
938 | " self.res = lambda x: x\n",
939 | " else:\n",
940 | " self.res = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)\n",
941 | " \n",
942 | " def forward(self, x):\n",
943 | " # 转换输入形状\n",
944 | " N, T, C, H, W = x.shape\n",
945 | " x = x.view(N*T, C, H, W)\n",
946 | " \n",
947 | " # 残差\n",
948 | " res = self.res(x)\n",
949 | " res = self.bn2(res)\n",
950 | "\n",
951 | " x = F.relu(self.bn1(x))\n",
952 | " x = self.conv(x)\n",
953 | " x = self.bn2(x)\n",
954 | " \n",
955 | " x = x + res\n",
956 | " \n",
957 | " # 将输出转换回(N,T,C,H,W)的形式\n",
958 | " _, C_new, H_new, W_new = x.shape\n",
959 | " x = x.view(N, T, C_new, H_new, W_new)\n",
960 | " \n",
961 | " return x"
962 | ]
963 | },
964 | {
965 | "cell_type": "markdown",
966 | "metadata": {},
967 | "source": [
968 | "- **TCNN层**\n",
969 | "\n",
970 | "将TCN单元和CNN单元串行连接,就构成了一个TCNN层。\n",
971 | "\n",
972 | "
"
973 | ]
974 | },
975 | {
976 | "cell_type": "code",
977 | "execution_count": 22,
978 | "metadata": {
979 | "execution": {
980 | "iopub.execute_input": "2021-11-07T11:33:59.223231Z",
981 | "iopub.status.busy": "2021-11-07T11:33:59.222402Z",
982 | "iopub.status.idle": "2021-11-07T11:33:59.224980Z",
983 | "shell.execute_reply": "2021-11-07T11:33:59.224529Z",
984 | "shell.execute_reply.started": "2021-11-07T11:23:57.182192Z"
985 | },
986 | "papermill": {
987 | "duration": 0.034509,
988 | "end_time": "2021-11-07T11:33:59.225082",
989 | "exception": false,
990 | "start_time": "2021-11-07T11:33:59.190573",
991 | "status": "completed"
992 | },
993 | "tags": []
994 | },
995 | "outputs": [],
996 | "source": [
997 | "class TCNNBlock(nn.Module):\n",
998 | " def __init__(self, in_channels, out_channels, kernel_size, stride_tcn, stride_cnn, padding):\n",
999 | " super().__init__()\n",
1000 | " self.tcn = TCNBlock(in_channels, out_channels, kernel_size, stride_tcn, padding)\n",
1001 | " self.cnn = CNNBlock(out_channels, out_channels, kernel_size, stride_cnn, padding)\n",
1002 | " \n",
1003 | " def forward(self, x):\n",
1004 | " x = self.tcn(x)\n",
1005 | " x = self.cnn(x)\n",
1006 | " return x"
1007 | ]
1008 | },
1009 | {
1010 | "cell_type": "markdown",
1011 | "metadata": {},
1012 | "source": [
1013 | "- **TCNN+RNN模型**\n",
1014 | "\n",
1015 | "整体的模型结构如下图所示:\n",
1016 | "\n",
1017 | "
\n",
1018 | "\n",
1019 | "1. TCNN部分\n",
1020 | "\n",
1021 | "TCNN部分的模型结构类似传统CNN的结构,非常规整,通过逐渐增加通道数来提取更丰富的特征表达。需要注意的是输入数据的格式是(N,T,H,W,C),为了匹配卷积层的输入格式,需要将数据格式转换为(N,T,C,H,W)。\n",
1022 | "\n",
1023 | "2. GAP层\n",
1024 | "\n",
1025 | "GAP全称为全局平均池化(Global Average Pooling)层,它的作用是把每个通道上的特征图取全局平均,假设经过TCNN部分得到的输出格式为(N,T,C,H,W),那么GAP层就会把每个通道上形状为H×W的特征图上的所有值求平均,最终得到的输出格式就变成(N,T,C)。GAP层最早出现在论文《Network in Network》(论文原文:https://arxiv.org/pdf/1312.4400.pdf )中用于代替传统CNN中的全连接层,之后的许多实验证明GAP层确实可以提高CNN的效果。\n",
1026 | "\n",
1027 | "那么GAP层为什么可以代替全连接层呢?在传统CNN中,经过多层卷积和池化的操作后,会由Flatten层将特征图拉伸成一列,然后经过全连接层,那么对于形状为(C,H,W)的一条数据,经Flatten层拉伸后的长度为C×H×W,此时假设全连接层节点数为U,全连接层的参数量就是C×H×W×U,这么大的参数量很容易使得模型过拟合。相比之下,GAP层不引入新的参数,因此可以有效减少过拟合问题,并且模型参数少也能加快训练速度。另一方面,全连接层是一个黑箱子,我们很难解释多分类的信息是怎样传回卷积层的,而GAP层就很容易理解,每个通道的值就代表了经过多层卷积操作后所提取出来的特征。更详细的理解可以参考https://www.zhihu.com/question/373188099\n",
1028 | "\n",
1029 | "在Pytorch中没有内置的GAP层,因此可以用adaptive_avg_pool2d来替代,这个函数可以将特征图压缩成给定的输出形状,将output_size参数设置为(1,1),就等同于GAP操作,函数的详细使用方法可以参考https://pytorch.org/docs/stable/generated/torch.nn.functional.adaptive_avg_pool2d.html?highlight=adaptive_avg_pool2d#torch.nn.functional.adaptive_avg_pool2d\n",
1030 | "\n",
1031 | "3. RNN部分\n",
1032 | "\n",
1033 | "至此为止我们所使用的都是长度为12的时间序列,每个时间步代表一个月的信息。不同尺度的时间序列所携带的信息是不尽相同的,比如用长度为6的时间序列来表达一年的SST值,那么每个时间步所代表的就是两个月的SST信息,这种时间尺度下的SST序列与长度为12的SST序列所反映的一年中SST变化趋势等信息就不完全相同。所以,为了尽可能全面地挖掘更多信息,该TOP方案中用MaxPool层来获得三种不同时间尺度的序列,同时,用RNN层来抽取序列的特征表达。RNN非常适合用于线性序列的自动特征提取,例如对于形状为(T,C1)的一条输入数据,R经过节点数为C2的RNN层就能抽取出长度为C2的向量,由于RNN由前往后进行信息线性传递的网络结构,抽取出的向量能够很好地表达序列中的依赖关系。\n",
1034 | "\n",
1035 | "此时三种不同时间尺度的序列都抽取出了一个向量来表示,将向量拼接起来再经过一个全连接层就得到了24个月的预测序列。"
1036 | ]
1037 | },
1038 | {
1039 | "cell_type": "code",
1040 | "execution_count": 23,
1041 | "metadata": {
1042 | "execution": {
1043 | "iopub.execute_input": "2021-11-07T11:33:59.339242Z",
1044 | "iopub.status.busy": "2021-11-07T11:33:59.337656Z",
1045 | "iopub.status.idle": "2021-11-07T11:33:59.339842Z",
1046 | "shell.execute_reply": "2021-11-07T11:33:59.340257Z",
1047 | "shell.execute_reply.started": "2021-11-07T11:23:57.199137Z"
1048 | },
1049 | "papermill": {
1050 | "duration": 0.088367,
1051 | "end_time": "2021-11-07T11:33:59.340377",
1052 | "exception": false,
1053 | "start_time": "2021-11-07T11:33:59.252010",
1054 | "status": "completed"
1055 | },
1056 | "tags": []
1057 | },
1058 | "outputs": [],
1059 | "source": [
1060 | "# 构造模型\n",
1061 | "class Model(nn.Module):\n",
1062 | " def __init__(self):\n",
1063 | " super().__init__()\n",
1064 | " self.conv = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3)\n",
1065 | " self.tcnn1 = TCNNBlock(64, 64, 3, 1, 1, 1)\n",
1066 | " self.tcnn2 = TCNNBlock(64, 128, 3, 1, 2, 1)\n",
1067 | " self.tcnn3 = TCNNBlock(128, 128, 3, 1, 1, 1)\n",
1068 | " self.tcnn4 = TCNNBlock(128, 256, 3, 1, 2, 1)\n",
1069 | " self.tcnn5 = TCNNBlock(256, 256, 3, 1, 1, 1)\n",
1070 | " self.rnn = nn.RNN(256, 256, batch_first=True)\n",
1071 | " self.maxpool = nn.MaxPool1d(2)\n",
1072 | " self.fc = nn.Linear(256*3, 24)\n",
1073 | " \n",
1074 | " def forward(self, x):\n",
1075 | " # 转换输入形状\n",
1076 | " N, T, H, W, C = x.shape\n",
1077 | " x = x.permute(0, 1, 4, 2, 3).contiguous()\n",
1078 | " x = x.view(N*T, C, H, W)\n",
1079 | " \n",
1080 | " # 经过一个卷积层\n",
1081 | " x = self.conv(x)\n",
1082 | " _, C_new, H_new, W_new = x.shape\n",
1083 | " x = x.view(N, T, C_new, H_new, W_new)\n",
1084 | " \n",
1085 | " # TCNN部分\n",
1086 | " for i in range(3):\n",
1087 | " x = self.tcnn1(x)\n",
1088 | " x = self.tcnn2(x)\n",
1089 | " for i in range(2):\n",
1090 | " x = self.tcnn3(x)\n",
1091 | " x = self.tcnn4(x)\n",
1092 | " for i in range(2):\n",
1093 | " x = self.tcnn5(x)\n",
1094 | " \n",
1095 | " # 全局平均池化\n",
1096 | " x = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()\n",
1097 | " \n",
1098 | " # RNN部分,分别得到长度为T、T/2、T/4三种时间尺度的特征表达,注意转换RNN层输出的格式\n",
1099 | " hidden_state = []\n",
1100 | " for i in range(3):\n",
1101 | " x, h = self.rnn(x)\n",
1102 | " h = h.squeeze()\n",
1103 | " hidden_state.append(h)\n",
1104 | " x = self.maxpool(x.transpose(1, 2)).transpose(1, 2)\n",
1105 | " \n",
1106 | " x = torch.cat(hidden_state, dim=1)\n",
1107 | " x = self.fc(x)\n",
1108 | " \n",
1109 | " return x"
1110 | ]
1111 | },
1112 | {
1113 | "cell_type": "code",
1114 | "execution_count": 25,
1115 | "metadata": {
1116 | "execution": {
1117 | "iopub.execute_input": "2021-11-07T11:33:59.467079Z",
1118 | "iopub.status.busy": "2021-11-07T11:33:59.466474Z",
1119 | "iopub.status.idle": "2021-11-07T11:33:59.507567Z",
1120 | "shell.execute_reply": "2021-11-07T11:33:59.507973Z",
1121 | "shell.execute_reply.started": "2021-11-07T11:23:57.232949Z"
1122 | },
1123 | "papermill": {
1124 | "duration": 0.076245,
1125 | "end_time": "2021-11-07T11:33:59.508101",
1126 | "exception": false,
1127 | "start_time": "2021-11-07T11:33:59.431856",
1128 | "status": "completed"
1129 | },
1130 | "tags": []
1131 | },
1132 | "outputs": [
1133 | {
1134 | "name": "stdout",
1135 | "output_type": "stream",
1136 | "text": [
1137 | "Model(\n",
1138 | " (conv): Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n",
1139 | " (tcnn1): TCNNBlock(\n",
1140 | " (tcn): TCNBlock(\n",
1141 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1142 | " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
1143 | " (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1144 | " )\n",
1145 | " (cnn): CNNBlock(\n",
1146 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1147 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1148 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1149 | " )\n",
1150 | " )\n",
1151 | " (tcnn2): TCNNBlock(\n",
1152 | " (tcn): TCNBlock(\n",
1153 | " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1154 | " (conv): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n",
1155 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1156 | " (res): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
1157 | " )\n",
1158 | " (cnn): CNNBlock(\n",
1159 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1160 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
1161 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1162 | " (res): Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 2))\n",
1163 | " )\n",
1164 | " )\n",
1165 | " (tcnn3): TCNNBlock(\n",
1166 | " (tcn): TCNBlock(\n",
1167 | " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1168 | " (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n",
1169 | " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1170 | " )\n",
1171 | " (cnn): CNNBlock(\n",
1172 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1173 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1174 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1175 | " )\n",
1176 | " )\n",
1177 | " (tcnn4): TCNNBlock(\n",
1178 | " (tcn): TCNBlock(\n",
1179 | " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1180 | " (conv): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n",
1181 | " (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1182 | " (res): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n",
1183 | " )\n",
1184 | " (cnn): CNNBlock(\n",
1185 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1186 | " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
1187 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1188 | " (res): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2))\n",
1189 | " )\n",
1190 | " )\n",
1191 | " (tcnn5): TCNNBlock(\n",
1192 | " (tcn): TCNBlock(\n",
1193 | " (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1194 | " (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n",
1195 | " (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1196 | " )\n",
1197 | " (cnn): CNNBlock(\n",
1198 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1199 | " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1200 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
1201 | " )\n",
1202 | " )\n",
1203 | " (rnn): RNN(256, 256, batch_first=True)\n",
1204 | " (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
1205 | " (fc): Linear(in_features=768, out_features=24, bias=True)\n",
1206 | ")\n"
1207 | ]
1208 | }
1209 | ],
1210 | "source": [
1211 | "model = Model()\n",
1212 | "print(model)"
1213 | ]
1214 | },
1215 | {
1216 | "cell_type": "markdown",
1217 | "metadata": {},
1218 | "source": [
1219 | "#### 模型训练"
1220 | ]
1221 | },
1222 | {
1223 | "cell_type": "code",
1224 | "execution_count": 26,
1225 | "metadata": {
1226 | "execution": {
1227 | "iopub.execute_input": "2021-11-07T11:33:59.569324Z",
1228 | "iopub.status.busy": "2021-11-07T11:33:59.568443Z",
1229 | "iopub.status.idle": "2021-11-07T11:33:59.570995Z",
1230 | "shell.execute_reply": "2021-11-07T11:33:59.570559Z",
1231 | "shell.execute_reply.started": "2021-11-07T11:23:59.196182Z"
1232 | },
1233 | "papermill": {
1234 | "duration": 0.035076,
1235 | "end_time": "2021-11-07T11:33:59.571104",
1236 | "exception": false,
1237 | "start_time": "2021-11-07T11:33:59.536028",
1238 | "status": "completed"
1239 | },
1240 | "tags": []
1241 | },
1242 | "outputs": [],
1243 | "source": [
1244 | "# 采用RMSE作为损失函数\n",
1245 | "def RMSELoss(y_pred,y_true):\n",
1246 | " loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n",
1247 | " return loss"
1248 | ]
1249 | },
1250 | {
1251 | "cell_type": "code",
1252 | "execution_count": 27,
1253 | "metadata": {
1254 | "execution": {
1255 | "iopub.execute_input": "2021-11-07T11:33:59.635947Z",
1256 | "iopub.status.busy": "2021-11-07T11:33:59.635133Z",
1257 | "iopub.status.idle": "2021-11-07T11:41:54.987872Z",
1258 | "shell.execute_reply": "2021-11-07T11:41:54.987159Z",
1259 | "shell.execute_reply.started": "2021-11-07T11:24:00.777235Z"
1260 | },
1261 | "papermill": {
1262 | "duration": 475.387685,
1263 | "end_time": "2021-11-07T11:41:54.988052",
1264 | "exception": false,
1265 | "start_time": "2021-11-07T11:33:59.600367",
1266 | "status": "completed"
1267 | },
1268 | "tags": []
1269 | },
1270 | "outputs": [
1271 | {
1272 | "name": "stdout",
1273 | "output_type": "stream",
1274 | "text": [
1275 | "Epoch: 1/10\n"
1276 | ]
1277 | },
1278 | {
1279 | "name": "stderr",
1280 | "output_type": "stream",
1281 | "text": [
1282 | "100%|██████████| 100/100 [00:51<00:00, 1.95it/s]\n"
1283 | ]
1284 | },
1285 | {
1286 | "name": "stdout",
1287 | "output_type": "stream",
1288 | "text": [
1289 | "Training Loss: 18.099\n"
1290 | ]
1291 | },
1292 | {
1293 | "name": "stderr",
1294 | "output_type": "stream",
1295 | "text": [
1296 | "4it [00:00, 6.19it/s]\n"
1297 | ]
1298 | },
1299 | {
1300 | "name": "stdout",
1301 | "output_type": "stream",
1302 | "text": [
1303 | "Validation Loss: 16.756\n",
1304 | "Score: -4.320\n",
1305 | "Epoch: 2/10\n"
1306 | ]
1307 | },
1308 | {
1309 | "name": "stderr",
1310 | "output_type": "stream",
1311 | "text": [
1312 | "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n"
1313 | ]
1314 | },
1315 | {
1316 | "name": "stdout",
1317 | "output_type": "stream",
1318 | "text": [
1319 | "Training Loss: 16.955\n"
1320 | ]
1321 | },
1322 | {
1323 | "name": "stderr",
1324 | "output_type": "stream",
1325 | "text": [
1326 | "4it [00:00, 6.41it/s]\n"
1327 | ]
1328 | },
1329 | {
1330 | "name": "stdout",
1331 | "output_type": "stream",
1332 | "text": [
1333 | "Validation Loss: 17.657\n",
1334 | "Score: -32.332\n",
1335 | "Epoch: 3/10\n"
1336 | ]
1337 | },
1338 | {
1339 | "name": "stderr",
1340 | "output_type": "stream",
1341 | "text": [
1342 | "100%|██████████| 100/100 [00:45<00:00, 2.18it/s]\n"
1343 | ]
1344 | },
1345 | {
1346 | "name": "stdout",
1347 | "output_type": "stream",
1348 | "text": [
1349 | "Training Loss: 16.639\n"
1350 | ]
1351 | },
1352 | {
1353 | "name": "stderr",
1354 | "output_type": "stream",
1355 | "text": [
1356 | "4it [00:00, 6.29it/s]\n"
1357 | ]
1358 | },
1359 | {
1360 | "name": "stdout",
1361 | "output_type": "stream",
1362 | "text": [
1363 | "Validation Loss: 19.156\n",
1364 | "Score: -25.483\n",
1365 | "Epoch: 4/10\n"
1366 | ]
1367 | },
1368 | {
1369 | "name": "stderr",
1370 | "output_type": "stream",
1371 | "text": [
1372 | "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n"
1373 | ]
1374 | },
1375 | {
1376 | "name": "stdout",
1377 | "output_type": "stream",
1378 | "text": [
1379 | "Training Loss: 16.173\n"
1380 | ]
1381 | },
1382 | {
1383 | "name": "stderr",
1384 | "output_type": "stream",
1385 | "text": [
1386 | "4it [00:00, 6.29it/s]\n"
1387 | ]
1388 | },
1389 | {
1390 | "name": "stdout",
1391 | "output_type": "stream",
1392 | "text": [
1393 | "Validation Loss: 18.130\n",
1394 | "Score: -15.470\n",
1395 | "Epoch: 5/10\n"
1396 | ]
1397 | },
1398 | {
1399 | "name": "stderr",
1400 | "output_type": "stream",
1401 | "text": [
1402 | "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n"
1403 | ]
1404 | },
1405 | {
1406 | "name": "stdout",
1407 | "output_type": "stream",
1408 | "text": [
1409 | "Training Loss: 15.818\n"
1410 | ]
1411 | },
1412 | {
1413 | "name": "stderr",
1414 | "output_type": "stream",
1415 | "text": [
1416 | "4it [00:00, 6.28it/s]\n"
1417 | ]
1418 | },
1419 | {
1420 | "name": "stdout",
1421 | "output_type": "stream",
1422 | "text": [
1423 | "Validation Loss: 17.367\n",
1424 | "Score: -14.745\n",
1425 | "Epoch: 6/10\n"
1426 | ]
1427 | },
1428 | {
1429 | "name": "stderr",
1430 | "output_type": "stream",
1431 | "text": [
1432 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
1433 | ]
1434 | },
1435 | {
1436 | "name": "stdout",
1437 | "output_type": "stream",
1438 | "text": [
1439 | "Training Loss: 15.464\n"
1440 | ]
1441 | },
1442 | {
1443 | "name": "stderr",
1444 | "output_type": "stream",
1445 | "text": [
1446 | "4it [00:00, 6.28it/s]\n"
1447 | ]
1448 | },
1449 | {
1450 | "name": "stdout",
1451 | "output_type": "stream",
1452 | "text": [
1453 | "Validation Loss: 18.289\n",
1454 | "Score: -4.441\n",
1455 | "Epoch: 7/10\n"
1456 | ]
1457 | },
1458 | {
1459 | "name": "stderr",
1460 | "output_type": "stream",
1461 | "text": [
1462 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
1463 | ]
1464 | },
1465 | {
1466 | "name": "stdout",
1467 | "output_type": "stream",
1468 | "text": [
1469 | "Training Loss: 15.175\n"
1470 | ]
1471 | },
1472 | {
1473 | "name": "stderr",
1474 | "output_type": "stream",
1475 | "text": [
1476 | "4it [00:00, 6.26it/s]\n"
1477 | ]
1478 | },
1479 | {
1480 | "name": "stdout",
1481 | "output_type": "stream",
1482 | "text": [
1483 | "Validation Loss: 18.604\n",
1484 | "Score: -21.144\n",
1485 | "Epoch: 8/10\n"
1486 | ]
1487 | },
1488 | {
1489 | "name": "stderr",
1490 | "output_type": "stream",
1491 | "text": [
1492 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
1493 | ]
1494 | },
1495 | {
1496 | "name": "stdout",
1497 | "output_type": "stream",
1498 | "text": [
1499 | "Training Loss: 15.004\n"
1500 | ]
1501 | },
1502 | {
1503 | "name": "stderr",
1504 | "output_type": "stream",
1505 | "text": [
1506 | "4it [00:00, 6.27it/s]\n"
1507 | ]
1508 | },
1509 | {
1510 | "name": "stdout",
1511 | "output_type": "stream",
1512 | "text": [
1513 | "Validation Loss: 18.593\n",
1514 | "Score: -27.508\n",
1515 | "Epoch: 9/10\n"
1516 | ]
1517 | },
1518 | {
1519 | "name": "stderr",
1520 | "output_type": "stream",
1521 | "text": [
1522 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
1523 | ]
1524 | },
1525 | {
1526 | "name": "stdout",
1527 | "output_type": "stream",
1528 | "text": [
1529 | "Training Loss: 14.578\n"
1530 | ]
1531 | },
1532 | {
1533 | "name": "stderr",
1534 | "output_type": "stream",
1535 | "text": [
1536 | "4it [00:00, 6.28it/s]\n"
1537 | ]
1538 | },
1539 | {
1540 | "name": "stdout",
1541 | "output_type": "stream",
1542 | "text": [
1543 | "Validation Loss: 18.264\n",
1544 | "Score: -19.113\n",
1545 | "Epoch: 10/10\n"
1546 | ]
1547 | },
1548 | {
1549 | "name": "stderr",
1550 | "output_type": "stream",
1551 | "text": [
1552 | "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
1553 | ]
1554 | },
1555 | {
1556 | "name": "stdout",
1557 | "output_type": "stream",
1558 | "text": [
1559 | "Training Loss: 14.330\n"
1560 | ]
1561 | },
1562 | {
1563 | "name": "stderr",
1564 | "output_type": "stream",
1565 | "text": [
1566 | "4it [00:00, 6.27it/s]\n"
1567 | ]
1568 | },
1569 | {
1570 | "name": "stdout",
1571 | "output_type": "stream",
1572 | "text": [
1573 | "Validation Loss: 17.739\n",
1574 | "Score: -18.628\n"
1575 | ]
1576 | }
1577 | ],
1578 | "source": [
1579 | "model_weights = './task04_model_weights.pth'\n",
1580 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
1581 | "model = Model().to(device)\n",
1582 | "criterion = RMSELoss\n",
1583 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
1584 | "epochs = 10\n",
1585 | "train_losses, valid_losses = [], []\n",
1586 | "scores = []\n",
1587 | "best_score = float('-inf')\n",
1588 | "preds = np.zeros((len(y_valid),24))\n",
1589 | "\n",
1590 | "for epoch in range(epochs):\n",
1591 | " print('Epoch: {}/{}'.format(epoch+1, epochs))\n",
1592 | " \n",
1593 | " # 模型训练\n",
1594 | " model.train()\n",
1595 | " losses = 0\n",
1596 | " for data, labels in tqdm(trainloader):\n",
1597 | " data = data.to(device)\n",
1598 | " labels = labels.to(device)\n",
1599 | " optimizer.zero_grad()\n",
1600 | " pred = model(data)\n",
1601 | " loss = criterion(pred, labels)\n",
1602 | " losses += loss.cpu().detach().numpy()\n",
1603 | " loss.backward()\n",
1604 | " optimizer.step()\n",
1605 | " train_loss = losses / len(trainloader)\n",
1606 | " train_losses.append(train_loss)\n",
1607 | " print('Training Loss: {:.3f}'.format(train_loss))\n",
1608 | " \n",
1609 | " # 模型验证\n",
1610 | " model.eval()\n",
1611 | " losses = 0\n",
1612 | " with torch.no_grad():\n",
1613 | " for i, data in tqdm(enumerate(validloader)):\n",
1614 | " data, labels = data\n",
1615 | " data = data.to(device)\n",
1616 | " labels = labels.to(device)\n",
1617 | " pred = model(data)\n",
1618 | " loss = criterion(pred, labels)\n",
1619 | " losses += loss.cpu().detach().numpy()\n",
1620 | " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
1621 | " valid_loss = losses / len(validloader)\n",
1622 | " valid_losses.append(valid_loss)\n",
1623 | " print('Validation Loss: {:.3f}'.format(valid_loss))\n",
1624 | " s = score(y_valid, preds)\n",
1625 | " scores.append(s)\n",
1626 | " print('Score: {:.3f}'.format(s))\n",
1627 | " \n",
1628 | " # 保存最佳模型权重\n",
1629 | " if s > best_score:\n",
1630 | " best_score = s\n",
1631 | " checkpoint = {'best_score': s,\n",
1632 | " 'state_dict': model.state_dict()}\n",
1633 | " torch.save(checkpoint, model_weights)"
1634 | ]
1635 | },
1636 | {
1637 | "cell_type": "code",
1638 | "execution_count": 28,
1639 | "metadata": {
1640 | "execution": {
1641 | "iopub.execute_input": "2021-11-07T11:41:55.621271Z",
1642 | "iopub.status.busy": "2021-11-07T11:41:55.620191Z",
1643 | "iopub.status.idle": "2021-11-07T11:41:55.623008Z",
1644 | "shell.execute_reply": "2021-11-07T11:41:55.623387Z",
1645 | "shell.execute_reply.started": "2021-11-07T11:31:56.277287Z"
1646 | },
1647 | "papermill": {
1648 | "duration": 0.31815,
1649 | "end_time": "2021-11-07T11:41:55.623547",
1650 | "exception": false,
1651 | "start_time": "2021-11-07T11:41:55.305397",
1652 | "status": "completed"
1653 | },
1654 | "tags": []
1655 | },
1656 | "outputs": [],
1657 | "source": [
1658 | "# 绘制训练/验证曲线\n",
1659 | "def training_vis(train_losses, valid_losses):\n",
1660 | " # 绘制损失函数曲线\n",
1661 | " fig = plt.figure(figsize=(8,4))\n",
1662 | " # subplot loss\n",
1663 | " ax1 = fig.add_subplot(121)\n",
1664 | " ax1.plot(train_losses, label='train_loss')\n",
1665 | " ax1.plot(valid_losses,label='val_loss')\n",
1666 | " ax1.set_xlabel('Epochs')\n",
1667 | " ax1.set_ylabel('Loss')\n",
1668 | " ax1.set_title('Loss on Training and Validation Data')\n",
1669 | " ax1.legend()\n",
1670 | " plt.tight_layout()"
1671 | ]
1672 | },
1673 | {
1674 | "cell_type": "code",
1675 | "execution_count": 29,
1676 | "metadata": {
1677 | "execution": {
1678 | "iopub.execute_input": "2021-11-07T11:41:56.271045Z",
1679 | "iopub.status.busy": "2021-11-07T11:41:56.270383Z",
1680 | "iopub.status.idle": "2021-11-07T11:41:56.549178Z",
1681 | "shell.execute_reply": "2021-11-07T11:41:56.549615Z",
1682 | "shell.execute_reply.started": "2021-11-07T11:31:56.286373Z"
1683 | },
1684 | "papermill": {
1685 | "duration": 0.611301,
1686 | "end_time": "2021-11-07T11:41:56.549765",
1687 | "exception": false,
1688 | "start_time": "2021-11-07T11:41:55.938464",
1689 | "status": "completed"
1690 | },
1691 | "tags": []
1692 | },
1693 | "outputs": [
1694 | {
1695 | "data": {
1696 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA51ElEQVR4nO3dd3hU1dbA4d9KIyGEHnoJNbQQSkBAelGaYENERVRsoGLlil4/9d7rtVcsSBELIIrYBRWpAQUh9E7oRFoSeifJ/v7YgzdAAikzc2aS9T5PHmbOzJy9zsywZp99dhFjDEop5W8CnA5AKaXyQpOXUsovafJSSvklTV5KKb+kyUsp5Zc0eSml/JImr0JKRH4WkUHufq6TRGS7iHT1wH7nisjdrtu3isiMnDw3D+VUE5FjIhKY11gLk0KVvDz15fYW1xf73F+GiJzMdP/W3OzLGNPDGPOpu5/ri0RkhIjEZ7G9rIicEZFGOd2XMWaSMeYqN8V13vfRGLPTGFPMGJPujv1fUJYRkeOu70qqiMwSkf65eH1HEUlyd1z5UaiSl79zfbGLGWOKATuBazJtm3TueSIS5FyUPmki0EZEalyw/WZgtTFmjQMxOSHW9d2JBj4B3hOR55wNKe80eQEiUkRE3haR3a6/t0WkiOuxsiLyk4gcEpEDIjJfRAJcjz0pIn+JyFER2SgiXbLZfwkR+UxEkkVkh4g8k2kfd4jIAhF5XUQOisg2EemRy/g7ikiSK569wMciUsoVd7Jrvz+JSJVMr8l8KnTJGHL53BoiEu96T2aKyPsiMjGbuHMS439E5HfX/maISNlMjw90vZ+pIvLP7N4fY0wSMBsYeMFDtwOfXS6OC2K+Q0QWZLrfTUQ2iMhhEXkPkEyP1RKR2a74UkRkkoiUdD02AagG/OiqDf1DRKJcNaQg13MqicgPru/dZhG5J9O+nxeRKa7v1VERWSsicdm9Bxe8HynGmAnAEOApESnj2uedIrLetb+tInKfa3s48DNQSf5X068kIi1FZKHr/8YeEXlPREJyEoM7aPKy/gm0ApoAsUBL4BnXY48DSUAkUB54GjAiEg08CLQwxkQAVwPbs9n/u0AJoCbQAfuf5s5Mj18BbATKAq8CH4mIXLiTy6gAlAaqA/diP9uPXferASeB9y7x+tzEcKnnfg4sBsoAz3NxwsgsJzHegn2vygEhwBMAItIAGOXafyVXeVkmHJdPM8fi+vyauOLN7Xt1bh9lgW+w35WywBbgysxPAV5yxVcfqIp9TzDGDOT82vOrWRTxBfa7Vwm4EXhRRDpneryP6zklgR9yEvMFvgeCsN93gP1Ab6A49j1/S0SaGWOOAz2A3Zlq+ruBdOBR17G3BroAQ3MZQ94ZYwrNHza5dM1i+xagZ6b7VwPbXbf/jf2Qa1/wmtrYD7srEHyJMgOBM0CDTNvuA+a6bt8BbM70WFHAABVyeixAR1cZoZd4fhPgYKb7c4G7cxJDTp+L/Y+fBhTN9PhEYGIOP5+sYnwm0/2hwC+u288CX2R6LNz1Hlz0+WaK8wjQxnX/v8D3eXyvFrhu3w4syvQ8wSabu7PZ77XA8uy+j0CU670Mwia6dCAi0+MvAZ+4bj8PzMz0WAPg5CXeW8MF32HX9r3Ardm85jvg4UzfsaTLfH6PAN/m5LN2x5/WvKxKwI5M93e4tgG8BmwGZriq0iMAjDGbsR/W88B+EflCRCpxsbJAcBb7r5zp/t5zN4wxJ1w3i+XyGJKNMafO3RGRoiIy2nVadQSIB0pK9leychNDds+tBBzItA1gV3YB5zDGvZlun8gUU6XM+za2dpCaXVmumL4CbnfVEm8FPstFHFm5MAaT+b6IlHd9L/5y7Xci9vuQE+fey6OZtmX7vcG+N6GSi/ZOEQnGnlEccN3vISKLXKeph4Cel4pXROq6TrH3uo7vxUs93900eVm7sacM51RzbcMYc9QY87gxpia2mv6YuNq2jDGfG2Paul5rgFey2HcKcDaL/f/l5mO4cHqQx7ENs1cYY4oD7V3bc3s6mht7gNIiUjTTtqqXeH5+YtyTed+uMstc5jWfAjcB3YAI4Md8xnFhDML5x/si9nOJce33tgv2eakpXXZj38uITNvc/b3pi60pLxbbxvs18DpQ3hhTEpieKd6sYh0FbADquI7vaTz7/TpPYUxewSISmukvCJgMPCMika52jGexv5KISG8Rqe36Yh7GVuUzRCRaRDq7PvRT2HaSjAsLM/ay9xTgvyISISLVgcfO7d+DIlwxHRKR0oDHryoZY3YACcDzIhIiIq2BazwU41Sgt4i0dTUS/5vLf5/nA4eAMdhTzjP5jGMa0FBErnd9j4ZhT5/PiQCOAYdFpDIw/ILX78O2g17EGLML+AN4yfU9bQwMxg3fGxEpLbZrzfvAK8aYVGx7YhEgGUgTexEmc5eQfUAZESlxwfEdAY6JSD3sBQCvKYzJazr2i3ru73ngBex/ulXAamCZaxtAHWAm9ku4EPjAGDMH+0G/jK1Z7cU2KD+VTZkPAceBrcACbCPxePce1kXeBsJc8S0CfvFweefcim28TcW+h18Cp7N57tvkMUZjzFrgAex7uQc4iG1vutRrDPZUsbrr33zFYYxJAfphvwep2O/K75me8i+gGfZHbxq2cT+zl7A/modE5IksihiAbQfbDXwLPGeMmZmT2LKxUkSOYZtB7gYeNcY86zqWo9jkOwX7Xt6CvQhw7lg3YH/kt7rirYS9eHILcBQYi/2svUZcDW1KeYSIfAlsMMb4bX8i5ZsKY81LeZCItBDbvylARLpj21W+czgsVQBpT2zlbhWwp0dlsKdxQ4wxy50NSRVEetqolPJLetqolPJLfnHaWLZsWRMVFeV0GEopL1u6dGmKMSYyq8f8InlFRUWRkJDgdBhKKS8TkR3ZPaanjUopv6TJSynllzR5KaX8kiYvpZRf0uSllPJLmryUUn5Jk5dSyi9p8vJlB7bC/g1OR6GUT9Lk5auMgcm3wMc94MQBp6NRyudo8vJVO36H5PVw8gDM+pfT0SjlczR5+aol4yC0JMQNhqWfQpIOj1IqM01evujoXlj/IzS9Dbr9CyIqwE+PQobbV4FXym95LHmJyHgR2S8iazJti3WtsLtaRH4UkeKeKt+vLf0UMtIg7i4oEgFXvwh7V8GSj5yOTCmf4cma1ydA9wu2jQNGGGNisAsKXLiaiko/C0s/hlpdoEwtu63hdVCzE8x+AY7uczY+pXyEx5KXMSYe12KWmdTFLugJ8Btwg6fK91sbp8PRPdDi7v9tE4Ger0PaSfjt/5yLTSkf4u02r7XYBRnALhmV7YKkInKviCSISEJycrJXgvMJS8ZBiapQ9+rzt5etDVc+DKu+hO0LnIlNKR/i7eR1FzBURJZiF6w8k90TjTFjjDFxxpi4yMgsJ1IseJI3wrZ4iLsTArJYab7tY1CyGkx73J5eKlWIeTV5GWM2GGOuMsY0xy5gucWb5fu8JR9BYAg0vT3rx0OKQo9XIXkDLPrAu7Ep5WO8mrxEpJzr3wDgGeBDb5bv004fg5WTocG1UOwSNc3oHhDdE+a+DIcvuUC0UgWaJ7tKTAYWAtEikiQig4EBIrIJ2IBdwvxjT5Xvd1ZPgdNHzm+oz073l+3woV+e8nxcSvkojy3AYYwZkM1D73iqTL9ljD1lrBADVVte/vmlqkP7J2D2fyBxJtTp6vkYlfIx2sPeF+z6E/atsbUukZy9ps1DUKY2TH8Czp7ybHxK+SBNXr5g8VgoUgJi+uX8NUFFbN+vg9vg97c9FppSvkqTl9OO7Yd130OTWyAkPHevrdUJGl4P89+0c38pVYho8nLass8g4yy0GJy311/9XwgMhp+ftG1nShUSmryclJ4GCR9DjQ5Qtk7e9lG8EnR6GhJnwIaf3BufUj5Mk5eTEn+FI0nQ8p787aflfVCuIfw8As4cd09sSvk4TV5OWjwWileGuj3yt5/AIOj1hk2E8151T2xK+ThNXk5J2Qxb50DzO23yya/qraHJrbDwPV20QxUKmryckjAeAoKgWTbjGPOi278hpJjt+6WN96qA81gPe3UJZ07AiolQvw9ElHfffsPLQpdnYdpjsHoqNM5FvzGVc+lnYeUXsHe1vdIbEAgBwfbHKDAo0+0LH8vqftDFtzPfDwyG4lUgQOsZF9Lk5YQ1U+HU4fw31Gel+R2wfCL8+jTUvQpCS7i/jMLqXNKKfxUO7YSQCDAZdsrujLP2tidE1oPO/wf1euV8BEYhoMnL24yxDfXlGkC11u7ff0Cgbbwf2xnmvAg9XnF/GYVNepodOD/vFTi4HSo1hZ5vQJ1u5yeTjHOJzJXMMtJtwrvovmtbeubnZnP/1GFYPBq+vBUqx0HX56BGe8feCl+iycvbkhLsYhq93vDcr2jlZnbxjsVjbM/9irGeKaegy0i3p9/zXoEDW6BCYxjwBdTtnvVnFxAAASFAiHvjiLsLVn5up0H69Bqo1dk2D1Rq6t5y/IyeSHvbknH2dKNxf8+W0+X/IKy0nXU1w0OnMwXVuaT1/hXw7b0QXBT6T4L74u18at4+dQt0Xdh5aBlc9V/YvQLGdIQpt0NKondj8SGavLzpeCqs/QZib7ZLmnlSWCm46j+QtASWT/BsWQVFRgas+QY+aA1fD7aN5Td9ZpNW/d7OtzcFh0KbB+HhldDhSdg8yybY7x8slBNTavLypuWfQfqZvI9jzK3YAVCtDcx8ziZOlbWMDDs4/sMrYeqdNknd+DHc/zs06Ot7V/pCi9shYcNWQMt77aIsI5vBr/8sVJ+zj30qBVhGuu3bFdUOytX3Tpki0Ot1OHUEZj3vnTL9iTGw/icY3c6egqWfhRs+giF/QKPrfS9pXahYJPR4GR5aaqdTWvQBvBMLc1+B00edjs7jfPzTKUASf7OX171V6zqnfENoNcTOXrFriXfL9lXGwMafYXR7exXv7Em4bgw88CfE3Jj1yk2+rGQ1uPZ9GLIQanWEuS/CO01g0ShIO+10dB6jyctbloyDYhWgXm/vl91xBERUtJ1X09O8X76vMAY2zYCxnWDyzXbNgGtHwQOLIba//yWtC5WrB/0nwt2z7Y/WLyPg3eawfJKt+Rcwmry84cBW2DzTdiANDPZ++UUioPtLtotGwkfeL99pxtj3f1xX+LwfnEiFPu/Bgwm2K4k7xpb6kirNYdAPMPA7O+ri+6H2IsT6HwvUsDFNXt6QMB4kAJoPci6GBtdCzU4w+wU4us+5OLzJGNgyBz66CibeAMf2wTXvwINLodlAZ35IvKlWJ7hnDtw0ATDw5W0wrgtsned0ZG6hycvTzp60w3Xq97YTBzpFxM55n3YKZjzjXBzesi0ePu4BE66FI39BrzdtP6nmd0CQmzuR+jIRaNDHtof1fd/+cH3WBz7rC38tczq6fNHk5Wlrv4WTB3O2HqOnla0NVz5sh7psm+90NJ6RugU+6W17oh/cbhP2sOX2QklhSloXCgyCprfZK5NXv2QHlY/tBF8OhORNTkeXJ2L84Bw4Li7OJCQkOB1G3ozpZGc3feBP5zs5gp3R4oMrICgM7l9QsP5Dp5+1YzoP7YCOT9taVnCo01H5ptNHYeH78Md7cPa4bfvr9h8oWtrpyM4jIkuNMXFZPaY1L0/6aynsXpa79Rg9LaQo9HgNUjbafkEFyfw37EWJvu9Dq/s1cV1KkQh7FfrhldBqKKz80l7QSNnsdGQ5psnLk5aMh+Bwexnel0R3h+iedsDxoV1OR+Meu1dA/GsQcxPUv8bpaPxHeBm7AtWgH+HUIdugvy3e6ahyRJOXp5w4YOftanyTb86p1f1lezXu16ecjiT/0k7Dd0OgaFmdAiivqreGu2dBRAWYcB0s/dTpiC5Lk5enrJhkr+z5QkN9VkpVh/ZP2L4/ib85HU3+zH0Z9q+DPiN9rs3Gr5SuAYNn2KX4fhxmx0r6cOdWTV6ekJEBSz6ykw1WaOR0NNlr8xCUqWPnvD970ulo8iYpAX5/G5rcBnWvdjoa/xdaAm6ZYpfTW/gefHGLz46T1OTlCVtmw8FtvlvrOieoiB24fXC77bzqb86ehG/vh4hK0P1Fp6MpOAKDoOertptJ4m8wvrtPto0WqOSVuO8oA8Ys4q9DDtciloyD8Ei7wIavq9kR4gbbX9lVXzkdTe7MfgFSE6Hvu77ZrujvWt4Dt35lJxQY29nWcn1IgUpeRYsEsXTnQd6c4WCnu4M7YNMv0GyQ//Sh6v6ynffrhwf9p9f1jj9sP6W4u+y0yMozaneBwb/ZLjYf97QzzPoIjyUvERkvIvtFZE2mbU1EZJGIrBCRBBFp6c4yK5cM4842UXyzPIl1u4+4c9c5t/Rj26cr7k5nys+LoBA7Y2h4JHxxq++PfTxzHL4baqeC6fYfp6Mp+MrVszNVVG5uZ5id+7JPDPD2ZM3rE6D7BdteBf5ljGkCPOu671ZDO9ameGgwr/ziwKrRaaftvFnRPaFEFe+Xnx/FIuHmz21fny9v8+15oGY+b9sUr/0AihRzOprCIbwM3P4dxN4Cc1+Cr+92/CKPx5KXMSYeOHDhZqC463YJYLe7yy1RNJgHO9Vm3qZkft+c4u7dX9ra7+x0K96ecNBdKja2CSFpsZ37ywd+XS+ydZ5dFemKIRDV1uloCpegIvb70fV524fxk96O1tK93eb1CPCaiOwCXgey7SEpIve6Ti0TkpOTc1XIwNbVqVwyjJd+Xk9Ghhf/Ay4ZB2VqQ42O3ivT3RpeB+2H25kw/hztdDTnO3XELjZRupZd+kt5nwi0fdROerh/ne2Rv3fN5V/nAd5OXkOAR40xVYFHgWxnxjPGjDHGxBlj4iIjI3NVSGhwIE9cXZc1fx3hx1Vur9xlbc9KW2OJG+z7c59fTsenIbqXXXV761yno/mfGc/AkSQ7+2lIUaejKdzqXwN3/mwXxx1/NWz8xeshePt/2SDgG9ftrwC3Nthn1je2Mg0qFue1XzdyOs0LvYSXjLMzNTQZ4PmyPC0gAK4fDWXrwpRBdiZYpyXOhGWfQusHodoVTkejACo1gXtm27ONyTfbGSq82NTg7eS1G+jgut0Z8NiKmQEBwlM965F08CQTF+30VDHWyUO2j1Tjfna9xIKgSAQM+NzenuxwL+uTh+CHhyCyHnT6p3NxqIsVr2RrYPWvgRn/hB8ftlMTeYEnu0pMBhYC0SKSJCKDgXuAN0RkJfAicK+nygdoVyeSdnXK8u7sRA6f9OAbunIypJ30/R71uVW6JvT7BFI2wTf3Obfy9i8j7BTO147SaW58UUhR6PcptHvc1o4nXm8nJvAwT15tHGCMqWiMCTbGVDHGfGSMWWCMaW6MiTXGXGGMWeqp8s95sns9Dp04y4fztnimgIwMe8pYpSVUjPVMGU6q1clOmbJxmr1E7m0bptsfh3aPQeVm3i9f5UxAgL2Ict1o2LnIK3OD+XnL8uU1qlyC65pWZvyCbew57IF+KdvmQermglfryuyK++3A5/hX7bTW3nLigD0NKR8D7f/hvXJV3sXeDLf/4JW5wQp88gJ4rFtdjMEzw4aWjIOiZeyy8AWVCPR+09YuvxsKe1Z5p9zpT9j5/68b5T9DrZTX5gYrFMmraumiDGpTna+XJbFxrxsbng8nwcbp0Oz2gt8WE1TE9u0JLWmHEB33cAfgtd/Bmq+hw5NQIcazZSn388LcYIUieQE80Kk2xYoEuXfY0NJP7KXh5n40jjE/IsrDzZPg+H6YcjuknfFMOceSbQ//ik1sh0jlnzw8N1ihSV4li4YwtFNtZm/Yz8ItqfnfYdoZWx2ue7WdlbSwqNzMrja943f45Un3798Y+OkR+yW/7sOCt5p1YZPV3GBuGlJUaJIXwB1toqhUItQ9w4bW/2BrIC3ucU9w/qRxP7v+Y8J4O2OsO63+Cjb8ZPtzlavv3n0r55ybG6x4JQgr6ZZdFqrkFRocyGNXRbMq6TDTVu/J386WfASlogrvXFJdnoM6V8HP/4Dtv7tnn0f22Eb6Ki3tFNWqYKndxSawoCJu2V2hSl4A1zWtTL0KEbz260bOpOWx0+W+tbDzj4IxjjGvAgLhhnFQqgZMGWhn28wPY2zDbtoZ2xk1INA9caoCq9D9zwsMEEb0qMfOAyf4/M8dud9BRoYdsBwcbpdPL8xCS8CAyZCeZocQnTme930tnwiJM6Drc1C2tvtiVAVWoUteAB3qRtKmVhlGzt7M0VO5HDa0ZKydaeHqF3SZLYCydeDGj2DfGtsHLC8Dcw/tgl+egupt7ZUppXKgUCYvEeGpHvU5cPwMo+flYsaE5E3w27O2raewdI/IiTrdoNu/YN13EP967l5rjJ0732RA3/cK72m4yrVC+02JqVKCPrGVGLdgK3sPn7r8C9LPwrf3QnBR6POu7XWu/qfNMGjcH+a8ABum5fx1CeNtTfaq/9iOjUrlUKFNXgDDr44mPcPw9swcDBuKfx12L4feb9lhD+p8InDNO1CpKXxzL+xbd/nXHNgGM/4PanayqwAplQuFOnlVLV2Uga2imJKwi8R9l+j5m7QU4l+DxjdDw2u9Fp/fCQ6zi3iEhMMXAy49LUpGhp3SOSDQni5qTVblUqFOXgAPdq5NeMglhg2dOWFPFyMq2p7C6tKKV4L+k+DIbvjqDnslMiuLR8OOBXD1i/630pLyCYU+eZUOD+H+jrWYuX4/f27NYtjQb8/aKW+u/UBXZc6pqi2g99t2uqAZz1z8eMpmmPkve+GjsHc3UXlW6JMXwF1X1qBC8VBe+nkDJvOl/s2zbNeIVkOhZofsd6Au1vRW+779Ocr24TonIx2+G2J7WV8zUk8XVZ5p8gLCQgJ5rFtdVuw6xM9r9tqNJw7A9w9A2WhdZiuvuv3HNsb/9CjsWmy3/fGuXWWp52tQvKKz8Sm/psnL5YbmVahbvhiv/rKBs+kZdozd8WS4foxtiFa5FxgEN46H4pXtHGCbZ8Kc/0K93hDTz+nolJ/T5OVybtjQ9tQTLPx+tGsivBF2eSeVd0VL2yFEZ0/AxBvsqkS939bTRZVvmrwy6RRdju7VMmiy6j+kV4rTifDcpVx9O4g7JML2BSuWu0WElcqKzvSWiRjDq0GjCTJpfFp+BHfpRHjuE90Dntyukwsqt9GaV2ZLxlF893y+Lz+U1xLS2X8kB8OGVM5p4lJupMnrnJRE26erdjda93uCs+kZvD3LYwt6K6XySZMX2EHX39xrVwDq+x5RkcW4rVV1vlyyi837jzkdnVIqC5q8AOa/AbuXnTfo+qHOtQkLDuRVd642pJRyG01efy2Fea9CzE3Q8Lq/N5cpVoT7O9Rkxrp9JGy/xABjpZQjCnfyOnMCvrnP1rZ6vnbRw3e1rUG5iCK8OH39+cOGlFKOK9zJa+bzkJpoB11nsRxT0ZAgHu1Wl2U7D/HrWvesNaeUco/Cm7y2zLbTslwxBGp2zPZp/ZpXoXa5TMOGlFI+oXAmr5MH4TvXoOuuz13yqUGBATzZvR5bU47z5ZJdXgpQKXU5hTN5TXvCrnZ9/egcDbruWr8cLaJK8fbMRI6fzmZyPaWUV3kseYnIeBHZLyJrMm37UkRWuP62i8gKT5WfrTVfw5qp0OFJO996DogIT/WsT8qx04ydn4vVhpRSHuPJmtcnQPfMG4wx/Y0xTYwxTYCvgW88WP7FjuyGnx6DynHQ9rFcvbRZtVL0aFSBMfFbST562kMBKqVyymPJyxgTD2TZQUpEBLgJmOyp8rMIyC74kHYarhudp3F2w6+O5kxaBu/MysFqQ0opj3KqzasdsM8Yk+3gQRG5V0QSRCQhOTk5/yUuGQdbZtn1AfO4nHzNyGIMaFmNyYt3sTVZhw0p5SSnktcALlPrMsaMMcbEGWPiIiPzOf9TSqJdH7BWF2hxd752NaxLHUKDAnhx+gbStOuEUo7xevISkSDgeuBLrxSYngbf3mcXfOj7fr5n8IyMKMLQTrWZuX4fXd+cx7fLk0jP0N73SnlbjpKXiISLSIDrdl0R6SMiwXkssyuwwRiTlMfX5878N+z4xd5vuW3Bh6EdazH29jjCQoJ49MuVXPXWPH5YuVuTmFJelNOaVzwQKiKVgRnAQOzVxGyJyGRgIRAtIkkiMtj10M14q6H+r2Uw7xW72EOj6922WxGhW4PyTHuoLR/e1oyggACGTV5Oj3fimb56DxmaxJTyOMnJgGMRWWaMaSYiDwFhxphXRWSFq8uDx8XFxZmEhITcvejsSRjdHk4fg6F/QFgpzwQHZGQYpq/Zw9szE9m8/xj1KkTwaLe6XNWgPKILTSiVZyKy1BgTl9VjOa15iYi0Bm4Fprm2BbojOI+Z+TykbHINuvZc4gIICBB6N67Er4+0552bm3AmLYP7Jiyl97sLmLlun85IoZQH5DR5PQI8BXxrjFkrIjWBOR6LKr+2zIY/P4SW90GtTl4rNjBA6NukMjMebc8b/WI5eiqNuz9LoO/7vzNn435NYkq5UY5OG897gW24L2aMOeKZkC6Wq9PGkwfhgzYQEg73xUNIUc8Gdwln0zP4dtlfjJydSNLBkzStVpLHutWlbe2yejqpVA7k+7RRRD4XkeIiEg6sAdaJyHB3Buk204fDsX120LWDiQsgODCAm1pUZfbjHXnxuhj2HT7FwI8Wc9PohfyxJcXR2JTydzk9bWzgqmldC/wM1MBecfQtW2bD6q/soOvKzZ2O5m8hQQHcckU15gzvyH/6NmTngRPcMvZPbh6zkMXbdIpppfIip8kr2NWv61rgB2PMWcD3GnBqdIC+H0C7x52OJEtFggIZ2DqKecM78dw1DdiSfJybRi/ktnF/snSHJjGlciOnyWs0sB0IB+JFpDrgtTavHAsIhKa3+vzipqHBgdx5ZQ3ih3fimV71Wb/nCDeMWsig8YtZseuQ0+Ep5Rdy3WD/9wtFgowxXpmZL0/9vPzIiTNpfLZwB6PnbeHgibN0qVeOR7vVpVHlEk6HppSjLtVgn9NOqiWA54D2rk3zgH8bYw67LcpLKOjJ65xjp9P49I/tjInfyuGTZ+nWoDxPXBVNdIUIp0NTyhHu6KQ6HjiKnYPrJuwp48fuCU+dU6xIEA90qs38JzvxWLe6LNqayjXvLmDCwu3aR0ypC+S05nXRUCCfHx5UABw4fobHp6xgzsZkesZU4OUbGlM8NK/j4ZXyP+6oeZ0UkbaZdnglcNIdwanslQ4P4aNBLRjRox6/rt1H75ELWJV0yOmwlPIJOU1e9wPvuxbN2A68B9znsajU3wIChPs71GLKfa1IS8/ghlF/8Mnv2/Q0UhV6OUpexpiVxphYoDHQ2BjTFOjs0cjUeZpXL820Ye3oUDeS539cx5CJyzh88qzTYSnlmFzNpGqMOZJpTGPult9R+VYqPISxt8fxTK/6zFy/j14j52u/MFVo5WcaaB1Z7AAR4e52NZlyf2uMgX4f/sG4+Vv1NFIVOvlJXvq/xUHNqpVi+rB2dIwuxwvT1nPPZ0s5dOKM02Ep5TWXTF4iclREjmTxdxSo5KUYVTZKFA1mzMDmPNu7AfM27afXyAUs23nQ6bCU8opLJi9jTIQxpngWfxHGGN8eQFhIiAh3ta3B1PvbEBAAN324kDHxW3QefVXgObVuo3Kz2Kol+emhdnStX54Xp2/g7s8SOHhcTyNVwaXJqwApERbMqNua8a8+DVmQmELPkfN1qh1VYGnyKmBEhEFtovh6SBtCggK4afQiRs3V00hV8GjyKqBiqpTgx4fa0r1RBV75ZQN3fbqE1GOnnQ5LKbfR5FWAFQ8N5r0BTXnh2kb8sSWVXiMX6LTTqsDQ5FXAiQi3tarOt0PbEBYSyICxi3h/zmY9jVR+T5NXIdGwkj2N7BVTkdd+3cigjxeToqeRyo9p8ipEihUJ4p2bm/DS9TEs3naAnu/MZ+GWVKfDUipPNHkVMiLCgJbV+O6BKykWGsSt4xYxclYi6XoaqfyMJq9Cqn7F4vz4YFv6NqnMm79t4voPftc+YcqvaPIqxMKLBPHmTbG8c3MT9h05zQ2jFvLQ5OUkHTzhdGhKXZYmr0JOROjbpDKzn+jAsC51mLF2L13emMcbMzZy/LRXVrZTKk80eSkAioYE8Vi3usx+oiPdG1Xg3dmb6fT6XKYuTdJuFconafJS56lcMox3bm7K10PaULFkGE98tZJrP/idJdu1PUz5Fo8lLxEZLyL7RWTNBdsfEpENIrJWRF71VPkqf5pXL8W3Q9rwdv8m7D9ymn4fLuSBz5dpe5jyGZ6seX0CdM+8QUQ6AX2BWGNMQ+B1D5av8ikgQLi2qW0Pe7hLHWat30fnN+bx+q/aHqac57HkZYyJBy481xgCvGyMOe16zn5Pla/cp2hIEI92q8vsxzvSs1EF3puzmY6vz+WrhF3aHqYc4+02r7pAOxH5U0TmiUiL7J4oIveKSIKIJCQnJ3sxRJWdSiXDePvmpnwztA2VS4YxfOoq+ryvg72VM7ydvIKA0kArYDgwRUSyXIXIGDPGGBNnjImLjIz0ZozqMppVK8W3Q9vwzs1NSD12hptGL+SBScvYdUDbw5T3eDt5JQHfGGsxkAGU9XIMyg3+7h/2eEce7VqX2Rv20+XNebz6ywaOaXuY8gJvJ6/vgE4AIlIXCAFSvByDcqOwkEAe7lqH2U90oFdMRT6Yu4VOr89liraHKQ/zZFeJycBCIFpEkkRkMDAeqOnqPvEFMMjoaqkFQsUSYbzVvwnfDm1DlVJh/GPqKq55bwF/btVZK5RniD/kjri4OJOQkOB0GCqHjDH8sHI3r/y8gd2HT9EzpgJP9ahP1dJFnQ5N+RkRWWqMicvqMV17UbndufawqxpUYOz8rYyau4WZ6/YzuF0NhnasRURosNMhqgJAhwcpjwkLCWRYlzrMeaIjvWMrMsrVHjZ58U6dP0zlmyYv5XEVSoTy5k1N+P6BK4kqE85T36ym18j5/L5Zr9WovNPkpbwmtmpJvrq/Ne/d0pSjp9K4ddyf3P3pErYmH3M6NOWHNHkprxIRejeuxKzHO/CP7tEs2nqAq96K598/ruPwibNOh6f8iCYv5YjQ4ECGdqzNnCc60i+uCp/8sY0Or8/hk9+3cTY9w+nwlB/Q5KUcFRlRhJeub8y0Ye1oWKk4z/+4ju5vxzN7wz78oRuPco4mL+UT6lcszsTBVzD29jgyDNz1SQK3j1/Mxr1HnQ5N+ShNXspniAjdGpTn10fa83+9G7By1yF6vBPPP79dTaoukKsuoMlL+ZyQoAAGt63BvOGduL11FF8s2UXH1+Yyet4WTqelOx2e8hGavJTPKhUewvN9GvLrI+1pUaM0L/28gW5vxvPz6j3aHqY0eSnfV7tcMcbf0YLP7mpJaHAAQyYto/+YRaxOOux0aMpBmryU32hfN5Lpw9rxwrWN2LL/GH3eX8DjU1ay78gpp0NTDtDkpfxKUGAAt7WqzpzhHbm3XU1+XLmbjq/NZeSsRE6e0fawwkSTl/JLxUODeapnfX57rD0doyN587dNdH5jLt8t/0snQSwkNHkpv1a9TDijbmvOl/e2okyxEB75cgXXffA7v67dq0msgNPJCFWBkZFh+HpZEu/MSiTp4ElqRYZzX/ta9G1aiSJBgU6Hp/LgUpMRavJSBU5aegbTVu/hw3lbWb/nCOWLF2Fw2xoMaFlNJ0L0M5q8VKFkjCE+MYUP525h4dZUIkKDGNiqOndcGUW5iFCnw1M5oMlLFXordx1idPwWfl6zl+DAAG5oVoV729ekRtlwp0NTl6DJSymXbSnHGRO/la+XJXE2PYMejSpwf4daNK5S0unQVBY0eSl1gf1HT/HJ79uZsGgHR0+l0bpmGe7vWIv2dcqSzSLuygGavJTKxtFTZ5m8eCcfLdjGviOnaVCxOPd1qEmvmIoEBWpPIqdp8lLqMk6npfP98t2Mjt/CluTjVC0dxj3tatKveVXCQrSbhVM0eSmVQxkZhpnr9/HhvC0s23mI0uEhDGodxe2tq1MqPMTp8AodTV5K5ZIxhiXbD/LhvC3M3rCfsOBAbm5Zlbvb1aRyyTCnwys0dMVspXJJRGhZozQta5Rm496jjI7fwoSFO5iwcAd9YitxX4daRFeIcDrMQk1rXkrl0F+HTvLR/G18sWQnJ86kc0WN0nSIjqR9nUgaVCxOQIBepXQ3PW1Uyo0OnTjDhIU7mL5mL+v3HAGgTHgIbeuUpX2dSNrVKUu54tqD3x00eSnlIfuPnGLB5hTiNyUzPzGF1ONnAKhXIYL2dW2tLC6qFKHBesUyLzR5KeUFGRmGdXuOMD/RJrOEHQc4m24IDQ7gihplaFenLB3qRlK7XDHtCJtDmryUcsCJM2ks2ppK/KYU4hOT2Zp8HICKJUJpV6cs7epE0rZ2We2CcQmOJC8RGQ/0BvYbYxq5tj0P3AMku572tDFm+uX2pclLFQRJB0+wINEmsgWJKRw5lYYINK5cgnZ1ImlfN5Km1UoSrD37/+ZU8moPHAM+uyB5HTPGvJ6bfWWVvM6ePUtSUhKnTuniC/kVGhpKlSpVCA7Wua68JT3DsDLpEPNdtbIVuw6RnmEoViSI1rXK0L5OWdrXjaR6mcI964Uj/byMMfEiEuWp/SclJREREUFUVJS2H+SDMYbU1FSSkpKoUaOG0+EUGoEBQrNqpWhWrRQPd63D4ZNnWbglhXhXe9lv6/YBUK10UQa3rcHAVtW1K8YFnOik+qCI3A4kAI8bYw5m9SQRuRe4F6BatWoXPX7q1ClNXG4gIpQpU4bk5OTLP1l5TImwYLo3qkj3RhUxxrA99QTxm5KZtnoPz/2wll/X7uW1frHauz8Tb59cjwJqAU2APcAb2T3RGDPGGBNnjImLjIzM8jmauNxD30ffIiLUKBvOoDZRfHlvK166PoaVuw7R/a14vkrYpauFu3g1eRlj9hlj0o0xGcBYoKU3y1fK34gIA1pW45dH2lO/UnGGT13FPZ8tZf9Rbev1avISkYqZ7l4HrPFm+Ur5q6qli/LFPa14pld94hOTufqteKav3uN0WI7yWPISkcnAQiBaRJJEZDDwqoisFpFVQCfgUU+V72mHDh3igw8+yPXrevbsyaFDh3L9ujvuuIOpU6fm+nWq4AgIEO5uV5Ppw9pStXRRhk5axsNfLOfwibNOh+YIT15tHJDF5o88Uda/flzLut1H3LrPBpWK89w1DbN9/FzyGjp06Hnb09LSCArK/m2dPv2y3dqUuqTa5SL4ekgbRs3dwshZiSzamsorNzSmY3Q5p0PzKu0Nl0cjRoxgy5YtNGnShBYtWtCuXTv69OlDgwYNALj22mtp3rw5DRs2ZMyYMX+/LioqipSUFLZv3079+vW55557aNiwIVdddRUnT57MUdmzZs2iadOmxMTEcNddd3H69Om/Y2rQoAGNGzfmiSeeAOCrr76iUaNGxMbG0r59eze/C8opwYEBDOtSh+8euJLiocHc8fESnv52NcdPpzkdmvcYY3z+r3nz5uZC69atu2ibN23bts00bNjQGGPMnDlzTNGiRc3WrVv/fjw1NdUYY8yJEydMw4YNTUpKijHGmOrVq5vk5GSzbds2ExgYaJYvX26MMaZfv35mwoQJ2ZY3aNAg89VXX5mTJ0+aKlWqmI0bNxpjjBk4cKB56623TEpKiqlbt67JyMgwxhhz8OBBY4wxjRo1MklJSedty4rT76fKu5Nn0syL09aZqBE/mbavzDJ/bk11OiS3ARJMNnlBa15u0rJly/M6eY4cOZLY2FhatWrFrl27SExMvOg1NWrUoEmTJgA0b96c7du3X7acjRs3UqNGDerWrQvAoEGDiI+Pp0SJEoSGhjJ48GC++eYbihYtCsCVV17JHXfcwdixY0lPT8//gSqfExocyFM96zPlvtYIQv8xC/nvtHWcOluwP29NXm4SHv6/YRxz585l5syZLFy4kJUrV9K0adMshzEVKVLk79uBgYGkpeW9yh8UFMTixYu58cYb+emnn+jevTsAH374IS+88AK7du2iefPmpKam5rkM5dtaRJXm54fbcUvLaoydv43e7y5gVdIhp8PyGE1eeRQREcHRo0ezfOzw4cOUKlWKokWLsmHDBhYtWuS2cqOjo9m+fTubN28GYMKECXTo0IFjx45x+PBhevbsyVtvvcXKlSsB2LJlC1dccQX//ve/iYyMZNeuXW6LRfme8CJB/Pe6GD69qyXHTqVx3Qd/8NZvmzibnuF0aG6nc9jnUZkyZbjyyitp1KgRYWFhlC9f/u/Hunfvzocffkj9+vWJjo6mVatWbis3NDSUjz/+mH79+pGWlkaLFi24//77OXDgAH379uXUqVMYY3jzzTcBGD58OImJiRhj6NKlC7GxsW6LRfmuDnUj+fWR9jz/41remZXI7A37efOmWOqULzjz7vvtfF7r16+nfv36DkVU8Oj7WXD9smYPT3+7hmOn0xh+VTR3ta1BoJ8M8r7UrBJ62qhUAde9UUV+faQ9HepG8t/p6xkwZhE7U084HVa+afLyMQ888ABNmjQ57+/jjz92Oizl5yIjijBmYHNe7xfL+j1H6P5OPJP+3OHXg7y1zcvHvP/++06HoAooEeHG5lVoXasM/5i6kn9+u4YZa/fxyg2NqVDC/1Y70pqXUoVM5ZJhTLjrCv7dtyF/bkvl6rfj+X7FX35XC9PkpVQhFBAg3N46ip8fbk+tyHAe/mIFXd+cx7j5WzngWr7N12nyUqoQq1E2nK/ub8Pr/WIpERbMC9PW0+rFWQybvJyFW1J9ujambV5KFXKBAbYt7MbmVdiw9whfLN7FN8uS+GHlbmqUDWdAy6rc0KwKZYoVufzOvEhrXl5SrFixbB/bvn07jRo18mI0SmWtXoXiPN+nIYv/2ZU3b4qlbLEQXpy+gVYvzeKBz5fx++YUMjJ8ozZWMGpeP4+Avavdu88KMdDjZffuUyk/ERocyPXNqnB9syok7jvK5MW7+GZ5EtNW7aF6maLc3KIaNzavQmSEc7UxrXnl0YgRI87r1vD888/zwgsv0KVLF5o1a0ZMTAzff/99rvd76tQp7rzzTmJiYmjatClz5swBYO3atbRs2ZImTZrQuHFjEhMTOX78OL169SI2NpZGjRrx5Zdfuu34lDqnTvkInr2mAYue6sI7NzehQvFQXvllA61fmsWQiUuJ35TsTG0su7lyfOnPF+fzWrZsmWnfvv3f9+vXr2927txpDh8+bIwxJjk52dSqVevv+bXCw8Oz3VfmucFef/11c+eddxpjjFm/fr2pWrWqOXnypHnwwQfNxIkTjTHGnD592pw4ccJMnTrV3H333X/v59ChQ3k+HqffT+VfEvcdNS/8tNY0+devpvqTdh6x92Ynmn2HT7q1HHQ+L/dr2rQp+/fvZ/fu3axcuZJSpUpRoUIFnn76aRo3bkzXrl3566+/2LdvX672u2DBAm677TYA6tWrR/Xq1dm0aROtW7fmxRdf5JVXXmHHjh2EhYURExPDb7/9xpNPPsn8+fMpUaKEJw5VqYvULleMf/ZqwKKnuzByQFOqlirKa79upPXLs7n3swTmbNxPuodrYwWjzcsh/fr1Y+rUqezdu5f+/fszadIkkpOTWbp0KcHBwURFRWU5j1de3HLLLVxxxRVMmzaNnj17Mnr0aDp37syyZcuYPn06zzzzDF26dOHZZ591S3lK5USRoED6xFaiT2wltqUc54slO5makMSMdfuoXDKM/i2qclNcVY/04NfklQ/9+/fnnnvuISUlhXnz5jFlyhTKlStHcHAwc+bMYceOHbneZ7t27Zg0aRKdO3dm06ZN7Ny5k+joaLZu3UrNmjUZNmwYO3fuZNWqVdSrV4/SpUtz2223UbJkScaNG+eBo1QqZ2qUDeepHvV5vFs0v63bxxdLdvLmb5t4e+YmOtcrx4CW1egYXc5tM1po8sqHhg0bcvToUSpXrkzFihW59dZbueaaa4iJiSEuLo569erlep9Dhw5lyJAhxMTEEBQUxCeffEKRIkWYMmUKEyZMIDg4+O/T0yVLljB8+HACAgIIDg5m1KhRHjhKpXInJCiAXo0r0qtxRXakHufLJbuYkpDEzPUJVCwRyncPXEn54vmviel8XgrQ91N51tn0DGat30d8Ygr/vbYRIjmrfV1qPi+teSmlPC44MIDujSrSvVFFt+1Tk5cXrV69moEDB563rUiRIvz5558ORaSU//Lr5GWMyXH10xfExMSwYsUKp8O4iD80HSh1Ib/t5xUaGkpqqm+PevcHxhhSU1MJDfW/yehU4ea3Na8qVaqQlJREcnKy06H4vdDQUKpUqeJ0GErlit8mr+Dg4PNWqFZKFS5+e9qolCrcNHkppfySJi+llF/yix72IpIM5GagYFkgxUPheFtBOhYoWMdTkI4FfPN4qhtjIrN6wC+SV26JSEJ2Qwr8TUE6FihYx1OQjgX873j0tFEp5Zc0eSml/FJBTV5jnA7AjQrSsUDBOp6CdCzgZ8dTINu8lFIFX0GteSmlCjhNXkopv1SgkpeIdBeRjSKyWURGOB1PfohIVRGZIyLrRGStiDzsdEz5JSKBIrJcRH5yOpb8EpGSIjJVRDaIyHoRae10THklIo+6vmNrRGSyiPjFFCMFJnmJSCDwPtADaAAMEJEGzkaVL2nA48aYBkAr4AE/Px6Ah4H1TgfhJu8Avxhj6gGx+OlxiUhlYBgQZ4xpBAQCNzsbVc4UmOQFtAQ2G2O2GmPOAF8AfR2OKc+MMXuMMctct49i/3NUdjaqvBORKkAvwO+XOBKREkB74CMAY8wZY8whR4PKnyAgTESCgKLAbofjyZGClLwqA7sy3U/Cj/+zZyYiUUBTwJ/ni34b+AeQ4XAc7lADSAY+dp0GjxORcKeDygtjzF/A68BOYA9w2Bgzw9mocqYgJa8CSUSKAV8DjxhjjjgdT16ISG9gvzFmqdOxuEkQ0AwYZYxpChwH/LKNVURKYc9QagCVgHARuc3ZqHKmICWvv4Cqme5XcW3zWyISjE1ck4wx3zgdTz5cCfQRke3Y0/nOIjLR2ZDyJQlIMsacqwlPxSYzf9QV2GaMSTbGnAW+Ado4HFOOFKTktQSoIyI1RCQE2+j4g8Mx5ZnYlUU+AtYbY950Op78MMY8ZYypYoyJwn4us40xfvHrnhVjzF5gl4hEuzZ1AdY5GFJ+7ARaiUhR13euC35y8cFvp4G+kDEmTUQeBH7FXjEZb4xZ63BY+XElMBBYLSIrXNueNsZMdy4klclDwCTXD+VW4E6H48kTY8yfIjIVWIa9wr0cPxkmpMODlFJ+qSCdNiqlChFNXkopv6TJSynllzR5KaX8kiYvpZRf0uSlPEpE0kVkRaY/t/VEF5EoEVnjrv0p/1Jg+nkpn3XSGNPE6SBUwaM1L+UIEdkuIq+KyGoRWSwitV3bo0RktoisEpFZIlLNtb28iHwrIitdf+eGsASKyFjXfFQzRCTM9fxhrrnQVonIFw4dpvIgTV7K08IuOG3sn+mxw8aYGOA97KwTAO8CnxpjGgOTgJGu7SOBecaYWOw4wnOjJ+oA7xtjGgKHgBtc20cATV37ud8zh6acpD3slUeJyDFjTLEstm8HOhtjtroGoO81xpQRkRSgojHmrGv7HmNMWdeq6VWMMacz7SMK+M0YU8d1/0kg2Bjzgoj8AhwDvgO+M8Yc8/ChKi/Tmpdyksnmdm6cznQ7nf+14/bCzqzbDFjimmhPFSCavJST+mf6d6Hr9h/8bxriW4H5rtuzgCHw91z4JbLbqYgEAFWNMXOAJ4ESwEW1P+Xf9NdIeVpYplkxwM77fq67RCkRWYWtPQ1wbXsIO0PpcOxspedma3gYGCMig7E1rCHYmT+zEghMdCU4AUb6+TTNKgva5qUc4WrzijPGpDgdi/JPetqolPJLWvNSSvklrXkppfySJi+llF/S5KWU8kuavJRSfkmTl1LKL/0/rC1qDE44mlAAAAAASUVORK5CYII=\n",
1697 | "text/plain": [
1698 | ""
1699 | ]
1700 | },
1701 | "metadata": {
1702 | "needs_background": "light"
1703 | },
1704 | "output_type": "display_data"
1705 | }
1706 | ],
1707 | "source": [
1708 | "training_vis(train_losses, valid_losses)"
1709 | ]
1710 | },
1711 | {
1712 | "cell_type": "markdown",
1713 | "metadata": {},
1714 | "source": [
1715 | "#### 模型评估\n",
1716 | "\n",
1717 | "在测试集上评估模型效果。"
1718 | ]
1719 | },
1720 | {
1721 | "cell_type": "code",
1722 | "execution_count": 30,
1723 | "metadata": {
1724 | "execution": {
1725 | "iopub.execute_input": "2021-11-07T11:41:57.180629Z",
1726 | "iopub.status.busy": "2021-11-07T11:41:57.180063Z",
1727 | "iopub.status.idle": "2021-11-07T11:41:57.544721Z",
1728 | "shell.execute_reply": "2021-11-07T11:41:57.544233Z",
1729 | "shell.execute_reply.started": "2021-11-07T11:31:56.594261Z"
1730 | },
1731 | "papermill": {
1732 | "duration": 0.680346,
1733 | "end_time": "2021-11-07T11:41:57.544847",
1734 | "exception": false,
1735 | "start_time": "2021-11-07T11:41:56.864501",
1736 | "status": "completed"
1737 | },
1738 | "tags": []
1739 | },
1740 | "outputs": [
1741 | {
1742 | "data": {
1743 | "text/plain": [
1744 | ""
1745 | ]
1746 | },
1747 | "execution_count": 30,
1748 | "metadata": {},
1749 | "output_type": "execute_result"
1750 | }
1751 | ],
1752 | "source": [
1753 | "# 加载最佳模型权重\n",
1754 | "checkpoint = torch.load('../input/ai-earth-model-weights/task04_model_weights.pth')\n",
1755 | "model = Model()\n",
1756 | "model.load_state_dict(checkpoint['state_dict'])"
1757 | ]
1758 | },
1759 | {
1760 | "cell_type": "code",
1761 | "execution_count": 31,
1762 | "metadata": {
1763 | "execution": {
1764 | "iopub.execute_input": "2021-11-07T11:41:58.181883Z",
1765 | "iopub.status.busy": "2021-11-07T11:41:58.181002Z",
1766 | "iopub.status.idle": "2021-11-07T11:41:58.182836Z",
1767 | "shell.execute_reply": "2021-11-07T11:41:58.183296Z",
1768 | "shell.execute_reply.started": "2021-11-07T11:31:57.001527Z"
1769 | },
1770 | "papermill": {
1771 | "duration": 0.323811,
1772 | "end_time": "2021-11-07T11:41:58.183429",
1773 | "exception": false,
1774 | "start_time": "2021-11-07T11:41:57.859618",
1775 | "status": "completed"
1776 | },
1777 | "tags": []
1778 | },
1779 | "outputs": [],
1780 | "source": [
1781 | "# 测试集路径\n",
1782 | "test_path = '../input/ai-earth-tests/'\n",
1783 | "# 测试集标签路径\n",
1784 | "test_label_path = '../input/ai-earth-tests-labels/'"
1785 | ]
1786 | },
1787 | {
1788 | "cell_type": "code",
1789 | "execution_count": 32,
1790 | "metadata": {
1791 | "execution": {
1792 | "iopub.execute_input": "2021-11-07T11:41:58.817618Z",
1793 | "iopub.status.busy": "2021-11-07T11:41:58.816950Z",
1794 | "iopub.status.idle": "2021-11-07T11:42:00.269201Z",
1795 | "shell.execute_reply": "2021-11-07T11:42:00.268643Z",
1796 | "shell.execute_reply.started": "2021-11-07T11:31:57.010291Z"
1797 | },
1798 | "papermill": {
1799 | "duration": 1.770946,
1800 | "end_time": "2021-11-07T11:42:00.269350",
1801 | "exception": false,
1802 | "start_time": "2021-11-07T11:41:58.498404",
1803 | "status": "completed"
1804 | },
1805 | "tags": []
1806 | },
1807 | "outputs": [],
1808 | "source": [
1809 | "import os\n",
1810 | "\n",
1811 | "# 读取测试数据和测试数据的标签\n",
1812 | "files = os.listdir(test_path)\n",
1813 | "X_test = []\n",
1814 | "y_test = []\n",
1815 | "for file in files:\n",
1816 | " X_test.append(np.load(test_path + file))\n",
1817 | " y_test.append(np.load(test_label_path + file))"
1818 | ]
1819 | },
1820 | {
1821 | "cell_type": "code",
1822 | "execution_count": 33,
1823 | "metadata": {
1824 | "execution": {
1825 | "iopub.execute_input": "2021-11-07T11:42:00.911046Z",
1826 | "iopub.status.busy": "2021-11-07T11:42:00.909973Z",
1827 | "iopub.status.idle": "2021-11-07T11:42:00.973831Z",
1828 | "shell.execute_reply": "2021-11-07T11:42:00.973361Z",
1829 | "shell.execute_reply.started": "2021-11-07T11:31:58.325122Z"
1830 | },
1831 | "papermill": {
1832 | "duration": 0.395072,
1833 | "end_time": "2021-11-07T11:42:00.973969",
1834 | "exception": false,
1835 | "start_time": "2021-11-07T11:42:00.578897",
1836 | "status": "completed"
1837 | },
1838 | "tags": []
1839 | },
1840 | "outputs": [
1841 | {
1842 | "data": {
1843 | "text/plain": [
1844 | "((103, 12, 24, 72, 4), (103, 24))"
1845 | ]
1846 | },
1847 | "execution_count": 33,
1848 | "metadata": {},
1849 | "output_type": "execute_result"
1850 | }
1851 | ],
1852 | "source": [
1853 | "X_test = np.array(X_test)\n",
1854 | "y_test = np.array(y_test)\n",
1855 | "X_test.shape, y_test.shape"
1856 | ]
1857 | },
1858 | {
1859 | "cell_type": "code",
1860 | "execution_count": 34,
1861 | "metadata": {
1862 | "execution": {
1863 | "iopub.execute_input": "2021-11-07T11:42:01.612810Z",
1864 | "iopub.status.busy": "2021-11-07T11:42:01.611682Z",
1865 | "iopub.status.idle": "2021-11-07T11:42:01.638035Z",
1866 | "shell.execute_reply": "2021-11-07T11:42:01.637526Z",
1867 | "shell.execute_reply.started": "2021-11-07T11:31:58.416006Z"
1868 | },
1869 | "papermill": {
1870 | "duration": 0.3441,
1871 | "end_time": "2021-11-07T11:42:01.638176",
1872 | "exception": false,
1873 | "start_time": "2021-11-07T11:42:01.294076",
1874 | "status": "completed"
1875 | },
1876 | "tags": []
1877 | },
1878 | "outputs": [],
1879 | "source": [
1880 | "testset = AIEarthDataset(X_test, y_test)\n",
1881 | "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)"
1882 | ]
1883 | },
1884 | {
1885 | "cell_type": "code",
1886 | "execution_count": 35,
1887 | "metadata": {
1888 | "execution": {
1889 | "iopub.execute_input": "2021-11-07T11:42:02.286777Z",
1890 | "iopub.status.busy": "2021-11-07T11:42:02.285865Z",
1891 | "iopub.status.idle": "2021-11-07T11:42:02.621942Z",
1892 | "shell.execute_reply": "2021-11-07T11:42:02.622610Z",
1893 | "shell.execute_reply.started": "2021-11-07T11:31:58.447987Z"
1894 | },
1895 | "papermill": {
1896 | "duration": 0.666798,
1897 | "end_time": "2021-11-07T11:42:02.622817",
1898 | "exception": false,
1899 | "start_time": "2021-11-07T11:42:01.956019",
1900 | "status": "completed"
1901 | },
1902 | "tags": []
1903 | },
1904 | "outputs": [
1905 | {
1906 | "name": "stderr",
1907 | "output_type": "stream",
1908 | "text": [
1909 | "4it [00:00, 12.75it/s]"
1910 | ]
1911 | },
1912 | {
1913 | "name": "stdout",
1914 | "output_type": "stream",
1915 | "text": [
1916 | "Score: 20.274\n"
1917 | ]
1918 | },
1919 | {
1920 | "name": "stderr",
1921 | "output_type": "stream",
1922 | "text": [
1923 | "\n"
1924 | ]
1925 | }
1926 | ],
1927 | "source": [
1928 | "# 在测试集上评估模型效果\n",
1929 | "model.eval()\n",
1930 | "model.to(device)\n",
1931 | "preds = np.zeros((len(y_test),24))\n",
1932 | "for i, data in tqdm(enumerate(testloader)):\n",
1933 | " data, labels = data\n",
1934 | " data = data.to(device)\n",
1935 | " labels = labels.to(device)\n",
1936 | " pred = model(data)\n",
1937 | " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
1938 | "s = score(y_test, preds)\n",
1939 | "print('Score: {:.3f}'.format(s))"
1940 | ]
1941 | },
1942 | {
1943 | "cell_type": "markdown",
1944 | "metadata": {
1945 | "papermill": {
1946 | "duration": 0.311337,
1947 | "end_time": "2021-11-07T11:42:03.280397",
1948 | "exception": false,
1949 | "start_time": "2021-11-07T11:42:02.969060",
1950 | "status": "completed"
1951 | },
1952 | "tags": []
1953 | },
1954 | "source": [
1955 | "## 总结\n",
1956 | "\n",
1957 | "- 该方案充分考虑到数据量小、特征少的数据情况,对时间和空间分别进行卷积操作,交替地提取时间和空间信息,用GAP层对提取的信息进行降维,尽可能减少每一层的参数量、增加模型层数以提取更丰富的特征。\n",
1958 | "- 该方案考虑到不同时间尺度序列所携带的信息不同,用池化层变换时间尺度,并用RNN进行信息提取,综合三种不同时间尺度的序列信息得到最终的预测序列。\n",
1959 | "- 该方案同样选择了自己设计模型,在构造模型时充分考虑了数据集情况和问题背景,并能灵活运用各种网络层来处理特定问题,这种模型构造思路要求对不同网络层的作用有较为深刻地理解,方案中各种网络层的用法值得大家学习和借鉴。"
1960 | ]
1961 | },
1962 | {
1963 | "cell_type": "markdown",
1964 | "metadata": {},
1965 | "source": [
1966 | "## 参考文献\n",
1967 | "\n",
1968 | "1. Top1思路分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.6.561d482cp7CFlx&postId=210391"
1969 | ]
1970 | }
1971 | ],
1972 | "metadata": {
1973 | "kernelspec": {
1974 | "display_name": "Python 3",
1975 | "language": "python",
1976 | "name": "python3"
1977 | },
1978 | "language_info": {
1979 | "codemirror_mode": {
1980 | "name": "ipython",
1981 | "version": 3
1982 | },
1983 | "file_extension": ".py",
1984 | "mimetype": "text/x-python",
1985 | "name": "python",
1986 | "nbconvert_exporter": "python",
1987 | "pygments_lexer": "ipython3",
1988 | "version": "3.7.3"
1989 | },
1990 | "papermill": {
1991 | "default_parameters": {},
1992 | "duration": 571.585002,
1993 | "end_time": "2021-11-07T11:42:05.308202",
1994 | "environment_variables": {},
1995 | "exception": null,
1996 | "input_path": "__notebook__.ipynb",
1997 | "output_path": "__notebook__.ipynb",
1998 | "parameters": {},
1999 | "start_time": "2021-11-07T11:32:33.723200",
2000 | "version": "2.3.3"
2001 | }
2002 | },
2003 | "nbformat": 4,
2004 | "nbformat_minor": 5
2005 | }
2006 |
--------------------------------------------------------------------------------
/Task4/fig/Task4-CNN单元.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-CNN单元.png
--------------------------------------------------------------------------------
/Task4/fig/Task4-TCNN+RNN模型.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-TCNN+RNN模型.png
--------------------------------------------------------------------------------
/Task4/fig/Task4-TCNN层.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-TCNN层.png
--------------------------------------------------------------------------------
/Task4/fig/Task4-TCN单元.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-TCN单元.png
--------------------------------------------------------------------------------
/Task4/fig/Task4-扩张卷积.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-扩张卷积.png
--------------------------------------------------------------------------------
/Task4/fig/Task4-残差连接.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task4/fig/Task4-残差连接.png
--------------------------------------------------------------------------------
/Task5/Task5 模型建立之SA-ConvLSTM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Datawhale 气象海洋预测-Task5 模型建立之 SA-ConvLSTM\n",
8 | "\n",
9 | "本次任务我们将学习来自TOP选手“吴先生的队伍”的建模方案,该方案中采用的模型是SA-ConvLSTM。\n",
10 | "\n",
11 | "前两个TOP方案中选择将赛题看作一个多输出的任务,通过构建神经网络直接输出24个nino3.4预测值,这种思路的问题在于,序列问题往往是时序依赖的,当我们采用多输出的方法时其实把这24个nino3.4预测值看作是完全独立的,但是实际上它们之间是存在序列依赖的,即每个预测值往往受上一个时间步的预测值的影响。因此,在这次的TOP方案中,采用Seq2Seq结构来考虑输出预测值的序列依赖性。\n",
12 | "\n",
13 | "Seq2Seq结构包括Encoder(编码器)和Decoder(解码器)两部分,Encoder部分将输入序列编码成一个向量,Decoder部分对向量进行解码,输出一个预测序列。要将Seq2Seq结构应用于不同的序列问题,关键在于每一个时间步所使用的Cell。我们之前说到,挖掘空间信息通常会采用CNN,挖掘时间信息通常会采用RNN或LSTM,将二者结合在一起就得到了时空序列领域的经典模型——ConvLSTM,我们本次要学习的SA-ConvLSTM模型是对ConvLSTM模型的改进,在其基础上引入了自注意力机制来提高模型对于长期空间依赖关系的挖掘能力。\n",
14 | "\n",
15 | "另外与前两个TOP方案所不同的一点是,该TOP方案没有直接预测Nino3.4指数,而是通过预测sst来间接求得Nino3.4指数序列。"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "## 学习目标\n",
23 | "1. 学习TOP方案的模型构建方法"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "## 内容介绍\n",
31 | "1. 数据处理\n",
32 | " - 数据扁平化\n",
33 | " - 空值填充\n",
34 | " - 构造数据集\n",
35 | "2. 模型构建\n",
36 | " - 构造评估函数\n",
37 | " - 模型构造\n",
38 | " - 模型训练\n",
39 | " - 模型评估\n",
40 | "3. 总结"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {},
46 | "source": [
47 | "## 代码示例"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "### 数据处理\n",
55 | "该TOP方案的数据处理主要包括三部分:\n",
56 | "1. 数据扁平化。\n",
57 | "2. 空值填充。\n",
58 | "3. 构造数据集"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 1,
64 | "metadata": {
65 | "execution": {
66 | "iopub.execute_input": "2021-11-29T03:04:34.698663Z",
67 | "iopub.status.busy": "2021-11-29T03:04:34.697133Z",
68 | "iopub.status.idle": "2021-11-29T03:04:37.035400Z",
69 | "shell.execute_reply": "2021-11-29T03:04:37.034767Z",
70 | "shell.execute_reply.started": "2021-11-29T01:02:51.883602Z"
71 | },
72 | "papermill": {
73 | "duration": 2.370278,
74 | "end_time": "2021-11-29T03:04:37.035673",
75 | "exception": false,
76 | "start_time": "2021-11-29T03:04:34.665395",
77 | "status": "completed"
78 | },
79 | "tags": []
80 | },
81 | "outputs": [],
82 | "source": [
83 | "import netCDF4 as nc\n",
84 | "import random\n",
85 | "import os\n",
86 | "from tqdm import tqdm\n",
87 | "import pandas as pd\n",
88 | "import numpy as np\n",
89 | "import math\n",
90 | "import matplotlib.pyplot as plt\n",
91 | "%matplotlib inline\n",
92 | "\n",
93 | "import torch\n",
94 | "from torch import nn, optim\n",
95 | "import torch.nn.functional as F\n",
96 | "from torch.utils.data import Dataset, DataLoader\n",
97 | "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
98 | "\n",
99 | "from sklearn.metrics import mean_squared_error"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 2,
105 | "metadata": {
106 | "execution": {
107 | "iopub.execute_input": "2021-11-29T03:04:37.102995Z",
108 | "iopub.status.busy": "2021-11-29T03:04:37.102144Z",
109 | "iopub.status.idle": "2021-11-29T03:04:37.107646Z",
110 | "shell.execute_reply": "2021-11-29T03:04:37.107161Z",
111 | "shell.execute_reply.started": "2021-11-29T01:02:54.06493Z"
112 | },
113 | "papermill": {
114 | "duration": 0.040737,
115 | "end_time": "2021-11-29T03:04:37.107761",
116 | "exception": false,
117 | "start_time": "2021-11-29T03:04:37.067024",
118 | "status": "completed"
119 | },
120 | "tags": []
121 | },
122 | "outputs": [],
123 | "source": [
124 | "# 固定随机种子\n",
125 | "SEED = 22\n",
126 | "\n",
127 | "def seed_everything(seed=42):\n",
128 | " random.seed(seed)\n",
129 | " os.environ['PYTHONHASHSEED'] = str(seed)\n",
130 | " np.random.seed(seed)\n",
131 | " torch.manual_seed(seed)\n",
132 | " torch.cuda.manual_seed(seed)\n",
133 | " torch.backends.cudnn.deterministic = True\n",
134 | " \n",
135 | "seed_everything(SEED)"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 3,
141 | "metadata": {
142 | "execution": {
143 | "iopub.execute_input": "2021-11-29T03:04:37.222525Z",
144 | "iopub.status.busy": "2021-11-29T03:04:37.221100Z",
145 | "iopub.status.idle": "2021-11-29T03:04:37.225844Z",
146 | "shell.execute_reply": "2021-11-29T03:04:37.226442Z",
147 | "shell.execute_reply.started": "2021-11-29T01:02:54.074875Z"
148 | },
149 | "papermill": {
150 | "duration": 0.090198,
151 | "end_time": "2021-11-29T03:04:37.226602",
152 | "exception": false,
153 | "start_time": "2021-11-29T03:04:37.136404",
154 | "status": "completed"
155 | },
156 | "tags": []
157 | },
158 | "outputs": [
159 | {
160 | "name": "stdout",
161 | "output_type": "stream",
162 | "text": [
163 | "CUDA is available! Training on GPU ...\n"
164 | ]
165 | }
166 | ],
167 | "source": [
168 | "# 查看CUDA是否可用\n",
169 | "train_on_gpu = torch.cuda.is_available()\n",
170 | "\n",
171 | "if not train_on_gpu:\n",
172 | " print('CUDA is not available. Training on CPU ...')\n",
173 | "else:\n",
174 | " print('CUDA is available! Training on GPU ...')"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 4,
180 | "metadata": {
181 | "execution": {
182 | "iopub.execute_input": "2021-11-29T03:04:37.353082Z",
183 | "iopub.status.busy": "2021-11-29T03:04:37.352143Z",
184 | "iopub.status.idle": "2021-11-29T03:04:37.432852Z",
185 | "shell.execute_reply": "2021-11-29T03:04:37.434332Z",
186 | "shell.execute_reply.started": "2021-11-28T10:13:13.644947Z"
187 | },
188 | "papermill": {
189 | "duration": 0.179146,
190 | "end_time": "2021-11-29T03:04:37.434792",
191 | "exception": false,
192 | "start_time": "2021-11-29T03:04:37.255646",
193 | "status": "completed"
194 | },
195 | "tags": []
196 | },
197 | "outputs": [],
198 | "source": [
199 | "# 读取数据\n",
200 | "\n",
201 | "# 存放数据的路径\n",
202 | "path = '/kaggle/input/ninoprediction/'\n",
203 | "soda_train = nc.Dataset(path + 'SODA_train.nc')\n",
204 | "soda_label = nc.Dataset(path + 'SODA_label.nc')\n",
205 | "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n",
206 | "cmip_label = nc.Dataset(path + 'CMIP_label.nc')"
207 | ]
208 | },
209 | {
210 | "cell_type": "markdown",
211 | "metadata": {},
212 | "source": [
213 | "#### 数据扁平化\n",
214 | "采用滑窗构造数据集。该方案中只使用了sst特征,且只使用了lon值在[90, 330]范围内的数据,可能是为了节约计算资源。"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 5,
220 | "metadata": {
221 | "execution": {
222 | "iopub.execute_input": "2021-11-29T03:04:37.548239Z",
223 | "iopub.status.busy": "2021-11-29T03:04:37.546737Z",
224 | "iopub.status.idle": "2021-11-29T03:04:37.551951Z",
225 | "shell.execute_reply": "2021-11-29T03:04:37.553081Z",
226 | "shell.execute_reply.started": "2021-11-27T13:38:32.620904Z"
227 | },
228 | "papermill": {
229 | "duration": 0.065069,
230 | "end_time": "2021-11-29T03:04:37.553274",
231 | "exception": false,
232 | "start_time": "2021-11-29T03:04:37.488205",
233 | "status": "completed"
234 | },
235 | "tags": []
236 | },
237 | "outputs": [],
238 | "source": [
239 | "def make_flatted(train_ds, label_ds, info, start_idx=0):\n",
240 | " # 只使用sst特征\n",
241 | " keys = ['sst']\n",
242 | " label_key = 'nino'\n",
243 | " # 年数\n",
244 | " years = info[1]\n",
245 | " # 模式数\n",
246 | " models = info[2]\n",
247 | " \n",
248 | " train_list = []\n",
249 | " label_list = []\n",
250 | " \n",
251 | " # 将同种模式下的数据拼接起来\n",
252 | " for model_i in range(models):\n",
253 | " blocks = []\n",
254 | " \n",
255 | " # 对每个特征,取每条数据的前12个月进行拼接,只使用lon值在[90, 330]范围内的数据\n",
256 | " for key in keys:\n",
257 | " block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12, :, 19: 67].reshape(-1, 24, 48, 1).data\n",
258 | " blocks.append(block)\n",
259 | " \n",
260 | " # 将所有特征在最后一个维度上拼接起来\n",
261 | " train_flatted = np.concatenate(blocks, axis=-1)\n",
262 | " \n",
263 | " # 取12-23月的标签进行拼接,注意加上最后一年的最后12个月的标签(与最后一年12-23月的标签共同构成最后一年前12个月的预测目标)\n",
264 | " label_flatted = np.concatenate([\n",
265 | " label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n",
266 | " label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n",
267 | " ], axis=0)\n",
268 | " \n",
269 | " train_list.append(train_flatted)\n",
270 | " label_list.append(label_flatted)\n",
271 | " \n",
272 | " return train_list, label_list"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 6,
278 | "metadata": {
279 | "execution": {
280 | "iopub.execute_input": "2021-11-29T03:04:37.661954Z",
281 | "iopub.status.busy": "2021-11-29T03:04:37.660977Z",
282 | "iopub.status.idle": "2021-11-29T03:05:11.515409Z",
283 | "shell.execute_reply": "2021-11-29T03:05:11.515853Z",
284 | "shell.execute_reply.started": "2021-11-27T13:38:33.844185Z"
285 | },
286 | "papermill": {
287 | "duration": 33.912013,
288 | "end_time": "2021-11-29T03:05:11.516001",
289 | "exception": false,
290 | "start_time": "2021-11-29T03:04:37.603988",
291 | "status": "completed"
292 | },
293 | "tags": []
294 | },
295 | "outputs": [
296 | {
297 | "data": {
298 | "text/plain": [
299 | "((1, 1200, 24, 48, 1), (15, 1812, 24, 48, 1), (17, 1680, 24, 48, 1))"
300 | ]
301 | },
302 | "execution_count": 6,
303 | "metadata": {},
304 | "output_type": "execute_result"
305 | }
306 | ],
307 | "source": [
308 | "soda_info = ('soda', 100, 1)\n",
309 | "cmip6_info = ('cmip6', 151, 15)\n",
310 | "cmip5_info = ('cmip5', 140, 17)\n",
311 | "\n",
312 | "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n",
313 | "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n",
314 | "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n",
315 | "\n",
316 | "# 得到扁平化后的数据维度为(模式数×序列长度×纬度×经度×特征数),其中序列长度=年数×12\n",
317 | "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)"
318 | ]
319 | },
320 | {
321 | "cell_type": "markdown",
322 | "metadata": {},
323 | "source": [
324 | "#### 空值填充\n",
325 | "将空值填充为0。"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": 8,
331 | "metadata": {
332 | "execution": {
333 | "iopub.execute_input": "2021-11-29T03:05:11.638562Z",
334 | "iopub.status.busy": "2021-11-29T03:05:11.637553Z",
335 | "iopub.status.idle": "2021-11-29T03:05:11.644302Z",
336 | "shell.execute_reply": "2021-11-29T03:05:11.644742Z",
337 | "shell.execute_reply.started": "2021-11-27T13:39:22.665855Z"
338 | },
339 | "papermill": {
340 | "duration": 0.040786,
341 | "end_time": "2021-11-29T03:05:11.644893",
342 | "exception": false,
343 | "start_time": "2021-11-29T03:05:11.604107",
344 | "status": "completed"
345 | },
346 | "tags": []
347 | },
348 | "outputs": [
349 | {
350 | "name": "stdout",
351 | "output_type": "stream",
352 | "text": [
353 | "Number of null in soda_trains after fillna: 0\n"
354 | ]
355 | }
356 | ],
357 | "source": [
358 | "# 填充SODA数据中的空值\n",
359 | "soda_trains = np.array(soda_trains)\n",
360 | "soda_trains_nan = np.isnan(soda_trains)\n",
361 | "soda_trains[soda_trains_nan] = 0\n",
362 | "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": 9,
368 | "metadata": {
369 | "execution": {
370 | "iopub.execute_input": "2021-11-29T03:05:11.709054Z",
371 | "iopub.status.busy": "2021-11-29T03:05:11.707767Z",
372 | "iopub.status.idle": "2021-11-29T03:05:11.862744Z",
373 | "shell.execute_reply": "2021-11-29T03:05:11.863294Z",
374 | "shell.execute_reply.started": "2021-11-27T13:39:24.110039Z"
375 | },
376 | "papermill": {
377 | "duration": 0.18937,
378 | "end_time": "2021-11-29T03:05:11.863480",
379 | "exception": false,
380 | "start_time": "2021-11-29T03:05:11.674110",
381 | "status": "completed"
382 | },
383 | "tags": []
384 | },
385 | "outputs": [
386 | {
387 | "name": "stdout",
388 | "output_type": "stream",
389 | "text": [
390 | "Number of null in cmip6_trains after fillna: 0\n"
391 | ]
392 | }
393 | ],
394 | "source": [
395 | "# 填充CMIP6数据中的空值\n",
396 | "cmip6_trains = np.array(cmip6_trains)\n",
397 | "cmip6_trains_nan = np.isnan(cmip6_trains)\n",
398 | "cmip6_trains[cmip6_trains_nan] = 0\n",
399 | "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))"
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "execution_count": 10,
405 | "metadata": {
406 | "execution": {
407 | "iopub.execute_input": "2021-11-29T03:05:11.927752Z",
408 | "iopub.status.busy": "2021-11-29T03:05:11.925353Z",
409 | "iopub.status.idle": "2021-11-29T03:05:12.091117Z",
410 | "shell.execute_reply": "2021-11-29T03:05:12.091855Z",
411 | "shell.execute_reply.started": "2021-11-27T13:39:24.520724Z"
412 | },
413 | "papermill": {
414 | "duration": 0.197975,
415 | "end_time": "2021-11-29T03:05:12.092014",
416 | "exception": false,
417 | "start_time": "2021-11-29T03:05:11.894039",
418 | "status": "completed"
419 | },
420 | "tags": []
421 | },
422 | "outputs": [
423 | {
424 | "name": "stdout",
425 | "output_type": "stream",
426 | "text": [
427 | "Number of null in cmip6_trains after fillna: 0\n"
428 | ]
429 | }
430 | ],
431 | "source": [
432 | "# 填充CMIP5数据中的空值\n",
433 | "cmip5_trains = np.array(cmip5_trains)\n",
434 | "cmip5_trains_nan = np.isnan(cmip5_trains)\n",
435 | "cmip5_trains[cmip5_trains_nan] = 0\n",
436 | "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))"
437 | ]
438 | },
439 | {
440 | "cell_type": "markdown",
441 | "metadata": {},
442 | "source": [
443 | "#### 构造数据集\n",
444 | "构造训练和验证集。注意这里取每条输入数据的序列长度是38,这是因为输入sst序列长度是12,输出sst序列长度是26,在训练中采用teacher forcing策略(这个策略会在之后的模型构造时详细说明),因此这里在构造输入数据时包含了输出sst序列的实际值。"
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "execution_count": 11,
450 | "metadata": {
451 | "execution": {
452 | "iopub.execute_input": "2021-11-29T03:05:12.165242Z",
453 | "iopub.status.busy": "2021-11-29T03:05:12.164045Z",
454 | "iopub.status.idle": "2021-11-29T03:05:12.480257Z",
455 | "shell.execute_reply": "2021-11-29T03:05:12.479767Z",
456 | "shell.execute_reply.started": "2021-11-27T13:39:25.418945Z"
457 | },
458 | "papermill": {
459 | "duration": 0.361254,
460 | "end_time": "2021-11-29T03:05:12.480405",
461 | "exception": false,
462 | "start_time": "2021-11-29T03:05:12.119151",
463 | "status": "completed"
464 | },
465 | "tags": []
466 | },
467 | "outputs": [],
468 | "source": [
469 | "# 构造训练集\n",
470 | "\n",
471 | "X_train = []\n",
472 | "y_train = []\n",
473 | "# 从CMIP5的17种模式中各抽取100条数据\n",
474 | "for model_i in range(17):\n",
475 | " samples = np.random.choice(cmip5_trains.shape[1]-38, size=100)\n",
476 | " for ind in samples:\n",
477 | " X_train.append(cmip5_trains[model_i, ind: ind+38])\n",
478 | " y_train.append(cmip5_labels[model_i][ind: ind+24])\n",
479 | "# 从CMIP6的15种模式种各抽取100条数据\n",
480 | "for model_i in range(15):\n",
481 | " samples = np.random.choice(cmip6_trains.shape[1]-38, size=100)\n",
482 | " for ind in samples:\n",
483 | " X_train.append(cmip6_trains[model_i, ind: ind+38])\n",
484 | " y_train.append(cmip6_labels[model_i][ind: ind+24])\n",
485 | "X_train = np.array(X_train)\n",
486 | "y_train = np.array(y_train)"
487 | ]
488 | },
489 | {
490 | "cell_type": "code",
491 | "execution_count": 12,
492 | "metadata": {
493 | "execution": {
494 | "iopub.execute_input": "2021-11-29T03:05:12.541232Z",
495 | "iopub.status.busy": "2021-11-29T03:05:12.540676Z",
496 | "iopub.status.idle": "2021-11-29T03:05:12.548103Z",
497 | "shell.execute_reply": "2021-11-29T03:05:12.547520Z",
498 | "shell.execute_reply.started": "2021-11-27T13:39:26.341849Z"
499 | },
500 | "papermill": {
501 | "duration": 0.040262,
502 | "end_time": "2021-11-29T03:05:12.548224",
503 | "exception": false,
504 | "start_time": "2021-11-29T03:05:12.507962",
505 | "status": "completed"
506 | },
507 | "tags": []
508 | },
509 | "outputs": [],
510 | "source": [
511 | "# 构造测试集\n",
512 | "\n",
513 | "X_valid = []\n",
514 | "y_valid = []\n",
515 | "samples = np.random.choice(soda_trains.shape[1]-38, size=100)\n",
516 | "for ind in samples:\n",
517 | " X_valid.append(soda_trains[0, ind: ind+38])\n",
518 | " y_valid.append(soda_labels[0][ind: ind+24])\n",
519 | "X_valid = np.array(X_valid)\n",
520 | "y_valid = np.array(y_valid)"
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": 13,
526 | "metadata": {
527 | "execution": {
528 | "iopub.execute_input": "2021-11-29T03:05:12.606407Z",
529 | "iopub.status.busy": "2021-11-29T03:05:12.605555Z",
530 | "iopub.status.idle": "2021-11-29T03:05:12.611580Z",
531 | "shell.execute_reply": "2021-11-29T03:05:12.611152Z",
532 | "shell.execute_reply.started": "2021-11-27T13:39:27.247585Z"
533 | },
534 | "papermill": {
535 | "duration": 0.036214,
536 | "end_time": "2021-11-29T03:05:12.611721",
537 | "exception": false,
538 | "start_time": "2021-11-29T03:05:12.575507",
539 | "status": "completed"
540 | },
541 | "tags": []
542 | },
543 | "outputs": [
544 | {
545 | "data": {
546 | "text/plain": [
547 | "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))"
548 | ]
549 | },
550 | "execution_count": 13,
551 | "metadata": {},
552 | "output_type": "execute_result"
553 | }
554 | ],
555 | "source": [
556 | "# 查看数据集维度\n",
557 | "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": 15,
563 | "metadata": {
564 | "execution": {
565 | "iopub.execute_input": "2021-11-29T03:05:12.737322Z",
566 | "iopub.status.busy": "2021-11-29T03:05:12.736558Z",
567 | "iopub.status.idle": "2021-11-29T03:05:13.516712Z",
568 | "shell.execute_reply": "2021-11-29T03:05:13.517217Z",
569 | "shell.execute_reply.started": "2021-11-27T13:39:38.421657Z"
570 | },
571 | "papermill": {
572 | "duration": 0.812187,
573 | "end_time": "2021-11-29T03:05:13.517368",
574 | "exception": false,
575 | "start_time": "2021-11-29T03:05:12.705181",
576 | "status": "completed"
577 | },
578 | "tags": []
579 | },
580 | "outputs": [],
581 | "source": [
582 | "# 保存数据集\n",
583 | "np.save('X_train_sample.npy', X_train)\n",
584 | "np.save('y_train_sample.npy', y_train)\n",
585 | "np.save('X_valid_sample.npy', X_valid)\n",
586 | "np.save('y_valid_sample.npy', y_valid)"
587 | ]
588 | },
589 | {
590 | "cell_type": "markdown",
591 | "metadata": {},
592 | "source": [
593 | "### 模型构建"
594 | ]
595 | },
596 | {
597 | "cell_type": "code",
598 | "execution_count": 16,
599 | "metadata": {
600 | "execution": {
601 | "iopub.execute_input": "2021-11-29T03:05:13.577516Z",
602 | "iopub.status.busy": "2021-11-29T03:05:13.576992Z",
603 | "iopub.status.idle": "2021-11-29T03:05:21.917657Z",
604 | "shell.execute_reply": "2021-11-29T03:05:21.918265Z",
605 | "shell.execute_reply.started": "2021-11-29T01:03:01.505192Z"
606 | },
607 | "papermill": {
608 | "duration": 8.372964,
609 | "end_time": "2021-11-29T03:05:21.918443",
610 | "exception": false,
611 | "start_time": "2021-11-29T03:05:13.545479",
612 | "status": "completed"
613 | },
614 | "tags": []
615 | },
616 | "outputs": [],
617 | "source": [
618 | "# 读取数据集\n",
619 | "X_train = np.load('../input/ai-earth-task05-samples/X_train_sample.npy')\n",
620 | "y_train = np.load('../input/ai-earth-task05-samples/y_train_sample.npy')\n",
621 | "X_valid = np.load('../input/ai-earth-task05-samples/X_valid_sample.npy')\n",
622 | "y_valid = np.load('../input/ai-earth-task05-samples/y_valid_sample.npy')"
623 | ]
624 | },
625 | {
626 | "cell_type": "code",
627 | "execution_count": 17,
628 | "metadata": {
629 | "execution": {
630 | "iopub.execute_input": "2021-11-29T03:05:21.983898Z",
631 | "iopub.status.busy": "2021-11-29T03:05:21.982953Z",
632 | "iopub.status.idle": "2021-11-29T03:05:21.986939Z",
633 | "shell.execute_reply": "2021-11-29T03:05:21.986453Z",
634 | "shell.execute_reply.started": "2021-11-29T01:03:11.548945Z"
635 | },
636 | "papermill": {
637 | "duration": 0.039398,
638 | "end_time": "2021-11-29T03:05:21.987066",
639 | "exception": false,
640 | "start_time": "2021-11-29T03:05:21.947668",
641 | "status": "completed"
642 | },
643 | "tags": []
644 | },
645 | "outputs": [
646 | {
647 | "data": {
648 | "text/plain": [
649 | "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))"
650 | ]
651 | },
652 | "execution_count": 17,
653 | "metadata": {},
654 | "output_type": "execute_result"
655 | }
656 | ],
657 | "source": [
658 | "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
659 | ]
660 | },
661 | {
662 | "cell_type": "code",
663 | "execution_count": 18,
664 | "metadata": {
665 | "execution": {
666 | "iopub.execute_input": "2021-11-29T03:05:22.341929Z",
667 | "iopub.status.busy": "2021-11-29T03:05:22.340932Z",
668 | "iopub.status.idle": "2021-11-29T03:05:22.346140Z",
669 | "shell.execute_reply": "2021-11-29T03:05:22.346878Z",
670 | "shell.execute_reply.started": "2021-11-29T01:03:11.560457Z"
671 | },
672 | "papermill": {
673 | "duration": 0.143838,
674 | "end_time": "2021-11-29T03:05:22.347113",
675 | "exception": false,
676 | "start_time": "2021-11-29T03:05:22.203275",
677 | "status": "completed"
678 | },
679 | "tags": []
680 | },
681 | "outputs": [],
682 | "source": [
683 | "# 构造数据管道\n",
684 | "class AIEarthDataset(Dataset):\n",
685 | " def __init__(self, data, label):\n",
686 | " self.data = torch.tensor(data, dtype=torch.float32)\n",
687 | " self.label = torch.tensor(label, dtype=torch.float32)\n",
688 | "\n",
689 | " def __len__(self):\n",
690 | " return len(self.label)\n",
691 | " \n",
692 | " def __getitem__(self, idx):\n",
693 | " return self.data[idx], self.label[idx]"
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "execution_count": 19,
699 | "metadata": {
700 | "execution": {
701 | "iopub.execute_input": "2021-11-29T03:05:22.583350Z",
702 | "iopub.status.busy": "2021-11-29T03:05:22.582298Z",
703 | "iopub.status.idle": "2021-11-29T03:05:23.243100Z",
704 | "shell.execute_reply": "2021-11-29T03:05:23.243851Z",
705 | "shell.execute_reply.started": "2021-11-29T01:03:23.691846Z"
706 | },
707 | "papermill": {
708 | "duration": 0.825537,
709 | "end_time": "2021-11-29T03:05:23.244098",
710 | "exception": false,
711 | "start_time": "2021-11-29T03:05:22.418561",
712 | "status": "completed"
713 | },
714 | "tags": []
715 | },
716 | "outputs": [],
717 | "source": [
718 | "batch_size = 2\n",
719 | "\n",
720 | "trainset = AIEarthDataset(X_train, y_train)\n",
721 | "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
722 | "\n",
723 | "validset = AIEarthDataset(X_valid, y_valid)\n",
724 | "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)"
725 | ]
726 | },
727 | {
728 | "cell_type": "markdown",
729 | "metadata": {},
730 | "source": [
731 | "#### 构造评估函数"
732 | ]
733 | },
734 | {
735 | "cell_type": "code",
736 | "execution_count": 24,
737 | "metadata": {
738 | "execution": {
739 | "iopub.execute_input": "2021-11-29T03:05:23.655820Z",
740 | "iopub.status.busy": "2021-11-29T03:05:23.655241Z",
741 | "iopub.status.idle": "2021-11-29T03:05:23.658416Z",
742 | "shell.execute_reply": "2021-11-29T03:05:23.658859Z",
743 | "shell.execute_reply.started": "2021-11-29T01:03:26.481561Z"
744 | },
745 | "papermill": {
746 | "duration": 0.040887,
747 | "end_time": "2021-11-29T03:05:23.658990",
748 | "exception": false,
749 | "start_time": "2021-11-29T03:05:23.618103",
750 | "status": "completed"
751 | },
752 | "tags": []
753 | },
754 | "outputs": [],
755 | "source": [
756 | "def rmse(y_true, y_preds):\n",
757 | " return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n",
758 | "\n",
759 | "# 评估函数\n",
760 | "def score(y_true, y_preds):\n",
761 | " # 相关性技巧评分\n",
762 | " accskill_score = 0\n",
763 | " # RMSE\n",
764 | " rmse_scores = 0\n",
765 | " a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n",
766 | " y_true_mean = np.mean(y_true, axis=0)\n",
767 | " y_pred_mean = np.mean(y_preds, axis=0)\n",
768 | " for i in range(24):\n",
769 | " fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n",
770 | " fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n",
771 | " cor_i = fenzi / fenmu\n",
772 | " accskill_score += a[i] * np.log(i+1) * cor_i\n",
773 | " rmse_score = rmse(y_true[:, i], y_preds[:, i])\n",
774 | " rmse_scores += rmse_score\n",
775 | " return 2/3.0 * accskill_score - rmse_scores"
776 | ]
777 | },
778 | {
779 | "cell_type": "markdown",
780 | "metadata": {
781 | "papermill": {
782 | "duration": 0.028556,
783 | "end_time": "2021-11-29T03:05:23.310560",
784 | "exception": false,
785 | "start_time": "2021-11-29T03:05:23.282004",
786 | "status": "completed"
787 | },
788 | "tags": []
789 | },
790 | "source": [
791 | "#### 模型构造"
792 | ]
793 | },
794 | {
795 | "cell_type": "markdown",
796 | "metadata": {},
797 | "source": [
798 | "不同于前两个TOP方案所构建的多输出神经网络,该TOP方案采用的是Seq2Seq结构,以本赛题为例,输入的序列长度是12,输出的序列长度是26,方案中构建了四个隐藏层,那么一个基础的Seq2Seq结构就如下图所示:\n",
799 | "\n",
800 | "
\n",
801 | "\n",
802 | "要将Seq2Seq结构应用于不同的问题,重点在于使用怎样的Cell(神经元)。在该TOP方案中使用的Cell是清华大学提出的SA-ConvLSTM(Self-Attention ConvLSTM),论文原文可参考https://ojs.aaai.org//index.php/AAAI/article/view/6819\n",
803 | "\n",
804 | "SA-ConvLSTM是施行健博士提出的时空序列领域经典模型ConvLSTM的改进模型,为了捕捉空间信息的时序依赖关系,它在ConvLSTM的基础上增加了SAM模块,用来记忆空间的聚合特征。ConvLSTM的论文原文可参考https://arxiv.org/pdf/1506.04214.pdf\n",
805 | "\n",
806 | "1. ConvLSTM模型\n",
807 | "\n",
808 | "LSTM模型是非常经典的时序模型,三个门的结构使得它在挖掘长期的时间依赖任务中有不俗的表现,并且相较于RNN,LSTM能够有效地避免梯度消失问题。对于单个输入样本,在每个时间步上,LSTM的每个门实际是对输入向量做了一个全连接,那么对应到我们这个赛题上,输入X的形状是(N,T,H,W,C),则单个输入样本在每个时间步上输入LSTM的就是形状为(H,W,C)的空间信息。我们知道,全连接网络对于这种空间信息的提取能力并不强,转换成卷积操作后能够在大大减少参数量的同时通过堆叠多层网络逐步提取出更复杂的特征,到这里就可以很自然地想到,把LSTM中的全连接操作转换为卷积操作,就能够适用于时空序列问题。ConvLSTM模型就是这么做的,实践也表明这样的作法是非常有效的。\n",
809 | "\n",
810 | "
\n",
811 | "\n",
812 | "2. SAM模块\n",
813 | "\n",
814 | "然而,ConvLSTM模型存在两个问题:\n",
815 | "\n",
816 | "一是卷积层的感受野受限于卷积核的大小,需要通过堆叠多个卷积层来扩大感受野,发掘全局的特征。举例来说,假设第一个卷积层的卷积核大小是3×3,那么这一层的每个节点就只能感知这3×3的空间范围内的输入信息,此时再增加一个3×3的卷积层,那么每个节点所能感知的就是3×3个第一层的节点内的信息,在第一层步长为1的情况下,就是4×4范围内的输入信息,于是相比于第一个卷积层,第二层所能感知的输入信息的空间范围就增大了,而这样做所带来的后果就是参数量增加。对于单纯的CNN模型来说增加一层只是增加了一个卷积核大小的参数量,但是对于ConvLSTM来说就有些不堪重负,参数量的增加增大了过拟合的风险,与此同时模型的收效却并不高。\n",
817 | "\n",
818 | "二是卷积操作只针对当前时间步输入的空间信息,而忽视了过去的空间信息,因此难以挖掘空间信息在时间上的依赖关系。\n",
819 | "\n",
820 | "因此,为了同时挖掘全局和本地的空间依赖,提升模型在大空间范围和长时间的时空序列预测任务中的预测效果,SA-ConvLSTM模型在ConvLSTM模型的基础上引入了SAM(self-attention memory)模块。\n",
821 | "\n",
822 | "
\n",
823 | "\n",
824 | "SAM模块引入了一个新的记忆单元M,用来记忆包含时序依赖关系的空间信息。SAM模块以当前时间步通过ConvLSTM所获得的隐藏层状态$H_t$和上一个时间步的记忆$M_{t-1}$作为输入,首先将$H_t$通过自注意力机制得到特征$Z_h$,自注意力机制能够增加$H_t$中与其他部分更相关的部分的权重,同时$H_t$也作为Query与$M_{t-1}$共同通过注意力机制得到特征$Z_m$,用以增强对$M_{t-1}$中与$H_t$有更强依赖关系的部分的权重,将$Z_h$和$Z_m$拼接起来就得到了二者的聚合特征$Z$。此时,聚合特征$Z$中既包含了当前时间步的信息,又包含了全局的时空记忆信息,接下来借鉴LSTM中的门控结构用聚合特征$Z$对隐藏层状态和记忆单元进行更新,就得到了更新后的隐藏层状态$\\hat{H_t}$和当前时间步的记忆$M_t$。SAM模块的公式如下:\n",
825 | "\n",
826 | "$$\n",
827 | "\\begin{aligned}\n",
828 | "& i'_t = \\sigma (W_{m;zi} \\ast Z + W_{m;hi} \\ast H_t + b_{m;i}) \\\\\n",
829 | "& g'_t = tanh (W_{m;zg} \\ast Z + W_{m;hg} \\ast H_t + b_{m;g}) \\\\\n",
830 | "& M_t = (1 - i'_t) \\circ M_{t-1} + i'_t \\circ g'_t \\\\\n",
831 | "& o'_t = \\sigma (W_{m;zo} \\ast Z + W_{m;ho} \\ast H_t + b_{m;o}) \\\\\n",
832 | "& \\hat{H_t} = o'_t \\circ M_t\n",
833 | "\\end{aligned}\n",
834 | "$$\n",
835 | "\n",
836 | "关于注意力机制和自注意力机制可以参考以下链接:\n",
837 | "\n",
838 | " - 深度学习中的注意力机制:https://blog.csdn.net/malefactor/article/details/78767781\n",
839 | " - 目前主流的Attention方法:https://www.zhihu.com/question/68482809\n",
840 | "\n",
841 | "3. SA-ConvLSTM模型\n",
842 | "\n",
843 | "将以上二者结合起来,就得到了SA-ConvLSTM模型:\n",
844 | "\n",
845 | "
"
846 | ]
847 | },
848 | {
849 | "cell_type": "code",
850 | "execution_count": 20,
851 | "metadata": {
852 | "execution": {
853 | "iopub.execute_input": "2021-11-29T03:05:23.372772Z",
854 | "iopub.status.busy": "2021-11-29T03:05:23.371873Z",
855 | "iopub.status.idle": "2021-11-29T03:05:23.373700Z",
856 | "shell.execute_reply": "2021-11-29T03:05:23.374122Z",
857 | "shell.execute_reply.started": "2021-11-29T01:03:24.585147Z"
858 | },
859 | "papermill": {
860 | "duration": 0.035787,
861 | "end_time": "2021-11-29T03:05:23.374254",
862 | "exception": false,
863 | "start_time": "2021-11-29T03:05:23.338467",
864 | "status": "completed"
865 | },
866 | "tags": []
867 | },
868 | "outputs": [],
869 | "source": [
870 | "# Attention机制\n",
871 | "def attn(query, key, value):\n",
872 | " # query、key、value的形状都是(N, C, H*W),令S=H*W\n",
873 | " # 采用缩放点积模型计算得分,scores(i)=key(i)^T query/根号C\n",
874 | " scores = torch.matmul(query.transpose(1, 2), key / math.sqrt(query.size(1))) # (N, S, S)\n",
875 | " # 计算注意力得分\n",
876 | " attn = F.softmax(scores, dim=-1)\n",
877 | " output = torch.matmul(attn, value.transpose(1, 2)) # (N, S, C)\n",
878 | " return output.transpose(1, 2) # (N, C, S)"
879 | ]
880 | },
881 | {
882 | "cell_type": "code",
883 | "execution_count": 21,
884 | "metadata": {
885 | "execution": {
886 | "iopub.execute_input": "2021-11-29T03:05:23.440765Z",
887 | "iopub.status.busy": "2021-11-29T03:05:23.440042Z",
888 | "iopub.status.idle": "2021-11-29T03:05:23.442191Z",
889 | "shell.execute_reply": "2021-11-29T03:05:23.442569Z",
890 | "shell.execute_reply.started": "2021-11-29T01:03:25.147999Z"
891 | },
892 | "papermill": {
893 | "duration": 0.041095,
894 | "end_time": "2021-11-29T03:05:23.442725",
895 | "exception": false,
896 | "start_time": "2021-11-29T03:05:23.401630",
897 | "status": "completed"
898 | },
899 | "tags": []
900 | },
901 | "outputs": [],
902 | "source": [
903 | "# SAM模块\n",
904 | "class SAAttnMem(nn.Module):\n",
905 | " def __init__(self, input_dim, d_model, kernel_size):\n",
906 | " super().__init__()\n",
907 | " pad = kernel_size[0] // 2, kernel_size[1] // 2\n",
908 | " self.d_model = d_model\n",
909 | " self.input_dim = input_dim\n",
910 | " # 用1*1卷积实现全连接操作WhHt\n",
911 | " self.conv_h = nn.Conv2d(input_dim, d_model*3, kernel_size=1)\n",
912 | " # 用1*1卷积实现全连接操作WmMt-1\n",
913 | " self.conv_m = nn.Conv2d(input_dim, d_model*2, kernel_size=1)\n",
914 | " # 用1*1卷积实现全连接操作Wz[Zh,Zm]\n",
915 | " self.conv_z = nn.Conv2d(d_model*2, d_model, kernel_size=1)\n",
916 | " # 注意输出维度和输入维度要保持一致,都是input_dim\n",
917 | " self.conv_output = nn.Conv2d(input_dim+d_model, input_dim*3, kernel_size=kernel_size, padding=pad)\n",
918 | " \n",
919 | " def forward(self, h, m):\n",
920 | " # self.conv_h(h)得到WhHt,将其在dim=1上划分成大小为self.d_model的块,每一块的形状就是(N, d_model, H, W),所得到的三块就是Qh、Kh、Vh\n",
921 | " hq, hk, hv = torch.split(self.conv_h(h), self.d_model, dim=1)\n",
922 | " # 同样的方法得到Km和Vm\n",
923 | " mk, mv = torch.split(self.conv_m(m), self.d_model, dim=1)\n",
924 | " N, C, H, W = hq.size()\n",
925 | " # 通过自注意力机制得到Zh\n",
926 | " Zh = attn(hq.view(N, C, -1), hk.view(N, C, -1), hv.view(N, C, -1)) # (N, C, S), C=d_model\n",
927 | " # 通过注意力机制得到Zm\n",
928 | " Zm = attn(hq.view(N, C, -1), mk.view(N, C, -1), mv.view(N, C, -1)) # (N, C, S), C=d_model\n",
929 | " # 将Zh和Zm拼接起来,并进行全连接操作得到聚合特征Z\n",
930 | " Z = self.conv_z(torch.cat([Zh.view(N, C, H, W), Zm.view(N, C, H, W)], dim=1)) # (N, C, H, W), C=d_model\n",
931 | " # 计算i't、g't、o't\n",
932 | " i, g, o = torch.split(self.conv_output(torch.cat([Z, h], dim=1)), self.input_dim, dim=1) # (N, C, H, W), C=input_dim\n",
933 | " i = torch.sigmoid(i)\n",
934 | " g = torch.tanh(g)\n",
935 | " # 得到更新后的记忆单元Mt\n",
936 | " m_next = i * g + (1 - i) * m\n",
937 | " # 得到更新后的隐藏状态Ht\n",
938 | " h_next = torch.sigmoid(o) * m_next\n",
939 | " return h_next, m_next"
940 | ]
941 | },
942 | {
943 | "cell_type": "code",
944 | "execution_count": 22,
945 | "metadata": {
946 | "execution": {
947 | "iopub.execute_input": "2021-11-29T03:05:23.509738Z",
948 | "iopub.status.busy": "2021-11-29T03:05:23.509080Z",
949 | "iopub.status.idle": "2021-11-29T03:05:23.512667Z",
950 | "shell.execute_reply": "2021-11-29T03:05:23.512182Z",
951 | "shell.execute_reply.started": "2021-11-29T01:03:25.667808Z"
952 | },
953 | "papermill": {
954 | "duration": 0.042616,
955 | "end_time": "2021-11-29T03:05:23.512781",
956 | "exception": false,
957 | "start_time": "2021-11-29T03:05:23.470165",
958 | "status": "completed"
959 | },
960 | "tags": []
961 | },
962 | "outputs": [],
963 | "source": [
964 | "# SA-ConvLSTM Cell\n",
965 | "class SAConvLSTMCell(nn.Module):\n",
966 | " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n",
967 | " super().__init__()\n",
968 | " self.input_dim = input_dim\n",
969 | " self.hidden_dim = hidden_dim\n",
970 | " pad = kernel_size[0] // 2, kernel_size[1] // 2\n",
971 | " # 卷积操作Wx*Xt+Wh*Ht-1\n",
972 | " self.conv = nn.Conv2d(in_channels=input_dim+hidden_dim, out_channels=4*hidden_dim, kernel_size=kernel_size, padding=pad)\n",
973 | " self.sa = SAAttnMem(input_dim=hidden_dim, d_model=d_attn, kernel_size=kernel_size)\n",
974 | " \n",
975 | " def initialize(self, inputs):\n",
976 | " device = inputs.device\n",
977 | " N, _, H, W = inputs.size()\n",
978 | " # 初始化隐藏层状态Ht\n",
979 | " self.hidden_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
980 | " # 初始化记忆细胞状态ct\n",
981 | " self.cell_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
982 | " # 初始化记忆单元状态Mt\n",
983 | " self.memory_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
984 | " \n",
985 | " def forward(self, inputs, first_step=False):\n",
986 | " # 如果当前是第一个时间步,初始化Ht、ct、Mt\n",
987 | " if first_step:\n",
988 | " self.initialize(inputs)\n",
989 | " \n",
990 | " # ConvLSTM部分\n",
991 | " # 拼接Xt和Ht\n",
992 | " combined = torch.cat([inputs, self.hidden_state], dim=1) # (N, C, H, W), C=input_dim+hidden_dim\n",
993 | " # 进行卷积操作\n",
994 | " combined_conv = self.conv(combined) \n",
995 | " # 得到四个门控单元it、ft、ot、gt\n",
996 | " cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)\n",
997 | " i = torch.sigmoid(cc_i)\n",
998 | " f = torch.sigmoid(cc_f)\n",
999 | " o = torch.sigmoid(cc_o)\n",
1000 | " g = torch.tanh(cc_g)\n",
1001 | " # 得到当前时间步的记忆细胞状态ct=ft·ct-1+it·gt\n",
1002 | " self.cell_state = f * self.cell_state + i * g\n",
1003 | " # 得到当前时间步的隐藏层状态Ht=ot·tanh(ct)\n",
1004 | " self.hidden_state = o * torch.tanh(self.cell_state)\n",
1005 | " \n",
1006 | " # SAM部分,更新Ht和Mt\n",
1007 | " self.hidden_state, self.memory_state = self.sa(self.hidden_state, self.memory_state)\n",
1008 | " \n",
1009 | " return self.hidden_state"
1010 | ]
1011 | },
1012 | {
1013 | "cell_type": "markdown",
1014 | "metadata": {},
1015 | "source": [
1016 | "在Seq2Seq模型的训练中,有两种训练模式。一是Free running,也就是传统的训练方式,以上一个时间步的输出$\\hat{y_{t-1}}$作为下一个时间步的输入,但是这种做法存在的问题是在训练的初期所得到的$\\hat{y_{t-1}}$与实际标签$y_{t-1}$相差甚远,以此作为输入会导致后续的输出越来越偏离我们期望的预测标签。于是就产生了第二种训练模式——Teacher forcing。\n",
1017 | "\n",
1018 | "Teacher forcing就是直接使用实际标签$y_{t-1}$作为下一个时间步的输入,由老师(ground truth)带领着防止模型越走越偏。但是老师不能总是手把手领着学生走,要逐渐放手让学生自主学习,于是我们使用Scheduled Sampling来控制使用实际标签的概率。我们用ratio来表示Scheduled Sampling的比例,在训练初期,ratio=1,模型完全由老师带领着,随着训练论述的增加,ratio以一定的方式衰减(该方案中使用线性衰减,ratio每次减小一个衰减率decay_rate),每个时间步以ratio的概率从伯努利分布中提取二进制随机数0或1,为1时输入就是实际标签$y_{t-1}$,否则输入为$\\hat{y_{t-1}}$。"
1019 | ]
1020 | },
1021 | {
1022 | "cell_type": "code",
1023 | "execution_count": 23,
1024 | "metadata": {
1025 | "execution": {
1026 | "iopub.execute_input": "2021-11-29T03:05:23.587567Z",
1027 | "iopub.status.busy": "2021-11-29T03:05:23.586781Z",
1028 | "iopub.status.idle": "2021-11-29T03:05:23.588776Z",
1029 | "shell.execute_reply": "2021-11-29T03:05:23.589156Z",
1030 | "shell.execute_reply.started": "2021-11-29T01:03:26.065997Z"
1031 | },
1032 | "papermill": {
1033 | "duration": 0.047514,
1034 | "end_time": "2021-11-29T03:05:23.589277",
1035 | "exception": false,
1036 | "start_time": "2021-11-29T03:05:23.541763",
1037 | "status": "completed"
1038 | },
1039 | "tags": []
1040 | },
1041 | "outputs": [],
1042 | "source": [
1043 | "# 构建SA-ConvLSTM模型\n",
1044 | "class SAConvLSTM(nn.Module):\n",
1045 | " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n",
1046 | " super().__init__()\n",
1047 | " self.input_dim = input_dim\n",
1048 | " self.hidden_dim = hidden_dim\n",
1049 | " self.num_layers = len(hidden_dim)\n",
1050 | " \n",
1051 | " layers = []\n",
1052 | " for i in range(self.num_layers):\n",
1053 | " cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]\n",
1054 | " layers.append(SAConvLSTMCell(input_dim=cur_input_dim, hidden_dim=self.hidden_dim[i], d_attn = d_attn, kernel_size=kernel_size)) \n",
1055 | " self.layers = nn.ModuleList(layers)\n",
1056 | " \n",
1057 | " self.conv_output = nn.Conv2d(self.hidden_dim[-1], 1, kernel_size=1)\n",
1058 | " \n",
1059 | " def forward(self, input_x, device=torch.device('cuda:0'), input_frames=12, future_frames=26, output_frames=37, teacher_forcing=False, scheduled_sampling_ratio=0, train=True):\n",
1060 | " # 将输入样本X的形状(N, T, H, W, C)转换为(N, T, C, H, W)\n",
1061 | " input_x = input_x.permute(0, 1, 4, 2, 3).contiguous()\n",
1062 | " \n",
1063 | " # 仅在训练时使用teacher forcing\n",
1064 | " if train:\n",
1065 | " if teacher_forcing and scheduled_sampling_ratio > 1e-6:\n",
1066 | " teacher_forcing_mask = torch.bernoulli(scheduled_sampling_ratio * torch.ones(input_x.size(0), future_frames-1, 1, 1, 1))\n",
1067 | " else:\n",
1068 | " teacher_forcing = False\n",
1069 | " else:\n",
1070 | " teacher_forcing = False\n",
1071 | " \n",
1072 | " total_steps = input_frames + future_frames - 1\n",
1073 | " outputs = [None] * total_steps\n",
1074 | " \n",
1075 | " # 对于每一个时间步\n",
1076 | " for t in range(total_steps):\n",
1077 | " # 在前12个月,使用每个月的输入样本Xt\n",
1078 | " if t < input_frames:\n",
1079 | " input_ = input_x[:, t].to(device)\n",
1080 | " # 若不使用teacher forcing,则以上一个时间步的预测标签作为当前时间步的输入\n",
1081 | " elif not teacher_forcing:\n",
1082 | " input_ = outputs[t-1]\n",
1083 | " # 若使用teacher forcing,则以ratio的概率使用上一个时间步的实际标签作为当前时间步的输入\n",
1084 | " else:\n",
1085 | " mask = teacher_forcing_mask[:, t-input_frames].float().to(device)\n",
1086 | " input_ = input_x[:, t].to(device) * mask + outputs[t-1] * (1-mask)\n",
1087 | " first_step = (t==0)\n",
1088 | " input_ = input_.float()\n",
1089 | " \n",
1090 | " # 将当前时间步的输入通过隐藏层\n",
1091 | " for layer_idx in range(self.num_layers):\n",
1092 | " input_ = self.layers[layer_idx](input_, first_step=first_step)\n",
1093 | " \n",
1094 | " # 记录每个时间步的输出\n",
1095 | " if train or (t >= (input_frames - 1)):\n",
1096 | " outputs[t] = self.conv_output(input_)\n",
1097 | " \n",
1098 | " outputs = [x for x in outputs if x is not None]\n",
1099 | " \n",
1100 | " # 确认输出序列的长度\n",
1101 | " if train:\n",
1102 | " assert len(outputs) == output_frames\n",
1103 | " else:\n",
1104 | " assert len(outputs) == future_frames\n",
1105 | " \n",
1106 | " # 得到sst的预测序列\n",
1107 | " outputs = torch.stack(outputs, dim=1)[:, :, 0] # (N, 37, H, W)\n",
1108 | " # 对sst的预测序列在nino3.4区域取三个月的平均值就得到nino3.4指数的预测序列\n",
1109 | " nino_pred = outputs[:, -future_frames:, 10:13, 19:30].mean(dim=[2, 3]) # (N, 26)\n",
1110 | " nino_pred = nino_pred.unfold(dimension=1, size=3, step=1).mean(dim=2) # (N, 24)\n",
1111 | " \n",
1112 | " return nino_pred"
1113 | ]
1114 | },
1115 | {
1116 | "cell_type": "code",
1117 | "execution_count": 25,
1118 | "metadata": {
1119 | "execution": {
1120 | "iopub.execute_input": "2021-11-29T03:05:23.726291Z",
1121 | "iopub.status.busy": "2021-11-29T03:05:23.725688Z",
1122 | "iopub.status.idle": "2021-11-29T03:05:23.753509Z",
1123 | "shell.execute_reply": "2021-11-29T03:05:23.753976Z",
1124 | "shell.execute_reply.started": "2021-11-29T01:03:29.448921Z"
1125 | },
1126 | "papermill": {
1127 | "duration": 0.066105,
1128 | "end_time": "2021-11-29T03:05:23.754109",
1129 | "exception": false,
1130 | "start_time": "2021-11-29T03:05:23.688004",
1131 | "status": "completed"
1132 | },
1133 | "tags": []
1134 | },
1135 | "outputs": [
1136 | {
1137 | "name": "stdout",
1138 | "output_type": "stream",
1139 | "text": [
1140 | "SAConvLSTM(\n",
1141 | " (layers): ModuleList(\n",
1142 | " (0): SAConvLSTMCell(\n",
1143 | " (conv): Conv2d(65, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1144 | " (sa): SAAttnMem(\n",
1145 | " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
1146 | " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
1147 | " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
1148 | " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1149 | " )\n",
1150 | " )\n",
1151 | " (1): SAConvLSTMCell(\n",
1152 | " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1153 | " (sa): SAAttnMem(\n",
1154 | " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
1155 | " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
1156 | " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
1157 | " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1158 | " )\n",
1159 | " )\n",
1160 | " (2): SAConvLSTMCell(\n",
1161 | " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1162 | " (sa): SAAttnMem(\n",
1163 | " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
1164 | " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
1165 | " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
1166 | " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1167 | " )\n",
1168 | " )\n",
1169 | " (3): SAConvLSTMCell(\n",
1170 | " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1171 | " (sa): SAAttnMem(\n",
1172 | " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
1173 | " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
1174 | " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
1175 | " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
1176 | " )\n",
1177 | " )\n",
1178 | " )\n",
1179 | " (conv_output): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))\n",
1180 | ")\n"
1181 | ]
1182 | }
1183 | ],
1184 | "source": [
1185 | "# 输入特征数\n",
1186 | "input_dim = 1\n",
1187 | "# 隐藏层节点数\n",
1188 | "hidden_dim = (64, 64, 64, 64)\n",
1189 | "# 注意力机制节点数\n",
1190 | "d_attn = 32\n",
1191 | "# 卷积核大小\n",
1192 | "kernel_size = (3, 3)\n",
1193 | "\n",
1194 | "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n",
1195 | "print(model)"
1196 | ]
1197 | },
1198 | {
1199 | "cell_type": "markdown",
1200 | "metadata": {},
1201 | "source": [
1202 | "#### 模型训练"
1203 | ]
1204 | },
1205 | {
1206 | "cell_type": "code",
1207 | "execution_count": 26,
1208 | "metadata": {
1209 | "execution": {
1210 | "iopub.execute_input": "2021-11-29T03:05:23.816671Z",
1211 | "iopub.status.busy": "2021-11-29T03:05:23.815927Z",
1212 | "iopub.status.idle": "2021-11-29T03:05:23.818479Z",
1213 | "shell.execute_reply": "2021-11-29T03:05:23.818058Z",
1214 | "shell.execute_reply.started": "2021-11-29T01:03:31.476806Z"
1215 | },
1216 | "papermill": {
1217 | "duration": 0.035723,
1218 | "end_time": "2021-11-29T03:05:23.818579",
1219 | "exception": false,
1220 | "start_time": "2021-11-29T03:05:23.782856",
1221 | "status": "completed"
1222 | },
1223 | "tags": []
1224 | },
1225 | "outputs": [],
1226 | "source": [
1227 | "# 采用RMSE作为损失函数\n",
1228 | "def RMSELoss(y_pred,y_true):\n",
1229 | " loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n",
1230 | " return loss"
1231 | ]
1232 | },
1233 | {
1234 | "cell_type": "code",
1235 | "execution_count": 27,
1236 | "metadata": {
1237 | "execution": {
1238 | "iopub.execute_input": "2021-11-29T03:05:23.893469Z",
1239 | "iopub.status.busy": "2021-11-29T03:05:23.892684Z",
1240 | "iopub.status.idle": "2021-11-29T04:55:28.956056Z",
1241 | "shell.execute_reply": "2021-11-29T04:55:28.956434Z"
1242 | },
1243 | "papermill": {
1244 | "duration": 6605.109145,
1245 | "end_time": "2021-11-29T04:55:28.956614",
1246 | "exception": false,
1247 | "start_time": "2021-11-29T03:05:23.847469",
1248 | "status": "completed"
1249 | },
1250 | "tags": []
1251 | },
1252 | "outputs": [
1253 | {
1254 | "name": "stdout",
1255 | "output_type": "stream",
1256 | "text": [
1257 | "Epoch: 1/5\n"
1258 | ]
1259 | },
1260 | {
1261 | "name": "stderr",
1262 | "output_type": "stream",
1263 | "text": [
1264 | "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n"
1265 | ]
1266 | },
1267 | {
1268 | "name": "stdout",
1269 | "output_type": "stream",
1270 | "text": [
1271 | "Training Loss: 3.289\n"
1272 | ]
1273 | },
1274 | {
1275 | "name": "stderr",
1276 | "output_type": "stream",
1277 | "text": [
1278 | "50it [00:11, 4.47it/s]\n"
1279 | ]
1280 | },
1281 | {
1282 | "name": "stdout",
1283 | "output_type": "stream",
1284 | "text": [
1285 | "Validation Loss: 44.009\n",
1286 | "Score: -43.458\n",
1287 | "Epoch: 2/5\n"
1288 | ]
1289 | },
1290 | {
1291 | "name": "stderr",
1292 | "output_type": "stream",
1293 | "text": [
1294 | "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n"
1295 | ]
1296 | },
1297 | {
1298 | "name": "stdout",
1299 | "output_type": "stream",
1300 | "text": [
1301 | "Training Loss: 3.084\n"
1302 | ]
1303 | },
1304 | {
1305 | "name": "stderr",
1306 | "output_type": "stream",
1307 | "text": [
1308 | "50it [00:11, 4.33it/s]\n"
1309 | ]
1310 | },
1311 | {
1312 | "name": "stdout",
1313 | "output_type": "stream",
1314 | "text": [
1315 | "Validation Loss: 25.011\n",
1316 | "Score: -19.966\n",
1317 | "Epoch: 3/5\n"
1318 | ]
1319 | },
1320 | {
1321 | "name": "stderr",
1322 | "output_type": "stream",
1323 | "text": [
1324 | "100%|██████████| 1600/1600 [21:46<00:00, 1.22it/s]\n"
1325 | ]
1326 | },
1327 | {
1328 | "name": "stdout",
1329 | "output_type": "stream",
1330 | "text": [
1331 | "Training Loss: 13.461\n"
1332 | ]
1333 | },
1334 | {
1335 | "name": "stderr",
1336 | "output_type": "stream",
1337 | "text": [
1338 | "50it [00:12, 4.16it/s]\n"
1339 | ]
1340 | },
1341 | {
1342 | "name": "stdout",
1343 | "output_type": "stream",
1344 | "text": [
1345 | "Validation Loss: 15.438\n",
1346 | "Score: -14.139\n",
1347 | "Epoch: 4/5\n"
1348 | ]
1349 | },
1350 | {
1351 | "name": "stderr",
1352 | "output_type": "stream",
1353 | "text": [
1354 | "100%|██████████| 1600/1600 [21:54<00:00, 1.22it/s]\n"
1355 | ]
1356 | },
1357 | {
1358 | "name": "stdout",
1359 | "output_type": "stream",
1360 | "text": [
1361 | "Training Loss: 17.627\n"
1362 | ]
1363 | },
1364 | {
1365 | "name": "stderr",
1366 | "output_type": "stream",
1367 | "text": [
1368 | "50it [00:12, 3.99it/s]\n"
1369 | ]
1370 | },
1371 | {
1372 | "name": "stdout",
1373 | "output_type": "stream",
1374 | "text": [
1375 | "Validation Loss: 15.389\n",
1376 | "Score: -22.500\n",
1377 | "Epoch: 5/5\n"
1378 | ]
1379 | },
1380 | {
1381 | "name": "stderr",
1382 | "output_type": "stream",
1383 | "text": [
1384 | "100%|██████████| 1600/1600 [21:55<00:00, 1.22it/s]\n"
1385 | ]
1386 | },
1387 | {
1388 | "name": "stdout",
1389 | "output_type": "stream",
1390 | "text": [
1391 | "Training Loss: 17.592\n"
1392 | ]
1393 | },
1394 | {
1395 | "name": "stderr",
1396 | "output_type": "stream",
1397 | "text": [
1398 | "50it [00:11, 4.48it/s]"
1399 | ]
1400 | },
1401 | {
1402 | "name": "stdout",
1403 | "output_type": "stream",
1404 | "text": [
1405 | "Validation Loss: 15.252\n",
1406 | "Score: -14.459\n"
1407 | ]
1408 | },
1409 | {
1410 | "name": "stderr",
1411 | "output_type": "stream",
1412 | "text": [
1413 | "\n"
1414 | ]
1415 | }
1416 | ],
1417 | "source": [
1418 | "model_weights = './task05_model_weights.pth'\n",
1419 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
1420 | "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size).to(device)\n",
1421 | "criterion = RMSELoss\n",
1422 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
1423 | "lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=0, verbose=True, min_lr=0.0001)\n",
1424 | "epochs = 5\n",
1425 | "ratio, decay_rate = 1, 8e-5\n",
1426 | "train_losses, valid_losses = [], []\n",
1427 | "scores = []\n",
1428 | "best_score = float('-inf')\n",
1429 | "preds = np.zeros((len(y_valid),24))\n",
1430 | "\n",
1431 | "for epoch in range(epochs):\n",
1432 | " print('Epoch: {}/{}'.format(epoch+1, epochs))\n",
1433 | " \n",
1434 | " # 模型训练\n",
1435 | " model.train()\n",
1436 | " losses = 0\n",
1437 | " for data, labels in tqdm(trainloader):\n",
1438 | " data = data.to(device)\n",
1439 | " labels = labels.to(device)\n",
1440 | " optimizer.zero_grad()\n",
1441 | " # ratio线性衰减\n",
1442 | " ratio = max(ratio-decay_rate, 0)\n",
1443 | " pred = model(data, teacher_forcing=True, scheduled_sampling_ratio=ratio, train=True)\n",
1444 | " loss = criterion(pred, labels)\n",
1445 | " losses += loss.cpu().detach().numpy()\n",
1446 | " loss.backward()\n",
1447 | " optimizer.step()\n",
1448 | " train_loss = losses / len(trainloader)\n",
1449 | " train_losses.append(train_loss)\n",
1450 | " print('Training Loss: {:.3f}'.format(train_loss))\n",
1451 | " \n",
1452 | " # 模型验证\n",
1453 | " model.eval()\n",
1454 | " losses = 0\n",
1455 | " with torch.no_grad():\n",
1456 | " for i, data in tqdm(enumerate(validloader)):\n",
1457 | " data, labels = data\n",
1458 | " data = data.to(device)\n",
1459 | " labels = labels.to(device)\n",
1460 | " pred = model(data, train=False)\n",
1461 | " loss = criterion(pred, labels)\n",
1462 | " losses += loss.cpu().detach().numpy()\n",
1463 | " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
1464 | " valid_loss = losses / len(validloader)\n",
1465 | " valid_losses.append(valid_loss)\n",
1466 | " print('Validation Loss: {:.3f}'.format(valid_loss))\n",
1467 | " s = score(y_valid, preds)\n",
1468 | " scores.append(s)\n",
1469 | " print('Score: {:.3f}'.format(s))\n",
1470 | " \n",
1471 | " # 保存最佳模型权重\n",
1472 | " if s > best_score:\n",
1473 | " best_score = s\n",
1474 | " checkpoint = {'best_score': s,\n",
1475 | " 'state_dict': model.state_dict()}\n",
1476 | " torch.save(checkpoint, model_weights)"
1477 | ]
1478 | },
1479 | {
1480 | "cell_type": "code",
1481 | "execution_count": 28,
1482 | "metadata": {
1483 | "execution": {
1484 | "iopub.execute_input": "2021-11-29T04:55:33.957872Z",
1485 | "iopub.status.busy": "2021-11-29T04:55:33.957066Z",
1486 | "iopub.status.idle": "2021-11-29T04:55:33.960119Z",
1487 | "shell.execute_reply": "2021-11-29T04:55:33.959684Z",
1488 | "shell.execute_reply.started": "2021-11-28T14:00:36.33194Z"
1489 | },
1490 | "papermill": {
1491 | "duration": 2.38263,
1492 | "end_time": "2021-11-29T04:55:33.960247",
1493 | "exception": false,
1494 | "start_time": "2021-11-29T04:55:31.577617",
1495 | "status": "completed"
1496 | },
1497 | "tags": []
1498 | },
1499 | "outputs": [],
1500 | "source": [
1501 | "# 绘制训练/验证曲线\n",
1502 | "def training_vis(train_losses, valid_losses):\n",
1503 | " # 绘制损失函数曲线\n",
1504 | " fig = plt.figure(figsize=(8,4))\n",
1505 | " # subplot loss\n",
1506 | " ax1 = fig.add_subplot(121)\n",
1507 | " ax1.plot(train_losses, label='train_loss')\n",
1508 | " ax1.plot(valid_losses,label='val_loss')\n",
1509 | " ax1.set_xlabel('Epochs')\n",
1510 | " ax1.set_ylabel('Loss')\n",
1511 | " ax1.set_title('Loss on Training and Validation Data')\n",
1512 | " ax1.legend()\n",
1513 | " plt.tight_layout()"
1514 | ]
1515 | },
1516 | {
1517 | "cell_type": "code",
1518 | "execution_count": 29,
1519 | "metadata": {
1520 | "execution": {
1521 | "iopub.execute_input": "2021-11-29T04:55:38.227343Z",
1522 | "iopub.status.busy": "2021-11-29T04:55:38.226636Z",
1523 | "iopub.status.idle": "2021-11-29T04:55:38.470256Z",
1524 | "shell.execute_reply": "2021-11-29T04:55:38.469252Z",
1525 | "shell.execute_reply.started": "2021-11-28T14:00:43.42651Z"
1526 | },
1527 | "papermill": {
1528 | "duration": 2.378943,
1529 | "end_time": "2021-11-29T04:55:38.470387",
1530 | "exception": false,
1531 | "start_time": "2021-11-29T04:55:36.091444",
1532 | "status": "completed"
1533 | },
1534 | "tags": []
1535 | },
1536 | "outputs": [
1537 | {
1538 | "data": {
1539 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsYUlEQVR4nO3dd3wUdf7H8dcnhSRApIYSeu8QMKKIngoWRATPE1AsgIVT0cMu3umpWE899TwVT08BFQsWDsQu4A89UQiYUKRzlBBCDwQlkPL5/TETWCCBhOxkspvP8/HYR2ZnZmfeu5l8MvPd78yIqmKMMaEmwu8AxhhzIqx4GWNCkhUvY0xIsuJljAlJVryMMSHJipcxJiRZ8aqkRORzERke7Hn9JCLrRORcD5b7rYhc7w5fKSJflWTeE1hPUxHZKyKRJ5q1MqlUxcurjbu8uBt24aNARPYFPL+yNMtS1QtVdVKw562IRGSsiMwpYnxdETkgIp1LuixVnayq5wcp12Hbo6puUNXqqpofjOUfsS4VkV/dbWWHiMwUkaGleP3ZIpIe7FxlUamKV6hzN+zqqlod2ABcHDBucuF8IhLlX8oK6W3gdBFpccT4y4HFqrrEh0x+6OZuO+2AicCLIvKgv5FOnBUvQERiROR5EclwH8+LSIw7ra6IzBCRLBHZKSLfiUiEO+1eEdkkItkiskJE+haz/Boi8qaIbBOR9SJyf8AyRojI9yLyjIjsEpH/iciFpcx/toiku3kygQkiUsvNvc1d7gwRaRzwmsBDoWNmKOW8LURkjvuZfCMiL4nI28XkLknGR0Tkv+7yvhKRugHTr3Y/zx0i8pfiPh9VTQdmAVcfMeka4M3j5Tgi8wgR+T7g+XkislxEdovIi4AETGslIrPcfNtFZLKI1HSnvQU0BT5x94buEZHm7h5SlDtPoohMd7e71SJyQ8CyHxKRKe52lS0iS0UkubjP4IjPY7uqvgXcBNwnInXcZY4UkWXu8taKyB/d8dWAz4FEObSnnygiPUVkrvu3sVlEXhSRKiXJEAxWvBx/AU4DkoBuQE/gfnfanUA6kADUB/4MqIi0A24BTlHVeOACYF0xy/8nUANoCZyF80czMmD6qcAKoC7wFPC6iMiRCzmOBkBtoBkwCud3O8F93hTYB7x4jNeXJsOx5n0HmAfUAR7i6IIRqCQZh+F8VvWAKsBdACLSERjvLj/RXV+RBcc1KTCL+/tLcvOW9rMqXEZd4GOcbaUusAboHTgL8ISbrwPQBOczQVWv5vC956eKWMV7ONteInAZ8LiI9AmYPtCdpyYwvSSZjzANiMLZ3gG2AgOAk3A+8+dEpIeq/gpcCGQE7OlnAPnA7e577wX0BW4uZYYTp6qV5oFTXM4tYvwaoH/A8wuAde7wOJxfcusjXtMa55d9LhB9jHVGAgeAjgHj/gh86w6PAFYHTKsKKNCgpO8FONtdR+wx5k8CdgU8/xa4viQZSjovzh9+HlA1YPrbwNsl/P0UlfH+gOc3A1+4w38F3guYVs39DI76/Qbk3AOc7j5/DJh2gp/V9+7wNcCPAfMJTrG5vpjlXgL8XNz2CDR3P8sonEKXD8QHTH8CmOgOPwR8EzCtI7DvGJ+tcsQ27I7PBK4s5jX/AcYEbGPpx/n93QZMLcnvOhgP2/NyJALrA56vd8cBPA2sBr5yd6XHAqjqapxf1kPAVhF5T0QSOVpdILqI5TcKeJ5ZOKCqv7mD1Uv5Hrapak7hExGpKiL/cg+r9gBzgJpS/DdZpclQ3LyJwM6AcQAbiwtcwoyZAcO/BWRKDFy2OnsHO4pbl5vpA+Aady/xSuDNUuQoypEZNPC5iNR3t4tN7nLfxtkeSqLws8wOGFfsdoPz2cRKKdo7RSQa54hip/v8QhH50T1MzQL6HyuviLR1D7Ez3ff3+LHmDzYrXo4MnEOGQk3dcahqtqreqaotcXbT7xC3bUtV31HVM9zXKvC3Ipa9HcgtYvmbgvwejrw8yJ04DbOnqupJwO/c8aU9HC2NzUBtEakaMK7JMeYvS8bNgct211nnOK+ZBAwBzgPigU/KmOPIDMLh7/dxnN9LF3e5Vx2xzGNd0iUD57OMDxgX7O1mEM6e8jxx2ng/Ap4B6qtqTeCzgLxFZR0PLAfauO/vz3i7fR2mMhavaBGJDXhEAe8C94tIgtuO8Vec/5KIyAARae1umLtxduULRKSdiPRxf+k5OO0kBUeuTJ2vvacAj4lIvIg0A+4oXL6H4t1MWSJSG/D8WyVVXQ+kAA+JSBUR6QVc7FHGD4EBInKG20g8juNvz98BWcCrOIecB8qY41Ogk4hc6m5Hf8I5fC4UD+wFdotII+DuI16/Bacd9CiquhH4AXjC3U67AtcRhO1GRGqL07XmJeBvqroDpz0xBtgG5InzJUxgl5AtQB0RqXHE+9sD7BWR9jhfAJSbyli8PsPZUAsfDwGP4vzRLQIWAwvdcQBtgG9wNsK5wMuqOhvnF/0kzp5VJk6D8n3FrPNW4FdgLfA9TiPxG8F9W0d5Hohz8/0IfOHx+gpdidN4uwPnM3wf2F/MvM9zghlVdSkwGuez3AzswmlvOtZrFOdQsZn7s0w5VHU7MBhnO9iBs638N2CWh4EeOP/0PsVp3A/0BM4/zSwRuauIVVyB0w6WAUwFHlTVb0qSrRhpIrIXpxnkeuB2Vf2r+16ycYrvFJzPchjOlwCF73U5zj/5tW7eRJwvT4YB2cBrOL/rciNuQ5sxnhCR94Hlqhqy/YlMxVQZ97yMh0TkFHH6N0WISD+cdpX/+BzLhCHriW2CrQHO4VEdnMO4m1T1Z38jmXBkh43GmJBkh43GmJAUEoeNdevW1ebNm/sdwxhTzhYsWLBdVROKmhYSxat58+akpKT4HcMYU85EZH1x0+yw0RgTkqx4GWNCkhUvY0xICok2L2MqotzcXNLT08nJyTn+zOaYYmNjady4MdHR0SV+jRUvY05Qeno68fHxNG/enNJfO9IUUlV27NhBeno6LVoceaXu4tlhozEnKCcnhzp16ljhKiMRoU6dOqXeg7XiZUwZWOEKjhP5HMOreOXsgR/Hg53yZEzYC6/itXwGfDEWUicff15jTEgLr+LV9XJo2gu+uh9+3e53GmM8lZWVxcsvv1zq1/Xv35+srKxSv27EiBF8+OGHpX6dV8KreEVEwIDnYf9ep4AZE8aKK155eXnHfN1nn31GzZo1PUpVfsKvq0S99tB7DHz3DHS7Alqe5XciUwk8/MlSfsnYE9Rldkw8iQcv7lTs9LFjx7JmzRqSkpKIjo4mNjaWWrVqsXz5clauXMkll1zCxo0bycnJYcyYMYwaNQo4dK7w3r17ufDCCznjjDP44YcfaNSoEdOmTSMuLu642WbOnMldd91FXl4ep5xyCuPHjycmJoaxY8cyffp0oqKiOP/883nmmWf44IMPePjhh4mMjKRGjRrMmTMnKJ9PeO15FfrdXVC7Jcy4HXKtA6EJT08++SStWrUiNTWVp59+moULF/KPf/yDlStXAvDGG2+wYMECUlJSeOGFF9ix4+g7w61atYrRo0ezdOlSatasyUcffXTc9ebk5DBixAjef/99Fi9eTF5eHuPHj2fHjh1MnTqVpUuXsmjRIu6/3zn6GTduHF9++SVpaWlMnz79OEsvufDb8wKIjoOLnoW3LoHv/g59ir0TvDFBcaw9pPLSs2fPwzp5vvDCC0ydOhWAjRs3smrVKurUOfzucC1atCApKQmAk08+mXXr1h13PStWrKBFixa0bdsWgOHDh/PSSy9xyy23EBsby3XXXceAAQMYMGAAAL1792bEiBEMGTKESy+9NAjv1BGee14Arc6BrkPh++dg2wq/0xjjuWrVqh0c/vbbb/nmm2+YO3cuaWlpdO/evchOoDExMQeHIyMjj9tedixRUVHMmzePyy67jBkzZtCvXz8AXnnlFR599FE2btzIySefXOQe4IkI3+IFcP5jUKUafHIbFBx1S0VjQlp8fDzZ2dlFTtu9eze1atWiatWqLF++nB9//DFo623Xrh3r1q1j9erVALz11lucddZZ7N27l927d9O/f3+ee+450tLSAFizZg2nnnoq48aNIyEhgY0bi72JeqmE52FjoeoJcP4jMP1Wp+9Xj6v9TmRM0NSpU4fevXvTuXNn4uLiqF+//sFp/fr145VXXqFDhw60a9eO0047LWjrjY2NZcKECQwePPhgg/2NN97Izp07GTRoEDk5Oagqzz77LAB33303q1atQlXp27cv3bp1C0qOkLgBR3Jysp7wlVRVYeJFsGUp3JLiFDRjgmDZsmV06NDB7xhho6jPU0QWqGpyUfOH92EjgAgMeA4O/ApfWcO9MeEi/IsXQEI7OON2WPQ+rJntdxpjKrTRo0eTlJR02GPChAl+xzpKeLd5BTrzTljyIXx6B9z0g9OdwhhzlJdeesnvCCVSOfa8AKJjncPHnWthzjN+pzHGlFHlKV4ALc92Thn67z9g6zK/0xhjyqByFS+A8x+FmOrOqUPW98uYkFX5ile1uk4B2zAXfn7L7zTGmBPkefESkUgR+VlEZrjPW4jITyKyWkTeF5EqXmc4StKV0OwM+PoB2Lu13FdvjB+qV69e7LR169bRuXPnckxTduWx5zUGCGxg+hvwnKq2BnYB15VDhsMV9v3K3Qdf/rncV2+MKTtPu0qISGPgIuAx4A5xrrLfBxjmzjIJeAgY72WOIiW0hTPugP970mnEb9233COYMPL5WMhcHNxlNugCFz5Z7OSxY8fSpEkTRo8eDcBDDz1EVFQUs2fPZteuXeTm5vLoo48yaNCgUq02JyeHm266iZSUFKKionj22Wc555xzWLp0KSNHjuTAgQMUFBTw0UcfkZiYyJAhQ0hPTyc/P58HHniAoUOHlultl5TXe17PA/cAhS3jdYAsVS08dT0daFTUC0VklIikiEjKtm3bvEl3xu1Qp7XT9yt3nzfrMMYjQ4cOZcqUKQefT5kyheHDhzN16lQWLlzI7NmzufPOOyntKYAvvfQSIsLixYt59913GT58ODk5ObzyyiuMGTOG1NRUUlJSaNy4MV988QWJiYmkpaWxZMmSg1eSKA+e7XmJyABgq6ouEJGzS/t6VX0VeBWccxuDm85V2Pdr0sUw52no+1dPVmMqgWPsIXmle/fubN26lYyMDLZt20atWrVo0KABt99+O3PmzCEiIoJNmzaxZcsWGjRoUOLlfv/999x6660AtG/fnmbNmrFy5Up69erFY489Rnp6Opdeeilt2rShS5cu3Hnnndx7770MGDCAM88806u3exQv97x6AwNFZB3wHs7h4j+AmiJSWDQbA5s8zHB8LX4H3YY5fb+2/OJrFGNKa/DgwXz44Ye8//77DB06lMmTJ7Nt2zYWLFhAamoq9evXL/XNXIszbNgwpk+fTlxcHP3792fWrFm0bduWhQsX0qVLF+6//37GjRsXlHWVhGfFS1XvU9XGqtocuByYpapXArOBy9zZhgPTvMpQYuc/CjEnwYzbrO+XCSlDhw7lvffe48MPP2Tw4MHs3r2bevXqER0dzezZs1m/fn2pl3nmmWcyebJz+8CVK1eyYcMG2rVrx9q1a2nZsiV/+tOfGDRoEIsWLSIjI4OqVaty1VVXcffdd7Nw4cJgv8Vi+dHP616cxvvVOG1gr/uQ4XDV6sAFj8HGn2DhJL/TGFNinTp1Ijs7m0aNGtGwYUOuvPJKUlJS6NKlC2+++Sbt27cv9TJvvvlmCgoK6NKlC0OHDmXixInExMQwZcoUOnfuTFJSEkuWLOGaa65h8eLF9OzZk6SkJB5++OGD160vD+F/Pa+SUnXavjIXwej5EF//+K8xlZpdzyu47HpeJ+qwvl/3+Z3GGHMcleeSOCVRtw2ceRd8+7jTiN/mXL8TGRNUixcv5uqrD78cekxMDD/99JNPiU6cFa8jnXEbLP7A6ft1849QparfiUwFpqo4fa9DQ5cuXUhNTfU7xlFOpPnKDhuPFBUDFz8PWethzlN+pzEVWGxsLDt27DihPzxziKqyY8cOYmNjS/U62/MqSvMzIOkq+OGf0GUw1Pf/hqKm4mncuDHp6el4dgZIJRIbG0vjxo1L9RorXsU5/xFY+blzz8drv4QI20k1h4uOjj7sDtWmfNlfZHGq1oYLHof0ebCg4t18wJjKzorXsXQd6pw+9M3DkJ3pdxpjTAArXsciAhc9B3k58IX1/TKmIrHidTx1W8Pv7oKlH8Oqr/1OY4xxWfEqid5joG47mHGHc+dtY4zvrHiVRFSMc+rQ7g3wf3/zO40xBiteJde8N3S/Gn54ETKX+J3GmErPildpnDcO4mrBJ2OgIN/vNMZUala8SqOw79emFEh5w+80xlRqVrxKq+sQaHk2zBwHezb7ncaYSsuKV2mJwEXPQt5++GKs32mMqbSseJ2IOq3grLvhl//Ayi/9TmNMpWTF60SdPgYS2sOnd1nfL2N8YMXrREVVgQHPO32/vn3C7zTGVDpWvMqiWS/oMRzmvgybF/mdxphKxYpXWZ37kNOFYsZt1vfLmHJkxausqtaGC56ATQtgvv+3oDSmsrDiFQxdLoNWfdy+Xxl+pzGmUrDiFQwicNHfoSAXPr/X7zTGVApWvIKldks46x5YNh1WfO53GmPCnhWvYOp1KyR0gM/uhv17/U5jTFiz4hVMUVWcez7u3mh9v4zxmBWvYGt6Gpw8En58GTan+Z3GmLBlxcsL5z4IVevadb+M8ZAVLy/E1YJ+T0DGzzDvNb/TGBOWrHh5pfMfoFVfmPUI7N7kdxpjwo4VL6+IwIBnncPGz+/xO40xYceKl5dqNYez74XlM2D5p36nMSasWPHyWq9boF5Ht+9Xtt9pjAkbVry8FhkNF/8D9myC2Y/7ncaYsGHFqzw06QnJ18JPrzjfQBpjysyKV3np+yBUS3D6fuXn+Z3GmJBnxau8xNWEfk86ve7nW98vY8rKs+IlIrEiMk9E0kRkqYg87I5vISI/ichqEXlfRKp4laHC6fR7aH0ezHoUdqf7ncaYkOblntd+oI+qdgOSgH4ichrwN+A5VW0N7AKu8zBDxSICFz3j9P36zPp+GVMWnhUvdRReFybafSjQB/jQHT8JuMSrDBVSreZwzn2w4lNYNsPvNMaELE/bvEQkUkRSga3A18AaIEtVC1us04FGXmaokE67Gep3tr5fxpSBp8VLVfNVNQloDPQE2pf0tSIySkRSRCRl27ZtXkX0R2S0c8/H7M1O+5cxptTK5dtGVc0CZgO9gJoiEuVOagwUedayqr6qqsmqmpyQkFAeMctXk1PglOvgp385dx4yxpSKl982JohITXc4DjgPWIZTxC5zZxsOTPMqQ4XX969QvT58cpv1/TKmlLzc82oIzBaRRcB84GtVnQHcC9whIquBOkDlvdlhbA248G+QuQjm/cvvNMaElKjjz3JiVHUR0L2I8Wtx2r8MQMdB0OYCmPUYdBgINZv4nciYkGA97P0mAv2fBtT59lHV70TGhAQrXhVBrWZw9n2w8nNY9onfaYwJCVa8KorTbob6XZyrrubs8TuNMRWeFa+KIjLKue5Xdqb1/TKmBKx4VSSNT4aeN8C8VyHd+n4ZcyxWvCqaPvdDfAO77pcxx2HFq6Ip7Pu1ZTH8NN7vNMZUWFa8KqIOA6Hthc4177M2+J3GmArJildFdLDvl8Cnd1nfL2OKYMWroqrZBM75M6z6En6pvKd/GlMcK14V2ak3QoOu8Pm9kLPb7zTGVChWvCqywr5fv26FmY/4ncaYCsWKV0XXqAf0HAXz/w3pKX6nMabCsOIVCs75C8Q3dPt+5fqdxpgKwYpXKIg9Cfo/BVuWwI8v+53GmArBileoaD8A2vWH2U/ArvV+pzHGd1a8QkVh3y+JgM+s75cxVrxCSY3GzrmPq76CpVP9TmOMr0pUvESkmohEuMNtRWSgiER7G80UqecoaNjN6fu1fZXfaYzxTUn3vOYAsSLSCPgKuBqY6FUocwyRUTDoZdACeK0PrPzS70TG+KKkxUtU9TfgUuBlVR0MdPIuljmmBp1h1LdQqzm8MxTmPGNtYKbSKXHxEpFewJXAp+64SG8imRKp2QSu/RK6DIZZj8AHw2H/Xr9TGVNuSlq8bgPuA6aq6lIRaYlz81jjpypV4dJX4fxHnRt3vH4e7FzrdypjyoVoKQ833Ib76qpabneJSE5O1pQUOzXmmNbMgg9GOsODJ0CrPv7mMSYIRGSBqiYXNa2k3za+IyIniUg1YAnwi4jcHcyQpoxa9XHawU5qBG//Af77grWDmbBW0sPGju6e1iXA50ALnG8cTUVSuwVc9xV0uBi+fgA+vgEO/OZ3KmM8UdLiFe3267oEmK6quYD9W6+IYqrD4EnQ96+w+EN44wK7lLQJSyUtXv8C1gHVgDki0gywO6NWVCJw5p0wbIpzHuSrZ8P/vvM7lTFBVaLipaovqGojVe2vjvXAOR5nM2XV9ny4YRZUrQtvDoKf/mXtYCZslLTBvoaIPCsiKe7j7zh7Yaaiq9sarv8G2l4An98D00ZDbo7fqYwps5IeNr4BZAND3MceYIJXoUyQxZ4EQyfDWWMhdTJMuBB2b/I7lTFlUtLi1UpVH1TVte7jYaCll8FMkEVEwDn3OUVs+0qnHWzDj36nMuaElbR47RORMwqfiEhvYJ83kYynOgyA62c630pOHAApb/idyJgTElXC+W4E3hSRGu7zXcBwbyIZz9VrDzfMho+uhxm3w+Y0uPBpiKridzJjSqyk3zamqWo3oCvQVVW7A3b+SSiLqwnD3ocz7oAFE2HSAMjO9DuVMSVWqiupquqegHMa7/AgjylPEZFw7oNw2QTIXOy0g6Uv8DuVMSVSlstAS9BSGH91vhSu+xoiq8CEfvDz234nMua4ylK8rLdjOCm8wGHTXk5fsM/usXtEmgrtmA32IpJN0UVKgDhPEhn/VK0NV30M3zwIc1+ELUthyCSoVtfvZMYc5Zh7Xqoar6onFfGIV9WSflNpQklkFFzwGPz+VdiU4rSDZaT6ncqYo3h26zMRaSIis0XkFxFZKiJj3PG1ReRrEVnl/qzlVQZTBt2GwrVfOOdCvnEBLJridyJjDuPlfRvzgDtVtSNwGjBaRDoCY4GZqtoGmOk+NxVRYnenHazRyc61wb78C+Tn+Z3KGMDD4qWqm1V1oTucDSwDGgGDgEnubJNwrhFmKqrqCXDNNOd+kXNfhMl/gN92+p3KmPK5Y7aINAe6Az8B9VV1szspE6hfzGtGFV7FYtu2beUR0xQnMhr6Pw0DX4T1PzjtYJlL/E5lKjnPi5eIVAc+Am478qYd6tz9o8guF6r6qqomq2pyQkKC1zFNSfS4GkZ+DvkHnDsVLf2P34lMJeZp8XIvHf0RMFlVP3ZHbxGRhu70hsBWLzOYIGuc7LSD1e/s3Cty5jgoyPc7lamEvPy2UYDXgWWq+mzApOkcOql7ODDNqwzGI/ENYMQM6DEcvvs7vHs57MvyO5WpZLzc8+qNc4ehPiKS6j76A08C54nIKuBc97kJNVExMPAFGPCcc8/I1/rAthV+pzKViGcdTVX1e4o//7GvV+s15Sz5WkjoAFOugdf6wqX/gvYX+Z3KVALl8m2jCXPNejntYHVbw3vD4NsnoaDA71QmzFnxMsFRo5HzTWS3K+DbJ+D9qyDH7o5nvGPFywRPdBxcMh76PQkrv4B/nwvbV/udyoQpK14muETgtJvg6qnw6zanIX/lV36nMmHIipfxRsuznHawmk3hnSHw3bN2w1sTVHZZG+OdWs3guq9g+i0w82HnRh+XvAxV7H7FAKrKwg27mJaawc8bslAUQRD3O3oBEDn4lb2IM07cGeTguEMvKBznPHWWdeRzDi5Pjpj/0DgOjpcjph9aX+Gyj17/kfkOrS86Unjqsm4n/qEFsOJlvFWlKvzhdWjYDb55CLavgssnQ+0WfifzzfLMPUxLzeCTtAzSd+0jJiqCU5rXJjrS+TNXnJ3Uwv1UdfdYC3dcFT00rEc8B7TAGYc7XgOXcdhynCeH1lPUsg+9NnCewOmFGfXgQg+97shlR0cG72DPipfxngj0HgP1O8GH18Jr5zg3/Wh1jt/Jys3Gnb8xPS2D6akZrNiSTWSE0Lt1XW4/ty3nd6pPfGy03xFDjmgItEMkJydrSkqK3zFMMOxYA+9dCdtXwHmPQK/RHHYsE0a2Ze/n00UZTE/LYOGGLACSm9ViYFIi/bs0pG71GH8DhgARWaCqyUVNsz0vU77qtILrv4b/3ARf/cVpBxv4gtPNIgzsycnlyyWZTE/L4L+rt1Og0L5BPPf0a8fFXRNpUruq3xHDhhUvU/5i4mHwm85J3bMfc/bChk6Gmk38TnZCcnLzmb18K9PTMpi5fCsH8gpoUjuOm85uxcBujWjXIN7viGHJipfxR0QEnHW3c8u1j0c5FzgcMgman+F3shLJyy9g7todTEvN4MslmWTvz6Nu9SoM69mUgUmJdG9S8+C3bsYbVryMv9pdCNfPdM6JfHMQXPAE9LyhQraDqSo/b8xiemoGMxZlsH3vAeJjorigcwMGJSXSq2UdooL4bZo5Nitexn8JbeGGmc4e2Od3Q2Ya9P87RMf6nQyAlVuymZa6ielpGWzcuY8qURH0bV+PQUmJnN2uHrHRkX5HrJSseJmKIbYGXP6uc1L3nKdg63IY+haclOhLnI07f+OTRU7XhuWZ2UQI9G5dlzF9na4NJ1nXBt9Z8TIVR0QE9PkLNOgCU2902sF+/wrUbonTfVtK/7MU827/9QCfL8nkk7RMUjZkoUD3JrV4eGAn+ndpSEK8dW2oSKyfl6mYtvzitIPt+p/fSQIUVfQiSlAYi3ttYJGNAIl0CrhEQkRkwDh3OCIy4PmxxkcUMV9x4wNef9iySrOM0oyPgianlPwTt35eJuTU7+ic2L3qayjIdc9p0cN/asHR46DoeQN+5ubns2brXpZl7GbN1mzyC5SacZF0aBhPhwbVSagec9xlHPPnCb2mwHkU5IPmuz8L3OGCgHEB0wryQQ8UMz5w/qJeX8z4om/mFTyRVeCB4NzK0IqXqbjiakLXwUFZVH6BMnfNDqanbeLzJZlk5+RRp1oVBiQ3ZGBSIj2a1rKuDXCokB5V6IorjKUsmEFkxcuELVUldWMW09MymLFoM9uy91M9JooLOjVgYFIivVtZ14ajiBw6zKvgrHiZsLN6azbTUjOYlprBhp2/USUygj7t6zEwKZE+7a1rQ7iw4mXCwqasfXyS5hSsZZv3ECFwequ63NKnNRd0akCNOOvaEG6seJmQtfPXA3y6eDPTUzcxf90uAJKa1OTBiztyUdeG1IuvGJ1cjTeseJmQsnd/Hl//ksm01Ay+X7WdvAKlTb3q3HV+Wy7ulkizOnaV1srCipep8Pbn5fN/K7YxPS2Db5ZtISe3gEY147j+zJYMSkqkfYN4+6awErLiZSqs9F2/8eKs1Xy2eDN7cvKoXa0Kg09uwiC3a0NEhBWsysyKl6mQFqVnce3EFH7dn8eFnd2uDa3rBvUa6Ca0WfEyFc7MZVu45Z2fqVO9Cu+N6k3renYxP3M0K16mQnlr7joenL6UTok1eH1Esn1jaIplxctUCAUFyt++WM6/5qzl3A71eOGK7lStYpunKZ5tHcZ3Obn53PlBGp8u2szVpzXjoYGdiLTGeHMcVryMr3b9eoAb3kwhZf0u/ty/PTec2dK6PZgSseJlfLN+x6+MnDCf9Kx9vDSsBxd1beh3JBNCrHgZX/y8YRfXT0ohX5V3rj+V5Oa1/Y5kQowVL1PuvlyayZj3fqZefCwTR55Cy4TqfkcyIciKlylXb3z/Px759Be6Na7J68OTqWO3vDcnyIqXKRf5Bcpjny7jjf/+jws61ef5od2Jq2LX1TInzoqX8VxObj63vZfKF0szGdm7Ofdf1NG6Qpgys+JlPLVj736ufzOF1I1Z/HVAR649o4XfkUyYsOJlPLN2215GTpxP5u4cxl95Mv06N/A7kgkjnp2iLyJviMhWEVkSMK62iHwtIqvcn7W8Wr/x14L1O/nD+B/Izsnj3VGnWeEyQefl9UUmAv2OGDcWmKmqbYCZ7nMTZj5dtJkrXvuJmlWrMPXm0+nR1P5HmeDzrHip6hxg5xGjBwGT3OFJwCVerd+UP1XltTlrGf3OQro2qsFHN51ul2U2ninvNq/6qrrZHc4E6hc3o4iMAkYBNG3atByimbLIL1DGfbKUSXPXc1GXhvx9SDe7xZjxlG+XpVQtvNd5sdNfVdVkVU1OSEgox2SmtH47kMcf31rApLnrGfW7lvzziu5WuIznynvPa4uINFTVzSLSENhazus3QbYtez/XTZrPkk27eWRQJ67u1dzvSKaSKO89r+nAcHd4ODCtnNdvgmj11r38/uX/smrLXl69OtkKlylXnu15ici7wNlAXRFJBx4EngSmiMh1wHpgiFfrN976ae0ORr21gOhI4f0/nkbXxjX9jmQqGc+Kl6peUcykvl6t05SPaambuPuDRTSpHcfEkT1pUruq35FMJWQ97E2JqSrj/28NT32xgp4tavPa1cnUqBrtdyxTSVnxMiWSl1/AA9OW8u68DQzslsjTg7sSE2XfKBr/WPEyx/Xr/jxueWchs1ds4+azW3HX+e3sbtXGd1a8zDFt2ZPDtRPnszwzm8d/34Vhp1qHYVMxWPEyxVq5JZsRb8wja18u/x6ezDnt6vkdyZiDrHiZIv2wejt/fHsBcdGRTPljLzo3quF3JGMOY8XLHOXjhenc+9EiWtStxoSRPWlUM87vSMYcxYqXOUhV+ees1Tz79UpOb1WH8VedTI046wphKiYrXgaA3PwC/jJ1MVNS0rm0RyOevLQrVaJ8O2/fmOOy4mXIzsnl5skL+W7Vdv7Utw23n9sGEesKYSo2K16V3Obd+xg5YT6rt+7lqT90ZcgpTfyOZEyJWPGqxJZt3sPICfPZuz+PCSNP4cw2dt00EzqseFVSc1Zu4+bJC6keE8UHN/aiQ8OT/I5kTKlY8aqEpqRs5M8fL6Z1vepMGHkKDWtYVwgTeqx4VSKqynPfrOKFmas4s01dXr6yB/Gx1hXChCYrXpXEgbwCxn68iI8XbmJIcmMe+30XoiOtK4QJXVa8KoHd+3K56e0F/LBmB3ee15Zb+rS2rhAm5FnxCnObsvYxcsI8/rf9V54d0o1LezT2O5IxQWHFK4wt2bSbayfOZ19uPpNG9uT01nX9jmRM0FjxClOzV2xl9OSF1KpahbevP5W29eP9jmRMUFnxCkPv/LSBB6YtoX2DeCaMOIV6J8X6HcmYoLPiFUYKCpRnvlrBy9+u4Zx2Cbw4rAfVYuxXbMKTbdlhYn9ePnd/sIjpaRkMO7Up4wZ2Isq6QpgwZsUrDGT9doBRby1g3v92cm+/9tx4VkvrCmHCnhWvELdx52+MmDCPjTv38Y/LkxiU1MjvSMaUCyteIWxRehbXTpxPbr7y1nU9ObVlHb8jGVNuwqp4fbk0k3s/WkRUhBAVEUFkhBAdKe5P53lUhBAVGTgtgugIZ56oSOd1Ue5wZETEwdcXvq5w2c68RS876uCy3BwBw1FFDB+2jMgipkVEHHWfxG9+2cKt7/5MnepVeG9UT1rXq+7Tp26MP8KqeDWqGcegbonkFij5+UpuQQH5BUpevpLnDufmq/uzgP25BeQW5JNfUODOc2jaoXkLyHOXkV/gLFO1/N+bCERHHCpwe/fn0bVRDf49/BQS4mPKP5AxPgur4tW5UY1yuUVXQUFAYQwojnn5hxe/w6YdMZx/1HglL7/gqAJ61OvcaTXiohn1u5ZUrRJWv0JjSsy2/BMQESHERET6HcOYSs06AhljQpIVL2NMSLLiZYwJSVa8jDEhyYqXMSYkWfEyxoQkK17GmJBkxcsYE5JE/TjXpZREZBuwvoSz1wW2exinorD3GV7sfRatmaomFDUhJIpXaYhIiqom+53Da/Y+w4u9z9Kzw0ZjTEiy4mWMCUnhWLxe9TtAObH3GV7sfZZS2LV5GWMqh3Dc8zLGVAJWvIwxISmsipeI9BORFSKyWkTG+p3HCyLyhohsFZElfmfxkog0EZHZIvKLiCwVkTF+Z/KCiMSKyDwRSXPf58N+Z/KSiESKyM8iMqOsywqb4iUikcBLwIVAR+AKEenobypPTAT6+R2iHOQBd6pqR+A0YHSY/j73A31UtRuQBPQTkdP8jeSpMcCyYCwobIoX0BNYraprVfUA8B4wyOdMQaeqc4CdfufwmqpuVtWF7nA2zgYfdjelVMde92m0+wjLb9FEpDFwEfDvYCwvnIpXI2BjwPN0wnBjr4xEpDnQHfjJ5yiecA+lUoGtwNeqGpbvE3geuAcoCMbCwql4mTAkItWBj4DbVHWP33m8oKr5qpoENAZ6ikhnnyMFnYgMALaq6oJgLTOcitcmoEnA88buOBOiRCQap3BNVtWP/c7jNVXNAmYTnm2avYGBIrIOp0mnj4i8XZYFhlPxmg+0EZEWIlIFuByY7nMmc4JERIDXgWWq+qzfebwiIgkiUtMdjgPOA5b7GsoDqnqfqjZW1eY4f5uzVPWqsiwzbIqXquYBtwBf4jTuTlHVpf6mCj4ReReYC7QTkXQRuc7vTB7pDVyN8x861X309zuUBxoCs0VkEc4/4K9VtczdCCoDOz3IGBOSwmbPyxhTuVjxMsaEJCtexpiQZMXLGBOSrHgZY0KSFS/jKRHJD+jqkBrMq32ISPNwv7qGKV6U3wFM2NvnnvpiTFDZnpfxhYisE5GnRGSxez2r1u745iIyS0QWichMEWnqjq8vIlPd616licjp7qIiReQ191pYX7m91BGRP7nXAlskIu/59DaNh6x4Ga/FHXHYODRg2m5V7QK8iHPFAYB/ApNUtSswGXjBHf8C8H/uda96AIVnT7QBXlLVTkAW8Ad3/Figu7ucG715a8ZP1sPeeEpE9qpq9SLGr8O5CN9a9wTsTFWtIyLbgYaqmuuO36yqdd27pjdW1f0By2iOczpNG/f5vUC0qj4qIl8Ae4H/AP8JuGaWCRO252X8pMUMl8b+gOF8DrXjXoRzZd0ewHwRsfbdMGPFy/hpaMDPue7wDzhXHQC4EvjOHZ4J3AQHL95Xo7iFikgE0ERVZwP3AjWAo/b+TGiz/0bGa3HuVUILfaGqhd0larlXU9gPXOGOuxWYICJ3A9uAke74McCr7lU08nEK2eZi1hkJvO0WOAFecK+VZcKItXkZX7htXsmqut3vLCY02WGjMSYk2Z6XMSYk2Z6XMSYkWfEyxoQkK17GmJBkxcsYE5KseBljQtL/A/w8USoPr020AAAAAElFTkSuQmCC\n",
1540 | "text/plain": [
1541 | ""
1542 | ]
1543 | },
1544 | "metadata": {
1545 | "needs_background": "light"
1546 | },
1547 | "output_type": "display_data"
1548 | }
1549 | ],
1550 | "source": [
1551 | "training_vis(train_losses, valid_losses)"
1552 | ]
1553 | },
1554 | {
1555 | "cell_type": "markdown",
1556 | "metadata": {},
1557 | "source": [
1558 | "#### 模型评估\n",
1559 | "\n",
1560 | "在测试集上评估模型效果。"
1561 | ]
1562 | },
1563 | {
1564 | "cell_type": "code",
1565 | "execution_count": 31,
1566 | "metadata": {
1567 | "execution": {
1568 | "iopub.execute_input": "2021-11-29T04:55:47.340416Z",
1569 | "iopub.status.busy": "2021-11-29T04:55:47.339537Z",
1570 | "iopub.status.idle": "2021-11-29T04:55:47.606447Z",
1571 | "shell.execute_reply": "2021-11-29T04:55:47.607038Z",
1572 | "shell.execute_reply.started": "2021-11-28T14:01:44.127754Z"
1573 | },
1574 | "papermill": {
1575 | "duration": 2.453872,
1576 | "end_time": "2021-11-29T04:55:47.607210",
1577 | "exception": false,
1578 | "start_time": "2021-11-29T04:55:45.153338",
1579 | "status": "completed"
1580 | },
1581 | "tags": []
1582 | },
1583 | "outputs": [
1584 | {
1585 | "data": {
1586 | "text/plain": [
1587 | ""
1588 | ]
1589 | },
1590 | "execution_count": 31,
1591 | "metadata": {},
1592 | "output_type": "execute_result"
1593 | }
1594 | ],
1595 | "source": [
1596 | "# 加载得分最高的模型\n",
1597 | "checkpoint = torch.load('../input/ai-earth-model-weights/task05_model_weights.pth')\n",
1598 | "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n",
1599 | "model.load_state_dict(checkpoint['state_dict'])"
1600 | ]
1601 | },
1602 | {
1603 | "cell_type": "code",
1604 | "execution_count": 32,
1605 | "metadata": {
1606 | "execution": {
1607 | "iopub.execute_input": "2021-11-29T04:55:51.849996Z",
1608 | "iopub.status.busy": "2021-11-29T04:55:51.849073Z",
1609 | "iopub.status.idle": "2021-11-29T04:55:51.851492Z",
1610 | "shell.execute_reply": "2021-11-29T04:55:51.850969Z",
1611 | "shell.execute_reply.started": "2021-11-28T14:06:59.931318Z"
1612 | },
1613 | "papermill": {
1614 | "duration": 2.125413,
1615 | "end_time": "2021-11-29T04:55:51.851629",
1616 | "exception": false,
1617 | "start_time": "2021-11-29T04:55:49.726216",
1618 | "status": "completed"
1619 | },
1620 | "tags": []
1621 | },
1622 | "outputs": [],
1623 | "source": [
1624 | "# 测试集路径\n",
1625 | "test_path = '../input/ai-earth-tests/'\n",
1626 | "# 测试集标签路径\n",
1627 | "test_label_path = '../input/ai-earth-tests-labels/'"
1628 | ]
1629 | },
1630 | {
1631 | "cell_type": "code",
1632 | "execution_count": 33,
1633 | "metadata": {
1634 | "execution": {
1635 | "iopub.execute_input": "2021-11-29T04:55:56.429364Z",
1636 | "iopub.status.busy": "2021-11-29T04:55:56.428800Z",
1637 | "iopub.status.idle": "2021-11-29T04:55:58.115486Z",
1638 | "shell.execute_reply": "2021-11-29T04:55:58.115007Z",
1639 | "shell.execute_reply.started": "2021-11-28T14:07:13.415385Z"
1640 | },
1641 | "papermill": {
1642 | "duration": 4.135325,
1643 | "end_time": "2021-11-29T04:55:58.115667",
1644 | "exception": false,
1645 | "start_time": "2021-11-29T04:55:53.980342",
1646 | "status": "completed"
1647 | },
1648 | "tags": []
1649 | },
1650 | "outputs": [],
1651 | "source": [
1652 | "import os\n",
1653 | "\n",
1654 | "# 读取测试数据和测试数据的标签\n",
1655 | "files = os.listdir(test_path)\n",
1656 | "X_test = []\n",
1657 | "y_test = []\n",
1658 | "for file in files:\n",
1659 | " X_test.append(np.load(test_path + file))\n",
1660 | " y_test.append(np.load(test_label_path + file))"
1661 | ]
1662 | },
1663 | {
1664 | "cell_type": "code",
1665 | "execution_count": 34,
1666 | "metadata": {
1667 | "execution": {
1668 | "iopub.execute_input": "2021-11-29T04:56:02.560786Z",
1669 | "iopub.status.busy": "2021-11-29T04:56:02.559461Z",
1670 | "iopub.status.idle": "2021-11-29T04:56:02.587431Z",
1671 | "shell.execute_reply": "2021-11-29T04:56:02.588024Z",
1672 | "shell.execute_reply.started": "2021-11-28T14:07:17.046359Z"
1673 | },
1674 | "papermill": {
1675 | "duration": 2.329175,
1676 | "end_time": "2021-11-29T04:56:02.588201",
1677 | "exception": false,
1678 | "start_time": "2021-11-29T04:56:00.259026",
1679 | "status": "completed"
1680 | },
1681 | "tags": []
1682 | },
1683 | "outputs": [
1684 | {
1685 | "data": {
1686 | "text/plain": [
1687 | "((103, 12, 24, 48, 1), (103, 24))"
1688 | ]
1689 | },
1690 | "execution_count": 34,
1691 | "metadata": {},
1692 | "output_type": "execute_result"
1693 | }
1694 | ],
1695 | "source": [
1696 | "X_test = np.array(X_test)[:, :, :, 19: 67, :1]\n",
1697 | "y_test = np.array(y_test)\n",
1698 | "X_test.shape, y_test.shape"
1699 | ]
1700 | },
1701 | {
1702 | "cell_type": "code",
1703 | "execution_count": 35,
1704 | "metadata": {
1705 | "execution": {
1706 | "iopub.execute_input": "2021-11-29T04:56:07.675344Z",
1707 | "iopub.status.busy": "2021-11-29T04:56:07.674488Z",
1708 | "iopub.status.idle": "2021-11-29T04:56:07.682481Z",
1709 | "shell.execute_reply": "2021-11-29T04:56:07.682895Z",
1710 | "shell.execute_reply.started": "2021-11-28T14:07:31.503452Z"
1711 | },
1712 | "papermill": {
1713 | "duration": 2.455352,
1714 | "end_time": "2021-11-29T04:56:07.683041",
1715 | "exception": false,
1716 | "start_time": "2021-11-29T04:56:05.227689",
1717 | "status": "completed"
1718 | },
1719 | "tags": []
1720 | },
1721 | "outputs": [],
1722 | "source": [
1723 | "testset = AIEarthDataset(X_test, y_test)\n",
1724 | "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)"
1725 | ]
1726 | },
1727 | {
1728 | "cell_type": "code",
1729 | "execution_count": null,
1730 | "metadata": {},
1731 | "outputs": [],
1732 | "source": [
1733 | "# 在测试集上评估模型效果\n",
1734 | "model.eval()\n",
1735 | "model.to(device)\n",
1736 | "preds = np.zeros((len(y_test),24))\n",
1737 | "for i, data in tqdm(enumerate(testloader)):\n",
1738 | " data, labels = data\n",
1739 | " data = data.to(device)\n",
1740 | " labels = labels.to(device)\n",
1741 | " pred = model(data, train=False)\n",
1742 | " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
1743 | "s = score(y_test, preds)\n",
1744 | "print('Score: {:.3f}'.format(s))"
1745 | ]
1746 | },
1747 | {
1748 | "cell_type": "markdown",
1749 | "metadata": {
1750 | "papermill": {
1751 | "duration": null,
1752 | "end_time": null,
1753 | "exception": null,
1754 | "start_time": null,
1755 | "status": "pending"
1756 | },
1757 | "tags": []
1758 | },
1759 | "source": [
1760 | "## 总结\n",
1761 | "\n",
1762 | "这一次的TOP方案没有自己设计模型,而是使用了目前时空序列预测领域现有的模型,另一组TOP选手“ailab”也使用了现有的模型PredRNN++,关于时空序列预测领域的一些比较经典的模型可以参考https://www.zhihu.com/column/c_1208033701705162752"
1763 | ]
1764 | },
1765 | {
1766 | "cell_type": "markdown",
1767 | "metadata": {},
1768 | "source": [
1769 | "## 作业\n",
1770 | "\n",
1771 | "该TOP方案中以sst作为预测目标,间接计算nino3.4指数,学有余力的同学可以尝试用SA-ConvLSTM模型直接预测nino3.4指数。"
1772 | ]
1773 | },
1774 | {
1775 | "cell_type": "markdown",
1776 | "metadata": {},
1777 | "source": [
1778 | "## 参考文献\n",
1779 | "\n",
1780 | "1. 吴先生的队伍方案分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.9.561d5330dF9lX1&postId=231465\n",
1781 | "2. ailab团队思路分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.15.561d5330dF9lX1&postId=210734"
1782 | ]
1783 | }
1784 | ],
1785 | "metadata": {
1786 | "kernelspec": {
1787 | "display_name": "Python 3",
1788 | "language": "python",
1789 | "name": "python3"
1790 | },
1791 | "language_info": {
1792 | "codemirror_mode": {
1793 | "name": "ipython",
1794 | "version": 3
1795 | },
1796 | "file_extension": ".py",
1797 | "mimetype": "text/x-python",
1798 | "name": "python",
1799 | "nbconvert_exporter": "python",
1800 | "pygments_lexer": "ipython3",
1801 | "version": "3.7.3"
1802 | },
1803 | "papermill": {
1804 | "default_parameters": {},
1805 | "duration": 6708.571081,
1806 | "end_time": "2021-11-29T04:56:15.789285",
1807 | "environment_variables": {},
1808 | "exception": true,
1809 | "input_path": "__notebook__.ipynb",
1810 | "output_path": "__notebook__.ipynb",
1811 | "parameters": {},
1812 | "start_time": "2021-11-29T03:04:27.218204",
1813 | "version": "2.3.3"
1814 | }
1815 | },
1816 | "nbformat": 4,
1817 | "nbformat_minor": 5
1818 | }
1819 |
--------------------------------------------------------------------------------
/Task5/fig/Task5-LSTM与ConvLSTM公式比较.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task5/fig/Task5-LSTM与ConvLSTM公式比较.png
--------------------------------------------------------------------------------
/Task5/fig/Task5-SA-ConvLSTM模型.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task5/fig/Task5-SA-ConvLSTM模型.png
--------------------------------------------------------------------------------
/Task5/fig/Task5-SAM模块.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task5/fig/Task5-SAM模块.png
--------------------------------------------------------------------------------
/Task5/fig/Task5-Seq2Seq基础结构.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datawhalechina/time-series-learning/bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6/Task5/fig/Task5-Seq2Seq基础结构.png
--------------------------------------------------------------------------------