├── .gitignore
├── Dockerfile
├── README.md
├── data_provider
├── data_factory.py
└── data_loader.py
├── environment.yml
├── experiments
├── exp_basic.py
└── exp_long_term_forecasting.py
├── figures
├── Efficiency.jpg
├── Framework.png
└── Long_term_forecast_results.jpg
├── layers
├── Embed.py
├── SWTAttention_Family.py
├── StandardNorm.py
└── Transformer_Encoder.py
├── model
└── SimpleTM.py
├── run.py
├── scripts
└── multivariate_forecasting
│ ├── ECL
│ └── SimpleTM.sh
│ ├── ETT
│ ├── SimpleTM_h1.sh
│ ├── SimpleTM_h2.sh
│ ├── SimpleTM_m1.sh
│ └── SimpleTM_m2.sh
│ ├── PEMS
│ ├── SimpleTM_03.sh
│ ├── SimpleTM_04.sh
│ ├── SimpleTM_07.sh
│ └── SimpleTM_08.sh
│ ├── SolarEnergy
│ └── SimpleTM.sh
│ ├── Traffic
│ └── SimpleTM.sh
│ └── Weather
│ └── SimpleTM.sh
└── utils
├── masking.py
├── metrics.py
├── timefeatures.py
└── tools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:22.10-py3
2 |
3 | ENV PYTHONPATH=/workspace
4 | ENV PYTHONUNBUFFERED=1
5 |
6 | RUN pip install --no-cache-dir \
7 | einops==0.8.1 \
8 | matplotlib==3.7.0 \
9 | numpy==1.23.5 \
10 | scikit-learn==1.2.2 \
11 | scipy==1.10.1 \
12 | pandas==1.5.3 \
13 | reformer-pytorch==1.4.4 \
14 | PyWavelets
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimpleTM
2 | The repo is the official implementation for the paper: [[ICLR '25] SimpleTM: A Simple Baseline For Multivariate Time Series Forcasting](https://openreview.net/pdf?id=oANkBaVci5).
3 |
4 |
5 | # Introduction
6 | We propose SimpleTM, a simple yet effective architecture that uniquely integrates classical signal processing ideas with a slightly modified attention mechanism.
7 |
8 |
9 |
10 |
11 |
12 | We show that even a single-layer configuration can effectively capture intricate dependencies in multivariate time-series data, while maintaining minimal model complexity and parameter requirements. This streamlined construction achieves a performance profile surpassing (or on par with) most existing baselines across nearly all publicly available benchmarks.
13 |
14 |
17 |
18 |
19 |
20 | Table 6: Complete results of the long-term forecasting task, with an input length of 96 for all tasks. The reported metrics include the averaged Mean Squared Error (MSE) and Mean Absolute Error (MAE) across four prediction horizons, where lower values indicate better model performance.
21 |
22 |
23 |
24 |
25 |
26 | Dataset |
27 | Horizon |
28 |
29 | SimpleTM (Ours) |
30 | TimeMixer (2024) |
31 | iTransformer (2024) |
32 | CrossGNN (2024) |
33 | RLinear (2023) |
34 | PatchTST (2023) |
35 | Crossformer (2023) |
36 | TiDE (2023) |
37 | TimesNet (2023) |
38 | DLinear (2023) |
39 | SCINet (2022) |
40 | FEDformer (2022) |
41 | Stationary (2022) |
42 | Autoformer (2021) |
43 |
44 |
45 |
46 | MSE |
47 | MAE |
48 | MSE |
49 | MAE |
50 | MSE |
51 | MAE |
52 | MSE |
53 | MAE |
54 | MSE |
55 | MAE |
56 | MSE |
57 | MAE |
58 | MSE |
59 | MAE |
60 | MSE |
61 | MAE |
62 | MSE |
63 | MAE |
64 | MSE |
65 | MAE |
66 | MSE |
67 | MAE |
68 | MSE |
69 | MAE |
70 | MSE |
71 | MAE |
72 | MSE |
73 | MAE |
74 |
75 |
76 |
77 |
78 |
79 | ETTm1 |
80 | 96 |
81 |
82 |
83 | 0.321
84 | |
85 |
86 | 0.361
87 | |
88 |
89 |
90 | 0.328
91 | |
92 |
93 | 0.363
94 | |
95 | 0.334 |
96 | 0.368 |
97 | 0.335 |
98 | 0.373 |
99 | 0.355 |
100 | 0.376 |
101 | 0.329 |
102 | 0.367 |
103 | 0.404 |
104 | 0.426 |
105 | 0.364 |
106 | 0.387 |
107 | 0.338 |
108 | 0.375 |
109 | 0.345 |
110 | 0.372 |
111 | 0.418 |
112 | 0.438 |
113 | 0.379 |
114 | 0.419 |
115 | 0.386 |
116 | 0.398 |
117 | 0.505 |
118 | 0.475 |
119 |
120 |
121 | 192 |
122 |
123 | 0.360
124 | |
125 |
126 | 0.380
127 | |
128 |
129 | 0.364
130 | |
131 |
132 | 0.384
133 | |
134 | 0.377 |
135 | 0.391 |
136 | 0.372 |
137 | 0.390 |
138 | 0.391 |
139 | 0.392 |
140 | 0.367 |
141 | 0.385 |
142 | 0.450 |
143 | 0.451 |
144 | 0.398 |
145 | 0.404 |
146 | 0.374 |
147 | 0.387 |
148 | 0.380 |
149 | 0.389 |
150 | 0.439 |
151 | 0.450 |
152 | 0.426 |
153 | 0.441 |
154 | 0.459 |
155 | 0.444 |
156 | 0.553 |
157 | 0.496 |
158 |
159 |
160 | 336 |
161 |
162 | 0.390
163 | |
164 |
165 | 0.404
166 | |
167 |
168 | 0.390
169 | |
170 |
171 | 0.404
172 | |
173 | 0.426 |
174 | 0.420 |
175 | 0.403 |
176 | 0.411 |
177 | 0.424 |
178 | 0.415 |
179 | 0.399 |
180 | 0.410 |
181 | 0.532 |
182 | 0.515 |
183 | 0.428 |
184 | 0.425 |
185 | 0.410 |
186 | 0.411 |
187 | 0.413 |
188 | 0.413 |
189 | 0.490 |
190 | 0.485 |
191 | 0.445 |
192 | 0.459 |
193 | 0.495 |
194 | 0.464 |
195 | 0.621 |
196 | 0.537 |
197 |
198 |
199 | 720 |
200 |
201 | 0.454
202 | |
203 |
204 | 0.438
205 | |
206 | 0.458 |
207 | 0.445 |
208 | 0.491 |
209 | 0.459 |
210 | 0.461 |
211 | 0.442 |
212 | 0.487 |
213 | 0.450 |
214 |
215 | 0.454
216 | |
217 |
218 | 0.439
219 | |
220 | 0.666 |
221 | 0.589 |
222 | 0.487 |
223 | 0.461 |
224 | 0.478 |
225 | 0.450 |
226 | 0.474 |
227 | 0.453 |
228 | 0.595 |
229 | 0.550 |
230 | 0.543 |
231 | 0.490 |
232 | 0.585 |
233 | 0.516 |
234 | 0.671 |
235 | 0.561 |
236 |
237 |
238 |
239 | Avg |
240 |
241 | 0.381
242 | |
243 |
244 | 0.396
245 | |
246 |
247 | 0.385
248 | |
249 |
250 | 0.399
251 | |
252 | 0.407 |
253 | 0.410 |
254 | 0.393 |
255 | 0.404 |
256 | 0.414 |
257 | 0.407 |
258 | 0.387 |
259 | 0.400 |
260 | 0.513 |
261 | 0.496 |
262 | 0.419 |
263 | 0.419 |
264 | 0.400 |
265 | 0.406 |
266 | 0.403 |
267 | 0.407 |
268 | 0.485 |
269 | 0.481 |
270 | 0.448 |
271 | 0.452 |
272 | 0.481 |
273 | 0.456 |
274 | 0.588 |
275 | 0.517 |
276 |
277 |
278 |
279 | ETTm2 |
280 | 96 |
281 |
282 | 0.173
283 | |
284 |
285 | 0.257
286 | |
287 | 0.176 |
288 | 0.259 |
289 | 0.180 |
290 | 0.264 |
291 | 0.176 |
292 | 0.266 |
293 | 0.182 |
294 | 0.265 |
295 |
296 | 0.175
297 | |
298 |
299 | 0.259
300 | |
301 | 0.287 |
302 | 0.366 |
303 | 0.207 |
304 | 0.305 |
305 | 0.187 |
306 | 0.267 |
307 | 0.193 |
308 | 0.292 |
309 | 0.286 |
310 | 0.377 |
311 | 0.203 |
312 | 0.287 |
313 | 0.192 |
314 | 0.274 |
315 | 0.255 |
316 | 0.339 |
317 |
318 |
319 | 192 |
320 |
321 | 0.238
322 | |
323 |
324 | 0.299
325 | |
326 | 0.242 |
327 | 0.303 |
328 | 0.250 |
329 | 0.309 |
330 |
331 | 0.240
332 | |
333 | 0.307 |
334 | 0.246 |
335 | 0.304 |
336 | 0.241 |
337 |
338 | 0.302
339 | |
340 | 0.414 |
341 | 0.492 |
342 | 0.290 |
343 | 0.364 |
344 | 0.249 |
345 | 0.309 |
346 | 0.284 |
347 | 0.362 |
348 | 0.399 |
349 | 0.445 |
350 | 0.269 |
351 | 0.328 |
352 | 0.280 |
353 | 0.339 |
354 | 0.281 |
355 | 0.340 |
356 |
357 |
358 | 336 |
359 |
360 | 0.296
361 | |
362 |
363 | 0.338
364 | |
365 |
366 | 0.304
367 | |
368 |
369 | 0.342
370 | |
371 | 0.311 |
372 | 0.348 |
373 | 0.304 |
374 | 0.345 |
375 | 0.307 |
376 | 0.342 |
377 | 0.305 |
378 | 0.343 |
379 | 0.597 |
380 | 0.542 |
381 | 0.377 |
382 | 0.422 |
383 | 0.321 |
384 | 0.351 |
385 | 0.369 |
386 | 0.427 |
387 | 0.637 |
388 | 0.591 |
389 | 0.325 |
390 | 0.366 |
391 | 0.334 |
392 | 0.361 |
393 | 0.339 |
394 | 0.372 |
395 |
396 |
397 | 720 |
398 |
399 | 0.393
400 | |
401 |
402 | 0.395
403 | |
404 |
405 | 0.393
406 | |
407 |
408 | 0.397
409 | |
410 | 0.412 |
411 | 0.407 |
412 | 0.406 |
413 | 0.400 |
414 | 0.407 |
415 | 0.398 |
416 | 0.402 |
417 | 0.400 |
418 | 1.730 |
419 | 1.042 |
420 | 0.558 |
421 | 0.524 |
422 | 0.408 |
423 | 0.403 |
424 | 0.554 |
425 | 0.522 |
426 | 0.960 |
427 | 0.735 |
428 | 0.421 |
429 | 0.415 |
430 | 0.417 |
431 | 0.413 |
432 | 0.433 |
433 | 0.432 |
434 |
435 |
436 | Avg |
437 |
438 | 0.275
439 | |
440 |
441 | 0.322
442 | |
443 |
444 | 0.278
445 | |
446 |
447 | 0.325
448 | |
449 | 0.288 |
450 | 0.332 |
451 | 0.282 |
452 | 0.330 |
453 | 0.286 |
454 | 0.327 |
455 | 0.281 |
456 | 0.326 |
457 | 0.757 |
458 | 0.610 |
459 | 0.358 |
460 | 0.404 |
461 | 0.291 |
462 | 0.333 |
463 | 0.350 |
464 | 0.401 |
465 | 0.571 |
466 | 0.537 |
467 | 0.305 |
468 | 0.349 |
469 | 0.306 |
470 | 0.347 |
471 | 0.327 |
472 | 0.371 |
473 |
474 |
475 |
476 | ETTh1 |
477 |
478 | 96 |
479 |
480 |
481 | 0.366
482 | |
483 |
484 | 0.392
485 | |
486 |
487 | 0.381 |
488 | 0.401 |
489 |
490 | 0.386 |
491 | 0.405 |
492 |
493 | 0.382 |
494 | 0.398 |
495 |
496 | 0.386 |
497 |
498 | 0.395
499 | |
500 |
501 | 0.414 |
502 | 0.419 |
503 |
504 | 0.423 |
505 | 0.448 |
506 |
507 | 0.479 |
508 | 0.464 |
509 |
510 | 0.384 |
511 | 0.402 |
512 |
513 | 0.386 |
514 | 0.400 |
515 |
516 | 0.654 |
517 | 0.599 |
518 |
519 |
520 | 0.376
521 | |
522 | 0.419 |
523 |
524 | 0.513 |
525 | 0.491 |
526 |
527 | 0.449 |
528 | 0.459 |
529 |
530 |
531 |
532 |
533 | 192 |
534 |
535 |
536 | 0.422
537 | |
538 |
539 | 0.421
540 | |
541 | 0.440 |
542 | 0.433 |
543 | 0.441 |
544 | 0.436 |
545 | 0.427 |
546 | 0.425 |
547 | 0.437 |
548 |
549 | 0.424
550 | |
551 | 0.460 |
552 | 0.445 |
553 | 0.471 |
554 | 0.474 |
555 | 0.525 |
556 | 0.492 |
557 | 0.436 |
558 | 0.429 |
559 | 0.437 |
560 | 0.432 |
561 | 0.719 |
562 | 0.631 |
563 |
564 | 0.420
565 | |
566 | 0.448 |
567 | 0.534 |
568 | 0.504 |
569 | 0.500 |
570 | 0.482 |
571 |
572 |
573 | 336 |
574 |
575 |
576 | 0.440
577 | |
578 |
579 | 0.438
580 | |
581 | 0.501 |
582 | 0.462 |
583 | 0.487 |
584 | 0.458 |
585 | 0.465 |
586 |
587 | 0.445
588 | |
589 | 0.479 |
590 | 0.446 |
591 | 0.501 |
592 | 0.466 |
593 | 0.570 |
594 | 0.546 |
595 | 0.565 |
596 | 0.515 |
597 | 0.491 |
598 | 0.469 |
599 | 0.481 |
600 | 0.459 |
601 | 0.778 |
602 | 0.659 |
603 |
604 | 0.459
605 | |
606 | 0.465 |
607 | 0.588 |
608 | 0.535 |
609 | 0.521 |
610 | 0.496 |
611 |
612 |
613 |
614 | 720 |
615 |
616 | 0.463
617 | |
618 |
619 | 0.462
620 | |
621 | 0.501 |
622 | 0.482 |
623 | 0.503 |
624 | 0.491 |
625 |
626 | 0.472
627 | |
628 |
629 | 0.468
630 | |
631 | 0.481 |
632 | 0.470 |
633 | 0.500 |
634 | 0.488 |
635 | 0.653 |
636 | 0.621 |
637 | 0.594 |
638 | 0.558 |
639 | 0.521 |
640 | 0.500 |
641 | 0.519 |
642 | 0.516 |
643 | 0.836 |
644 | 0.699 |
645 | 0.506 |
646 | 0.507 |
647 | 0.643 |
648 | 0.616 |
649 | 0.514 |
650 | 0.512 |
651 |
652 |
653 |
654 | Avg |
655 |
656 | 0.422
657 | |
658 |
659 | 0.428
660 | |
661 | 0.458 |
662 | 0.445 |
663 | 0.454 |
664 | 0.447 |
665 | 0.437 |
666 | 0.434 |
667 | 0.446 |
668 |
669 | 0.434
670 | |
671 | 0.469 |
672 | 0.454 |
673 | 0.529 |
674 | 0.522 |
675 | 0.541 |
676 | 0.507 |
677 | 0.458 |
678 | 0.450 |
679 | 0.456 |
680 | 0.452 |
681 | 0.747 |
682 | 0.647 |
683 |
684 | 0.440
685 | |
686 | 0.460 |
687 | 0.570 |
688 | 0.537 |
689 | 0.496 |
690 | 0.487 |
691 |
692 |
693 |
694 | ETTh2 |
695 |
696 | 96 |
697 |
698 |
699 | 0.281
700 | |
701 |
702 | 0.338
703 | |
704 |
705 | 0.292 |
706 | 0.343 |
707 |
708 | 0.297 |
709 | 0.349 |
710 |
711 | 0.309 |
712 | 0.359 |
713 |
714 |
715 | 0.288
716 | |
717 |
718 | 0.338
719 | |
720 |
721 | 0.302 |
722 | 0.348 |
723 |
724 | 0.745 |
725 | 0.584 |
726 |
727 | 0.400 |
728 | 0.440 |
729 |
730 | 0.340 |
731 | 0.374 |
732 |
733 | 0.333 |
734 | 0.387 |
735 |
736 | 0.707 |
737 | 0.621 |
738 |
739 | 0.358 |
740 | 0.397 |
741 |
742 | 0.476 |
743 | 0.458 |
744 |
745 | 0.346 |
746 | 0.388 |
747 |
748 |
749 |
750 | 192 |
751 |
752 | 0.355
753 | |
754 |
755 | 0.387
756 | |
757 | 0.374 |
758 | 0.395 |
759 | 0.380 |
760 | 0.400 |
761 | 0.390 |
762 | 0.406 |
763 |
764 | 0.374
765 | |
766 |
767 | 0.390
768 | |
769 | 0.388 |
770 | 0.400 |
771 | 0.877 |
772 | 0.656 |
773 | 0.528 |
774 | 0.509 |
775 | 0.402 |
776 | 0.414 |
777 | 0.477 |
778 | 0.476 |
779 | 0.860 |
780 | 0.689 |
781 | 0.429 |
782 | 0.439 |
783 | 0.512 |
784 | 0.493 |
785 | 0.456 |
786 | 0.452 |
787 |
788 |
789 |
790 | 336 |
791 |
792 | 0.365
793 | |
794 |
795 | 0.401
796 | |
797 | 0.428 |
798 | 0.433 |
799 | 0.428 |
800 | 0.432 |
801 | 0.426 |
802 | 0.444 |
803 |
804 | 0.415
805 | |
806 |
807 | 0.426
808 | |
809 | 0.426 |
810 | 0.433 |
811 | 1.043 |
812 | 0.731 |
813 | 0.643 |
814 | 0.571 |
815 | 0.452 |
816 | 0.452 |
817 | 0.594 |
818 | 0.541 |
819 | 1.000 |
820 | 0.744 |
821 | 0.496 |
822 | 0.487 |
823 | 0.552 |
824 | 0.551 |
825 | 0.482 |
826 | 0.486 |
827 |
828 |
829 |
830 | 720 |
831 |
832 | 0.413
833 | |
834 |
835 | 0.436
836 | |
837 | 0.454 |
838 | 0.458 |
839 | 0.427 |
840 | 0.445 |
841 | 0.445 |
842 | 0.444 |
843 |
844 | 0.420
845 | |
846 |
847 | 0.440
848 | |
849 | 0.431 |
850 | 0.446 |
851 | 1.104 |
852 | 0.763 |
853 | 0.874 |
854 | 0.679 |
855 | 0.462 |
856 | 0.468 |
857 | 0.831 |
858 | 0.657 |
859 | 1.249 |
860 | 0.838 |
861 | 0.463 |
862 | 0.474 |
863 | 0.562 |
864 | 0.560 |
865 | 0.515 |
866 | 0.511 |
867 |
868 |
869 |
870 | Avg |
871 |
872 | 0.353
873 | |
874 |
875 | 0.391
876 | |
877 | 0.384 |
878 | 0.407 |
879 | 0.383 |
880 | 0.407 |
881 | 0.393 |
882 | 0.413 |
883 |
884 | 0.374
885 | |
886 |
887 | 0.398
888 | |
889 | 0.387 |
890 | 0.407 |
891 | 0.942 |
892 | 0.684 |
893 | 0.611 |
894 | 0.550 |
895 | 0.414 |
896 | 0.427 |
897 | 0.559 |
898 | 0.515 |
899 | 0.954 |
900 | 0.723 |
901 | 0.437 |
902 | 0.449 |
903 | 0.526 |
904 | 0.516 |
905 | 0.450 |
906 | 0.459 |
907 |
908 |
909 |
910 |
911 | ECL |
912 | 96 |
913 |
914 |
915 | 0.141
916 | |
917 |
918 | 0.235
919 | |
920 |
921 | 0.153 |
922 | 0.244 |
923 |
924 |
925 | 0.148
926 | |
927 |
928 | 0.240
929 | |
930 |
931 | 0.173 |
932 | 0.275 |
933 |
934 | 0.201 |
935 | 0.281 |
936 |
937 | 0.181 |
938 | 0.270 |
939 |
940 | 0.219 |
941 | 0.314 |
942 |
943 | 0.237 |
944 | 0.329 |
945 |
946 | 0.168 |
947 | 0.272 |
948 |
949 | 0.197 |
950 | 0.282 |
951 |
952 | 0.247 |
953 | 0.345 |
954 |
955 | 0.193 |
956 | 0.308 |
957 |
958 | 0.169 |
959 | 0.273 |
960 |
961 | 0.201 |
962 | 0.317 |
963 |
964 |
965 |
966 | 192 |
967 |
968 |
969 | 0.151
970 | |
971 |
972 | 0.247
973 | |
974 |
975 | 0.166 |
976 | 0.256 |
977 |
978 |
979 | 0.162
980 | |
981 |
982 | 0.253
983 | |
984 | 0.195 |
985 | 0.288 |
986 | 0.201 |
987 | 0.283 |
988 | 0.188 |
989 | 0.274 |
990 | 0.231 |
991 | 0.322 |
992 | 0.236 |
993 | 0.330 |
994 | 0.184 |
995 | 0.289 |
996 | 0.196 |
997 | 0.285 |
998 | 0.257 |
999 | 0.355 |
1000 | 0.201 |
1001 | 0.315 |
1002 | 0.182 |
1003 | 0.286 |
1004 | 0.222 |
1005 | 0.334 |
1006 |
1007 |
1008 |
1009 | 336 |
1010 |
1011 |
1012 | 0.173
1013 | |
1014 |
1015 | 0.267
1016 | |
1017 |
1018 | 0.184 |
1019 | 0.275 |
1020 |
1021 |
1022 | 0.178
1023 | |
1024 |
1025 | 0.269
1026 | |
1027 | 0.206 |
1028 | 0.300 |
1029 | 0.215 |
1030 | 0.298 |
1031 | 0.204 |
1032 | 0.293 |
1033 | 0.246 |
1034 | 0.337 |
1035 | 0.249 |
1036 | 0.344 |
1037 | 0.198 |
1038 | 0.300 |
1039 | 0.209 |
1040 | 0.301 |
1041 | 0.269 |
1042 | 0.369 |
1043 | 0.214 |
1044 | 0.329 |
1045 | 0.200 |
1046 | 0.304 |
1047 | 0.231 |
1048 | 0.338 |
1049 |
1050 |
1051 |
1052 | 720 |
1053 |
1054 |
1055 | 0.201
1056 | |
1057 |
1058 | 0.293
1059 | |
1060 |
1061 | 0.226 |
1062 |
1063 | 0.313
1064 | |
1065 | 0.225 |
1066 | 0.317 |
1067 | 0.231 |
1068 | 0.335 |
1069 | 0.257 |
1070 | 0.331 |
1071 | 0.246 |
1072 | 0.324 |
1073 | 0.280 |
1074 | 0.363 |
1075 | 0.284 |
1076 | 0.373 |
1077 |
1078 | 0.220
1079 | |
1080 | 0.320 |
1081 | 0.245 |
1082 | 0.333 |
1083 | 0.299 |
1084 | 0.390 |
1085 | 0.246 |
1086 | 0.355 |
1087 | 0.222 |
1088 | 0.321 |
1089 | 0.254 |
1090 | 0.361 |
1091 |
1092 |
1093 |
1094 | Avg |
1095 |
1096 |
1097 | 0.166
1098 | |
1099 |
1100 | 0.260
1101 | |
1102 |
1103 |
1104 | 0.178
1105 | |
1106 |
1107 | 0.270
1108 | |
1109 | 0.201 |
1110 | 0.300 |
1111 | 0.219 |
1112 | 0.298 |
1113 | 0.205 |
1114 | 0.290 |
1115 | 0.244 |
1116 | 0.334 |
1117 | 0.251 |
1118 | 0.344 |
1119 | 0.192 |
1120 | 0.295 |
1121 | 0.212 |
1122 | 0.300 |
1123 | 0.268 |
1124 | 0.365 |
1125 | 0.214 |
1126 | 0.327 |
1127 | 0.193 |
1128 | 0.296 |
1129 | 0.227 |
1130 | 0.338 |
1131 |
1132 |
1133 |
1134 |
1135 |
1136 | Weather |
1137 |
1138 | 96 |
1139 |
1140 | 0.162 |
1141 |
1142 | 0.207
1143 | |
1144 |
1145 | 0.165 |
1146 |
1147 | 0.212
1148 | |
1149 |
1150 | 0.174 |
1151 | 0.214 |
1152 |
1153 |
1154 | 0.159
1155 | |
1156 | 0.218 |
1157 |
1158 | 0.192 |
1159 | 0.232 |
1160 |
1161 | 0.177 |
1162 | 0.218 |
1163 |
1164 |
1165 | 0.158
1166 | |
1167 | 0.230 |
1168 |
1169 | 0.202 |
1170 | 0.261 |
1171 |
1172 | 0.172 |
1173 | 0.220 |
1174 |
1175 | 0.196 |
1176 | 0.255 |
1177 |
1178 | 0.221 |
1179 | 0.306 |
1180 |
1181 | 0.217 |
1182 | 0.296 |
1183 |
1184 | 0.173 |
1185 | 0.223 |
1186 |
1187 | 0.266 |
1188 | 0.336 |
1189 |
1190 |
1191 |
1192 | 192 |
1193 |
1194 |
1195 | 0.208
1196 | |
1197 |
1198 | 0.248
1199 | |
1200 |
1201 | 0.209 |
1202 |
1203 | 0.253
1204 | |
1205 | 0.221 |
1206 | 0.254 |
1207 | 0.211 |
1208 | 0.266 |
1209 | 0.240 |
1210 | 0.271 |
1211 | 0.225 |
1212 | 0.259 |
1213 |
1214 |
1215 | 0.206
1216 | |
1217 | 0.277 |
1218 | 0.242 |
1219 | 0.298 |
1220 | 0.219 |
1221 | 0.261 |
1222 | 0.237 |
1223 | 0.296 |
1224 | 0.261 |
1225 | 0.340 |
1226 | 0.276 |
1227 | 0.336 |
1228 | 0.245 |
1229 | 0.285 |
1230 | 0.307 |
1231 | 0.367 |
1232 |
1233 |
1234 |
1235 | 336 |
1236 |
1237 |
1238 | 0.263
1239 | |
1240 |
1241 | 0.290
1242 | |
1243 |
1244 |
1245 | 0.264
1246 | |
1247 |
1248 | 0.293
1249 | |
1250 | 0.278 |
1251 | 0.296 |
1252 | 0.267 |
1253 | 0.310 |
1254 | 0.292 |
1255 | 0.307 |
1256 | 0.278 |
1257 | 0.297 |
1258 | 0.272 |
1259 | 0.335 |
1260 | 0.287 |
1261 | 0.335 |
1262 | 0.280 |
1263 | 0.306 |
1264 | 0.283 |
1265 | 0.335 |
1266 | 0.309 |
1267 | 0.378 |
1268 | 0.339 |
1269 | 0.380 |
1270 | 0.321 |
1271 | 0.338 |
1272 | 0.359 |
1273 | 0.395 |
1274 |
1275 |
1276 |
1277 | 720 |
1278 |
1279 |
1280 | 0.340
1281 | |
1282 |
1283 | 0.341
1284 | |
1285 |
1286 |
1287 | 0.342
1288 | |
1289 |
1290 | 0.345
1291 | |
1292 | 0.358 |
1293 | 0.347 |
1294 | 0.352 |
1295 | 0.362 |
1296 | 0.364 |
1297 | 0.353 |
1298 | 0.354 |
1299 | 0.348 |
1300 | 0.398 |
1301 | 0.418 |
1302 | 0.351 |
1303 | 0.386 |
1304 | 0.365 |
1305 | 0.359 |
1306 | 0.345 |
1307 | 0.381 |
1308 | 0.377 |
1309 | 0.427 |
1310 | 0.403 |
1311 | 0.428 |
1312 | 0.414 |
1313 | 0.410 |
1314 | 0.419 |
1315 | 0.428 |
1316 |
1317 |
1318 |
1319 | Avg |
1320 |
1321 |
1322 | 0.243
1323 | |
1324 |
1325 | 0.271
1326 | |
1327 |
1328 |
1329 | 0.245
1330 | |
1331 |
1332 | 0.276
1333 | |
1334 | 0.258 |
1335 | 0.278 |
1336 | 0.247 |
1337 | 0.289 |
1338 | 0.272 |
1339 | 0.291 |
1340 | 0.259 |
1341 | 0.281 |
1342 | 0.259 |
1343 | 0.315 |
1344 | 0.271 |
1345 | 0.320 |
1346 | 0.259 |
1347 | 0.287 |
1348 | 0.265 |
1349 | 0.317 |
1350 | 0.292 |
1351 | 0.363 |
1352 | 0.309 |
1353 | 0.360 |
1354 | 0.288 |
1355 | 0.314 |
1356 | 0.338 |
1357 | 0.382 |
1358 |
1359 |
1360 |
1361 |
1362 | Traffic |
1363 | 96 |
1364 |
1365 |
1366 | 0.410
1367 | |
1368 |
1369 | 0.274
1370 | |
1371 |
1372 | 0.464 |
1373 | 0.289 |
1374 |
1375 |
1376 | 0.395
1377 | |
1378 |
1379 | 0.268
1380 | |
1381 | 0.570 |
1382 | 0.310 |
1383 | 0.649 |
1384 | 0.389 |
1385 | 0.462 |
1386 | 0.295 |
1387 | 0.522 |
1388 | 0.290 |
1389 | 0.805 |
1390 | 0.493 |
1391 | 0.593 |
1392 | 0.321 |
1393 | 0.650 |
1394 | 0.396 |
1395 | 0.788 |
1396 | 0.499 |
1397 | 0.587 |
1398 | 0.366 |
1399 | 0.612 |
1400 | 0.338 |
1401 | 0.613 |
1402 | 0.388 |
1403 |
1404 |
1405 |
1406 | 192 |
1407 |
1408 | 0.430
1409 | |
1410 |
1411 | 0.280
1412 | |
1413 | 0.477 |
1414 | 0.292 |
1415 |
1416 |
1417 | 0.417
1418 | |
1419 |
1420 | 0.276
1421 | |
1422 | 0.577 |
1423 | 0.321 |
1424 | 0.601 |
1425 | 0.366 |
1426 | 0.466 |
1427 | 0.296 |
1428 | 0.530 |
1429 | 0.293 |
1430 | 0.756 |
1431 | 0.474 |
1432 | 0.617 |
1433 | 0.336 |
1434 | 0.598 |
1435 | 0.370 |
1436 | 0.789 |
1437 | 0.505 |
1438 | 0.604 |
1439 | 0.373 |
1440 | 0.613 |
1441 | 0.340 |
1442 | 0.616 |
1443 | 0.382 |
1444 |
1445 |
1446 |
1447 | 336 |
1448 |
1449 |
1450 | 0.449
1451 | |
1452 |
1453 | 0.290
1454 | |
1455 | 0.500 |
1456 | 0.305 |
1457 |
1458 |
1459 | 0.433
1460 | |
1461 |
1462 | 0.283
1463 | |
1464 | 0.588 |
1465 | 0.324 |
1466 | 0.609 |
1467 | 0.369 |
1468 | 0.482 |
1469 | 0.304 |
1470 | 0.558 |
1471 | 0.305 |
1472 | 0.762 |
1473 | 0.477 |
1474 | 0.629 |
1475 | 0.336 |
1476 | 0.605 |
1477 | 0.373 |
1478 | 0.797 |
1479 | 0.508 |
1480 | 0.621 |
1481 | 0.383 |
1482 | 0.618 |
1483 | 0.328 |
1484 | 0.622 |
1485 | 0.337 |
1486 |
1487 |
1488 |
1489 | 720 |
1490 |
1491 |
1492 | 0.486
1493 | |
1494 |
1495 | 0.309
1496 | |
1497 | 0.548 |
1498 | 0.313 |
1499 |
1500 |
1501 | 0.467
1502 | |
1503 |
1504 | 0.302
1505 | |
1506 | 0.597 |
1507 | 0.337 |
1508 | 0.647 |
1509 | 0.387 |
1510 | 0.514 |
1511 | 0.322 |
1512 | 0.589 |
1513 | 0.328 |
1514 | 0.719 |
1515 | 0.449 |
1516 | 0.640 |
1517 | 0.350 |
1518 | 0.645 |
1519 | 0.394 |
1520 | 0.841 |
1521 | 0.523 |
1522 | 0.626 |
1523 | 0.382 |
1524 | 0.653 |
1525 | 0.355 |
1526 | 0.660 |
1527 | 0.408 |
1528 |
1529 |
1530 |
1531 | Avg |
1532 |
1533 |
1534 | 0.444
1535 | |
1536 |
1537 | 0.289
1538 | |
1539 | 0.497 |
1540 | 0.300 |
1541 |
1542 |
1543 | 0.428
1544 | |
1545 |
1546 | 0.282
1547 | |
1548 | 0.583 |
1549 | 0.323 |
1550 | 0.626 |
1551 | 0.378 |
1552 | 0.481 |
1553 | 0.304 |
1554 | 0.550 |
1555 | 0.304 |
1556 | 0.760 |
1557 | 0.473 |
1558 | 0.620 |
1559 | 0.336 |
1560 | 0.625 |
1561 | 0.383 |
1562 | 0.804 |
1563 | 0.509 |
1564 | 0.610 |
1565 | 0.376 |
1566 | 0.624 |
1567 | 0.340 |
1568 | 0.628 |
1569 | 0.379 |
1570 |
1571 |
1572 |
1573 |
1574 | Solar-Energy |
1575 | 96 |
1576 |
1577 |
1578 | 0.163
1579 | |
1580 |
1581 | 0.232
1582 | |
1583 |
1584 | 0.215 |
1585 | 0.294 |
1586 |
1587 |
1588 | 0.203
1589 | |
1590 |
1591 | 0.237
1592 | |
1593 |
1594 | 0.222 |
1595 | 0.301 |
1596 |
1597 | 0.322 |
1598 | 0.339 |
1599 |
1600 | 0.234 |
1601 | 0.286 |
1602 |
1603 | 0.310 |
1604 | 0.331 |
1605 |
1606 | 0.312 |
1607 | 0.399 |
1608 |
1609 | 0.250 |
1610 | 0.292 |
1611 |
1612 | 0.290 |
1613 | 0.378 |
1614 |
1615 | 0.237 |
1616 | 0.344 |
1617 |
1618 | 0.242 |
1619 | 0.342 |
1620 |
1621 | 0.215 |
1622 | 0.249 |
1623 |
1624 | 0.884 |
1625 | 0.711 |
1626 |
1627 |
1628 |
1629 | 192 |
1630 |
1631 |
1632 | 0.182
1633 | |
1634 |
1635 | 0.247
1636 | |
1637 |
1638 | 0.237 |
1639 | 0.275 |
1640 |
1641 |
1642 | 0.233
1643 | |
1644 |
1645 | 0.261
1646 | |
1647 | 0.246 |
1648 | 0.307 |
1649 | 0.359 |
1650 | 0.356 |
1651 | 0.267 |
1652 | 0.310 |
1653 | 0.734 |
1654 | 0.725 |
1655 | 0.339 |
1656 | 0.416 |
1657 | 0.296 |
1658 | 0.318 |
1659 | 0.320 |
1660 | 0.398 |
1661 | 0.280 |
1662 | 0.380 |
1663 | 0.285 |
1664 | 0.380 |
1665 | 0.254 |
1666 | 0.272 |
1667 | 0.834 |
1668 | 0.692 |
1669 |
1670 |
1671 |
1672 | 336 |
1673 |
1674 |
1675 | 0.193
1676 | |
1677 |
1678 | 0.257
1679 | |
1680 |
1681 | 0.252 |
1682 | 0.298 |
1683 |
1684 |
1685 | 0.248
1686 | |
1687 |
1688 | 0.273
1689 | |
1690 | 0.263 |
1691 | 0.324 |
1692 | 0.397 |
1693 | 0.369 |
1694 | 0.290 |
1695 | 0.315 |
1696 | 0.750 |
1697 | 0.735 |
1698 | 0.368 |
1699 | 0.430 |
1700 | 0.319 |
1701 | 0.330 |
1702 | 0.353 |
1703 | 0.415 |
1704 | 0.304 |
1705 | 0.389 |
1706 | 0.282 |
1707 | 0.376 |
1708 | 0.290 |
1709 | 0.296 |
1710 | 0.941 |
1711 | 0.723 |
1712 |
1713 |
1714 |
1715 | 720 |
1716 |
1717 |
1718 | 0.199
1719 | |
1720 |
1721 | 0.252
1722 | |
1723 |
1724 | 0.244 |
1725 | 0.293 |
1726 |
1727 | 0.249
1728 | |
1729 |
1730 | 0.275
1731 | |
1732 | 0.265 |
1733 | 0.318 |
1734 | 0.397 |
1735 | 0.356 |
1736 | 0.289 |
1737 | 0.317 |
1738 | 0.769 |
1739 | 0.765 |
1740 | 0.370 |
1741 | 0.425 |
1742 | 0.338 |
1743 | 0.337 |
1744 | 0.356 |
1745 | 0.413 |
1746 | 0.308 |
1747 | 0.388 |
1748 | 0.357 |
1749 | 0.427 |
1750 | 0.285 |
1751 | 0.295 |
1752 | 0.882 |
1753 | 0.717 |
1754 |
1755 |
1756 |
1757 | Avg |
1758 |
1759 |
1760 | 0.184
1761 | |
1762 |
1763 | 0.247
1764 | |
1765 |
1766 | 0.237 |
1767 | 0.290 |
1768 |
1769 | 0.233
1770 | |
1771 |
1772 | 0.262
1773 | |
1774 | 0.249 |
1775 | 0.313 |
1776 | 0.369 |
1777 | 0.356 |
1778 | 0.270 |
1779 | 0.307 |
1780 | 0.641 |
1781 | 0.639 |
1782 | 0.347 |
1783 | 0.417 |
1784 | 0.301 |
1785 | 0.319 |
1786 | 0.330 |
1787 | 0.401 |
1788 | 0.282 |
1789 | 0.375 |
1790 | 0.291 |
1791 | 0.381 |
1792 | 0.261 |
1793 | 0.381 |
1794 | 0.885 |
1795 | 0.711 |
1796 |
1797 |
1798 |
1799 |
1800 |
1801 | # Get Started
1802 |
1803 | ## 1. Download the Data
1804 |
1805 | All datasets have been preprocessed and are ready for use. You can obtain them from their original sources:
1806 |
1807 | - **ETT**: [https://github.com/zhouhaoyi/ETDataset/tree/main](https://github.com/zhouhaoyi/ETDataset/tree/main)
1808 | - **Traffic, Electricity, Weather**: [https://github.com/thuml/Autoformer](https://github.com/thuml/Autoformer?tab=readme-ov-file)
1809 | - **Solar**: [https://github.com/laiguokun/LSTNet](https://github.com/laiguokun/LSTNet)
1810 | - **PEMS**: [https://github.com/cure-lab/SCINet](https://github.com/cure-lab/SCINet?tab=readme-ov-file)
1811 |
1812 | For convenience, we provide a comprehensive package containing all required datasets, available for download from [Google Drive](https://drive.google.com/file/d/1hTpUrhe1yEIGa9mCiGxM5rDyzlYKAnyx/view?usp=sharing). You can place it under the folder [./dataset](./dataset/).
1813 |
1814 | ## 2. Setup Your Environment
1815 |
1816 | Choose one of the following methods to set up your environment:
1817 |
1818 | ### Option A: Anaconda
1819 | Create and activate a Python environment using the provided configuration file [environment.yml](./environment.yml):
1820 |
1821 | ```bash
1822 | conda env create -f environment.yml -n SimpleTM
1823 | conda activate SimpleTM
1824 | ```
1825 |
1826 | ### Option B: Docker
1827 | If you prefer Docker, build an image using the provided [Dockerfile](./Dockerfile):
1828 |
1829 | ```bash
1830 | docker build --tag simpletm:latest .
1831 | ```
1832 |
1833 |
1834 | ## 3. Train the Model
1835 |
1836 | Experiment scripts for various benchmarks are provided in the [`scripts`](./scripts) directory. You can reproduce experiment results as follows:
1837 |
1838 | ```bash
1839 | bash ./scripts/multivariate_forecasting/ETT/SimpleTM_h1.sh # ETTh1
1840 | bash ./scripts/multivariate_forecasting/ECL/SimpleTM.sh # Electricity
1841 | bash ./scripts/long_term_forecast/SolarEnergy/SimpleTM.sh # Solar-Energy
1842 | bash ./scripts/long_term_forecast/Weather/SimpleTM.sh # Weather
1843 | bash ./scripts/short_term_forecast/PEMS/SimpleTM_03.sh # PEMS03
1844 | ```
1845 |
1846 | ### Docker Users
1847 | If you're using Docker, run the scripts with the following command structure (example for ETTh1):
1848 |
1849 | ```bash
1850 | docker run --gpus all -it --rm --ipc=host \
1851 | --user $(id -u):$(id -g) \
1852 | -v "$(pwd)":/scratch --workdir /scratch -e HOME=/scratch \
1853 | simpletm:latest \
1854 | bash scripts/multivariate_forecasting/ETT/SimpleTM_h1.sh
1855 | ```
1856 |
1857 |
1858 | # Model Efficiency
1859 | To provide an efficiency comparison, we evaluated our model against two of the most competitive baselines: the transformer-based iTransformer and linear-based TimeMixer. Our experimental setup used a consistent batch size of 256 across all models and measured four key metrics: total trainable parameters, inference time, GPU memory footprint, and peak memory usage during the backward pass. Results for all baseline models were compiled using PyTorch.
1860 |
1861 | Please note that our default experimental configuration does not employ compilation optimizations. To speed up, enable the --compile flag in the scripts.
1862 |
1863 |
1866 |
1867 |
1868 |
1869 | Table 13: Comparison of model performance and resource utilization across different datasets. Metrics include Mean Squared Error (MSE), total parameter count, inference time (seconds), GPU memory footprint (MB), and peak memory usage (MB).
1870 |
1871 |
1872 |
1873 | Dataset |
1874 | Model |
1875 | MSE |
1876 | Total Params |
1877 | Inference Time (s) |
1878 | GPU Mem Footprint (MB) |
1879 | Peak Mem (MB) |
1880 |
1881 |
1882 |
1883 |
1884 |
1885 | Weather |
1886 | SimpleTM |
1887 | 0.162 |
1888 | 13,472 |
1889 | 0.0132 |
1890 | 994 |
1891 | 181.75 |
1892 |
1893 |
1894 | TimeMixer |
1895 | 0.164 |
1896 | 104,433 |
1897 | 0.0453 |
1898 | 2,954 |
1899 | 2,281.38 |
1900 |
1901 |
1902 | iTransformer |
1903 | 0.176 |
1904 | 4,833,888 |
1905 | 0.0222 |
1906 | 1,596 |
1907 | 847.62 |
1908 |
1909 |
1910 |
1911 | Solar |
1912 | SimpleTM |
1913 | 0.163 |
1914 | 166,304 |
1915 | 0.0455 |
1916 | 2,048 |
1917 | 1,181.56 |
1918 |
1919 |
1920 | TimeMixer |
1921 | 0.215 |
1922 | 13,009,079 |
1923 | 0.2644 |
1924 | 7,576 |
1925 | 6,632.40 |
1926 |
1927 |
1928 | iTransformer |
1929 | 0.203 |
1930 | 3,255,904 |
1931 | 0.0663 |
1932 | 4,022 |
1933 | 2,776.50 |
1934 |
1935 |
1936 |
1937 |
1938 |
1939 | # Acknowledgement
1940 |
1941 | We appreciate the following GitHub repos a lot for their valuable code and efforts.
1942 | - Time-Series-Library (https://github.com/thuml/Time-Series-Library)
1943 | - iTransformer (https://github.com/thuml/iTransformer)
1944 | - TimeMixer (https://github.com/kwuking/TimeMixer)
1945 | - Autoformer (https://github.com/thuml/Autoformer)
1946 |
1947 |
1948 | # Citation
1949 | If you find this repo helpful, please cite our paper.
1950 |
1951 | ```bibtex
1952 | @inproceedings{
1953 | chen2025simpletm,
1954 | title={Simple{TM}: A Simple Baseline for Multivariate Time Series Forecasting},
1955 | author={Hui Chen and Viet Luong and Lopamudra Mukherjee and Vikas Singh},
1956 | booktitle={The Thirteenth International Conference on Learning Representations},
1957 | year={2025},
1958 | url={https://openreview.net/forum?id=oANkBaVci5}
1959 | }
1960 | ```
--------------------------------------------------------------------------------
/data_provider/data_factory.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Solar, Dataset_PEMS, \
2 | Dataset_Pred
3 | from torch.utils.data import DataLoader
4 |
5 | data_dict = {
6 | 'ETTh1': Dataset_ETT_hour,
7 | 'ETTh2': Dataset_ETT_hour,
8 | 'ETTm1': Dataset_ETT_minute,
9 | 'ETTm2': Dataset_ETT_minute,
10 | 'Solar': Dataset_Solar,
11 | 'PEMS': Dataset_PEMS,
12 | 'custom': Dataset_Custom,
13 | }
14 |
15 |
16 | def data_provider(args, flag):
17 | Data = data_dict[args.data]
18 | timeenc = 0 if args.embed != 'timeF' else 1
19 |
20 | if flag == 'test':
21 | shuffle_flag = False
22 | drop_last = True
23 | batch_size = args.batch_size
24 | freq = args.freq
25 | elif flag == 'pred':
26 | shuffle_flag = False
27 | drop_last = False
28 | batch_size = 1
29 | freq = args.freq
30 | Data = Dataset_Pred
31 | else:
32 | shuffle_flag = True
33 | drop_last = True
34 | batch_size = args.batch_size
35 | freq = args.freq
36 |
37 | data_set = Data(
38 | root_path=args.root_path,
39 | data_path=args.data_path,
40 | flag=flag,
41 | size=[args.seq_len, args.label_len, args.pred_len],
42 | features=args.features,
43 | target=args.target,
44 | timeenc=timeenc,
45 | freq=freq,
46 | )
47 | print(flag, len(data_set))
48 | data_loader = DataLoader(
49 | data_set,
50 | batch_size=batch_size,
51 | shuffle=shuffle_flag,
52 | num_workers=args.num_workers,
53 | drop_last=drop_last)
54 | return data_set, data_loader
55 |
--------------------------------------------------------------------------------
/data_provider/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 | from torch.utils.data import Dataset, DataLoader
6 | from sklearn.preprocessing import StandardScaler
7 | from utils.timefeatures import time_features
8 | import warnings
9 |
10 | warnings.filterwarnings('ignore')
11 |
12 | class Dataset_ETT_hour(Dataset):
13 | def __init__(self, root_path, flag='train', size=None,
14 | features='S', data_path='ETTh1.csv',
15 | target='OT', scale=True, timeenc=0, freq='h'):
16 | if size == None:
17 | self.seq_len = 24 * 4 * 4
18 | self.label_len = 24 * 4
19 | self.pred_len = 24 * 4
20 | else:
21 | self.seq_len = size[0]
22 | self.label_len = size[1]
23 | self.pred_len = size[2]
24 | assert flag in ['train', 'test', 'val']
25 | type_map = {'train': 0, 'val': 1, 'test': 2}
26 | self.set_type = type_map[flag]
27 |
28 | self.features = features
29 | self.target = target
30 | self.scale = scale
31 | self.timeenc = timeenc
32 | self.freq = freq
33 |
34 | self.root_path = root_path
35 | self.data_path = data_path
36 | self.__read_data__()
37 |
38 | def __read_data__(self):
39 | self.scaler = StandardScaler()
40 | df_raw = pd.read_csv(os.path.join(self.root_path,
41 | self.data_path))
42 |
43 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
44 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
45 | border1 = border1s[self.set_type]
46 | border2 = border2s[self.set_type]
47 |
48 | if self.features == 'M' or self.features == 'MS':
49 | cols_data = df_raw.columns[1:]
50 | df_data = df_raw[cols_data]
51 | elif self.features == 'S':
52 | df_data = df_raw[[self.target]]
53 |
54 | if self.scale:
55 | train_data = df_data[border1s[0]:border2s[0]]
56 | self.scaler.fit(train_data.values)
57 | data = self.scaler.transform(df_data.values)
58 | else:
59 | data = df_data.values
60 |
61 | df_stamp = df_raw[['date']][border1:border2]
62 | df_stamp['date'] = pd.to_datetime(df_stamp.date)
63 | if self.timeenc == 0:
64 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
65 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
66 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
67 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
68 | data_stamp = df_stamp.drop(['date'], 1).values
69 | elif self.timeenc == 1:
70 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
71 | data_stamp = data_stamp.transpose(1, 0)
72 |
73 | self.data_x = data[border1:border2]
74 | self.data_y = data[border1:border2]
75 | self.data_stamp = data_stamp
76 |
77 | def __getitem__(self, index):
78 | s_begin = index
79 | s_end = s_begin + self.seq_len
80 | r_begin = s_end - self.label_len
81 | r_end = r_begin + self.label_len + self.pred_len
82 |
83 | seq_x = self.data_x[s_begin:s_end]
84 | seq_y = self.data_y[r_begin:r_end]
85 | seq_x_mark = self.data_stamp[s_begin:s_end]
86 | seq_y_mark = self.data_stamp[r_begin:r_end]
87 |
88 | return seq_x, seq_y, seq_x_mark, seq_y_mark
89 |
90 | def __len__(self):
91 | return len(self.data_x) - self.seq_len - self.pred_len + 1
92 |
93 | def inverse_transform(self, data):
94 | return self.scaler.inverse_transform(data)
95 |
96 |
97 | class Dataset_ETT_minute(Dataset):
98 | def __init__(self, root_path, flag='train', size=None,
99 | features='S', data_path='ETTm1.csv',
100 | target='OT', scale=True, timeenc=0, freq='t'):
101 | if size == None:
102 | self.seq_len = 24 * 4 * 4
103 | self.label_len = 24 * 4
104 | self.pred_len = 24 * 4
105 | else:
106 | self.seq_len = size[0]
107 | self.label_len = size[1]
108 | self.pred_len = size[2]
109 | assert flag in ['train', 'test', 'val']
110 | type_map = {'train': 0, 'val': 1, 'test': 2}
111 | self.set_type = type_map[flag]
112 |
113 | self.features = features
114 | self.target = target
115 | self.scale = scale
116 | self.timeenc = timeenc
117 | self.freq = freq
118 |
119 | self.root_path = root_path
120 | self.data_path = data_path
121 | self.__read_data__()
122 |
123 | def __read_data__(self):
124 | self.scaler = StandardScaler()
125 | df_raw = pd.read_csv(os.path.join(self.root_path,
126 | self.data_path))
127 |
128 | border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
129 | border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
130 | border1 = border1s[self.set_type]
131 | border2 = border2s[self.set_type]
132 |
133 | if self.features == 'M' or self.features == 'MS':
134 | cols_data = df_raw.columns[1:]
135 | df_data = df_raw[cols_data]
136 | elif self.features == 'S':
137 | df_data = df_raw[[self.target]]
138 |
139 | if self.scale:
140 | train_data = df_data[border1s[0]:border2s[0]]
141 | self.scaler.fit(train_data.values)
142 | data = self.scaler.transform(df_data.values)
143 | else:
144 | data = df_data.values
145 |
146 | df_stamp = df_raw[['date']][border1:border2]
147 | df_stamp['date'] = pd.to_datetime(df_stamp.date)
148 | if self.timeenc == 0:
149 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
150 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
151 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
152 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
153 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
154 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
155 | data_stamp = df_stamp.drop(['date'], 1).values
156 | elif self.timeenc == 1:
157 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
158 | data_stamp = data_stamp.transpose(1, 0)
159 |
160 | self.data_x = data[border1:border2]
161 | self.data_y = data[border1:border2]
162 | self.data_stamp = data_stamp
163 |
164 | def __getitem__(self, index):
165 | s_begin = index
166 | s_end = s_begin + self.seq_len
167 | r_begin = s_end - self.label_len
168 | r_end = r_begin + self.label_len + self.pred_len
169 |
170 | seq_x = self.data_x[s_begin:s_end]
171 | seq_y = self.data_y[r_begin:r_end]
172 | seq_x_mark = self.data_stamp[s_begin:s_end]
173 | seq_y_mark = self.data_stamp[r_begin:r_end]
174 |
175 | return seq_x, seq_y, seq_x_mark, seq_y_mark
176 |
177 | def __len__(self):
178 | return len(self.data_x) - self.seq_len - self.pred_len + 1
179 |
180 | def inverse_transform(self, data):
181 | return self.scaler.inverse_transform(data)
182 |
183 |
184 | class Dataset_Custom(Dataset):
185 | def __init__(self, root_path, flag='train', size=None,
186 | features='S', data_path='ETTh1.csv',
187 | target='OT', scale=True, timeenc=0, freq='h'):
188 | if size == None:
189 | self.seq_len = 24 * 4 * 4
190 | self.label_len = 24 * 4
191 | self.pred_len = 24 * 4
192 | else:
193 | self.seq_len = size[0]
194 | self.label_len = size[1]
195 | self.pred_len = size[2]
196 | assert flag in ['train', 'test', 'val']
197 | type_map = {'train': 0, 'val': 1, 'test': 2}
198 | self.set_type = type_map[flag]
199 |
200 | self.features = features
201 | self.target = target
202 | self.scale = scale
203 | self.timeenc = timeenc
204 | self.freq = freq
205 |
206 | self.root_path = root_path
207 | self.data_path = data_path
208 | self.__read_data__()
209 |
210 | def __read_data__(self):
211 | self.scaler = StandardScaler()
212 | df_raw = pd.read_csv(os.path.join(self.root_path,
213 | self.data_path))
214 | cols = list(df_raw.columns)
215 | cols.remove(self.target)
216 | cols.remove('date')
217 | df_raw = df_raw[['date'] + cols + [self.target]]
218 | num_train = int(len(df_raw) * 0.7)
219 | num_test = int(len(df_raw) * 0.2)
220 | num_vali = len(df_raw) - num_train - num_test
221 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
222 | border2s = [num_train, num_train + num_vali, len(df_raw)]
223 | border1 = border1s[self.set_type]
224 | border2 = border2s[self.set_type]
225 |
226 | if self.features == 'M' or self.features == 'MS':
227 | cols_data = df_raw.columns[1:]
228 | df_data = df_raw[cols_data]
229 | elif self.features == 'S':
230 | df_data = df_raw[[self.target]]
231 |
232 | if self.scale:
233 | train_data = df_data[border1s[0]:border2s[0]]
234 | self.scaler.fit(train_data.values)
235 | data = self.scaler.transform(df_data.values)
236 | else:
237 | data = df_data.values
238 |
239 | df_stamp = df_raw[['date']][border1:border2]
240 | df_stamp['date'] = pd.to_datetime(df_stamp.date)
241 | if self.timeenc == 0:
242 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
243 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
244 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
245 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
246 | data_stamp = df_stamp.drop(['date'], 1).values
247 | elif self.timeenc == 1:
248 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
249 | data_stamp = data_stamp.transpose(1, 0)
250 |
251 | self.data_x = data[border1:border2]
252 | self.data_y = data[border1:border2]
253 | self.data_stamp = data_stamp
254 |
255 | def __getitem__(self, index):
256 | s_begin = index
257 | s_end = s_begin + self.seq_len
258 | r_begin = s_end - self.label_len
259 | r_end = r_begin + self.label_len + self.pred_len
260 |
261 | seq_x = self.data_x[s_begin:s_end]
262 | seq_y = self.data_y[r_begin:r_end]
263 | seq_x_mark = self.data_stamp[s_begin:s_end]
264 | seq_y_mark = self.data_stamp[r_begin:r_end]
265 |
266 | return seq_x, seq_y, seq_x_mark, seq_y_mark
267 |
268 | def __len__(self):
269 | return len(self.data_x) - self.seq_len - self.pred_len + 1
270 |
271 | def inverse_transform(self, data):
272 | return self.scaler.inverse_transform(data)
273 |
274 |
275 | class Dataset_PEMS(Dataset):
276 | def __init__(self, root_path, flag='train', size=None,
277 | features='S', data_path='ETTh1.csv',
278 | target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
279 | self.seq_len = size[0]
280 | self.label_len = size[1]
281 | self.pred_len = size[2]
282 | assert flag in ['train', 'test', 'val']
283 | type_map = {'train': 0, 'val': 1, 'test': 2}
284 | self.set_type = type_map[flag]
285 |
286 | self.features = features
287 | self.target = target
288 | self.scale = scale
289 | self.timeenc = timeenc
290 | self.freq = freq
291 |
292 | self.root_path = root_path
293 | self.data_path = data_path
294 | self.__read_data__()
295 |
296 | def __read_data__(self):
297 | self.scaler = StandardScaler()
298 | data_file = os.path.join(self.root_path, self.data_path)
299 | print('data file:', data_file)
300 | data = np.load(data_file, allow_pickle=True)
301 | data = data['data'][:, :, 0]
302 |
303 | train_ratio = 0.6
304 | valid_ratio = 0.2
305 | train_data = data[:int(train_ratio * len(data))]
306 | valid_data = data[int(train_ratio * len(data)):int((train_ratio + valid_ratio) * len(data))]
307 | test_data = data[int((train_ratio + valid_ratio) * len(data)):]
308 | total_data = [train_data, valid_data, test_data]
309 | data = total_data[self.set_type]
310 |
311 | if self.scale:
312 | self.scaler.fit(data)
313 | data = self.scaler.transform(data)
314 |
315 | df = pd.DataFrame(data)
316 | df = df.fillna(method='ffill', limit=len(df)).fillna(method='bfill', limit=len(df)).values
317 |
318 | self.data_x = df
319 | self.data_y = df
320 |
321 | def __getitem__(self, index):
322 | if self.set_type == 2:
323 | s_begin = index * 12
324 | else:
325 | s_begin = index
326 | s_end = s_begin + self.seq_len
327 | r_begin = s_end - self.label_len
328 | r_end = r_begin + self.label_len + self.pred_len
329 |
330 | seq_x = self.data_x[s_begin:s_end]
331 | seq_y = self.data_y[r_begin:r_end]
332 | seq_x_mark = torch.zeros((seq_x.shape[0], 1))
333 | seq_y_mark = torch.zeros((seq_y.shape[0], 1))
334 |
335 | return seq_x, seq_y, seq_x_mark, seq_y_mark
336 |
337 | def __len__(self):
338 | if self.set_type == 2:
339 | return (len(self.data_x) - self.seq_len - self.pred_len + 1) // 12
340 | else:
341 | return len(self.data_x) - self.seq_len - self.pred_len + 1
342 |
343 | def inverse_transform(self, data):
344 | return self.scaler.inverse_transform(data)
345 |
346 |
347 | class Dataset_Solar(Dataset):
348 | def __init__(self, root_path, flag='train', size=None,
349 | features='S', data_path='ETTh1.csv',
350 | target='OT', scale=True, timeenc=0, freq='h'):
351 | self.seq_len = size[0]
352 | self.label_len = size[1]
353 | self.pred_len = size[2]
354 | assert flag in ['train', 'test', 'val']
355 | type_map = {'train': 0, 'val': 1, 'test': 2}
356 | self.set_type = type_map[flag]
357 |
358 | self.features = features
359 | self.target = target
360 | self.scale = scale
361 | self.timeenc = timeenc
362 | self.freq = freq
363 |
364 | self.root_path = root_path
365 | self.data_path = data_path
366 | self.__read_data__()
367 |
368 | def __read_data__(self):
369 | self.scaler = StandardScaler()
370 | df_raw = []
371 | with open(os.path.join(self.root_path, self.data_path), "r", encoding='utf-8') as f:
372 | for line in f.readlines():
373 | line = line.strip('\n').split(',')
374 | data_line = np.stack([float(i) for i in line])
375 | df_raw.append(data_line)
376 | df_raw = np.stack(df_raw, 0)
377 | df_raw = pd.DataFrame(df_raw)
378 |
379 | num_train = int(len(df_raw) * 0.7)
380 | num_test = int(len(df_raw) * 0.2)
381 | num_valid = int(len(df_raw) * 0.1)
382 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
383 | border2s = [num_train, num_train + num_valid, len(df_raw)]
384 | border1 = border1s[self.set_type]
385 | border2 = border2s[self.set_type]
386 |
387 | df_data = df_raw.values
388 |
389 | if self.scale:
390 | train_data = df_data[border1s[0]:border2s[0]]
391 | self.scaler.fit(train_data)
392 | data = self.scaler.transform(df_data)
393 | else:
394 | data = df_data
395 |
396 | self.data_x = data[border1:border2]
397 | self.data_y = data[border1:border2]
398 |
399 | def __getitem__(self, index):
400 | s_begin = index
401 | s_end = s_begin + self.seq_len
402 | r_begin = s_end - self.label_len
403 | r_end = r_begin + self.label_len + self.pred_len
404 |
405 | seq_x = self.data_x[s_begin:s_end]
406 | seq_y = self.data_y[r_begin:r_end]
407 | seq_x_mark = torch.zeros((seq_x.shape[0], 1))
408 | seq_y_mark = torch.zeros((seq_x.shape[0], 1))
409 |
410 | return seq_x, seq_y, seq_x_mark, seq_y_mark
411 |
412 | def __len__(self):
413 | return len(self.data_x) - self.seq_len - self.pred_len + 1
414 |
415 | def inverse_transform(self, data):
416 | return self.scaler.inverse_transform(data)
417 |
418 |
419 | class Dataset_Pred(Dataset):
420 | def __init__(self, root_path, flag='pred', size=None,
421 | features='S', data_path='ETTh1.csv',
422 | target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None):
423 | if size == None:
424 | self.seq_len = 24 * 4 * 4
425 | self.label_len = 24 * 4
426 | self.pred_len = 24 * 4
427 | else:
428 | self.seq_len = size[0]
429 | self.label_len = size[1]
430 | self.pred_len = size[2]
431 | assert flag in ['pred']
432 |
433 | self.features = features
434 | self.target = target
435 | self.scale = scale
436 | self.inverse = inverse
437 | self.timeenc = timeenc
438 | self.freq = freq
439 | self.cols = cols
440 | self.root_path = root_path
441 | self.data_path = data_path
442 | self.__read_data__()
443 |
444 | def __read_data__(self):
445 | self.scaler = StandardScaler()
446 | df_raw = pd.read_csv(os.path.join(self.root_path,
447 | self.data_path))
448 | if self.cols:
449 | cols = self.cols.copy()
450 | cols.remove(self.target)
451 | else:
452 | cols = list(df_raw.columns)
453 | cols.remove(self.target)
454 | cols.remove('date')
455 | df_raw = df_raw[['date'] + cols + [self.target]]
456 | border1 = len(df_raw) - self.seq_len
457 | border2 = len(df_raw)
458 |
459 | if self.features == 'M' or self.features == 'MS':
460 | cols_data = df_raw.columns[1:]
461 | df_data = df_raw[cols_data]
462 | elif self.features == 'S':
463 | df_data = df_raw[[self.target]]
464 |
465 | if self.scale:
466 | self.scaler.fit(df_data.values)
467 | data = self.scaler.transform(df_data.values)
468 | else:
469 | data = df_data.values
470 |
471 | tmp_stamp = df_raw[['date']][border1:border2]
472 | tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date)
473 | pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq)
474 |
475 | df_stamp = pd.DataFrame(columns=['date'])
476 | df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:])
477 | if self.timeenc == 0:
478 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
479 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
480 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
481 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
482 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
483 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
484 | data_stamp = df_stamp.drop(['date'], 1).values
485 | elif self.timeenc == 1:
486 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
487 | data_stamp = data_stamp.transpose(1, 0)
488 |
489 | self.data_x = data[border1:border2]
490 | if self.inverse:
491 | self.data_y = df_data.values[border1:border2]
492 | else:
493 | self.data_y = data[border1:border2]
494 | self.data_stamp = data_stamp
495 |
496 | def __getitem__(self, index):
497 | s_begin = index
498 | s_end = s_begin + self.seq_len
499 | r_begin = s_end - self.label_len
500 | r_end = r_begin + self.label_len + self.pred_len
501 |
502 | seq_x = self.data_x[s_begin:s_end]
503 | if self.inverse:
504 | seq_y = self.data_x[r_begin:r_begin + self.label_len]
505 | else:
506 | seq_y = self.data_y[r_begin:r_begin + self.label_len]
507 | seq_x_mark = self.data_stamp[s_begin:s_end]
508 | seq_y_mark = self.data_stamp[r_begin:r_end]
509 |
510 | return seq_x, seq_y, seq_x_mark, seq_y_mark
511 |
512 | def __len__(self):
513 | return len(self.data_x) - self.seq_len + 1
514 |
515 | def inverse_transform(self, data):
516 | return self.scaler.inverse_transform(data)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: SimpleTM
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python=3.9
7 | - einops=0.8.1
8 | - matplotlib=3.7.0
9 | - numpy=1.23.5
10 | - scikit-learn=1.2.2
11 | - scipy=1.13.1
12 | - sympy=1.13.3
13 | - pandas=1.5.3
14 | - pip:
15 | - reformer-pytorch==1.4.4
16 | - PyWavelets
17 |
--------------------------------------------------------------------------------
/experiments/exp_basic.py:
--------------------------------------------------------------------------------
1 | import os
2 | from model import SimpleTM
3 | import torch
4 |
5 | # Add this at the beginning of your training script
6 | import torch._dynamo as dynamo
7 | dynamo.config.suppress_errors = True
8 |
9 | import numpy as np
10 |
11 | class Exp_Basic(object):
12 | def __init__(self, args):
13 | self.args = args
14 | self.model_dict = {
15 | 'SimpleTM': SimpleTM,
16 | }
17 | self.device = self._acquire_device()
18 | self.model = self._build_model().to(self.device)
19 |
20 | if self.args.compile:
21 | self.model = torch.compile(
22 | self.model,
23 | )
24 |
25 | # Count trainable parameters
26 | model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
27 | param_count = sum([np.prod(p.size()) for p in model_parameters])
28 |
29 | # Calculate memory usage in bytes and convert to megabytes
30 | memory_usage_bytes = param_count * 4 # 4 bytes per float32 parameter
31 | memory_usage_MB = memory_usage_bytes / (1024 ** 2)
32 | print(f"Total trainable parameters: {param_count}")
33 | print(f"Memory usage for trainable parameters: {memory_usage_MB:.2f} MB")
34 |
35 | # Measure static memory footprint
36 | print(f"Static memory footprint (allocated): {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
37 | print(f"Static memory footprint (reserved): {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB")
38 |
39 |
40 | def _build_model(self):
41 | raise NotImplementedError
42 |
43 | def _acquire_device(self):
44 | if self.args.use_gpu:
45 | os.environ["CUDA_VISIBLE_DEVICES"] = str(
46 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
47 | device = torch.device('cuda:{}'.format(self.args.gpu))
48 | print('Use GPU: cuda:{}'.format(self.args.gpu))
49 | else:
50 | device = torch.device('cpu')
51 | print('Use CPU')
52 | return device
53 |
54 | def _get_data(self):
55 | pass
56 |
57 | def vali(self):
58 | pass
59 |
60 | def train(self):
61 | pass
62 |
63 | def test(self):
64 | pass
--------------------------------------------------------------------------------
/experiments/exp_long_term_forecasting.py:
--------------------------------------------------------------------------------
1 | from torch.optim import lr_scheduler
2 |
3 | from data_provider.data_factory import data_provider
4 | from experiments.exp_basic import Exp_Basic
5 | from utils.tools import EarlyStopping, adjust_learning_rate, visual
6 | from utils.metrics import metric
7 | import torch
8 | import torch.nn as nn
9 | from torch import optim
10 | import os
11 | import time
12 | import warnings
13 | import numpy as np
14 |
15 | warnings.filterwarnings('ignore')
16 |
17 | torch.autograd.set_detect_anomaly(True)
18 |
19 |
20 | class Exp_Long_Term_Forecast(Exp_Basic):
21 | def __init__(self, args):
22 | super(Exp_Long_Term_Forecast, self).__init__(args)
23 |
24 | def _build_model(self):
25 | model = self.model_dict[self.args.model].Model(self.args).float()
26 |
27 | if self.args.use_multi_gpu and self.args.use_gpu:
28 | model = nn.DataParallel(model, device_ids=self.args.device_ids)
29 | return model
30 |
31 | def _get_data(self, flag):
32 | data_set, data_loader = data_provider(self.args, flag)
33 | return data_set, data_loader
34 |
35 | def _select_optimizer(self):
36 | model_optim = optim.AdamW(self.model.parameters(), lr=self.args.learning_rate)
37 | return model_optim
38 |
39 | def _select_criterion(self):
40 | if self.args.data == 'PEMS':
41 | criterion = nn.L1Loss()
42 | else:
43 | criterion = nn.MSELoss()
44 | return criterion
45 |
46 | def vali(self, vali_data, vali_loader, criterion):
47 | total_loss = []
48 | self.model.eval()
49 | with torch.no_grad():
50 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
51 | batch_x = batch_x.float().to(self.device)
52 | batch_y = batch_y.float()
53 |
54 | if 'PEMS' in self.args.data or 'Solar' in self.args.data:
55 | batch_x_mark = None
56 | batch_y_mark = None
57 | else:
58 | batch_x_mark = batch_x_mark.float().to(self.device)
59 | batch_y_mark = batch_y_mark.float().to(self.device)
60 |
61 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
62 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
63 |
64 | if self.args.use_amp:
65 | with torch.cuda.amp.autocast():
66 | if self.args.output_attention:
67 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
68 | else:
69 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
70 | else:
71 | if self.args.output_attention:
72 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
73 | else:
74 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
75 | f_dim = -1 if self.args.features == 'MS' else 0
76 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
77 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
78 |
79 | pred = outputs.detach().cpu()
80 | true = batch_y.detach().cpu()
81 |
82 | if self.args.data == 'PEMS':
83 | B, T, C = pred.shape
84 | pred = pred.numpy()
85 | true = true.numpy()
86 | pred = vali_data.inverse_transform(pred.reshape(-1, C)).reshape(B, T, C)
87 | true = vali_data.inverse_transform(true.reshape(-1, C)).reshape(B, T, C)
88 | mae, mse, rmse, mape, mspe = metric(pred, true)
89 | total_loss.append(mae)
90 | else:
91 | loss = criterion(pred, true)
92 | total_loss.append(loss)
93 |
94 | total_loss = np.average(total_loss)
95 | self.model.train()
96 | return total_loss
97 |
98 | def train(self, setting):
99 | train_data, train_loader = self._get_data(flag='train')
100 | vali_data, vali_loader = self._get_data(flag='val')
101 | test_data, test_loader = self._get_data(flag='test')
102 |
103 | path = os.path.join(self.args.checkpoints, setting)
104 | if not os.path.exists(path):
105 | os.makedirs(path)
106 |
107 | time_now = time.time()
108 |
109 | train_steps = len(train_loader)
110 | early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
111 |
112 | model_optim = self._select_optimizer()
113 | criterion = self._select_criterion()
114 |
115 | if self.args.lradj == 'TST':
116 | scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim,
117 | steps_per_epoch=train_steps,
118 | pct_start=self.args.pct_start,
119 | epochs=self.args.train_epochs,
120 | max_lr=self.args.learning_rate)
121 |
122 |
123 | if self.args.use_amp:
124 | scaler = torch.cuda.amp.GradScaler()
125 | # # Efficiency: dynamic memory footprint
126 | # # Track dynamic memory usage over an epoch
127 | # torch.cuda.reset_peak_memory_stats() # Reset peak memory tracking
128 |
129 | for epoch in range(self.args.train_epochs):
130 | iter_count = 0
131 | train_loss = []
132 |
133 | self.model.train()
134 | epoch_time = time.time()
135 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
136 | iter_count += 1
137 | model_optim.zero_grad()
138 | batch_x = batch_x.float().to(self.device)
139 |
140 | batch_y = batch_y.float().to(self.device)
141 | if 'PEMS' in self.args.data or 'Solar' in self.args.data:
142 | batch_x_mark = None
143 | batch_y_mark = None
144 | else:
145 | batch_x_mark = batch_x_mark.float().to(self.device)
146 | batch_y_mark = batch_y_mark.float().to(self.device)
147 |
148 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
149 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
150 |
151 | if self.args.use_amp:
152 | with torch.cuda.amp.autocast():
153 | if self.args.output_attention:
154 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
155 | else:
156 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
157 |
158 | f_dim = -1 if self.args.features == 'MS' else 0
159 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
160 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
161 | loss = criterion(outputs, batch_y)
162 | train_loss.append(loss.item())
163 | else:
164 | if self.args.output_attention:
165 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
166 | else:
167 | outputs, attn = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
168 |
169 | f_dim = -1 if self.args.features == 'MS' else 0
170 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
171 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
172 |
173 | loss = criterion(outputs, batch_y) + self.args.l1_weight * attn[0]
174 | train_loss.append(loss.item())
175 |
176 | if (i + 1) % 30 == 0:
177 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
178 | speed = (time.time() - time_now) / iter_count
179 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
180 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
181 | iter_count = 0
182 | time_now = time.time()
183 |
184 | if self.args.use_amp:
185 | scaler.scale(loss).backward()
186 | scaler.step(model_optim)
187 | scaler.update()
188 | else:
189 | loss.backward()
190 | model_optim.step()
191 | # # Efficiency: dynamic memory footprint
192 | # # Record current and peak memory usage after processing this batch
193 | # current_memory = torch.cuda.memory_allocated()
194 | # peak_memory = torch.cuda.max_memory_allocated()
195 | # print(f"Current memory: {current_memory / (1024 ** 2):.2f} MB, Peak memory: {peak_memory / (1024 ** 2):.2f} MB")
196 |
197 | if self.args.lradj == 'TST':
198 | adjust_learning_rate(model_optim, epoch + 1, self.args, scheduler, printout=False)
199 | scheduler.step()
200 |
201 |
202 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
203 | train_loss = np.average(train_loss)
204 | vali_loss = self.vali(vali_data, vali_loader, criterion)
205 | test_loss = self.vali(test_data, test_loader, criterion)
206 |
207 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
208 | epoch + 1, train_steps, train_loss, vali_loss, test_loss))
209 | early_stopping(vali_loss, self.model, path)
210 | if early_stopping.early_stop:
211 | print("Early stopping")
212 | break
213 |
214 | if self.args.lradj != 'TST':
215 | adjust_learning_rate(model_optim, epoch + 1, self.args)
216 | else:
217 | adjust_learning_rate(model_optim, epoch + 1, self.args, scheduler)
218 |
219 |
220 | best_model_path = path + '/' + 'checkpoint.pth'
221 | self.model.load_state_dict(torch.load(best_model_path))
222 |
223 | return self.model
224 |
225 | def test(self, setting, test=0):
226 | test_data, test_loader = self._get_data(flag='test')
227 | if test:
228 | print('loading model')
229 | self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
230 |
231 | preds = []
232 | trues = []
233 | folder_path = './checkpoints/' + setting + '/'
234 | if not os.path.exists(folder_path):
235 | os.makedirs(folder_path)
236 |
237 | self.model.eval()
238 | with torch.no_grad():
239 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
240 | batch_x = batch_x.float().to(self.device)
241 | batch_y = batch_y.float().to(self.device)
242 |
243 | if 'PEMS' in self.args.data or 'Solar' in self.args.data:
244 | batch_x_mark = None
245 | batch_y_mark = None
246 | else:
247 | batch_x_mark = batch_x_mark.float().to(self.device)
248 | batch_y_mark = batch_y_mark.float().to(self.device)
249 |
250 |
251 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
252 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
253 | # encoder - decoder
254 | if self.args.use_amp:
255 | with torch.cuda.amp.autocast():
256 | if self.args.output_attention:
257 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
258 | else:
259 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
260 | else:
261 | if self.args.output_attention:
262 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
263 | else:
264 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
265 |
266 | f_dim = -1 if self.args.features == 'MS' else 0
267 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
268 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
269 | outputs = outputs.detach().cpu().numpy()
270 | batch_y = batch_y.detach().cpu().numpy()
271 |
272 | pred = outputs
273 | true = batch_y
274 |
275 | preds.append(pred)
276 | trues.append(true)
277 | if i % 20 == 0:
278 | input = batch_x.detach().cpu().numpy()
279 | if test_data.scale and self.args.inverse:
280 | shape = input.shape
281 | input = test_data.inverse_transform(input.squeeze(0)).reshape(shape)
282 | gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
283 | pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
284 | visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
285 |
286 | preds = np.array(preds)
287 | trues = np.array(trues)
288 | print('test shape:', preds.shape, trues.shape)
289 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
290 | trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
291 | print('test shape:', preds.shape, trues.shape)
292 |
293 | if self.args.data == 'PEMS':
294 | B, T, C = preds.shape
295 | preds = test_data.inverse_transform(preds.reshape(-1, C)).reshape(B, T, C)
296 | trues = test_data.inverse_transform(trues.reshape(-1, C)).reshape(B, T, C)
297 |
298 | # result save
299 | folder_path = './checkpoints/' + setting + '/'
300 | if not os.path.exists(folder_path):
301 | os.makedirs(folder_path)
302 |
303 | mae, mse, rmse, mape, mspe = metric(preds, trues)
304 | print('mse:{}, mae:{}'.format(mse, mae))
305 | print('rmse:{}, mape:{}, mspe:{}'.format(rmse, mape, mspe))
306 | f = open("result_long_term_forecast.txt", 'a')
307 | f.write(setting + " \n")
308 | if self.args.data == 'PEMS':
309 | f.write('mae:{}, mape:{}, rmse:{}'.format(mae, mape, rmse))
310 | else:
311 | f.write('mse:{}, mae:{}'.format(mse, mae))
312 | f.write('\n')
313 | f.write('\n')
314 | f.close()
315 |
316 |
317 | return
318 |
319 |
320 | def predict(self, setting, load=False):
321 | pred_data, pred_loader = self._get_data(flag='pred')
322 |
323 | if load:
324 | path = os.path.join(self.args.checkpoints, setting)
325 | best_model_path = path + '/' + 'checkpoint.pth'
326 | self.model.load_state_dict(torch.load(best_model_path))
327 |
328 | preds = []
329 |
330 | self.model.eval()
331 | with torch.no_grad():
332 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader):
333 | batch_x = batch_x.float().to(self.device)
334 | batch_y = batch_y.float()
335 | batch_x_mark = batch_x_mark.float().to(self.device)
336 | batch_y_mark = batch_y_mark.float().to(self.device)
337 |
338 |
339 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
340 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
341 | # encoder - decoder
342 | if self.args.use_amp:
343 | with torch.cuda.amp.autocast():
344 | if self.args.output_attention:
345 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
346 | else:
347 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
348 | else:
349 | if self.args.output_attention:
350 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
351 | else:
352 | outputs, _ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
353 | outputs = outputs.detach().cpu().numpy()
354 | if pred_data.scale and self.args.inverse:
355 | shape = outputs.shape
356 | outputs = pred_data.inverse_transform(outputs.squeeze(0)).reshape(shape)
357 | preds.append(outputs)
358 |
359 | preds = np.array(preds)
360 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
361 |
362 | # result save
363 | folder_path = './results/' + setting + '/'
364 | if not os.path.exists(folder_path):
365 | os.makedirs(folder_path)
366 |
367 | np.save(folder_path + 'real_prediction.npy', preds)
368 |
369 | return
--------------------------------------------------------------------------------
/figures/Efficiency.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsingh-group/SimpleTM/3c77d820837b726afb03c943235ea95bc924243d/figures/Efficiency.jpg
--------------------------------------------------------------------------------
/figures/Framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsingh-group/SimpleTM/3c77d820837b726afb03c943235ea95bc924243d/figures/Framework.png
--------------------------------------------------------------------------------
/figures/Long_term_forecast_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsingh-group/SimpleTM/3c77d820837b726afb03c943235ea95bc924243d/figures/Long_term_forecast_results.jpg
--------------------------------------------------------------------------------
/layers/Embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class DataEmbedding_inverted(nn.Module):
5 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
6 | super(DataEmbedding_inverted, self).__init__()
7 | self.value_embedding = nn.Linear(c_in, d_model)
8 | self.dropout = nn.Dropout(p=dropout)
9 |
10 | def forward(self, x, x_mark):
11 | x = x.permute(0, 2, 1)
12 | if x_mark is None:
13 | x = self.value_embedding(x)
14 | else:
15 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
16 | return self.dropout(x)
--------------------------------------------------------------------------------
/layers/SWTAttention_Family.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from math import sqrt
5 | import pywt
6 |
7 |
8 | class WaveletEmbedding(nn.Module):
9 | def __init__(self, d_channel=16, swt=True, requires_grad=False, wv='db2', m=2,
10 | kernel_size=None):
11 | super().__init__()
12 |
13 | self.swt = swt
14 | self.d_channel = d_channel
15 | self.m = m # Number of decomposition levels of detailed coefficients
16 |
17 | if kernel_size is None:
18 | self.wavelet = pywt.Wavelet(wv)
19 | if self.swt:
20 | h0 = torch.tensor(self.wavelet.dec_lo[::-1], dtype=torch.float32)
21 | h1 = torch.tensor(self.wavelet.dec_hi[::-1], dtype=torch.float32)
22 | else:
23 | h0 = torch.tensor(self.wavelet.rec_lo[::-1], dtype=torch.float32)
24 | h1 = torch.tensor(self.wavelet.rec_hi[::-1], dtype=torch.float32)
25 | self.h0 = nn.Parameter(torch.tile(h0[None, None, :], [self.d_channel, 1, 1]), requires_grad=requires_grad)
26 | self.h1 = nn.Parameter(torch.tile(h1[None, None, :], [self.d_channel, 1, 1]), requires_grad=requires_grad)
27 | self.kernel_size = self.h0.shape[-1]
28 | else:
29 | self.kernel_size = kernel_size
30 | self.h0 = nn.Parameter(torch.Tensor(self.d_channel, 1, self.kernel_size), requires_grad=requires_grad)
31 | self.h1 = nn.Parameter(torch.Tensor(self.d_channel, 1, self.kernel_size), requires_grad=requires_grad)
32 | nn.init.xavier_uniform_(self.h0)
33 | nn.init.xavier_uniform_(self.h1)
34 |
35 | with torch.no_grad():
36 | self.h0.data = self.h0.data / torch.norm(self.h0.data, dim=-1, keepdim=True)
37 | self.h1.data = self.h1.data / torch.norm(self.h1.data, dim=-1, keepdim=True)
38 |
39 |
40 | def forward(self, x):
41 | if self.swt:
42 | coeffs = self.swt_decomposition(x, self.h0, self.h1, self.m, self.kernel_size)
43 | else:
44 | coeffs = self.swt_reconstruction(x, self.h0, self.h1, self.m, self.kernel_size)
45 | return coeffs
46 |
47 | def swt_decomposition(self, x, h0, h1, depth, kernel_size):
48 | approx_coeffs = x
49 | coeffs = []
50 | dilation = 1
51 | for _ in range(depth):
52 | padding = dilation * (kernel_size - 1)
53 | padding_r = (kernel_size * dilation) // 2
54 | pad = (padding - padding_r, padding_r)
55 | approx_coeffs_pad = F.pad(approx_coeffs, pad, "circular")
56 | detail_coeff = F.conv1d(approx_coeffs_pad, h1, dilation=dilation, groups=x.shape[1])
57 | approx_coeffs = F.conv1d(approx_coeffs_pad, h0, dilation=dilation, groups=x.shape[1])
58 | coeffs.append(detail_coeff)
59 | dilation *= 2
60 | coeffs.append(approx_coeffs)
61 |
62 | return torch.stack(list(reversed(coeffs)), -2)
63 |
64 | def swt_reconstruction(self, coeffs, g0, g1, m, kernel_size):
65 | dilation = 2 ** (m - 1)
66 | approx_coeff = coeffs[:,:,0,:]
67 | detail_coeffs = coeffs[:,:,1:,:]
68 |
69 | for i in range(m):
70 | detail_coeff = detail_coeffs[:,:,i,:]
71 | padding = dilation * (kernel_size - 1)
72 | padding_l = (dilation * kernel_size) // 2
73 | pad = (padding_l, padding - padding_l)
74 | approx_coeff_pad = F.pad(approx_coeff, pad, "circular")
75 | detail_coeff_pad = F.pad(detail_coeff, pad, "circular")
76 |
77 | y = F.conv1d(approx_coeff_pad, g0, groups=approx_coeff.shape[1], dilation=dilation) + \
78 | F.conv1d(detail_coeff_pad, g1, groups=detail_coeff.shape[1], dilation=dilation)
79 | approx_coeff = y / 2
80 | dilation //= 2
81 |
82 | return approx_coeff
83 |
84 |
85 | class GeomAttentionLayer(nn.Module):
86 | def __init__(self, attention, d_model,
87 | requires_grad=True, wv='db2', m=2, kernel_size=None,
88 | d_channel=None, geomattn_dropout=0.5,):
89 | super(GeomAttentionLayer, self).__init__()
90 |
91 | self.d_channel = d_channel
92 | self.inner_attention = attention
93 |
94 | self.swt = WaveletEmbedding(d_channel=self.d_channel, swt=True, requires_grad=requires_grad, wv=wv, m=m, kernel_size=kernel_size)
95 | self.query_projection = nn.Sequential(
96 | nn.Linear(d_model, d_model),
97 | nn.Dropout(geomattn_dropout)
98 | )
99 | self.key_projection = nn.Sequential(
100 | nn.Linear(d_model, d_model),
101 | nn.Dropout(geomattn_dropout)
102 | )
103 | self.value_projection = nn.Sequential(
104 | nn.Linear(d_model, d_model),
105 | nn.Dropout(geomattn_dropout)
106 | )
107 | self.out_projection = nn.Sequential(
108 | nn.Linear(d_model, d_model),
109 | WaveletEmbedding(d_channel=self.d_channel, swt=False, requires_grad=requires_grad, wv=wv, m=m, kernel_size=kernel_size),
110 | )
111 |
112 | def forward(self, queries, keys, values, attn_mask=None, tau=None, delta=None):
113 | queries = self.swt(queries)
114 | keys = self.swt(keys)
115 | values = self.swt(values)
116 |
117 | queries = self.query_projection(queries).permute(0,3,2,1)
118 | keys = self.key_projection(keys).permute(0,3,2,1)
119 | values = self.value_projection(values).permute(0,3,2,1)
120 |
121 | out, attn = self.inner_attention(
122 | queries,
123 | keys,
124 | values,
125 | )
126 |
127 | out = self.out_projection(out.permute(0,3,2,1))
128 |
129 | return out, attn
130 |
131 |
132 | class GeomAttention(nn.Module):
133 | def __init__(self, mask_flag=False, factor=5, scale=None, attention_dropout=0.1,
134 | output_attention=False,
135 | alpha=1.,):
136 | super(GeomAttention, self).__init__()
137 | self.scale = scale
138 | self.mask_flag = mask_flag
139 | self.output_attention = output_attention
140 | self.dropout = nn.Dropout(attention_dropout)
141 |
142 | self.alpha = alpha
143 |
144 | def forward(self, queries, keys, values, attn_mask=None):
145 | B, L, H, E = queries.shape
146 | _, S, _, _ = values.shape
147 | scale = self.scale or 1. / sqrt(E)
148 |
149 | dot_product = torch.einsum("blhe,bshe->bhls", queries, keys)
150 |
151 | queries_norm2 = torch.sum(queries**2, dim=-1)
152 | keys_norm2 = torch.sum(keys**2, dim=-1)
153 | queries_norm2 = queries_norm2.permute(0, 2, 1).unsqueeze(-1) # (B, H, L, 1)
154 | keys_norm2 = keys_norm2.permute(0, 2, 1).unsqueeze(-2) # (B, H, 1, S)
155 | wedge_norm2 = queries_norm2 * keys_norm2 - dot_product ** 2 # (B, H, L, S)
156 | wedge_norm2 = F.relu(wedge_norm2)
157 | wedge_norm = torch.sqrt(wedge_norm2 + 1e-8)
158 |
159 | scores = (1 - self.alpha) * dot_product + self.alpha * wedge_norm
160 | scores = scores * scale
161 |
162 | if self.mask_flag:
163 | if attn_mask is None:
164 | attn_mask = torch.tril(torch.ones(L, S)).to(scores.device)
165 | scores.masked_fill_(attn_mask.unsqueeze(1).unsqueeze(2) == 0, float('-inf'))
166 |
167 | A = self.dropout(torch.softmax(scores, dim=-1))
168 |
169 | V = torch.einsum("bhls,bshd->blhd", A, values)
170 |
171 | if self.output_attention:
172 | return V.contiguous()
173 | else:
174 | return (V.contiguous(), scores.abs().mean())
--------------------------------------------------------------------------------
/layers/StandardNorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Normalize(nn.Module):
5 | def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
6 | """
7 | :param num_features: the number of features or channels
8 | :param eps: a value added for numerical stability
9 | :param affine: if True, RevIN has learnable affine parameters
10 | """
11 | super(Normalize, self).__init__()
12 | self.num_features = num_features
13 | self.eps = eps
14 | self.affine = affine
15 | self.subtract_last = subtract_last
16 | self.non_norm = non_norm
17 | if self.affine:
18 | self._init_params()
19 |
20 | def forward(self, x, mode: str):
21 | if mode == 'norm':
22 | self._get_statistics(x)
23 | x = self._normalize(x)
24 | elif mode == 'denorm':
25 | x = self._denormalize(x)
26 | else:
27 | raise NotImplementedError
28 | return x
29 |
30 | def _init_params(self):
31 | self.affine_weight = nn.Parameter(torch.ones(self.num_features))
32 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
33 |
34 | def _get_statistics(self, x):
35 | dim2reduce = tuple(range(1, x.ndim - 1))
36 | if self.subtract_last:
37 | self.last = x[:, -1, :].unsqueeze(1)
38 | else:
39 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
40 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
41 |
42 | def _normalize(self, x):
43 | if self.non_norm:
44 | return x
45 | if self.subtract_last:
46 | x = x - self.last
47 | else:
48 | x = x - self.mean
49 | x = x / self.stdev
50 | if self.affine:
51 | x = x * self.affine_weight
52 | x = x + self.affine_bias
53 | return x
54 |
55 | def _denormalize(self, x):
56 | if self.non_norm:
57 | return x
58 | if self.affine:
59 | x = x - self.affine_bias
60 | x = x / (self.affine_weight + self.eps * self.eps)
61 | x = x * self.stdev
62 | if self.subtract_last:
63 | x = x + self.last
64 | else:
65 | x = x + self.mean
66 | return x
--------------------------------------------------------------------------------
/layers/Transformer_Encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class EncoderLayer(nn.Module):
6 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", dec_in=866):
7 | super(EncoderLayer, self).__init__()
8 | d_ff = d_ff or 4 * d_model
9 | self.attention = attention
10 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
11 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
12 | self.norm1 = nn.LayerNorm(d_model)
13 | self.norm2 = nn.LayerNorm(d_model)
14 | self.dropout = nn.Dropout(dropout)
15 | self.activation = F.relu if activation == "relu" else F.gelu
16 |
17 | def forward(self, x, attn_mask=None, tau=None, delta=None):
18 | new_x, attn = self.attention(
19 | x, x, x,
20 | attn_mask=attn_mask,
21 | tau=tau, delta=delta
22 | )
23 | x = x + self.dropout(new_x)
24 | y = x = self.norm1(x)
25 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
26 | y = self.dropout(self.conv2(y).transpose(-1, 1))
27 | return self.norm2(x + y), attn
28 |
29 |
30 | class Encoder(nn.Module):
31 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
32 | super(Encoder, self).__init__()
33 | self.attn_layers = nn.ModuleList(attn_layers)
34 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
35 | self.norm = norm_layer
36 |
37 | def forward(self, x, attn_mask=None, tau=None, delta=None):
38 | # x [B, L, D]
39 | attns = []
40 | if self.conv_layers is not None:
41 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
42 | delta = delta if i == 0 else None
43 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
44 | x = conv_layer(x)
45 | attns.append(attn)
46 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
47 | attns.append(attn)
48 | else:
49 | for attn_layer in self.attn_layers:
50 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
51 | attns.append(attn)
52 |
53 | if self.norm is not None:
54 | x = self.norm(x)
55 |
56 | return x, attns
--------------------------------------------------------------------------------
/model/SimpleTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_Encoder import Encoder, EncoderLayer
5 | from layers.SWTAttention_Family import GeomAttentionLayer, GeomAttention
6 | from layers.Embed import DataEmbedding_inverted
7 |
8 |
9 | class Model(nn.Module):
10 | def __init__(self, configs):
11 | super(Model, self).__init__()
12 | self.seq_len = configs.seq_len
13 | self.pred_len = configs.pred_len
14 | self.output_attention = configs.output_attention
15 | self.use_norm = configs.use_norm
16 | self.geomattn_dropout = configs.geomattn_dropout
17 | self.alpha = configs.alpha
18 | self.kernel_size = configs.kernel_size
19 |
20 | enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model,
21 | configs.embed, configs.freq, configs.dropout)
22 | self.enc_embedding = enc_embedding
23 |
24 | encoder = Encoder(
25 | [
26 | EncoderLayer(
27 | GeomAttentionLayer(
28 | GeomAttention(
29 | False, configs.factor, attention_dropout=configs.dropout,
30 | output_attention=configs.output_attention, alpha=self.alpha
31 | ),
32 | configs.d_model,
33 | requires_grad=configs.requires_grad,
34 | wv=configs.wv,
35 | m=configs.m,
36 | d_channel=configs.dec_in,
37 | kernel_size=self.kernel_size,
38 | geomattn_dropout=self.geomattn_dropout
39 | ),
40 | configs.d_model,
41 | configs.d_ff,
42 | dropout=configs.dropout,
43 | activation=configs.activation,
44 | ) for l in range(configs.e_layers)
45 | ],
46 | norm_layer=torch.nn.LayerNorm(configs.d_model)
47 | )
48 | self.encoder = encoder
49 |
50 | projector = nn.Linear(configs.d_model, self.pred_len, bias=True)
51 | self.projector = projector
52 |
53 |
54 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
55 | if self.use_norm:
56 | means = x_enc.mean(1, keepdim=True).detach()
57 | x_enc = x_enc - means
58 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
59 | # x_enc /= stdev
60 | x_enc = x_enc / stdev
61 |
62 | _, _, N = x_enc.shape
63 |
64 | enc_embedding = self.enc_embedding
65 | encoder = self.encoder
66 | projector = self.projector
67 | # Linear Projection B L N -> B L' (pseudo temporal tokens) N
68 | enc_out = enc_embedding(x_enc, x_mark_enc)
69 |
70 | # SimpleTM Layer B L' N -> B L' N
71 | enc_out, attns = encoder(enc_out, attn_mask=None)
72 |
73 | # Output Projection B L' N -> B H (Horizon) N
74 | dec_out = projector(enc_out).permute(0, 2, 1)[:, :, :N]
75 |
76 | if self.use_norm:
77 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
78 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
79 |
80 | return dec_out, attns
81 |
82 |
83 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
84 | dec_out, attns = self.forecast(x_enc, None, None, None)
85 | return dec_out, attns
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from experiments.exp_long_term_forecasting import Exp_Long_Term_Forecast
4 | import random
5 | import numpy as np
6 | from model.SimpleTM import Model
7 |
8 | if __name__ == '__main__':
9 |
10 | parser = argparse.ArgumentParser(description='iTransformer')
11 |
12 | # basic config
13 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status')
14 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
15 |
16 | parser.add_argument('--model', type=str, required=True, default='SimpleTM',
17 | help='model name, options: [SimpleTM]')
18 |
19 | # data loader
20 | parser.add_argument('--data', type=str, required=True, default='custom', help='dataset type')
21 | parser.add_argument('--root_path', type=str, default='./data/electricity/', help='root path of the data file')
22 | parser.add_argument('--data_path', type=str, default='electricity.csv', help='data csv file')
23 | parser.add_argument('--features', type=str, default='M',
24 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
25 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
26 | parser.add_argument('--freq', type=str, default='h',
27 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
28 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
29 |
30 | # forecasting task
31 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
32 | parser.add_argument('--label_len', type=int, default=0, help='start token length')
33 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
34 |
35 | # model define
36 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
37 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
38 | parser.add_argument('--c_out', type=int, default=7, help='output size')
39 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
40 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
41 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
42 | parser.add_argument('--factor', type=int, default=1, help='attn factor')
43 | parser.add_argument('--distil', action='store_false',
44 | help='whether to use distilling in encoder, using this argument means not using distilling',
45 | default=True)
46 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
47 | parser.add_argument('--geomattn_dropout', type=float, default=0.5, help='dropout rate of the projection layer in the geometric attention')
48 | parser.add_argument('--embed', type=str, default='timeF',
49 | help='time features encoding, options:[timeF, fixed, learned]')
50 | parser.add_argument('--activation', type=str, default='gelu', help='activation')
51 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data')
52 |
53 | # optimization
54 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
55 | parser.add_argument('--itr', type=int, default=1, help='experiments times')
56 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
57 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
58 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
59 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
60 | parser.add_argument('--des', type=str, default='test', help='exp description')
61 | parser.add_argument('--loss', type=str, default='MSE', help='loss function')
62 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
63 | parser.add_argument('--pct_start', type=float, default=0.2, help='Warmup ratio for the learning rate scheduler')
64 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
65 |
66 | # GPU
67 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
68 | parser.add_argument('--gpu', type=int, default=0, help='gpu')
69 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
70 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
71 |
72 | parser.add_argument('--exp_name', type=str, required=False, default='MTSF',
73 | help='experiemnt name, options:[MTSF, partial_train]')
74 | parser.add_argument('--channel_independence', type=bool, default=False, help='whether to use channel_independence mechanism')
75 | parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)
76 | parser.add_argument('--class_strategy', type=str, default='projection', help='projection/average/cls_token')
77 | parser.add_argument('--target_root_path', type=str, default='./data/electricity/', help='root path of the data file')
78 | parser.add_argument('--target_data_path', type=str, default='electricity.csv', help='data file')
79 | parser.add_argument('--efficient_training', type=bool, default=False, help='whether to use efficient_training (exp_name should be partial train)') # See Figure 8 of our paper for the detail
80 | parser.add_argument('--use_norm', type=int, default=True, help='use norm and denorm')
81 | parser.add_argument('--partial_start_index', type=int, default=0, help='the start index of variates for partial training, '
82 | 'you can select [partial_start_index, min(enc_in + partial_start_index, N)]')
83 |
84 | # SimpleTM Arguments
85 | parser.add_argument('--requires_grad', type=bool, default=True, help='Set to True to enable learnable wavelets')
86 | parser.add_argument('--wv', type=str, default='db1', help='Wavelet filter type. Supports all wavelets available in PyTorch Wavelets')
87 | parser.add_argument('--m', type=int, default=3, help='Number of levels for the stationary wavelet transform')
88 | parser.add_argument('--kernel_size', default=None, help='Specify the length of randomly initialized wavelets (if not None)')
89 | parser.add_argument('--alpha', type=float, default=1, help='Weight of the inner product score in geometric attention')
90 | parser.add_argument('--l1_weight', type=float, default=5e-5, help='Weight of L1 loss')
91 | parser.add_argument('--d_model', type=int, default=32, help='Dimensionality of pseudo tokens')
92 | parser.add_argument('--d_ff', type=int, default=32, help='Dimensionality of the feedforward network')
93 | parser.add_argument('--e_layers', type=int, default=1, help='Number of SimpleTM layers')
94 | parser.add_argument('--compile', type=bool, default=False, help='Set to True to enable compilation, which can accelerate speed but may slightly impact performance')
95 | parser.add_argument('--output_attention', action='store_true', help='Set to False to output attn, which can be used to compute training loss')
96 |
97 | parser.add_argument('--fix_seed', type=int, default=2025, help='gpu')
98 |
99 | args = parser.parse_args()
100 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
101 |
102 | fix_seed = args.fix_seed
103 | random.seed(fix_seed)
104 | torch.manual_seed(fix_seed)
105 | np.random.seed(fix_seed)
106 |
107 | if args.use_gpu and args.use_multi_gpu:
108 | args.devices = args.devices.replace(' ', '')
109 | device_ids = args.devices.split(',')
110 | args.device_ids = [int(id_) for id_ in device_ids]
111 | args.gpu = args.device_ids[0]
112 |
113 | print('Args in experiment:')
114 | print(args)
115 |
116 | Exp = Exp_Long_Term_Forecast
117 |
118 | if args.is_training:
119 | for ii in range(args.itr):
120 | # setting record of experiments
121 | setting = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
122 | args.model_id,
123 | args.data,
124 | args.seq_len,
125 | args.pred_len,
126 | args.d_model,
127 | args.d_ff,
128 | args.e_layers,
129 | args.wv,
130 | args.kernel_size,
131 | args.m,
132 | args.alpha,
133 | args.l1_weight,
134 | args.learning_rate,
135 | args.lradj,
136 | args.batch_size,
137 | args.fix_seed,
138 | args.use_norm,
139 | ii)
140 |
141 | exp = Exp(args) # set experiments
142 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
143 | exp.train(setting)
144 |
145 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
146 | exp.test(setting)
147 |
148 | if args.do_predict:
149 | print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
150 | exp.predict(setting, True)
151 |
152 | torch.cuda.empty_cache()
153 | else:
154 |
155 | ii = 0
156 | setting = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
157 | args.data,
158 | args.seq_len,
159 | args.pred_len,
160 | args.d_model,
161 | args.d_ff,
162 | args.e_layers,
163 | args.wv,
164 | args.kernel_size,
165 | args.m,
166 | args.alpha,
167 | args.l1_weight,
168 | args.learning_rate,
169 | args.lradj,
170 | args.batch_size,
171 | args.fix_seed,
172 | args.use_norm,
173 | ii)
174 |
175 | exp = Exp(args) # set experiments
176 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
177 | exp.test(setting, test=1)
178 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ECL/SimpleTM.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 3 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL \
11 | --model "$model_name" \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 1 \
17 | --d_model 256 \
18 | --d_ff 1024 \
19 | --learning_rate 0.01 \
20 | --batch_size 256 \
21 | --fix_seed 2025 \
22 | --use_norm 1 \
23 | --wv "db1" \
24 | --m 3 \
25 | --enc_in 321 \
26 | --dec_in 321 \
27 | --c_out 321 \
28 | --des 'Exp' \
29 | --itr 3 \
30 | --alpha 0.0 \
31 | --l1_weight 0.0
32 |
33 | python -u run.py \
34 | --is_training 1 \
35 | --lradj 'TST' \
36 | --patience 3 \
37 | --root_path ./dataset/electricity/ \
38 | --data_path electricity.csv \
39 | --model_id ECL \
40 | --model "$model_name" \
41 | --data custom \
42 | --features M \
43 | --seq_len 96 \
44 | --pred_len 192 \
45 | --e_layers 1 \
46 | --d_model 256 \
47 | --d_ff 1024 \
48 | --learning_rate 0.006 \
49 | --batch_size 256 \
50 | --fix_seed 2025 \
51 | --use_norm 1 \
52 | --wv "db1" \
53 | --m 3 \
54 | --enc_in 321 \
55 | --dec_in 321 \
56 | --c_out 321 \
57 | --des 'Exp' \
58 | --itr 3 \
59 | --alpha 0.0 \
60 | --l1_weight 0.0
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --lradj 'TST' \
65 | --patience 3 \
66 | --root_path ./dataset/electricity/ \
67 | --data_path electricity.csv \
68 | --model_id ECL \
69 | --model "$model_name" \
70 | --data custom \
71 | --features M \
72 | --seq_len 96 \
73 | --pred_len 336 \
74 | --e_layers 1 \
75 | --d_model 256 \
76 | --d_ff 1024 \
77 | --learning_rate 0.006 \
78 | --batch_size 256 \
79 | --fix_seed 2025 \
80 | --use_norm 1 \
81 | --wv "db1" \
82 | --m 3 \
83 | --enc_in 321 \
84 | --dec_in 321 \
85 | --c_out 321 \
86 | --des 'Exp' \
87 | --itr 3 \
88 | --alpha 0.0 \
89 | --l1_weight 5e-5
90 |
91 | python -u run.py \
92 | --is_training 1 \
93 | --lradj 'TST' \
94 | --patience 3 \
95 | --root_path ./dataset/electricity/ \
96 | --data_path electricity.csv \
97 | --model_id ECL \
98 | --model "$model_name" \
99 | --data custom \
100 | --features M \
101 | --seq_len 96 \
102 | --pred_len 720 \
103 | --e_layers 1 \
104 | --d_model 256 \
105 | --d_ff 1024 \
106 | --learning_rate 0.006 \
107 | --batch_size 256 \
108 | --fix_seed 2025 \
109 | --use_norm 1 \
110 | --wv "db1" \
111 | --m 3 \
112 | --enc_in 321 \
113 | --dec_in 321 \
114 | --c_out 321 \
115 | --des 'Exp' \
116 | --itr 3 \
117 | --alpha 0.0 \
118 | --l1_weight 5e-5
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ETT/SimpleTM_h1.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj TST \
7 | --patience 3 \
8 | --root_path ./dataset/ETT-small/ \
9 | --data_path ETTh1.csv \
10 | --model_id ETTh1 \
11 | --model $model_name \
12 | --data ETTh1 \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 1 \
17 | --d_model 32 \
18 | --d_ff 32 \
19 | --learning_rate 0.02 \
20 | --batch_size 256 \
21 | --fix_seed 2025 \
22 | --use_norm 1 \
23 | --wv db1 \
24 | --m 3 \
25 | --enc_in 7 \
26 | --dec_in 7 \
27 | --c_out 7 \
28 | --des Exp \
29 | --itr 3 \
30 | --alpha 0.3 \
31 | --l1_weight 0.0005 \
32 |
33 | python -u run.py \
34 | --is_training 1 \
35 | --lradj TST \
36 | --patience 3 \
37 | --root_path ./dataset/ETT-small/ \
38 | --data_path ETTh1.csv \
39 | --model_id ETTh1 \
40 | --model $model_name \
41 | --data ETTh1 \
42 | --features M \
43 | --seq_len 96 \
44 | --pred_len 192 \
45 | --e_layers 1 \
46 | --d_model 32 \
47 | --d_ff 32 \
48 | --learning_rate 0.02 \
49 | --batch_size 256 \
50 | --fix_seed 2025 \
51 | --use_norm 1 \
52 | --wv db1 \
53 | --m 3 \
54 | --enc_in 7 \
55 | --dec_in 7 \
56 | --c_out 7 \
57 | --des Exp \
58 | --itr 3 \
59 | --alpha 1.0 \
60 | --l1_weight 5e-05 \
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --lradj TST \
65 | --patience 3 \
66 | --root_path ./dataset/ETT-small/ \
67 | --data_path ETTh1.csv \
68 | --model_id ETTh1 \
69 | --model $model_name \
70 | --data ETTh1 \
71 | --features M \
72 | --seq_len 96 \
73 | --pred_len 336 \
74 | --e_layers 4 \
75 | --d_model 64 \
76 | --d_ff 64 \
77 | --learning_rate 0.002 \
78 | --batch_size 256 \
79 | --fix_seed 2025 \
80 | --use_norm 1 \
81 | --wv db1 \
82 | --m 3 \
83 | --enc_in 7 \
84 | --dec_in 7 \
85 | --c_out 7 \
86 | --des Exp \
87 | --itr 3 \
88 | --alpha 0.0 \
89 | --l1_weight 0.0 \
90 |
91 | python -u run.py \
92 | --is_training 1 \
93 | --lradj TST \
94 | --patience 3 \
95 | --root_path ./dataset/ETT-small/ \
96 | --data_path ETTh1.csv \
97 | --model_id ETTh1 \
98 | --model $model_name \
99 | --data ETTh1 \
100 | --features M \
101 | --seq_len 96 \
102 | --pred_len 720 \
103 | --e_layers 1 \
104 | --d_model 32 \
105 | --d_ff 32 \
106 | --learning_rate 0.009 \
107 | --batch_size 256 \
108 | --fix_seed 2025 \
109 | --use_norm 1 \
110 | --wv db1 \
111 | --m 1 \
112 | --enc_in 7 \
113 | --dec_in 7 \
114 | --c_out 7 \
115 | --des Exp \
116 | --itr 3 \
117 | --alpha 0.9 \
118 | --l1_weight 0.0005 \
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ETT/SimpleTM_h2.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj TST \
7 | --patience 3 \
8 | --root_path ./dataset/ETT-small/ \
9 | --data_path ETTh2.csv \
10 | --model_id ETTh2 \
11 | --model $model_name \
12 | --data ETTh2 \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 1 \
17 | --d_model 32 \
18 | --d_ff 32 \
19 | --learning_rate 0.006 \
20 | --batch_size 256 \
21 | --fix_seed 2025 \
22 | --use_norm 1 \
23 | --wv bior3.1 \
24 | --m 1 \
25 | --enc_in 7 \
26 | --dec_in 7 \
27 | --c_out 7 \
28 | --des Exp \
29 | --itr 3 \
30 | --alpha 0.1 \
31 | --l1_weight 0.0005 \
32 |
33 | python -u run.py \
34 | --is_training 1 \
35 | --lradj TST \
36 | --patience 3 \
37 | --root_path ./dataset/ETT-small/ \
38 | --data_path ETTh2.csv \
39 | --model_id ETTh2 \
40 | --model $model_name \
41 | --data ETTh2 \
42 | --features M \
43 | --seq_len 96 \
44 | --pred_len 192 \
45 | --e_layers 1 \
46 | --d_model 32 \
47 | --d_ff 32 \
48 | --learning_rate 0.006 \
49 | --batch_size 256 \
50 | --fix_seed 2025 \
51 | --use_norm 1 \
52 | --wv db1 \
53 | --m 1 \
54 | --enc_in 7 \
55 | --dec_in 7 \
56 | --c_out 7 \
57 | --des Exp \
58 | --itr 3 \
59 | --alpha 0.1 \
60 | --l1_weight 0.005 \
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --lradj TST \
65 | --patience 3 \
66 | --root_path ./dataset/ETT-small/ \
67 | --data_path ETTh2.csv \
68 | --model_id ETTh2 \
69 | --model $model_name \
70 | --data ETTh2 \
71 | --features M \
72 | --seq_len 96 \
73 | --pred_len 336 \
74 | --e_layers 1 \
75 | --d_model 32 \
76 | --d_ff 32 \
77 | --learning_rate 0.003 \
78 | --batch_size 256 \
79 | --fix_seed 2025 \
80 | --use_norm 1 \
81 | --wv db1 \
82 | --m 1 \
83 | --enc_in 7 \
84 | --dec_in 7 \
85 | --c_out 7 \
86 | --des Exp \
87 | --itr 3 \
88 | --alpha 0.9 \
89 | --l1_weight 0.0 \
90 |
91 | python -u run.py \
92 | --is_training 1 \
93 | --lradj TST \
94 | --patience 3 \
95 | --root_path ./dataset/ETT-small/ \
96 | --data_path ETTh2.csv \
97 | --model_id ETTh2 \
98 | --model $model_name \
99 | --data ETTh2 \
100 | --features M \
101 | --seq_len 96 \
102 | --pred_len 720 \
103 | --e_layers 1 \
104 | --d_model 32 \
105 | --d_ff 32 \
106 | --learning_rate 0.003 \
107 | --batch_size 256 \
108 | --fix_seed 2025 \
109 | --use_norm 1 \
110 | --wv db1 \
111 | --m 1 \
112 | --enc_in 7 \
113 | --dec_in 7 \
114 | --c_out 7 \
115 | --des Exp \
116 | --itr 3 \
117 | --alpha 1.0 \
118 | --l1_weight 5e-05 \
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ETT/SimpleTM_m1.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 3 \
8 | --root_path ./dataset/ETT-small/ \
9 | --data_path ETTm1.csv \
10 | --model_id ETTm1 \
11 | --model "$model_name" \
12 | --data ETTm1 \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 1 \
17 | --d_model 32 \
18 | --d_ff 32 \
19 | --learning_rate 0.02 \
20 | --batch_size 256 \
21 | --fix_seed 2025 \
22 | --use_norm 1 \
23 | --wv 'db1' \
24 | --m 3 \
25 | --enc_in 7 \
26 | --dec_in 7 \
27 | --c_out 7 \
28 | --des 'Exp' \
29 | --itr 3 \
30 | --alpha 0.1 \
31 | --l1_weight 0.005
32 |
33 | python -u run.py \
34 | --is_training 1 \
35 | --lradj 'TST' \
36 | --patience 3 \
37 | --root_path ./dataset/ETT-small/ \
38 | --data_path ETTm1.csv \
39 | --model_id ETTm1 \
40 | --model "$model_name" \
41 | --data ETTm1 \
42 | --features M \
43 | --seq_len 96 \
44 | --pred_len 192 \
45 | --e_layers 1 \
46 | --d_model 32 \
47 | --d_ff 32 \
48 | --learning_rate 0.02 \
49 | --batch_size 256 \
50 | --fix_seed 2025 \
51 | --use_norm 1 \
52 | --wv 'db1' \
53 | --m 3 \
54 | --enc_in 7 \
55 | --dec_in 7 \
56 | --c_out 7 \
57 | --des 'Exp' \
58 | --itr 3 \
59 | --alpha 0.1 \
60 | --l1_weight 0.005
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --lradj 'TST' \
65 | --patience 3 \
66 | --root_path ./dataset/ETT-small/ \
67 | --data_path ETTm1.csv \
68 | --model_id ETTm1 \
69 | --model "$model_name" \
70 | --data ETTm1 \
71 | --features M \
72 | --seq_len 96 \
73 | --pred_len 336 \
74 | --e_layers 1 \
75 | --d_model 32 \
76 | --d_ff 32 \
77 | --learning_rate 0.02 \
78 | --batch_size 256 \
79 | --fix_seed 2025 \
80 | --use_norm 1 \
81 | --wv 'db1' \
82 | --m 1 \
83 | --enc_in 7 \
84 | --dec_in 7 \
85 | --c_out 7 \
86 | --des 'Exp' \
87 | --itr 3 \
88 | --alpha 0.1 \
89 | --l1_weight 0.005
90 |
91 | python -u run.py \
92 | --is_training 1 \
93 | --lradj 'TST' \
94 | --patience 3 \
95 | --root_path ./dataset/ETT-small/ \
96 | --data_path ETTm1.csv \
97 | --model_id ETTm1 \
98 | --model "$model_name" \
99 | --data ETTm1 \
100 | --features M \
101 | --seq_len 96 \
102 | --pred_len 720 \
103 | --e_layers 1 \
104 | --d_model 32 \
105 | --d_ff 32 \
106 | --learning_rate 0.02 \
107 | --batch_size 256 \
108 | --fix_seed 2025 \
109 | --use_norm 1 \
110 | --wv 'db1' \
111 | --m 3 \
112 | --enc_in 7 \
113 | --dec_in 7 \
114 | --c_out 7 \
115 | --des 'Exp' \
116 | --itr 3 \
117 | --alpha 0.1 \
118 | --l1_weight 0.005
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ETT/SimpleTM_m2.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 3 \
8 | --root_path ./dataset/ETT-small/ \
9 | --data_path ETTm2.csv \
10 | --model_id ETTm2 \
11 | --model "$model_name" \
12 | --data ETTm2 \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 1 \
17 | --d_model 32 \
18 | --d_ff 32 \
19 | --learning_rate 0.006\
20 | --batch_size 256 \
21 | --fix_seed 2025 \
22 | --use_norm 1 \
23 | --wv "bior3.1" \
24 | --m 3 \
25 | --enc_in 7 \
26 | --dec_in 7 \
27 | --c_out 7 \
28 | --des 'Exp' \
29 | --itr 3 \
30 | --alpha 0.3 \
31 | --l1_weight 0.0005
32 |
33 | python -u run.py \
34 | --is_training 1 \
35 | --lradj 'TST' \
36 | --patience 3 \
37 | --root_path ./dataset/ETT-small/ \
38 | --data_path ETTm2.csv \
39 | --model_id ETTm2 \
40 | --model "$model_name" \
41 | --data ETTm2 \
42 | --features M \
43 | --seq_len 96 \
44 | --pred_len 192 \
45 | --e_layers 1 \
46 | --d_model 32 \
47 | --d_ff 32 \
48 | --learning_rate 0.006\
49 | --batch_size 256 \
50 | --fix_seed 2025 \
51 | --use_norm 1 \
52 | --wv "bior3.1" \
53 | --m 1 \
54 | --enc_in 7 \
55 | --dec_in 7 \
56 | --c_out 7 \
57 | --des 'Exp' \
58 | --itr 3 \
59 | --alpha 0.0 \
60 | --l1_weight 0.005
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --lradj 'TST' \
65 | --patience 3 \
66 | --root_path ./dataset/ETT-small/ \
67 | --data_path ETTm2.csv \
68 | --model_id ETTm2 \
69 | --model "$model_name" \
70 | --data ETTm2 \
71 | --features M \
72 | --seq_len 96 \
73 | --pred_len 336 \
74 | --e_layers 1 \
75 | --d_model 64 \
76 | --d_ff 64 \
77 | --learning_rate 0.006\
78 | --batch_size 128 \
79 | --fix_seed 2025 \
80 | --use_norm 1 \
81 | --wv "bior3.3" \
82 | --m 1 \
83 | --enc_in 7 \
84 | --dec_in 7 \
85 | --c_out 7 \
86 | --des 'Exp' \
87 | --itr 3 \
88 | --alpha 0.6 \
89 | --l1_weight 5e-5
90 |
91 | python -u run.py \
92 | --is_training 1 \
93 | --lradj 'TST' \
94 | --patience 3 \
95 | --root_path ./dataset/ETT-small/ \
96 | --data_path ETTm2.csv \
97 | --model_id ETTm2 \
98 | --model "$model_name" \
99 | --data ETTm2 \
100 | --features M \
101 | --seq_len 96 \
102 | --pred_len 720 \
103 | --e_layers 1 \
104 | --d_model 96 \
105 | --d_ff 96 \
106 | --learning_rate 0.003\
107 | --batch_size 256 \
108 | --fix_seed 2025 \
109 | --use_norm 1 \
110 | --wv "db1" \
111 | --m 3 \
112 | --enc_in 7 \
113 | --dec_in 7 \
114 | --c_out 7 \
115 | --des 'Exp' \
116 | --itr 3 \
117 | --alpha 1.0 \
118 | --l1_weight 0.0
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/PEMS/SimpleTM_03.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 10 \
8 | --train_epochs 20 \
9 | --root_path ./dataset/PEMS/ \
10 | --data_path PEMS03.npz \
11 | --model_id PEMS03 \
12 | --model "$model_name" \
13 | --data PEMS \
14 | --features M \
15 | --seq_len 96 \
16 | --pred_len 12 \
17 | --e_layers 1 \
18 | --d_model 256 \
19 | --d_ff 512 \
20 | --learning_rate 0.002 \
21 | --batch_size 16 \
22 | --fix_seed 2025 \
23 | --use_norm 1 \
24 | --wv 'bior3.1' \
25 | --m 3 \
26 | --enc_in 358 \
27 | --dec_in 358 \
28 | --c_out 358 \
29 | --des 'Exp' \
30 | --itr 3 \
31 | --alpha 0.1 \
32 | --use_norm 0 \
33 | --l1_weight 0.005
34 |
35 | python -u run.py \
36 | --is_training 1 \
37 | --lradj 'TST' \
38 | --patience 10 \
39 | --train_epochs 20 \
40 | --root_path ./dataset/PEMS/ \
41 | --data_path PEMS03.npz \
42 | --model_id PEMS03 \
43 | --model "$model_name" \
44 | --data PEMS \
45 | --features M \
46 | --seq_len 96 \
47 | --pred_len 24 \
48 | --e_layers 1 \
49 | --d_model 256 \
50 | --d_ff 512 \
51 | --learning_rate 0.002 \
52 | --batch_size 16 \
53 | --fix_seed 2025 \
54 | --use_norm 1 \
55 | --wv 'bior3.1' \
56 | --m 3 \
57 | --enc_in 358 \
58 | --dec_in 358 \
59 | --c_out 358 \
60 | --des 'Exp' \
61 | --itr 3 \
62 | --alpha 0.1 \
63 | --use_norm 0 \
64 | --l1_weight 0.005
65 |
66 | python -u run.py \
67 | --is_training 1 \
68 | --lradj 'TST' \
69 | --patience 10 \
70 | --train_epochs 20 \
71 | --root_path ./dataset/PEMS/ \
72 | --data_path PEMS03.npz \
73 | --model_id PEMS03 \
74 | --model "$model_name" \
75 | --data PEMS \
76 | --features M \
77 | --seq_len 96 \
78 | --pred_len 48 \
79 | --e_layers 1 \
80 | --d_model 256 \
81 | --d_ff 1024 \
82 | --learning_rate 0.002 \
83 | --batch_size 16 \
84 | --fix_seed 2025 \
85 | --use_norm 0 \
86 | --wv 'bior3.1' \
87 | --m 3 \
88 | --enc_in 358 \
89 | --dec_in 358 \
90 | --c_out 358 \
91 | --des 'Exp' \
92 | --itr 3 \
93 | --alpha 0.1 \
94 | --l1_weight 0.005
95 |
96 | python -u run.py \
97 | --is_training 1 \
98 | --lradj 'TST' \
99 | --patience 10 \
100 | --train_epochs 20 \
101 | --root_path ./dataset/PEMS/ \
102 | --data_path PEMS03.npz \
103 | --model_id PEMS03 \
104 | --model "$model_name" \
105 | --data PEMS \
106 | --features M \
107 | --seq_len 96 \
108 | --pred_len 96 \
109 | --e_layers 1 \
110 | --d_model 256 \
111 | --d_ff 1024 \
112 | --learning_rate 0.002 \
113 | --batch_size 16 \
114 | --fix_seed 2025 \
115 | --use_norm 0 \
116 | --wv 'bior3.1' \
117 | --m 3 \
118 | --enc_in 358 \
119 | --dec_in 358 \
120 | --c_out 358 \
121 | --des 'Exp' \
122 | --itr 3 \
123 | --alpha 0.1 \
124 | --l1_weight 0.005
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/PEMS/SimpleTM_04.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 10 \
8 | --train_epochs 20 \
9 | --root_path ./dataset/PEMS/ \
10 | --data_path PEMS04.npz \
11 | --model_id PEMS04 \
12 | --model "$model_name" \
13 | --data PEMS \
14 | --features M \
15 | --seq_len 96 \
16 | --pred_len 12 \
17 | --e_layers 2 \
18 | --d_model 256 \
19 | --d_ff 1024 \
20 | --learning_rate 0.002 \
21 | --batch_size 16 \
22 | --fix_seed 2025 \
23 | --use_norm 0 \
24 | --wv 'bior3.1' \
25 | --m 3 \
26 | --enc_in 307 \
27 | --dec_in 307 \
28 | --c_out 307 \
29 | --des 'Exp' \
30 | --itr 3 \
31 | --alpha 0.1 \
32 | --l1_weight 5e-05
33 |
34 | python -u run.py \
35 | --is_training 1 \
36 | --lradj 'TST' \
37 | --patience 10 \
38 | --train_epochs 20 \
39 | --root_path ./dataset/PEMS/ \
40 | --data_path PEMS04.npz \
41 | --model_id PEMS04 \
42 | --model "$model_name" \
43 | --data PEMS \
44 | --features M \
45 | --seq_len 96 \
46 | --pred_len 24 \
47 | --e_layers 1 \
48 | --d_model 256 \
49 | --d_ff 1024 \
50 | --learning_rate 0.002 \
51 | --batch_size 16 \
52 | --fix_seed 2025 \
53 | --use_norm 0 \
54 | --wv 'bior3.1' \
55 | --m 3 \
56 | --enc_in 307 \
57 | --dec_in 307 \
58 | --c_out 307 \
59 | --des 'Exp' \
60 | --itr 3 \
61 | --alpha 0.1 \
62 | --l1_weight 5e-05
63 |
64 | python -u run.py \
65 | --is_training 1 \
66 | --lradj 'TST' \
67 | --patience 10 \
68 | --train_epochs 20 \
69 | --root_path ./dataset/PEMS/ \
70 | --data_path PEMS04.npz \
71 | --model_id PEMS04 \
72 | --model "$model_name" \
73 | --data PEMS \
74 | --features M \
75 | --seq_len 96 \
76 | --pred_len 48 \
77 | --e_layers 1 \
78 | --d_model 256 \
79 | --d_ff 1024 \
80 | --learning_rate 0.002 \
81 | --batch_size 16 \
82 | --fix_seed 2025 \
83 | --use_norm 0 \
84 | --wv 'bior3.1' \
85 | --m 3 \
86 | --enc_in 307 \
87 | --dec_in 307 \
88 | --c_out 307 \
89 | --des 'Exp' \
90 | --itr 3 \
91 | --alpha 0.1 \
92 | --l1_weight 5e-05
93 |
94 | python -u run.py \
95 | --is_training 1 \
96 | --lradj 'TST' \
97 | --patience 10 \
98 | --train_epochs 20 \
99 | --root_path ./dataset/PEMS/ \
100 | --data_path PEMS04.npz \
101 | --model_id PEMS04 \
102 | --model "$model_name" \
103 | --data PEMS \
104 | --features M \
105 | --seq_len 96 \
106 | --pred_len 96 \
107 | --e_layers 1 \
108 | --d_model 256 \
109 | --d_ff 1024 \
110 | --learning_rate 0.002 \
111 | --batch_size 16 \
112 | --fix_seed 2025 \
113 | --use_norm 0 \
114 | --wv 'bior3.1' \
115 | --m 3 \
116 | --enc_in 307 \
117 | --dec_in 307 \
118 | --c_out 307 \
119 | --des 'Exp' \
120 | --itr 3 \
121 | --alpha 0.1 \
122 | --l1_weight 5e-05
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/PEMS/SimpleTM_07.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 10 \
8 | --train_epochs 20 \
9 | --root_path ./dataset/PEMS/ \
10 | --data_path PEMS07.npz \
11 | --model_id PEMS07 \
12 | --model "$model_name" \
13 | --data PEMS \
14 | --features M \
15 | --seq_len 96 \
16 | --pred_len 12 \
17 | --e_layers 1 \
18 | --d_model 256 \
19 | --d_ff 512 \
20 | --learning_rate 0.002 \
21 | --batch_size 16 \
22 | --fix_seed 2025 \
23 | --use_norm 0 \
24 | --wv 'db1' \
25 | --m 3 \
26 | --enc_in 883 \
27 | --dec_in 883 \
28 | --c_out 883 \
29 | --des 'Exp' \
30 | --itr 3 \
31 | --alpha 0.1 \
32 | --l1_weight 5e-05
33 |
34 | python -u run.py \
35 | --is_training 1 \
36 | --lradj 'TST' \
37 | --patience 10 \
38 | --train_epochs 20 \
39 | --root_path ./dataset/PEMS/ \
40 | --data_path PEMS07.npz \
41 | --model_id PEMS07 \
42 | --model "$model_name" \
43 | --data PEMS \
44 | --features M \
45 | --seq_len 96 \
46 | --pred_len 24 \
47 | --e_layers 1 \
48 | --d_model 256 \
49 | --d_ff 512 \
50 | --learning_rate 0.002 \
51 | --batch_size 16 \
52 | --fix_seed 2025 \
53 | --use_norm 0 \
54 | --wv 'db1' \
55 | --m 3 \
56 | --enc_in 883 \
57 | --dec_in 883 \
58 | --c_out 883 \
59 | --des 'Exp' \
60 | --itr 3 \
61 | --alpha 0.1 \
62 | --l1_weight 5e-5
63 |
64 | python -u run.py \
65 | --is_training 1 \
66 | --lradj 'TST' \
67 | --patience 10 \
68 | --train_epochs 20 \
69 | --root_path ./dataset/PEMS/ \
70 | --data_path PEMS07.npz \
71 | --model_id PEMS07 \
72 | --model "$model_name" \
73 | --data PEMS \
74 | --features M \
75 | --seq_len 96 \
76 | --pred_len 48 \
77 | --e_layers 1 \
78 | --d_model 256 \
79 | --d_ff 512 \
80 | --learning_rate 0.002 \
81 | --batch_size 16 \
82 | --fix_seed 2025 \
83 | --use_norm 0 \
84 | --wv 'db1' \
85 | --m 3 \
86 | --enc_in 883 \
87 | --dec_in 883 \
88 | --c_out 883 \
89 | --des 'Exp' \
90 | --itr 3 \
91 | --alpha 0.1 \
92 | --l1_weight 5e-05
93 |
94 | python -u run.py \
95 | --is_training 1 \
96 | --lradj 'TST' \
97 | --patience 10 \
98 | --train_epochs 20 \
99 | --root_path ./dataset/PEMS/ \
100 | --data_path PEMS07.npz \
101 | --model_id PEMS07 \
102 | --model "$model_name" \
103 | --data PEMS \
104 | --features M \
105 | --seq_len 96 \
106 | --pred_len 96 \
107 | --e_layers 1 \
108 | --d_model 256 \
109 | --d_ff 512 \
110 | --learning_rate 0.002 \
111 | --batch_size 16 \
112 | --fix_seed 2025 \
113 | --use_norm 0 \
114 | --wv 'db1' \
115 | --m 3 \
116 | --enc_in 883 \
117 | --dec_in 883 \
118 | --c_out 883 \
119 | --des 'Exp' \
120 | --itr 3 \
121 | --alpha 0.1 \
122 | --l1_weight 5e-5
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/PEMS/SimpleTM_08.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 10 \
8 | --train_epochs 20 \
9 | --root_path ./dataset/PEMS/ \
10 | --data_path PEMS08.npz \
11 | --model_id PEMS08 \
12 | --model "$model_name" \
13 | --data PEMS \
14 | --features M \
15 | --seq_len 96 \
16 | --pred_len 12 \
17 | --e_layers 1 \
18 | --d_model 256 \
19 | --d_ff 512 \
20 | --learning_rate 0.001 \
21 | --batch_size 16 \
22 | --fix_seed 2025 \
23 | --use_norm 0 \
24 | --wv 'db12' \
25 | --m 3 \
26 | --enc_in 170 \
27 | --dec_in 170 \
28 | --c_out 170 \
29 | --des 'Exp' \
30 | --itr 3 \
31 | --alpha 0.0 \
32 | --l1_weight 0.0
33 |
34 | python -u run.py \
35 | --is_training 1 \
36 | --lradj 'TST' \
37 | --patience 10 \
38 | --train_epochs 20 \
39 | --root_path ./dataset/PEMS/ \
40 | --data_path PEMS08.npz \
41 | --model_id PEMS08 \
42 | --model "$model_name" \
43 | --data PEMS \
44 | --features M \
45 | --seq_len 96 \
46 | --pred_len 24 \
47 | --e_layers 1 \
48 | --d_model 256 \
49 | --d_ff 512 \
50 | --learning_rate 0.001 \
51 | --batch_size 16 \
52 | --fix_seed 2025 \
53 | --use_norm 0 \
54 | --wv 'db12' \
55 | --m 3 \
56 | --enc_in 170 \
57 | --dec_in 170 \
58 | --c_out 170 \
59 | --des 'Exp' \
60 | --itr 3 \
61 | --alpha 0.0 \
62 | --l1_weight 0.0
63 |
64 | python -u run.py \
65 | --is_training 1 \
66 | --lradj 'TST' \
67 | --patience 10 \
68 | --train_epochs 20 \
69 | --root_path ./dataset/PEMS/ \
70 | --data_path PEMS08.npz \
71 | --model_id PEMS08 \
72 | --model "$model_name" \
73 | --data PEMS \
74 | --features M \
75 | --seq_len 96 \
76 | --pred_len 48 \
77 | --e_layers 1 \
78 | --d_model 256 \
79 | --d_ff 512 \
80 | --learning_rate 0.001 \
81 | --batch_size 16 \
82 | --fix_seed 2025 \
83 | --use_norm 0 \
84 | --wv 'db12' \
85 | --m 3 \
86 | --enc_in 170 \
87 | --dec_in 170 \
88 | --c_out 170 \
89 | --des 'Exp' \
90 | --itr 3 \
91 | --alpha 0.0 \
92 | --l1_weight 0.0
93 |
94 | python -u run.py \
95 | --is_training 1 \
96 | --lradj 'TST' \
97 | --patience 10 \
98 | --train_epochs 20 \
99 | --root_path ./dataset/PEMS/ \
100 | --data_path PEMS08.npz \
101 | --model_id PEMS08 \
102 | --model "$model_name" \
103 | --data PEMS \
104 | --features M \
105 | --seq_len 96 \
106 | --pred_len 96 \
107 | --e_layers 1 \
108 | --d_model 256 \
109 | --d_ff 1024 \
110 | --learning_rate 0.001 \
111 | --batch_size 16 \
112 | --fix_seed 2025 \
113 | --use_norm 0 \
114 | --wv 'db12' \
115 | --m 3 \
116 | --enc_in 170 \
117 | --dec_in 170 \
118 | --c_out 170 \
119 | --des 'Exp' \
120 | --itr 3 \
121 | --alpha 0.0 \
122 | --l1_weight 0.0
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/SolarEnergy/SimpleTM.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 3 \
8 | --root_path ./dataset/solar/ \
9 | --data_path solar_AL.txt \
10 | --model_id Solar \
11 | --model "$model_name" \
12 | --data Solar \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 1 \
17 | --d_model 128 \
18 | --d_ff 256 \
19 | --learning_rate 0.01 \
20 | --batch_size 256 \
21 | --fix_seed 2025 \
22 | --use_norm 0 \
23 | --wv "db8" \
24 | --m 3 \
25 | --enc_in 137 \
26 | --dec_in 137 \
27 | --c_out 137 \
28 | --des 'Exp' \
29 | --itr 3 \
30 | --use_norm 0 \
31 | --alpha 0.0 \
32 | --l1_weight 0.005
33 |
34 | python -u run.py \
35 | --is_training 1 \
36 | --lradj 'TST' \
37 | --patience 3 \
38 | --root_path ./dataset/solar/ \
39 | --data_path solar_AL.txt \
40 | --model_id Solar \
41 | --model "$model_name" \
42 | --data Solar \
43 | --features M \
44 | --seq_len 96 \
45 | --pred_len 192 \
46 | --e_layers 1 \
47 | --d_model 128 \
48 | --d_ff 256 \
49 | --learning_rate 0.003 \
50 | --batch_size 256 \
51 | --fix_seed 2025 \
52 | --use_norm 0 \
53 | --wv "db8" \
54 | --m 1 \
55 | --enc_in 137 \
56 | --dec_in 137 \
57 | --c_out 137 \
58 | --des 'Exp' \
59 | --itr 2 \
60 | --use_norm 0 \
61 | --alpha 0.0 \
62 | --l1_weight 0.005
63 |
64 | python -u run.py \
65 | --is_training 1 \
66 | --lradj 'TST' \
67 | --patience 3 \
68 | --root_path ./dataset/solar/ \
69 | --data_path solar_AL.txt \
70 | --model_id Solar \
71 | --model "$model_name" \
72 | --data Solar \
73 | --features M \
74 | --seq_len 96 \
75 | --pred_len 336 \
76 | --e_layers 1 \
77 | --d_model 128 \
78 | --d_ff 256 \
79 | --learning_rate 0.003 \
80 | --batch_size 256 \
81 | --fix_seed 2025 \
82 | --use_norm 0 \
83 | --wv "db8" \
84 | --m 1 \
85 | --enc_in 137 \
86 | --dec_in 137 \
87 | --c_out 137 \
88 | --des 'Exp' \
89 | --itr 2 \
90 | --use_norm 0 \
91 | --alpha 0.1 \
92 | --l1_weight 0.005
93 |
94 | python -u run.py \
95 | --is_training 1 \
96 | --lradj 'TST' \
97 | --patience 3 \
98 | --root_path ./dataset/solar/ \
99 | --data_path solar_AL.txt \
100 | --model_id Solar \
101 | --model "$model_name" \
102 | --data Solar \
103 | --features M \
104 | --seq_len 96 \
105 | --pred_len 720 \
106 | --e_layers 1 \
107 | --d_model 128 \
108 | --d_ff 256 \
109 | --learning_rate 0.009 \
110 | --batch_size 256 \
111 | --fix_seed 2025 \
112 | --use_norm 0 \
113 | --wv "db8" \
114 | --m 1 \
115 | --enc_in 137 \
116 | --dec_in 137 \
117 | --c_out 137 \
118 | --des 'Exp' \
119 | --itr 3 \
120 | --use_norm 0 \
121 | --alpha 0.0 \
122 | --l1_weight 0.005
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/Traffic/SimpleTM.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 3 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id Traffic \
11 | --model "$model_name" \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --d_model 512 \
18 | --d_ff 1024 \
19 | --learning_rate 0.003 \
20 | --batch_size 24 \
21 | --fix_seed 2025 \
22 | --use_norm 1 \
23 | --wv "db1" \
24 | --m 3 \
25 | --enc_in 862 \
26 | --dec_in 862 \
27 | --c_out 862 \
28 | --des 'Exp' \
29 | --itr 1 \
30 | --alpha 0.1\
31 | --l1_weight 0.0 \
32 |
33 | python -u run.py \
34 | --is_training 1 \
35 | --lradj 'TST' \
36 | --patience 3 \
37 | --root_path ./dataset/traffic/ \
38 | --data_path traffic.csv \
39 | --model_id Traffic \
40 | --model "$model_name" \
41 | --data custom \
42 | --features M \
43 | --seq_len 96 \
44 | --pred_len 192 \
45 | --e_layers 1 \
46 | --d_model 1024 \
47 | --d_ff 2048 \
48 | --learning_rate 0.0005 \
49 | --batch_size 32 \
50 | --fix_seed 2025 \
51 | --use_norm 1 \
52 | --wv "db1" \
53 | --m 1 \
54 | --enc_in 862 \
55 | --dec_in 862 \
56 | --c_out 862 \
57 | --des 'Exp' \
58 | --itr 1 \
59 | --alpha 0.1\
60 | --l1_weight 0.0 \
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --lradj 'TST' \
65 | --patience 3 \
66 | --root_path ./dataset/traffic/ \
67 | --data_path traffic.csv \
68 | --model_id Traffic \
69 | --model "$model_name" \
70 | --data custom \
71 | --features M \
72 | --seq_len 96 \
73 | --pred_len 336 \
74 | --e_layers 1 \
75 | --d_model 1024 \
76 | --d_ff 2048 \
77 | --learning_rate 0.0005 \
78 | --batch_size 32 \
79 | --fix_seed 2025 \
80 | --use_norm 1 \
81 | --wv "db1" \
82 | --m 1 \
83 | --enc_in 862 \
84 | --dec_in 862 \
85 | --c_out 862 \
86 | --des 'Exp' \
87 | --itr 1 \
88 | --alpha 0.1\
89 | --l1_weight 0.0 \
90 |
91 | ppython -u run.py \
92 | --is_training 1 \
93 | --lradj 'TST' \
94 | --patience 3 \
95 | --root_path ./dataset/traffic/ \
96 | --data_path traffic.csv \
97 | --model_id Traffic \
98 | --model "$model_name" \
99 | --data custom \
100 | --features M \
101 | --seq_len 96 \
102 | --pred_len 720 \
103 | --e_layers 1 \
104 | --d_model 1024 \
105 | --d_ff 2048 \
106 | --learning_rate 0.0005 \
107 | --batch_size 32 \
108 | --fix_seed 2025 \
109 | --use_norm 1 \
110 | --wv "db1" \
111 | --m 1 \
112 | --enc_in 862 \
113 | --dec_in 862 \
114 | --c_out 862 \
115 | --des 'Exp' \
116 | --itr 1 \
117 | --alpha 0.1\
118 | --l1_weight 0.0 \
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/Weather/SimpleTM.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | model_name=SimpleTM
3 |
4 | python -u run.py \
5 | --is_training 1 \
6 | --lradj 'TST' \
7 | --patience 3 \
8 | --root_path ./dataset/weather/ \
9 | --data_path weather.csv \
10 | --model_id Weather \
11 | --model "$model_name" \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 4 \
17 | --d_model 32 \
18 | --d_ff 32 \
19 | --learning_rate 0.01 \
20 | --batch_size 256 \
21 | --fix_seed 2025 \
22 | --use_norm 1 \
23 | --wv "db4" \
24 | --m 1 \
25 | --enc_in 21 \
26 | --dec_in 21 \
27 | --c_out 21 \
28 | --des 'Exp' \
29 | --itr 3 \
30 | --alpha 0.3 \
31 | --l1_weight 5e-05
32 |
33 | python -u run.py \
34 | --is_training 1 \
35 | --lradj 'TST' \
36 | --patience 3 \
37 | --root_path ./dataset/weather/ \
38 | --data_path weather.csv \
39 | --model_id Weather \
40 | --model "$model_name" \
41 | --data custom \
42 | --features M \
43 | --seq_len 96 \
44 | --pred_len 192 \
45 | --e_layers 4 \
46 | --d_model 32 \
47 | --d_ff 32 \
48 | --learning_rate 0.009 \
49 | --batch_size 256 \
50 | --fix_seed 2025 \
51 | --use_norm 1 \
52 | --wv "db4" \
53 | --m 1 \
54 | --enc_in 21 \
55 | --dec_in 21 \
56 | --c_out 21 \
57 | --des 'Exp' \
58 | --itr 3 \
59 | --alpha 0.3 \
60 | --l1_weight 0.0
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --lradj 'TST' \
65 | --patience 3 \
66 | --root_path ./dataset/weather/ \
67 | --data_path weather.csv \
68 | --model_id Weather \
69 | --model "$model_name" \
70 | --data custom \
71 | --features M \
72 | --seq_len 96 \
73 | --pred_len 336 \
74 | --e_layers 1 \
75 | --d_model 32 \
76 | --d_ff 32 \
77 | --learning_rate 0.009 \
78 | --batch_size 256 \
79 | --fix_seed 2025 \
80 | --use_norm 1 \
81 | --wv "db4" \
82 | --m 3 \
83 | --enc_in 21 \
84 | --dec_in 21 \
85 | --c_out 21 \
86 | --des 'Exp' \
87 | --itr 3 \
88 | --alpha 1.0 \
89 | --l1_weight 5e-05
90 |
91 | python -u run.py \
92 | --is_training 1 \
93 | --lradj 'TST' \
94 | --patience 3 \
95 | --root_path ./dataset/weather/ \
96 | --data_path weather.csv \
97 | --model_id Weather \
98 | --model "$model_name" \
99 | --data custom \
100 | --features M \
101 | --seq_len 96 \
102 | --pred_len 720 \
103 | --e_layers 1 \
104 | --d_model 32 \
105 | --d_ff 32 \
106 | --learning_rate 0.02 \
107 | --batch_size 256 \
108 | --fix_seed 2025 \
109 | --use_norm 1 \
110 | --wv "db4" \
111 | --m 1 \
112 | --enc_in 21 \
113 | --dec_in 21 \
114 | --c_out 21 \
115 | --des 'Exp' \
116 | --itr 3 \
117 | --alpha 0.9 \
118 | --l1_weight 0.005
--------------------------------------------------------------------------------
/utils/masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TriangularCausalMask():
5 | def __init__(self, B, L, device="cpu"):
6 | mask_shape = [B, 1, L, L]
7 | with torch.no_grad():
8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
9 |
10 | @property
11 | def mask(self):
12 | return self._mask
13 |
14 |
15 | class ProbMask():
16 | def __init__(self, B, H, L, index, scores, device="cpu"):
17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
19 | indicator = _mask_ex[torch.arange(B)[:, None, None],
20 | torch.arange(H)[None, :, None],
21 | index, :].to(device)
22 | self._mask = indicator.view(scores.shape).to(device)
23 |
24 | @property
25 | def mask(self):
26 | return self._mask
27 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def RSE(pred, true):
5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6 |
7 |
8 | def CORR(pred, true):
9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
11 | return (u / d).mean(-1)
12 |
13 |
14 | def MAE(pred, true):
15 | return np.mean(np.abs(pred - true))
16 |
17 |
18 | def MSE(pred, true):
19 | return np.mean((pred - true) ** 2)
20 |
21 |
22 | def RMSE(pred, true):
23 | return np.sqrt(MSE(pred, true))
24 |
25 | # Troubleshooting for PEMS Nov 8
26 | # def MAPE(pred, true):
27 | # return np.mean(np.abs((pred - true) / true))
28 | def MAPE(pred, true):
29 | mape = np.abs((pred - true) / true)
30 | mape = np.where(mape > 5, 0, mape)
31 | return np.mean(mape)
32 |
33 |
34 | def MSPE(pred, true):
35 | return np.mean(np.square((pred - true) / true))
36 |
37 |
38 | def metric(pred, true):
39 | mae = MAE(pred, true)
40 | mse = MSE(pred, true)
41 | rmse = RMSE(pred, true)
42 | mape = MAPE(pred, true)
43 | mspe = MSPE(pred, true)
44 |
45 | return mae, mse, rmse, mape, mspe
46 |
--------------------------------------------------------------------------------
/utils/timefeatures.py:
--------------------------------------------------------------------------------
1 | # From: gluonts/src/gluonts/time_feature/_base.py
2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License").
5 | # You may not use this file except in compliance with the License.
6 | # A copy of the License is located at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # or in the "license" file accompanying this file. This file is distributed
11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12 | # express or implied. See the License for the specific language governing
13 | # permissions and limitations under the License.
14 |
15 | from typing import List
16 |
17 | import numpy as np
18 | import pandas as pd
19 | from pandas.tseries import offsets
20 | from pandas.tseries.frequencies import to_offset
21 |
22 |
23 | class TimeFeature:
24 | def __init__(self):
25 | pass
26 |
27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
28 | pass
29 |
30 | def __repr__(self):
31 | return self.__class__.__name__ + "()"
32 |
33 |
34 | class SecondOfMinute(TimeFeature):
35 | """Minute of hour encoded as value between [-0.5, 0.5]"""
36 |
37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
38 | return index.second / 59.0 - 0.5
39 |
40 |
41 | class MinuteOfHour(TimeFeature):
42 | """Minute of hour encoded as value between [-0.5, 0.5]"""
43 |
44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
45 | return index.minute / 59.0 - 0.5
46 |
47 |
48 | class HourOfDay(TimeFeature):
49 | """Hour of day encoded as value between [-0.5, 0.5]"""
50 |
51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
52 | return index.hour / 23.0 - 0.5
53 |
54 |
55 | class DayOfWeek(TimeFeature):
56 | """Hour of day encoded as value between [-0.5, 0.5]"""
57 |
58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
59 | return index.dayofweek / 6.0 - 0.5
60 |
61 |
62 | class DayOfMonth(TimeFeature):
63 | """Day of month encoded as value between [-0.5, 0.5]"""
64 |
65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
66 | return (index.day - 1) / 30.0 - 0.5
67 |
68 |
69 | class DayOfYear(TimeFeature):
70 | """Day of year encoded as value between [-0.5, 0.5]"""
71 |
72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
73 | return (index.dayofyear - 1) / 365.0 - 0.5
74 |
75 |
76 | class MonthOfYear(TimeFeature):
77 | """Month of year encoded as value between [-0.5, 0.5]"""
78 |
79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
80 | return (index.month - 1) / 11.0 - 0.5
81 |
82 |
83 | class WeekOfYear(TimeFeature):
84 | """Week of year encoded as value between [-0.5, 0.5]"""
85 |
86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
87 | return (index.isocalendar().week - 1) / 52.0 - 0.5
88 |
89 |
90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
91 | """
92 | Returns a list of time features that will be appropriate for the given frequency string.
93 | Parameters
94 | ----------
95 | freq_str
96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
97 | """
98 |
99 | features_by_offsets = {
100 | offsets.YearEnd: [],
101 | offsets.QuarterEnd: [MonthOfYear],
102 | offsets.MonthEnd: [MonthOfYear],
103 | offsets.Week: [DayOfMonth, WeekOfYear],
104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
107 | offsets.Minute: [
108 | MinuteOfHour,
109 | HourOfDay,
110 | DayOfWeek,
111 | DayOfMonth,
112 | DayOfYear,
113 | ],
114 | offsets.Second: [
115 | SecondOfMinute,
116 | MinuteOfHour,
117 | HourOfDay,
118 | DayOfWeek,
119 | DayOfMonth,
120 | DayOfYear,
121 | ],
122 | }
123 |
124 | offset = to_offset(freq_str)
125 |
126 | for offset_type, feature_classes in features_by_offsets.items():
127 | if isinstance(offset, offset_type):
128 | return [cls() for cls in feature_classes]
129 |
130 | supported_freq_msg = f"""
131 | Unsupported frequency {freq_str}
132 | The following frequencies are supported:
133 | Y - yearly
134 | alias: A
135 | M - monthly
136 | W - weekly
137 | D - daily
138 | B - business days
139 | H - hourly
140 | T - minutely
141 | alias: min
142 | S - secondly
143 | """
144 | raise RuntimeError(supported_freq_msg)
145 |
146 |
147 | def time_features(dates, freq='h'):
148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
149 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 |
8 | plt.switch_backend('agg')
9 |
10 |
11 | def adjust_learning_rate(optimizer, epoch, args, scheduler=None, printout=True):
12 | # lr = args.learning_rate * (0.2 ** (epoch // 2))
13 | if args.lradj == 'type1':
14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
15 | elif args.lradj == 'type2':
16 | lr_adjust = {
17 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
18 | 10: 5e-7, 15: 1e-7, 20: 5e-8
19 | }
20 | elif args.lradj == 'type3':
21 | lr_adjust = {epoch: args.learning_rate if epoch < 2 else args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
22 | elif args.lradj == 'constant':
23 | lr_adjust = {epoch: args.learning_rate * 1}
24 | elif args.lradj == 'TST':
25 | lr_adjust = {epoch: scheduler.get_last_lr()[0]}
26 | if epoch in lr_adjust.keys():
27 | lr = lr_adjust[epoch]
28 | for param_group in optimizer.param_groups:
29 | param_group['lr'] = lr
30 | if printout: print('Updating learning rate to {}'.format(lr))
31 |
32 |
33 | class EarlyStopping:
34 | def __init__(self, patience=7, verbose=False, delta=0):
35 | self.patience = patience
36 | self.verbose = verbose
37 | self.counter = 0
38 | self.best_score = None
39 | self.early_stop = False
40 | self.val_loss_min = np.Inf
41 | self.delta = delta
42 |
43 | def __call__(self, val_loss, model, path):
44 | score = -val_loss
45 | if self.best_score is None:
46 | self.best_score = score
47 | self.save_checkpoint(val_loss, model, path)
48 | elif score < self.best_score + self.delta:
49 | self.counter += 1
50 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
51 | if self.counter >= self.patience:
52 | self.early_stop = True
53 | else:
54 | self.best_score = score
55 | self.save_checkpoint(val_loss, model, path)
56 | self.counter = 0
57 |
58 | def save_checkpoint(self, val_loss, model, path):
59 | if self.verbose:
60 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
61 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
62 | self.val_loss_min = val_loss
63 |
64 |
65 | class dotdict(dict):
66 | """dot.notation access to dictionary attributes"""
67 | __getattr__ = dict.get
68 | __setattr__ = dict.__setitem__
69 | __delattr__ = dict.__delitem__
70 |
71 |
72 | class StandardScaler():
73 | def __init__(self, mean, std):
74 | self.mean = mean
75 | self.std = std
76 |
77 | def transform(self, data):
78 | return (data - self.mean) / self.std
79 |
80 | def inverse_transform(self, data):
81 | return (data * self.std) + self.mean
82 |
83 |
84 | def visual(true, preds=None, name='./pic/test.pdf'):
85 | """
86 | Results visualization
87 | """
88 | plt.figure()
89 | plt.plot(true, label='GroundTruth', linewidth=2)
90 | if preds is not None:
91 | plt.plot(preds, label='Prediction', linewidth=2)
92 | plt.legend()
93 | plt.savefig(name, bbox_inches='tight')
94 |
95 |
96 | def adjustment(gt, pred):
97 | anomaly_state = False
98 | for i in range(len(gt)):
99 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
100 | anomaly_state = True
101 | for j in range(i, 0, -1):
102 | if gt[j] == 0:
103 | break
104 | else:
105 | if pred[j] == 0:
106 | pred[j] = 1
107 | for j in range(i, len(gt)):
108 | if gt[j] == 0:
109 | break
110 | else:
111 | if pred[j] == 0:
112 | pred[j] = 1
113 | elif gt[i] == 0:
114 | anomaly_state = False
115 | if anomaly_state:
116 | pred[i] = 1
117 | return gt, pred
118 |
119 |
120 | def cal_accuracy(y_pred, y_true):
121 | return np.mean(y_pred == y_true)
122 |
--------------------------------------------------------------------------------