├── .gitignore ├── LICENSE ├── NOTICE.txt ├── README.md ├── config.py ├── data ├── ETTh1.csv ├── ETTh2.csv └── ETTm1.csv ├── data_loader.py ├── ipc.py ├── metrics.py ├── model.py ├── run-ds.py ├── settings ├── ds_config_zero.json └── tuned │ ├── ts_full_m_h1_168.json │ ├── ts_full_m_h1_24.json │ ├── ts_full_m_h1_336.json │ ├── ts_full_m_h1_48.json │ ├── ts_full_m_h1_720.json │ ├── ts_full_m_h2_168.json │ ├── ts_full_m_h2_24.json │ ├── ts_full_m_h2_336.json │ ├── ts_full_m_h2_48.json │ ├── ts_full_m_h2_720.json │ ├── ts_full_m_m1_24.json │ ├── ts_full_m_m1_288.json │ ├── ts_full_m_m1_48.json │ ├── ts_full_m_m1_672.json │ ├── ts_full_m_m1_96.json │ ├── ts_full_u_h1_168.json │ ├── ts_full_u_h1_24.json │ ├── ts_full_u_h1_336.json │ ├── ts_full_u_h1_48.json │ ├── ts_full_u_h1_720.json │ ├── ts_full_u_h2_168.json │ ├── ts_full_u_h2_24.json │ ├── ts_full_u_h2_336.json │ ├── ts_full_u_h2_48.json │ ├── ts_full_u_h2_720.json │ ├── ts_full_u_m1_24.json │ ├── ts_full_u_m1_288.json │ ├── ts_full_u_m1_48.json │ ├── ts_full_u_m1_672.json │ ├── ts_full_u_m1_96.json │ ├── ts_query-selector_m_h1_168.json │ ├── ts_query-selector_m_h1_24.json │ ├── ts_query-selector_m_h1_336.json │ ├── ts_query-selector_m_h1_48.json │ ├── ts_query-selector_m_h1_720.json │ ├── ts_query-selector_m_h2_168.json │ ├── ts_query-selector_m_h2_24.json │ ├── ts_query-selector_m_h2_336.json │ ├── ts_query-selector_m_h2_48.json │ ├── ts_query-selector_m_h2_720.json │ ├── ts_query-selector_m_m1_24.json │ ├── ts_query-selector_m_m1_288.json │ ├── ts_query-selector_m_m1_48.json │ ├── ts_query-selector_m_m1_672.json │ ├── ts_query-selector_m_m1_96.json │ ├── ts_query-selector_u_h1_168.json │ ├── ts_query-selector_u_h1_24.json │ ├── ts_query-selector_u_h1_336.json │ ├── ts_query-selector_u_h1_48.json │ ├── ts_query-selector_u_h1_720.json │ ├── ts_query-selector_u_h2_168.json │ ├── ts_query-selector_u_h2_24.json │ ├── ts_query-selector_u_h2_336.json │ ├── ts_query-selector_u_h2_48.json │ ├── ts_query-selector_u_h2_720.json │ ├── ts_query-selector_u_m1_24.json │ ├── ts_query-selector_u_m1_288.json │ ├── ts_query-selector_u_m1_48.json │ ├── ts_query-selector_u_m1_672.json │ └── ts_query-selector_u_m1_96.json ├── train.py └── utils ├── timefeatures.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | /dumps/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | This project includes code from: 2 | https://github.com/zhouhaoyi/Informer2020 (Apache 2.0 License) 3 | 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Query Selector 2 | Here you can find code and data loaders for the paper https://arxiv.org/pdf/2107.08687v1.pdf . Query Selector is a novel approach to sparse attention Transformer algorithm that is especially suitable for long term time series forecasting 3 | 4 | # Depencency 5 | ``` 6 | Python 3.7.9 7 | deepspeed 0.4.0 8 | numpy 1.20.3 9 | pandas 1.2.4 10 | scipy 1.6.3 11 | tensorboardX 1.8 12 | torch 1.7.1 13 | torchaudio 0.7.2 14 | torchvision 0.8.2 15 | tqdm 4.61.0 16 | ``` 17 | 18 | # Results on ETT dataset 19 | ## Univariate 20 | | Data | Prediction len | Informer MSE | Informer MAE | Trans former MSE | Trans former MAE | Query Selector MSE | Query Selector MAE | MSE ratio | 21 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | 22 | | ETTh1 | 24 | 0.0980 | 0.2470 | 0.0548 | 0.1830 | **0.0436** | **0.1616** | **0.445** | 23 | | ETTh1 | 48 | 0.1580 | 0.3190 | 0.0740 | 0.2144 | **0.0721** | **0.2118** | **0.456** | 24 | | ETTh1 | 168 | 0.1830 | 0.3460 | 0.1049 | 0.2539 | **0.0935** | **0.2371** | **0.511** | 25 | | ETTh1 | 336 | 0.2220 | 0.3870 | 0.1541 | 0.3201 | **0.1267** | **0.2844** | **0.571** | 26 | | ETTh1 | 720 | 0.2690 | 0.4350 | 0.2501 | 0.4213 | **0.2136** | **0.3730** | **0.794** | 27 | | ETTh2 | 24 | 0.0930 | 0.2400 | 0.0999 | 0.2479 | **0.0843** | **0.2239** | **0.906** | 28 | | ETTh2 | 48 | 0.1550 | 0.3140 | 0.1218 | 0.2763 | **0.1117** | **0.2622** | **0.721** | 29 | | ETTh2 | 168 | 0.2320 | 0.3890 | 0.1974 | 0.3547 | **0.1753** | **0.3322** | **0.756** | 30 | | ETTh2 | 336 | 0.2630 | 0.4170 | 0.2191 | 0.3805 | **0.2088** | **0.3710** | **0.794** | 31 | | ETTh2 | 720 | 0.2770 | 0.4310 | 0.2853 | 0.4340 | **0.2585** | **0.4130** | **0.933** | 32 | | ETTm1 | 24 | 0.0300 | 0.1370 | 0.0143 | 0.0894 | **0.0139** | **0.0870** | **0.463** | 33 | | ETTm1 | 48 | 0.0690 | 0.2030 | **0.0328** | **0.1388** | 0.0342 | 0.1408 | **0.475** | 34 | | ETTm1 | 96 | 0.1940 | **0.2030** | **0.0695** | 0.2085 | 0.0702 | 0.2100 | **0.358** | 35 | | ETTm1 | 288 | 0.4010 | 0.5540 | **0.1316** | **0.2948** | 0.1548 | 0.3240 | **0.328** | 36 | | ETTm1 | 672 | 0.5120 | 0.6440 | **0.1728** | 0.3437 | 0.1735 | **0.3427** | **0.338** | 37 | 38 | ## Multivariate 39 | | Data | Prediction len | Informer MSE | Informer MAE | Trans former MSE | Trans former MAE | Query Selector MSE | Query Selector MAE | MSE ratio | 40 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | 41 | | ETTh1 | 24 | 0.5770 | 0.5490 | 0.4496 | 0.4788 | **0.4226** | **0.4627** | **0.732** | 42 | | ETTh1 | 48 | 0.6850 | 0.6250 | 0.4668 | 0.4968 | **0.4581** | **0.4878** | **0.669** | 43 | | ETTh1 | 168 | 0.9310 | 0.7520 | 0.7146 | 0.6325 | **0.6835** | **0.6088** | **0.734** | 44 | | ETTh1 | 336 | 1.1280 | 0.8730 | **0.8321** | 0.7041 | 0.8503 | **0.7039** | **0.738** | 45 | | ETTh1 | 720 | 1.2150 | 0.8960 | **1.1080** | **0.8399** | 1.1150 | 0.8428 | **0.912** | 46 | | ETTh2 | 24 | 0.7200 | 0.6650 | 0.4237 | 0.5013 | **0.4124** | **0.4864** | **0.573** | 47 | | ETTh2 | 48 | 1.4570 | 1.0010 | 1.5220 | 0.9488 | **1.4074** | **0.9317** | **0.966** | 48 | | ETTh2 | 168 | 3.4890 | 1.5150 | **1.6225** | **0.9726** | 1.7385 | 1.0125 | **0.465** | 49 | | ETTh2 | 336 | 2.7230 | 1.3400 | 2.6617 | 1.2189 | **2.3168** | **1.1859** | **0.851** | 50 | | ETTh2 | 720 | 3.4670 | 1.4730 | 3.1805 | 1.3668 | **3.0664** | **1.3084** | **0.884** | 51 | | ETTm1 | 24 | 0.3230 | **0.3690** | **0.3150** | 0.3886 | 0.3351 | 0.3875 | **0.975** | 52 | | ETTm1 | 48 | 0.4940 | 0.5030 | **0.4454** | **0.4620** | 0.4726 | 0.4702 | **0.902** | 53 | | ETTm1 | 96 | 0.6780 | 0.6140 | 0.4641 | **0.4823** | **0.4543** | 0.4831 | **0.670** | 54 | | ETTm1 | 288 | 1.0560 | 0.7860 | 0.6814 | 0.6312 | **0.6185** | **0.5991** | **0.586** | 55 | | ETTm1 | 672 | 1.1920 | 0.9260 | 1.1365 | 0.8572 | **1.1273** | **0.8412** | **0.946** | 56 | 57 | # Citation 58 | ``` 59 | @misc{klimek2021longterm, 60 | title={Long-term series forecasting with Query Selector -- efficient model of sparse attention}, 61 | author={Jacek Klimek and Jakub Klimek and Witold Kraskiewicz and Mateusz Topolewski}, 62 | year={2021}, 63 | eprint={2107.08687}, 64 | archivePrefix={arXiv}, 65 | primaryClass={cs.LG} 66 | } 67 | ``` 68 | # Contact 69 | If you have any questions please contact us by email - jacek.klimek@morai.eu 70 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import json 4 | 5 | from deepspeed import deepspeed 6 | 7 | 8 | class Config: 9 | def __init__(self, data='ETTh1',seq_len=720, pred_len=24, dec_seq_len=24, hidden_size=128, heads=3, batch_size=100, embedding_size=32, 10 | n_encoder_layers=3, encoder_attention='full', n_decoder_layers=3, decoder_attention='full', 11 | prediction_type='uni', dropout=0.1, fp16=True, 12 | iterations=10, exps=5, deepspeed= True, debug=False): 13 | self.data = data 14 | self.seq_len = seq_len 15 | self.pred_len = pred_len 16 | self.dec_seq_len = dec_seq_len 17 | self.hidden_size = hidden_size 18 | self.heads = heads 19 | self.n_encoder_layers = n_encoder_layers 20 | self.encoder_attention = encoder_attention 21 | self.n_decoder_layers = n_decoder_layers 22 | self.decoder_attention = decoder_attention 23 | self.batch_size = batch_size 24 | self.embedding_size = embedding_size 25 | self.prediction_type = prediction_type 26 | self.dropout = dropout 27 | self.fp16 = fp16 28 | self.deepspeed = deepspeed 29 | self.iterations = iterations 30 | self.exps = exps 31 | self.debug = debug 32 | 33 | def extend_argv(self): 34 | sys.argv.extend(['--data', str(self.data)]) 35 | sys.argv.extend(['--seq_len', str(self.seq_len)]) 36 | sys.argv.extend(['--pred_len', str(self.pred_len)]) 37 | sys.argv.extend(['--dec_seq_len', str(self.dec_seq_len)]) 38 | sys.argv.extend(["--hidden_size", str(self.hidden_size)]) 39 | sys.argv.extend(["--n_encoder_layers", str(self.n_encoder_layers)]) 40 | sys.argv.extend(["--n_decoder_layers", str(self.n_decoder_layers)]) 41 | sys.argv.extend(["--encoder_attention", str(self.encoder_attention)]) 42 | sys.argv.extend(["--decoder_attention", str(self.decoder_attention)]) 43 | sys.argv.extend(["--n_heads", str(self.heads)]) 44 | sys.argv.extend(["--batch_size", str(self.batch_size)]) 45 | sys.argv.extend(["--embedding_size", str(self.embedding_size)]) 46 | sys.argv.extend(["--iterations", str(self.iterations)]) 47 | sys.argv.extend(["--exps", str(self.exps)]) 48 | 49 | sys.argv.extend(["--dropout", str(self.dropout)]) 50 | if self.fp16: 51 | sys.argv.extend(["--fp16"]) 52 | 53 | if self.deepspeed: 54 | sys.argv.extend(["--deepspeed"]) 55 | 56 | if self.debug: 57 | sys.argv.extend(["--debug"]) 58 | 59 | if self.prediction_type == 'uni': 60 | sys.argv.extend(["--features", 'S']) 61 | sys.argv.extend(["--input_len", '1', "--output_len", "1"]) 62 | elif self.prediction_type == 'multiuni': 63 | sys.argv.extend(["--features", 'MS']) 64 | sys.argv.extend(["--input_len", '7', "--output_len", "1"]) 65 | elif self.prediction_type == 'multi': 66 | sys.argv.extend(["--features", 'M']) 67 | sys.argv.extend(["--input_len", '7', "--output_len", "7"]) 68 | else: 69 | raise NotImplemented 70 | 71 | def __str__(self): 72 | res = ':: ds-time-series config\n' 73 | res += ':::: train dataset: {}\n'.format(self.data) 74 | res += ':::: train input sequence len: {}\n'.format(self.seq_len) 75 | res += ':::: train prediction sequence len: {}\n'.format(self.pred_len) 76 | res += ':::: train decoder sequence len: {}\n'.format(self.dec_seq_len) 77 | res += ':::: train batch size: {}\n'.format(self.batch_size) 78 | res += ':::: train prediction type: {}\n'.format('univariate' if self.prediction_type == 'uni' 79 | else 'multiunivariate') 80 | res += ':::: train iterations: {}\n'.format(str(self.iterations)) 81 | res += ':::: train experiment number: {}\n'.format(self.exps) 82 | res += ':::: train using deepspeed: {}\n'.format(self.deepspeed) 83 | res += ':::: train using fp16: {}\n'.format(self.deepspeed) 84 | res += ':::: train recording: {}\n'.format(self.debug) 85 | 86 | res += ':::: model hidden size: {}\n'.format(self.hidden_size) 87 | res += ':::: model embedding size: {}\n'.format(self.embedding_size) 88 | res += ':::: model encoder layers: {}\n'.format(self.n_encoder_layers) 89 | res += ':::: model encoder attention: {}\n'.format(self.encoder_attention) 90 | res += ':::: model decoder layers: {}\n'.format(self.n_encoder_layers) 91 | res += ':::: model decoder attention: {}\n'.format(self.decoder_attention) 92 | res += ':::: model heads number: {}\n'.format(self.heads) 93 | res += ':::: model input dropout: {}\n'.format(self.dropout) 94 | 95 | return res 96 | 97 | @staticmethod 98 | def from_file(f): 99 | with open(f, 'r') as file: 100 | a = file.readlines() 101 | dict = json.loads(''.join(a)) 102 | return Config(**dict) 103 | 104 | def to_json(self): 105 | return json.dumps(self.__dict__, indent=2) 106 | 107 | 108 | def build_parser(): 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('--data', type=str, choices=['ETTh1', 'ETTh2', 'ETTm1'], required=True) 111 | parser.add_argument('--input_len', type=int, required=True) 112 | parser.add_argument('--output_len', type=int, required=True) 113 | parser.add_argument('--seq_len', type=int, required=True) 114 | parser.add_argument('--dec_seq_len', type=int, required=True) 115 | parser.add_argument('--pred_len', type=int, required=True) 116 | parser.add_argument('--features', type=str, required=True) 117 | parser.add_argument('--target', default='OT', type=str) 118 | parser.add_argument('--iterations', type=int, required=True) 119 | parser.add_argument('--exps', type=int, required=True) 120 | 121 | parser.add_argument('--hidden_size', type=int, required=True) 122 | parser.add_argument('--n_heads', type=int, required=True) 123 | parser.add_argument('--n_encoder_layers', type=int, required=True) 124 | parser.add_argument('--encoder_attention', type=str, required=True) 125 | parser.add_argument('--n_decoder_layers', type=int, required=True) 126 | parser.add_argument('--decoder_attention', type=str, required=True) 127 | parser.add_argument('--batch_size', type=int, required=True) 128 | parser.add_argument('--embedding_size', type=int, required=True) 129 | parser.add_argument('--dropout', type=float, required=True) 130 | parser.add_argument('--fp16', action='store_true') 131 | 132 | parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False) 133 | parser.add_argument('--num-workers', 134 | type=int, 135 | default=2) 136 | parser.add_argument('--freq', type=str, default='h', 137 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 138 | 139 | parser.add_argument("--local_rank", type=int, default=0) 140 | parser.add_argument("--run_num", type=int, default=0) 141 | parser.add_argument('--debug', action='store_true') 142 | 143 | parser = deepspeed.add_config_arguments(parser) 144 | return parser -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # Original code from https://github.com/zhouhaoyi/Informer2020/ 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | # from sklearn.preprocessing import StandardScaler 9 | 10 | from utils.tools import StandardScaler 11 | from utils.timefeatures import time_features 12 | 13 | import warnings 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | class Dataset_ETT_hour(Dataset): 19 | def __init__(self, root_path, flag='train', size=None, 20 | features='S', data_path='ETTh1.csv', 21 | target='OT', scale=True, inverse=False, timeenc=0, freq='h'): 22 | # size [seq_len, label_len, pred_len] 23 | # info 24 | if size == None: 25 | self.seq_len = 24 * 4 * 4 26 | self.label_len = 24 * 4 27 | self.pred_len = 24 * 4 28 | else: 29 | self.seq_len = size[0] 30 | self.label_len = size[1] 31 | self.pred_len = size[2] 32 | # init 33 | assert flag in ['train', 'test', 'val'] 34 | type_map = {'train': 0, 'val': 1, 'test': 2} 35 | self.set_type = type_map[flag] 36 | 37 | self.features = features 38 | self.target = target 39 | self.scale = scale 40 | self.inverse = inverse 41 | self.timeenc = timeenc 42 | self.freq = freq 43 | 44 | self.root_path = root_path 45 | self.data_path = data_path 46 | self.__read_data__() 47 | 48 | def __read_data__(self): 49 | self.scaler = StandardScaler() 50 | df_raw = pd.read_csv(os.path.join(self.root_path, 51 | self.data_path)) 52 | 53 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len] 54 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24] 55 | border1 = border1s[self.set_type] 56 | border2 = border2s[self.set_type] 57 | 58 | if self.features == 'M' or self.features == 'MS': 59 | cols_data = df_raw.columns[1:] 60 | target_index = list(cols_data).index(self.target) 61 | df_data = df_raw[cols_data] 62 | elif self.features == 'S': 63 | df_data = df_raw[[self.target]] 64 | 65 | if self.scale: 66 | train_data = df_data[border1s[0]:border2s[0]] 67 | self.scaler.fit(train_data.values) 68 | data = self.scaler.transform(df_data.values) 69 | else: 70 | data = df_data.values 71 | 72 | df_stamp = df_raw[['date']][border1:border2] 73 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 74 | data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq) 75 | 76 | self.data_x = data[border1:border2] 77 | if self.inverse: 78 | self.data_y = df_data.values[border1:border2] 79 | else: 80 | self.data_y = data[border1:border2] 81 | if self.features == "MS": 82 | self.data_y = self.data_y[:, [target_index]] 83 | 84 | self.data_stamp = data_stamp 85 | 86 | def __getitem__(self, index): 87 | s_begin = index 88 | s_end = s_begin + self.seq_len 89 | r_begin = s_end - self.label_len 90 | r_end = r_begin + self.label_len + self.pred_len 91 | 92 | seq_x = self.data_x[s_begin:s_end] 93 | seq_y = self.data_y[r_begin:r_end] 94 | seq_x_mark = self.data_stamp[s_begin:s_end] 95 | seq_y_mark = self.data_stamp[r_begin:r_end] 96 | 97 | return seq_x, seq_y, seq_x_mark, seq_y_mark 98 | 99 | def __len__(self): 100 | return len(self.data_x) - self.seq_len - self.pred_len + 1 101 | 102 | def inverse_transform(self, data): 103 | return self.scaler.inverse_transform(data) 104 | 105 | 106 | class Dataset_ETT_minute(Dataset): 107 | def __init__(self, root_path, flag='train', size=None, 108 | features='S', data_path='ETTm1.csv', 109 | target='OT', scale=True, inverse=False, timeenc=0, freq='t'): 110 | # size [seq_len, label_len, pred_len] 111 | # info 112 | if size == None: 113 | self.seq_len = 24 * 4 * 4 114 | self.label_len = 24 * 4 115 | self.pred_len = 24 * 4 116 | else: 117 | self.seq_len = size[0] 118 | self.label_len = size[1] 119 | self.pred_len = size[2] 120 | # init 121 | assert flag in ['train', 'test', 'val'] 122 | type_map = {'train': 0, 'val': 1, 'test': 2} 123 | self.set_type = type_map[flag] 124 | 125 | self.features = features 126 | self.target = target 127 | self.scale = scale 128 | self.inverse = inverse 129 | self.timeenc = timeenc 130 | self.freq = freq 131 | 132 | self.root_path = root_path 133 | self.data_path = data_path 134 | self.__read_data__() 135 | 136 | def __read_data__(self): 137 | self.scaler = StandardScaler() 138 | df_raw = pd.read_csv(os.path.join(self.root_path, 139 | self.data_path)) 140 | 141 | border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len] 142 | border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4] 143 | border1 = border1s[self.set_type] 144 | border2 = border2s[self.set_type] 145 | 146 | if self.features == 'M' or self.features == 'MS': 147 | cols_data = df_raw.columns[1:] 148 | target_index = list(cols_data).index(self.target) 149 | df_data = df_raw[cols_data] 150 | elif self.features == 'S': 151 | df_data = df_raw[[self.target]] 152 | 153 | if self.scale: 154 | train_data = df_data[border1s[0]:border2s[0]] 155 | self.scaler.fit(train_data.values) 156 | data = self.scaler.transform(df_data.values) 157 | else: 158 | data = df_data.values 159 | 160 | df_stamp = df_raw[['date']][border1:border2] 161 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 162 | data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq) 163 | 164 | self.data_x = data[border1:border2] 165 | if self.inverse: 166 | self.data_y = df_data.values[border1:border2] 167 | else: 168 | self.data_y = data[border1:border2] 169 | if self.features == "MS": 170 | self.data_y = self.data_y[:, [target_index]] 171 | self.data_stamp = data_stamp 172 | 173 | def __getitem__(self, index): 174 | s_begin = index 175 | s_end = s_begin + self.seq_len 176 | r_begin = s_end - self.label_len 177 | r_end = r_begin + self.label_len + self.pred_len 178 | 179 | seq_x = self.data_x[s_begin:s_end] 180 | seq_y = self.data_y[r_begin:r_end] 181 | seq_x_mark = self.data_stamp[s_begin:s_end] 182 | seq_y_mark = self.data_stamp[r_begin:r_end] 183 | 184 | return seq_x, seq_y, seq_x_mark, seq_y_mark 185 | 186 | def __len__(self): 187 | return len(self.data_x) - self.seq_len - self.pred_len + 1 188 | 189 | def inverse_transform(self, data): 190 | return self.scaler.inverse_transform(data) 191 | 192 | 193 | class Dataset_Custom(Dataset): 194 | def __init__(self, root_path, flag='train', size=None, 195 | features='S', data_path='ETTh1.csv', 196 | target='OT', scale=True, inverse=False, timeenc=0, freq='h'): 197 | # size [seq_len, label_len, pred_len] 198 | # info 199 | if size == None: 200 | self.seq_len = 24 * 4 * 4 201 | self.label_len = 24 * 4 202 | self.pred_len = 24 * 4 203 | else: 204 | self.seq_len = size[0] 205 | self.label_len = size[1] 206 | self.pred_len = size[2] 207 | # init 208 | assert flag in ['train', 'test', 'val'] 209 | type_map = {'train': 0, 'val': 1, 'test': 2} 210 | self.set_type = type_map[flag] 211 | 212 | self.features = features 213 | self.target = target 214 | self.scale = scale 215 | self.inverse = inverse 216 | self.timeenc = timeenc 217 | self.freq = freq 218 | 219 | self.root_path = root_path 220 | self.data_path = data_path 221 | self.__read_data__() 222 | 223 | def __read_data__(self): 224 | self.scaler = StandardScaler() 225 | df_raw = pd.read_csv(os.path.join(self.root_path, 226 | self.data_path)) 227 | ''' 228 | df_raw.columns: ['date', ...(other features), target feature] 229 | ''' 230 | cols = list(df_raw.columns); 231 | cols.remove(self.target); 232 | cols.remove('date') 233 | df_raw = df_raw[['date'] + cols + [self.target]] 234 | 235 | num_train = int(len(df_raw) * 0.7) 236 | num_test = int(len(df_raw) * 0.2) 237 | num_vali = len(df_raw) - num_train - num_test 238 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 239 | border2s = [num_train, num_train + num_vali, len(df_raw)] 240 | border1 = border1s[self.set_type] 241 | border2 = border2s[self.set_type] 242 | 243 | if self.features == 'M' or self.features == 'MS': 244 | cols_data = df_raw.columns[1:] 245 | df_data = df_raw[cols_data] 246 | elif self.features == 'S': 247 | df_data = df_raw[[self.target]] 248 | 249 | if self.scale: 250 | train_data = df_data[border1s[0]:border2s[0]] 251 | self.scaler.fit(train_data.values) 252 | data = self.scaler.transform(df_data.values) 253 | else: 254 | data = df_data.values 255 | 256 | df_stamp = df_raw[['date']][border1:border2] 257 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 258 | data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq) 259 | 260 | self.data_x = data[border1:border2] 261 | if self.inverse: 262 | self.data_y = df_data.values[border1:border2] 263 | else: 264 | self.data_y = data[border1:border2] 265 | self.data_stamp = data_stamp 266 | 267 | def __getitem__(self, index): 268 | s_begin = index 269 | s_end = s_begin + self.seq_len 270 | r_begin = s_end - self.label_len 271 | r_end = r_begin + self.label_len + self.pred_len 272 | 273 | seq_x = self.data_x[s_begin:s_end] 274 | seq_y = self.data_y[r_begin:r_end] 275 | seq_x_mark = self.data_stamp[s_begin:s_end] 276 | seq_y_mark = self.data_stamp[r_begin:r_end] 277 | 278 | return seq_x, seq_y, seq_x_mark, seq_y_mark 279 | 280 | def __len__(self): 281 | return len(self.data_x) - self.seq_len - self.pred_len + 1 282 | 283 | def inverse_transform(self, data): 284 | return self.scaler.inverse_transform(data) 285 | 286 | 287 | class Dataset_Pred(Dataset): 288 | def __init__(self, root_path, flag='pred', size=None, 289 | features='S', data_path='ETTh1.csv', 290 | target='OT', scale=True, inverse=False, timeenc=0, freq='15min'): 291 | # size [seq_len, label_len, pred_len] 292 | # info 293 | if size == None: 294 | self.seq_len = 24 * 4 * 4 295 | self.label_len = 24 * 4 296 | self.pred_len = 24 * 4 297 | else: 298 | self.seq_len = size[0] 299 | self.label_len = size[1] 300 | self.pred_len = size[2] 301 | # init 302 | assert flag in ['pred'] 303 | 304 | self.features = features 305 | self.target = target 306 | self.scale = scale 307 | self.inverse = inverse 308 | self.timeenc = timeenc 309 | self.freq = freq 310 | 311 | self.root_path = root_path 312 | self.data_path = data_path 313 | self.__read_data__() 314 | 315 | def __read_data__(self): 316 | self.scaler = StandardScaler() 317 | df_raw = pd.read_csv(os.path.join(self.root_path, 318 | self.data_path)) 319 | ''' 320 | df_raw.columns: ['date', ...(other features), target feature] 321 | ''' 322 | cols = list(df_raw.columns); 323 | cols.remove(self.target); 324 | cols.remove('date') 325 | df_raw = df_raw[['date'] + cols + [self.target]] 326 | 327 | border1 = len(df_raw) - self.seq_len 328 | border2 = len(df_raw) 329 | 330 | if self.features == 'M' or self.features == 'MS': 331 | cols_data = df_raw.columns[1:] 332 | df_data = df_raw[cols_data] 333 | elif self.features == 'S': 334 | df_data = df_raw[[self.target]] 335 | 336 | if self.scale: 337 | self.scaler.fit(df_data.values) 338 | data = self.scaler.transform(df_data.values) 339 | else: 340 | data = df_data.values 341 | 342 | tmp_stamp = df_raw[['date']][border1:border2] 343 | tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date) 344 | pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq) 345 | 346 | df_stamp = pd.DataFrame(columns=['date']) 347 | df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:]) 348 | data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq[-1:]) 349 | 350 | self.data_x = data[border1:border2] 351 | if self.inverse: 352 | self.data_y = df_data.values[border1:border2] 353 | else: 354 | self.data_y = data[border1:border2] 355 | self.data_stamp = data_stamp 356 | 357 | def __getitem__(self, index): 358 | s_begin = index 359 | s_end = s_begin + self.seq_len 360 | r_begin = s_end - self.label_len 361 | r_end = r_begin + self.label_len + self.pred_len 362 | 363 | seq_x = self.data_x[s_begin:s_end] 364 | seq_y = self.data_y[r_begin:r_begin + self.label_len] 365 | seq_x_mark = self.data_stamp[s_begin:s_end] 366 | seq_y_mark = self.data_stamp[r_begin:r_end] 367 | 368 | return seq_x, seq_y, seq_x_mark, seq_y_mark 369 | 370 | def __len__(self): 371 | return len(self.data_x) - self.seq_len + 1 372 | 373 | def inverse_transform(self, data): 374 | return self.scaler.inverse_transform(data) 375 | -------------------------------------------------------------------------------- /ipc.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from statistics import mean 3 | 4 | PORT = 6666 5 | 6 | 7 | def print_results(results): 8 | for e in range(1, len(results) + 1): 9 | print('Iteration {:>2}| MSE {:>6.4f} | MAE {:>6.4f}'.format(e, results[e - 1][0], results[e - 1][1])) 10 | final_mse = float(mean([float(r[0]) for r in results])) 11 | final_mae = float(mean([float(r[1]) for r in results])) 12 | print('Mean | MSE {:>6.4f} | MAE {:>6.4f}'.format(final_mse, final_mae)) 13 | 14 | 15 | def resultServer(args, q=None): 16 | results = [] 17 | partials = [] 18 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 19 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 20 | s.bind(('localhost', PORT)) 21 | s.listen() 22 | while len(results) < args.exps: 23 | conn, addr = s.accept() 24 | with conn: 25 | print('Connected by', addr) 26 | data = conn.recv(4 * 1024) 27 | res = data.decode().split(';') 28 | if len(res) == 2: 29 | print('\033[94mReceived result:', data,'\033[0m' ) 30 | results.append([float(res[0]), float(res[1])]) 31 | print_results(results) 32 | else: 33 | print('\033[94mReceived training result:', data, '\033[0m') 34 | it, mse, mae = int(res[0]), float(res[1]), float(res[2]) 35 | if it > len (partials): 36 | partials.append([]) 37 | partials[it-1].append([mse, mae]) 38 | s.shutdown(1) 39 | s.close() 40 | for e in range(1, args.exps + 1): 41 | print('Iteration {:>2}| MSE {:>6.4f} | MAE {:>6.4f}'.format(e, results[e - 1][0], results[e - 1][1])) 42 | final_mse = float(mean([float(r[0]) for r in results])) 43 | final_mae = float(mean([float(r[1]) for r in results])) 44 | print('Mean | MSE {:>6.4f} | MAE {:>6.4f}'.format(final_mse, final_mae)) 45 | print(partials) 46 | if q: 47 | q.put({"mse" : final_mse, "mae": final_mae}) 48 | 49 | 50 | def sendResults(mse, mae): 51 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as c: 52 | c.connect(('localhost', PORT)) 53 | c.send("{:.5f};{:.5f}".format(float(mse), float(mae)).encode()) 54 | c.close() 55 | 56 | 57 | def sendPartials(it, mse, mae): 58 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as c: 59 | c.connect(('localhost', PORT)) 60 | c.send("{};{:.5f};{:.5f}".format(it, float(mse), float(mae)).encode()) 61 | c.close() -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # Original code from https://github.com/zhouhaoyi/Informer2020/blob/ 2 | import numpy as np 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true-pred)**2)) / np.sqrt(np.sum((true-true.mean())**2)) 6 | 7 | def CORR(pred, true): 8 | u = ((true-true.mean(0))*(pred-pred.mean(0))).sum(0) 9 | d = np.sqrt(((true-true.mean(0))**2*(pred-pred.mean(0))**2).sum(0)) 10 | return (u/d).mean(-1) 11 | 12 | def MAE(pred, true): 13 | return np.mean(np.abs(pred-true)) 14 | 15 | def MSE(pred, true): 16 | return np.mean((pred-true)**2) 17 | 18 | def RMSE(pred, true): 19 | return np.sqrt(MSE(pred, true)) 20 | 21 | def MAPE(pred, true): 22 | return np.mean(np.abs((pred - true) / true)) 23 | 24 | def MSPE(pred, true): 25 | return np.mean(np.square((pred - true) / true)) 26 | 27 | def metric(pred, true): 28 | mae = MAE(pred, true) 29 | mse = MSE(pred, true) 30 | rmse = RMSE(pred, true) 31 | mape = MAPE(pred, true) 32 | mspe = MSPE(pred, true) 33 | 34 | return mae,mse,rmse,mape,mspe -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import time 2 | from multiprocessing import Pool 3 | 4 | import numpy 5 | import numpy as np 6 | import torch 7 | from torch import nn as nn 8 | import torch.nn.functional as F 9 | import math 10 | import matplotlib.pyplot as plt 11 | from scipy.stats import wasserstein_distance 12 | 13 | 14 | def a_norm(Q, K): 15 | m = torch.matmul(Q, K.transpose(2, 1)) 16 | m /= torch.sqrt(torch.tensor(Q.shape[-1]).float()) 17 | return torch.softmax(m, -1) 18 | 19 | 20 | def attention(Q, K, V): 21 | a = a_norm(Q, K) # (batch_size, dim_attn, seq_length) 22 | return torch.matmul(a, V) # (batch_size, seq_length, seq_length) 23 | 24 | 25 | class QuerySelector(nn.Module): 26 | def __init__(self, fraction=0.33): 27 | super(QuerySelector, self).__init__() 28 | self.fraction = fraction 29 | 30 | def forward(self, queries, keys, values): 31 | B, L_Q, D = queries.shape 32 | _, L_K, _ = keys.shape 33 | l_Q = int((1.0 - self.fraction) * L_Q) 34 | K_reduce = torch.mean(keys.topk(l_Q, dim=1).values, dim=1).unsqueeze(1) 35 | sqk = torch.matmul(K_reduce, queries.transpose(1,2)) 36 | indices = sqk.topk(l_Q, dim=-1).indices.squeeze(1) 37 | Q_sample = queries[torch.arange(B)[:, None], indices, :] # factor*ln(L_q) 38 | Q_K = torch.matmul(Q_sample, keys.transpose(-2, -1)) 39 | attn = torch.softmax(Q_K / math.sqrt(D), dim=-1) 40 | mean_values = values.mean(dim=-2) 41 | result = mean_values.unsqueeze(-2).expand(B, L_Q, mean_values.shape[-1]).clone() 42 | result[torch.arange(B)[:, None], indices, :] = torch.matmul(attn, values).type_as(result) 43 | return result, None 44 | 45 | def inference(self): 46 | pass # no parameters 47 | 48 | 49 | class InferenceModule(torch.nn.Module): 50 | def inference(self): 51 | for mod in self.modules(): 52 | if mod != self: 53 | mod.inference() 54 | 55 | 56 | 57 | class InferenceModuleList(torch.nn.ModuleList): 58 | def inference(self): 59 | for mod in self.modules(): 60 | if mod != self: 61 | mod.inference() 62 | 63 | 64 | class AttentionBlock(InferenceModule): 65 | def __init__(self, dim_val, dim_attn, debug=False, attn_type='full'): 66 | super(AttentionBlock, self).__init__() 67 | self.value = Value(dim_val, dim_val) 68 | self.key = Key(dim_val, dim_attn) 69 | self.query = Query(dim_val, dim_attn) 70 | self.debug = debug 71 | self.qk_record = None 72 | self.qkv_record = None 73 | self.n = 0 74 | if attn_type == "full": 75 | self.attentionLayer = None 76 | elif attn_type.startswith("query_selector"): 77 | args = {} 78 | if len(attn_type.split('_')) == 3: 79 | args['fraction'] = float(attn_type.split('_')[-1]) 80 | self.attentionLayer = QuerySelector(**args) 81 | else: 82 | raise Exception 83 | 84 | def forward(self, x, kv=None): 85 | if kv is None: 86 | if self.attentionLayer: 87 | qkv = self.attentionLayer(self.query(x), self.key(x), self.value(x))[0] 88 | else: 89 | qkv = attention(self.query(x), self.key(x), self.value(x)) 90 | return qkv 91 | return attention(self.query(x), self.key(kv), self.value(kv)) 92 | 93 | 94 | class MultiHeadAttentionBlock(InferenceModule): 95 | def __init__(self, dim_val, dim_attn, n_heads, attn_type): 96 | super(MultiHeadAttentionBlock, self).__init__() 97 | self.heads = [] 98 | for i in range(n_heads): 99 | self.heads.append(AttentionBlock(dim_val, dim_attn, attn_type=attn_type)) 100 | 101 | self.heads = InferenceModuleList(self.heads) 102 | self.fc = Linear(n_heads * dim_val, dim_val, bias=False) 103 | 104 | def forward(self, x, kv=None): 105 | a = [] 106 | for h in self.heads: 107 | a.append(h(x, kv=kv)) 108 | 109 | a = torch.stack(a, dim=-1) # combine heads 110 | a = a.flatten(start_dim=2) # flatten all head outputs 111 | 112 | x = self.fc(a) 113 | 114 | return x 115 | 116 | def record(self): 117 | for h in self.heads: 118 | h.record() 119 | 120 | 121 | class Value(InferenceModule): 122 | def __init__(self, dim_input, dim_val): 123 | super(Value, self).__init__() 124 | self.dim_val = dim_val 125 | self.fc1 = Linear(dim_input, dim_val, bias=False) 126 | 127 | def forward(self, x): 128 | return self.fc1(x) 129 | 130 | 131 | class Key(InferenceModule): 132 | def __init__(self, dim_input, dim_attn): 133 | super(Key, self).__init__() 134 | self.dim_attn = dim_attn 135 | self.fc1 = Linear(dim_input, dim_attn, bias=False) 136 | 137 | def forward(self, x): 138 | return self.fc1(x) 139 | 140 | 141 | class Query(InferenceModule): 142 | def __init__(self, dim_input, dim_attn): 143 | super(Query, self).__init__() 144 | self.dim_attn = dim_attn 145 | self.fc1 = Linear(dim_input, dim_attn, bias=False) 146 | 147 | def forward(self, x): 148 | return self.fc1(x) 149 | 150 | 151 | class PositionalEncoding(InferenceModule): 152 | def __init__(self, d_model, dropout=0.1, max_len=5000): 153 | super(PositionalEncoding, self).__init__() 154 | 155 | pe = torch.zeros(max_len, d_model) 156 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 157 | 158 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 159 | 160 | pe[:, 0::2] = torch.sin(position * div_term) 161 | pe[:, 1::2] = torch.cos(position * div_term) 162 | 163 | pe = pe.unsqueeze(0).transpose(0, 1) 164 | 165 | self.register_buffer('pe', pe) 166 | 167 | def forward(self, x): 168 | x = x + self.pe[:x.size(1), :].squeeze(1) 169 | return x 170 | 171 | 172 | class EncoderLayer(InferenceModule): 173 | def __init__(self, dim_val, dim_attn, n_heads=1, attn_type='full'): 174 | super(EncoderLayer, self).__init__() 175 | self.attn = MultiHeadAttentionBlock(dim_val, dim_attn, n_heads, attn_type=attn_type) 176 | 177 | self.fc1 = Linear(dim_val, dim_val) 178 | self.fc2 = Linear(dim_val, dim_val) 179 | 180 | self.norm1 = LayerNorm(dim_val) 181 | self.norm2 = LayerNorm(dim_val) 182 | 183 | def forward(self, x): 184 | a = self.attn(x) 185 | x = self.norm1(x + a) 186 | 187 | a = self.fc1(F.elu(self.fc2(x))) 188 | x = self.norm2(x + a) 189 | 190 | return x 191 | 192 | def record(self): 193 | self.attn.record() 194 | 195 | 196 | class DecoderLayer(InferenceModule): 197 | def __init__(self, dim_val, dim_attn, n_heads=1, attn_type='full'): 198 | super(DecoderLayer, self).__init__() 199 | self.attn1 = MultiHeadAttentionBlock(dim_val, dim_attn, n_heads, attn_type=attn_type) 200 | self.attn2 = MultiHeadAttentionBlock(dim_val, dim_attn, n_heads, attn_type=attn_type) 201 | 202 | self.fc1 = Linear(dim_val, dim_val) 203 | self.fc2 = Linear(dim_val, dim_val) 204 | 205 | self.norm1 = LayerNorm(dim_val) 206 | self.norm2 = LayerNorm(dim_val) 207 | self.norm3 = LayerNorm(dim_val) 208 | 209 | def forward(self, x, enc): 210 | a = self.attn1(x) 211 | x = self.norm1(a + x) 212 | 213 | a = self.attn2(x, kv=enc) 214 | x = self.norm2(a + x) 215 | 216 | a = self.fc1(F.elu(self.fc2(x))) 217 | x = self.norm3(x + a) 218 | return x 219 | 220 | def record(self): 221 | self.attn1.record() 222 | self.attn2.record() 223 | 224 | 225 | class Dropout(nn.Dropout): 226 | def forward(self, x=False): 227 | if self.training: 228 | return super(Dropout, self).forward(x) 229 | else: 230 | return x 231 | 232 | def inference(self): 233 | self.training = False 234 | 235 | 236 | class Linear(nn.Linear): 237 | def forward(self, x=False): 238 | if self.training: 239 | return super(Linear, self).forward(x) 240 | else: 241 | return F.linear(x, self.weight.data, self.bias.data if self.bias is not None else None) 242 | 243 | def inference(self): 244 | self.training = False 245 | 246 | 247 | class LayerNorm(nn.LayerNorm): 248 | def forward(self, x): 249 | if self.training: 250 | return super(LayerNorm, self).forward(x) 251 | else: 252 | return F.layer_norm(x, self.normalized_shape, self.weight.data, self.bias.data, self.eps) 253 | 254 | def inference(self): 255 | self.training = False 256 | 257 | 258 | class Transformer(InferenceModule): 259 | def __init__(self, dim_val, dim_attn, input_size, dec_seq_len, out_seq_len, n_decoder_layers=1, n_encoder_layers=1, 260 | enc_attn_type='full', dec_attn_type='full', n_heads=1, dropout=0.1, debug=False, output_len=1): 261 | super(Transformer, self).__init__() 262 | self.dec_seq_len = dec_seq_len 263 | self.output_len = output_len 264 | 265 | # Initiate encoder and Decoder layers 266 | self.encs = [] 267 | for i in range(n_encoder_layers): 268 | self.encs.append(EncoderLayer(dim_val, dim_attn, n_heads, attn_type=enc_attn_type)) 269 | self.encs = InferenceModuleList(self.encs) 270 | self.decs = [] 271 | for i in range(n_decoder_layers): 272 | self.decs.append(DecoderLayer(dim_val, dim_attn, n_heads, attn_type=dec_attn_type)) 273 | self.decs = InferenceModuleList(self.decs) 274 | self.pos = PositionalEncoding(dim_val) 275 | 276 | self.enc_dropout = Dropout(dropout) 277 | 278 | # Dense layers for managing network inputs and outputs 279 | self.enc_input_fc = Linear(input_size, dim_val) 280 | self.dec_input_fc = Linear(input_size, dim_val) 281 | self.out_fc = Linear(dec_seq_len * dim_val, out_seq_len*output_len) 282 | 283 | self.debug = debug 284 | 285 | def forward(self, x): 286 | # encoder 287 | e = self.encs[0](self.pos(self.enc_dropout(self.enc_input_fc(x)))) 288 | 289 | for enc in self.encs[1:]: 290 | e = enc(e) 291 | if self.debug: 292 | print('Encoder output size: {}'.format(e.shape)) 293 | # decoder 294 | decoded = self.dec_input_fc(x[:, -self.dec_seq_len:]) 295 | 296 | d = self.decs[0](decoded, e) 297 | for dec in self.decs[1:]: 298 | d = dec(d, e) 299 | 300 | # output 301 | x = self.out_fc(d.flatten(start_dim=1)) 302 | return torch.reshape(x, (x.shape[0], -1, self.output_len)) 303 | 304 | def record(self): 305 | self.debug = True 306 | for enc in self.encs: 307 | enc.record() 308 | for dec in self.decs: 309 | dec.record() 310 | 311 | 312 | -------------------------------------------------------------------------------- /run-ds.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from datetime import datetime 3 | from multiprocessing import Process 4 | 5 | from deepspeed.launcher import runner as ds_runer 6 | from torch.distributed import launch as pt_runner 7 | from config import Config 8 | import sys 9 | import socket 10 | 11 | from ipc import resultServer 12 | 13 | conf = Config.from_file('settings/tuned/ts_query-selector_u_h1_24.json') 14 | 15 | print(conf.to_json()) 16 | 17 | q = multiprocessing.Queue() 18 | p = Process(target=resultServer, args=[conf, q]) 19 | p.start() 20 | 21 | 22 | if conf.deepspeed: 23 | sys.argv.extend(['train.py', '--deepspeed_config', 'settings/ds_config_zero.json']) 24 | conf.extend_argv() 25 | setting_argv = sys.argv.copy() 26 | for run_num in range(conf.exps): 27 | sys.argv.clear() 28 | sys.argv.extend(setting_argv) 29 | sys.argv.extend(["--run_num", str(run_num + 1)]) 30 | ds_runer.main() 31 | else: 32 | conf.extend_argv() 33 | setting_argv = sys.argv.copy() 34 | for run_num in range(conf.exps): 35 | sys.argv.clear() 36 | sys.argv.extend(setting_argv) 37 | sys.argv.extend(["--run_num", str(run_num + 1)]) 38 | import train 39 | 40 | train.main() 41 | 42 | rfq = q.get() 43 | -------------------------------------------------------------------------------- /settings/ds_config_zero.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 5, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 10, 5 | "optimizer": { 6 | "type": "Adam", 7 | "params": { 8 | "lr": 0.00005, 9 | "weight_decay": 1e-2 10 | } 11 | }, 12 | "zero_optimization": { 13 | "stage": 2, 14 | "allgather_partitions": false, 15 | "cpu_offload": false 16 | }, 17 | "fp16": { 18 | "enabled": true, 19 | "loss_scale_window": 1000 20 | } 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h1_168.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 168, 4 | "pred_len": 168, 5 | "dec_seq_len": 168, 6 | "hidden_size": 144, 7 | "heads": 4, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 144, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h1_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 48, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 144, 7 | "heads": 2, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h1_336.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 168, 4 | "pred_len": 336, 5 | "dec_seq_len": 168, 6 | "hidden_size": 144, 7 | "heads": 4, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 96, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h1_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 96, 4 | "pred_len": 48, 5 | "dec_seq_len": 96, 6 | "hidden_size": 144, 7 | "heads": 4, 8 | "n_encoder_layers": 1, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 96, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h1_720.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 336, 4 | "pred_len": 720, 5 | "dec_seq_len": 336, 6 | "hidden_size": 144, 7 | "heads": 4, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 128, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h2_168.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 336, 4 | "pred_len": 168, 5 | "dec_seq_len": 336, 6 | "hidden_size": 312, 7 | "heads": 3, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 112, 14 | "prediction_type": "multi", 15 | "dropout": 0, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 4, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h2_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 48, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 312, 7 | "heads": 3, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 144, 13 | "embedding_size": 48, 14 | "prediction_type": "multi", 15 | "dropout": 0, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 6, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h2_336.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 336, 4 | "pred_len": 336, 5 | "dec_seq_len": 336, 6 | "hidden_size": 128, 7 | "heads": 6, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 96, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 3, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h2_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 720, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 144, 7 | "heads": 3, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_h2_720.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 336, 4 | "pred_len": 720, 5 | "dec_seq_len": 336, 6 | "hidden_size": 128, 7 | "heads": 6, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 96, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_m1_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 672, 4 | "pred_len": 24, 5 | "dec_seq_len": 96, 6 | "hidden_size": 128, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_m1_288.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 384, 4 | "pred_len": 288, 5 | "dec_seq_len": 384, 6 | "hidden_size": 312, 7 | "heads": 4, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 1, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_m1_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 512, 7 | "heads": 3, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 144, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_m1_672.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 672, 4 | "pred_len": 672, 5 | "dec_seq_len": 384, 6 | "hidden_size": 144, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 6, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_m_m1_96.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 96, 5 | "dec_seq_len": 96, 6 | "hidden_size": 512, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 144, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h1_168.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 168, 5 | "dec_seq_len": 336, 6 | "hidden_size": 128, 7 | "heads": 4, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 100, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h1_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 4, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h1_336.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 336, 5 | "dec_seq_len": 720, 6 | "hidden_size": 400, 7 | "heads": 4, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 24, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 2, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h1_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 5, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 100, 13 | "embedding_size": 48, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h1_720.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 720, 5 | "dec_seq_len": 720, 6 | "hidden_size": 128, 7 | "heads": 6, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 100, 13 | "embedding_size": 24, 14 | "prediction_type": "uni", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 2, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h2_168.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 720, 4 | "pred_len": 168, 5 | "dec_seq_len": 336, 6 | "hidden_size": 512, 7 | "heads": 3, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 96, 13 | "embedding_size": 24, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 4, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h2_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 48, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 5, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 96, 13 | "embedding_size": 64, 14 | "prediction_type": "uni", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 3, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h2_336.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 720, 4 | "pred_len": 336, 5 | "dec_seq_len": 336, 6 | "hidden_size": 256, 7 | "heads": 6, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 100, 13 | "embedding_size": 16, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 3, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h2_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 720, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 3, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 100, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_h2_720.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 336, 4 | "pred_len": 720, 5 | "dec_seq_len": 336, 6 | "hidden_size": 378, 7 | "heads": 2, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 48, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_m1_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_m1_288.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 384, 4 | "pred_len": 288, 5 | "dec_seq_len": 384, 6 | "hidden_size": 144, 7 | "heads": 4, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 6, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_m1_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_m1_672.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 384, 4 | "pred_len": 672, 5 | "dec_seq_len": 384, 6 | "hidden_size": 128, 7 | "heads": 4, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 48, 14 | "prediction_type": "uni", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 6, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_full_u_m1_96.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 96, 5 | "dec_seq_len": 96, 6 | "hidden_size": 128, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "full", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h1_168.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 168, 4 | "pred_len": 168, 5 | "dec_seq_len": 168, 6 | "hidden_size": 144, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.8", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 128, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h1_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 48, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 96, 7 | "heads": 2, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.85", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h1_336.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 168, 4 | "pred_len": 336, 5 | "dec_seq_len": 168, 6 | "hidden_size": 128, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.90", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 96, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h1_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 96, 4 | "pred_len": 48, 5 | "dec_seq_len": 96, 6 | "hidden_size": 96, 7 | "heads": 6, 8 | "n_encoder_layers": 1, 9 | "encoder_attention": "query_selector_0.90", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 96, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h1_720.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 336, 4 | "pred_len": 720, 5 | "dec_seq_len": 336, 6 | "hidden_size": 128, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.8", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 128, 13 | "embedding_size": 48, 14 | "prediction_type": "multi", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h2_168.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 336, 4 | "pred_len": 168, 5 | "dec_seq_len": 336, 6 | "hidden_size": 384, 7 | "heads": 6, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.8", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 96, 14 | "prediction_type": "multi", 15 | "dropout": 0, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 4, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h2_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 48, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 256, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.85", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 144, 13 | "embedding_size": 64, 14 | "prediction_type": "multi", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 6, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h2_336.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 336, 4 | "pred_len": 336, 5 | "dec_seq_len": 336, 6 | "hidden_size": 96, 7 | "heads": 6, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.75", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 64, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 3, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h2_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 720, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 2, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.85", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_h2_720.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 336, 4 | "pred_len": 720, 5 | "dec_seq_len": 336, 6 | "hidden_size": 512, 7 | "heads": 5, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.90", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 64, 14 | "prediction_type": "multi", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_m1_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 5, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.5", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 48, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_m1_288.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 384, 4 | "pred_len": 288, 5 | "dec_seq_len": 384, 6 | "hidden_size": 312, 7 | "heads": 4, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.9", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 32, 14 | "prediction_type": "multi", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 1, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_m1_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 144, 7 | "heads": 5, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.75", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 144, 13 | "embedding_size": 96, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_m1_672.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 672, 4 | "pred_len": 672, 5 | "dec_seq_len": 384, 6 | "hidden_size": 256, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.90", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 6, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_m_m1_96.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 96, 5 | "dec_seq_len": 96, 6 | "hidden_size": 96, 7 | "heads": 5, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.85", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 144, 13 | "embedding_size": 24, 14 | "prediction_type": "multi", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h1_168.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 168, 5 | "dec_seq_len": 336, 6 | "hidden_size": 128, 7 | "heads": 4, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.1", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 48, 13 | "embedding_size": 48, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 3, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h1_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 312, 7 | "heads": 4, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.8", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 48, 13 | "embedding_size": 96, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h1_336.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 336, 5 | "dec_seq_len": 720, 6 | "hidden_size": 128, 7 | "heads": 5, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.7", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 48, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 2, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h1_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 720, 4 | "pred_len": 48, 5 | "dec_seq_len": 168, 6 | "hidden_size": 378, 7 | "heads": 6, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.6", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 100, 13 | "embedding_size": 24, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h1_720.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh1", 3 | "seq_len": 840, 4 | "pred_len": 720, 5 | "dec_seq_len": 840, 6 | "hidden_size": 312, 7 | "heads": 3, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.7", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 24, 13 | "embedding_size": 64, 14 | "prediction_type": "uni", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 2, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h2_168.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 720, 4 | "pred_len": 168, 5 | "dec_seq_len": 336, 6 | "hidden_size": 256, 7 | "heads": 2, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.2", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 48, 13 | "embedding_size": 18, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 4, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h2_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 48, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 144, 7 | "heads": 3, 8 | "n_encoder_layers": 1, 9 | "encoder_attention": "query_selector_0.5", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 64, 13 | "embedding_size": 128, 14 | "prediction_type": "uni", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 3, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h2_336.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 720, 4 | "pred_len": 336, 5 | "dec_seq_len": 336, 6 | "hidden_size": 144, 7 | "heads": 5, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.95", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 100, 13 | "embedding_size": 48, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 3, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h2_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 720, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 378, 7 | "heads": 2, 8 | "n_encoder_layers": 4, 9 | "encoder_attention": "query_selector_0.90", 10 | "n_decoder_layers": 3, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 5, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_h2_720.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTh2", 3 | "seq_len": 336, 4 | "pred_len": 720, 5 | "dec_seq_len": 336, 6 | "hidden_size": 256, 7 | "heads": 3, 8 | "n_encoder_layers": 3, 9 | "encoder_attention": "query_selector_0.8", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 48, 13 | "embedding_size": 96, 14 | "prediction_type": "uni", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 2, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_m1_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 24, 5 | "dec_seq_len": 48, 6 | "hidden_size": 128, 7 | "heads": 3, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.8", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 48, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_m1_288.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 384, 4 | "pred_len": 288, 5 | "dec_seq_len": 384, 6 | "hidden_size": 256, 7 | "heads": 5, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.75", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 24, 14 | "prediction_type": "uni", 15 | "dropout": 0.15, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 6, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_m1_48.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 48, 5 | "dec_seq_len": 48, 6 | "hidden_size": 144, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.85", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.1, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_m1_672.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 384, 4 | "pred_len": 672, 5 | "dec_seq_len": 384, 6 | "hidden_size": 144, 7 | "heads": 3, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.90", 10 | "n_decoder_layers": 2, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0.05, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 6, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /settings/tuned/ts_query-selector_u_m1_96.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": "ETTm1", 3 | "seq_len": 96, 4 | "pred_len": 96, 5 | "dec_seq_len": 96, 6 | "hidden_size": 128, 7 | "heads": 2, 8 | "n_encoder_layers": 2, 9 | "encoder_attention": "query_selector_0.8", 10 | "n_decoder_layers": 1, 11 | "decoder_attention": "full", 12 | "batch_size": 32, 13 | "embedding_size": 32, 14 | "prediction_type": "uni", 15 | "dropout": 0, 16 | "fp16": true, 17 | "deepspeed": true, 18 | "iterations": 7, 19 | "exps": 5, 20 | "debug": false 21 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | from torch.optim import Adam 9 | from torch.utils.data import DataLoader 10 | 11 | from deepspeed import deepspeed 12 | import torch.nn as nn 13 | 14 | import ipc 15 | from config import build_parser 16 | from model import Transformer 17 | from data_loader import Dataset_ETT_hour, Dataset_ETT_minute 18 | from metrics import metric 19 | 20 | 21 | def get_model(args): 22 | return Transformer(args.embedding_size, args.hidden_size, args.input_len, args.dec_seq_len, args.pred_len, 23 | output_len=args.output_len, 24 | n_heads=args.n_heads, n_encoder_layers=args.n_encoder_layers, 25 | n_decoder_layers=args.n_decoder_layers, enc_attn_type=args.encoder_attention, 26 | dec_attn_type=args.decoder_attention, dropout=args.dropout) 27 | 28 | 29 | def get_params(mdl): 30 | return mdl.parameters() 31 | 32 | 33 | def _get_data(args, flag): 34 | if not args.data == 'ETTm1': 35 | Data = Dataset_ETT_hour 36 | else: 37 | Data = Dataset_ETT_minute 38 | # timeenc = 0 if args.embed != 'timeF' else 1 39 | 40 | if flag == 'test': 41 | shuffle_flag = False; 42 | drop_last = True; 43 | batch_size = 32 44 | freq = args.freq 45 | elif flag == 'pred': 46 | shuffle_flag = False; 47 | drop_last = False; 48 | batch_size = 1; 49 | # freq = args.detail_freq 50 | # Data = Dataset_Pred 51 | else: 52 | shuffle_flag = True; 53 | drop_last = True; 54 | batch_size = args.batch_size 55 | # freq = args.freq 56 | 57 | data_set = Data( 58 | root_path='data', 59 | data_path=args.data+'.csv', 60 | flag=flag, 61 | size=[args.seq_len, 0, args.pred_len], 62 | features=args.features, 63 | target=args.target, 64 | inverse=args.inverse, 65 | # timeenc=timeenc, 66 | # freq=freq 67 | ) 68 | print(flag, len(data_set)) 69 | data_loader = DataLoader( 70 | data_set, 71 | batch_size=batch_size, 72 | shuffle=shuffle_flag, 73 | num_workers=args.num_workers, 74 | drop_last=drop_last) 75 | 76 | return data_set, data_loader 77 | 78 | 79 | def run_metrics(caption, preds, trues): 80 | preds = np.array(preds) 81 | trues = np.array(trues) 82 | # print('test shape:', preds.shape, trues.shape) 83 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) 84 | trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1]) 85 | # print('test shape:', preds.shape, trues.shape) 86 | mae, mse, rmse, mape, mspe = metric(preds, trues) 87 | print('{} ; MSE: {}, MAE: {}'.format(caption, mse, mae)) 88 | return mse, mae 89 | 90 | 91 | def run_iteration(model, loader, args, training=True, message = ''): 92 | preds = [] 93 | trues = [] 94 | total_loss = 0 95 | elem_num = 0 96 | steps = 0 97 | target_device = 'cuda:{}'.format(args.local_rank) 98 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(loader): 99 | if not args.deepspeed: 100 | model.optim.zero_grad() 101 | 102 | batch = torch.tensor(batch_x, dtype=torch.float16 if args.fp16 else torch.float32, device=target_device) 103 | target = torch.tensor(batch_y, dtype=torch.float16 if args.fp16 else torch.float32, 104 | device=target_device) 105 | 106 | elem_num += len(batch) 107 | steps += 1 108 | 109 | result = model(batch) 110 | 111 | loss = nn.functional.mse_loss(result.squeeze(2), target.squeeze(2), reduction='mean') 112 | 113 | #pred = result.detach().cpu().unsqueeze(2).numpy() # .squeeze() 114 | pred = result.detach().cpu().numpy() # .squeeze() 115 | true = target.detach().cpu().numpy() # .squeeze() 116 | 117 | preds.append(pred) 118 | trues.append(true) 119 | 120 | unscaled_loss = loss.item() 121 | total_loss += unscaled_loss 122 | print("{} Loss at step {}: {}, mean for epoch: {}, mem_alloc: {}".format(message, steps, unscaled_loss, total_loss / steps,torch.cuda.max_memory_allocated())) 123 | 124 | if training: 125 | if args.deepspeed: 126 | model.backward(loss) 127 | model.step() 128 | else: 129 | loss.backward() 130 | model.optim.step() 131 | return preds, trues 132 | 133 | 134 | def preform_experiment(args): 135 | 136 | model = get_model(args) 137 | params = list(get_params(model)) 138 | print('Number of parameters: {}'.format(len(params))) 139 | for p in params: 140 | print(p.shape) 141 | 142 | if args.deepspeed: 143 | deepspeed_engine, optimizer, _, _ = deepspeed.initialize(args=args, 144 | model=model, 145 | model_parameters=params) 146 | else: 147 | model.to('cuda') 148 | model.optim = Adam(params, lr=0.001) 149 | 150 | train_data, train_loader = _get_data(args, flag='train') 151 | assert len(train_data.data_x[0]) == args.input_len, \ 152 | "Dataset contains input vectors of length {} while input_len is set to {}".format(len(train_data.data_x[0], args.input_len)) 153 | assert len(train_data.data_y[0]) == args.output_len, \ 154 | "Dataset contains output vectors of length {} while output_len is set to {}".format( 155 | len(train_data.data_y[0]), args.output_len) 156 | 157 | start = time.time() 158 | for iter in range(1, args.iterations + 1): 159 | preds, trues = run_iteration(deepspeed_engine if args.deepspeed else model , train_loader, args, training=True, message=' Run {:>3}, iteration: {:>3}: '.format(args.run_num, iter)) 160 | mse, mae = run_metrics("Loss after iteration {}".format(iter), preds, trues) 161 | if args.local_rank == 0: 162 | ipc.sendPartials(iter, mse, mae) 163 | print("Time per iteration {}, memory {}".format((time.time() - start)/iter, torch.cuda.memory_stats())) 164 | 165 | print(torch.cuda.max_memory_allocated()) 166 | 167 | if args.debug: 168 | model.record() 169 | 170 | 171 | test_data, test_loader = _get_data(args, flag='test') 172 | if deepspeed: 173 | model.inference() 174 | else: 175 | model.eval() 176 | # Model evaluation on validation data 177 | v_preds, v_trues = run_iteration(deepspeed_engine if args.deepspeed else model, test_loader, args, training=False, message="Validation set") 178 | mse, mae = run_metrics("Loss for validation set ", v_preds, v_trues) 179 | 180 | # Send results / plot models if debug option is on 181 | if args.local_rank == 0: 182 | ipc.sendResults(mse, mae) 183 | if args.debug: 184 | plot_model(args, model) 185 | 186 | def main(): 187 | parser = build_parser() 188 | args = parser.parse_args(None) 189 | preform_experiment(args) 190 | 191 | 192 | if __name__ == '__main__': 193 | main() 194 | 195 | 196 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # Original code from https://github.com/zhouhaoyi/Informer2020/ 2 | from typing import List 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from pandas.tseries import offsets 7 | from pandas.tseries.frequencies import to_offset 8 | 9 | class TimeFeature: 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 14 | pass 15 | 16 | def __repr__(self): 17 | return self.__class__.__name__ + "()" 18 | 19 | class SecondOfMinute(TimeFeature): 20 | """Minute of hour encoded as value between [-0.5, 0.5]""" 21 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 22 | return index.second / 59.0 - 0.5 23 | 24 | class MinuteOfHour(TimeFeature): 25 | """Minute of hour encoded as value between [-0.5, 0.5]""" 26 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 27 | return index.minute / 59.0 - 0.5 28 | 29 | class HourOfDay(TimeFeature): 30 | """Hour of day encoded as value between [-0.5, 0.5]""" 31 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 32 | return index.hour / 23.0 - 0.5 33 | 34 | class DayOfWeek(TimeFeature): 35 | """Hour of day encoded as value between [-0.5, 0.5]""" 36 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 37 | return index.dayofweek / 6.0 - 0.5 38 | 39 | class DayOfMonth(TimeFeature): 40 | """Day of month encoded as value between [-0.5, 0.5]""" 41 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 42 | return (index.day - 1) / 30.0 - 0.5 43 | 44 | class DayOfYear(TimeFeature): 45 | """Day of year encoded as value between [-0.5, 0.5]""" 46 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 47 | return (index.dayofyear - 1) / 365.0 - 0.5 48 | 49 | class MonthOfYear(TimeFeature): 50 | """Month of year encoded as value between [-0.5, 0.5]""" 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return (index.month - 1) / 11.0 - 0.5 53 | 54 | class WeekOfYear(TimeFeature): 55 | """Week of year encoded as value between [-0.5, 0.5]""" 56 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 57 | return (index.isocalendar().week - 1) / 52.0 - 0.5 58 | 59 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 60 | """ 61 | Returns a list of time features that will be appropriate for the given frequency string. 62 | Parameters 63 | ---------- 64 | freq_str 65 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 66 | """ 67 | 68 | features_by_offsets = { 69 | offsets.YearEnd: [], 70 | offsets.QuarterEnd: [MonthOfYear], 71 | offsets.MonthEnd: [MonthOfYear], 72 | offsets.Week: [DayOfMonth, WeekOfYear], 73 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 74 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 75 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 76 | offsets.Minute: [ 77 | MinuteOfHour, 78 | HourOfDay, 79 | DayOfWeek, 80 | DayOfMonth, 81 | DayOfYear, 82 | ], 83 | offsets.Second: [ 84 | SecondOfMinute, 85 | MinuteOfHour, 86 | HourOfDay, 87 | DayOfWeek, 88 | DayOfMonth, 89 | DayOfYear, 90 | ], 91 | } 92 | 93 | offset = to_offset(freq_str) 94 | 95 | for offset_type, feature_classes in features_by_offsets.items(): 96 | if isinstance(offset, offset_type): 97 | return [cls() for cls in feature_classes] 98 | 99 | supported_freq_msg = f""" 100 | Unsupported frequency {freq_str} 101 | The following frequencies are supported: 102 | Y - yearly 103 | alias: A 104 | M - monthly 105 | W - weekly 106 | D - daily 107 | B - business days 108 | H - hourly 109 | T - minutely 110 | alias: min 111 | S - secondly 112 | """ 113 | raise RuntimeError(supported_freq_msg) 114 | 115 | def time_features(dates, timeenc=1, freq='h'): 116 | if timeenc==0: 117 | dates['month'] = dates.date.apply(lambda row:row.month,1) 118 | dates['day'] = dates.date.apply(lambda row:row.day,1) 119 | dates['weekday'] = dates.date.apply(lambda row:row.weekday(),1) 120 | dates['hour'] = dates.date.apply(lambda row:row.hour,1) 121 | dates['minute'] = dates.date.apply(lambda row:row.minute,1) 122 | dates['minute'] = dates.minute.map(lambda x:x//15) 123 | freq_map = { 124 | 'y':[],'m':['month'],'w':['month'],'d':['month','day','weekday'], 125 | 'b':['month','day','weekday'],'h':['month','day','weekday','hour'], 126 | 't':['month','day','weekday','hour','minute'], 127 | } 128 | return dates[freq_map[freq.lower()]].values 129 | if timeenc==1: 130 | dates = pd.to_datetime(dates.date.values) 131 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]).transpose(1,0) -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | # Original code from https://github.com/zhouhaoyi/Informer2020/ 2 | import numpy as np 3 | import torch 4 | 5 | def adjust_learning_rate(optimizer, epoch, args): 6 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 7 | if args.lradj=='type1': 8 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch-1) // 1))} 9 | elif args.lradj=='type2': 10 | lr_adjust = { 11 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 12 | 10: 5e-7, 15: 1e-7, 20: 5e-8 13 | } 14 | if epoch in lr_adjust.keys(): 15 | lr = lr_adjust[epoch] 16 | for param_group in optimizer.param_groups: 17 | param_group['lr'] = lr 18 | print('Updating learning rate to {}'.format(lr)) 19 | 20 | class EarlyStopping: 21 | def __init__(self, patience=7, verbose=False, delta=0): 22 | self.patience = patience 23 | self.verbose = verbose 24 | self.counter = 0 25 | self.best_score = None 26 | self.early_stop = False 27 | self.val_loss_min = np.Inf 28 | self.delta = delta 29 | 30 | def __call__(self, val_loss, model, path): 31 | score = -val_loss 32 | if self.best_score is None: 33 | self.best_score = score 34 | self.save_checkpoint(val_loss, model, path) 35 | elif score < self.best_score + self.delta: 36 | self.counter += 1 37 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 38 | if self.counter >= self.patience: 39 | self.early_stop = True 40 | else: 41 | self.best_score = score 42 | self.save_checkpoint(val_loss, model, path) 43 | self.counter = 0 44 | 45 | def save_checkpoint(self, val_loss, model, path): 46 | if self.verbose: 47 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 48 | torch.save(model.state_dict(), path+'/'+'checkpoint.pth') 49 | self.val_loss_min = val_loss 50 | 51 | class dotdict(dict): 52 | """dot.notation access to dictionary attributes""" 53 | __getattr__ = dict.get 54 | __setattr__ = dict.__setitem__ 55 | __delattr__ = dict.__delitem__ 56 | 57 | class StandardScaler(): 58 | def __init__(self): 59 | self.mean = 0. 60 | self.std = 1. 61 | 62 | def fit(self, data): 63 | self.mean = data.mean(0) 64 | self.std = data.std(0) 65 | 66 | def transform(self, data): 67 | mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean 68 | std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std 69 | return (data - mean) / std 70 | 71 | def inverse_transform(self, data): 72 | mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean 73 | std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std 74 | return (data * std) + mean --------------------------------------------------------------------------------