├── .gitignore ├── Dockerfile ├── README.md ├── data_provider ├── data_factory.py └── data_loader.py ├── environment.yml ├── experiments ├── exp_basic.py └── exp_long_term_forecasting.py ├── figures ├── Efficiency.jpg ├── Framework.png └── Long_term_forecast_results.jpg ├── layers ├── Embed.py ├── SWTAttention_Family.py ├── StandardNorm.py └── Transformer_Encoder.py ├── model └── SimpleTM.py ├── run.py ├── scripts └── multivariate_forecasting │ ├── ECL │ └── SimpleTM.sh │ ├── ETT │ ├── SimpleTM_h1.sh │ ├── SimpleTM_h2.sh │ ├── SimpleTM_m1.sh │ └── SimpleTM_m2.sh │ ├── PEMS │ ├── SimpleTM_03.sh │ ├── SimpleTM_04.sh │ ├── SimpleTM_07.sh │ └── SimpleTM_08.sh │ ├── SolarEnergy │ └── SimpleTM.sh │ ├── Traffic │ └── SimpleTM.sh │ └── Weather │ └── SimpleTM.sh └── utils ├── masking.py ├── metrics.py ├── timefeatures.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.10-py3 2 | 3 | ENV PYTHONPATH=/workspace 4 | ENV PYTHONUNBUFFERED=1 5 | 6 | RUN pip install --no-cache-dir \ 7 | einops==0.8.1 \ 8 | matplotlib==3.7.0 \ 9 | numpy==1.23.5 \ 10 | scikit-learn==1.2.2 \ 11 | scipy==1.10.1 \ 12 | pandas==1.5.3 \ 13 | reformer-pytorch==1.4.4 \ 14 | PyWavelets 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimpleTM 2 | The repo is the official implementation for the paper: [[ICLR '25] SimpleTM: A Simple Baseline For Multivariate Time Series Forcasting](https://openreview.net/pdf?id=oANkBaVci5). 3 | 4 | 5 | # Introduction 6 | We propose SimpleTM, a simple yet effective architecture that uniquely integrates classical signal processing ideas with a slightly modified attention mechanism. 7 | 8 |

9 | 10 |

11 | 12 | We show that even a single-layer configuration can effectively capture intricate dependencies in multivariate time-series data, while maintaining minimal model complexity and parameter requirements. This streamlined construction achieves a performance profile surpassing (or on par with) most existing baselines across nearly all publicly available benchmarks. 13 | 14 | 17 | 18 | 19 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 85 | 88 | 89 | 92 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 125 | 128 | 131 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 164 | 167 | 170 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 203 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 217 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 243 | 246 | 249 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 284 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 298 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 323 | 326 | 327 | 328 | 329 | 330 | 333 | 334 | 335 | 336 | 337 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 362 | 365 | 368 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 401 | 404 | 407 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 440 | 443 | 446 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 483 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 538 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 578 | 581 | 582 | 583 | 584 | 585 | 586 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 618 | 621 | 622 | 623 | 624 | 625 | 628 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 658 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 701 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 717 | 720 | 721 | 722 | 723 | 724 | 725 | 726 | 727 | 728 | 729 | 730 | 731 | 732 | 733 | 734 | 735 | 736 | 737 | 738 | 739 | 740 | 741 | 742 | 743 | 744 | 745 | 746 | 747 | 748 | 749 | 750 | 751 | 754 | 757 | 758 | 759 | 760 | 761 | 762 | 763 | 766 | 769 | 770 | 771 | 772 | 773 | 774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 | 785 | 786 | 787 | 788 | 789 | 790 | 791 | 794 | 797 | 798 | 799 | 800 | 801 | 802 | 803 | 806 | 809 | 810 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | 818 | 819 | 820 | 821 | 822 | 823 | 824 | 825 | 826 | 827 | 828 | 829 | 830 | 831 | 834 | 837 | 838 | 839 | 840 | 841 | 842 | 843 | 846 | 849 | 850 | 851 | 852 | 853 | 854 | 855 | 856 | 857 | 858 | 859 | 860 | 861 | 862 | 863 | 864 | 865 | 866 | 867 | 868 | 869 | 870 | 871 | 874 | 877 | 878 | 879 | 880 | 881 | 882 | 883 | 886 | 889 | 890 | 891 | 892 | 893 | 894 | 895 | 896 | 897 | 898 | 899 | 900 | 901 | 902 | 903 | 904 | 905 | 906 | 907 | 908 | 909 | 910 | 911 | 912 | 913 | 914 | 917 | 920 | 921 | 922 | 923 | 924 | 927 | 930 | 931 | 932 | 933 | 934 | 935 | 936 | 937 | 938 | 939 | 940 | 941 | 942 | 943 | 944 | 945 | 946 | 947 | 948 | 949 | 950 | 951 | 952 | 953 | 954 | 955 | 956 | 957 | 958 | 959 | 960 | 961 | 962 | 963 | 964 | 965 | 966 | 967 | 968 | 971 | 974 | 975 | 976 | 977 | 978 | 981 | 984 | 985 | 986 | 987 | 988 | 989 | 990 | 991 | 992 | 993 | 994 | 995 | 996 | 997 | 998 | 999 | 1000 | 1001 | 1002 | 1003 | 1004 | 1005 | 1006 | 1007 | 1008 | 1009 | 1010 | 1011 | 1014 | 1017 | 1018 | 1019 | 1020 | 1021 | 1024 | 1027 | 1028 | 1029 | 1030 | 1031 | 1032 | 1033 | 1034 | 1035 | 1036 | 1037 | 1038 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1045 | 1046 | 1047 | 1048 | 1049 | 1050 | 1051 | 1052 | 1053 | 1054 | 1057 | 1060 | 1061 | 1062 | 1065 | 1066 | 1067 | 1068 | 1069 | 1070 | 1071 | 1072 | 1073 | 1074 | 1075 | 1076 | 1077 | 1080 | 1081 | 1082 | 1083 | 1084 | 1085 | 1086 | 1087 | 1088 | 1089 | 1090 | 1091 | 1092 | 1093 | 1094 | 1095 | 1096 | 1099 | 1102 | 1103 | 1106 | 1109 | 1110 | 1111 | 1112 | 1113 | 1114 | 1115 | 1116 | 1117 | 1118 | 1119 | 1120 | 1121 | 1122 | 1123 | 1124 | 1125 | 1126 | 1127 | 1128 | 1129 | 1130 | 1131 | 1132 | 1133 | 1134 | 1135 | 1136 | 1137 | 1138 | 1139 | 1140 | 1141 | 1144 | 1145 | 1146 | 1149 | 1150 | 1151 | 1152 | 1153 | 1156 | 1157 | 1158 | 1159 | 1160 | 1161 | 1162 | 1163 | 1164 | 1167 | 1168 | 1169 | 1170 | 1171 | 1172 | 1173 | 1174 | 1175 | 1176 | 1177 | 1178 | 1179 | 1180 | 1181 | 1182 | 1183 | 1184 | 1185 | 1186 | 1187 | 1188 | 1189 | 1190 | 1191 | 1192 | 1193 | 1194 | 1197 | 1200 | 1201 | 1202 | 1205 | 1206 | 1207 | 1208 | 1209 | 1210 | 1211 | 1212 | 1213 | 1214 | 1217 | 1218 | 1219 | 1220 | 1221 | 1222 | 1223 | 1224 | 1225 | 1226 | 1227 | 1228 | 1229 | 1230 | 1231 | 1232 | 1233 | 1234 | 1235 | 1236 | 1237 | 1240 | 1243 | 1244 | 1247 | 1250 | 1251 | 1252 | 1253 | 1254 | 1255 | 1256 | 1257 | 1258 | 1259 | 1260 | 1261 | 1262 | 1263 | 1264 | 1265 | 1266 | 1267 | 1268 | 1269 | 1270 | 1271 | 1272 | 1273 | 1274 | 1275 | 1276 | 1277 | 1278 | 1279 | 1282 | 1285 | 1286 | 1289 | 1292 | 1293 | 1294 | 1295 | 1296 | 1297 | 1298 | 1299 | 1300 | 1301 | 1302 | 1303 | 1304 | 1305 | 1306 | 1307 | 1308 | 1309 | 1310 | 1311 | 1312 | 1313 | 1314 | 1315 | 1316 | 1317 | 1318 | 1319 | 1320 | 1321 | 1324 | 1327 | 1328 | 1331 | 1334 | 1335 | 1336 | 1337 | 1338 | 1339 | 1340 | 1341 | 1342 | 1343 | 1344 | 1345 | 1346 | 1347 | 1348 | 1349 | 1350 | 1351 | 1352 | 1353 | 1354 | 1355 | 1356 | 1357 | 1358 | 1359 | 1360 | 1361 | 1362 | 1363 | 1364 | 1365 | 1368 | 1371 | 1372 | 1373 | 1374 | 1375 | 1378 | 1381 | 1382 | 1383 | 1384 | 1385 | 1386 | 1387 | 1388 | 1389 | 1390 | 1391 | 1392 | 1393 | 1394 | 1395 | 1396 | 1397 | 1398 | 1399 | 1400 | 1401 | 1402 | 1403 | 1404 | 1405 | 1406 | 1407 | 1410 | 1413 | 1414 | 1415 | 1416 | 1419 | 1422 | 1423 | 1424 | 1425 | 1426 | 1427 | 1428 | 1429 | 1430 | 1431 | 1432 | 1433 | 1434 | 1435 | 1436 | 1437 | 1438 | 1439 | 1440 | 1441 | 1442 | 1443 | 1444 | 1445 | 1446 | 1447 | 1448 | 1449 | 1452 | 1455 | 1456 | 1457 | 1458 | 1461 | 1464 | 1465 | 1466 | 1467 | 1468 | 1469 | 1470 | 1471 | 1472 | 1473 | 1474 | 1475 | 1476 | 1477 | 1478 | 1479 | 1480 | 1481 | 1482 | 1483 | 1484 | 1485 | 1486 | 1487 | 1488 | 1489 | 1490 | 1491 | 1494 | 1497 | 1498 | 1499 | 1500 | 1503 | 1506 | 1507 | 1508 | 1509 | 1510 | 1511 | 1512 | 1513 | 1514 | 1515 | 1516 | 1517 | 1518 | 1519 | 1520 | 1521 | 1522 | 1523 | 1524 | 1525 | 1526 | 1527 | 1528 | 1529 | 1530 | 1531 | 1532 | 1533 | 1536 | 1539 | 1540 | 1541 | 1542 | 1545 | 1548 | 1549 | 1550 | 1551 | 1552 | 1553 | 1554 | 1555 | 1556 | 1557 | 1558 | 1559 | 1560 | 1561 | 1562 | 1563 | 1564 | 1565 | 1566 | 1567 | 1568 | 1569 | 1570 | 1571 | 1572 | 1573 | 1574 | 1575 | 1576 | 1577 | 1580 | 1583 | 1584 | 1585 | 1586 | 1587 | 1590 | 1593 | 1594 | 1595 | 1596 | 1597 | 1598 | 1599 | 1600 | 1601 | 1602 | 1603 | 1604 | 1605 | 1606 | 1607 | 1608 | 1609 | 1610 | 1611 | 1612 | 1613 | 1614 | 1615 | 1616 | 1617 | 1618 | 1619 | 1620 | 1621 | 1622 | 1623 | 1624 | 1625 | 1626 | 1627 | 1628 | 1629 | 1630 | 1631 | 1634 | 1637 | 1638 | 1639 | 1640 | 1641 | 1644 | 1647 | 1648 | 1649 | 1650 | 1651 | 1652 | 1653 | 1654 | 1655 | 1656 | 1657 | 1658 | 1659 | 1660 | 1661 | 1662 | 1663 | 1664 | 1665 | 1666 | 1667 | 1668 | 1669 | 1670 | 1671 | 1672 | 1673 | 1674 | 1677 | 1680 | 1681 | 1682 | 1683 | 1684 | 1687 | 1690 | 1691 | 1692 | 1693 | 1694 | 1695 | 1696 | 1697 | 1698 | 1699 | 1700 | 1701 | 1702 | 1703 | 1704 | 1705 | 1706 | 1707 | 1708 | 1709 | 1710 | 1711 | 1712 | 1713 | 1714 | 1715 | 1716 | 1717 | 1720 | 1723 | 1724 | 1725 | 1726 | 1729 | 1732 | 1733 | 1734 | 1735 | 1736 | 1737 | 1738 | 1739 | 1740 | 1741 | 1742 | 1743 | 1744 | 1745 | 1746 | 1747 | 1748 | 1749 | 1750 | 1751 | 1752 | 1753 | 1754 | 1755 | 1756 | 1757 | 1758 | 1759 | 1762 | 1765 | 1766 | 1767 | 1768 | 1771 | 1774 | 1775 | 1776 | 1777 | 1778 | 1779 | 1780 | 1781 | 1782 | 1783 | 1784 | 1785 | 1786 | 1787 | 1788 | 1789 | 1790 | 1791 | 1792 | 1793 | 1794 | 1795 | 1796 | 1797 | 1798 |
20 | Table 6: Complete results of the long-term forecasting task, with an input length of 96 for all tasks. The reported metrics include the averaged Mean Squared Error (MSE) and Mean Absolute Error (MAE) across four prediction horizons, where lower values indicate better model performance. 21 |
DatasetHorizonSimpleTM (Ours)TimeMixer (2024)iTransformer (2024)CrossGNN (2024)RLinear (2023)PatchTST (2023)Crossformer (2023)TiDE (2023)TimesNet (2023)DLinear (2023)SCINet (2022)FEDformer (2022)Stationary (2022)Autoformer (2021)
MSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAEMSEMAE
ETTm196 83 | 0.321 84 | 86 | 0.361 87 | 90 | 0.328 91 | 93 | 0.363 94 | 0.3340.3680.3350.3730.3550.3760.3290.3670.4040.4260.3640.3870.3380.3750.3450.3720.4180.4380.3790.4190.3860.3980.5050.475
192 123 | 0.360 124 | 126 | 0.380 127 | 129 | 0.364 130 | 132 | 0.384 133 | 0.3770.3910.3720.3900.3910.3920.3670.3850.4500.4510.3980.4040.3740.3870.3800.3890.4390.4500.4260.4410.4590.4440.5530.496
336 162 | 0.390 163 | 165 | 0.404 166 | 168 | 0.390 169 | 171 | 0.404 172 | 0.4260.4200.4030.4110.4240.4150.3990.4100.5320.5150.4280.4250.4100.4110.4130.4130.4900.4850.4450.4590.4950.4640.6210.537
720 201 | 0.454 202 | 204 | 0.438 205 | 0.4580.4450.4910.4590.4610.4420.4870.450 215 | 0.454 216 | 218 | 0.439 219 | 0.6660.5890.4870.4610.4780.4500.4740.4530.5950.5500.5430.4900.5850.5160.6710.561
Avg 241 | 0.381 242 | 244 | 0.396 245 | 247 | 0.385 248 | 250 | 0.399 251 | 0.4070.4100.3930.4040.4140.4070.3870.4000.5130.4960.4190.4190.4000.4060.4030.4070.4850.4810.4480.4520.4810.4560.5880.517
ETTm296 282 | 0.173 283 | 285 | 0.257 286 | 0.1760.2590.1800.2640.1760.2660.1820.265 296 | 0.175 297 | 299 | 0.259 300 | 0.2870.3660.2070.3050.1870.2670.1930.2920.2860.3770.2030.2870.1920.2740.2550.339
192 321 | 0.238 322 | 324 | 0.299 325 | 0.2420.3030.2500.309 331 | 0.240 332 | 0.3070.2460.3040.241 338 | 0.302 339 | 0.4140.4920.2900.3640.2490.3090.2840.3620.3990.4450.2690.3280.2800.3390.2810.340
336 360 | 0.296 361 | 363 | 0.338 364 | 366 | 0.304 367 | 369 | 0.342 370 | 0.3110.3480.3040.3450.3070.3420.3050.3430.5970.5420.3770.4220.3210.3510.3690.4270.6370.5910.3250.3660.3340.3610.3390.372
720 399 | 0.393 400 | 402 | 0.395 403 | 405 | 0.393 406 | 408 | 0.397 409 | 0.4120.4070.4060.4000.4070.3980.4020.4001.7301.0420.5580.5240.4080.4030.5540.5220.9600.7350.4210.4150.4170.4130.4330.432
Avg 438 | 0.275 439 | 441 | 0.322 442 | 444 | 0.278 445 | 447 | 0.325 448 | 0.2880.3320.2820.3300.2860.3270.2810.3260.7570.6100.3580.4040.2910.3330.3500.4010.5710.5370.3050.3490.3060.3470.3270.371
ETTh196 481 | 0.366 482 | 484 | 0.392 485 | 0.3810.4010.3860.4050.3820.3980.386 498 | 0.395 499 | 0.4140.4190.4230.4480.4790.4640.3840.4020.3860.4000.6540.599 520 | 0.376 521 | 0.4190.5130.4910.4490.459
192 536 | 0.422 537 | 539 | 0.421 540 | 0.4400.4330.4410.4360.4270.4250.437 549 | 0.424 550 | 0.4600.4450.4710.4740.5250.4920.4360.4290.4370.4320.7190.631 564 | 0.420 565 | 0.4480.5340.5040.5000.482
336 576 | 0.440 577 | 579 | 0.438 580 | 0.5010.4620.4870.4580.465 587 | 0.445 588 | 0.4790.4460.5010.4660.5700.5460.5650.5150.4910.4690.4810.4590.7780.659 604 | 0.459 605 | 0.4650.5880.5350.5210.496
720 616 | 0.463 617 | 619 | 0.462 620 | 0.5010.4820.5030.491 626 | 0.472 627 | 629 | 0.468 630 | 0.4810.4700.5000.4880.6530.6210.5940.5580.5210.5000.5190.5160.8360.6990.5060.5070.6430.6160.5140.512
Avg 656 | 0.422 657 | 659 | 0.428 660 | 0.4580.4450.4540.4470.4370.4340.446 669 | 0.434 670 | 0.4690.4540.5290.5220.5410.5070.4580.4500.4560.4520.7470.647 684 | 0.440 685 | 0.4600.5700.5370.4960.487
ETTh296 699 | 0.281 700 | 702 | 0.338 703 | 0.2920.3430.2970.3490.3090.359 715 | 0.288 716 | 718 | 0.338 719 | 0.3020.3480.7450.5840.4000.4400.3400.3740.3330.3870.7070.6210.3580.3970.4760.4580.3460.388
192 752 | 0.355 753 | 755 | 0.387 756 | 0.3740.3950.3800.4000.3900.406 764 | 0.374 765 | 767 | 0.390 768 | 0.3880.4000.8770.6560.5280.5090.4020.4140.4770.4760.8600.6890.4290.4390.5120.4930.4560.452
336 792 | 0.365 793 | 795 | 0.401 796 | 0.4280.4330.4280.4320.4260.444 804 | 0.415 805 | 807 | 0.426 808 | 0.4260.4331.0430.7310.6430.5710.4520.4520.5940.5411.0000.7440.4960.4870.5520.5510.4820.486
720 832 | 0.413 833 | 835 | 0.436 836 | 0.4540.4580.4270.4450.4450.444 844 | 0.420 845 | 847 | 0.440 848 | 0.4310.4461.1040.7630.8740.6790.4620.4680.8310.6571.2490.8380.4630.4740.5620.5600.5150.511
Avg 872 | 0.353 873 | 875 | 0.391 876 | 0.3840.4070.3830.4070.3930.413 884 | 0.374 885 | 887 | 0.398 888 | 0.3870.4070.9420.6840.6110.5500.4140.4270.5590.5150.9540.7230.4370.4490.5260.5160.4500.459
ECL96 915 | 0.141 916 | 918 | 0.235 919 | 0.1530.244 925 | 0.148 926 | 928 | 0.240 929 | 0.1730.2750.2010.2810.1810.2700.2190.3140.2370.3290.1680.2720.1970.2820.2470.3450.1930.3080.1690.2730.2010.317
192 969 | 0.151 970 | 972 | 0.247 973 | 0.1660.256 979 | 0.162 980 | 982 | 0.253 983 | 0.1950.2880.2010.2830.1880.2740.2310.3220.2360.3300.1840.2890.1960.2850.2570.3550.2010.3150.1820.2860.2220.334
336 1012 | 0.173 1013 | 1015 | 0.267 1016 | 0.1840.275 1022 | 0.178 1023 | 1025 | 0.269 1026 | 0.2060.3000.2150.2980.2040.2930.2460.3370.2490.3440.1980.3000.2090.3010.2690.3690.2140.3290.2000.3040.2310.338
720 1055 | 0.201 1056 | 1058 | 0.293 1059 | 0.226 1063 | 0.313 1064 | 0.2250.3170.2310.3350.2570.3310.2460.3240.2800.3630.2840.373 1078 | 0.220 1079 | 0.3200.2450.3330.2990.3900.2460.3550.2220.3210.2540.361
Avg 1097 | 0.166 1098 | 1100 | 0.260 1101 | 1104 | 0.178 1105 | 1107 | 0.270 1108 | 0.2010.3000.2190.2980.2050.2900.2440.3340.2510.3440.1920.2950.2120.3000.2680.3650.2140.3270.1930.2960.2270.338
Weather960.162 1142 | 0.207 1143 | 0.165 1147 | 0.212 1148 | 0.1740.214 1154 | 0.159 1155 | 0.2180.1920.2320.1770.218 1165 | 0.158 1166 | 0.2300.2020.2610.1720.2200.1960.2550.2210.3060.2170.2960.1730.2230.2660.336
192 1195 | 0.208 1196 | 1198 | 0.248 1199 | 0.209 1203 | 0.253 1204 | 0.2210.2540.2110.2660.2400.2710.2250.259 1215 | 0.206 1216 | 0.2770.2420.2980.2190.2610.2370.2960.2610.3400.2760.3360.2450.2850.3070.367
336 1238 | 0.263 1239 | 1241 | 0.290 1242 | 1245 | 0.264 1246 | 1248 | 0.293 1249 | 0.2780.2960.2670.3100.2920.3070.2780.2970.2720.3350.2870.3350.2800.3060.2830.3350.3090.3780.3390.3800.3210.3380.3590.395
720 1280 | 0.340 1281 | 1283 | 0.341 1284 | 1287 | 0.342 1288 | 1290 | 0.345 1291 | 0.3580.3470.3520.3620.3640.3530.3540.3480.3980.4180.3510.3860.3650.3590.3450.3810.3770.4270.4030.4280.4140.4100.4190.428
Avg 1322 | 0.243 1323 | 1325 | 0.271 1326 | 1329 | 0.245 1330 | 1332 | 0.276 1333 | 0.2580.2780.2470.2890.2720.2910.2590.2810.2590.3150.2710.3200.2590.2870.2650.3170.2920.3630.3090.3600.2880.3140.3380.382
Traffic96 1366 | 0.410 1367 | 1369 | 0.274 1370 | 0.4640.289 1376 | 0.395 1377 | 1379 | 0.268 1380 | 0.5700.3100.6490.3890.4620.2950.5220.2900.8050.4930.5930.3210.6500.3960.7880.4990.5870.3660.6120.3380.6130.388
192 1408 | 0.430 1409 | 1411 | 0.280 1412 | 0.4770.292 1417 | 0.417 1418 | 1420 | 0.276 1421 | 0.5770.3210.6010.3660.4660.2960.5300.2930.7560.4740.6170.3360.5980.3700.7890.5050.6040.3730.6130.3400.6160.382
336 1450 | 0.449 1451 | 1453 | 0.290 1454 | 0.5000.305 1459 | 0.433 1460 | 1462 | 0.283 1463 | 0.5880.3240.6090.3690.4820.3040.5580.3050.7620.4770.6290.3360.6050.3730.7970.5080.6210.3830.6180.3280.6220.337
720 1492 | 0.486 1493 | 1495 | 0.309 1496 | 0.5480.313 1501 | 0.467 1502 | 1504 | 0.302 1505 | 0.5970.3370.6470.3870.5140.3220.5890.3280.7190.4490.6400.3500.6450.3940.8410.5230.6260.3820.6530.3550.6600.408
Avg 1534 | 0.444 1535 | 1537 | 0.289 1538 | 0.4970.300 1543 | 0.428 1544 | 1546 | 0.282 1547 | 0.5830.3230.6260.3780.4810.3040.5500.3040.7600.4730.6200.3360.6250.3830.8040.5090.6100.3760.6240.3400.6280.379
Solar-Energy96 1578 | 0.163 1579 | 1581 | 0.232 1582 | 0.2150.294 1588 | 0.203 1589 | 1591 | 0.237 1592 | 0.2220.3010.3220.3390.2340.2860.3100.3310.3120.3990.2500.2920.2900.3780.2370.3440.2420.3420.2150.2490.8840.711
192 1632 | 0.182 1633 | 1635 | 0.247 1636 | 0.2370.275 1642 | 0.233 1643 | 1645 | 0.261 1646 | 0.2460.3070.3590.3560.2670.3100.7340.7250.3390.4160.2960.3180.3200.3980.2800.3800.2850.3800.2540.2720.8340.692
336 1675 | 0.193 1676 | 1678 | 0.257 1679 | 0.2520.298 1685 | 0.248 1686 | 1688 | 0.273 1689 | 0.2630.3240.3970.3690.2900.3150.7500.7350.3680.4300.3190.3300.3530.4150.3040.3890.2820.3760.2900.2960.9410.723
720 1718 | 0.199 1719 | 1721 | 0.252 1722 | 0.2440.293 1727 | 0.249 1728 | 1730 | 0.275 1731 | 0.2650.3180.3970.3560.2890.3170.7690.7650.3700.4250.3380.3370.3560.4130.3080.3880.3570.4270.2850.2950.8820.717
Avg 1760 | 0.184 1761 | 1763 | 0.247 1764 | 0.2370.290 1769 | 0.233 1770 | 1772 | 0.262 1773 | 0.2490.3130.3690.3560.2700.3070.6410.6390.3470.4170.3010.3190.3300.4010.2820.3750.2910.3810.2610.3810.8850.711
1799 | 1800 | 1801 | # Get Started 1802 | 1803 | ## 1. Download the Data 1804 | 1805 | All datasets have been preprocessed and are ready for use. You can obtain them from their original sources: 1806 | 1807 | - **ETT**: [https://github.com/zhouhaoyi/ETDataset/tree/main](https://github.com/zhouhaoyi/ETDataset/tree/main) 1808 | - **Traffic, Electricity, Weather**: [https://github.com/thuml/Autoformer](https://github.com/thuml/Autoformer?tab=readme-ov-file) 1809 | - **Solar**: [https://github.com/laiguokun/LSTNet](https://github.com/laiguokun/LSTNet) 1810 | - **PEMS**: [https://github.com/cure-lab/SCINet](https://github.com/cure-lab/SCINet?tab=readme-ov-file) 1811 | 1812 | For convenience, we provide a comprehensive package containing all required datasets, available for download from [Google Drive](https://drive.google.com/file/d/1hTpUrhe1yEIGa9mCiGxM5rDyzlYKAnyx/view?usp=sharing). You can place it under the folder [./dataset](./dataset/). 1813 | 1814 | ## 2. Setup Your Environment 1815 | 1816 | Choose one of the following methods to set up your environment: 1817 | 1818 | ### Option A: Anaconda 1819 | Create and activate a Python environment using the provided configuration file [environment.yml](./environment.yml): 1820 | 1821 | ```bash 1822 | conda env create -f environment.yml -n SimpleTM 1823 | conda activate SimpleTM 1824 | ``` 1825 | 1826 | ### Option B: Docker 1827 | If you prefer Docker, build an image using the provided [Dockerfile](./Dockerfile): 1828 | 1829 | ```bash 1830 | docker build --tag simpletm:latest . 1831 | ``` 1832 | 1833 | 1834 | ## 3. Train the Model 1835 | 1836 | Experiment scripts for various benchmarks are provided in the [`scripts`](./scripts) directory. You can reproduce experiment results as follows: 1837 | 1838 | ```bash 1839 | bash ./scripts/multivariate_forecasting/ETT/SimpleTM_h1.sh # ETTh1 1840 | bash ./scripts/multivariate_forecasting/ECL/SimpleTM.sh # Electricity 1841 | bash ./scripts/long_term_forecast/SolarEnergy/SimpleTM.sh # Solar-Energy 1842 | bash ./scripts/long_term_forecast/Weather/SimpleTM.sh # Weather 1843 | bash ./scripts/short_term_forecast/PEMS/SimpleTM_03.sh # PEMS03 1844 | ``` 1845 | 1846 | ### Docker Users 1847 | If you're using Docker, run the scripts with the following command structure (example for ETTh1): 1848 | 1849 | ```bash 1850 | docker run --gpus all -it --rm --ipc=host \ 1851 | --user $(id -u):$(id -g) \ 1852 | -v "$(pwd)":/scratch --workdir /scratch -e HOME=/scratch \ 1853 | simpletm:latest \ 1854 | bash scripts/multivariate_forecasting/ETT/SimpleTM_h1.sh 1855 | ``` 1856 | 1857 | 1858 | # Model Efficiency 1859 | To provide an efficiency comparison, we evaluated our model against two of the most competitive baselines: the transformer-based iTransformer and linear-based TimeMixer. Our experimental setup used a consistent batch size of 256 across all models and measured four key metrics: total trainable parameters, inference time, GPU memory footprint, and peak memory usage during the backward pass. Results for all baseline models were compiled using PyTorch. 1860 | 1861 | Please note that our default experimental configuration does not employ compilation optimizations. To speed up, enable the --compile flag in the scripts. 1862 | 1863 | 1866 | 1867 | 1868 | 1871 | 1872 | 1873 | 1874 | 1875 | 1876 | 1877 | 1878 | 1879 | 1880 | 1881 | 1882 | 1883 | 1884 | 1885 | 1886 | 1887 | 1888 | 1889 | 1890 | 1891 | 1892 | 1893 | 1894 | 1895 | 1896 | 1897 | 1898 | 1899 | 1900 | 1901 | 1902 | 1903 | 1904 | 1905 | 1906 | 1907 | 1908 | 1909 | 1910 | 1911 | 1912 | 1913 | 1914 | 1915 | 1916 | 1917 | 1918 | 1919 | 1920 | 1921 | 1922 | 1923 | 1924 | 1925 | 1926 | 1927 | 1928 | 1929 | 1930 | 1931 | 1932 | 1933 | 1934 | 1935 | 1936 |
1869 | Table 13: Comparison of model performance and resource utilization across different datasets. Metrics include Mean Squared Error (MSE), total parameter count, inference time (seconds), GPU memory footprint (MB), and peak memory usage (MB). 1870 |
DatasetModelMSETotal ParamsInference Time (s)GPU Mem Footprint (MB)Peak Mem (MB)
WeatherSimpleTM0.16213,4720.0132994181.75
TimeMixer0.164104,4330.04532,9542,281.38
iTransformer0.1764,833,8880.02221,596847.62
SolarSimpleTM0.163166,3040.04552,0481,181.56
TimeMixer0.21513,009,0790.26447,5766,632.40
iTransformer0.2033,255,9040.06634,0222,776.50
1937 | 1938 | 1939 | # Acknowledgement 1940 | 1941 | We appreciate the following GitHub repos a lot for their valuable code and efforts. 1942 | - Time-Series-Library (https://github.com/thuml/Time-Series-Library) 1943 | - iTransformer (https://github.com/thuml/iTransformer) 1944 | - TimeMixer (https://github.com/kwuking/TimeMixer) 1945 | - Autoformer (https://github.com/thuml/Autoformer) 1946 | 1947 | 1948 | # Citation 1949 | If you find this repo helpful, please cite our paper. 1950 | 1951 | ```bibtex 1952 | @inproceedings{ 1953 | chen2025simpletm, 1954 | title={Simple{TM}: A Simple Baseline for Multivariate Time Series Forecasting}, 1955 | author={Hui Chen and Viet Luong and Lopamudra Mukherjee and Vikas Singh}, 1956 | booktitle={The Thirteenth International Conference on Learning Representations}, 1957 | year={2025}, 1958 | url={https://openreview.net/forum?id=oANkBaVci5} 1959 | } 1960 | ``` -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Solar, Dataset_PEMS, \ 2 | Dataset_Pred 3 | from torch.utils.data import DataLoader 4 | 5 | data_dict = { 6 | 'ETTh1': Dataset_ETT_hour, 7 | 'ETTh2': Dataset_ETT_hour, 8 | 'ETTm1': Dataset_ETT_minute, 9 | 'ETTm2': Dataset_ETT_minute, 10 | 'Solar': Dataset_Solar, 11 | 'PEMS': Dataset_PEMS, 12 | 'custom': Dataset_Custom, 13 | } 14 | 15 | 16 | def data_provider(args, flag): 17 | Data = data_dict[args.data] 18 | timeenc = 0 if args.embed != 'timeF' else 1 19 | 20 | if flag == 'test': 21 | shuffle_flag = False 22 | drop_last = True 23 | batch_size = args.batch_size 24 | freq = args.freq 25 | elif flag == 'pred': 26 | shuffle_flag = False 27 | drop_last = False 28 | batch_size = 1 29 | freq = args.freq 30 | Data = Dataset_Pred 31 | else: 32 | shuffle_flag = True 33 | drop_last = True 34 | batch_size = args.batch_size 35 | freq = args.freq 36 | 37 | data_set = Data( 38 | root_path=args.root_path, 39 | data_path=args.data_path, 40 | flag=flag, 41 | size=[args.seq_len, args.label_len, args.pred_len], 42 | features=args.features, 43 | target=args.target, 44 | timeenc=timeenc, 45 | freq=freq, 46 | ) 47 | print(flag, len(data_set)) 48 | data_loader = DataLoader( 49 | data_set, 50 | batch_size=batch_size, 51 | shuffle=shuffle_flag, 52 | num_workers=args.num_workers, 53 | drop_last=drop_last) 54 | return data_set, data_loader 55 | -------------------------------------------------------------------------------- /data_provider/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from sklearn.preprocessing import StandardScaler 7 | from utils.timefeatures import time_features 8 | import warnings 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | class Dataset_ETT_hour(Dataset): 13 | def __init__(self, root_path, flag='train', size=None, 14 | features='S', data_path='ETTh1.csv', 15 | target='OT', scale=True, timeenc=0, freq='h'): 16 | if size == None: 17 | self.seq_len = 24 * 4 * 4 18 | self.label_len = 24 * 4 19 | self.pred_len = 24 * 4 20 | else: 21 | self.seq_len = size[0] 22 | self.label_len = size[1] 23 | self.pred_len = size[2] 24 | assert flag in ['train', 'test', 'val'] 25 | type_map = {'train': 0, 'val': 1, 'test': 2} 26 | self.set_type = type_map[flag] 27 | 28 | self.features = features 29 | self.target = target 30 | self.scale = scale 31 | self.timeenc = timeenc 32 | self.freq = freq 33 | 34 | self.root_path = root_path 35 | self.data_path = data_path 36 | self.__read_data__() 37 | 38 | def __read_data__(self): 39 | self.scaler = StandardScaler() 40 | df_raw = pd.read_csv(os.path.join(self.root_path, 41 | self.data_path)) 42 | 43 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len] 44 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24] 45 | border1 = border1s[self.set_type] 46 | border2 = border2s[self.set_type] 47 | 48 | if self.features == 'M' or self.features == 'MS': 49 | cols_data = df_raw.columns[1:] 50 | df_data = df_raw[cols_data] 51 | elif self.features == 'S': 52 | df_data = df_raw[[self.target]] 53 | 54 | if self.scale: 55 | train_data = df_data[border1s[0]:border2s[0]] 56 | self.scaler.fit(train_data.values) 57 | data = self.scaler.transform(df_data.values) 58 | else: 59 | data = df_data.values 60 | 61 | df_stamp = df_raw[['date']][border1:border2] 62 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 63 | if self.timeenc == 0: 64 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 65 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 66 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 67 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 68 | data_stamp = df_stamp.drop(['date'], 1).values 69 | elif self.timeenc == 1: 70 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 71 | data_stamp = data_stamp.transpose(1, 0) 72 | 73 | self.data_x = data[border1:border2] 74 | self.data_y = data[border1:border2] 75 | self.data_stamp = data_stamp 76 | 77 | def __getitem__(self, index): 78 | s_begin = index 79 | s_end = s_begin + self.seq_len 80 | r_begin = s_end - self.label_len 81 | r_end = r_begin + self.label_len + self.pred_len 82 | 83 | seq_x = self.data_x[s_begin:s_end] 84 | seq_y = self.data_y[r_begin:r_end] 85 | seq_x_mark = self.data_stamp[s_begin:s_end] 86 | seq_y_mark = self.data_stamp[r_begin:r_end] 87 | 88 | return seq_x, seq_y, seq_x_mark, seq_y_mark 89 | 90 | def __len__(self): 91 | return len(self.data_x) - self.seq_len - self.pred_len + 1 92 | 93 | def inverse_transform(self, data): 94 | return self.scaler.inverse_transform(data) 95 | 96 | 97 | class Dataset_ETT_minute(Dataset): 98 | def __init__(self, root_path, flag='train', size=None, 99 | features='S', data_path='ETTm1.csv', 100 | target='OT', scale=True, timeenc=0, freq='t'): 101 | if size == None: 102 | self.seq_len = 24 * 4 * 4 103 | self.label_len = 24 * 4 104 | self.pred_len = 24 * 4 105 | else: 106 | self.seq_len = size[0] 107 | self.label_len = size[1] 108 | self.pred_len = size[2] 109 | assert flag in ['train', 'test', 'val'] 110 | type_map = {'train': 0, 'val': 1, 'test': 2} 111 | self.set_type = type_map[flag] 112 | 113 | self.features = features 114 | self.target = target 115 | self.scale = scale 116 | self.timeenc = timeenc 117 | self.freq = freq 118 | 119 | self.root_path = root_path 120 | self.data_path = data_path 121 | self.__read_data__() 122 | 123 | def __read_data__(self): 124 | self.scaler = StandardScaler() 125 | df_raw = pd.read_csv(os.path.join(self.root_path, 126 | self.data_path)) 127 | 128 | border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len] 129 | border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4] 130 | border1 = border1s[self.set_type] 131 | border2 = border2s[self.set_type] 132 | 133 | if self.features == 'M' or self.features == 'MS': 134 | cols_data = df_raw.columns[1:] 135 | df_data = df_raw[cols_data] 136 | elif self.features == 'S': 137 | df_data = df_raw[[self.target]] 138 | 139 | if self.scale: 140 | train_data = df_data[border1s[0]:border2s[0]] 141 | self.scaler.fit(train_data.values) 142 | data = self.scaler.transform(df_data.values) 143 | else: 144 | data = df_data.values 145 | 146 | df_stamp = df_raw[['date']][border1:border2] 147 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 148 | if self.timeenc == 0: 149 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 150 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 151 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 152 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 153 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 154 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 155 | data_stamp = df_stamp.drop(['date'], 1).values 156 | elif self.timeenc == 1: 157 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 158 | data_stamp = data_stamp.transpose(1, 0) 159 | 160 | self.data_x = data[border1:border2] 161 | self.data_y = data[border1:border2] 162 | self.data_stamp = data_stamp 163 | 164 | def __getitem__(self, index): 165 | s_begin = index 166 | s_end = s_begin + self.seq_len 167 | r_begin = s_end - self.label_len 168 | r_end = r_begin + self.label_len + self.pred_len 169 | 170 | seq_x = self.data_x[s_begin:s_end] 171 | seq_y = self.data_y[r_begin:r_end] 172 | seq_x_mark = self.data_stamp[s_begin:s_end] 173 | seq_y_mark = self.data_stamp[r_begin:r_end] 174 | 175 | return seq_x, seq_y, seq_x_mark, seq_y_mark 176 | 177 | def __len__(self): 178 | return len(self.data_x) - self.seq_len - self.pred_len + 1 179 | 180 | def inverse_transform(self, data): 181 | return self.scaler.inverse_transform(data) 182 | 183 | 184 | class Dataset_Custom(Dataset): 185 | def __init__(self, root_path, flag='train', size=None, 186 | features='S', data_path='ETTh1.csv', 187 | target='OT', scale=True, timeenc=0, freq='h'): 188 | if size == None: 189 | self.seq_len = 24 * 4 * 4 190 | self.label_len = 24 * 4 191 | self.pred_len = 24 * 4 192 | else: 193 | self.seq_len = size[0] 194 | self.label_len = size[1] 195 | self.pred_len = size[2] 196 | assert flag in ['train', 'test', 'val'] 197 | type_map = {'train': 0, 'val': 1, 'test': 2} 198 | self.set_type = type_map[flag] 199 | 200 | self.features = features 201 | self.target = target 202 | self.scale = scale 203 | self.timeenc = timeenc 204 | self.freq = freq 205 | 206 | self.root_path = root_path 207 | self.data_path = data_path 208 | self.__read_data__() 209 | 210 | def __read_data__(self): 211 | self.scaler = StandardScaler() 212 | df_raw = pd.read_csv(os.path.join(self.root_path, 213 | self.data_path)) 214 | cols = list(df_raw.columns) 215 | cols.remove(self.target) 216 | cols.remove('date') 217 | df_raw = df_raw[['date'] + cols + [self.target]] 218 | num_train = int(len(df_raw) * 0.7) 219 | num_test = int(len(df_raw) * 0.2) 220 | num_vali = len(df_raw) - num_train - num_test 221 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 222 | border2s = [num_train, num_train + num_vali, len(df_raw)] 223 | border1 = border1s[self.set_type] 224 | border2 = border2s[self.set_type] 225 | 226 | if self.features == 'M' or self.features == 'MS': 227 | cols_data = df_raw.columns[1:] 228 | df_data = df_raw[cols_data] 229 | elif self.features == 'S': 230 | df_data = df_raw[[self.target]] 231 | 232 | if self.scale: 233 | train_data = df_data[border1s[0]:border2s[0]] 234 | self.scaler.fit(train_data.values) 235 | data = self.scaler.transform(df_data.values) 236 | else: 237 | data = df_data.values 238 | 239 | df_stamp = df_raw[['date']][border1:border2] 240 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 241 | if self.timeenc == 0: 242 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 243 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 244 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 245 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 246 | data_stamp = df_stamp.drop(['date'], 1).values 247 | elif self.timeenc == 1: 248 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 249 | data_stamp = data_stamp.transpose(1, 0) 250 | 251 | self.data_x = data[border1:border2] 252 | self.data_y = data[border1:border2] 253 | self.data_stamp = data_stamp 254 | 255 | def __getitem__(self, index): 256 | s_begin = index 257 | s_end = s_begin + self.seq_len 258 | r_begin = s_end - self.label_len 259 | r_end = r_begin + self.label_len + self.pred_len 260 | 261 | seq_x = self.data_x[s_begin:s_end] 262 | seq_y = self.data_y[r_begin:r_end] 263 | seq_x_mark = self.data_stamp[s_begin:s_end] 264 | seq_y_mark = self.data_stamp[r_begin:r_end] 265 | 266 | return seq_x, seq_y, seq_x_mark, seq_y_mark 267 | 268 | def __len__(self): 269 | return len(self.data_x) - self.seq_len - self.pred_len + 1 270 | 271 | def inverse_transform(self, data): 272 | return self.scaler.inverse_transform(data) 273 | 274 | 275 | class Dataset_PEMS(Dataset): 276 | def __init__(self, root_path, flag='train', size=None, 277 | features='S', data_path='ETTh1.csv', 278 | target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None): 279 | self.seq_len = size[0] 280 | self.label_len = size[1] 281 | self.pred_len = size[2] 282 | assert flag in ['train', 'test', 'val'] 283 | type_map = {'train': 0, 'val': 1, 'test': 2} 284 | self.set_type = type_map[flag] 285 | 286 | self.features = features 287 | self.target = target 288 | self.scale = scale 289 | self.timeenc = timeenc 290 | self.freq = freq 291 | 292 | self.root_path = root_path 293 | self.data_path = data_path 294 | self.__read_data__() 295 | 296 | def __read_data__(self): 297 | self.scaler = StandardScaler() 298 | data_file = os.path.join(self.root_path, self.data_path) 299 | print('data file:', data_file) 300 | data = np.load(data_file, allow_pickle=True) 301 | data = data['data'][:, :, 0] 302 | 303 | train_ratio = 0.6 304 | valid_ratio = 0.2 305 | train_data = data[:int(train_ratio * len(data))] 306 | valid_data = data[int(train_ratio * len(data)):int((train_ratio + valid_ratio) * len(data))] 307 | test_data = data[int((train_ratio + valid_ratio) * len(data)):] 308 | total_data = [train_data, valid_data, test_data] 309 | data = total_data[self.set_type] 310 | 311 | if self.scale: 312 | self.scaler.fit(data) 313 | data = self.scaler.transform(data) 314 | 315 | df = pd.DataFrame(data) 316 | df = df.fillna(method='ffill', limit=len(df)).fillna(method='bfill', limit=len(df)).values 317 | 318 | self.data_x = df 319 | self.data_y = df 320 | 321 | def __getitem__(self, index): 322 | if self.set_type == 2: 323 | s_begin = index * 12 324 | else: 325 | s_begin = index 326 | s_end = s_begin + self.seq_len 327 | r_begin = s_end - self.label_len 328 | r_end = r_begin + self.label_len + self.pred_len 329 | 330 | seq_x = self.data_x[s_begin:s_end] 331 | seq_y = self.data_y[r_begin:r_end] 332 | seq_x_mark = torch.zeros((seq_x.shape[0], 1)) 333 | seq_y_mark = torch.zeros((seq_y.shape[0], 1)) 334 | 335 | return seq_x, seq_y, seq_x_mark, seq_y_mark 336 | 337 | def __len__(self): 338 | if self.set_type == 2: 339 | return (len(self.data_x) - self.seq_len - self.pred_len + 1) // 12 340 | else: 341 | return len(self.data_x) - self.seq_len - self.pred_len + 1 342 | 343 | def inverse_transform(self, data): 344 | return self.scaler.inverse_transform(data) 345 | 346 | 347 | class Dataset_Solar(Dataset): 348 | def __init__(self, root_path, flag='train', size=None, 349 | features='S', data_path='ETTh1.csv', 350 | target='OT', scale=True, timeenc=0, freq='h'): 351 | self.seq_len = size[0] 352 | self.label_len = size[1] 353 | self.pred_len = size[2] 354 | assert flag in ['train', 'test', 'val'] 355 | type_map = {'train': 0, 'val': 1, 'test': 2} 356 | self.set_type = type_map[flag] 357 | 358 | self.features = features 359 | self.target = target 360 | self.scale = scale 361 | self.timeenc = timeenc 362 | self.freq = freq 363 | 364 | self.root_path = root_path 365 | self.data_path = data_path 366 | self.__read_data__() 367 | 368 | def __read_data__(self): 369 | self.scaler = StandardScaler() 370 | df_raw = [] 371 | with open(os.path.join(self.root_path, self.data_path), "r", encoding='utf-8') as f: 372 | for line in f.readlines(): 373 | line = line.strip('\n').split(',') 374 | data_line = np.stack([float(i) for i in line]) 375 | df_raw.append(data_line) 376 | df_raw = np.stack(df_raw, 0) 377 | df_raw = pd.DataFrame(df_raw) 378 | 379 | num_train = int(len(df_raw) * 0.7) 380 | num_test = int(len(df_raw) * 0.2) 381 | num_valid = int(len(df_raw) * 0.1) 382 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 383 | border2s = [num_train, num_train + num_valid, len(df_raw)] 384 | border1 = border1s[self.set_type] 385 | border2 = border2s[self.set_type] 386 | 387 | df_data = df_raw.values 388 | 389 | if self.scale: 390 | train_data = df_data[border1s[0]:border2s[0]] 391 | self.scaler.fit(train_data) 392 | data = self.scaler.transform(df_data) 393 | else: 394 | data = df_data 395 | 396 | self.data_x = data[border1:border2] 397 | self.data_y = data[border1:border2] 398 | 399 | def __getitem__(self, index): 400 | s_begin = index 401 | s_end = s_begin + self.seq_len 402 | r_begin = s_end - self.label_len 403 | r_end = r_begin + self.label_len + self.pred_len 404 | 405 | seq_x = self.data_x[s_begin:s_end] 406 | seq_y = self.data_y[r_begin:r_end] 407 | seq_x_mark = torch.zeros((seq_x.shape[0], 1)) 408 | seq_y_mark = torch.zeros((seq_x.shape[0], 1)) 409 | 410 | return seq_x, seq_y, seq_x_mark, seq_y_mark 411 | 412 | def __len__(self): 413 | return len(self.data_x) - self.seq_len - self.pred_len + 1 414 | 415 | def inverse_transform(self, data): 416 | return self.scaler.inverse_transform(data) 417 | 418 | 419 | class Dataset_Pred(Dataset): 420 | def __init__(self, root_path, flag='pred', size=None, 421 | features='S', data_path='ETTh1.csv', 422 | target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None): 423 | if size == None: 424 | self.seq_len = 24 * 4 * 4 425 | self.label_len = 24 * 4 426 | self.pred_len = 24 * 4 427 | else: 428 | self.seq_len = size[0] 429 | self.label_len = size[1] 430 | self.pred_len = size[2] 431 | assert flag in ['pred'] 432 | 433 | self.features = features 434 | self.target = target 435 | self.scale = scale 436 | self.inverse = inverse 437 | self.timeenc = timeenc 438 | self.freq = freq 439 | self.cols = cols 440 | self.root_path = root_path 441 | self.data_path = data_path 442 | self.__read_data__() 443 | 444 | def __read_data__(self): 445 | self.scaler = StandardScaler() 446 | df_raw = pd.read_csv(os.path.join(self.root_path, 447 | self.data_path)) 448 | if self.cols: 449 | cols = self.cols.copy() 450 | cols.remove(self.target) 451 | else: 452 | cols = list(df_raw.columns) 453 | cols.remove(self.target) 454 | cols.remove('date') 455 | df_raw = df_raw[['date'] + cols + [self.target]] 456 | border1 = len(df_raw) - self.seq_len 457 | border2 = len(df_raw) 458 | 459 | if self.features == 'M' or self.features == 'MS': 460 | cols_data = df_raw.columns[1:] 461 | df_data = df_raw[cols_data] 462 | elif self.features == 'S': 463 | df_data = df_raw[[self.target]] 464 | 465 | if self.scale: 466 | self.scaler.fit(df_data.values) 467 | data = self.scaler.transform(df_data.values) 468 | else: 469 | data = df_data.values 470 | 471 | tmp_stamp = df_raw[['date']][border1:border2] 472 | tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date) 473 | pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq) 474 | 475 | df_stamp = pd.DataFrame(columns=['date']) 476 | df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:]) 477 | if self.timeenc == 0: 478 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 479 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 480 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 481 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 482 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 483 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 484 | data_stamp = df_stamp.drop(['date'], 1).values 485 | elif self.timeenc == 1: 486 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 487 | data_stamp = data_stamp.transpose(1, 0) 488 | 489 | self.data_x = data[border1:border2] 490 | if self.inverse: 491 | self.data_y = df_data.values[border1:border2] 492 | else: 493 | self.data_y = data[border1:border2] 494 | self.data_stamp = data_stamp 495 | 496 | def __getitem__(self, index): 497 | s_begin = index 498 | s_end = s_begin + self.seq_len 499 | r_begin = s_end - self.label_len 500 | r_end = r_begin + self.label_len + self.pred_len 501 | 502 | seq_x = self.data_x[s_begin:s_end] 503 | if self.inverse: 504 | seq_y = self.data_x[r_begin:r_begin + self.label_len] 505 | else: 506 | seq_y = self.data_y[r_begin:r_begin + self.label_len] 507 | seq_x_mark = self.data_stamp[s_begin:s_end] 508 | seq_y_mark = self.data_stamp[r_begin:r_end] 509 | 510 | return seq_x, seq_y, seq_x_mark, seq_y_mark 511 | 512 | def __len__(self): 513 | return len(self.data_x) - self.seq_len + 1 514 | 515 | def inverse_transform(self, data): 516 | return self.scaler.inverse_transform(data) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: SimpleTM 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.9 7 | - einops=0.8.1 8 | - matplotlib=3.7.0 9 | - numpy=1.23.5 10 | - scikit-learn=1.2.2 11 | - scipy=1.13.1 12 | - sympy=1.13.3 13 | - pandas=1.5.3 14 | - pip: 15 | - reformer-pytorch==1.4.4 16 | - PyWavelets 17 | -------------------------------------------------------------------------------- /experiments/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from model import SimpleTM 3 | import torch 4 | 5 | # Add this at the beginning of your training script 6 | import torch._dynamo as dynamo 7 | dynamo.config.suppress_errors = True 8 | 9 | import numpy as np 10 | 11 | class Exp_Basic(object): 12 | def __init__(self, args): 13 | self.args = args 14 | self.model_dict = { 15 | 'SimpleTM': SimpleTM, 16 | } 17 | self.device = self._acquire_device() 18 | self.model = self._build_model().to(self.device) 19 | 20 | if self.args.compile: 21 | self.model = torch.compile( 22 | self.model, 23 | ) 24 | 25 | # Count trainable parameters 26 | model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) 27 | param_count = sum([np.prod(p.size()) for p in model_parameters]) 28 | 29 | # Calculate memory usage in bytes and convert to megabytes 30 | memory_usage_bytes = param_count * 4 # 4 bytes per float32 parameter 31 | memory_usage_MB = memory_usage_bytes / (1024 ** 2) 32 | print(f"Total trainable parameters: {param_count}") 33 | print(f"Memory usage for trainable parameters: {memory_usage_MB:.2f} MB") 34 | 35 | # Measure static memory footprint 36 | print(f"Static memory footprint (allocated): {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB") 37 | print(f"Static memory footprint (reserved): {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB") 38 | 39 | 40 | def _build_model(self): 41 | raise NotImplementedError 42 | 43 | def _acquire_device(self): 44 | if self.args.use_gpu: 45 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 46 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 47 | device = torch.device('cuda:{}'.format(self.args.gpu)) 48 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 49 | else: 50 | device = torch.device('cpu') 51 | print('Use CPU') 52 | return device 53 | 54 | def _get_data(self): 55 | pass 56 | 57 | def vali(self): 58 | pass 59 | 60 | def train(self): 61 | pass 62 | 63 | def test(self): 64 | pass -------------------------------------------------------------------------------- /experiments/exp_long_term_forecasting.py: -------------------------------------------------------------------------------- 1 | from torch.optim import lr_scheduler 2 | 3 | from data_provider.data_factory import data_provider 4 | from experiments.exp_basic import Exp_Basic 5 | from utils.tools import EarlyStopping, adjust_learning_rate, visual 6 | from utils.metrics import metric 7 | import torch 8 | import torch.nn as nn 9 | from torch import optim 10 | import os 11 | import time 12 | import warnings 13 | import numpy as np 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | torch.autograd.set_detect_anomaly(True) 18 | 19 | 20 | class Exp_Long_Term_Forecast(Exp_Basic): 21 | def __init__(self, args): 22 | super(Exp_Long_Term_Forecast, self).__init__(args) 23 | 24 | def _build_model(self): 25 | model = self.model_dict[self.args.model].Model(self.args).float() 26 | 27 | if self.args.use_multi_gpu and self.args.use_gpu: 28 | model = nn.DataParallel(model, device_ids=self.args.device_ids) 29 | return model 30 | 31 | def _get_data(self, flag): 32 | data_set, data_loader = data_provider(self.args, flag) 33 | return data_set, data_loader 34 | 35 | def _select_optimizer(self): 36 | model_optim = optim.AdamW(self.model.parameters(), lr=self.args.learning_rate) 37 | return model_optim 38 | 39 | def _select_criterion(self): 40 | if self.args.data == 'PEMS': 41 | criterion = nn.L1Loss() 42 | else: 43 | criterion = nn.MSELoss() 44 | return criterion 45 | 46 | def vali(self, vali_data, vali_loader, criterion): 47 | total_loss = [] 48 | self.model.eval() 49 | with torch.no_grad(): 50 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader): 51 | batch_x = batch_x.float().to(self.device) 52 | batch_y = batch_y.float() 53 | 54 | if 'PEMS' in self.args.data or 'Solar' in self.args.data: 55 | batch_x_mark = None 56 | batch_y_mark = None 57 | else: 58 | batch_x_mark = batch_x_mark.float().to(self.device) 59 | batch_y_mark = batch_y_mark.float().to(self.device) 60 | 61 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 62 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) 63 | 64 | if self.args.use_amp: 65 | with torch.cuda.amp.autocast(): 66 | if self.args.output_attention: 67 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 68 | else: 69 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 70 | else: 71 | if self.args.output_attention: 72 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 73 | else: 74 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 75 | f_dim = -1 if self.args.features == 'MS' else 0 76 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 77 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 78 | 79 | pred = outputs.detach().cpu() 80 | true = batch_y.detach().cpu() 81 | 82 | if self.args.data == 'PEMS': 83 | B, T, C = pred.shape 84 | pred = pred.numpy() 85 | true = true.numpy() 86 | pred = vali_data.inverse_transform(pred.reshape(-1, C)).reshape(B, T, C) 87 | true = vali_data.inverse_transform(true.reshape(-1, C)).reshape(B, T, C) 88 | mae, mse, rmse, mape, mspe = metric(pred, true) 89 | total_loss.append(mae) 90 | else: 91 | loss = criterion(pred, true) 92 | total_loss.append(loss) 93 | 94 | total_loss = np.average(total_loss) 95 | self.model.train() 96 | return total_loss 97 | 98 | def train(self, setting): 99 | train_data, train_loader = self._get_data(flag='train') 100 | vali_data, vali_loader = self._get_data(flag='val') 101 | test_data, test_loader = self._get_data(flag='test') 102 | 103 | path = os.path.join(self.args.checkpoints, setting) 104 | if not os.path.exists(path): 105 | os.makedirs(path) 106 | 107 | time_now = time.time() 108 | 109 | train_steps = len(train_loader) 110 | early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) 111 | 112 | model_optim = self._select_optimizer() 113 | criterion = self._select_criterion() 114 | 115 | if self.args.lradj == 'TST': 116 | scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim, 117 | steps_per_epoch=train_steps, 118 | pct_start=self.args.pct_start, 119 | epochs=self.args.train_epochs, 120 | max_lr=self.args.learning_rate) 121 | 122 | 123 | if self.args.use_amp: 124 | scaler = torch.cuda.amp.GradScaler() 125 | # # Efficiency: dynamic memory footprint 126 | # # Track dynamic memory usage over an epoch 127 | # torch.cuda.reset_peak_memory_stats() # Reset peak memory tracking 128 | 129 | for epoch in range(self.args.train_epochs): 130 | iter_count = 0 131 | train_loss = [] 132 | 133 | self.model.train() 134 | epoch_time = time.time() 135 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): 136 | iter_count += 1 137 | model_optim.zero_grad() 138 | batch_x = batch_x.float().to(self.device) 139 | 140 | batch_y = batch_y.float().to(self.device) 141 | if 'PEMS' in self.args.data or 'Solar' in self.args.data: 142 | batch_x_mark = None 143 | batch_y_mark = None 144 | else: 145 | batch_x_mark = batch_x_mark.float().to(self.device) 146 | batch_y_mark = batch_y_mark.float().to(self.device) 147 | 148 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 149 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) 150 | 151 | if self.args.use_amp: 152 | with torch.cuda.amp.autocast(): 153 | if self.args.output_attention: 154 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 155 | else: 156 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 157 | 158 | f_dim = -1 if self.args.features == 'MS' else 0 159 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 160 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 161 | loss = criterion(outputs, batch_y) 162 | train_loss.append(loss.item()) 163 | else: 164 | if self.args.output_attention: 165 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 166 | else: 167 | outputs, attn = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 168 | 169 | f_dim = -1 if self.args.features == 'MS' else 0 170 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 171 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 172 | 173 | loss = criterion(outputs, batch_y) + self.args.l1_weight * attn[0] 174 | train_loss.append(loss.item()) 175 | 176 | if (i + 1) % 30 == 0: 177 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) 178 | speed = (time.time() - time_now) / iter_count 179 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i) 180 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 181 | iter_count = 0 182 | time_now = time.time() 183 | 184 | if self.args.use_amp: 185 | scaler.scale(loss).backward() 186 | scaler.step(model_optim) 187 | scaler.update() 188 | else: 189 | loss.backward() 190 | model_optim.step() 191 | # # Efficiency: dynamic memory footprint 192 | # # Record current and peak memory usage after processing this batch 193 | # current_memory = torch.cuda.memory_allocated() 194 | # peak_memory = torch.cuda.max_memory_allocated() 195 | # print(f"Current memory: {current_memory / (1024 ** 2):.2f} MB, Peak memory: {peak_memory / (1024 ** 2):.2f} MB") 196 | 197 | if self.args.lradj == 'TST': 198 | adjust_learning_rate(model_optim, epoch + 1, self.args, scheduler, printout=False) 199 | scheduler.step() 200 | 201 | 202 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) 203 | train_loss = np.average(train_loss) 204 | vali_loss = self.vali(vali_data, vali_loader, criterion) 205 | test_loss = self.vali(test_data, test_loader, criterion) 206 | 207 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( 208 | epoch + 1, train_steps, train_loss, vali_loss, test_loss)) 209 | early_stopping(vali_loss, self.model, path) 210 | if early_stopping.early_stop: 211 | print("Early stopping") 212 | break 213 | 214 | if self.args.lradj != 'TST': 215 | adjust_learning_rate(model_optim, epoch + 1, self.args) 216 | else: 217 | adjust_learning_rate(model_optim, epoch + 1, self.args, scheduler) 218 | 219 | 220 | best_model_path = path + '/' + 'checkpoint.pth' 221 | self.model.load_state_dict(torch.load(best_model_path)) 222 | 223 | return self.model 224 | 225 | def test(self, setting, test=0): 226 | test_data, test_loader = self._get_data(flag='test') 227 | if test: 228 | print('loading model') 229 | self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth'))) 230 | 231 | preds = [] 232 | trues = [] 233 | folder_path = './checkpoints/' + setting + '/' 234 | if not os.path.exists(folder_path): 235 | os.makedirs(folder_path) 236 | 237 | self.model.eval() 238 | with torch.no_grad(): 239 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader): 240 | batch_x = batch_x.float().to(self.device) 241 | batch_y = batch_y.float().to(self.device) 242 | 243 | if 'PEMS' in self.args.data or 'Solar' in self.args.data: 244 | batch_x_mark = None 245 | batch_y_mark = None 246 | else: 247 | batch_x_mark = batch_x_mark.float().to(self.device) 248 | batch_y_mark = batch_y_mark.float().to(self.device) 249 | 250 | 251 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 252 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) 253 | # encoder - decoder 254 | if self.args.use_amp: 255 | with torch.cuda.amp.autocast(): 256 | if self.args.output_attention: 257 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 258 | else: 259 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 260 | else: 261 | if self.args.output_attention: 262 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 263 | else: 264 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 265 | 266 | f_dim = -1 if self.args.features == 'MS' else 0 267 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 268 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 269 | outputs = outputs.detach().cpu().numpy() 270 | batch_y = batch_y.detach().cpu().numpy() 271 | 272 | pred = outputs 273 | true = batch_y 274 | 275 | preds.append(pred) 276 | trues.append(true) 277 | if i % 20 == 0: 278 | input = batch_x.detach().cpu().numpy() 279 | if test_data.scale and self.args.inverse: 280 | shape = input.shape 281 | input = test_data.inverse_transform(input.squeeze(0)).reshape(shape) 282 | gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0) 283 | pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0) 284 | visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf')) 285 | 286 | preds = np.array(preds) 287 | trues = np.array(trues) 288 | print('test shape:', preds.shape, trues.shape) 289 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) 290 | trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1]) 291 | print('test shape:', preds.shape, trues.shape) 292 | 293 | if self.args.data == 'PEMS': 294 | B, T, C = preds.shape 295 | preds = test_data.inverse_transform(preds.reshape(-1, C)).reshape(B, T, C) 296 | trues = test_data.inverse_transform(trues.reshape(-1, C)).reshape(B, T, C) 297 | 298 | # result save 299 | folder_path = './checkpoints/' + setting + '/' 300 | if not os.path.exists(folder_path): 301 | os.makedirs(folder_path) 302 | 303 | mae, mse, rmse, mape, mspe = metric(preds, trues) 304 | print('mse:{}, mae:{}'.format(mse, mae)) 305 | print('rmse:{}, mape:{}, mspe:{}'.format(rmse, mape, mspe)) 306 | f = open("result_long_term_forecast.txt", 'a') 307 | f.write(setting + " \n") 308 | if self.args.data == 'PEMS': 309 | f.write('mae:{}, mape:{}, rmse:{}'.format(mae, mape, rmse)) 310 | else: 311 | f.write('mse:{}, mae:{}'.format(mse, mae)) 312 | f.write('\n') 313 | f.write('\n') 314 | f.close() 315 | 316 | 317 | return 318 | 319 | 320 | def predict(self, setting, load=False): 321 | pred_data, pred_loader = self._get_data(flag='pred') 322 | 323 | if load: 324 | path = os.path.join(self.args.checkpoints, setting) 325 | best_model_path = path + '/' + 'checkpoint.pth' 326 | self.model.load_state_dict(torch.load(best_model_path)) 327 | 328 | preds = [] 329 | 330 | self.model.eval() 331 | with torch.no_grad(): 332 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader): 333 | batch_x = batch_x.float().to(self.device) 334 | batch_y = batch_y.float() 335 | batch_x_mark = batch_x_mark.float().to(self.device) 336 | batch_y_mark = batch_y_mark.float().to(self.device) 337 | 338 | 339 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 340 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) 341 | # encoder - decoder 342 | if self.args.use_amp: 343 | with torch.cuda.amp.autocast(): 344 | if self.args.output_attention: 345 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 346 | else: 347 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 348 | else: 349 | if self.args.output_attention: 350 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 351 | else: 352 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 353 | outputs = outputs.detach().cpu().numpy() 354 | if pred_data.scale and self.args.inverse: 355 | shape = outputs.shape 356 | outputs = pred_data.inverse_transform(outputs.squeeze(0)).reshape(shape) 357 | preds.append(outputs) 358 | 359 | preds = np.array(preds) 360 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) 361 | 362 | # result save 363 | folder_path = './results/' + setting + '/' 364 | if not os.path.exists(folder_path): 365 | os.makedirs(folder_path) 366 | 367 | np.save(folder_path + 'real_prediction.npy', preds) 368 | 369 | return -------------------------------------------------------------------------------- /figures/Efficiency.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsingh-group/SimpleTM/3c77d820837b726afb03c943235ea95bc924243d/figures/Efficiency.jpg -------------------------------------------------------------------------------- /figures/Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsingh-group/SimpleTM/3c77d820837b726afb03c943235ea95bc924243d/figures/Framework.png -------------------------------------------------------------------------------- /figures/Long_term_forecast_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vsingh-group/SimpleTM/3c77d820837b726afb03c943235ea95bc924243d/figures/Long_term_forecast_results.jpg -------------------------------------------------------------------------------- /layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DataEmbedding_inverted(nn.Module): 5 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 6 | super(DataEmbedding_inverted, self).__init__() 7 | self.value_embedding = nn.Linear(c_in, d_model) 8 | self.dropout = nn.Dropout(p=dropout) 9 | 10 | def forward(self, x, x_mark): 11 | x = x.permute(0, 2, 1) 12 | if x_mark is None: 13 | x = self.value_embedding(x) 14 | else: 15 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) 16 | return self.dropout(x) -------------------------------------------------------------------------------- /layers/SWTAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import sqrt 5 | import pywt 6 | 7 | 8 | class WaveletEmbedding(nn.Module): 9 | def __init__(self, d_channel=16, swt=True, requires_grad=False, wv='db2', m=2, 10 | kernel_size=None): 11 | super().__init__() 12 | 13 | self.swt = swt 14 | self.d_channel = d_channel 15 | self.m = m # Number of decomposition levels of detailed coefficients 16 | 17 | if kernel_size is None: 18 | self.wavelet = pywt.Wavelet(wv) 19 | if self.swt: 20 | h0 = torch.tensor(self.wavelet.dec_lo[::-1], dtype=torch.float32) 21 | h1 = torch.tensor(self.wavelet.dec_hi[::-1], dtype=torch.float32) 22 | else: 23 | h0 = torch.tensor(self.wavelet.rec_lo[::-1], dtype=torch.float32) 24 | h1 = torch.tensor(self.wavelet.rec_hi[::-1], dtype=torch.float32) 25 | self.h0 = nn.Parameter(torch.tile(h0[None, None, :], [self.d_channel, 1, 1]), requires_grad=requires_grad) 26 | self.h1 = nn.Parameter(torch.tile(h1[None, None, :], [self.d_channel, 1, 1]), requires_grad=requires_grad) 27 | self.kernel_size = self.h0.shape[-1] 28 | else: 29 | self.kernel_size = kernel_size 30 | self.h0 = nn.Parameter(torch.Tensor(self.d_channel, 1, self.kernel_size), requires_grad=requires_grad) 31 | self.h1 = nn.Parameter(torch.Tensor(self.d_channel, 1, self.kernel_size), requires_grad=requires_grad) 32 | nn.init.xavier_uniform_(self.h0) 33 | nn.init.xavier_uniform_(self.h1) 34 | 35 | with torch.no_grad(): 36 | self.h0.data = self.h0.data / torch.norm(self.h0.data, dim=-1, keepdim=True) 37 | self.h1.data = self.h1.data / torch.norm(self.h1.data, dim=-1, keepdim=True) 38 | 39 | 40 | def forward(self, x): 41 | if self.swt: 42 | coeffs = self.swt_decomposition(x, self.h0, self.h1, self.m, self.kernel_size) 43 | else: 44 | coeffs = self.swt_reconstruction(x, self.h0, self.h1, self.m, self.kernel_size) 45 | return coeffs 46 | 47 | def swt_decomposition(self, x, h0, h1, depth, kernel_size): 48 | approx_coeffs = x 49 | coeffs = [] 50 | dilation = 1 51 | for _ in range(depth): 52 | padding = dilation * (kernel_size - 1) 53 | padding_r = (kernel_size * dilation) // 2 54 | pad = (padding - padding_r, padding_r) 55 | approx_coeffs_pad = F.pad(approx_coeffs, pad, "circular") 56 | detail_coeff = F.conv1d(approx_coeffs_pad, h1, dilation=dilation, groups=x.shape[1]) 57 | approx_coeffs = F.conv1d(approx_coeffs_pad, h0, dilation=dilation, groups=x.shape[1]) 58 | coeffs.append(detail_coeff) 59 | dilation *= 2 60 | coeffs.append(approx_coeffs) 61 | 62 | return torch.stack(list(reversed(coeffs)), -2) 63 | 64 | def swt_reconstruction(self, coeffs, g0, g1, m, kernel_size): 65 | dilation = 2 ** (m - 1) 66 | approx_coeff = coeffs[:,:,0,:] 67 | detail_coeffs = coeffs[:,:,1:,:] 68 | 69 | for i in range(m): 70 | detail_coeff = detail_coeffs[:,:,i,:] 71 | padding = dilation * (kernel_size - 1) 72 | padding_l = (dilation * kernel_size) // 2 73 | pad = (padding_l, padding - padding_l) 74 | approx_coeff_pad = F.pad(approx_coeff, pad, "circular") 75 | detail_coeff_pad = F.pad(detail_coeff, pad, "circular") 76 | 77 | y = F.conv1d(approx_coeff_pad, g0, groups=approx_coeff.shape[1], dilation=dilation) + \ 78 | F.conv1d(detail_coeff_pad, g1, groups=detail_coeff.shape[1], dilation=dilation) 79 | approx_coeff = y / 2 80 | dilation //= 2 81 | 82 | return approx_coeff 83 | 84 | 85 | class GeomAttentionLayer(nn.Module): 86 | def __init__(self, attention, d_model, 87 | requires_grad=True, wv='db2', m=2, kernel_size=None, 88 | d_channel=None, geomattn_dropout=0.5,): 89 | super(GeomAttentionLayer, self).__init__() 90 | 91 | self.d_channel = d_channel 92 | self.inner_attention = attention 93 | 94 | self.swt = WaveletEmbedding(d_channel=self.d_channel, swt=True, requires_grad=requires_grad, wv=wv, m=m, kernel_size=kernel_size) 95 | self.query_projection = nn.Sequential( 96 | nn.Linear(d_model, d_model), 97 | nn.Dropout(geomattn_dropout) 98 | ) 99 | self.key_projection = nn.Sequential( 100 | nn.Linear(d_model, d_model), 101 | nn.Dropout(geomattn_dropout) 102 | ) 103 | self.value_projection = nn.Sequential( 104 | nn.Linear(d_model, d_model), 105 | nn.Dropout(geomattn_dropout) 106 | ) 107 | self.out_projection = nn.Sequential( 108 | nn.Linear(d_model, d_model), 109 | WaveletEmbedding(d_channel=self.d_channel, swt=False, requires_grad=requires_grad, wv=wv, m=m, kernel_size=kernel_size), 110 | ) 111 | 112 | def forward(self, queries, keys, values, attn_mask=None, tau=None, delta=None): 113 | queries = self.swt(queries) 114 | keys = self.swt(keys) 115 | values = self.swt(values) 116 | 117 | queries = self.query_projection(queries).permute(0,3,2,1) 118 | keys = self.key_projection(keys).permute(0,3,2,1) 119 | values = self.value_projection(values).permute(0,3,2,1) 120 | 121 | out, attn = self.inner_attention( 122 | queries, 123 | keys, 124 | values, 125 | ) 126 | 127 | out = self.out_projection(out.permute(0,3,2,1)) 128 | 129 | return out, attn 130 | 131 | 132 | class GeomAttention(nn.Module): 133 | def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1, 134 | output_attention=False, 135 | alpha=1.,): 136 | super(GeomAttention, self).__init__() 137 | self.scale = scale 138 | self.mask_flag = mask_flag 139 | self.output_attention = output_attention 140 | self.dropout = nn.Dropout(attention_dropout) 141 | 142 | self.alpha = alpha 143 | 144 | def forward(self, queries, keys, values, attn_mask=None): 145 | B, L, H, E = queries.shape 146 | _, S, _, _ = values.shape 147 | scale = self.scale or 1. / sqrt(E) 148 | 149 | dot_product = torch.einsum("blhe,bshe->bhls", queries, keys) 150 | 151 | queries_norm2 = torch.sum(queries**2, dim=-1) 152 | keys_norm2 = torch.sum(keys**2, dim=-1) 153 | queries_norm2 = queries_norm2.permute(0, 2, 1).unsqueeze(-1) # (B, H, L, 1) 154 | keys_norm2 = keys_norm2.permute(0, 2, 1).unsqueeze(-2) # (B, H, 1, S) 155 | wedge_norm2 = queries_norm2 * keys_norm2 - dot_product ** 2 # (B, H, L, S) 156 | wedge_norm2 = F.relu(wedge_norm2) 157 | wedge_norm = torch.sqrt(wedge_norm2 + 1e-8) 158 | 159 | scores = (1 - self.alpha) * dot_product + self.alpha * wedge_norm 160 | scores = scores * scale 161 | 162 | if self.mask_flag: 163 | if attn_mask is None: 164 | attn_mask = torch.tril(torch.ones(L, S)).to(scores.device) 165 | scores.masked_fill_(attn_mask.unsqueeze(1).unsqueeze(2) == 0, float('-inf')) 166 | 167 | A = self.dropout(torch.softmax(scores, dim=-1)) 168 | 169 | V = torch.einsum("bhls,bshd->blhd", A, values) 170 | 171 | if self.output_attention: 172 | return V.contiguous() 173 | else: 174 | return (V.contiguous(), scores.abs().mean()) -------------------------------------------------------------------------------- /layers/StandardNorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Normalize(nn.Module): 5 | def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): 6 | """ 7 | :param num_features: the number of features or channels 8 | :param eps: a value added for numerical stability 9 | :param affine: if True, RevIN has learnable affine parameters 10 | """ 11 | super(Normalize, self).__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.affine = affine 15 | self.subtract_last = subtract_last 16 | self.non_norm = non_norm 17 | if self.affine: 18 | self._init_params() 19 | 20 | def forward(self, x, mode: str): 21 | if mode == 'norm': 22 | self._get_statistics(x) 23 | x = self._normalize(x) 24 | elif mode == 'denorm': 25 | x = self._denormalize(x) 26 | else: 27 | raise NotImplementedError 28 | return x 29 | 30 | def _init_params(self): 31 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 32 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 33 | 34 | def _get_statistics(self, x): 35 | dim2reduce = tuple(range(1, x.ndim - 1)) 36 | if self.subtract_last: 37 | self.last = x[:, -1, :].unsqueeze(1) 38 | else: 39 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 40 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() 41 | 42 | def _normalize(self, x): 43 | if self.non_norm: 44 | return x 45 | if self.subtract_last: 46 | x = x - self.last 47 | else: 48 | x = x - self.mean 49 | x = x / self.stdev 50 | if self.affine: 51 | x = x * self.affine_weight 52 | x = x + self.affine_bias 53 | return x 54 | 55 | def _denormalize(self, x): 56 | if self.non_norm: 57 | return x 58 | if self.affine: 59 | x = x - self.affine_bias 60 | x = x / (self.affine_weight + self.eps * self.eps) 61 | x = x * self.stdev 62 | if self.subtract_last: 63 | x = x + self.last 64 | else: 65 | x = x + self.mean 66 | return x -------------------------------------------------------------------------------- /layers/Transformer_Encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class EncoderLayer(nn.Module): 6 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", dec_in=866): 7 | super(EncoderLayer, self).__init__() 8 | d_ff = d_ff or 4 * d_model 9 | self.attention = attention 10 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 11 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 12 | self.norm1 = nn.LayerNorm(d_model) 13 | self.norm2 = nn.LayerNorm(d_model) 14 | self.dropout = nn.Dropout(dropout) 15 | self.activation = F.relu if activation == "relu" else F.gelu 16 | 17 | def forward(self, x, attn_mask=None, tau=None, delta=None): 18 | new_x, attn = self.attention( 19 | x, x, x, 20 | attn_mask=attn_mask, 21 | tau=tau, delta=delta 22 | ) 23 | x = x + self.dropout(new_x) 24 | y = x = self.norm1(x) 25 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 26 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 27 | return self.norm2(x + y), attn 28 | 29 | 30 | class Encoder(nn.Module): 31 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 32 | super(Encoder, self).__init__() 33 | self.attn_layers = nn.ModuleList(attn_layers) 34 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 35 | self.norm = norm_layer 36 | 37 | def forward(self, x, attn_mask=None, tau=None, delta=None): 38 | # x [B, L, D] 39 | attns = [] 40 | if self.conv_layers is not None: 41 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 42 | delta = delta if i == 0 else None 43 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 44 | x = conv_layer(x) 45 | attns.append(attn) 46 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 47 | attns.append(attn) 48 | else: 49 | for attn_layer in self.attn_layers: 50 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 51 | attns.append(attn) 52 | 53 | if self.norm is not None: 54 | x = self.norm(x) 55 | 56 | return x, attns -------------------------------------------------------------------------------- /model/SimpleTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_Encoder import Encoder, EncoderLayer 5 | from layers.SWTAttention_Family import GeomAttentionLayer, GeomAttention 6 | from layers.Embed import DataEmbedding_inverted 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, configs): 11 | super(Model, self).__init__() 12 | self.seq_len = configs.seq_len 13 | self.pred_len = configs.pred_len 14 | self.output_attention = configs.output_attention 15 | self.use_norm = configs.use_norm 16 | self.geomattn_dropout = configs.geomattn_dropout 17 | self.alpha = configs.alpha 18 | self.kernel_size = configs.kernel_size 19 | 20 | enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, 21 | configs.embed, configs.freq, configs.dropout) 22 | self.enc_embedding = enc_embedding 23 | 24 | encoder = Encoder( 25 | [ 26 | EncoderLayer( 27 | GeomAttentionLayer( 28 | GeomAttention( 29 | False, configs.factor, attention_dropout=configs.dropout, 30 | output_attention=configs.output_attention, alpha=self.alpha 31 | ), 32 | configs.d_model, 33 | requires_grad=configs.requires_grad, 34 | wv=configs.wv, 35 | m=configs.m, 36 | d_channel=configs.dec_in, 37 | kernel_size=self.kernel_size, 38 | geomattn_dropout=self.geomattn_dropout 39 | ), 40 | configs.d_model, 41 | configs.d_ff, 42 | dropout=configs.dropout, 43 | activation=configs.activation, 44 | ) for l in range(configs.e_layers) 45 | ], 46 | norm_layer=torch.nn.LayerNorm(configs.d_model) 47 | ) 48 | self.encoder = encoder 49 | 50 | projector = nn.Linear(configs.d_model, self.pred_len, bias=True) 51 | self.projector = projector 52 | 53 | 54 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 55 | if self.use_norm: 56 | means = x_enc.mean(1, keepdim=True).detach() 57 | x_enc = x_enc - means 58 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 59 | # x_enc /= stdev 60 | x_enc = x_enc / stdev 61 | 62 | _, _, N = x_enc.shape 63 | 64 | enc_embedding = self.enc_embedding 65 | encoder = self.encoder 66 | projector = self.projector 67 | # Linear Projection B L N -> B L' (pseudo temporal tokens) N 68 | enc_out = enc_embedding(x_enc, x_mark_enc) 69 | 70 | # SimpleTM Layer B L' N -> B L' N 71 | enc_out, attns = encoder(enc_out, attn_mask=None) 72 | 73 | # Output Projection B L' N -> B H (Horizon) N 74 | dec_out = projector(enc_out).permute(0, 2, 1)[:, :, :N] 75 | 76 | if self.use_norm: 77 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 78 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 79 | 80 | return dec_out, attns 81 | 82 | 83 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 84 | dec_out, attns = self.forecast(x_enc, None, None, None) 85 | return dec_out, attns -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from experiments.exp_long_term_forecasting import Exp_Long_Term_Forecast 4 | import random 5 | import numpy as np 6 | from model.SimpleTM import Model 7 | 8 | if __name__ == '__main__': 9 | 10 | parser = argparse.ArgumentParser(description='iTransformer') 11 | 12 | # basic config 13 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status') 14 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') 15 | 16 | parser.add_argument('--model', type=str, required=True, default='SimpleTM', 17 | help='model name, options: [SimpleTM]') 18 | 19 | # data loader 20 | parser.add_argument('--data', type=str, required=True, default='custom', help='dataset type') 21 | parser.add_argument('--root_path', type=str, default='./data/electricity/', help='root path of the data file') 22 | parser.add_argument('--data_path', type=str, default='electricity.csv', help='data csv file') 23 | parser.add_argument('--features', type=str, default='M', 24 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') 25 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') 26 | parser.add_argument('--freq', type=str, default='h', 27 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 28 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') 29 | 30 | # forecasting task 31 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length') 32 | parser.add_argument('--label_len', type=int, default=0, help='start token length') 33 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length') 34 | 35 | # model define 36 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') 37 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size') 38 | parser.add_argument('--c_out', type=int, default=7, help='output size') 39 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') 40 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') 41 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') 42 | parser.add_argument('--factor', type=int, default=1, help='attn factor') 43 | parser.add_argument('--distil', action='store_false', 44 | help='whether to use distilling in encoder, using this argument means not using distilling', 45 | default=True) 46 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout') 47 | parser.add_argument('--geomattn_dropout', type=float, default=0.5, help='dropout rate of the projection layer in the geometric attention') 48 | parser.add_argument('--embed', type=str, default='timeF', 49 | help='time features encoding, options:[timeF, fixed, learned]') 50 | parser.add_argument('--activation', type=str, default='gelu', help='activation') 51 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data') 52 | 53 | # optimization 54 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') 55 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 56 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs') 57 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data') 58 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience') 59 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate') 60 | parser.add_argument('--des', type=str, default='test', help='exp description') 61 | parser.add_argument('--loss', type=str, default='MSE', help='loss function') 62 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate') 63 | parser.add_argument('--pct_start', type=float, default=0.2, help='Warmup ratio for the learning rate scheduler') 64 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 65 | 66 | # GPU 67 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') 68 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 69 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 70 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus') 71 | 72 | parser.add_argument('--exp_name', type=str, required=False, default='MTSF', 73 | help='experiemnt name, options:[MTSF, partial_train]') 74 | parser.add_argument('--channel_independence', type=bool, default=False, help='whether to use channel_independence mechanism') 75 | parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False) 76 | parser.add_argument('--class_strategy', type=str, default='projection', help='projection/average/cls_token') 77 | parser.add_argument('--target_root_path', type=str, default='./data/electricity/', help='root path of the data file') 78 | parser.add_argument('--target_data_path', type=str, default='electricity.csv', help='data file') 79 | parser.add_argument('--efficient_training', type=bool, default=False, help='whether to use efficient_training (exp_name should be partial train)') # See Figure 8 of our paper for the detail 80 | parser.add_argument('--use_norm', type=int, default=True, help='use norm and denorm') 81 | parser.add_argument('--partial_start_index', type=int, default=0, help='the start index of variates for partial training, ' 82 | 'you can select [partial_start_index, min(enc_in + partial_start_index, N)]') 83 | 84 | # SimpleTM Arguments 85 | parser.add_argument('--requires_grad', type=bool, default=True, help='Set to True to enable learnable wavelets') 86 | parser.add_argument('--wv', type=str, default='db1', help='Wavelet filter type. Supports all wavelets available in PyTorch Wavelets') 87 | parser.add_argument('--m', type=int, default=3, help='Number of levels for the stationary wavelet transform') 88 | parser.add_argument('--kernel_size', default=None, help='Specify the length of randomly initialized wavelets (if not None)') 89 | parser.add_argument('--alpha', type=float, default=1, help='Weight of the inner product score in geometric attention') 90 | parser.add_argument('--l1_weight', type=float, default=5e-5, help='Weight of L1 loss') 91 | parser.add_argument('--d_model', type=int, default=32, help='Dimensionality of pseudo tokens') 92 | parser.add_argument('--d_ff', type=int, default=32, help='Dimensionality of the feedforward network') 93 | parser.add_argument('--e_layers', type=int, default=1, help='Number of SimpleTM layers') 94 | parser.add_argument('--compile', type=bool, default=False, help='Set to True to enable compilation, which can accelerate speed but may slightly impact performance') 95 | parser.add_argument('--output_attention', action='store_true', help='Set to False to output attn, which can be used to compute training loss') 96 | 97 | parser.add_argument('--fix_seed', type=int, default=2025, help='gpu') 98 | 99 | args = parser.parse_args() 100 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 101 | 102 | fix_seed = args.fix_seed 103 | random.seed(fix_seed) 104 | torch.manual_seed(fix_seed) 105 | np.random.seed(fix_seed) 106 | 107 | if args.use_gpu and args.use_multi_gpu: 108 | args.devices = args.devices.replace(' ', '') 109 | device_ids = args.devices.split(',') 110 | args.device_ids = [int(id_) for id_ in device_ids] 111 | args.gpu = args.device_ids[0] 112 | 113 | print('Args in experiment:') 114 | print(args) 115 | 116 | Exp = Exp_Long_Term_Forecast 117 | 118 | if args.is_training: 119 | for ii in range(args.itr): 120 | # setting record of experiments 121 | setting = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( 122 | args.model_id, 123 | args.data, 124 | args.seq_len, 125 | args.pred_len, 126 | args.d_model, 127 | args.d_ff, 128 | args.e_layers, 129 | args.wv, 130 | args.kernel_size, 131 | args.m, 132 | args.alpha, 133 | args.l1_weight, 134 | args.learning_rate, 135 | args.lradj, 136 | args.batch_size, 137 | args.fix_seed, 138 | args.use_norm, 139 | ii) 140 | 141 | exp = Exp(args) # set experiments 142 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 143 | exp.train(setting) 144 | 145 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 146 | exp.test(setting) 147 | 148 | if args.do_predict: 149 | print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 150 | exp.predict(setting, True) 151 | 152 | torch.cuda.empty_cache() 153 | else: 154 | 155 | ii = 0 156 | setting = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( 157 | args.data, 158 | args.seq_len, 159 | args.pred_len, 160 | args.d_model, 161 | args.d_ff, 162 | args.e_layers, 163 | args.wv, 164 | args.kernel_size, 165 | args.m, 166 | args.alpha, 167 | args.l1_weight, 168 | args.learning_rate, 169 | args.lradj, 170 | args.batch_size, 171 | args.fix_seed, 172 | args.use_norm, 173 | ii) 174 | 175 | exp = Exp(args) # set experiments 176 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 177 | exp.test(setting, test=1) 178 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ECL/SimpleTM.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 3 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL \ 11 | --model "$model_name" \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 1 \ 17 | --d_model 256 \ 18 | --d_ff 1024 \ 19 | --learning_rate 0.01 \ 20 | --batch_size 256 \ 21 | --fix_seed 2025 \ 22 | --use_norm 1 \ 23 | --wv "db1" \ 24 | --m 3 \ 25 | --enc_in 321 \ 26 | --dec_in 321 \ 27 | --c_out 321 \ 28 | --des 'Exp' \ 29 | --itr 3 \ 30 | --alpha 0.0 \ 31 | --l1_weight 0.0 32 | 33 | python -u run.py \ 34 | --is_training 1 \ 35 | --lradj 'TST' \ 36 | --patience 3 \ 37 | --root_path ./dataset/electricity/ \ 38 | --data_path electricity.csv \ 39 | --model_id ECL \ 40 | --model "$model_name" \ 41 | --data custom \ 42 | --features M \ 43 | --seq_len 96 \ 44 | --pred_len 192 \ 45 | --e_layers 1 \ 46 | --d_model 256 \ 47 | --d_ff 1024 \ 48 | --learning_rate 0.006 \ 49 | --batch_size 256 \ 50 | --fix_seed 2025 \ 51 | --use_norm 1 \ 52 | --wv "db1" \ 53 | --m 3 \ 54 | --enc_in 321 \ 55 | --dec_in 321 \ 56 | --c_out 321 \ 57 | --des 'Exp' \ 58 | --itr 3 \ 59 | --alpha 0.0 \ 60 | --l1_weight 0.0 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --lradj 'TST' \ 65 | --patience 3 \ 66 | --root_path ./dataset/electricity/ \ 67 | --data_path electricity.csv \ 68 | --model_id ECL \ 69 | --model "$model_name" \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --pred_len 336 \ 74 | --e_layers 1 \ 75 | --d_model 256 \ 76 | --d_ff 1024 \ 77 | --learning_rate 0.006 \ 78 | --batch_size 256 \ 79 | --fix_seed 2025 \ 80 | --use_norm 1 \ 81 | --wv "db1" \ 82 | --m 3 \ 83 | --enc_in 321 \ 84 | --dec_in 321 \ 85 | --c_out 321 \ 86 | --des 'Exp' \ 87 | --itr 3 \ 88 | --alpha 0.0 \ 89 | --l1_weight 5e-5 90 | 91 | python -u run.py \ 92 | --is_training 1 \ 93 | --lradj 'TST' \ 94 | --patience 3 \ 95 | --root_path ./dataset/electricity/ \ 96 | --data_path electricity.csv \ 97 | --model_id ECL \ 98 | --model "$model_name" \ 99 | --data custom \ 100 | --features M \ 101 | --seq_len 96 \ 102 | --pred_len 720 \ 103 | --e_layers 1 \ 104 | --d_model 256 \ 105 | --d_ff 1024 \ 106 | --learning_rate 0.006 \ 107 | --batch_size 256 \ 108 | --fix_seed 2025 \ 109 | --use_norm 1 \ 110 | --wv "db1" \ 111 | --m 3 \ 112 | --enc_in 321 \ 113 | --dec_in 321 \ 114 | --c_out 321 \ 115 | --des 'Exp' \ 116 | --itr 3 \ 117 | --alpha 0.0 \ 118 | --l1_weight 5e-5 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/SimpleTM_h1.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj TST \ 7 | --patience 3 \ 8 | --root_path ./dataset/ETT-small/ \ 9 | --data_path ETTh1.csv \ 10 | --model_id ETTh1 \ 11 | --model $model_name \ 12 | --data ETTh1 \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 1 \ 17 | --d_model 32 \ 18 | --d_ff 32 \ 19 | --learning_rate 0.02 \ 20 | --batch_size 256 \ 21 | --fix_seed 2025 \ 22 | --use_norm 1 \ 23 | --wv db1 \ 24 | --m 3 \ 25 | --enc_in 7 \ 26 | --dec_in 7 \ 27 | --c_out 7 \ 28 | --des Exp \ 29 | --itr 3 \ 30 | --alpha 0.3 \ 31 | --l1_weight 0.0005 \ 32 | 33 | python -u run.py \ 34 | --is_training 1 \ 35 | --lradj TST \ 36 | --patience 3 \ 37 | --root_path ./dataset/ETT-small/ \ 38 | --data_path ETTh1.csv \ 39 | --model_id ETTh1 \ 40 | --model $model_name \ 41 | --data ETTh1 \ 42 | --features M \ 43 | --seq_len 96 \ 44 | --pred_len 192 \ 45 | --e_layers 1 \ 46 | --d_model 32 \ 47 | --d_ff 32 \ 48 | --learning_rate 0.02 \ 49 | --batch_size 256 \ 50 | --fix_seed 2025 \ 51 | --use_norm 1 \ 52 | --wv db1 \ 53 | --m 3 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --des Exp \ 58 | --itr 3 \ 59 | --alpha 1.0 \ 60 | --l1_weight 5e-05 \ 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --lradj TST \ 65 | --patience 3 \ 66 | --root_path ./dataset/ETT-small/ \ 67 | --data_path ETTh1.csv \ 68 | --model_id ETTh1 \ 69 | --model $model_name \ 70 | --data ETTh1 \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --pred_len 336 \ 74 | --e_layers 4 \ 75 | --d_model 64 \ 76 | --d_ff 64 \ 77 | --learning_rate 0.002 \ 78 | --batch_size 256 \ 79 | --fix_seed 2025 \ 80 | --use_norm 1 \ 81 | --wv db1 \ 82 | --m 3 \ 83 | --enc_in 7 \ 84 | --dec_in 7 \ 85 | --c_out 7 \ 86 | --des Exp \ 87 | --itr 3 \ 88 | --alpha 0.0 \ 89 | --l1_weight 0.0 \ 90 | 91 | python -u run.py \ 92 | --is_training 1 \ 93 | --lradj TST \ 94 | --patience 3 \ 95 | --root_path ./dataset/ETT-small/ \ 96 | --data_path ETTh1.csv \ 97 | --model_id ETTh1 \ 98 | --model $model_name \ 99 | --data ETTh1 \ 100 | --features M \ 101 | --seq_len 96 \ 102 | --pred_len 720 \ 103 | --e_layers 1 \ 104 | --d_model 32 \ 105 | --d_ff 32 \ 106 | --learning_rate 0.009 \ 107 | --batch_size 256 \ 108 | --fix_seed 2025 \ 109 | --use_norm 1 \ 110 | --wv db1 \ 111 | --m 1 \ 112 | --enc_in 7 \ 113 | --dec_in 7 \ 114 | --c_out 7 \ 115 | --des Exp \ 116 | --itr 3 \ 117 | --alpha 0.9 \ 118 | --l1_weight 0.0005 \ -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/SimpleTM_h2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj TST \ 7 | --patience 3 \ 8 | --root_path ./dataset/ETT-small/ \ 9 | --data_path ETTh2.csv \ 10 | --model_id ETTh2 \ 11 | --model $model_name \ 12 | --data ETTh2 \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 1 \ 17 | --d_model 32 \ 18 | --d_ff 32 \ 19 | --learning_rate 0.006 \ 20 | --batch_size 256 \ 21 | --fix_seed 2025 \ 22 | --use_norm 1 \ 23 | --wv bior3.1 \ 24 | --m 1 \ 25 | --enc_in 7 \ 26 | --dec_in 7 \ 27 | --c_out 7 \ 28 | --des Exp \ 29 | --itr 3 \ 30 | --alpha 0.1 \ 31 | --l1_weight 0.0005 \ 32 | 33 | python -u run.py \ 34 | --is_training 1 \ 35 | --lradj TST \ 36 | --patience 3 \ 37 | --root_path ./dataset/ETT-small/ \ 38 | --data_path ETTh2.csv \ 39 | --model_id ETTh2 \ 40 | --model $model_name \ 41 | --data ETTh2 \ 42 | --features M \ 43 | --seq_len 96 \ 44 | --pred_len 192 \ 45 | --e_layers 1 \ 46 | --d_model 32 \ 47 | --d_ff 32 \ 48 | --learning_rate 0.006 \ 49 | --batch_size 256 \ 50 | --fix_seed 2025 \ 51 | --use_norm 1 \ 52 | --wv db1 \ 53 | --m 1 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --des Exp \ 58 | --itr 3 \ 59 | --alpha 0.1 \ 60 | --l1_weight 0.005 \ 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --lradj TST \ 65 | --patience 3 \ 66 | --root_path ./dataset/ETT-small/ \ 67 | --data_path ETTh2.csv \ 68 | --model_id ETTh2 \ 69 | --model $model_name \ 70 | --data ETTh2 \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --pred_len 336 \ 74 | --e_layers 1 \ 75 | --d_model 32 \ 76 | --d_ff 32 \ 77 | --learning_rate 0.003 \ 78 | --batch_size 256 \ 79 | --fix_seed 2025 \ 80 | --use_norm 1 \ 81 | --wv db1 \ 82 | --m 1 \ 83 | --enc_in 7 \ 84 | --dec_in 7 \ 85 | --c_out 7 \ 86 | --des Exp \ 87 | --itr 3 \ 88 | --alpha 0.9 \ 89 | --l1_weight 0.0 \ 90 | 91 | python -u run.py \ 92 | --is_training 1 \ 93 | --lradj TST \ 94 | --patience 3 \ 95 | --root_path ./dataset/ETT-small/ \ 96 | --data_path ETTh2.csv \ 97 | --model_id ETTh2 \ 98 | --model $model_name \ 99 | --data ETTh2 \ 100 | --features M \ 101 | --seq_len 96 \ 102 | --pred_len 720 \ 103 | --e_layers 1 \ 104 | --d_model 32 \ 105 | --d_ff 32 \ 106 | --learning_rate 0.003 \ 107 | --batch_size 256 \ 108 | --fix_seed 2025 \ 109 | --use_norm 1 \ 110 | --wv db1 \ 111 | --m 1 \ 112 | --enc_in 7 \ 113 | --dec_in 7 \ 114 | --c_out 7 \ 115 | --des Exp \ 116 | --itr 3 \ 117 | --alpha 1.0 \ 118 | --l1_weight 5e-05 \ -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/SimpleTM_m1.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 3 \ 8 | --root_path ./dataset/ETT-small/ \ 9 | --data_path ETTm1.csv \ 10 | --model_id ETTm1 \ 11 | --model "$model_name" \ 12 | --data ETTm1 \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 1 \ 17 | --d_model 32 \ 18 | --d_ff 32 \ 19 | --learning_rate 0.02 \ 20 | --batch_size 256 \ 21 | --fix_seed 2025 \ 22 | --use_norm 1 \ 23 | --wv 'db1' \ 24 | --m 3 \ 25 | --enc_in 7 \ 26 | --dec_in 7 \ 27 | --c_out 7 \ 28 | --des 'Exp' \ 29 | --itr 3 \ 30 | --alpha 0.1 \ 31 | --l1_weight 0.005 32 | 33 | python -u run.py \ 34 | --is_training 1 \ 35 | --lradj 'TST' \ 36 | --patience 3 \ 37 | --root_path ./dataset/ETT-small/ \ 38 | --data_path ETTm1.csv \ 39 | --model_id ETTm1 \ 40 | --model "$model_name" \ 41 | --data ETTm1 \ 42 | --features M \ 43 | --seq_len 96 \ 44 | --pred_len 192 \ 45 | --e_layers 1 \ 46 | --d_model 32 \ 47 | --d_ff 32 \ 48 | --learning_rate 0.02 \ 49 | --batch_size 256 \ 50 | --fix_seed 2025 \ 51 | --use_norm 1 \ 52 | --wv 'db1' \ 53 | --m 3 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --des 'Exp' \ 58 | --itr 3 \ 59 | --alpha 0.1 \ 60 | --l1_weight 0.005 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --lradj 'TST' \ 65 | --patience 3 \ 66 | --root_path ./dataset/ETT-small/ \ 67 | --data_path ETTm1.csv \ 68 | --model_id ETTm1 \ 69 | --model "$model_name" \ 70 | --data ETTm1 \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --pred_len 336 \ 74 | --e_layers 1 \ 75 | --d_model 32 \ 76 | --d_ff 32 \ 77 | --learning_rate 0.02 \ 78 | --batch_size 256 \ 79 | --fix_seed 2025 \ 80 | --use_norm 1 \ 81 | --wv 'db1' \ 82 | --m 1 \ 83 | --enc_in 7 \ 84 | --dec_in 7 \ 85 | --c_out 7 \ 86 | --des 'Exp' \ 87 | --itr 3 \ 88 | --alpha 0.1 \ 89 | --l1_weight 0.005 90 | 91 | python -u run.py \ 92 | --is_training 1 \ 93 | --lradj 'TST' \ 94 | --patience 3 \ 95 | --root_path ./dataset/ETT-small/ \ 96 | --data_path ETTm1.csv \ 97 | --model_id ETTm1 \ 98 | --model "$model_name" \ 99 | --data ETTm1 \ 100 | --features M \ 101 | --seq_len 96 \ 102 | --pred_len 720 \ 103 | --e_layers 1 \ 104 | --d_model 32 \ 105 | --d_ff 32 \ 106 | --learning_rate 0.02 \ 107 | --batch_size 256 \ 108 | --fix_seed 2025 \ 109 | --use_norm 1 \ 110 | --wv 'db1' \ 111 | --m 3 \ 112 | --enc_in 7 \ 113 | --dec_in 7 \ 114 | --c_out 7 \ 115 | --des 'Exp' \ 116 | --itr 3 \ 117 | --alpha 0.1 \ 118 | --l1_weight 0.005 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/SimpleTM_m2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 3 \ 8 | --root_path ./dataset/ETT-small/ \ 9 | --data_path ETTm2.csv \ 10 | --model_id ETTm2 \ 11 | --model "$model_name" \ 12 | --data ETTm2 \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 1 \ 17 | --d_model 32 \ 18 | --d_ff 32 \ 19 | --learning_rate 0.006\ 20 | --batch_size 256 \ 21 | --fix_seed 2025 \ 22 | --use_norm 1 \ 23 | --wv "bior3.1" \ 24 | --m 3 \ 25 | --enc_in 7 \ 26 | --dec_in 7 \ 27 | --c_out 7 \ 28 | --des 'Exp' \ 29 | --itr 3 \ 30 | --alpha 0.3 \ 31 | --l1_weight 0.0005 32 | 33 | python -u run.py \ 34 | --is_training 1 \ 35 | --lradj 'TST' \ 36 | --patience 3 \ 37 | --root_path ./dataset/ETT-small/ \ 38 | --data_path ETTm2.csv \ 39 | --model_id ETTm2 \ 40 | --model "$model_name" \ 41 | --data ETTm2 \ 42 | --features M \ 43 | --seq_len 96 \ 44 | --pred_len 192 \ 45 | --e_layers 1 \ 46 | --d_model 32 \ 47 | --d_ff 32 \ 48 | --learning_rate 0.006\ 49 | --batch_size 256 \ 50 | --fix_seed 2025 \ 51 | --use_norm 1 \ 52 | --wv "bior3.1" \ 53 | --m 1 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --des 'Exp' \ 58 | --itr 3 \ 59 | --alpha 0.0 \ 60 | --l1_weight 0.005 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --lradj 'TST' \ 65 | --patience 3 \ 66 | --root_path ./dataset/ETT-small/ \ 67 | --data_path ETTm2.csv \ 68 | --model_id ETTm2 \ 69 | --model "$model_name" \ 70 | --data ETTm2 \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --pred_len 336 \ 74 | --e_layers 1 \ 75 | --d_model 64 \ 76 | --d_ff 64 \ 77 | --learning_rate 0.006\ 78 | --batch_size 128 \ 79 | --fix_seed 2025 \ 80 | --use_norm 1 \ 81 | --wv "bior3.3" \ 82 | --m 1 \ 83 | --enc_in 7 \ 84 | --dec_in 7 \ 85 | --c_out 7 \ 86 | --des 'Exp' \ 87 | --itr 3 \ 88 | --alpha 0.6 \ 89 | --l1_weight 5e-5 90 | 91 | python -u run.py \ 92 | --is_training 1 \ 93 | --lradj 'TST' \ 94 | --patience 3 \ 95 | --root_path ./dataset/ETT-small/ \ 96 | --data_path ETTm2.csv \ 97 | --model_id ETTm2 \ 98 | --model "$model_name" \ 99 | --data ETTm2 \ 100 | --features M \ 101 | --seq_len 96 \ 102 | --pred_len 720 \ 103 | --e_layers 1 \ 104 | --d_model 96 \ 105 | --d_ff 96 \ 106 | --learning_rate 0.003\ 107 | --batch_size 256 \ 108 | --fix_seed 2025 \ 109 | --use_norm 1 \ 110 | --wv "db1" \ 111 | --m 3 \ 112 | --enc_in 7 \ 113 | --dec_in 7 \ 114 | --c_out 7 \ 115 | --des 'Exp' \ 116 | --itr 3 \ 117 | --alpha 1.0 \ 118 | --l1_weight 0.0 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/PEMS/SimpleTM_03.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 10 \ 8 | --train_epochs 20 \ 9 | --root_path ./dataset/PEMS/ \ 10 | --data_path PEMS03.npz \ 11 | --model_id PEMS03 \ 12 | --model "$model_name" \ 13 | --data PEMS \ 14 | --features M \ 15 | --seq_len 96 \ 16 | --pred_len 12 \ 17 | --e_layers 1 \ 18 | --d_model 256 \ 19 | --d_ff 512 \ 20 | --learning_rate 0.002 \ 21 | --batch_size 16 \ 22 | --fix_seed 2025 \ 23 | --use_norm 1 \ 24 | --wv 'bior3.1' \ 25 | --m 3 \ 26 | --enc_in 358 \ 27 | --dec_in 358 \ 28 | --c_out 358 \ 29 | --des 'Exp' \ 30 | --itr 3 \ 31 | --alpha 0.1 \ 32 | --use_norm 0 \ 33 | --l1_weight 0.005 34 | 35 | python -u run.py \ 36 | --is_training 1 \ 37 | --lradj 'TST' \ 38 | --patience 10 \ 39 | --train_epochs 20 \ 40 | --root_path ./dataset/PEMS/ \ 41 | --data_path PEMS03.npz \ 42 | --model_id PEMS03 \ 43 | --model "$model_name" \ 44 | --data PEMS \ 45 | --features M \ 46 | --seq_len 96 \ 47 | --pred_len 24 \ 48 | --e_layers 1 \ 49 | --d_model 256 \ 50 | --d_ff 512 \ 51 | --learning_rate 0.002 \ 52 | --batch_size 16 \ 53 | --fix_seed 2025 \ 54 | --use_norm 1 \ 55 | --wv 'bior3.1' \ 56 | --m 3 \ 57 | --enc_in 358 \ 58 | --dec_in 358 \ 59 | --c_out 358 \ 60 | --des 'Exp' \ 61 | --itr 3 \ 62 | --alpha 0.1 \ 63 | --use_norm 0 \ 64 | --l1_weight 0.005 65 | 66 | python -u run.py \ 67 | --is_training 1 \ 68 | --lradj 'TST' \ 69 | --patience 10 \ 70 | --train_epochs 20 \ 71 | --root_path ./dataset/PEMS/ \ 72 | --data_path PEMS03.npz \ 73 | --model_id PEMS03 \ 74 | --model "$model_name" \ 75 | --data PEMS \ 76 | --features M \ 77 | --seq_len 96 \ 78 | --pred_len 48 \ 79 | --e_layers 1 \ 80 | --d_model 256 \ 81 | --d_ff 1024 \ 82 | --learning_rate 0.002 \ 83 | --batch_size 16 \ 84 | --fix_seed 2025 \ 85 | --use_norm 0 \ 86 | --wv 'bior3.1' \ 87 | --m 3 \ 88 | --enc_in 358 \ 89 | --dec_in 358 \ 90 | --c_out 358 \ 91 | --des 'Exp' \ 92 | --itr 3 \ 93 | --alpha 0.1 \ 94 | --l1_weight 0.005 95 | 96 | python -u run.py \ 97 | --is_training 1 \ 98 | --lradj 'TST' \ 99 | --patience 10 \ 100 | --train_epochs 20 \ 101 | --root_path ./dataset/PEMS/ \ 102 | --data_path PEMS03.npz \ 103 | --model_id PEMS03 \ 104 | --model "$model_name" \ 105 | --data PEMS \ 106 | --features M \ 107 | --seq_len 96 \ 108 | --pred_len 96 \ 109 | --e_layers 1 \ 110 | --d_model 256 \ 111 | --d_ff 1024 \ 112 | --learning_rate 0.002 \ 113 | --batch_size 16 \ 114 | --fix_seed 2025 \ 115 | --use_norm 0 \ 116 | --wv 'bior3.1' \ 117 | --m 3 \ 118 | --enc_in 358 \ 119 | --dec_in 358 \ 120 | --c_out 358 \ 121 | --des 'Exp' \ 122 | --itr 3 \ 123 | --alpha 0.1 \ 124 | --l1_weight 0.005 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/PEMS/SimpleTM_04.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 10 \ 8 | --train_epochs 20 \ 9 | --root_path ./dataset/PEMS/ \ 10 | --data_path PEMS04.npz \ 11 | --model_id PEMS04 \ 12 | --model "$model_name" \ 13 | --data PEMS \ 14 | --features M \ 15 | --seq_len 96 \ 16 | --pred_len 12 \ 17 | --e_layers 2 \ 18 | --d_model 256 \ 19 | --d_ff 1024 \ 20 | --learning_rate 0.002 \ 21 | --batch_size 16 \ 22 | --fix_seed 2025 \ 23 | --use_norm 0 \ 24 | --wv 'bior3.1' \ 25 | --m 3 \ 26 | --enc_in 307 \ 27 | --dec_in 307 \ 28 | --c_out 307 \ 29 | --des 'Exp' \ 30 | --itr 3 \ 31 | --alpha 0.1 \ 32 | --l1_weight 5e-05 33 | 34 | python -u run.py \ 35 | --is_training 1 \ 36 | --lradj 'TST' \ 37 | --patience 10 \ 38 | --train_epochs 20 \ 39 | --root_path ./dataset/PEMS/ \ 40 | --data_path PEMS04.npz \ 41 | --model_id PEMS04 \ 42 | --model "$model_name" \ 43 | --data PEMS \ 44 | --features M \ 45 | --seq_len 96 \ 46 | --pred_len 24 \ 47 | --e_layers 1 \ 48 | --d_model 256 \ 49 | --d_ff 1024 \ 50 | --learning_rate 0.002 \ 51 | --batch_size 16 \ 52 | --fix_seed 2025 \ 53 | --use_norm 0 \ 54 | --wv 'bior3.1' \ 55 | --m 3 \ 56 | --enc_in 307 \ 57 | --dec_in 307 \ 58 | --c_out 307 \ 59 | --des 'Exp' \ 60 | --itr 3 \ 61 | --alpha 0.1 \ 62 | --l1_weight 5e-05 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --lradj 'TST' \ 67 | --patience 10 \ 68 | --train_epochs 20 \ 69 | --root_path ./dataset/PEMS/ \ 70 | --data_path PEMS04.npz \ 71 | --model_id PEMS04 \ 72 | --model "$model_name" \ 73 | --data PEMS \ 74 | --features M \ 75 | --seq_len 96 \ 76 | --pred_len 48 \ 77 | --e_layers 1 \ 78 | --d_model 256 \ 79 | --d_ff 1024 \ 80 | --learning_rate 0.002 \ 81 | --batch_size 16 \ 82 | --fix_seed 2025 \ 83 | --use_norm 0 \ 84 | --wv 'bior3.1' \ 85 | --m 3 \ 86 | --enc_in 307 \ 87 | --dec_in 307 \ 88 | --c_out 307 \ 89 | --des 'Exp' \ 90 | --itr 3 \ 91 | --alpha 0.1 \ 92 | --l1_weight 5e-05 93 | 94 | python -u run.py \ 95 | --is_training 1 \ 96 | --lradj 'TST' \ 97 | --patience 10 \ 98 | --train_epochs 20 \ 99 | --root_path ./dataset/PEMS/ \ 100 | --data_path PEMS04.npz \ 101 | --model_id PEMS04 \ 102 | --model "$model_name" \ 103 | --data PEMS \ 104 | --features M \ 105 | --seq_len 96 \ 106 | --pred_len 96 \ 107 | --e_layers 1 \ 108 | --d_model 256 \ 109 | --d_ff 1024 \ 110 | --learning_rate 0.002 \ 111 | --batch_size 16 \ 112 | --fix_seed 2025 \ 113 | --use_norm 0 \ 114 | --wv 'bior3.1' \ 115 | --m 3 \ 116 | --enc_in 307 \ 117 | --dec_in 307 \ 118 | --c_out 307 \ 119 | --des 'Exp' \ 120 | --itr 3 \ 121 | --alpha 0.1 \ 122 | --l1_weight 5e-05 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/PEMS/SimpleTM_07.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 10 \ 8 | --train_epochs 20 \ 9 | --root_path ./dataset/PEMS/ \ 10 | --data_path PEMS07.npz \ 11 | --model_id PEMS07 \ 12 | --model "$model_name" \ 13 | --data PEMS \ 14 | --features M \ 15 | --seq_len 96 \ 16 | --pred_len 12 \ 17 | --e_layers 1 \ 18 | --d_model 256 \ 19 | --d_ff 512 \ 20 | --learning_rate 0.002 \ 21 | --batch_size 16 \ 22 | --fix_seed 2025 \ 23 | --use_norm 0 \ 24 | --wv 'db1' \ 25 | --m 3 \ 26 | --enc_in 883 \ 27 | --dec_in 883 \ 28 | --c_out 883 \ 29 | --des 'Exp' \ 30 | --itr 3 \ 31 | --alpha 0.1 \ 32 | --l1_weight 5e-05 33 | 34 | python -u run.py \ 35 | --is_training 1 \ 36 | --lradj 'TST' \ 37 | --patience 10 \ 38 | --train_epochs 20 \ 39 | --root_path ./dataset/PEMS/ \ 40 | --data_path PEMS07.npz \ 41 | --model_id PEMS07 \ 42 | --model "$model_name" \ 43 | --data PEMS \ 44 | --features M \ 45 | --seq_len 96 \ 46 | --pred_len 24 \ 47 | --e_layers 1 \ 48 | --d_model 256 \ 49 | --d_ff 512 \ 50 | --learning_rate 0.002 \ 51 | --batch_size 16 \ 52 | --fix_seed 2025 \ 53 | --use_norm 0 \ 54 | --wv 'db1' \ 55 | --m 3 \ 56 | --enc_in 883 \ 57 | --dec_in 883 \ 58 | --c_out 883 \ 59 | --des 'Exp' \ 60 | --itr 3 \ 61 | --alpha 0.1 \ 62 | --l1_weight 5e-5 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --lradj 'TST' \ 67 | --patience 10 \ 68 | --train_epochs 20 \ 69 | --root_path ./dataset/PEMS/ \ 70 | --data_path PEMS07.npz \ 71 | --model_id PEMS07 \ 72 | --model "$model_name" \ 73 | --data PEMS \ 74 | --features M \ 75 | --seq_len 96 \ 76 | --pred_len 48 \ 77 | --e_layers 1 \ 78 | --d_model 256 \ 79 | --d_ff 512 \ 80 | --learning_rate 0.002 \ 81 | --batch_size 16 \ 82 | --fix_seed 2025 \ 83 | --use_norm 0 \ 84 | --wv 'db1' \ 85 | --m 3 \ 86 | --enc_in 883 \ 87 | --dec_in 883 \ 88 | --c_out 883 \ 89 | --des 'Exp' \ 90 | --itr 3 \ 91 | --alpha 0.1 \ 92 | --l1_weight 5e-05 93 | 94 | python -u run.py \ 95 | --is_training 1 \ 96 | --lradj 'TST' \ 97 | --patience 10 \ 98 | --train_epochs 20 \ 99 | --root_path ./dataset/PEMS/ \ 100 | --data_path PEMS07.npz \ 101 | --model_id PEMS07 \ 102 | --model "$model_name" \ 103 | --data PEMS \ 104 | --features M \ 105 | --seq_len 96 \ 106 | --pred_len 96 \ 107 | --e_layers 1 \ 108 | --d_model 256 \ 109 | --d_ff 512 \ 110 | --learning_rate 0.002 \ 111 | --batch_size 16 \ 112 | --fix_seed 2025 \ 113 | --use_norm 0 \ 114 | --wv 'db1' \ 115 | --m 3 \ 116 | --enc_in 883 \ 117 | --dec_in 883 \ 118 | --c_out 883 \ 119 | --des 'Exp' \ 120 | --itr 3 \ 121 | --alpha 0.1 \ 122 | --l1_weight 5e-5 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/PEMS/SimpleTM_08.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 10 \ 8 | --train_epochs 20 \ 9 | --root_path ./dataset/PEMS/ \ 10 | --data_path PEMS08.npz \ 11 | --model_id PEMS08 \ 12 | --model "$model_name" \ 13 | --data PEMS \ 14 | --features M \ 15 | --seq_len 96 \ 16 | --pred_len 12 \ 17 | --e_layers 1 \ 18 | --d_model 256 \ 19 | --d_ff 512 \ 20 | --learning_rate 0.001 \ 21 | --batch_size 16 \ 22 | --fix_seed 2025 \ 23 | --use_norm 0 \ 24 | --wv 'db12' \ 25 | --m 3 \ 26 | --enc_in 170 \ 27 | --dec_in 170 \ 28 | --c_out 170 \ 29 | --des 'Exp' \ 30 | --itr 3 \ 31 | --alpha 0.0 \ 32 | --l1_weight 0.0 33 | 34 | python -u run.py \ 35 | --is_training 1 \ 36 | --lradj 'TST' \ 37 | --patience 10 \ 38 | --train_epochs 20 \ 39 | --root_path ./dataset/PEMS/ \ 40 | --data_path PEMS08.npz \ 41 | --model_id PEMS08 \ 42 | --model "$model_name" \ 43 | --data PEMS \ 44 | --features M \ 45 | --seq_len 96 \ 46 | --pred_len 24 \ 47 | --e_layers 1 \ 48 | --d_model 256 \ 49 | --d_ff 512 \ 50 | --learning_rate 0.001 \ 51 | --batch_size 16 \ 52 | --fix_seed 2025 \ 53 | --use_norm 0 \ 54 | --wv 'db12' \ 55 | --m 3 \ 56 | --enc_in 170 \ 57 | --dec_in 170 \ 58 | --c_out 170 \ 59 | --des 'Exp' \ 60 | --itr 3 \ 61 | --alpha 0.0 \ 62 | --l1_weight 0.0 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --lradj 'TST' \ 67 | --patience 10 \ 68 | --train_epochs 20 \ 69 | --root_path ./dataset/PEMS/ \ 70 | --data_path PEMS08.npz \ 71 | --model_id PEMS08 \ 72 | --model "$model_name" \ 73 | --data PEMS \ 74 | --features M \ 75 | --seq_len 96 \ 76 | --pred_len 48 \ 77 | --e_layers 1 \ 78 | --d_model 256 \ 79 | --d_ff 512 \ 80 | --learning_rate 0.001 \ 81 | --batch_size 16 \ 82 | --fix_seed 2025 \ 83 | --use_norm 0 \ 84 | --wv 'db12' \ 85 | --m 3 \ 86 | --enc_in 170 \ 87 | --dec_in 170 \ 88 | --c_out 170 \ 89 | --des 'Exp' \ 90 | --itr 3 \ 91 | --alpha 0.0 \ 92 | --l1_weight 0.0 93 | 94 | python -u run.py \ 95 | --is_training 1 \ 96 | --lradj 'TST' \ 97 | --patience 10 \ 98 | --train_epochs 20 \ 99 | --root_path ./dataset/PEMS/ \ 100 | --data_path PEMS08.npz \ 101 | --model_id PEMS08 \ 102 | --model "$model_name" \ 103 | --data PEMS \ 104 | --features M \ 105 | --seq_len 96 \ 106 | --pred_len 96 \ 107 | --e_layers 1 \ 108 | --d_model 256 \ 109 | --d_ff 1024 \ 110 | --learning_rate 0.001 \ 111 | --batch_size 16 \ 112 | --fix_seed 2025 \ 113 | --use_norm 0 \ 114 | --wv 'db12' \ 115 | --m 3 \ 116 | --enc_in 170 \ 117 | --dec_in 170 \ 118 | --c_out 170 \ 119 | --des 'Exp' \ 120 | --itr 3 \ 121 | --alpha 0.0 \ 122 | --l1_weight 0.0 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/SolarEnergy/SimpleTM.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 3 \ 8 | --root_path ./dataset/solar/ \ 9 | --data_path solar_AL.txt \ 10 | --model_id Solar \ 11 | --model "$model_name" \ 12 | --data Solar \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 1 \ 17 | --d_model 128 \ 18 | --d_ff 256 \ 19 | --learning_rate 0.01 \ 20 | --batch_size 256 \ 21 | --fix_seed 2025 \ 22 | --use_norm 0 \ 23 | --wv "db8" \ 24 | --m 3 \ 25 | --enc_in 137 \ 26 | --dec_in 137 \ 27 | --c_out 137 \ 28 | --des 'Exp' \ 29 | --itr 3 \ 30 | --use_norm 0 \ 31 | --alpha 0.0 \ 32 | --l1_weight 0.005 33 | 34 | python -u run.py \ 35 | --is_training 1 \ 36 | --lradj 'TST' \ 37 | --patience 3 \ 38 | --root_path ./dataset/solar/ \ 39 | --data_path solar_AL.txt \ 40 | --model_id Solar \ 41 | --model "$model_name" \ 42 | --data Solar \ 43 | --features M \ 44 | --seq_len 96 \ 45 | --pred_len 192 \ 46 | --e_layers 1 \ 47 | --d_model 128 \ 48 | --d_ff 256 \ 49 | --learning_rate 0.003 \ 50 | --batch_size 256 \ 51 | --fix_seed 2025 \ 52 | --use_norm 0 \ 53 | --wv "db8" \ 54 | --m 1 \ 55 | --enc_in 137 \ 56 | --dec_in 137 \ 57 | --c_out 137 \ 58 | --des 'Exp' \ 59 | --itr 2 \ 60 | --use_norm 0 \ 61 | --alpha 0.0 \ 62 | --l1_weight 0.005 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --lradj 'TST' \ 67 | --patience 3 \ 68 | --root_path ./dataset/solar/ \ 69 | --data_path solar_AL.txt \ 70 | --model_id Solar \ 71 | --model "$model_name" \ 72 | --data Solar \ 73 | --features M \ 74 | --seq_len 96 \ 75 | --pred_len 336 \ 76 | --e_layers 1 \ 77 | --d_model 128 \ 78 | --d_ff 256 \ 79 | --learning_rate 0.003 \ 80 | --batch_size 256 \ 81 | --fix_seed 2025 \ 82 | --use_norm 0 \ 83 | --wv "db8" \ 84 | --m 1 \ 85 | --enc_in 137 \ 86 | --dec_in 137 \ 87 | --c_out 137 \ 88 | --des 'Exp' \ 89 | --itr 2 \ 90 | --use_norm 0 \ 91 | --alpha 0.1 \ 92 | --l1_weight 0.005 93 | 94 | python -u run.py \ 95 | --is_training 1 \ 96 | --lradj 'TST' \ 97 | --patience 3 \ 98 | --root_path ./dataset/solar/ \ 99 | --data_path solar_AL.txt \ 100 | --model_id Solar \ 101 | --model "$model_name" \ 102 | --data Solar \ 103 | --features M \ 104 | --seq_len 96 \ 105 | --pred_len 720 \ 106 | --e_layers 1 \ 107 | --d_model 128 \ 108 | --d_ff 256 \ 109 | --learning_rate 0.009 \ 110 | --batch_size 256 \ 111 | --fix_seed 2025 \ 112 | --use_norm 0 \ 113 | --wv "db8" \ 114 | --m 1 \ 115 | --enc_in 137 \ 116 | --dec_in 137 \ 117 | --c_out 137 \ 118 | --des 'Exp' \ 119 | --itr 3 \ 120 | --use_norm 0 \ 121 | --alpha 0.0 \ 122 | --l1_weight 0.005 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/Traffic/SimpleTM.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 3 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id Traffic \ 11 | --model "$model_name" \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --d_model 512 \ 18 | --d_ff 1024 \ 19 | --learning_rate 0.003 \ 20 | --batch_size 24 \ 21 | --fix_seed 2025 \ 22 | --use_norm 1 \ 23 | --wv "db1" \ 24 | --m 3 \ 25 | --enc_in 862 \ 26 | --dec_in 862 \ 27 | --c_out 862 \ 28 | --des 'Exp' \ 29 | --itr 1 \ 30 | --alpha 0.1\ 31 | --l1_weight 0.0 \ 32 | 33 | python -u run.py \ 34 | --is_training 1 \ 35 | --lradj 'TST' \ 36 | --patience 3 \ 37 | --root_path ./dataset/traffic/ \ 38 | --data_path traffic.csv \ 39 | --model_id Traffic \ 40 | --model "$model_name" \ 41 | --data custom \ 42 | --features M \ 43 | --seq_len 96 \ 44 | --pred_len 192 \ 45 | --e_layers 1 \ 46 | --d_model 1024 \ 47 | --d_ff 2048 \ 48 | --learning_rate 0.0005 \ 49 | --batch_size 32 \ 50 | --fix_seed 2025 \ 51 | --use_norm 1 \ 52 | --wv "db1" \ 53 | --m 1 \ 54 | --enc_in 862 \ 55 | --dec_in 862 \ 56 | --c_out 862 \ 57 | --des 'Exp' \ 58 | --itr 1 \ 59 | --alpha 0.1\ 60 | --l1_weight 0.0 \ 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --lradj 'TST' \ 65 | --patience 3 \ 66 | --root_path ./dataset/traffic/ \ 67 | --data_path traffic.csv \ 68 | --model_id Traffic \ 69 | --model "$model_name" \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --pred_len 336 \ 74 | --e_layers 1 \ 75 | --d_model 1024 \ 76 | --d_ff 2048 \ 77 | --learning_rate 0.0005 \ 78 | --batch_size 32 \ 79 | --fix_seed 2025 \ 80 | --use_norm 1 \ 81 | --wv "db1" \ 82 | --m 1 \ 83 | --enc_in 862 \ 84 | --dec_in 862 \ 85 | --c_out 862 \ 86 | --des 'Exp' \ 87 | --itr 1 \ 88 | --alpha 0.1\ 89 | --l1_weight 0.0 \ 90 | 91 | ppython -u run.py \ 92 | --is_training 1 \ 93 | --lradj 'TST' \ 94 | --patience 3 \ 95 | --root_path ./dataset/traffic/ \ 96 | --data_path traffic.csv \ 97 | --model_id Traffic \ 98 | --model "$model_name" \ 99 | --data custom \ 100 | --features M \ 101 | --seq_len 96 \ 102 | --pred_len 720 \ 103 | --e_layers 1 \ 104 | --d_model 1024 \ 105 | --d_ff 2048 \ 106 | --learning_rate 0.0005 \ 107 | --batch_size 32 \ 108 | --fix_seed 2025 \ 109 | --use_norm 1 \ 110 | --wv "db1" \ 111 | --m 1 \ 112 | --enc_in 862 \ 113 | --dec_in 862 \ 114 | --c_out 862 \ 115 | --des 'Exp' \ 116 | --itr 1 \ 117 | --alpha 0.1\ 118 | --l1_weight 0.0 \ -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/Weather/SimpleTM.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=SimpleTM 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --lradj 'TST' \ 7 | --patience 3 \ 8 | --root_path ./dataset/weather/ \ 9 | --data_path weather.csv \ 10 | --model_id Weather \ 11 | --model "$model_name" \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 4 \ 17 | --d_model 32 \ 18 | --d_ff 32 \ 19 | --learning_rate 0.01 \ 20 | --batch_size 256 \ 21 | --fix_seed 2025 \ 22 | --use_norm 1 \ 23 | --wv "db4" \ 24 | --m 1 \ 25 | --enc_in 21 \ 26 | --dec_in 21 \ 27 | --c_out 21 \ 28 | --des 'Exp' \ 29 | --itr 3 \ 30 | --alpha 0.3 \ 31 | --l1_weight 5e-05 32 | 33 | python -u run.py \ 34 | --is_training 1 \ 35 | --lradj 'TST' \ 36 | --patience 3 \ 37 | --root_path ./dataset/weather/ \ 38 | --data_path weather.csv \ 39 | --model_id Weather \ 40 | --model "$model_name" \ 41 | --data custom \ 42 | --features M \ 43 | --seq_len 96 \ 44 | --pred_len 192 \ 45 | --e_layers 4 \ 46 | --d_model 32 \ 47 | --d_ff 32 \ 48 | --learning_rate 0.009 \ 49 | --batch_size 256 \ 50 | --fix_seed 2025 \ 51 | --use_norm 1 \ 52 | --wv "db4" \ 53 | --m 1 \ 54 | --enc_in 21 \ 55 | --dec_in 21 \ 56 | --c_out 21 \ 57 | --des 'Exp' \ 58 | --itr 3 \ 59 | --alpha 0.3 \ 60 | --l1_weight 0.0 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --lradj 'TST' \ 65 | --patience 3 \ 66 | --root_path ./dataset/weather/ \ 67 | --data_path weather.csv \ 68 | --model_id Weather \ 69 | --model "$model_name" \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --pred_len 336 \ 74 | --e_layers 1 \ 75 | --d_model 32 \ 76 | --d_ff 32 \ 77 | --learning_rate 0.009 \ 78 | --batch_size 256 \ 79 | --fix_seed 2025 \ 80 | --use_norm 1 \ 81 | --wv "db4" \ 82 | --m 3 \ 83 | --enc_in 21 \ 84 | --dec_in 21 \ 85 | --c_out 21 \ 86 | --des 'Exp' \ 87 | --itr 3 \ 88 | --alpha 1.0 \ 89 | --l1_weight 5e-05 90 | 91 | python -u run.py \ 92 | --is_training 1 \ 93 | --lradj 'TST' \ 94 | --patience 3 \ 95 | --root_path ./dataset/weather/ \ 96 | --data_path weather.csv \ 97 | --model_id Weather \ 98 | --model "$model_name" \ 99 | --data custom \ 100 | --features M \ 101 | --seq_len 96 \ 102 | --pred_len 720 \ 103 | --e_layers 1 \ 104 | --d_model 32 \ 105 | --d_ff 32 \ 106 | --learning_rate 0.02 \ 107 | --batch_size 256 \ 108 | --fix_seed 2025 \ 109 | --use_norm 1 \ 110 | --wv "db4" \ 111 | --m 1 \ 112 | --enc_in 21 \ 113 | --dec_in 21 \ 114 | --c_out 21 \ 115 | --des 'Exp' \ 116 | --itr 3 \ 117 | --alpha 0.9 \ 118 | --l1_weight 0.005 -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(pred - true)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((pred - true) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | # Troubleshooting for PEMS Nov 8 26 | # def MAPE(pred, true): 27 | # return np.mean(np.abs((pred - true) / true)) 28 | def MAPE(pred, true): 29 | mape = np.abs((pred - true) / true) 30 | mape = np.where(mape > 5, 0, mape) 31 | return np.mean(mape) 32 | 33 | 34 | def MSPE(pred, true): 35 | return np.mean(np.square((pred - true) / true)) 36 | 37 | 38 | def metric(pred, true): 39 | mae = MAE(pred, true) 40 | mse = MSE(pred, true) 41 | rmse = RMSE(pred, true) 42 | mape = MAPE(pred, true) 43 | mspe = MSPE(pred, true) 44 | 45 | return mae, mse, rmse, mape, mspe 46 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | 8 | plt.switch_backend('agg') 9 | 10 | 11 | def adjust_learning_rate(optimizer, epoch, args, scheduler=None, printout=True): 12 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 13 | if args.lradj == 'type1': 14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 15 | elif args.lradj == 'type2': 16 | lr_adjust = { 17 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 18 | 10: 5e-7, 15: 1e-7, 20: 5e-8 19 | } 20 | elif args.lradj == 'type3': 21 | lr_adjust = {epoch: args.learning_rate if epoch < 2 else args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 22 | elif args.lradj == 'constant': 23 | lr_adjust = {epoch: args.learning_rate * 1} 24 | elif args.lradj == 'TST': 25 | lr_adjust = {epoch: scheduler.get_last_lr()[0]} 26 | if epoch in lr_adjust.keys(): 27 | lr = lr_adjust[epoch] 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = lr 30 | if printout: print('Updating learning rate to {}'.format(lr)) 31 | 32 | 33 | class EarlyStopping: 34 | def __init__(self, patience=7, verbose=False, delta=0): 35 | self.patience = patience 36 | self.verbose = verbose 37 | self.counter = 0 38 | self.best_score = None 39 | self.early_stop = False 40 | self.val_loss_min = np.Inf 41 | self.delta = delta 42 | 43 | def __call__(self, val_loss, model, path): 44 | score = -val_loss 45 | if self.best_score is None: 46 | self.best_score = score 47 | self.save_checkpoint(val_loss, model, path) 48 | elif score < self.best_score + self.delta: 49 | self.counter += 1 50 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 51 | if self.counter >= self.patience: 52 | self.early_stop = True 53 | else: 54 | self.best_score = score 55 | self.save_checkpoint(val_loss, model, path) 56 | self.counter = 0 57 | 58 | def save_checkpoint(self, val_loss, model, path): 59 | if self.verbose: 60 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 61 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 62 | self.val_loss_min = val_loss 63 | 64 | 65 | class dotdict(dict): 66 | """dot.notation access to dictionary attributes""" 67 | __getattr__ = dict.get 68 | __setattr__ = dict.__setitem__ 69 | __delattr__ = dict.__delitem__ 70 | 71 | 72 | class StandardScaler(): 73 | def __init__(self, mean, std): 74 | self.mean = mean 75 | self.std = std 76 | 77 | def transform(self, data): 78 | return (data - self.mean) / self.std 79 | 80 | def inverse_transform(self, data): 81 | return (data * self.std) + self.mean 82 | 83 | 84 | def visual(true, preds=None, name='./pic/test.pdf'): 85 | """ 86 | Results visualization 87 | """ 88 | plt.figure() 89 | plt.plot(true, label='GroundTruth', linewidth=2) 90 | if preds is not None: 91 | plt.plot(preds, label='Prediction', linewidth=2) 92 | plt.legend() 93 | plt.savefig(name, bbox_inches='tight') 94 | 95 | 96 | def adjustment(gt, pred): 97 | anomaly_state = False 98 | for i in range(len(gt)): 99 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 100 | anomaly_state = True 101 | for j in range(i, 0, -1): 102 | if gt[j] == 0: 103 | break 104 | else: 105 | if pred[j] == 0: 106 | pred[j] = 1 107 | for j in range(i, len(gt)): 108 | if gt[j] == 0: 109 | break 110 | else: 111 | if pred[j] == 0: 112 | pred[j] = 1 113 | elif gt[i] == 0: 114 | anomaly_state = False 115 | if anomaly_state: 116 | pred[i] = 1 117 | return gt, pred 118 | 119 | 120 | def cal_accuracy(y_pred, y_true): 121 | return np.mean(y_pred == y_true) 122 | --------------------------------------------------------------------------------