├── .gitignore ├── LICENSE ├── README.md ├── README_en.md ├── datasets ├── __init__.py └── data_loader.py ├── img └── informer.png ├── models ├── __init__.py ├── autoformer.py └── informer.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /results 3 | /output 4 | /outputs 5 | /checkpoints 6 | .ipynb_checkpoints 7 | __pycache__ 8 | .vscode 9 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 HFAiLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Former Models for Long-Term Series Forecasting (LTSF) 2 | 3 | 简体中文 | [English](README_en.md) 4 | 5 | 本项目在幻方萤火超算集群上用 PyTorch 实现了 [*Informer*](https://github.com/zhouhaoyi/Informer2020) 和 [*Autoformer*](https://github.com/thuml/Autoformer) 两个模型的**分布式训练版本**,它们是近年来采用 *transformer* 系列方法进行长时间序列预测的代表模型之一。 6 | + [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting (AAAI 2021)](https://ojs.aaai.org/index.php/AAAI/article/view/17325) 7 | + [Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting (NeurIPS 2021)](https://arxiv.org/abs/2106.13008) 8 | 9 | ![Informer](./img/informer.png) 10 | 11 | 12 | ## Requirements 13 | 14 | - [hfai](https://doc.hfai.high-flyer.cn/index.html) 15 | - torch >=1.8 16 | 17 | 18 | ## Training 19 | 原始数据来自 [Autoformer开源仓库](https://github.com/thuml/Autoformer) ,整理进 `hfai.datasets` 数据集仓库中,包括:`ETTh1`, `ETTh2`, `ETTm1`, `ETTm2`, `exchange_rate`, `electricity`, `national_illness`, `traffic`。 使用参考[hfai开发文档](#)。 20 | 21 | 1. 训练 informer 22 | 23 | 提交任务至萤火集群 24 | ```shell 25 | hfai python train.py --ds ETTh1 --model informer -- -n 1 -p 30 26 | ``` 27 | 本地运行: 28 | ```shell 29 | python train.py --ds ETTh1 --model informer 30 | ``` 31 | 32 | 2. 训练 Autoformer 33 | 34 | 提交任务至萤火集群 35 | ```shell 36 | hfai python train.py --ds ETTh1 --model autoformer -- -n 1 -p 30 37 | ``` 38 | 本地运行: 39 | ```shell 40 | python train.py --ds ETTh1 --model autoformer 41 | ``` 42 | 43 | 44 | ## References 45 | + [Informer](https://github.com/zhouhaoyi/Informer2020) 46 | + [Autoformer](https://github.com/thuml/Autoformer) 47 | 48 | 49 | ## Citation 50 | 51 | ```bibtex 52 | @inproceedings{haoyietal-informer-2021, 53 | author = {Haoyi Zhou and 54 | Shanghang Zhang and 55 | Jieqi Peng and 56 | Shuai Zhang and 57 | Jianxin Li and 58 | Hui Xiong and 59 | Wancai Zhang}, 60 | title = {Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting}, 61 | booktitle = {The Thirty-Fifth {AAAI} Conference on Artificial Intelligence, {AAAI} 2021, Virtual Conference}, 62 | volume = {35}, 63 | number = {12}, 64 | pages = {11106--11115}, 65 | publisher = {{AAAI} Press}, 66 | year = {2021}, 67 | } 68 | ``` 69 | 70 | ```bibtex 71 | @inproceedings{wu2021autoformer, 72 | title={Autoformer: Decomposition Transformers with {Auto-Correlation} for Long-Term Series Forecasting}, 73 | author={Haixu Wu and Jiehui Xu and Jianmin Wang and Mingsheng Long}, 74 | booktitle={Advances in Neural Information Processing Systems}, 75 | year={2021} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | # Former Models for Long-Term Series Forecasting (LTSF) 2 | 3 | English | [简体中文](README.md) 4 | 5 | This is a distributed-training implementation of former models ([*Informer*](https://github.com/zhouhaoyi/Informer2020) and [*Autoformer*](https://github.com/thuml/Autoformer)), 6 | which aim at conducting Long-Term Series Forecasting (LTSF). 7 | + [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting (AAAI 2021)](https://ojs.aaai.org/index.php/AAAI/article/view/17325) 8 | + [Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting (NeurIPS 2021)](https://arxiv.org/abs/2106.13008) 9 | 10 | ![Informer](./img/informer.png) 11 | 12 | ## Requirements 13 | 14 | - hfai (to be released soon) 15 | - torch >=1.8 16 | 17 | 18 | ## Training 19 | The raw data is from [Github:Autoformer](https://github.com/thuml/Autoformer) , which is integrated into the dataset warehouse, `hfai.datasets`, including: `ETTh1`, `ETTh2`, `ETTm1`, `ETTm2`, `exchange_rate`, `electricity`, `national_illness`, `traffic` 20 | 21 | 1. train informer 22 | 23 | submit the task to Yinghuo HPC: 24 | ```shell 25 | hfai python train.py --ds ETTh1 --model informer -- -n 1 -p 30 26 | ``` 27 | run locally: 28 | ```shell 29 | python train.py --ds ETTh1 --model informer 30 | ``` 31 | 32 | 2. train Autoformer 33 | 34 | submit the task to Yinghuo HPC: 35 | ```shell 36 | hfai python train.py --ds ETTh1 --model autoformer -- -n 1 -p 30 37 | ``` 38 | run locally: 39 | ```shell 40 | python train.py --ds ETTh1 --model autoformer 41 | ``` 42 | 43 | 44 | ## References 45 | + [Informer](https://github.com/zhouhaoyi/Informer2020) 46 | + [Autoformer](https://github.com/thuml/Autoformer) 47 | 48 | 49 | ## Citation 50 | 51 | ```bibtex 52 | @inproceedings{haoyietal-informer-2021, 53 | author = {Haoyi Zhou and 54 | Shanghang Zhang and 55 | Jieqi Peng and 56 | Shuai Zhang and 57 | Jianxin Li and 58 | Hui Xiong and 59 | Wancai Zhang}, 60 | title = {Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting}, 61 | booktitle = {The Thirty-Fifth {AAAI} Conference on Artificial Intelligence, {AAAI} 2021, Virtual Conference}, 62 | volume = {35}, 63 | number = {12}, 64 | pages = {11106--11115}, 65 | publisher = {{AAAI} Press}, 66 | year = {2021}, 67 | } 68 | ``` 69 | 70 | ```bibtex 71 | @inproceedings{wu2021autoformer, 72 | title={Autoformer: Decomposition Transformers with {Auto-Correlation} for Long-Term Series Forecasting}, 73 | author={Haixu Wu and Jiehui Xu and Jianmin Wang and Mingsheng Long}, 74 | booktitle={Advances in Neural Information Processing Systems}, 75 | year={2021} 76 | } 77 | ``` -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .data_loader import get_dataloader -------------------------------------------------------------------------------- /datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.distributed import DistributedSampler 2 | from ffrecord.torch import DataLoader 3 | from hfai.datasets import LTSF 4 | 5 | 6 | def get_dataloader(data_name: str, seq_len: int, label_len: int, pred_len: int, features: str, batch_size: int, num_workers: int=8, mode: str='train'): 7 | assert data_name in ['ETTh1', 'ETTh2', 'ETTm1', 'ETTm2', 8 | 'exchange_rate', 'electricity', 9 | 'national_illness', 'traffic'] 10 | assert mode in ['train', 'val'] 11 | 12 | data = LTSF( 13 | data_name, 14 | split=mode, 15 | seq_len=seq_len, 16 | label_len=label_len, 17 | pred_len=pred_len, 18 | features=features, 19 | ) 20 | datasampler = DistributedSampler(data, shuffle=True) 21 | dataloader = DataLoader( 22 | data, batch_size=batch_size, sampler=datasampler, num_workers=num_workers, pin_memory=True 23 | ) 24 | 25 | x, y, x_mark, y_mark = data[[0]][0] 26 | encoder_dim = x.shape[-1] 27 | decoder_dim = y.shape[-1] 28 | output_dim = decoder_dim if features != 'MS' else 1 29 | 30 | return dataloader, data.get_scaler(), encoder_dim, decoder_dim, output_dim -------------------------------------------------------------------------------- /img/informer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HFAiLab/LTSF-formers/c7daaef53cf5318c5775ebc0d0a13edec682b20d/img/informer.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .informer import Informer, InformerStack 3 | from .autoformer import Autoformer -------------------------------------------------------------------------------- /models/autoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class PositionalEmbedding(nn.Module): 8 | def __init__(self, d_model, max_len=5000): 9 | super(PositionalEmbedding, self).__init__() 10 | # Compute the positional encodings once in log space. 11 | pe = torch.zeros(max_len, d_model).float() 12 | pe.require_grad = False 13 | 14 | position = torch.arange(0, max_len).float().unsqueeze(1) 15 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 16 | 17 | pe[:, 0::2] = torch.sin(position * div_term) 18 | pe[:, 1::2] = torch.cos(position * div_term) 19 | 20 | pe = pe.unsqueeze(0) 21 | self.register_buffer("pe", pe) 22 | 23 | def forward(self, x): 24 | return self.pe[:, : x.size(1)] 25 | 26 | 27 | class TokenEmbedding(nn.Module): 28 | def __init__(self, c_in, d_model): 29 | super(TokenEmbedding, self).__init__() 30 | padding = 1 if torch.__version__ >= "1.5.0" else 2 31 | self.tokenConv = nn.Conv1d( 32 | in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode="circular", bias=False 33 | ) 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv1d): 36 | nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu") 37 | 38 | def forward(self, x): 39 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 40 | return x 41 | 42 | 43 | class FixedEmbedding(nn.Module): 44 | def __init__(self, c_in, d_model): 45 | super(FixedEmbedding, self).__init__() 46 | 47 | w = torch.zeros(c_in, d_model).float() 48 | w.require_grad = False 49 | 50 | position = torch.arange(0, c_in).float().unsqueeze(1) 51 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 52 | 53 | w[:, 0::2] = torch.sin(position * div_term) 54 | w[:, 1::2] = torch.cos(position * div_term) 55 | 56 | self.emb = nn.Embedding(c_in, d_model) 57 | self.emb.weight = nn.Parameter(w, requires_grad=False) 58 | 59 | def forward(self, x): 60 | return self.emb(x).detach() 61 | 62 | 63 | class TemporalEmbedding(nn.Module): 64 | def __init__(self, d_model, embed_type="fixed", freq="h"): 65 | super(TemporalEmbedding, self).__init__() 66 | 67 | minute_size = 4 68 | hour_size = 24 69 | weekday_size = 7 70 | day_size = 32 71 | month_size = 13 72 | 73 | Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding 74 | if freq == "t": 75 | self.minute_embed = Embed(minute_size, d_model) 76 | self.hour_embed = Embed(hour_size, d_model) 77 | self.weekday_embed = Embed(weekday_size, d_model) 78 | self.day_embed = Embed(day_size, d_model) 79 | self.month_embed = Embed(month_size, d_model) 80 | 81 | def forward(self, x): 82 | x = x.long() 83 | 84 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 85 | hour_x = self.hour_embed(x[:, :, 3]) 86 | weekday_x = self.weekday_embed(x[:, :, 2]) 87 | day_x = self.day_embed(x[:, :, 1]) 88 | month_x = self.month_embed(x[:, :, 0]) 89 | 90 | return hour_x + weekday_x + day_x + month_x + minute_x 91 | 92 | 93 | class TimeFeatureEmbedding(nn.Module): 94 | def __init__(self, d_model, embed_type="timeF", freq="h"): 95 | super(TimeFeatureEmbedding, self).__init__() 96 | 97 | freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} 98 | d_inp = freq_map[freq] 99 | self.embed = nn.Linear(d_inp, d_model, bias=False) 100 | 101 | def forward(self, x): 102 | return self.embed(x) 103 | 104 | 105 | class DataEmbedding(nn.Module): 106 | def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): 107 | super(DataEmbedding, self).__init__() 108 | 109 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 110 | self.position_embedding = PositionalEmbedding(d_model=d_model) 111 | self.temporal_embedding = ( 112 | TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 113 | if embed_type != "timeF" 114 | else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 115 | ) 116 | self.dropout = nn.Dropout(p=dropout) 117 | 118 | def forward(self, x, x_mark): 119 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 120 | return self.dropout(x) 121 | 122 | 123 | class DataEmbedding_wo_pos(nn.Module): 124 | def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): 125 | super(DataEmbedding_wo_pos, self).__init__() 126 | 127 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 128 | self.position_embedding = PositionalEmbedding(d_model=d_model) 129 | self.temporal_embedding = ( 130 | TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 131 | if embed_type != "timeF" 132 | else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 133 | ) 134 | self.dropout = nn.Dropout(p=dropout) 135 | 136 | def forward(self, x, x_mark): 137 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 138 | return self.dropout(x) 139 | 140 | 141 | class my_Layernorm(nn.Module): 142 | """ 143 | Special designed layernorm for the seasonal part 144 | """ 145 | 146 | def __init__(self, channels): 147 | super(my_Layernorm, self).__init__() 148 | self.layernorm = nn.LayerNorm(channels) 149 | 150 | def forward(self, x): 151 | x_hat = self.layernorm(x) 152 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) 153 | return x_hat - bias 154 | 155 | 156 | class moving_avg(nn.Module): 157 | """ 158 | Moving average block to highlight the trend of time series 159 | """ 160 | 161 | def __init__(self, kernel_size, stride): 162 | super(moving_avg, self).__init__() 163 | self.kernel_size = kernel_size 164 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 165 | 166 | def forward(self, x): 167 | # padding on the both ends of time series 168 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 169 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 170 | x = torch.cat([front, x, end], dim=1) 171 | x = self.avg(x.permute(0, 2, 1)) 172 | x = x.permute(0, 2, 1) 173 | return x 174 | 175 | 176 | class series_decomp(nn.Module): 177 | """ 178 | Series decomposition block 179 | """ 180 | 181 | def __init__(self, kernel_size): 182 | super(series_decomp, self).__init__() 183 | self.moving_avg = moving_avg(kernel_size, stride=1) 184 | 185 | def forward(self, x): 186 | moving_mean = self.moving_avg(x) 187 | res = x - moving_mean 188 | return res, moving_mean 189 | 190 | 191 | class EncoderLayer(nn.Module): 192 | """ 193 | Autoformer encoder layer with the progressive decomposition architecture 194 | """ 195 | 196 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): 197 | super(EncoderLayer, self).__init__() 198 | d_ff = d_ff or 4 * d_model 199 | self.attention = attention 200 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 201 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 202 | self.decomp1 = series_decomp(moving_avg) 203 | self.decomp2 = series_decomp(moving_avg) 204 | self.dropout = nn.Dropout(dropout) 205 | self.activation = F.relu if activation == "relu" else F.gelu 206 | 207 | def forward(self, x, attn_mask=None): 208 | new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) 209 | x = x + self.dropout(new_x) 210 | x, _ = self.decomp1(x) 211 | y = x 212 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 213 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 214 | res, _ = self.decomp2(x + y) 215 | return res, attn 216 | 217 | 218 | class Encoder(nn.Module): 219 | """ 220 | Autoformer encoder 221 | """ 222 | 223 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 224 | super(Encoder, self).__init__() 225 | self.attn_layers = nn.ModuleList(attn_layers) 226 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 227 | self.norm = norm_layer 228 | 229 | def forward(self, x, attn_mask=None): 230 | attns = [] 231 | if self.conv_layers is not None: 232 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 233 | x, attn = attn_layer(x, attn_mask=attn_mask) 234 | x = conv_layer(x) 235 | attns.append(attn) 236 | x, attn = self.attn_layers[-1](x) 237 | attns.append(attn) 238 | else: 239 | for attn_layer in self.attn_layers: 240 | x, attn = attn_layer(x, attn_mask=attn_mask) 241 | attns.append(attn) 242 | 243 | if self.norm is not None: 244 | x = self.norm(x) 245 | 246 | return x, attns 247 | 248 | 249 | class DecoderLayer(nn.Module): 250 | """ 251 | Autoformer decoder layer with the progressive decomposition architecture 252 | """ 253 | 254 | def __init__( 255 | self, self_attention, cross_attention, d_model, c_out, d_ff=None, moving_avg=25, dropout=0.1, activation="relu" 256 | ): 257 | super(DecoderLayer, self).__init__() 258 | d_ff = d_ff or 4 * d_model 259 | self.self_attention = self_attention 260 | self.cross_attention = cross_attention 261 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 262 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 263 | self.decomp1 = series_decomp(moving_avg) 264 | self.decomp2 = series_decomp(moving_avg) 265 | self.decomp3 = series_decomp(moving_avg) 266 | self.dropout = nn.Dropout(dropout) 267 | self.projection = nn.Conv1d( 268 | in_channels=d_model, 269 | out_channels=c_out, 270 | kernel_size=3, 271 | stride=1, 272 | padding=1, 273 | padding_mode="circular", 274 | bias=False, 275 | ) 276 | self.activation = F.relu if activation == "relu" else F.gelu 277 | 278 | def forward(self, x, cross, x_mask=None, cross_mask=None): 279 | x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0]) 280 | x, trend1 = self.decomp1(x) 281 | x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]) 282 | x, trend2 = self.decomp2(x) 283 | y = x 284 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 285 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 286 | x, trend3 = self.decomp3(x + y) 287 | 288 | residual_trend = trend1 + trend2 + trend3 289 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) 290 | return x, residual_trend 291 | 292 | 293 | class Decoder(nn.Module): 294 | """ 295 | Autoformer encoder 296 | """ 297 | 298 | def __init__(self, layers, norm_layer=None, projection=None): 299 | super(Decoder, self).__init__() 300 | self.layers = nn.ModuleList(layers) 301 | self.norm = norm_layer 302 | self.projection = projection 303 | 304 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): 305 | for layer in self.layers: 306 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 307 | trend = trend + residual_trend 308 | 309 | if self.norm is not None: 310 | x = self.norm(x) 311 | 312 | if self.projection is not None: 313 | x = self.projection(x) 314 | return x, trend 315 | 316 | 317 | class AutoCorrelation(nn.Module): 318 | """ 319 | AutoCorrelation Mechanism with the following two phases: 320 | (1) period-based dependencies discovery 321 | (2) time delay aggregation 322 | This block can replace the self-attention family mechanism seamlessly. 323 | """ 324 | 325 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): 326 | super(AutoCorrelation, self).__init__() 327 | self.factor = factor 328 | self.scale = scale 329 | self.mask_flag = mask_flag 330 | self.output_attention = output_attention 331 | self.dropout = nn.Dropout(attention_dropout) 332 | 333 | def time_delay_agg_training(self, values, corr): 334 | """ 335 | SpeedUp version of Autocorrelation (a batch-normalization style design) 336 | This is for the training phase. 337 | """ 338 | head = values.shape[1] 339 | channel = values.shape[2] 340 | length = values.shape[3] 341 | # find top k 342 | top_k = int(self.factor * math.log(length)) 343 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 344 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] 345 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) 346 | # update corr 347 | tmp_corr = torch.softmax(weights, dim=-1) 348 | # aggregation 349 | tmp_values = values 350 | delays_agg = torch.zeros_like(values).float() 351 | for i in range(top_k): 352 | pattern = torch.roll(tmp_values, -int(index[i]), -1) 353 | delays_agg = delays_agg + pattern * ( 354 | tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 355 | ) 356 | return delays_agg 357 | 358 | def time_delay_agg_inference(self, values, corr): 359 | """ 360 | SpeedUp version of Autocorrelation (a batch-normalization style design) 361 | This is for the inference phase. 362 | """ 363 | batch = values.shape[0] 364 | head = values.shape[1] 365 | channel = values.shape[2] 366 | length = values.shape[3] 367 | # index init 368 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 369 | # find top k 370 | top_k = int(self.factor * math.log(length)) 371 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 372 | weights = torch.topk(mean_value, top_k, dim=-1)[0] 373 | delay = torch.topk(mean_value, top_k, dim=-1)[1] 374 | # update corr 375 | tmp_corr = torch.softmax(weights, dim=-1) 376 | # aggregation 377 | tmp_values = values.repeat(1, 1, 1, 2) 378 | delays_agg = torch.zeros_like(values).float() 379 | for i in range(top_k): 380 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 381 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 382 | delays_agg = delays_agg + pattern * ( 383 | tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 384 | ) 385 | return delays_agg 386 | 387 | def time_delay_agg_full(self, values, corr): 388 | """ 389 | Standard version of Autocorrelation 390 | """ 391 | batch = values.shape[0] 392 | head = values.shape[1] 393 | channel = values.shape[2] 394 | length = values.shape[3] 395 | # index init 396 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 397 | # find top k 398 | top_k = int(self.factor * math.log(length)) 399 | weights = torch.topk(corr, top_k, dim=-1)[0] 400 | delay = torch.topk(corr, top_k, dim=-1)[1] 401 | # update corr 402 | tmp_corr = torch.softmax(weights, dim=-1) 403 | # aggregation 404 | tmp_values = values.repeat(1, 1, 1, 2) 405 | delays_agg = torch.zeros_like(values).float() 406 | for i in range(top_k): 407 | tmp_delay = init_index + delay[..., i].unsqueeze(-1) 408 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 409 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) 410 | return delays_agg 411 | 412 | def forward(self, queries, keys, values, attn_mask): 413 | B, L, H, E = queries.shape 414 | _, S, _, D = values.shape 415 | if L > S: 416 | zeros = torch.zeros_like(queries[:, : (L - S), :]).float() 417 | values = torch.cat([values, zeros], dim=1) 418 | keys = torch.cat([keys, zeros], dim=1) 419 | else: 420 | values = values[:, :L, :, :] 421 | keys = keys[:, :L, :, :] 422 | 423 | # period-based dependencies 424 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) 425 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) 426 | res = q_fft * torch.conj(k_fft) 427 | corr = torch.fft.irfft(res, dim=-1) 428 | 429 | # time delay agg 430 | if self.training: 431 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 432 | else: 433 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 434 | 435 | if self.output_attention: 436 | return (V.contiguous(), corr.permute(0, 3, 1, 2)) 437 | else: 438 | return (V.contiguous(), None) 439 | 440 | 441 | class AutoCorrelationLayer(nn.Module): 442 | def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None): 443 | super(AutoCorrelationLayer, self).__init__() 444 | 445 | d_keys = d_keys or (d_model // n_heads) 446 | d_values = d_values or (d_model // n_heads) 447 | 448 | self.inner_correlation = correlation 449 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 450 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 451 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 452 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 453 | self.n_heads = n_heads 454 | 455 | def forward(self, queries, keys, values, attn_mask): 456 | B, L, _ = queries.shape 457 | _, S, _ = keys.shape 458 | H = self.n_heads 459 | 460 | queries = self.query_projection(queries).view(B, L, H, -1) 461 | keys = self.key_projection(keys).view(B, S, H, -1) 462 | values = self.value_projection(values).view(B, S, H, -1) 463 | 464 | out, attn = self.inner_correlation(queries, keys, values, attn_mask) 465 | out = out.view(B, L, -1) 466 | 467 | return self.out_projection(out), attn 468 | 469 | 470 | class Autoformer(nn.Module): 471 | """ 472 | Autoformer is the first method to achieve the series-wise connection, 473 | with inherent O(LlogL) complexity 474 | """ 475 | 476 | def __init__( 477 | self, 478 | enc_in, 479 | dec_in, 480 | c_out, 481 | seq_len, 482 | label_len, 483 | out_len, 484 | factor=5, 485 | d_model=512, 486 | n_heads=8, 487 | e_layers=3, 488 | d_layers=2, 489 | d_ff=512, 490 | moving_avg=25, 491 | dropout=0.0, 492 | embed="fixed", 493 | freq="h", 494 | activation="gelu", 495 | output_attention=False, 496 | ): 497 | super(Autoformer, self).__init__() 498 | self.seq_len = seq_len 499 | self.label_len = label_len 500 | self.pred_len = out_len 501 | self.output_attention = output_attention 502 | 503 | # Decomp 504 | kernel_size = moving_avg 505 | self.decomp = series_decomp(kernel_size) 506 | 507 | # Embedding 508 | # The series-wise connection inherently contains the sequential information. 509 | # Thus, we can discard the position embedding of transformers. 510 | self.enc_embedding = DataEmbedding_wo_pos(enc_in, d_model, embed, freq, dropout) 511 | self.dec_embedding = DataEmbedding_wo_pos(dec_in, d_model, embed, freq, dropout) 512 | 513 | # Encoder 514 | self.encoder = Encoder( 515 | [ 516 | EncoderLayer( 517 | AutoCorrelationLayer( 518 | AutoCorrelation(False, factor, attention_dropout=dropout, output_attention=output_attention), 519 | d_model, 520 | n_heads, 521 | ), 522 | d_model, 523 | d_ff, 524 | moving_avg=moving_avg, 525 | dropout=dropout, 526 | activation=activation, 527 | ) 528 | for l in range(e_layers) 529 | ], 530 | norm_layer=my_Layernorm(d_model), 531 | ) 532 | # Decoder 533 | self.decoder = Decoder( 534 | [ 535 | DecoderLayer( 536 | AutoCorrelationLayer( 537 | AutoCorrelation(True, factor, attention_dropout=dropout, output_attention=False), 538 | d_model, 539 | n_heads, 540 | ), 541 | AutoCorrelationLayer( 542 | AutoCorrelation(False, factor, attention_dropout=dropout, output_attention=False), 543 | d_model, 544 | n_heads, 545 | ), 546 | d_model, 547 | c_out, 548 | d_ff, 549 | moving_avg=moving_avg, 550 | dropout=dropout, 551 | activation=activation, 552 | ) 553 | for l in range(d_layers) 554 | ], 555 | norm_layer=my_Layernorm(d_model), 556 | projection=nn.Linear(d_model, c_out, bias=True), 557 | ) 558 | 559 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 560 | # decomp init 561 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) 562 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device) 563 | seasonal_init, trend_init = self.decomp(x_enc) 564 | # decoder input 565 | trend_init = torch.cat([trend_init[:, -self.label_len :, :], mean], dim=1) 566 | seasonal_init = torch.cat([seasonal_init[:, -self.label_len :, :], zeros], dim=1) 567 | # enc 568 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 569 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 570 | # dec 571 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 572 | seasonal_part, trend_part = self.decoder( 573 | dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, trend=trend_init 574 | ) 575 | # final 576 | dec_out = trend_part + seasonal_part 577 | 578 | if self.output_attention: 579 | return dec_out[:, -self.pred_len :, :], attns 580 | else: 581 | return dec_out[:, -self.pred_len :, :] # [B, L, D] -------------------------------------------------------------------------------- /models/informer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # Attention Model: TriangularCausalMask, ProbMask, FullAttention, ProbAttention, AttentionLayer 9 | class TriangularCausalMask: 10 | def __init__(self, B, L): 11 | mask_shape = [B, 1, L, L] 12 | with torch.no_grad(): 13 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool).cuda(), diagonal=1) 14 | 15 | @property 16 | def mask(self): 17 | return self._mask 18 | 19 | 20 | class ProbMask: 21 | def __init__(self, B, H, L, index, scores): 22 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).cuda().triu(1) 23 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 24 | indicator = _mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] 25 | self._mask = indicator.view(scores.shape) 26 | 27 | @property 28 | def mask(self): 29 | return self._mask 30 | 31 | 32 | class FullAttention(nn.Module): 33 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 34 | super(FullAttention, self).__init__() 35 | self.scale = scale 36 | self.mask_flag = mask_flag 37 | self.output_attention = output_attention 38 | self.dropout = nn.Dropout(attention_dropout) 39 | 40 | def forward(self, queries, keys, values, attn_mask): 41 | B, L, H, E = queries.shape 42 | _, S, _, D = values.shape 43 | scale = self.scale or 1.0 / math.sqrt(E) 44 | 45 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 46 | if self.mask_flag: 47 | if attn_mask is None: 48 | attn_mask = TriangularCausalMask(B, L) 49 | 50 | scores.masked_fill_(attn_mask.mask, -np.inf) 51 | 52 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 53 | V = torch.einsum("bhls,bshd->blhd", A, values) 54 | 55 | if self.output_attention: 56 | return V.contiguous(), A 57 | else: 58 | return V.contiguous(), None 59 | 60 | 61 | class ProbAttention(nn.Module): 62 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 63 | super(ProbAttention, self).__init__() 64 | self.factor = factor 65 | self.scale = scale 66 | self.mask_flag = mask_flag 67 | self.output_attention = output_attention 68 | self.dropout = nn.Dropout(attention_dropout) 69 | 70 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 71 | # Q [B, H, L, D] 72 | B, H, L_K, E = K.shape 73 | _, _, L_Q, _ = Q.shape 74 | 75 | # calculate the sampled Q_K 76 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 77 | index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q 78 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] 79 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) 80 | 81 | # find the Top_k query with sparisty measurement 82 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 83 | M_top = M.topk(n_top, sorted=False)[1] 84 | 85 | # use the reduced Q to calculate Q_K 86 | Q_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] # factor*ln(L_q) 87 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 88 | 89 | return Q_K, M_top 90 | 91 | def _get_initial_context(self, V, L_Q): 92 | B, H, L_V, D = V.shape 93 | if not self.mask_flag: 94 | # V_sum = V.sum(dim=-2) 95 | V_sum = V.mean(dim=-2) 96 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() 97 | else: # use mask 98 | assert L_Q == L_V # requires that L_Q == L_V, i.e. for self-attention only 99 | contex = V.cumsum(dim=-2) 100 | return contex 101 | 102 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 103 | B, H, L_V, D = V.shape 104 | 105 | if self.mask_flag: 106 | attn_mask = ProbMask(B, H, L_Q, index, scores) 107 | scores.masked_fill_(attn_mask.mask, -np.inf) 108 | 109 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 110 | 111 | context_in[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = torch.matmul( 112 | attn, V 113 | ).type_as(context_in) 114 | if self.output_attention: 115 | attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn) 116 | attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn 117 | return context_in, attns 118 | else: 119 | return context_in, None 120 | 121 | def forward(self, queries, keys, values, attn_mask): 122 | B, L_Q, H, D = queries.shape 123 | _, L_K, _, _ = keys.shape 124 | 125 | queries = queries.transpose(2, 1) 126 | keys = keys.transpose(2, 1) 127 | values = values.transpose(2, 1) 128 | 129 | U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item() # c*ln(L_k) 130 | u = self.factor * np.ceil(np.log(L_Q)).astype("int").item() # c*ln(L_q) 131 | 132 | U_part = U_part if U_part < L_K else L_K 133 | u = u if u < L_Q else L_Q 134 | 135 | scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 136 | 137 | # add scale factor 138 | scale = self.scale or 1.0 / math.sqrt(D) 139 | if scale is not None: 140 | scores_top = scores_top * scale 141 | # get the context 142 | context = self._get_initial_context(values, L_Q) 143 | # update the context with selected top_k queries 144 | context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) 145 | 146 | return context.transpose(2, 1).contiguous(), attn 147 | 148 | 149 | class AttentionLayer(nn.Module): 150 | def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None, mix=False): 151 | super(AttentionLayer, self).__init__() 152 | 153 | d_keys = d_keys or (d_model // n_heads) 154 | d_values = d_values or (d_model // n_heads) 155 | 156 | self.inner_attention = attention 157 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 158 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 159 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 160 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 161 | self.n_heads = n_heads 162 | self.mix = mix 163 | 164 | def forward(self, queries, keys, values, attn_mask): 165 | B, L, _ = queries.shape 166 | _, S, _ = keys.shape 167 | H = self.n_heads 168 | 169 | queries = self.query_projection(queries).view(B, L, H, -1) 170 | keys = self.key_projection(keys).view(B, S, H, -1) 171 | values = self.value_projection(values).view(B, S, H, -1) 172 | 173 | out, attn = self.inner_attention(queries, keys, values, attn_mask) 174 | if self.mix: 175 | out = out.transpose(2, 1).contiguous() 176 | out = out.view(B, L, -1) 177 | 178 | return self.out_projection(out), attn 179 | 180 | 181 | # Embedding Model: PositionalEmbedding, TokenEmbedding, FixedEmbedding, TemporalEmbedding, TimeFeatureEmbedding, DataEmbedding 182 | 183 | 184 | class PositionalEmbedding(nn.Module): 185 | def __init__(self, d_model, max_len=5000): 186 | super(PositionalEmbedding, self).__init__() 187 | # Compute the positional encodings once in log space. 188 | pe = torch.zeros(max_len, d_model).float() 189 | pe.require_grad = False 190 | 191 | position = torch.arange(0, max_len).float().unsqueeze(1) 192 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 193 | 194 | pe[:, 0::2] = torch.sin(position * div_term) 195 | pe[:, 1::2] = torch.cos(position * div_term) 196 | 197 | pe = pe.unsqueeze(0) 198 | self.register_buffer("pe", pe) 199 | 200 | def forward(self, x): 201 | return self.pe[:, : x.size(1)] 202 | 203 | 204 | class TokenEmbedding(nn.Module): 205 | def __init__(self, c_in, d_model): 206 | super(TokenEmbedding, self).__init__() 207 | padding = 1 if torch.__version__ >= "1.5.0" else 2 208 | self.tokenConv = nn.Conv1d( 209 | in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode="circular" 210 | ) 211 | for m in self.modules(): 212 | if isinstance(m, nn.Conv1d): 213 | nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="leaky_relu") 214 | 215 | def forward(self, x): 216 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 217 | return x 218 | 219 | 220 | class FixedEmbedding(nn.Module): 221 | def __init__(self, c_in, d_model): 222 | super(FixedEmbedding, self).__init__() 223 | 224 | w = torch.zeros(c_in, d_model).float() 225 | w.require_grad = False 226 | 227 | position = torch.arange(0, c_in).float().unsqueeze(1) 228 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 229 | 230 | w[:, 0::2] = torch.sin(position * div_term) 231 | w[:, 1::2] = torch.cos(position * div_term) 232 | 233 | self.emb = nn.Embedding(c_in, d_model) 234 | self.emb.weight = nn.Parameter(w, requires_grad=False) 235 | 236 | def forward(self, x): 237 | return self.emb(x).detach() 238 | 239 | 240 | class TemporalEmbedding(nn.Module): 241 | def __init__(self, d_model, embed_type="fixed", freq="h"): 242 | super(TemporalEmbedding, self).__init__() 243 | 244 | minute_size = 4 245 | hour_size = 24 246 | weekday_size = 7 247 | day_size = 32 248 | month_size = 13 249 | 250 | Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding 251 | if freq == "t": 252 | self.minute_embed = Embed(minute_size, d_model) 253 | self.hour_embed = Embed(hour_size, d_model) 254 | self.weekday_embed = Embed(weekday_size, d_model) 255 | self.day_embed = Embed(day_size, d_model) 256 | self.month_embed = Embed(month_size, d_model) 257 | 258 | def forward(self, x): 259 | x = x.long() 260 | 261 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 262 | hour_x = self.hour_embed(x[:, :, 3]) 263 | weekday_x = self.weekday_embed(x[:, :, 2]) 264 | day_x = self.day_embed(x[:, :, 1]) 265 | month_x = self.month_embed(x[:, :, 0]) 266 | 267 | return hour_x + weekday_x + day_x + month_x + minute_x 268 | 269 | 270 | class TimeFeatureEmbedding(nn.Module): 271 | def __init__(self, d_model, embed_type="timeF", freq="h"): 272 | super(TimeFeatureEmbedding, self).__init__() 273 | 274 | freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} 275 | d_inp = freq_map[freq] 276 | self.embed = nn.Linear(d_inp, d_model) 277 | 278 | def forward(self, x): 279 | return self.embed(x) 280 | 281 | 282 | class DataEmbedding(nn.Module): 283 | def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): 284 | super(DataEmbedding, self).__init__() 285 | 286 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 287 | self.position_embedding = PositionalEmbedding(d_model=d_model) 288 | if embed_type != "timeF": 289 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 290 | else: 291 | self.temporal_embedding = TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 292 | 293 | self.dropout = nn.Dropout(p=dropout) 294 | 295 | def forward(self, x, x_mark): 296 | x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark) 297 | 298 | return self.dropout(x) 299 | 300 | 301 | # Encoder Layer: ConvLayer, EncoderLayer, Encoder, EncoderStack 302 | 303 | 304 | class ConvLayer(nn.Module): 305 | def __init__(self, c_in): 306 | super(ConvLayer, self).__init__() 307 | padding = 1 if torch.__version__ >= "1.5.0" else 2 308 | self.downConv = nn.Conv1d( 309 | in_channels=c_in, out_channels=c_in, kernel_size=3, padding=padding, padding_mode="circular" 310 | ) 311 | self.norm = nn.BatchNorm1d(c_in) 312 | self.activation = nn.ELU() 313 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 314 | 315 | def forward(self, x): 316 | x = self.downConv(x.permute(0, 2, 1)) 317 | x = self.norm(x) 318 | x = self.activation(x) 319 | x = self.maxPool(x) 320 | x = x.transpose(1, 2) 321 | return x 322 | 323 | 324 | class EncoderLayer(nn.Module): 325 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 326 | super(EncoderLayer, self).__init__() 327 | d_ff = d_ff or 4 * d_model 328 | self.attention = attention 329 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 330 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 331 | self.norm1 = nn.LayerNorm(d_model) 332 | self.norm2 = nn.LayerNorm(d_model) 333 | self.dropout = nn.Dropout(dropout) 334 | self.activation = F.relu if activation == "relu" else F.gelu 335 | 336 | def forward(self, x, attn_mask=None): 337 | # x [B, L, D] 338 | # x = x + self.dropout(self.attention( 339 | # x, x, x, 340 | # attn_mask = attn_mask 341 | # )) 342 | new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) 343 | x = x + self.dropout(new_x) 344 | 345 | y = x = self.norm1(x) 346 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 347 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 348 | 349 | return self.norm2(x + y), attn 350 | 351 | 352 | class Encoder(nn.Module): 353 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 354 | super(Encoder, self).__init__() 355 | self.attn_layers = nn.ModuleList(attn_layers) 356 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 357 | self.norm = norm_layer 358 | 359 | def forward(self, x, attn_mask=None): 360 | # x [B, L, D] 361 | attns = [] 362 | if self.conv_layers is not None: 363 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 364 | x, attn = attn_layer(x, attn_mask=attn_mask) 365 | x = conv_layer(x) 366 | attns.append(attn) 367 | x, attn = self.attn_layers[-1](x, attn_mask=attn_mask) 368 | attns.append(attn) 369 | else: 370 | for attn_layer in self.attn_layers: 371 | x, attn = attn_layer(x, attn_mask=attn_mask) 372 | attns.append(attn) 373 | 374 | if self.norm is not None: 375 | x = self.norm(x) 376 | 377 | return x, attns 378 | 379 | 380 | class EncoderStack(nn.Module): 381 | def __init__(self, encoders, inp_lens): 382 | super(EncoderStack, self).__init__() 383 | self.encoders = nn.ModuleList(encoders) 384 | self.inp_lens = inp_lens 385 | 386 | def forward(self, x, attn_mask=None): 387 | # x [B, L, D] 388 | x_stack = [] 389 | attns = [] 390 | for i_len, encoder in zip(self.inp_lens, self.encoders): 391 | inp_len = x.shape[1] // (2 ** i_len) 392 | x_s, attn = encoder(x[:, -inp_len:, :]) 393 | x_stack.append(x_s) 394 | attns.append(attn) 395 | x_stack = torch.cat(x_stack, -2) 396 | 397 | return x_stack, attns 398 | 399 | 400 | # Decoder Layer: DecoderLayer, Decoder 401 | 402 | 403 | class DecoderLayer(nn.Module): 404 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 405 | super(DecoderLayer, self).__init__() 406 | d_ff = d_ff or 4 * d_model 407 | self.self_attention = self_attention 408 | self.cross_attention = cross_attention 409 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 410 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 411 | self.norm1 = nn.LayerNorm(d_model) 412 | self.norm2 = nn.LayerNorm(d_model) 413 | self.norm3 = nn.LayerNorm(d_model) 414 | self.dropout = nn.Dropout(dropout) 415 | self.activation = F.relu if activation == "relu" else F.gelu 416 | 417 | def forward(self, x, cross, x_mask=None, cross_mask=None): 418 | x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0]) 419 | x = self.norm1(x) 420 | 421 | x = x + self.dropout(self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]) 422 | 423 | y = x = self.norm2(x) 424 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 425 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 426 | 427 | return self.norm3(x + y) 428 | 429 | 430 | class Decoder(nn.Module): 431 | def __init__(self, layers, norm_layer=None): 432 | super(Decoder, self).__init__() 433 | self.layers = nn.ModuleList(layers) 434 | self.norm = norm_layer 435 | 436 | def forward(self, x, cross, x_mask=None, cross_mask=None): 437 | for layer in self.layers: 438 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 439 | 440 | if self.norm is not None: 441 | x = self.norm(x) 442 | 443 | return x 444 | 445 | 446 | class Informer(nn.Module): 447 | """ 448 | Implementation of AAAI 2021 best paper "Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting" 449 | 450 | Args: 451 | enc_in (int): encoder input size 452 | dec_in (int): decoder input size 453 | c_out (int): output size 454 | seq_len (int): input sequence length of Informer encode 455 | label_len (int): start token length of Informer decoder 456 | out_len (int): prediction sequence length 457 | factor (int): probsparse attn factor (default is 5) 458 | d_model (int): dimension of model (default is 512) 459 | n_heads (int): num of heads (default is 8) 460 | e_layers (int): num of encoder layers (default is 3) 461 | d_layers (int): num of decoder layers (default is 2) 462 | d_ff (int): dimension of fcn (default is 512) 463 | dropout (float): dropout (default is 0.0) 464 | attn (string): attention used in encoder, options:[prob, full] (default is prob) 465 | embed (string): time features encoding, options:[timeF, fixed, learned] (default is fixed) 466 | freq (string): freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h (default is h) 467 | activation (string): activation (default is gelu) 468 | output_attention (bool): whether to output attention in ecoder (default is False) 469 | distil (bool): whether to use distilling in encoder, using this argument means not using distilling (default is True) 470 | mix (bool): whether to use mix attention in generative decode (default is True) 471 | 472 | Predefined models: 473 | 474 | informer_univariate: 475 | enc_in = 1 476 | dec_in = 1 477 | c_out = 1 478 | seq_len = 720 479 | label_len = 168 480 | out_len = 24 481 | e_layers = 2 482 | d_layers = 1 483 | d_ff = 2048 484 | dropout = 0.05 485 | embed = 'timeF' 486 | 487 | informer_multivariate: 488 | enc_in = 7 489 | dec_in = 7 490 | c_out = 7 491 | seq_len = 48 492 | label_len = 48 493 | out_len = 24 494 | e_layers = 2 495 | d_layers = 1 496 | d_ff = 2048 497 | dropout = 0.05 498 | embed = 'timeF' 499 | 500 | Example: 501 | 502 | .. code-block:: python 503 | 504 | x, x_mask = torch.zeros(32, 720, 1).cuda(), torch.zeros(32, 720, 4).cuda() 505 | y, y_mask = torch.zeros(32, 168, 1).cuda(), torch.zeros(32, 168, 4).cuda() 506 | 507 | model = hfai.models.informer_univariate().cuda() 508 | # or 509 | # model = hfai.models.Informer( 510 | # enc_in=1, 511 | # dec_in=1, 512 | # c_out=1, 513 | # seq_len=720, 514 | # label_len=168, 515 | # out_len=24, 516 | # e_layers=2, 517 | # d_layers=1, 518 | # d_ff=2048, 519 | # dropout=0.05, 520 | # embed='timeF', 521 | # ) 522 | 523 | # forward 524 | logits = model(x, x_mask, y, y_mask) 525 | 526 | """ 527 | 528 | def __init__( 529 | self, 530 | enc_in, 531 | dec_in, 532 | c_out, 533 | seq_len, 534 | label_len, 535 | out_len, 536 | factor=5, 537 | d_model=512, 538 | n_heads=8, 539 | e_layers=3, 540 | d_layers=2, 541 | d_ff=512, 542 | dropout=0.0, 543 | attn="prob", 544 | embed="fixed", 545 | freq="h", 546 | activation="gelu", 547 | output_attention=False, 548 | distil=True, 549 | mix=True, 550 | ): 551 | super(Informer, self).__init__() 552 | self.seq_len = seq_len 553 | self.label_len = label_len 554 | self.pred_len = out_len 555 | self.attn = attn 556 | self.output_attention = output_attention 557 | 558 | # Encoding 559 | self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout) 560 | self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout) 561 | # Attention 562 | Attn = ProbAttention if attn == "prob" else FullAttention 563 | # Encoder 564 | self.encoder = Encoder( 565 | [ 566 | EncoderLayer( 567 | AttentionLayer( 568 | Attn(False, factor, attention_dropout=dropout, output_attention=output_attention), 569 | d_model, 570 | n_heads, 571 | mix=False, 572 | ), 573 | d_model, 574 | d_ff, 575 | dropout=dropout, 576 | activation=activation, 577 | ) 578 | for l in range(e_layers) 579 | ], 580 | [ConvLayer(d_model) for l in range(e_layers - 1)] if distil else None, 581 | norm_layer=torch.nn.LayerNorm(d_model), 582 | ) 583 | # Decoder 584 | self.decoder = Decoder( 585 | [ 586 | DecoderLayer( 587 | AttentionLayer( 588 | Attn(True, factor, attention_dropout=dropout, output_attention=False), d_model, n_heads, mix=mix 589 | ), 590 | AttentionLayer( 591 | FullAttention(False, factor, attention_dropout=dropout, output_attention=False), 592 | d_model, 593 | n_heads, 594 | mix=False, 595 | ), 596 | d_model, 597 | d_ff, 598 | dropout=dropout, 599 | activation=activation, 600 | ) 601 | for l in range(d_layers) 602 | ], 603 | norm_layer=torch.nn.LayerNorm(d_model), 604 | ) 605 | # self.end_conv1 = nn.Conv1d(in_channels=label_len+out_len, out_channels=out_len, kernel_size=1, bias=True) 606 | # self.end_conv2 = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=1, bias=True) 607 | self.projection = nn.Linear(d_model, c_out, bias=True) 608 | 609 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 610 | """ 611 | Args: 612 | x_enc (Tensor): input time series, size ``[B, L, D]`` 613 | x_mark_enc (Tensor): input time-series features, size ``[B, L, D]`` 614 | x_dec (Tensor): target time series, size ``[B, L, D]`` 615 | x_mark_dec (Tensor): target time-series features, size ``[B, L, D]`` 616 | enc_self_mask (Tensor): encode mask, size ``[B, D]`` 617 | dec_self_mask (Tensor): decode mask, size ``[B, D]`` 618 | dec_enc_mask (Tensor): encode-decode mask, size ``[B, D]`` 619 | 620 | Returns: 621 | pred (Tensor): prediction, size ``[B, L, D]`` 622 | """ 623 | 624 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 625 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 626 | 627 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 628 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 629 | dec_out = self.projection(dec_out) 630 | 631 | # dec_out = self.end_conv1(dec_out) 632 | # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2) 633 | if self.output_attention: 634 | return dec_out[:, -self.pred_len :, :], attns 635 | else: 636 | return dec_out[:, -self.pred_len :, :] 637 | 638 | 639 | class InformerStack(nn.Module): 640 | """ 641 | Implementation of AAAI 2021 best paper "Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting" 642 | 643 | Args: 644 | enc_in (int): encoder input size 645 | dec_in (int): decoder input size 646 | c_out (int): output size 647 | seq_len (int): input sequence length of Informer encode 648 | label_len (int): start token length of Informer decoder 649 | out_len (int): prediction sequence length 650 | factor (int): probsparse attn factor (default is 5) 651 | d_model (int): dimension of model (default is 512) 652 | n_heads (int): num of heads (default is 8) 653 | e_layers (list): num of stack encoder layers (default is [3,2,1]) 654 | d_layers (int): num of decoder layers (default is 2) 655 | d_ff (int): dimension of fcn (default is 512) 656 | dropout (float): dropout (default is 0.0) 657 | attn (string): attention used in encoder, options:[prob, full] (default is prob) 658 | embed (string): time features encoding, options:[timeF, fixed, learned] (default is fixed) 659 | freq (string): freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h (default is h) 660 | activation (string): activation (default is gelu) 661 | output_attention (bool): whether to output attention in ecoder (default is False) 662 | distil (bool): whether to use distilling in encoder, using this argument means not using distilling (default is True) 663 | mix (bool): whether to use mix attention in generative decode (default is True) 664 | 665 | Predefined models: 666 | 667 | informer_stack_univariate: 668 | enc_in = 1 669 | dec_in = 1 670 | c_out = 1 671 | seq_len = 720 672 | label_len = 168 673 | out_len = 24 674 | e_layers = 2 675 | d_layers = 1 676 | d_ff = 2048 677 | dropout = 0.05 678 | embed = 'timeF' 679 | 680 | informer_stack_multivariate: 681 | enc_in = 7 682 | dec_in = 7 683 | c_out = 7 684 | seq_len = 48 685 | label_len = 48 686 | out_len = 24 687 | e_layers = 2 688 | d_layers = 1 689 | d_ff = 2048 690 | dropout = 0.05 691 | embed = 'timeF' 692 | 693 | Example: 694 | 695 | .. code-block:: python 696 | 697 | x, x_mask = torch.zeros(32, 720, 1).cuda(), torch.zeros(32, 720, 4).cuda() 698 | y, y_mask = torch.zeros(32, 168, 1).cuda(), torch.zeros(32, 168, 4).cuda() 699 | 700 | model = hfai.models.informer_stack_univariate().cuda() 701 | # or 702 | # model = hfai.models.InformerStack( 703 | # enc_in=1, 704 | # dec_in=1, 705 | # c_out=1, 706 | # seq_len=720, 707 | # label_len=168, 708 | # out_len=24, 709 | # e_layers=2, 710 | # d_layers=1, 711 | # d_ff=2048, 712 | # dropout=0.05, 713 | # embed='timeF', 714 | # ) 715 | 716 | # forward 717 | logits = model(x, x_mask, y, y_mask) 718 | 719 | """ 720 | 721 | def __init__( 722 | self, 723 | enc_in, 724 | dec_in, 725 | c_out, 726 | seq_len, 727 | label_len, 728 | out_len, 729 | factor=5, 730 | d_model=512, 731 | n_heads=8, 732 | e_layers=[3, 2, 1], 733 | d_layers=2, 734 | d_ff=512, 735 | dropout=0.0, 736 | attn="prob", 737 | embed="fixed", 738 | freq="h", 739 | activation="gelu", 740 | output_attention=False, 741 | distil=True, 742 | mix=True, 743 | ): 744 | super(InformerStack, self).__init__() 745 | self.seq_len = seq_len 746 | self.label_len = label_len 747 | self.pred_len = out_len 748 | self.attn = attn 749 | self.output_attention = output_attention 750 | 751 | # Encoding 752 | self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout) 753 | self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout) 754 | # Attention 755 | Attn = ProbAttention if attn == "prob" else FullAttention 756 | # Encoder 757 | 758 | inp_lens = list(range(len(e_layers))) # [0,1,2,...] you can customize here 759 | encoders = [ 760 | Encoder( 761 | [ 762 | EncoderLayer( 763 | AttentionLayer( 764 | Attn(False, factor, attention_dropout=dropout, output_attention=output_attention), 765 | d_model, 766 | n_heads, 767 | mix=False, 768 | ), 769 | d_model, 770 | d_ff, 771 | dropout=dropout, 772 | activation=activation, 773 | ) 774 | for l in range(el) 775 | ], 776 | [ConvLayer(d_model) for l in range(el - 1)] if distil else None, 777 | norm_layer=torch.nn.LayerNorm(d_model), 778 | ) 779 | for el in e_layers 780 | ] 781 | self.encoder = EncoderStack(encoders, inp_lens) 782 | # Decoder 783 | self.decoder = Decoder( 784 | [ 785 | DecoderLayer( 786 | AttentionLayer( 787 | Attn(True, factor, attention_dropout=dropout, output_attention=False), d_model, n_heads, mix=mix 788 | ), 789 | AttentionLayer( 790 | FullAttention(False, factor, attention_dropout=dropout, output_attention=False), 791 | d_model, 792 | n_heads, 793 | mix=False, 794 | ), 795 | d_model, 796 | d_ff, 797 | dropout=dropout, 798 | activation=activation, 799 | ) 800 | for l in range(d_layers) 801 | ], 802 | norm_layer=torch.nn.LayerNorm(d_model), 803 | ) 804 | # self.end_conv1 = nn.Conv1d(in_channels=label_len+out_len, out_channels=out_len, kernel_size=1, bias=True) 805 | # self.end_conv2 = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=1, bias=True) 806 | self.projection = nn.Linear(d_model, c_out, bias=True) 807 | 808 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 809 | """ 810 | Args: 811 | x_enc (Tensor): input time series, size ``[B, L, D]`` 812 | x_mark_enc (Tensor): input time-series features, size ``[B, L, D]`` 813 | x_dec (Tensor): target time series, size ``[B, L, D]`` 814 | x_mark_dec (Tensor): target time-series features, size ``[B, L, D]`` 815 | enc_self_mask (Tensor): encode mask, size ``[B, D]`` 816 | dec_self_mask (Tensor): decode mask, size ``[B, D]`` 817 | dec_enc_mask (Tensor): encode-decode mask, size ``[B, D]`` 818 | 819 | Returns: 820 | pred (Tensor): prediction, size ``[B, L, D]`` 821 | """ 822 | 823 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 824 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 825 | 826 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 827 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 828 | dec_out = self.projection(dec_out) 829 | 830 | # dec_out = self.end_conv1(dec_out) 831 | # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2) 832 | if self.output_attention: 833 | return dec_out[:, -self.pred_len :, :], attns 834 | else: 835 | return dec_out[:, -self.pred_len :, :] # [B, L, D] -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hf_env 2 | hf_env.set_env("202111") 3 | 4 | import os 5 | import time 6 | import argparse 7 | import numpy as np 8 | from pathlib import Path 9 | import torch 10 | from torch.nn.parallel import DistributedDataParallel 11 | 12 | import hfai 13 | import hfai.nccl.distributed as dist 14 | from torch.multiprocessing import Process 15 | hfai.client.bind_hf_except_hook(Process) 16 | 17 | from datasets import get_dataloader 18 | from models import Informer, InformerStack, Autoformer 19 | 20 | 21 | ########################################### 22 | # CONFIG 23 | ########################################### 24 | 25 | parser = argparse.ArgumentParser(description="Train LTSF Formers") 26 | parser.add_argument("--ds", type=str, default="ETTh1", help="dataset name") 27 | parser.add_argument("--model", type=str, default="informer", help="former model") 28 | parser.add_argument("--epochs", type=int, default=100, help="training epoch") 29 | parser.add_argument("--bs", type=int, default=64, help="batch size") 30 | parser.add_argument("--n_workers", type=int, default=8, help="num of workers") 31 | parser.add_argument("--lr", type=float, default=0.0001, help="learning rate") 32 | parser.add_argument("--seq_len", type=int, default=96, help="sequence length") 33 | parser.add_argument("--label_len", type=int, default=48, help="label sequence length") 34 | parser.add_argument("--pred_len", type=int, default=48, help="prediction length") 35 | parser.add_argument("--feature", type=str, default='S', help="prediction mode") 36 | args = parser.parse_args() 37 | 38 | # 超参数设置 39 | epochs = args.epochs 40 | batch_size = args.bs 41 | num_workers = args.n_workers 42 | lr = args.lr 43 | data_name = args.ds 44 | seq_len = args.seq_len 45 | label_len = args.label_len 46 | pred_len = args.pred_len 47 | features = args.feature 48 | model_name = args.model 49 | 50 | save_path = Path(f"output/{data_name}/{model_name}") 51 | save_path.mkdir(exist_ok=True, parents=True) 52 | 53 | best_mse = np.inf 54 | 55 | 56 | def process_one_batch(model, standard_scaler, batch_x, batch_y, batch_x_mark, batch_y_mark): 57 | x = batch_x.float().cuda(non_blocking=True) 58 | x_mark = batch_x_mark.float().cuda(non_blocking=True) 59 | y_mark = batch_y_mark.float().cuda(non_blocking=True) 60 | 61 | # decoder input 62 | dec_inp = torch.zeros([batch_y.shape[0], pred_len, batch_y.shape[-1]]).float() 63 | dec_inp = torch.cat([batch_y[:, : label_len, :], dec_inp], dim=1).float().cuda(non_blocking=True) 64 | 65 | # encoder - decoder 66 | outputs = model(x, x_mark, dec_inp, y_mark) 67 | y_pred = standard_scaler.inverse_transform(outputs) 68 | f_dim = -1 if features == "MS" else 0 69 | y_true = batch_y[:, pred_len:, f_dim:].float().cuda(non_blocking=True) 70 | 71 | return y_pred, y_true 72 | 73 | 74 | def train_one_epoch(epoch, start_step, model, train_loader, optimizer, criterion, standard_scaler): 75 | model.train() 76 | for step, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): 77 | if step < start_step: 78 | continue 79 | 80 | optimizer.zero_grad() 81 | y_pred, y_true = process_one_batch(model, standard_scaler, batch_x, batch_y, batch_x_mark, batch_y_mark) 82 | loss = criterion(y_pred, y_true) 83 | loss.backward() 84 | optimizer.step() 85 | 86 | # 收到打断信号,保存模型,设置当前执行的状态信息 87 | rank = dist.get_rank() 88 | if rank == 0 and hfai.receive_suspend_command(): 89 | state = { 90 | "model": model.module.state_dict(), 91 | "optimizer": optimizer.state_dict(), 92 | "epoch": epoch, 93 | "step": step, 94 | "best_mse": best_mse, 95 | } 96 | torch.save(state, save_path / "latest.tar") 97 | time.sleep(5) 98 | hfai.go_suspend() 99 | 100 | 101 | def eval(model, criterion, eval_loader, standard_scaler): 102 | loss, total = torch.zeros(2).cuda() 103 | 104 | model.eval() 105 | with torch.no_grad(): 106 | for _, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(eval_loader): 107 | y_pred, y_true = process_one_batch(model, standard_scaler, batch_x, batch_y, batch_x_mark, batch_y_mark) 108 | loss += criterion(y_pred, y_true) 109 | total += y_true.size(0) 110 | 111 | for x in [loss, total]: 112 | dist.reduce(x, 0) 113 | 114 | loss_val = 0 115 | if dist.get_rank() == 0: 116 | loss_val = loss.item() / total.item() 117 | return loss_val 118 | 119 | 120 | def fit(model, optimizer, criterion, train_loader, val_loader, standard_scaler=None): 121 | 122 | global best_mse 123 | if standard_scaler is None: 124 | raise RuntimeError("The standard scaler is None.") 125 | 126 | rank = dist.get_rank() 127 | 128 | # 如果模型存在checkpoint 129 | start_epoch, start_step = 0, 0 130 | if Path(save_path / "latest.tar").exists(): 131 | ckpt = torch.load(save_path / "latest.tar", map_location="cpu") 132 | model.module.load_state_dict(ckpt["model"]) 133 | optimizer.load_state_dict(ckpt["optimizer"]) 134 | start_epoch = ckpt["epoch"] 135 | start_step = ckpt["step"] 136 | best_mse = ckpt["best_mse"] 137 | 138 | # 训练、验证 139 | for epoch in range(start_epoch, epochs): 140 | train_loader.sampler.set_epoch(epoch) 141 | 142 | train_one_epoch(epoch, start_step, model, train_loader, optimizer, criterion, standard_scaler) 143 | train_loss = eval(model, criterion, train_loader, standard_scaler) 144 | val_loss = eval(model, criterion, val_loader, standard_scaler) 145 | 146 | # 保存 147 | if rank == 0: 148 | print(f"Epoch: {epoch}, train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}") 149 | 150 | if val_loss < best_mse: 151 | best_mse = val_loss 152 | print(f"New Best MSE: {best_mse:.4f}!") 153 | torch.save(model.module.state_dict(), save_path / "best.pt") 154 | 155 | torch.cuda.empty_cache() 156 | 157 | 158 | def main(local_rank): 159 | 160 | # 多机通信 161 | ip = os.environ.get("MASTER_ADDR", '127.0.0.1') 162 | port = os.environ.get("MASTER_PORT", '8899') 163 | hosts = int(os.environ.get("WORLD_SIZE", '1')) # 机器个数 164 | rank = int(os.environ.get("RANK", '0')) # 当前机器编号 165 | gpus = torch.cuda.device_count() # 每台机器的GPU个数 166 | 167 | # world_size是全局GPU个数,rank是当前GPU全局编号 168 | dist.init_process_group( 169 | backend="nccl", init_method=f"tcp://{ip}:{port}", world_size=hosts * gpus, rank=rank * gpus + local_rank 170 | ) 171 | torch.cuda.set_device(local_rank) 172 | 173 | train_loader, standard_scaler, encoder_dim, decoder_dim, output_dim = get_dataloader(data_name, seq_len, label_len, pred_len, features, batch_size, num_workers, mode='train') 174 | val_loader, _, _, _, _ = get_dataloader(data_name, seq_len, label_len, pred_len, features, batch_size, num_workers, mode='val') 175 | 176 | if model_name == 'autoformer': 177 | model = Autoformer( 178 | enc_in=encoder_dim, 179 | dec_in=decoder_dim, 180 | c_out=output_dim, 181 | seq_len=seq_len, 182 | label_len=label_len, 183 | out_len=pred_len, 184 | e_layers=2, 185 | d_layers=1, 186 | d_ff=2048, 187 | factor=3, 188 | dropout=0.05, 189 | embed="timeF" 190 | ) 191 | elif model_name == 'informer': 192 | model = Informer( 193 | enc_in=encoder_dim, 194 | dec_in=decoder_dim, 195 | c_out=output_dim, 196 | seq_len=seq_len, 197 | label_len=label_len, 198 | out_len=pred_len, 199 | e_layers=2, 200 | d_layers=1, 201 | d_ff=2048, 202 | factor=3, 203 | dropout=0.05, 204 | embed="timeF" 205 | ) 206 | else: 207 | raise KeyError(f'{model_name} cannot be implemented') 208 | 209 | model = DistributedDataParallel(model.cuda(), device_ids=[local_rank]) 210 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 211 | criterion = torch.nn.MSELoss() 212 | 213 | fit(model, optimizer, criterion, train_loader, val_loader, standard_scaler) 214 | 215 | 216 | if __name__ == "__main__": 217 | ngpus = torch.cuda.device_count() 218 | torch.multiprocessing.spawn(main, args=(), nprocs=ngpus) 219 | --------------------------------------------------------------------------------