├── LICENSE ├── README.md └── time_series data prediction with gru and lstm ├── GRU.PY ├── LSTM.PY ├── __pycache__ ├── GRU.cpython-38.pyc ├── LSTM.cpython-38.pyc ├── data_preparation.cpython-38.pyc └── train.cpython-38.pyc ├── data_preparation.py ├── gru.pt ├── lstm.pt ├── test.PY └── train.PY /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 【Pytorch】基于GRU和LSTM的时间序列数据预测实现 2 | 3 | **1.实现结果:** 4 | 5 | image-20220327192956318 6 | 7 | image-20220327193009703 8 | 9 | ​ 蓝色曲线为原数据集,包含1000个点(sin函数),训练集占80%。 10 | 11 | ​ 橙色曲线为网络的预测值,前80%参加了训练,但是20%没有参加训练,看形状,效果还不错。 12 | 13 | **2.数据集的准备:** 14 | 15 | ​ 下面附上数据集准备的代码:(因为是模块化的编程方式,在代码的第一行我会表注其所在的模块) 16 | 17 | ![image-20220326230851125](https://gitee.com/Ejemplarr/drawing-bed/raw/master/img/image-20220326230851125.png) 18 | 19 | ​ 首先产生原始的1000个数据点 20 | 21 | ```python 22 | '''data_preparation模块''' 23 | 24 | # 导入需要的库 25 | import torch 26 | import numpy as np 27 | import matplotlib.pyplot as plt 28 | from torch.utils.data import Dataset, DataLoader 29 | 30 | T = 1000 31 | x = torch.arange(1, T + 1, dtype=torch.float32) 32 | y = torch.sin(0.01 * x) + torch.normal(0, 0.1, (T,))#每个y加上一个0到0.1(左闭右开)的噪声 33 | plt.plot(x, y) 34 | plt.show() 35 | ``` 36 | 37 | ​ 输出: 38 | 39 | image-20220326230902165 40 | 41 | ​ 下面这段是产生数据集的最需要注意的地方: 42 | 43 | ​ 因为是模仿的时间序列的预测,所以必须在数据集上要体现时序的特性,比如我们可以用序列的某八个数字预测该子序列的后一个数字,那么数据集中的第一条数据的特征就为[y0,y1,y2,y3,y4,y5,y6,y7],目标值为[y8],第二条为[y1,y2,y3,y4,y5,y6,y7,y8],目标值为[y9],依次类推,直到目标值为[y999]。(这里是以我们当前的数据集为例,1000个数据点,从0开始,最后为有y999) 44 | 45 | ​ 当然也可以用某长度为8的子序列预测该子序列的后2位数字,此时这数据集中的第一条数据就应该为[y0,y1,y2,y3,y4,y5,y6,y7],目标值为[y8, y9],第二条就应该为[,y2,y3,y4,y5,y6,y7,y8,y9],目标值为[y10,y11],同样以此类推,直到最后目标值为[y998,y999]。上面的两个例子,第一个例子的数据集总共992条,第二个例子的数据集总共496条,有兴趣的话,自己推算一下,就出来了。 46 | 47 | ​ 当然可以以任意长的子序列预测子序列之后任意长的序列,但是就是准确度会有影响。本文所提供的代码实现了这一功能,随意定义用于预测的子序列长度lengths,随意定义待续测的序列长度targets。 48 | 49 | ```python 50 | '''data_preparation模块''' 51 | 52 | ''' 53 | lengths :决定了用于预测序列的长度 54 | targets :表示待预测的序列长度 55 | 例如lengths = 8, targets = 1,则表示用8个数预测一个数 56 | ''' 57 | lengths = 8 58 | targets = 1 59 | 60 | def data_prediction_to_f_and_t(data, num_features, num_targets): 61 | ''' 62 | 这段函数为拆分数据的关键,num_features为用于预测的子序列的长度,num_targets表示待预测序列的长度 63 | ''' 64 | features, target = [], [] 65 | for i in range(((len(data)-num_features-num_targets)//num_targets) + 1): 66 | f = data[i*num_targets:i*num_targets+num_features] 67 | t = data[i*num_targets+num_features:i*num_targets+num_features+num_targets] 68 | features.append(list(f)) 69 | target.append(list(t)) 70 | 71 | return np.array(features), np.array(target) 72 | 73 | # 第一步生成数据集 74 | dataset_features, dataset_target = data_prediction_to_f_and_t(y, lengths, targets)# 调用上述定义的函数 75 | print(dataset_features.shape) 76 | print(dataset_target.shape) 77 | >>>(992, 8) 78 | (992, 1)# 与我们上面描述的相同,shape大小正确 79 | ``` 80 | 81 | ​ 如果觉得看不清,我们可以再尝试一下这个函数: 82 | 83 | ```python 84 | '''不属于任何模块,测试用''' 85 | 86 | data = torch.arange(0, T, dtype=torch.float32)# data为0,1,2,...,999 87 | dataset_features, dataset_target = data_prediction_to_f_and_t(data, lengths, targets)# lengths=8, targets=1 88 | print(dataset_features) 89 | print(dataset_target) 90 | ``` 91 | 92 | ​ 输出: 93 | 94 | dataset_featuresimage-20220326230929137dataset_targetimage-20220326222352452......image-20220326222220648 95 | 96 | ​ 与我们上述的论述相同,如果有兴趣可以修改lengths与targets的值的大小看效果。 97 | 98 | ​ 下面继续首先进行数据集的拆分,我们同样定义了函数,然后再调用: 99 | 100 | ```python 101 | '''data_preparation模块''' 102 | 103 | def dataset_split_4sets(data_features, data_target, ratio=0.8): 104 | ''' 105 | 功能:训练集与测试集的特征与target分离 106 | ratio:表示训练集所占的百分比 107 | ''' 108 | split_index = int(ratio*len(data_features)) 109 | train_features = data_features[:split_index] 110 | train_target = data_target[:split_index] 111 | test_features = data_features[split_index:] 112 | test_target = data_target[split_index:] 113 | return train_features, train_target, test_features, test_target 114 | 115 | 116 | # 第二步,将数据集进行拆分,分成训练集和测试集 117 | trian_features, train_target, test_features, test_target = dataset_split_4sets(dataset_features, dataset_target) 118 | ``` 119 | 120 | ​ 接着,将数据集写成Dataset的子类,至于为什么要写成Dataset的子类,是因为后我们最终要将数据封装进Dataloader里,可以方便做mini-batch与shuffle操作,这是为了方便Pytorch框架下训练模型所使用的Dataloder类。关于这里不清楚得同学可以参考这篇文章。 121 | 122 | ```python 123 | '''data_preparation模块''' 124 | 125 | class dataset_to_Dataset(Dataset): 126 | ''' 127 | 将传入的数据集,转成Dataset类,方面后续转入Dataloader类 128 | 注意定义时传入的data_features,data_target必须为numpy数组 129 | ''' 130 | def __init__(self, data_features, data_target): 131 | self.len = len(data_features) 132 | self.features = torch.from_numpy(data_features) 133 | self.target = torch.from_numpy(data_target) 134 | 135 | def __getitem__(self, index): 136 | return self.features[index], self.target[index] 137 | 138 | def __len__(self): 139 | return self.len 140 | 141 | 142 | # 第三步,将刚才的数据集转换成Dataset类 143 | train_set = dataset_to_Dataset(data_features=trian_features, data_target=train_target) 144 | ``` 145 | 146 | ​ 最后将上述进行整理,下面是完整的data_prediction模块:(能写成函数的尽量写成函数方法,方便调用,和复用) 147 | 148 | ```python 149 | '''data_preparation完整模块''' 150 | 151 | # 用户:Ejemplarr 152 | # 编写时间:2022/3/24 22:11 153 | from torch.utils.data import Dataset, DataLoader 154 | import numpy as np 155 | import torch 156 | import matplotlib.pyplot as plt 157 | 158 | ''' 159 | lengths :决定了用于预测序列的长度 160 | targets :表示待预测的序列长度 161 | 例如lengths = 8, targets = 1,则表示用8个数预测一个数 162 | ''' 163 | lengths = 8 164 | targets = 1 165 | 166 | def data_start(): 167 | T = 1000 168 | x = torch.arange(1, T + 1, dtype=torch.float32) 169 | y = torch.sin(0.01 * x) + torch.normal(0, 0.1, (T,)) # 每个y加上一个0到0.2(左闭右开)的噪声 170 | return x, y 171 | 172 | def data_prediction_to_f_and_t(data, num_features, num_targets): 173 | ''' 174 | 准备数据集的函数 175 | ''' 176 | features, target = [], [] 177 | for i in range(((len(data)-num_features-num_targets)//num_targets) + 1): 178 | f = data[i*num_targets:i*num_targets+num_features] 179 | t = data[i*num_targets+num_features:i*num_targets+num_features+num_targets] 180 | features.append(list(f)) 181 | target.append(list(t)) 182 | 183 | return np.array(features), np.array(target) 184 | 185 | class dataset_to_Dataset(Dataset): 186 | ''' 187 | 将传入的数据集,转成Dataset类,方面后续转入Dataloader类 188 | 注意定义时传入的data_features,data_target必须为numpy数组 189 | ''' 190 | def __init__(self, data_features, data_target): 191 | self.len = len(data_features) 192 | self.features = torch.from_numpy(data_features) 193 | self.target = torch.from_numpy(data_target) 194 | 195 | def __getitem__(self, index): 196 | return self.features[index], self.target[index] 197 | 198 | def __len__(self): 199 | return self.len 200 | 201 | def dataset_split_4sets(data_features, data_target, ratio=0.8): 202 | ''' 203 | 功能:训练集与测试集的特征与target分离 204 | ratio:表示训练集所占的百分比 205 | ''' 206 | split_index = int(ratio*len(data_features)) 207 | train_features = data_features[:split_index] 208 | train_target = data_target[:split_index] 209 | test_features = data_features[split_index:] 210 | test_target = data_target[split_index:] 211 | return train_features, train_target, test_features, test_target 212 | ``` 213 | 214 | **3.GRU和LSTM网络框架的编写:** 215 | 216 | ```python 217 | '''GRU完整模块''' 218 | 219 | # 用户:Ejemplarr 220 | # 编写时间:2022/3/24 22:09 221 | import torch 222 | import torch.nn as nn 223 | from data_preparation import targets 224 | ''' 225 | GRU: 226 | 对于每个网络框架具体的学习最好参考官网进行学习: 227 | 228 | https://pytorch.org/docs/master/generated/torch.nn.GRU.html#torch.nn.GRU 229 | 230 | 因为官网对于一个网络的输入和输出的数据的shape讲的特别清楚,对于我来说,看完相关基本原理之后,直接就是打开官网 231 | 仔细阅读一下整个网络的各种数据的shape,以及各种参数的实际意义,最后就是借助简单的数据集跑一个demo。这仅仅是我 232 | 个人的习惯,仅供参考。 233 | 关于GRU的原理,可以参考某站的李沐老师的动手学习深度学习系列。 234 | ''' 235 | ''' 236 | 定义Parameters,从官网上可以看见除了我们下面定义的这两个参数,其他参数都有默认值,如果实现最简单的GRU网络,自己定义一下 237 | 前面两个参数就行了,后面的例如dropout是防止过拟合的,bidirectional是控制是否实现双向的,等等,但是这边我们还需要设置 238 | batch_first = True,因为一般我们的数据格式都是batch_size在前 239 | ''' 240 | INPUT_SIZE = 1# The number of expected features in the input x,就是我们表示子序列中一个数的描述的特征数量,只有一个就填1,一个数字就是1 241 | HIDDEN_SIZE = 64# The number of features in the hidden state h,隐藏状态的特征数 242 | # h0 = torch.zeros([])# h0的shape与hn的shape一样为(D * num_layers, batch_size, hidden_size) 243 | # 其中的D = 2 if bidirectional=True otherwise 1,num_layers为GRU的层数 244 | # 如果这边不对h0进行定义,则网络中的forward中h0可以直接用None替代,默认全零。 245 | 246 | # 定义我们的类 247 | class GRU(nn.Module): 248 | def __init__(self): 249 | super(GRU, self).__init__() 250 | self.gru = nn.GRU( 251 | input_size=INPUT_SIZE,# 传入我们上面定义的参数 252 | hidden_size=HIDDEN_SIZE,# 传入我们上面定义的参数 253 | batch_first=True,# 为什么设置为True上面解释过了 254 | ) 255 | self.mlp = nn.Sequential( 256 | nn.Linear(HIDDEN_SIZE, 32), # 加入线性层的原因是,GRU的输出,参考官网为(batch_size, seq_len, hidden_size) 257 | nn.LeakyReLU(), # 这边的多层全连接,根据自己的输出自己定义就好, 258 | nn.Linear(32, 16), # 我们需要将其最后打成(batch_size, output_size)比如单值预测,这个output_size就是1, 259 | nn.LeakyReLU(), # 这边我们等于targets 260 | nn.Linear(16, targets) # 这边输出的(batch_size, targets)且这个targets是上面一个模块已经定义好了 261 | ) 262 | 263 | def forward(self, input): 264 | output, h_n = self.gru(input, None)# output:(batch_size, seq_len, hidden_size),h0可以直接None 265 | # print(output.shape) 266 | output = output[:, -1, :]# output:(batch_size, hidden_size) 267 | output = self.mlp(output)# 进过一个多层感知机,也就是全连接层,output:(batch_size, output_size) 268 | return output 269 | ``` 270 | 271 | ```Python 272 | '''LSTM完整模块''' 273 | 274 | # 用户:Ejemplarr 275 | # 编写时间:2022/3/24 22:09 276 | import torch 277 | import torch.nn as nn 278 | from data_preparation import targets 279 | 280 | 281 | INPUT_SIZE = 1# The number of expected features in the input x 282 | HIDDEN_SIZE = 64# The number of features in the hidden state h 283 | 284 | ''' 285 | GRU与LSTM的在代码上的差别,就是将nn.GRU换成nn.LSTM而已 286 | ''' 287 | 288 | class LSTM(nn.Module): 289 | def __init__(self): 290 | super(LSTM, self).__init__() 291 | self.gru = nn.LSTM( 292 | input_size=INPUT_SIZE,# 传入我们上面定义的参数 293 | hidden_size=HIDDEN_SIZE,# 传入我们上面定义的参数 294 | batch_first=True,# 为什么设置为True上面解释过了 295 | ) 296 | self.mlp = nn.Sequential( 297 | nn.Linear(HIDDEN_SIZE, 32), # 加入线性层的原因是,GRU的输出,参考官网为(batch_size, seq_len, hidden_size) 298 | nn.LeakyReLU(), # 这边的多层全连接,根据自己的输出自己定义就好, 299 | nn.Linear(32, 16), # 我们需要将其最后打成(batch_size, output_size)比如单值预测,这个output_size就是1, 300 | nn.LeakyReLU(), # 这边我们等于targets 301 | nn.Linear(16, targets) # 这边输出的(batch_size, targets)且这个targets是上面一个模块已经定义好了 302 | ) 303 | 304 | def forward(self, input): 305 | output, h_n = self.gru(input, None)# output:(batch_size, seq_len, hidden_size),h0可以直接None 306 | # print(output.shape) 307 | output = output[:, -1, :]# output:(batch_size, hidden_size) 308 | output = self.mlp(output)# 进过一个多层感知机,也就是全连接层,output:(batch_size, output_size) 309 | return output 310 | ``` 311 | 312 | **4.定义训练函数:** 313 | 314 | ```python 315 | '''train完整模块''' 316 | 317 | # 用户:Ejemplarr 318 | # 编写时间:2022/3/24 22:10 319 | import time 320 | import torch 321 | import torch.nn as nn 322 | import torch.optim as optim 323 | from torch.utils.data import Dataset, DataLoader 324 | 325 | from GRU import GRU 326 | from LSTM import LSTM 327 | from data_preparation import data_start,data_prediction_to_f_and_t,dataset_to_Dataset,dataset_split_4sets,lengths,targets 328 | 329 | ''' 330 | 数据的导入 331 | 可调优数据的定义 332 | 网络实例化 333 | 优化器的定义 334 | 数据搬移至gpu 335 | 损失函数的定义 336 | 开始训练 337 | ''' 338 | 339 | # 可调参数的定义 340 | BATCH_SIZE = 16 341 | EPOCH = 100 342 | LEARN_RATE = 1e-3 343 | 344 | 345 | # 数据的导入 346 | x, y = data_start() 347 | dataset_features, dataset_target = data_prediction_to_f_and_t(y, lengths, targets) 348 | trian_features, train_target, test_features, test_target = dataset_split_4sets(dataset_features, dataset_target) 349 | train_set = dataset_to_Dataset(data_features=trian_features, data_target=train_target) 350 | 351 | train_set_iter = DataLoader(dataset=train_set,# 将数据封装进Dataloader类 352 | batch_size=BATCH_SIZE, 353 | shuffle=True, # 打乱batch与batch之间的顺序 354 | drop_last=True)# drop_last = True表示最后不够一个batch就舍弃那些多余的数据 355 | 356 | # gpu的定义 357 | device = ('cuda'if torch.cuda.is_available else 'cpu') 358 | 359 | # 网络的实例化 360 | net_gru = GRU().to(device) 361 | net_lstm = LSTM().to(device) 362 | 363 | # 优化器的定义 364 | optim_gru = optim.Adam(params=net_gru.parameters(), lr=LEARN_RATE) 365 | optim_lstm = optim.Adam(params=net_lstm.parameters(),lr=LEARN_RATE) 366 | 367 | # 损失函数的定义 368 | loss_fuc = nn.MSELoss() 369 | 370 | # 训练函数的定义 371 | def train_for_gru(data, device, loss_fuc, net, optim, Epoch): 372 | for epoch in range(Epoch): 373 | loss_print = [] 374 | for batch_idx, (x, y) in enumerate(data): 375 | x = x.reshape([BATCH_SIZE, lengths, 1]) 376 | x = x.to(device) 377 | # print(y.shape) 378 | y = y.reshape((len(y),targets)) 379 | y = y.to(device) 380 | # print(y.shape) 381 | y_pred = net(x) 382 | loss = loss_fuc(y, y_pred) 383 | loss_print.append(loss.item()) 384 | # 三大步 385 | # 网络的梯度值更为0 386 | net.zero_grad() 387 | # loss反向传播 388 | loss.backward() 389 | # 优化器更新 390 | optim.step() 391 | print('GRU:loss:',sum(loss_print)/len(data)) 392 | 393 | def train_for_lstm(data, device, loss_fuc, net, optim, Epoch): 394 | for epoch in range(Epoch): 395 | loss_print = [] 396 | for batch_idx, (x, y) in enumerate(data): 397 | x = x.reshape([BATCH_SIZE, lengths, 1]) 398 | x = x.to(device) 399 | # print(y.shape) 400 | y = y.reshape((len(y),targets)) 401 | y = y.to(device) 402 | # print(y.shape) 403 | y_pred = net(x) 404 | loss = loss_fuc(y, y_pred) 405 | loss_print.append(loss.item()) 406 | # 三大步 407 | # 网络的梯度值更为0 408 | net.zero_grad() 409 | # loss反向传播 410 | loss.backward() 411 | # 优化器更新 412 | optim.step() 413 | print('LSTM:loss:',sum(loss_print)/len(data)) 414 | 415 | 416 | def main(): 417 | start = time.perf_counter() 418 | train_for_gru(train_set_iter, device, loss_fuc, net_gru, optim_gru, EPOCH) 419 | train_for_lstm(train_set_iter, device, loss_fuc, net_lstm, optim_lstm, EPOCH) 420 | end = time.perf_counter() 421 | print('训练时间为:{:.2f}s'.format(end-start)) 422 | #保存模型 423 | torch.save(net_gru.state_dict(), 'gru.pt') 424 | torch.save(net_lstm.state_dict(), 'lstm.pt') 425 | if __name__ == '__main__': 426 | main() 427 | ``` 428 | 429 | **5.定义测试函数:** 430 | 431 | ```python 432 | '''test完整模块''' 433 | 434 | # 用户:Ejemplarr 435 | # 编写时间:2022/3/24 22:10 436 | from train import device 437 | from data_preparation import lengths, targets 438 | from train import x, y, dataset_features # 为了保持原始数据相同 439 | from GRU import GRU 440 | from LSTM import LSTM 441 | 442 | import torch 443 | import matplotlib.pyplot as plt 444 | 445 | # 导入保存好的网络 446 | net_gru = GRU().to(device) 447 | net_gru.load_state_dict(torch.load('gru.pt')) 448 | net_lstm = LSTM().to(device) 449 | net_lstm.load_state_dict(torch.load('lstm.pt')) 450 | 451 | # 定义测试函数 452 | def test_for_gru(dataset_features): 453 | dataset_features = dataset_features.reshape([len(dataset_features), lengths, 1]) 454 | y_pred = net_gru(torch.from_numpy(dataset_features).to(device)) 455 | y_pred = y_pred_to_numpy(y_pred) 456 | y_pred = y_pred.reshape(y_pred.size,1) 457 | plt.plot(x, y) 458 | plt.plot(x[lengths:y_pred.size+lengths], y_pred) 459 | plt.legend(('data', 'data_pred:{}'.format(targets)), loc='upper right') 460 | plt.title('GRU') 461 | plt.show() 462 | 463 | def test_for_lstm(dataset_features): 464 | dataset_features = dataset_features.reshape([len(dataset_features), lengths, 1]) 465 | y_pred = net_lstm(torch.from_numpy(dataset_features).to(device)) 466 | y_pred = y_pred_to_numpy(y_pred) 467 | y_pred = y_pred.reshape(y_pred.size,1) 468 | plt.plot(x, y) 469 | plt.plot(x[lengths:y_pred.size+lengths], y_pred) 470 | plt.legend(('data', 'data_pred:{}'.format(targets)), loc='upper right') 471 | plt.title('LSTM') 472 | plt.show() 473 | 474 | def y_pred_to_numpy(y_pred): 475 | ''' 476 | :param y_pred: 网络的输出 477 | :return: 一个numpy数组 478 | ''' 479 | y_pred = y_pred.detach().cpu().numpy() 480 | return y_pred 481 | 482 | if __name__ == '__main__': 483 | test_for_gru(dataset_features) 484 | test_for_lstm(dataset_features) 485 | ``` 486 | 487 | **6.总结:** 488 | 489 | ​ 使用方法,分别创建五个py文件,将上述五个完整模块分别复制到各个py文件,运行顺序为data_preparation.py----->GRU.py----->LSTM.py----->train.py----->test.py 490 | 491 | ​ 使用了GRU,LSTM对创建的数据集进行了预测,结果效果不错。 492 | 493 | ​ 感谢阅读,欢迎交流!!! 494 | 495 | -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/GRU.PY: -------------------------------------------------------------------------------- 1 | '''GRU完整模块''' 2 | 3 | # 用户:Ejemplarr 4 | # 编写时间:2022/3/24 22:09 5 | import torch 6 | import torch.nn as nn 7 | from data_preparation import targets 8 | ''' 9 | GRU: 10 | 对于每个网络框架具体的学习最好参考官网进行学习: 11 | 12 | https://pytorch.org/docs/master/generated/torch.nn.GRU.html#torch.nn.GRU 13 | 14 | 因为官网对于一个网络的输入和输出的数据的shape讲的特别清楚,对于我来说,看完相关基本原理之后,直接就是打开官网 15 | 仔细阅读一下整个网络的各种数据的shape,以及各种参数的实际意义,最后就是借助简单的数据集跑一个demo。这仅仅是我 16 | 个人的习惯,仅供参考。 17 | 关于GRU的原理,可以参考某站的李沐老师的动手学习深度学习系列。 18 | ''' 19 | ''' 20 | 定义Parameters,从官网上可以看见除了我们下面定义的这两个参数,其他参数都有默认值,如果实现最简单的GRU网络,自己定义一下 21 | 前面两个参数就行了,后面的例如dropout是防止过拟合的,bidirectional是控制是否实现双向的,等等,但是这边我们还需要设置 22 | batch_first = True,因为一般我们的数据格式都是batch_size在前 23 | ''' 24 | INPUT_SIZE = 1# The number of expected features in the input x,就是我们表示子序列中一个数的描述的特征数量,只有一个就填1,一个数字就是1 25 | HIDDEN_SIZE = 64# The number of features in the hidden state h,隐藏状态的特征数 26 | # h0 = torch.zeros([])# h0的shape与hn的shape一样为(D * num_layers, batch_size, hidden_size) 27 | # 其中的D = 2 if bidirectional=True otherwise 1,num_layers为GRU的层数 28 | # 如果这边不对h0进行定义,则网络中的forward中h0可以直接用None替代,默认全零。 29 | 30 | # 定义我们的类 31 | class GRU(nn.Module): 32 | def __init__(self): 33 | super(GRU, self).__init__() 34 | self.gru = nn.GRU( 35 | input_size=INPUT_SIZE,# 传入我们上面定义的参数 36 | hidden_size=HIDDEN_SIZE,# 传入我们上面定义的参数 37 | batch_first=True,# 为什么设置为True上面解释过了 38 | ) 39 | self.mlp = nn.Sequential( 40 | nn.Linear(HIDDEN_SIZE, 32), # 加入线性层的原因是,GRU的输出,参考官网为(batch_size, seq_len, hidden_size) 41 | nn.LeakyReLU(), # 这边的多层全连接,根据自己的输出自己定义就好, 42 | nn.Linear(32, 16), # 我们需要将其最后打成(batch_size, output_size)比如单值预测,这个output_size就是1, 43 | nn.LeakyReLU(), # 这边我们等于targets 44 | nn.Linear(16, targets) # 这边输出的(batch_size, targets)且这个targets是上面一个模块已经定义好了 45 | ) 46 | 47 | def forward(self, input): 48 | output, h_n = self.gru(input, None)# output:(batch_size, seq_len, hidden_size),h0可以直接None 49 | # print(output.shape) 50 | output = output[:, -1, :]# output:(batch_size, hidden_size) 51 | output = self.mlp(output)# 进过一个多层感知机,也就是全连接层,output:(batch_size, output_size) 52 | return output 53 | 54 | -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/LSTM.PY: -------------------------------------------------------------------------------- 1 | '''LSTM完整模块''' 2 | 3 | # 用户:Ejemplarr 4 | # 编写时间:2022/3/24 22:09 5 | import torch 6 | import torch.nn as nn 7 | from data_preparation import targets 8 | 9 | 10 | INPUT_SIZE = 1# The number of expected features in the input x 11 | HIDDEN_SIZE = 64# The number of features in the hidden state h 12 | 13 | ''' 14 | GRU与LSTM的在代码上的差别,就是将nn.GRU换成nn.LSTM而已 15 | ''' 16 | 17 | class LSTM(nn.Module): 18 | def __init__(self): 19 | super(LSTM, self).__init__() 20 | self.gru = nn.LSTM( 21 | input_size=INPUT_SIZE,# 传入我们上面定义的参数 22 | hidden_size=HIDDEN_SIZE,# 传入我们上面定义的参数 23 | batch_first=True,# 为什么设置为True上面解释过了 24 | ) 25 | self.mlp = nn.Sequential( 26 | nn.Linear(HIDDEN_SIZE, 32), # 加入线性层的原因是,GRU的输出,参考官网为(batch_size, seq_len, hidden_size) 27 | nn.LeakyReLU(), # 这边的多层全连接,根据自己的输出自己定义就好, 28 | nn.Linear(32, 16), # 我们需要将其最后打成(batch_size, output_size)比如单值预测,这个output_size就是1, 29 | nn.LeakyReLU(), # 这边我们等于targets 30 | nn.Linear(16, targets) # 这边输出的(batch_size, targets)且这个targets是上面一个模块已经定义好了 31 | ) 32 | 33 | def forward(self, input): 34 | output, h_n = self.gru(input, None)# output:(batch_size, seq_len, hidden_size),h0可以直接None 35 | # print(output.shape) 36 | output = output[:, -1, :]# output:(batch_size, hidden_size) 37 | output = self.mlp(output)# 进过一个多层感知机,也就是全连接层,output:(batch_size, output_size) 38 | return output -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/__pycache__/GRU.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rssevenyu/pytorch-time_series_data-prediction-with-gru-and-lstm/fb61ff4100aaa83b63b86e2a4278d1368e98a13a/time_series data prediction with gru and lstm/__pycache__/GRU.cpython-38.pyc -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/__pycache__/LSTM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rssevenyu/pytorch-time_series_data-prediction-with-gru-and-lstm/fb61ff4100aaa83b63b86e2a4278d1368e98a13a/time_series data prediction with gru and lstm/__pycache__/LSTM.cpython-38.pyc -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/__pycache__/data_preparation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rssevenyu/pytorch-time_series_data-prediction-with-gru-and-lstm/fb61ff4100aaa83b63b86e2a4278d1368e98a13a/time_series data prediction with gru and lstm/__pycache__/data_preparation.cpython-38.pyc -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/__pycache__/train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rssevenyu/pytorch-time_series_data-prediction-with-gru-and-lstm/fb61ff4100aaa83b63b86e2a4278d1368e98a13a/time_series data prediction with gru and lstm/__pycache__/train.cpython-38.pyc -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/data_preparation.py: -------------------------------------------------------------------------------- 1 | # 用户:Cy 2 | # 编写时间:2022/3/26 19:53 3 | 4 | '''data_prediction完整模块''' 5 | 6 | # 用户:Ejemplarr 7 | # 编写时间:2022/3/24 22:11 8 | from torch.utils.data import Dataset, DataLoader 9 | import numpy as np 10 | import torch 11 | import matplotlib.pyplot as plt 12 | 13 | ''' 14 | lengths :决定了用于预测序列的长度 15 | targets :表示待预测的序列长度 16 | 例如lengths = 8, targets = 1,则表示用8个数预测一个数 17 | ''' 18 | lengths = 8 19 | targets = 1 20 | 21 | def data_start(): 22 | T = 1000 23 | x = torch.arange(1, T + 1, dtype=torch.float32) 24 | y = torch.sin(0.01 * x) + torch.normal(0, 0.1, (T,)) # 每个y加上一个0到0.2(左闭右开)的噪声 25 | return x, y 26 | 27 | def data_prediction_to_f_and_t(data, num_features, num_targets): 28 | ''' 29 | 准备数据集的函数 30 | ''' 31 | features, target = [], [] 32 | for i in range(((len(data)-num_features-num_targets)//num_targets) + 1): 33 | f = data[i*num_targets:i*num_targets+num_features] 34 | t = data[i*num_targets+num_features:i*num_targets+num_features+num_targets] 35 | features.append(list(f)) 36 | target.append(list(t)) 37 | 38 | return np.array(features), np.array(target) 39 | 40 | class dataset_to_Dataset(Dataset): 41 | ''' 42 | 将传入的数据集,转成Dataset类,方面后续转入Dataloader类 43 | 注意定义时传入的data_features,data_target必须为numpy数组 44 | ''' 45 | def __init__(self, data_features, data_target): 46 | self.len = len(data_features) 47 | self.features = torch.from_numpy(data_features) 48 | self.target = torch.from_numpy(data_target) 49 | 50 | def __getitem__(self, index): 51 | return self.features[index], self.target[index] 52 | 53 | def __len__(self): 54 | return self.len 55 | 56 | def dataset_split_4sets(data_features, data_target, ratio=0.8): 57 | ''' 58 | 功能:训练集与测试集的特征与target分离 59 | ratio:表示训练集所占的百分比 60 | ''' 61 | split_index = int(ratio*len(data_features)) 62 | train_features = data_features[:split_index] 63 | train_target = data_target[:split_index] 64 | test_features = data_features[split_index:] 65 | test_target = data_target[split_index:] 66 | return train_features, train_target, test_features, test_target 67 | -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/gru.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rssevenyu/pytorch-time_series_data-prediction-with-gru-and-lstm/fb61ff4100aaa83b63b86e2a4278d1368e98a13a/time_series data prediction with gru and lstm/gru.pt -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/lstm.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rssevenyu/pytorch-time_series_data-prediction-with-gru-and-lstm/fb61ff4100aaa83b63b86e2a4278d1368e98a13a/time_series data prediction with gru and lstm/lstm.pt -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/test.PY: -------------------------------------------------------------------------------- 1 | '''test完整模块''' 2 | 3 | # 用户:Ejemplarr 4 | # 编写时间:2022/3/24 22:10 5 | from train import device 6 | from data_preparation import lengths, targets 7 | from train import x, y, dataset_features # 为了保持原始数据相同 8 | from GRU import GRU 9 | from LSTM import LSTM 10 | 11 | import torch 12 | import matplotlib.pyplot as plt 13 | 14 | # 导入保存好的网络 15 | net_gru = GRU().to(device) 16 | net_gru.load_state_dict(torch.load('gru.pt')) 17 | net_lstm = LSTM().to(device) 18 | net_lstm.load_state_dict(torch.load('lstm.pt')) 19 | 20 | # 定义测试函数 21 | def test_for_gru(dataset_features): 22 | dataset_features = dataset_features.reshape([len(dataset_features), lengths, 1]) 23 | y_pred = net_gru(torch.from_numpy(dataset_features).to(device)) 24 | y_pred = y_pred_to_numpy(y_pred) 25 | y_pred = y_pred.reshape(y_pred.size,1) 26 | plt.plot(x, y) 27 | plt.plot(x[lengths:y_pred.size+lengths], y_pred) 28 | plt.legend(('data', 'data_pred:{}'.format(targets)), loc='upper right') 29 | plt.title('GRU') 30 | plt.show() 31 | 32 | def test_for_lstm(dataset_features): 33 | dataset_features = dataset_features.reshape([len(dataset_features), lengths, 1]) 34 | y_pred = net_lstm(torch.from_numpy(dataset_features).to(device)) 35 | y_pred = y_pred_to_numpy(y_pred) 36 | y_pred = y_pred.reshape(y_pred.size,1) 37 | plt.plot(x, y) 38 | plt.plot(x[lengths:y_pred.size+lengths], y_pred) 39 | plt.legend(('data', 'data_pred:{}'.format(targets)), loc='upper right') 40 | plt.title('LSTM') 41 | plt.show() 42 | 43 | def y_pred_to_numpy(y_pred): 44 | ''' 45 | :param y_pred: 网络的输出 46 | :return: 一个numpy数组 47 | ''' 48 | y_pred = y_pred.detach().cpu().numpy() 49 | return y_pred 50 | 51 | if __name__ == '__main__': 52 | plt.plot(x, y) 53 | plt.show() 54 | test_for_gru(dataset_features) 55 | test_for_lstm(dataset_features) 56 | 57 | -------------------------------------------------------------------------------- /time_series data prediction with gru and lstm/train.PY: -------------------------------------------------------------------------------- 1 | '''train完整模块''' 2 | 3 | # 用户:Ejemplarr 4 | # 编写时间:2022/3/24 22:10 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from GRU import GRU 12 | from LSTM import LSTM 13 | from data_preparation import data_start,data_prediction_to_f_and_t,dataset_to_Dataset,dataset_split_4sets,lengths,targets 14 | 15 | ''' 16 | 数据的导入 17 | 可调优数据的定义 18 | 网络实例化 19 | 优化器的定义 20 | 数据搬移至gpu 21 | 损失函数的定义 22 | 开始训练 23 | ''' 24 | 25 | # 可调参数的定义 26 | BATCH_SIZE = 16 27 | EPOCH = 10 28 | LEARN_RATE = 1e-3 29 | 30 | 31 | # 数据的导入 32 | x, y = data_start() 33 | dataset_features, dataset_target = data_prediction_to_f_and_t(y, lengths, targets) 34 | trian_features, train_target, test_features, test_target = dataset_split_4sets(dataset_features, dataset_target) 35 | train_set = dataset_to_Dataset(data_features=trian_features, data_target=train_target) 36 | 37 | train_set_iter = DataLoader(dataset=train_set,# 将数据封装进Dataloader类 38 | batch_size=BATCH_SIZE, 39 | shuffle=True, # 打乱batch与batch之间的顺序 40 | drop_last=True)# drop_last = True表示最后不够一个batch就舍弃那些多余的数据 41 | 42 | # gpu的定义 43 | device = ('cuda'if torch.cuda.is_available else 'cpu') 44 | 45 | # 网络的实例化 46 | net_gru = GRU().to(device) 47 | net_lstm = LSTM().to(device) 48 | 49 | # 优化器的定义 50 | optim_gru = optim.Adam(params=net_gru.parameters(), lr=LEARN_RATE) 51 | optim_lstm = optim.Adam(params=net_lstm.parameters(),lr=LEARN_RATE) 52 | 53 | # 损失函数的定义 54 | loss_fuc = nn.MSELoss() 55 | 56 | # 训练函数的定义 57 | def train_for_gru(data, device, loss_fuc, net, optim, Epoch): 58 | for epoch in range(Epoch): 59 | loss_print = [] 60 | for batch_idx, (x, y) in enumerate(data): 61 | x = x.reshape([BATCH_SIZE, lengths, 1]) 62 | x = x.to(device) 63 | # print(y.shape) 64 | y = y.reshape((len(y),targets)) 65 | y = y.to(device) 66 | # print(y.shape) 67 | y_pred = net(x) 68 | loss = loss_fuc(y, y_pred) 69 | loss_print.append(loss.item()) 70 | # 三大步 71 | # 网络的梯度值更为0 72 | net.zero_grad() 73 | # loss反向传播 74 | loss.backward() 75 | # 优化器更新 76 | optim.step() 77 | print('GRU:loss:',sum(loss_print)/len(data)) 78 | 79 | def train_for_lstm(data, device, loss_fuc, net, optim, Epoch): 80 | for epoch in range(Epoch): 81 | loss_print = [] 82 | for batch_idx, (x, y) in enumerate(data): 83 | x = x.reshape([BATCH_SIZE, lengths, 1]) 84 | x = x.to(device) 85 | # print(y.shape) 86 | y = y.reshape((len(y),targets)) 87 | y = y.to(device) 88 | # print(y.shape) 89 | y_pred = net(x) 90 | loss = loss_fuc(y, y_pred) 91 | loss_print.append(loss.item()) 92 | # 三大步 93 | # 网络的梯度值更为0 94 | net.zero_grad() 95 | # loss反向传播 96 | loss.backward() 97 | # 优化器更新 98 | optim.step() 99 | print('LSTM:loss:',sum(loss_print)/len(data)) 100 | 101 | 102 | def main(): 103 | start = time.perf_counter() 104 | train_for_gru(train_set_iter, device, loss_fuc, net_gru, optim_gru, EPOCH) 105 | train_for_lstm(train_set_iter, device, loss_fuc, net_lstm, optim_lstm, EPOCH) 106 | end = time.perf_counter() 107 | print('训练时间为:{:.2f}s'.format(end-start)) 108 | #保存模型 109 | torch.save(net_gru.state_dict(), 'gru.pt') 110 | torch.save(net_lstm.state_dict(), 'lstm.pt') 111 | if __name__ == '__main__': 112 | main() --------------------------------------------------------------------------------