├── .gitignore ├── PURE-GUIv2.0 ├── app_dad.py ├── model_files │ ├── AirQuality_MTSIR3-GAN.pth │ ├── AirQuality_SSGAN.pth │ └── AirQuality_TimesNet.pth ├── pages │ ├── MTSIR3-GAN_logs.txt │ ├── SSGAN_logs.txt │ ├── TimesNet_logs.txt │ ├── __pycache__ │ │ ├── data_analysis.cpython-39.pyc │ │ ├── data_analysis_dad.cpython-312.pyc │ │ ├── data_analysis_dad.cpython-38.pyc │ │ ├── data_analysis_dad.cpython-39.pyc │ │ ├── model_visualization.cpython-38.pyc │ │ ├── model_visualization.cpython-39.pyc │ │ ├── time_imputation.cpython-39.pyc │ │ ├── time_imputation_dad.cpython-312.pyc │ │ ├── time_imputation_dad.cpython-38.pyc │ │ └── time_imputation_dad.cpython-39.pyc │ ├── data_analysis_dad.py │ ├── model_visualization.py │ └── time_imputation_dad.py └── uploaded_files │ ├── ETTh1.csv │ ├── ETTh2.csv │ ├── ETTm1.csv │ ├── ETTm2.csv │ ├── pm25_ground.csv │ ├── pm25_missing.csv │ └── weather.csv ├── R3GAN ├── R3GAN │ ├── FusedOperators.py │ ├── Networks.py │ ├── Resamplers.py │ ├── Trainer.py │ └── __pycache__ │ │ ├── FusedOperators.cpython-310.pyc │ │ ├── FusedOperators.cpython-39.pyc │ │ ├── Networks.cpython-310.pyc │ │ ├── Networks.cpython-39.pyc │ │ ├── Resamplers.cpython-310.pyc │ │ ├── Resamplers.cpython-39.pyc │ │ └── Trainer.cpython-310.pyc ├── calc_metrics.py ├── dataset_tool.py ├── dnnlib │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── util.cpython-310.pyc │ │ └── util.cpython-39.pyc │ └── util.py ├── gen_timeseries.py ├── legacy.py ├── metrics │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── frechet_inception_distance.cpython-310.pyc │ │ ├── frechet_inception_distance.cpython-39.pyc │ │ ├── inception_score.cpython-310.pyc │ │ ├── inception_score.cpython-39.pyc │ │ ├── kernel_inception_distance.cpython-310.pyc │ │ ├── kernel_inception_distance.cpython-39.pyc │ │ ├── metric_main.cpython-310.pyc │ │ ├── metric_main.cpython-39.pyc │ │ ├── metric_utils.cpython-310.pyc │ │ ├── metric_utils.cpython-39.pyc │ │ ├── precision_recall.cpython-310.pyc │ │ └── precision_recall.cpython-39.pyc │ ├── frechet_inception_distance.py │ ├── inception_score.py │ ├── kernel_inception_distance.py │ ├── metric_main.py │ ├── metric_utils.py │ └── precision_recall.py ├── process_air_quality.py ├── torch_utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── custom_ops.cpython-310.pyc │ │ ├── custom_ops.cpython-39.pyc │ │ ├── misc.cpython-310.pyc │ │ ├── misc.cpython-39.pyc │ │ ├── persistence.cpython-310.pyc │ │ ├── training_stats.cpython-310.pyc │ │ └── training_stats.cpython-39.pyc │ ├── custom_ops.py │ ├── misc.py │ ├── ops │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── bias_act.cpython-310.pyc │ │ │ ├── bias_act.cpython-39.pyc │ │ │ ├── conv2d_gradfix.cpython-310.pyc │ │ │ ├── conv2d_gradfix.cpython-39.pyc │ │ │ ├── grid_sample_gradfix.cpython-310.pyc │ │ │ ├── grid_sample_gradfix.cpython-39.pyc │ │ │ ├── upfirdn2d.cpython-310.pyc │ │ │ └── upfirdn2d.cpython-39.pyc │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py ├── train.py └── training │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── augment.cpython-310.pyc │ ├── dataset.cpython-310.pyc │ ├── dataset.cpython-39.pyc │ ├── loss.cpython-310.pyc │ ├── networks.cpython-310.pyc │ ├── networks.cpython-39.pyc │ ├── training_loop.cpython-310.pyc │ └── training_loop.cpython-39.pyc │ ├── augment.py │ ├── dataset.py │ ├── loss.py │ ├── networks.py │ └── training_loop.py ├── README.md ├── SSGAN ├── data_loader.py ├── main.py ├── models │ ├── Based_on_BRITS.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── Based_on_BRITS.cpython-310.pyc │ │ ├── Based_on_BRITS.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── brits.cpython-310.pyc │ │ ├── brits.cpython-39.pyc │ │ ├── brits_i.cpython-310.pyc │ │ ├── brits_i.cpython-39.pyc │ │ ├── classifier.cpython-310.pyc │ │ ├── classifier.cpython-39.pyc │ │ ├── discriminator.cpython-310.pyc │ │ ├── discriminator.cpython-39.pyc │ │ ├── gru_d.cpython-310.pyc │ │ ├── gru_d.cpython-39.pyc │ │ ├── m_rnn.cpython-310.pyc │ │ ├── m_rnn.cpython-39.pyc │ │ ├── rits.cpython-310.pyc │ │ ├── rits.cpython-39.pyc │ │ ├── rits_i.cpython-310.pyc │ │ └── rits_i.cpython-39.pyc │ ├── brits.py │ ├── brits_i.py │ ├── classifier.py │ ├── discriminator.py │ ├── discriminator2.py │ ├── gru_d.py │ ├── m_rnn.py │ ├── rits.py │ └── rits_i.py ├── preprocess.py └── utils.py ├── TimesNet ├── data_provider │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── data_factory.cpython-310.pyc │ │ ├── data_loader.cpython-310.pyc │ │ ├── m4.cpython-310.pyc │ │ └── uea.cpython-310.pyc │ ├── data_factory.py │ ├── data_loader.py │ ├── m4.py │ └── uea.py ├── exp │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── exp_anomaly_detection.cpython-310.pyc │ │ ├── exp_basic.cpython-310.pyc │ │ ├── exp_classification.cpython-310.pyc │ │ ├── exp_imputation.cpython-310.pyc │ │ ├── exp_long_term_forecasting.cpython-310.pyc │ │ └── exp_short_term_forecasting.cpython-310.pyc │ ├── exp_anomaly_detection.py │ ├── exp_basic.py │ ├── exp_classification.py │ ├── exp_imputation.py │ ├── exp_long_term_forecasting.py │ └── exp_short_term_forecasting.py ├── layers │ ├── AutoCorrelation.py │ ├── Autoformer_EncDec.py │ ├── Conv_Blocks.py │ ├── Crossformer_EncDec.py │ ├── DWT_Decomposition.py │ ├── ETSformer_EncDec.py │ ├── Embed.py │ ├── FourierCorrelation.py │ ├── MultiWaveletCorrelation.py │ ├── Pyraformer_EncDec.py │ ├── SelfAttention_Family.py │ ├── StandardNorm.py │ ├── Transformer_EncDec.py │ ├── __init__.py │ └── __pycache__ │ │ ├── AutoCorrelation.cpython-310.pyc │ │ ├── Autoformer_EncDec.cpython-310.pyc │ │ ├── Conv_Blocks.cpython-310.pyc │ │ ├── Crossformer_EncDec.cpython-310.pyc │ │ ├── DWT_Decomposition.cpython-310.pyc │ │ ├── ETSformer_EncDec.cpython-310.pyc │ │ ├── Embed.cpython-310.pyc │ │ ├── FourierCorrelation.cpython-310.pyc │ │ ├── MultiWaveletCorrelation.cpython-310.pyc │ │ ├── Pyraformer_EncDec.cpython-310.pyc │ │ ├── SelfAttention_Family.cpython-310.pyc │ │ ├── StandardNorm.cpython-310.pyc │ │ ├── Transformer_EncDec.cpython-310.pyc │ │ └── __init__.cpython-310.pyc ├── models │ ├── Autoformer.py │ ├── Crossformer.py │ ├── DLinear.py │ ├── ETSformer.py │ ├── FEDformer.py │ ├── FiLM.py │ ├── FreTS.py │ ├── GAN.py │ ├── Informer.py │ ├── Koopa.py │ ├── LightTS.py │ ├── MICN.py │ ├── Mamba.py │ ├── MambaSimple.py │ ├── MultiPatchFormer.py │ ├── Nonstationary_Transformer.py │ ├── PAttn.py │ ├── PatchTST.py │ ├── Pyraformer.py │ ├── Reformer.py │ ├── SCINet.py │ ├── SegRNN.py │ ├── TSMixer.py │ ├── TemporalFusionTransformer.py │ ├── TiDE.py │ ├── TimeMixer.py │ ├── TimeXer.py │ ├── TimesNet.py │ ├── Transformer.py │ ├── WPMixer.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── Autoformer.cpython-310.pyc │ │ ├── Crossformer.cpython-310.pyc │ │ ├── DLinear.cpython-310.pyc │ │ ├── ETSformer.cpython-310.pyc │ │ ├── FEDformer.cpython-310.pyc │ │ ├── FiLM.cpython-310.pyc │ │ ├── FreTS.cpython-310.pyc │ │ ├── GAN.cpython-310.pyc │ │ ├── Informer.cpython-310.pyc │ │ ├── Koopa.cpython-310.pyc │ │ ├── LightTS.cpython-310.pyc │ │ ├── MICN.cpython-310.pyc │ │ ├── MambaSimple.cpython-310.pyc │ │ ├── MultiPatchFormer.cpython-310.pyc │ │ ├── Nonstationary_Transformer.cpython-310.pyc │ │ ├── PAttn.cpython-310.pyc │ │ ├── PatchTST.cpython-310.pyc │ │ ├── Pyraformer.cpython-310.pyc │ │ ├── Reformer.cpython-310.pyc │ │ ├── SCINet.cpython-310.pyc │ │ ├── SegRNN.cpython-310.pyc │ │ ├── TSMixer.cpython-310.pyc │ │ ├── TemporalFusionTransformer.cpython-310.pyc │ │ ├── TiDE.cpython-310.pyc │ │ ├── TimeMixer.cpython-310.pyc │ │ ├── TimeXer.cpython-310.pyc │ │ ├── TimesNet.cpython-310.pyc │ │ ├── Transformer.cpython-310.pyc │ │ ├── WPMixer.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ └── iTransformer.cpython-310.pyc │ └── iTransformer.py ├── run.py └── utils │ ├── ADFtest.py │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── augmentation.cpython-310.pyc │ ├── dtw_metric.cpython-310.pyc │ ├── losses.cpython-310.pyc │ ├── m4_summary.cpython-310.pyc │ ├── masking.cpython-310.pyc │ ├── metrics.cpython-310.pyc │ ├── print_args.cpython-310.pyc │ ├── timefeatures.cpython-310.pyc │ └── tools.cpython-310.pyc │ ├── augmentation.py │ ├── dtw.py │ ├── dtw_metric.py │ ├── losses.py │ ├── m4_summary.py │ ├── masking.py │ ├── metrics.py │ ├── print_args.py │ ├── timefeatures.py │ └── tools.py └── datasets ├── AirQuality ├── pm25_ground.txt └── pm25_missing.txt ├── PSM ├── test.csv └── test_label.csv └── PhysioNet └── link.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /PURE-GUIv2.0/app_dad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dash 3 | 4 | from dash import dcc, html, Input, Output, State, callback 5 | import dash_bootstrap_components as dbc 6 | from pages.time_imputation_dad import time_imputation_layout 7 | from pages.data_analysis_dad import data_analysis_layout 8 | from pages.model_visualization import model_visualization_layout 9 | # 初始化应用 10 | app = dash.Dash(__name__, 11 | external_stylesheets=[dbc.themes.BOOTSTRAP], 12 | suppress_callback_exceptions=True) 13 | 14 | logo_src = 'logo.png' 15 | # 侧边导航栏 16 | sidebar = html.Div( 17 | [ 18 | # html.Img(src=logo_src, style={"width": "100%", "max-width": "100%", "height": "auto", "margin-bottom": "1rem"}), 19 | html.H3("Time-Series Analysis System", className="display-5", style={"font-size": "1.5rem"}), 20 | html.Hr(), 21 | dbc.Nav( 22 | [ 23 | dbc.NavLink("Data Analysis", href="/data-analysis", active="exact"), 24 | dbc.NavLink("Data Imputation", href="/time-imputation", active="exact"), 25 | dbc.NavLink("Model Visualization", href="/model-visualization", active="exact"), 26 | ], 27 | vertical=True, 28 | pills=True, 29 | ), 30 | ], 31 | style={ 32 | "position": "fixed", 33 | "top": 0, 34 | "left": 0, 35 | "bottom": 0, 36 | "width": "16rem", 37 | "padding": "2rem", 38 | "background-color": "#f8f9fa", 39 | }, 40 | ) 41 | 42 | # 主布局 43 | content = html.Div( 44 | id="page-content", 45 | style={"marginLeft": "18rem", "marginRight": "2rem", "padding": "2rem 1rem"} 46 | ) 47 | 48 | app.layout = html.Div([ 49 | dcc.Location(id="url",refresh=False), 50 | sidebar, 51 | content, 52 | dcc.Store(id="stored-filenames", data=[]), 53 | dcc.Store(id="upload-progress-data", data=[]), 54 | dcc.Store(id="selected-dataset", data=None), 55 | dcc.Store(id="selected-features", data=[]), 56 | html.Div(id="dropdown-cell-placeholder"), 57 | dcc.Store(id="feature-rows", data=[]), 58 | dcc.Store(id="visualization-graph-store", data=None) # 新增存储可视化图形数据的组件 59 | ]) 60 | 61 | 62 | # 回调:页面路由 63 | @callback( 64 | Output("page-content", "children"), 65 | Input("url", "pathname") 66 | ) 67 | def render_page_content(pathname): 68 | if pathname == "/data-analysis": 69 | return data_analysis_layout 70 | elif pathname == "/time-imputation": 71 | return time_imputation_layout 72 | elif pathname == "/model-visualization": 73 | return model_visualization_layout 74 | return html.P("Choose a page from the sidebar.") 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | app.run(debug=True) -------------------------------------------------------------------------------- /PURE-GUIv2.0/model_files/AirQuality_MTSIR3-GAN.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/model_files/AirQuality_MTSIR3-GAN.pth -------------------------------------------------------------------------------- /PURE-GUIv2.0/model_files/AirQuality_SSGAN.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/model_files/AirQuality_SSGAN.pth -------------------------------------------------------------------------------- /PURE-GUIv2.0/model_files/AirQuality_TimesNet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/model_files/AirQuality_TimesNet.pth -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/MTSIR3-GAN_logs.txt: -------------------------------------------------------------------------------- 1 | train 6036 2 | val 878 3 | test 1752 4 | iters: 100, epoch: 1 | loss: 0.3263778 5 | speed: 0.0391s/iter; left time: 143.8878s 6 | iters: 200, epoch: 1 | loss: 0.2931210 7 | speed: 0.0321s/iter; left time: 114.9809s 8 | iters: 300, epoch: 1 | loss: 0.2862863 9 | speed: 0.0318s/iter; left time: 110.7554s 10 | Epoch: 1 cost time: 12.83705472946167 11 | Epoch: 1, Steps: 378 | Train Loss: 0.2890600 Vali Loss: 0.3686774 Test Loss: 0.2053866 12 | Validation loss decreased (inf --> 0.368677). Saving model ... 13 | Updating learning rate to 0.001 14 | iters: 100, epoch: 2 | loss: 0.3064845 15 | speed: 0.0906s/iter; left time: 299.2311s 16 | iters: 200, epoch: 2 | loss: 0.1976164 17 | speed: 0.0320s/iter; left time: 102.4564s 18 | iters: 300, epoch: 2 | loss: 0.2021343 19 | speed: 0.0312s/iter; left time: 96.8817s 20 | Epoch: 2 cost time: 12.748377084732056 21 | Epoch: 2, Steps: 378 | Train Loss: 0.2506826 Vali Loss: 0.3456713 Test Loss: 0.1920307 22 | Validation loss decreased (0.368677 --> 0.345671). Saving model ... 23 | Updating learning rate to 0.0005 24 | iters: 100, epoch: 3 | loss: 0.2295276 25 | speed: 0.0891s/iter; left time: 260.7093s 26 | iters: 200, epoch: 3 | loss: 0.2636221 27 | speed: 0.0314s/iter; left time: 88.7160s 28 | iters: 300, epoch: 3 | loss: 0.2418462 29 | speed: 0.0318s/iter; left time: 86.7329s 30 | Epoch: 3 cost time: 12.669057846069336 31 | Epoch: 3, Steps: 378 | Train Loss: 0.2375194 Vali Loss: 0.3373488 Test Loss: 0.1852830 32 | Validation loss decreased (0.345671 --> 0.337349). Saving model ... 33 | Updating learning rate to 0.00025 34 | iters: 100, epoch: 4 | loss: 0.3084089 35 | speed: 0.0884s/iter; left time: 225.1855s 36 | iters: 200, epoch: 4 | loss: 0.1659867 37 | speed: 0.0312s/iter; left time: 76.4345s 38 | iters: 300, epoch: 4 | loss: 0.1870451 39 | speed: 0.0313s/iter; left time: 73.3677s 40 | Epoch: 4 cost time: 12.525406122207642 41 | Epoch: 4, Steps: 378 | Train Loss: 0.2337164 Vali Loss: 0.3215748 Test Loss: 0.1767931 42 | Validation loss decreased (0.337349 --> 0.321575). Saving model ... 43 | Updating learning rate to 0.000125 44 | iters: 100, epoch: 5 | loss: 0.2564480 45 | speed: 0.0888s/iter; left time: 192.5103s 46 | iters: 200, epoch: 5 | loss: 0.3722028 47 | speed: 0.0318s/iter; left time: 65.8318s 48 | iters: 300, epoch: 5 | loss: 0.2867151 49 | speed: 0.0313s/iter; left time: 61.5871s 50 | Epoch: 5 cost time: 12.631093978881836 51 | Epoch: 5, Steps: 378 | Train Loss: 0.2305659 Vali Loss: 0.3179046 Test Loss: 0.1738094 52 | Validation loss decreased (0.321575 --> 0.317905). Saving model ... 53 | Updating learning rate to 6.25e-05 54 | iters: 100, epoch: 6 | loss: 0.3136266 55 | speed: 0.0879s/iter; left time: 157.3523s 56 | iters: 200, epoch: 6 | loss: 0.2246323 57 | speed: 0.0316s/iter; left time: 53.4495s 58 | iters: 300, epoch: 6 | loss: 0.1334873 59 | speed: 0.0316s/iter; left time: 50.2439s 60 | Epoch: 6 cost time: 12.640253782272339 61 | Epoch: 6, Steps: 378 | Train Loss: 0.2287082 Vali Loss: 0.3149282 Test Loss: 0.1735218 62 | Validation loss decreased (0.317905 --> 0.314928). Saving model ... 63 | Updating learning rate to 3.125e-05 64 | iters: 100, epoch: 7 | loss: 0.1998388 65 | speed: 0.0911s/iter; left time: 128.7028s 66 | iters: 200, epoch: 7 | loss: 0.2398724 67 | speed: 0.0313s/iter; left time: 41.0628s 68 | iters: 300, epoch: 7 | loss: 0.2636862 69 | speed: 0.0311s/iter; left time: 37.6671s 70 | Epoch: 7 cost time: 12.714547157287598 71 | Epoch: 7, Steps: 378 | Train Loss: 0.2282870 Vali Loss: 0.3191752 Test Loss: 0.1744985 72 | EarlyStopping counter: 1 out of 3 73 | Updating learning rate to 1.5625e-05 74 | iters: 100, epoch: 8 | loss: 0.3208034 75 | speed: 0.0910s/iter; left time: 94.1682s 76 | iters: 200, epoch: 8 | loss: 0.2196216 77 | speed: 0.0316s/iter; left time: 29.5120s 78 | iters: 300, epoch: 8 | loss: 0.1461131 79 | speed: 0.0319s/iter; left time: 26.6590s 80 | Epoch: 8 cost time: 12.838458061218262 81 | Epoch: 8, Steps: 378 | Train Loss: 0.2278287 Vali Loss: 0.3182404 Test Loss: 0.1742399 82 | EarlyStopping counter: 2 out of 3 83 | Updating learning rate to 7.8125e-06 84 | iters: 100, epoch: 9 | loss: 0.1967115 85 | speed: 0.0887s/iter; left time: 58.2586s 86 | iters: 200, epoch: 9 | loss: 0.1657656 87 | speed: 0.0310s/iter; left time: 17.2508s 88 | iters: 300, epoch: 9 | loss: 0.2217258 89 | speed: 0.0310s/iter; left time: 14.1520s 90 | Epoch: 9 cost time: 12.512340545654297 91 | Epoch: 9, Steps: 378 | Train Loss: 0.2276424 Vali Loss: 0.3152142 Test Loss: 0.1725200 92 | EarlyStopping counter: 3 out of 3 93 | Early stopping 94 | Updating learning rate to 1.953125e-06 95 | >>>>>>>testing : imputation_AirQuality_mask_0.25_MTSIR3-GAN_custom_ftM_sl96_ll0_pl0_dm64_nh8_el2_dl1_df64_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< 96 | test 1752 97 | test shape: (1752, 96, 36) (1752, 96, 36) 98 | mse:0.33367881178855896, mae:0.20536541938781738 -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/data_analysis.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/data_analysis.cpython-39.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/data_analysis_dad.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/data_analysis_dad.cpython-312.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/data_analysis_dad.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/data_analysis_dad.cpython-38.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/data_analysis_dad.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/data_analysis_dad.cpython-39.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/model_visualization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/model_visualization.cpython-38.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/model_visualization.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/model_visualization.cpython-39.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/time_imputation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/time_imputation.cpython-39.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/time_imputation_dad.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/time_imputation_dad.cpython-312.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/time_imputation_dad.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/time_imputation_dad.cpython-38.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/__pycache__/time_imputation_dad.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/PURE-GUIv2.0/pages/__pycache__/time_imputation_dad.cpython-39.pyc -------------------------------------------------------------------------------- /PURE-GUIv2.0/pages/model_visualization.py: -------------------------------------------------------------------------------- 1 | import dash 2 | from dash import dcc, html, Input, Output, State, callback 3 | import dash_bootstrap_components as dbc 4 | import plotly.express as px 5 | from dash.exceptions import PreventUpdate 6 | import plotly.express as px 7 | import base64 8 | import pandas as pd 9 | import numpy as np 10 | from dash import dcc, html, Input, Output, State, callback, no_update 11 | import os 12 | import plotly.graph_objects as go 13 | MODEL_DIR="model_files" 14 | if os.path.exists(MODEL_DIR): 15 | model_files = [f for f in os.listdir(MODEL_DIR) if f.endswith('.pth') or f.endswith('.pkl')] 16 | else: 17 | model_files = [] 18 | 19 | feature_dropdown = dcc.Dropdown( 20 | id='feature-dropdown', 21 | options=[{'label': f'Feature {i}', 'value': i} for i in range(36)], 22 | value=0, 23 | clearable=False 24 | ) 25 | 26 | model_visualization_layout = html.Div([ 27 | dbc.Row([ 28 | html.H4("Available Models"), 29 | dcc.Dropdown(id="model-selector", options=model_files, placeholder="trained models"), 30 | ],style={"border-right": "1px solid #ccc"}), 31 | dbc.Row([ 32 | html.Div([ feature_dropdown, 33 | html.H4("Visualized Results"), 34 | dcc.Graph( 35 | id='visual', 36 | ) 37 | ], style={"border": "1px solid #ccc", "padding": "20px", "flex": 1}), 38 | ]) 39 | ]) 40 | @callback( 41 | Output('visual', 'figure'), 42 | Input('feature-dropdown', 'value'), 43 | Input('model-selector', 'value'), 44 | prevent_initial_call=True 45 | ) 46 | def update_visualization(feature_index, model_name): 47 | if not model_name: 48 | return dash.no_update 49 | model_name = model_name.split('_')[1].split('.')[0] 50 | try: 51 | true_data = np.load(f'model_results/{model_name}/true.npy') 52 | pred_data = np.load(f'model_results/{model_name}/pred.npy') 53 | 54 | t = np.arange(true_data.shape[0]) # 时间轴 55 | 56 | fig = go.Figure() 57 | fig.add_trace(go.Scatter( 58 | x=t, 59 | y=np.mean(true_data,axis=1)[:,feature_index], 60 | name='True Data', 61 | line=dict(color='blue') 62 | )) 63 | fig.add_trace(go.Scatter( 64 | x=t, 65 | y=np.mean(pred_data,axis=1)[:,feature_index], 66 | name='Imputed Data', 67 | line=dict(color='red') 68 | )) 69 | 70 | fig.update_layout( 71 | title=f"Comparison for Feature {feature_index}", 72 | xaxis_title="Time Step", 73 | yaxis_title="Value", 74 | legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1) 75 | ) 76 | 77 | return fig 78 | 79 | except Exception as e: 80 | return go.Figure(data=[go.Scatter(x=[], y=[], mode='lines')]) -------------------------------------------------------------------------------- /R3GAN/R3GAN/FusedOperators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from torch_utils.ops import bias_act 5 | 6 | class BiasedActivationReference(nn.Module): 7 | Gain = math.sqrt(2 / (1 + 0.2 ** 2)) 8 | Function = nn.LeakyReLU(0.2) 9 | 10 | def __init__(self, InputUnits): 11 | super(BiasedActivationReference, self).__init__() 12 | 13 | self.Bias = nn.Parameter(torch.empty(InputUnits)) 14 | self.Bias.data.zero_() 15 | 16 | def forward(self, x): 17 | y = x + self.Bias.to(x.dtype).view(1, -1, 1, 1) if len(x.shape) > 2 else x + self.Bias.to(x.dtype).view(1, -1) 18 | return BiasedActivationReference.Function(y) 19 | 20 | class BiasedActivationCUDA(nn.Module): 21 | Gain = math.sqrt(2 / (1 + 0.2 ** 2)) 22 | Function = 'lrelu' 23 | 24 | def __init__(self, InputUnits): 25 | super(BiasedActivationCUDA, self).__init__() 26 | 27 | self.Bias = nn.Parameter(torch.empty(InputUnits)) 28 | self.Bias.data.zero_() 29 | 30 | def forward(self, x): 31 | return bias_act.bias_act(x, self.Bias.to(x.dtype), act=BiasedActivationCUDA.Function, gain=1) 32 | 33 | BiasedActivation = BiasedActivationCUDA -------------------------------------------------------------------------------- /R3GAN/R3GAN/Resamplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy 4 | from torch_utils.ops import upfirdn2d 5 | 6 | def CreateLowpassKernel(Weights, Inplace): 7 | Kernel = numpy.array([Weights]) if Inplace else numpy.convolve(Weights, [1, 1]).reshape(1, -1) 8 | Kernel = torch.Tensor(Kernel.T @ Kernel) 9 | return Kernel / torch.sum(Kernel) 10 | 11 | class InterpolativeUpsamplerReference(nn.Module): 12 | def __init__(self, Filter): 13 | super(InterpolativeUpsamplerReference, self).__init__() 14 | 15 | self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=False)) 16 | self.FilterRadius = len(Filter) // 2 17 | 18 | def forward(self, x): 19 | Kernel = 4 * self.Kernel.view(1, 1, self.Kernel.shape[0], self.Kernel.shape[1]).to(x.dtype) 20 | y = nn.functional.conv_transpose2d(x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]), Kernel, stride=2, padding=self.FilterRadius) 21 | 22 | return y.view(x.shape[0], x.shape[1], y.shape[2], y.shape[3]) 23 | 24 | class InterpolativeDownsamplerReference(nn.Module): 25 | def __init__(self, Filter): 26 | super(InterpolativeDownsamplerReference, self).__init__() 27 | 28 | self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=False)) 29 | self.FilterRadius = len(Filter) // 2 30 | 31 | def forward(self, x): 32 | Kernel = self.Kernel.view(1, 1, self.Kernel.shape[0], self.Kernel.shape[1]).to(x.dtype) 33 | y = nn.functional.conv2d(x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]), Kernel, stride=2, padding=self.FilterRadius) 34 | 35 | return y.view(x.shape[0], x.shape[1], y.shape[2], y.shape[3]) 36 | 37 | class InplaceUpsamplerReference(nn.Module): 38 | def __init__(self, Filter): 39 | super(InplaceUpsamplerReference, self).__init__() 40 | 41 | self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=True)) 42 | self.FilterRadius = len(Filter) // 2 43 | 44 | def forward(self, x): 45 | Kernel = self.Kernel.view(1, 1, self.Kernel.shape[0], self.Kernel.shape[1]).to(x.dtype) 46 | x = nn.functional.pixel_shuffle(x, 2) 47 | 48 | return nn.functional.conv2d(x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]), Kernel, stride=1, padding=self.FilterRadius).view(*x.shape) 49 | 50 | class InplaceDownsamplerReference(nn.Module): 51 | def __init__(self, Filter): 52 | super(InplaceDownsamplerReference, self).__init__() 53 | 54 | self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=True)) 55 | self.FilterRadius = len(Filter) // 2 56 | 57 | def forward(self, x): 58 | Kernel = self.Kernel.view(1, 1, self.Kernel.shape[0], self.Kernel.shape[1]).to(x.dtype) 59 | y = nn.functional.conv2d(x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]), Kernel, stride=1, padding=self.FilterRadius).view(*x.shape) 60 | 61 | return nn.functional.pixel_unshuffle(y, 2) 62 | 63 | class InterpolativeUpsamplerCUDA(nn.Module): 64 | def __init__(self, Filter): 65 | super(InterpolativeUpsamplerCUDA, self).__init__() 66 | 67 | self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=False)) 68 | 69 | def forward(self, x): 70 | return upfirdn2d.upsample2d(x, self.Kernel) 71 | 72 | class InterpolativeDownsamplerCUDA(nn.Module): 73 | def __init__(self, Filter): 74 | super(InterpolativeDownsamplerCUDA, self).__init__() 75 | 76 | self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=False)) 77 | 78 | def forward(self, x): 79 | return upfirdn2d.downsample2d(x, self.Kernel) 80 | 81 | class InplaceUpsamplerCUDA(nn.Module): 82 | def __init__(self, Filter): 83 | super(InplaceUpsamplerCUDA, self).__init__() 84 | 85 | self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=True)) 86 | self.FilterRadius = len(Filter) // 2 87 | 88 | def forward(self, x): 89 | return upfirdn2d.upfirdn2d(nn.functional.pixel_shuffle(x, 2), self.Kernel, padding=self.FilterRadius) 90 | 91 | class InplaceDownsamplerCUDA(nn.Module): 92 | def __init__(self, Filter): 93 | super(InplaceDownsamplerCUDA, self).__init__() 94 | 95 | self.register_buffer('Kernel', CreateLowpassKernel(Filter, Inplace=True)) 96 | self.FilterRadius = len(Filter) // 2 97 | 98 | def forward(self, x): 99 | return nn.functional.pixel_unshuffle(upfirdn2d.upfirdn2d(x, self.Kernel, padding=self.FilterRadius), 2) 100 | 101 | InterpolativeUpsampler = InterpolativeUpsamplerCUDA 102 | InterpolativeDownsampler = InterpolativeDownsamplerCUDA 103 | InplaceUpsampler = InplaceUpsamplerCUDA 104 | InplaceDownsampler = InplaceDownsamplerCUDA -------------------------------------------------------------------------------- /R3GAN/R3GAN/Trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AdversarialTraining: 5 | def __init__(self, Generator, Discriminator): 6 | self.Generator = Generator 7 | self.Discriminator = Discriminator 8 | 9 | @staticmethod 10 | def ZeroCenteredGradientPenalty(Samples, Critics): 11 | Gradient, = torch.autograd.grad(outputs=Critics.sum(), inputs=Samples, create_graph=True) 12 | return Gradient.square().sum([1, 2, 3]) 13 | 14 | def AccumulateGeneratorGradients(self, Noise, RealSamples, Conditions, Scale=1, Preprocessor=lambda x: x): 15 | FakeSamples = self.Generator(Noise, Conditions) 16 | RealSamples = RealSamples.detach() 17 | 18 | FakeLogits = self.Discriminator(Preprocessor(FakeSamples), Conditions) 19 | RealLogits = self.Discriminator(Preprocessor(RealSamples), Conditions) 20 | 21 | RelativisticLogits = FakeLogits - RealLogits 22 | AdversarialLoss = nn.functional.softplus(-RelativisticLogits) 23 | 24 | (Scale * AdversarialLoss.mean()).backward() 25 | 26 | return [x.detach() for x in [AdversarialLoss, RelativisticLogits]] 27 | 28 | def AccumulateDiscriminatorGradients(self, Noise, RealSamples, Conditions, Gamma, Scale=1, Preprocessor=lambda x: x): 29 | RealSamples = RealSamples.detach().requires_grad_(True) 30 | FakeSamples = self.Generator(Noise, Conditions).detach().requires_grad_(True) 31 | 32 | RealLogits = self.Discriminator(Preprocessor(RealSamples), Conditions) 33 | FakeLogits = self.Discriminator(Preprocessor(FakeSamples), Conditions) 34 | 35 | R1Penalty = AdversarialTraining.ZeroCenteredGradientPenalty(RealSamples, RealLogits) 36 | R2Penalty = AdversarialTraining.ZeroCenteredGradientPenalty(FakeSamples, FakeLogits) 37 | 38 | RelativisticLogits = RealLogits - FakeLogits 39 | AdversarialLoss = nn.functional.softplus(-RelativisticLogits) 40 | 41 | DiscriminatorLoss = AdversarialLoss + (Gamma / 2) * (R1Penalty + R2Penalty) 42 | (Scale * DiscriminatorLoss.mean()).backward() 43 | 44 | return [x.detach() for x in [AdversarialLoss, RelativisticLogits, R1Penalty, R2Penalty]] -------------------------------------------------------------------------------- /R3GAN/R3GAN/__pycache__/FusedOperators.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/R3GAN/__pycache__/FusedOperators.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/R3GAN/__pycache__/FusedOperators.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/R3GAN/__pycache__/FusedOperators.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/R3GAN/__pycache__/Networks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/R3GAN/__pycache__/Networks.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/R3GAN/__pycache__/Networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/R3GAN/__pycache__/Networks.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/R3GAN/__pycache__/Resamplers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/R3GAN/__pycache__/Resamplers.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/R3GAN/__pycache__/Resamplers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/R3GAN/__pycache__/Resamplers.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/R3GAN/__pycache__/Trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/R3GAN/__pycache__/Trainer.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /R3GAN/dnnlib/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/dnnlib/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/dnnlib/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/dnnlib/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/dnnlib/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/dnnlib/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/dnnlib/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/dnnlib/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/gen_timeseries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import List, Optional, Union 4 | import click 5 | import dnnlib 6 | import numpy as np 7 | import PIL.Image 8 | import torch 9 | import legacy 10 | import json 11 | import random 12 | #---------------------------------------------------------------------------- 13 | 14 | def parse_range(s: Union[str, List]) -> List[int]: 15 | if isinstance(s, list): return s 16 | ranges = [] 17 | range_re = re.compile(r'^(\d+)-(\d+)$') 18 | for p in s.split(','): 19 | m = range_re.match(p) 20 | if m: 21 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 22 | else: 23 | ranges.append(int(p)) 24 | return ranges 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | @click.command() 29 | @click.option('--network', 'network_pkl', help='Network pickle filename', default='physioNet.pkl', required=True) 30 | @click.option('--num_images', type=int, help='Number of images to generate', default=1000, required=True) 31 | @click.option('--outdir', help='Where to save the output images', type=str, default='out_physioNet', required=True, metavar='DIR') 32 | def generate_images( 33 | network_pkl: str, 34 | num_images: int, 35 | outdir: str 36 | ): 37 | print('Loading networks from "%s"...' % network_pkl) 38 | device = torch.device('cuda') 39 | with dnnlib.util.open_url(network_pkl) as f: 40 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 41 | 42 | os.makedirs(outdir, exist_ok=True) 43 | 44 | seeds = [random.randint(0, 1000000) for _ in range(num_images)] 45 | 46 | # 生成图像 47 | for seed_idx, seed in enumerate(seeds): 48 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, num_images)) 49 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 50 | label = torch.zeros([1, G.c_dim], device=device) # 无条件生成,标签全为 0 51 | img = G(z, label) 52 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 53 | img_np = img[0].cpu().numpy() 54 | 55 | img_pil = PIL.Image.fromarray(img_np, 'RGB') 56 | 57 | img_resized = img_pil.resize((35, 35), PIL.Image.BICUBIC) 58 | 59 | img_resized_np = np.array(img_resized) 60 | 61 | img_list = img_resized_np.tolist() 62 | 63 | json_filename = f'{outdir}/seed{seed:04d}.json' 64 | with open(json_filename, 'w') as f: 65 | json.dump(img_list, f) 66 | 67 | #---------------------------------------------------------------------------- 68 | 69 | if __name__ == "__main__": 70 | generate_images() -------------------------------------------------------------------------------- /R3GAN/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/frechet_inception_distance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/frechet_inception_distance.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/frechet_inception_distance.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/frechet_inception_distance.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/inception_score.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/inception_score.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/inception_score.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/inception_score.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/kernel_inception_distance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/kernel_inception_distance.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/kernel_inception_distance.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/kernel_inception_distance.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/metric_main.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/metric_main.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/metric_main.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/metric_main.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/metric_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/metric_utils.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/metric_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/metric_utils.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/precision_recall.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/precision_recall.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/__pycache__/precision_recall.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/metrics/__pycache__/precision_recall.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /R3GAN/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /R3GAN/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /R3GAN/metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Main API for computing and reporting quality metrics.""" 10 | 11 | import os 12 | import time 13 | import json 14 | import torch 15 | import dnnlib 16 | 17 | from . import metric_utils 18 | from . import frechet_inception_distance 19 | from . import kernel_inception_distance 20 | from . import precision_recall 21 | from . import inception_score 22 | #---------------------------------------------------------------------------- 23 | 24 | _metric_dict = dict() # name => fn 25 | 26 | def register_metric(fn): 27 | assert callable(fn) 28 | _metric_dict[fn.__name__] = fn 29 | return fn 30 | 31 | def is_valid_metric(metric): 32 | return metric in _metric_dict 33 | 34 | def list_valid_metrics(): 35 | return list(_metric_dict.keys()) 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 40 | assert is_valid_metric(metric) 41 | opts = metric_utils.MetricOptions(**kwargs) 42 | 43 | # Calculate. 44 | start_time = time.time() 45 | results = _metric_dict[metric](opts) 46 | total_time = time.time() - start_time 47 | 48 | # Broadcast results. 49 | for key, value in list(results.items()): 50 | if opts.num_gpus > 1: 51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 52 | torch.distributed.broadcast(tensor=value, src=0) 53 | value = float(value.cpu()) 54 | results[key] = value 55 | 56 | # Decorate with metadata. 57 | return dnnlib.EasyDict( 58 | results = dnnlib.EasyDict(results), 59 | metric = metric, 60 | total_time = total_time, 61 | total_time_str = dnnlib.util.format_time(total_time), 62 | num_gpus = opts.num_gpus, 63 | ) 64 | 65 | #---------------------------------------------------------------------------- 66 | 67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 68 | metric = result_dict['metric'] 69 | assert is_valid_metric(metric) 70 | if run_dir is not None and snapshot_pkl is not None: 71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 72 | 73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 74 | print(jsonl_line) 75 | if run_dir is not None and os.path.isdir(run_dir): 76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 77 | f.write(jsonl_line + '\n') 78 | 79 | #---------------------------------------------------------------------------- 80 | # Recommended metrics. 81 | 82 | @register_metric 83 | def fid50k_full(opts): 84 | opts.dataset_kwargs.update(max_size=None, xflip=False) 85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 86 | return dict(fid50k_full=fid) 87 | 88 | @register_metric 89 | def kid50k_full(opts): 90 | opts.dataset_kwargs.update(max_size=None, xflip=False) 91 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 92 | return dict(kid50k_full=kid) 93 | 94 | @register_metric 95 | def pr50k3_full(opts): 96 | opts.dataset_kwargs.update(max_size=None, xflip=False) 97 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 98 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 99 | 100 | #---------------------------------------------------------------------------- 101 | # Legacy metrics. 102 | 103 | @register_metric 104 | def fid50k(opts): 105 | opts.dataset_kwargs.update(max_size=None) 106 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 107 | return dict(fid50k=fid) 108 | 109 | @register_metric 110 | def kid50k(opts): 111 | opts.dataset_kwargs.update(max_size=None) 112 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 113 | return dict(kid50k=kid) 114 | 115 | @register_metric 116 | def pr50k3(opts): 117 | opts.dataset_kwargs.update(max_size=None) 118 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 119 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 120 | 121 | @register_metric 122 | def is50k(opts): 123 | opts.dataset_kwargs.update(max_size=None, xflip=False) 124 | mean, std = inception_score.compute_is(opts, num_gen=500, num_splits=10) 125 | return dict(is50k_mean=mean, is50k_std=std) 126 | 127 | #---------------------------------------------------------------------------- 128 | -------------------------------------------------------------------------------- /R3GAN/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /R3GAN/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/custom_ops.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/custom_ops.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/custom_ops.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/custom_ops.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/persistence.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/persistence.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/training_stats.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/training_stats.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/__pycache__/training_stats.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/__pycache__/training_stats.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/bias_act.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/bias_act.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/upfirdn2d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/upfirdn2d.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op, _ = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | if _use_pytorch_1_11_api: 62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 64 | else: 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /R3GAN/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /R3GAN/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/augment.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/augment.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/networks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/networks.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/networks.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/training_loop.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/training_loop.cpython-310.pyc -------------------------------------------------------------------------------- /R3GAN/training/__pycache__/training_loop.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/R3GAN/training/__pycache__/training_loop.cpython-39.pyc -------------------------------------------------------------------------------- /R3GAN/training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Loss functions.""" 10 | 11 | from torch_utils import training_stats 12 | from R3GAN.Trainer import AdversarialTraining 13 | import torch 14 | 15 | #---------------------------------------------------------------------------- 16 | 17 | class R3GANLoss: 18 | def __init__(self, G, D, augment_pipe=None): 19 | self.trainer = AdversarialTraining(G, D) 20 | if augment_pipe is not None: 21 | self.preprocessor = lambda x: augment_pipe(x.to(torch.float32)).to(x.dtype) 22 | else: 23 | self.preprocessor = lambda x: x 24 | 25 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gamma, gain): 26 | # G 27 | if phase == 'G': 28 | AdversarialLoss, RelativisticLogits = self.trainer.AccumulateGeneratorGradients(gen_z, real_img, real_c, gain, self.preprocessor) 29 | 30 | training_stats.report('Loss/scores/fake', RelativisticLogits) 31 | training_stats.report('Loss/signs/fake', RelativisticLogits.sign()) 32 | training_stats.report('Loss/G/loss', AdversarialLoss) 33 | 34 | # D 35 | if phase == 'D': 36 | AdversarialLoss, RelativisticLogits, R1Penalty, R2Penalty = self.trainer.AccumulateDiscriminatorGradients(gen_z, real_img, real_c, gamma, gain, self.preprocessor) 37 | 38 | training_stats.report('Loss/scores/real', RelativisticLogits) 39 | training_stats.report('Loss/signs/real', RelativisticLogits.sign()) 40 | training_stats.report('Loss/D/loss', AdversarialLoss) 41 | training_stats.report('Loss/r1_penalty', R1Penalty) 42 | training_stats.report('Loss/r2_penalty', R2Penalty) 43 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /R3GAN/training/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import R3GAN.Networks 5 | 6 | class Generator(nn.Module): 7 | def __init__(self, *args, **kw): 8 | super(Generator, self).__init__() 9 | 10 | config = copy.deepcopy(kw) 11 | del config['FP16Stages'] 12 | del config['c_dim'] 13 | del config['img_resolution'] 14 | 15 | if kw['c_dim'] != 0: 16 | config['ConditionDimension'] = kw['c_dim'] 17 | 18 | self.Model = R3GAN.Networks.Generator(*args, **config) 19 | self.z_dim = kw['NoiseDimension'] 20 | self.c_dim = kw['c_dim'] 21 | self.img_resolution = kw['img_resolution'] 22 | 23 | for x in kw['FP16Stages']: 24 | self.Model.MainLayers[x].DataType = torch.bfloat16 25 | 26 | def forward(self, x, c): 27 | return self.Model(x, c) 28 | 29 | class Discriminator(nn.Module): 30 | def __init__(self, *args, **kw): 31 | super(Discriminator, self).__init__() 32 | 33 | config = copy.deepcopy(kw) 34 | del config['FP16Stages'] 35 | del config['c_dim'] 36 | del config['img_resolution'] 37 | 38 | if kw['c_dim'] != 0: 39 | config['ConditionDimension'] = kw['c_dim'] 40 | 41 | self.Model = R3GAN.Networks.Discriminator(*args, **config) 42 | 43 | for x in kw['FP16Stages']: 44 | self.Model.MainLayers[x].DataType = torch.bfloat16 45 | 46 | def forward(self, x, c): 47 | return self.Model(x, c) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MTSIR3-GAN 2 | 3 | This repository contains the code and resources for the thesis project "Exploring Generative Adversarial Networks for Multivariate Time Series Data Imputation". 4 | 5 | **Core Contribution:** 6 | 7 | This project introduces **MTSIR3-GAN**, a novel approach for Multivariate Time Series Imputation (MTSI). It successfully adapts the modern, principled R3GAN architecture [NeurIPS 2024] – originally designed for image generation – to the unique challenges of temporal data. 8 | 9 | **Key Features:** 10 | 11 | * **Model:** Implements MTSIR3-GAN, leveraging R3GAN's stable training objective (RpGAN + R1 + R2) and modernized convolutional backbone. 12 | * **Innovation:** Features a bespoke time-series patching strategy to enable R3GAN's application to sequential data. 13 | * **Evaluation:** Provides comprehensive experimental results demonstrating MTSIR3-GAN's competitive performance against baselines (TimesNet, SSGAN) on standard benchmark datasets (PhysioNet Challenge 2012, Beijing Air Quality, PSM). 14 | * **Robustness:** Shows effectiveness in handling complex temporal dependencies and robustness towards data anomalies. 15 | * **GUI:** Includes an interactive Dash-based interface for demonstrating the imputation system. 16 | 17 | **Goal:** 18 | 19 | To address the critical problem of missing data in multivariate time series by developing and validating a stable, high-performing GAN-based imputation method. 20 | 21 | **Keywords:** 22 | 23 | Multivariate Time Series, Data Imputation, Generative Adversarial Networks (GANs), Generative Models, Deep Learning PhysioNet, Air Quality, PSM. 24 | -------------------------------------------------------------------------------- /SSGAN/data_loader.py: -------------------------------------------------------------------------------- 1 | import ujson as json 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | choose = 0 7 | missing_rate = 50 8 | dataset = 'AirQuality' 9 | dimension = 36 10 | 11 | 12 | class MySet(Dataset): 13 | def __init__(self): 14 | super(MySet, self).__init__() 15 | self.content = open('./json/json').readlines() 16 | indices = np.arange(len(self.content)) 17 | val_indices = np.random.choice(indices, len(self.content) // 5) 18 | self.val_indices = set(val_indices.tolist()) 19 | def __len__(self): 20 | return len(self.content) 21 | def __getitem__(self, idx): 22 | rec = json.loads(self.content[idx]) 23 | if idx in self.val_indices: 24 | rec['is_train'] = 0 25 | else: 26 | rec['is_train'] = 1 27 | return rec 28 | 29 | class MyTrainSet(Dataset): 30 | def __init__(self): 31 | super(MyTrainSet, self).__init__() 32 | self.content = open('./json/'+dataset+'/'+str(missing_rate)+'_train.json').readlines() 33 | indices = np.arange(len(self.content)) 34 | val_indices = np.random.choice(indices, len(self.content) // 5) 35 | self.val_indices = set(val_indices.tolist()) 36 | def __len__(self): 37 | return len(self.content) 38 | def __getitem__(self, idx): 39 | rec = json.loads(self.content[idx]) 40 | return rec 41 | 42 | class MyTestSet(Dataset): 43 | def __init__(self): 44 | super(MyTestSet, self).__init__() 45 | self.content = open('./json/'+dataset+'/'+str(missing_rate)+'_test.json').readlines() 46 | indices = np.arange(len(self.content)) 47 | val_indices = np.random.choice(indices, len(self.content) // 5) 48 | self.val_indices = set(val_indices.tolist()) 49 | 50 | def __len__(self): 51 | return len(self.content) 52 | 53 | def __getitem__(self, idx): 54 | rec = json.loads(self.content[idx]) 55 | return rec 56 | 57 | def collate_fn(recs): 58 | forward = list(map(lambda x: x['forward'], recs)) 59 | # backward = list(map(lambda x: x['backward'], recs)) 60 | 61 | def to_tensor_dict(recs): 62 | values = torch.FloatTensor( 63 | list(map(lambda r: list(map(lambda x: x['values'], r)), recs))) 64 | masks = torch.FloatTensor( 65 | list(map(lambda r: list(map(lambda x: x['masks'], r)), recs))) 66 | deltas = torch.FloatTensor( 67 | list(map(lambda r: list(map(lambda x: x['deltas'], r)), recs))) 68 | forwards = torch.FloatTensor( 69 | list(map(lambda r: list(map(lambda x: x['forwards'], r)), recs))) 70 | 71 | evals = torch.FloatTensor( 72 | list(map(lambda r: list(map(lambda x: x['evals'], r)), recs))) 73 | eval_masks = torch.FloatTensor( 74 | list(map(lambda r: list(map(lambda x: x['eval_masks'], r)), recs))) 75 | 76 | return { 77 | 'values': values.permute(0,2,1), 78 | 'forwards': forwards.permute(0,2,1), 79 | 'masks': masks.permute(0,2,1), 80 | 'deltas': deltas.permute(0,2,1), 81 | 'evals': evals.permute(0,2,1), 82 | 'eval_masks': eval_masks.permute(0,2,1) 83 | } 84 | # ret_dict = {'forward': to_tensor_dict(forward), 'backward': to_tensor_dict(backward)} 85 | ret_dict = {'forward': to_tensor_dict(forward)} 86 | ret_dict['labels'] = torch.FloatTensor( 87 | list(map(lambda x: x['label'], recs))) 88 | ret_dict['is_train'] = torch.FloatTensor( 89 | list(map(lambda x: x['is_train'], recs))) 90 | return ret_dict 91 | 92 | 93 | def get_loader(batch_size=64, shuffle=True): 94 | data_set = MySet() 95 | data_iter = DataLoader(dataset=data_set, 96 | batch_size=batch_size, 97 | num_workers=1, 98 | shuffle=shuffle, 99 | pin_memory=True, 100 | collate_fn=collate_fn 101 | ) 102 | return data_iter 103 | 104 | def get_train_loader(batch_size=32, shuffle=True): 105 | data_set = MyTrainSet() 106 | data_iter = DataLoader(dataset=data_set, 107 | batch_size=batch_size, 108 | num_workers=1, 109 | shuffle=shuffle, 110 | pin_memory=True, 111 | collate_fn=collate_fn 112 | ) 113 | return data_iter 114 | 115 | def get_test_loader(batch_size=32, shuffle=False): 116 | data_set = MyTestSet() 117 | data_iter = DataLoader(dataset=data_set, 118 | batch_size=batch_size, 119 | num_workers=1, 120 | shuffle=shuffle, 121 | pin_memory=True, 122 | collate_fn=collate_fn 123 | ) 124 | 125 | return data_iter 126 | -------------------------------------------------------------------------------- /SSGAN/models/Based_on_BRITS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import models.rits as rits 5 | 6 | class Generator(nn.Module): 7 | def __init__(self, rnn_hid_size, impute_weight, label_weight): 8 | super(Generator, self).__init__() 9 | 10 | self.rnn_hid_size = 108 11 | self.impute_weight = torch.tensor(0.3) 12 | self.label_weight = torch.tensor(1.0) 13 | 14 | self.build() 15 | 16 | def build(self): 17 | self.rits_f = rits.Model(self.rnn_hid_size, self.impute_weight, self.label_weight) 18 | self.rits_b = rits.Model(self.rnn_hid_size, self.impute_weight, self.label_weight) 19 | 20 | def forward(self, data): 21 | ret_f = self.rits_f(data, 'forward') 22 | ret_b = self.reverse(self.rits_b(data, 'backward')) 23 | 24 | ret = self.merge_ret(ret_f, ret_b) 25 | 26 | return ret 27 | 28 | def merge_ret(self, ret_f, ret_b): 29 | loss_f = ret_f['loss'] 30 | loss_b = ret_b['loss'] 31 | loss_c = self.get_consistency_loss(ret_f['imputations'], ret_b['imputations']) 32 | 33 | loss = loss_f + loss_b + loss_c 34 | 35 | predictions = (ret_f['predictions'] + ret_b['predictions']) / 2 36 | imputations = (ret_f['imputations'] + ret_b['imputations']) / 2 37 | 38 | ret_f['loss'] = loss 39 | ret_f['predictions'] = predictions 40 | ret_f['imputations'] = imputations 41 | 42 | return ret_f 43 | 44 | def get_consistency_loss(self, pred_f, pred_b): 45 | loss = torch.abs(pred_f - pred_b).mean() * 1e-1 46 | return loss 47 | 48 | def reverse(self, ret): 49 | def reverse_tensor(tensor_): 50 | if tensor_.dim() <= 1: 51 | return tensor_ 52 | indices = range(tensor_.size()[1])[::-1] 53 | indices = Variable(torch.LongTensor(indices), requires_grad = False) 54 | 55 | if torch.cuda.is_available(): 56 | indices = indices.cuda() 57 | 58 | return tensor_.index_select(1, indices) 59 | 60 | for key in ret: 61 | ret[key] = reverse_tensor(ret[key]) 62 | 63 | return ret 64 | 65 | def run_on_batch(self, data): 66 | ret = self(data) 67 | 68 | # if optimizer is not None: 69 | # optimizer.zero_grad() 70 | # ret['loss'].backward() 71 | # optimizer.step() 72 | 73 | return ret 74 | -------------------------------------------------------------------------------- /SSGAN/models/__init__.py: -------------------------------------------------------------------------------- 1 | # import rits_i, brits_i, rits, brits, gru_d, m_rnn 2 | from models.brits import * 3 | from models.Based_on_BRITS import * 4 | from models.discriminator import * 5 | from models.classifier import * 6 | from models.brits_i import * 7 | from models.rits import * 8 | from models.rits_i import * 9 | from models.gru_d import * 10 | from models.m_rnn import * 11 | -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/Based_on_BRITS.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/Based_on_BRITS.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/Based_on_BRITS.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/Based_on_BRITS.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/brits.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/brits.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/brits.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/brits.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/brits_i.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/brits_i.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/brits_i.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/brits_i.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/classifier.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/classifier.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/classifier.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/classifier.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/discriminator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/discriminator.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/discriminator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/discriminator.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/gru_d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/gru_d.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/gru_d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/gru_d.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/m_rnn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/m_rnn.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/m_rnn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/m_rnn.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/rits.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/rits.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/rits.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/rits.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/rits_i.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/rits_i.cpython-310.pyc -------------------------------------------------------------------------------- /SSGAN/models/__pycache__/rits_i.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/SSGAN/models/__pycache__/rits_i.cpython-39.pyc -------------------------------------------------------------------------------- /SSGAN/models/brits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import models.rits as rits 5 | 6 | SEQ_LEN = 72 7 | RNN_HID_SIZE = 64 8 | 9 | 10 | class Model(nn.Module): 11 | def __init__(self, rnn_hid_size, impute_weight, label_weight): 12 | super(Model, self).__init__() 13 | 14 | self.rnn_hid_size = rnn_hid_size 15 | self.impute_weight = impute_weight 16 | self.label_weight = label_weight 17 | 18 | self.build() 19 | 20 | def build(self): 21 | self.rits_f = rits.Model(self.rnn_hid_size, self.impute_weight, self.label_weight) 22 | self.rits_b = rits.Model(self.rnn_hid_size, self.impute_weight, self.label_weight) 23 | 24 | def forward(self, data): 25 | ret_f = self.rits_f(data, 'forward') 26 | ret_b = self.reverse(self.rits_b(data, 'backward')) 27 | 28 | ret = self.merge_ret(ret_f, ret_b) 29 | 30 | return ret 31 | 32 | def merge_ret(self, ret_f, ret_b): 33 | loss_f = ret_f['loss'] 34 | loss_b = ret_b['loss'] 35 | loss_c = self.get_consistency_loss(ret_f['imputations'], ret_b['imputations']) 36 | 37 | loss = loss_f + loss_b + loss_c 38 | 39 | predictions = (ret_f['predictions'] + ret_b['predictions']) / 2 40 | imputations = (ret_f['imputations'] + ret_b['imputations']) / 2 41 | 42 | ret_f['loss'] = loss 43 | ret_f['predictions'] = predictions 44 | ret_f['imputations'] = imputations 45 | 46 | return ret_f 47 | 48 | def get_consistency_loss(self, pred_f, pred_b): 49 | loss = torch.abs(pred_f - pred_b).mean() * 1e-1 50 | return loss 51 | 52 | def reverse(self, ret): 53 | def reverse_tensor(tensor_): 54 | if tensor_.dim() <= 1: 55 | return tensor_ 56 | indices = range(tensor_.size()[1])[::-1] 57 | indices = Variable(torch.LongTensor(indices), requires_grad = False) 58 | 59 | if torch.cuda.is_available(): 60 | indices = indices.cuda() 61 | 62 | return tensor_.index_select(1, indices) 63 | 64 | for key in ret: 65 | ret[key] = reverse_tensor(ret[key]) 66 | 67 | return ret 68 | 69 | def run_on_batch(self, data, optimizer, epoch=None): 70 | ret = self(data) 71 | 72 | if optimizer is not None: 73 | optimizer.zero_grad() 74 | ret['loss'].backward() 75 | optimizer.step() 76 | 77 | return ret 78 | 79 | -------------------------------------------------------------------------------- /SSGAN/models/brits_i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from torch.autograd import Variable 7 | from torch.nn.parameter import Parameter 8 | 9 | import math 10 | import utils 11 | import argparse 12 | import data_loader 13 | 14 | # import rits_i 15 | import models.rits_i as rits_i 16 | from sklearn import metrics 17 | 18 | # from ipdb import set_trace 19 | 20 | SEQ_LEN = 30 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, rnn_hid_size, impute_weight, label_weight): 25 | super(Model, self).__init__() 26 | 27 | self.rnn_hid_size = rnn_hid_size 28 | self.impute_weight = impute_weight 29 | self.label_weight = label_weight 30 | 31 | self.build() 32 | 33 | def build(self): 34 | self.rits_f = rits_i.Model(self.rnn_hid_size, self.impute_weight, self.label_weight) 35 | self.rits_b = rits_i.Model(self.rnn_hid_size, self.impute_weight, self.label_weight) 36 | 37 | def forward(self, data): 38 | ret_f = self.rits_f(data, 'forward') 39 | ret_b = self.reverse(self.rits_b(data, 'backward')) 40 | 41 | ret = self.merge_ret(ret_f, ret_b) 42 | 43 | return ret 44 | 45 | def merge_ret(self, ret_f, ret_b): 46 | loss_f = ret_f['loss'] 47 | loss_b = ret_b['loss'] 48 | loss_c = self.get_consistency_loss(ret_f['imputations'], ret_b['imputations']) 49 | 50 | loss = loss_f + loss_b + loss_c 51 | 52 | predictions = (ret_f['predictions'] + ret_b['predictions']) / 2 53 | imputations = (ret_f['imputations'] + ret_b['imputations']) / 2 54 | 55 | ret_f['loss'] = loss 56 | ret_f['predictions'] = predictions 57 | ret_f['imputations'] = imputations 58 | 59 | return ret_f 60 | 61 | def get_consistency_loss(self, pred_f, pred_b): 62 | loss = torch.abs(pred_f - pred_b).mean() * 1e-1 63 | return loss 64 | 65 | def reverse(self, ret): 66 | def reverse_tensor(tensor_): 67 | if tensor_.dim() <= 1: 68 | return tensor_ 69 | indices = range(tensor_.size()[1])[::-1] 70 | indices = Variable(torch.LongTensor(indices), requires_grad = False) 71 | 72 | if torch.cuda.is_available(): 73 | indices = indices.cuda() 74 | 75 | return tensor_.index_select(1, indices) 76 | 77 | for key in ret: 78 | ret[key] = reverse_tensor(ret[key]) 79 | 80 | return ret 81 | 82 | def run_on_batch(self, data, optimizer, epoch=None): 83 | ret = self(data) 84 | 85 | if optimizer is not None: 86 | optimizer.zero_grad() 87 | ret['loss'].backward() 88 | optimizer.step() 89 | 90 | return ret 91 | 92 | -------------------------------------------------------------------------------- /SSGAN/models/rits_i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from torch.autograd import Variable 7 | from torch.nn.parameter import Parameter 8 | 9 | import math 10 | import utils 11 | import argparse 12 | import data_loader 13 | 14 | # from ipdb import set_trace 15 | from sklearn import metrics 16 | 17 | SEQ_LEN = 30 18 | INPUT_SIZE =1 19 | 20 | def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True, reduce=True): 21 | if not (target.size() == input.size()): 22 | raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size())) 23 | 24 | max_val = (-input).clamp(min=0) 25 | loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() 26 | 27 | if weight is not None: 28 | loss = loss * weight 29 | 30 | if not reduce: 31 | return loss 32 | elif size_average: 33 | return loss.mean() 34 | else: 35 | return loss.sum() 36 | 37 | 38 | class TemporalDecay(nn.Module): 39 | def __init__(self, input_size, rnn_hid_size): 40 | super(TemporalDecay, self).__init__() 41 | self.rnn_hid_size = rnn_hid_size 42 | self.build(input_size) 43 | 44 | def build(self, input_size): 45 | self.W = Parameter(torch.Tensor(self.rnn_hid_size, input_size)) 46 | self.b = Parameter(torch.Tensor(self.rnn_hid_size)) 47 | self.reset_parameters() 48 | 49 | def reset_parameters(self): 50 | stdv = 1. / math.sqrt(self.W.size(0)) 51 | self.W.data.uniform_(-stdv, stdv) 52 | if self.b is not None: 53 | self.b.data.uniform_(-stdv, stdv) 54 | 55 | def forward(self, d): 56 | gamma = F.relu(F.linear(d, self.W, self.b)) 57 | gamma = torch.exp(-gamma) 58 | return gamma 59 | 60 | class Model(nn.Module): 61 | def __init__(self, rnn_hid_size, impute_weight, label_weight): 62 | super(Model, self).__init__() 63 | 64 | self.rnn_hid_size = rnn_hid_size 65 | self.impute_weight = impute_weight 66 | self.label_weight = label_weight 67 | 68 | self.build() 69 | 70 | def build(self): 71 | self.rnn_cell = nn.LSTMCell(INPUT_SIZE * 2, self.rnn_hid_size) 72 | 73 | self.regression = nn.Linear(self.rnn_hid_size, INPUT_SIZE) 74 | self.temp_decay = TemporalDecay(input_size = INPUT_SIZE, rnn_hid_size = self.rnn_hid_size) 75 | 76 | self.out = nn.Linear(self.rnn_hid_size, 1) 77 | 78 | def forward(self, data, direct): 79 | # Original sequence with 24 time steps 80 | values = data[direct]['values'] 81 | masks = data[direct]['masks'] 82 | deltas = data[direct]['deltas'] 83 | 84 | evals = data[direct]['evals'] 85 | eval_masks = data[direct]['eval_masks'] 86 | 87 | labels = data['labels'].view(-1, 1) 88 | is_train = data['is_train'].view(-1, 1) 89 | 90 | h = Variable(torch.zeros((values.size()[0], self.rnn_hid_size))) 91 | c = Variable(torch.zeros((values.size()[0], self.rnn_hid_size))) 92 | 93 | if torch.cuda.is_available(): 94 | h, c = h.cuda(), c.cuda() 95 | 96 | x_loss = 0.0 97 | y_loss = 0.0 98 | 99 | imputations = [] 100 | 101 | for t in range(SEQ_LEN): 102 | x = values[:, t, :] 103 | m = masks[:, t, :] 104 | d = deltas[:, t, :] 105 | 106 | gamma = self.temp_decay(d) 107 | h = h * gamma 108 | x_h = self.regression(h) 109 | 110 | x_c = m * x + (1 - m) * x_h 111 | 112 | x_loss += torch.sum(torch.abs(x - x_h) * m) / (torch.sum(m) + 1e-5) 113 | 114 | inputs = torch.cat([x_c, m], dim = 1) 115 | 116 | h, c = self.rnn_cell(inputs, (h, c)) 117 | 118 | imputations.append(x_c.unsqueeze(dim = 1)) 119 | 120 | imputations = torch.cat(imputations, dim = 1) 121 | 122 | y_h = self.out(h) 123 | y_loss = binary_cross_entropy_with_logits(y_h, labels, reduce = False) 124 | 125 | # only use training labels 126 | y_loss = torch.sum(y_loss * is_train) / (torch.sum(is_train) + 1e-5) 127 | 128 | y_h = F.sigmoid(y_h) 129 | 130 | return {'loss': x_loss * self.impute_weight + y_loss * self.label_weight, 'predictions': y_h,\ 131 | 'imputations': imputations, 'labels': labels, 'is_train': is_train,\ 132 | 'evals': evals, 'eval_masks': eval_masks} 133 | 134 | def run_on_batch(self, data, optimizer, epoch = None): 135 | ret = self(data, direct = 'forward') 136 | 137 | if optimizer is not None: 138 | optimizer.zero_grad() 139 | ret['loss'].backward() 140 | optimizer.step() 141 | 142 | return ret 143 | -------------------------------------------------------------------------------- /SSGAN/preprocess.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | from datetime import datetime 4 | 5 | def calculate_deltas(df): 6 | """计算delta矩阵""" 7 | deltas = pd.DataFrame(0, index=df.index, columns=df.columns[1:]) # 第一行全0 8 | for i in range(1, len(df)): 9 | for col in df.columns[1:]: 10 | prev_mask = 1 if pd.notna(df.at[i-1, col]) else 0 11 | if prev_mask == 1: 12 | deltas.at[i, col] = 1 # 假设时间间隔为1小时 13 | else: 14 | deltas.at[i, col] = deltas.at[i-1, col] + 1 15 | return deltas 16 | 17 | def process_data(input_file, output_file): 18 | df = pd.read_csv(input_file) 19 | samples = [] 20 | deltas = calculate_deltas(df) # 计算delta矩阵 21 | 22 | for i in range(0, len(df), 24): 23 | sample_df = df.iloc[i:i+24] 24 | delta_sample = deltas.iloc[i:i+24] 25 | forward = [] 26 | 27 | for idx, (row, delta_row) in enumerate(zip(sample_df.iterrows(), delta_sample.iterrows())): 28 | _, data_row = row 29 | _, delta_data = delta_row 30 | evals = [] 31 | masks = [] 32 | values = [] 33 | eval_masks = [] 34 | forwards = [] 35 | 36 | for col in data_row.index[1:]: # 跳过datetime列 37 | val = data_row[col] 38 | if pd.notna(val): 39 | masks.append(1) 40 | values.append(float(val)) 41 | evals.append(float(val)) 42 | eval_masks.append(1) 43 | forwards.append(float(val)) 44 | else: 45 | masks.append(0) 46 | values.append(0.0) 47 | evals.append(0.0) 48 | eval_masks.append(0) 49 | forwards.append(0.0) 50 | 51 | forward_entry = { 52 | 'evals': evals, 53 | 'deltas': list(delta_data), 54 | 'forwards': forwards, 55 | 'masks': masks, 56 | 'values': values, 57 | 'eval_masks': eval_masks 58 | } 59 | forward.append(forward_entry) 60 | 61 | sample_json = { 62 | 'forward': forward, 63 | 'label': 0, # 可根据实际情况修改 64 | 'is_train': 0.0 # 可根据实际情况修改 65 | } 66 | samples.append(sample_json) 67 | 68 | with open(output_file, 'w') as f: 69 | for sample in samples: 70 | f.write(json.dumps(sample, ensure_ascii=False) + '\n') 71 | 72 | if __name__ == "__main__": 73 | input_file = 'D:\Project\GAN for ts imputation\Generative-Semi-supervised-Learning-for-Multivariate-Time-Series-Imputation-main\datasets\AirQuality\pm25_ground.txt' # 替换为实际输入文件名 74 | output_file = 'output.json' 75 | process_data(input_file, output_file) -------------------------------------------------------------------------------- /SSGAN/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | def to_var(var): 5 | if torch.is_tensor(var): 6 | var = Variable(var) 7 | if torch.cuda.is_available(): 8 | var = var.cuda() 9 | return var 10 | if isinstance(var, int) or isinstance(var, float) or isinstance(var, str): 11 | return var 12 | if isinstance(var, dict): 13 | for key in var: 14 | var[key] = to_var(var[key]) 15 | return var 16 | if isinstance(var, list): 17 | var = map(lambda x: to_var(x), var) 18 | return var 19 | 20 | def stop_gradient(x): 21 | if isinstance(x, float): 22 | return x 23 | if isinstance(x, tuple): 24 | return tuple(map(lambda y: Variable(y.data), x)) 25 | return Variable(x.data) 26 | 27 | def zero_var(sz): 28 | x = Variable(torch.zeros(sz)) 29 | if torch.cuda.is_available(): 30 | x = x.cuda() 31 | return x 32 | -------------------------------------------------------------------------------- /TimesNet/data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /TimesNet/data_provider/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/data_provider/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/data_provider/__pycache__/data_factory.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/data_provider/__pycache__/data_factory.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/data_provider/__pycache__/data_loader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/data_provider/__pycache__/data_loader.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/data_provider/__pycache__/m4.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/data_provider/__pycache__/m4.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/data_provider/__pycache__/uea.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/data_provider/__pycache__/uea.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \ 2 | MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader 3 | from data_provider.uea import collate_fn 4 | from torch.utils.data import DataLoader 5 | 6 | data_dict = { 7 | 'ETTh1': Dataset_ETT_hour, 8 | 'ETTh2': Dataset_ETT_hour, 9 | 'ETTm1': Dataset_ETT_minute, 10 | 'ETTm2': Dataset_ETT_minute, 11 | 'custom': Dataset_Custom, 12 | 'm4': Dataset_M4, 13 | 'PSM': PSMSegLoader, 14 | 'MSL': MSLSegLoader, 15 | 'SMAP': SMAPSegLoader, 16 | 'SMD': SMDSegLoader, 17 | 'SWAT': SWATSegLoader, 18 | 'UEA': UEAloader 19 | } 20 | 21 | 22 | def data_provider(args, flag): 23 | Data = data_dict[args.data] 24 | timeenc = 0 if args.embed != 'timeF' else 1 25 | 26 | shuffle_flag = False if (flag == 'test' or flag == 'TEST') else True 27 | drop_last = False 28 | batch_size = args.batch_size 29 | freq = args.freq 30 | 31 | if args.task_name == 'anomaly_detection': 32 | drop_last = False 33 | data_set = Data( 34 | args = args, 35 | root_path=args.root_path, 36 | win_size=args.seq_len, 37 | flag=flag, 38 | ) 39 | print(flag, len(data_set)) 40 | data_loader = DataLoader( 41 | data_set, 42 | batch_size=batch_size, 43 | shuffle=shuffle_flag, 44 | num_workers=args.num_workers, 45 | drop_last=drop_last) 46 | return data_set, data_loader 47 | elif args.task_name == 'classification': 48 | drop_last = False 49 | data_set = Data( 50 | args = args, 51 | root_path=args.root_path, 52 | flag=flag, 53 | ) 54 | 55 | data_loader = DataLoader( 56 | data_set, 57 | batch_size=batch_size, 58 | shuffle=shuffle_flag, 59 | num_workers=args.num_workers, 60 | drop_last=drop_last, 61 | collate_fn=lambda x: collate_fn(x, max_len=args.seq_len) 62 | ) 63 | return data_set, data_loader 64 | else: 65 | if args.data == 'm4': 66 | drop_last = False 67 | data_set = Data( 68 | args = args, 69 | root_path=args.root_path, 70 | data_path=args.data_path, 71 | flag=flag, 72 | size=[args.seq_len, args.label_len, args.pred_len], 73 | features=args.features, 74 | target=args.target, 75 | timeenc=timeenc, 76 | freq=freq, 77 | seasonal_patterns=args.seasonal_patterns 78 | ) 79 | print(flag, len(data_set)) 80 | data_loader = DataLoader( 81 | data_set, 82 | batch_size=batch_size, 83 | shuffle=shuffle_flag, 84 | num_workers=args.num_workers, 85 | drop_last=drop_last) 86 | return data_set, data_loader 87 | -------------------------------------------------------------------------------- /TimesNet/data_provider/m4.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Dataset 17 | """ 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | from dataclasses import dataclass 22 | from glob import glob 23 | 24 | import numpy as np 25 | import pandas as pd 26 | import patoolib 27 | from tqdm import tqdm 28 | import logging 29 | import os 30 | import pathlib 31 | import sys 32 | from urllib import request 33 | 34 | 35 | def url_file_name(url: str) -> str: 36 | """ 37 | Extract file name from url. 38 | 39 | :param url: URL to extract file name from. 40 | :return: File name. 41 | """ 42 | return url.split('/')[-1] if len(url) > 0 else '' 43 | 44 | 45 | def download(url: str, file_path: str) -> None: 46 | """ 47 | Download a file to the given path. 48 | 49 | :param url: URL to download 50 | :param file_path: Where to download the content. 51 | """ 52 | 53 | def progress(count, block_size, total_size): 54 | progress_pct = float(count * block_size) / float(total_size) * 100.0 55 | sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct)) 56 | sys.stdout.flush() 57 | 58 | if not os.path.isfile(file_path): 59 | opener = request.build_opener() 60 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 61 | request.install_opener(opener) 62 | pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True) 63 | f, _ = request.urlretrieve(url, file_path, progress) 64 | sys.stdout.write('\n') 65 | sys.stdout.flush() 66 | file_info = os.stat(f) 67 | logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.') 68 | else: 69 | file_info = os.stat(file_path) 70 | logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.') 71 | 72 | 73 | @dataclass() 74 | class M4Dataset: 75 | ids: np.ndarray 76 | groups: np.ndarray 77 | frequencies: np.ndarray 78 | horizons: np.ndarray 79 | values: np.ndarray 80 | 81 | @staticmethod 82 | def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset': 83 | """ 84 | Load cached dataset. 85 | 86 | :param training: Load training part if training is True, test part otherwise. 87 | """ 88 | info_file = os.path.join(dataset_file, 'M4-info.csv') 89 | train_cache_file = os.path.join(dataset_file, 'training.npz') 90 | test_cache_file = os.path.join(dataset_file, 'test.npz') 91 | m4_info = pd.read_csv(info_file) 92 | return M4Dataset(ids=m4_info.M4id.values, 93 | groups=m4_info.SP.values, 94 | frequencies=m4_info.Frequency.values, 95 | horizons=m4_info.Horizon.values, 96 | values=np.load( 97 | train_cache_file if training else test_cache_file, 98 | allow_pickle=True)) 99 | 100 | 101 | @dataclass() 102 | class M4Meta: 103 | seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly'] 104 | horizons = [6, 8, 18, 13, 14, 48] 105 | frequencies = [1, 4, 12, 1, 1, 24] 106 | horizons_map = { 107 | 'Yearly': 6, 108 | 'Quarterly': 8, 109 | 'Monthly': 18, 110 | 'Weekly': 13, 111 | 'Daily': 14, 112 | 'Hourly': 48 113 | } # different predict length 114 | frequency_map = { 115 | 'Yearly': 1, 116 | 'Quarterly': 4, 117 | 'Monthly': 12, 118 | 'Weekly': 1, 119 | 'Daily': 1, 120 | 'Hourly': 24 121 | } 122 | history_size = { 123 | 'Yearly': 1.5, 124 | 'Quarterly': 1.5, 125 | 'Monthly': 1.5, 126 | 'Weekly': 10, 127 | 'Daily': 10, 128 | 'Hourly': 10 129 | } # from interpretable.gin 130 | 131 | 132 | def load_m4_info() -> pd.DataFrame: 133 | """ 134 | Load M4Info file. 135 | 136 | :return: Pandas DataFrame of M4Info. 137 | """ 138 | return pd.read_csv(INFO_FILE_PATH) 139 | -------------------------------------------------------------------------------- /TimesNet/exp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/exp/__init__.py -------------------------------------------------------------------------------- /TimesNet/exp/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/exp/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/exp/__pycache__/exp_anomaly_detection.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/exp/__pycache__/exp_anomaly_detection.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/exp/__pycache__/exp_basic.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/exp/__pycache__/exp_basic.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/exp/__pycache__/exp_classification.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/exp/__pycache__/exp_classification.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/exp/__pycache__/exp_imputation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/exp/__pycache__/exp_imputation.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/exp/__pycache__/exp_long_term_forecasting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/exp/__pycache__/exp_long_term_forecasting.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/exp/__pycache__/exp_short_term_forecasting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/exp/__pycache__/exp_short_term_forecasting.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \ 4 | Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \ 5 | Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, PAttn, TimeXer, \ 6 | WPMixer, MultiPatchFormer 7 | from models import GAN 8 | 9 | 10 | class Exp_Basic(object): 11 | def __init__(self, args): 12 | self.args = args 13 | self.model_dict = { 14 | 'TimesNet': TimesNet, 15 | 'Autoformer': Autoformer, 16 | 'Transformer': Transformer, 17 | 'Nonstationary_Transformer': Nonstationary_Transformer, 18 | 'DLinear': DLinear, 19 | 'FEDformer': FEDformer, 20 | 'Informer': Informer, 21 | 'LightTS': LightTS, 22 | 'Reformer': Reformer, 23 | 'ETSformer': ETSformer, 24 | 'PatchTST': PatchTST, 25 | 'Pyraformer': Pyraformer, 26 | 'MICN': MICN, 27 | 'Crossformer': Crossformer, 28 | 'FiLM': FiLM, 29 | 'iTransformer': iTransformer, 30 | 'Koopa': Koopa, 31 | 'TiDE': TiDE, 32 | 'FreTS': FreTS, 33 | 'MambaSimple': MambaSimple, 34 | 'TimeMixer': TimeMixer, 35 | 'TSMixer': TSMixer, 36 | 'SegRNN': SegRNN, 37 | 'TemporalFusionTransformer': TemporalFusionTransformer, 38 | "SCINet": SCINet, 39 | 'PAttn': PAttn, 40 | 'TimeXer': TimeXer, 41 | 'WPMixer': WPMixer, 42 | 'MultiPatchFormer': MultiPatchFormer, 43 | 'GAN':GAN 44 | } 45 | if args.model == 'Mamba': 46 | print('Please make sure you have successfully installed mamba_ssm') 47 | from models import Mamba 48 | self.model_dict['Mamba'] = Mamba 49 | 50 | self.device = self._acquire_device() 51 | self.model = self._build_model().to(self.device) 52 | 53 | def _build_model(self): 54 | raise NotImplementedError 55 | return None 56 | 57 | def _acquire_device(self): 58 | if self.args.use_gpu and self.args.gpu_type == 'cuda': 59 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 60 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 61 | device = torch.device('cuda:{}'.format(self.args.gpu)) 62 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 63 | elif self.args.use_gpu and self.args.gpu_type == 'mps': 64 | device = torch.device('mps') 65 | print('Use GPU: mps') 66 | else: 67 | device = torch.device('cpu') 68 | print('Use CPU') 69 | return device 70 | 71 | def _get_data(self): 72 | pass 73 | 74 | def vali(self): 75 | pass 76 | 77 | def train(self): 78 | pass 79 | 80 | def test(self): 81 | pass 82 | -------------------------------------------------------------------------------- /TimesNet/layers/Conv_Blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Inception_Block_V1(nn.Module): 6 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): 7 | super(Inception_Block_V1, self).__init__() 8 | self.in_channels = in_channels 9 | self.out_channels = out_channels 10 | self.num_kernels = num_kernels 11 | kernels = [] 12 | for i in range(self.num_kernels): 13 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) 14 | self.kernels = nn.ModuleList(kernels) 15 | if init_weight: 16 | self._initialize_weights() 17 | 18 | def _initialize_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 22 | if m.bias is not None: 23 | nn.init.constant_(m.bias, 0) 24 | 25 | def forward(self, x): 26 | res_list = [] 27 | for i in range(self.num_kernels): 28 | res_list.append(self.kernels[i](x)) 29 | res = torch.stack(res_list, dim=-1).mean(-1) 30 | return res 31 | 32 | 33 | class Inception_Block_V2(nn.Module): 34 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): 35 | super(Inception_Block_V2, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.num_kernels = num_kernels 39 | kernels = [] 40 | for i in range(self.num_kernels // 2): 41 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1])) 42 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0])) 43 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) 44 | self.kernels = nn.ModuleList(kernels) 45 | if init_weight: 46 | self._initialize_weights() 47 | 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | 55 | def forward(self, x): 56 | res_list = [] 57 | for i in range(self.num_kernels // 2 * 2 + 1): 58 | res_list.append(self.kernels[i](x)) 59 | res = torch.stack(res_list, dim=-1).mean(-1) 60 | return res 61 | -------------------------------------------------------------------------------- /TimesNet/layers/Crossformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange, repeat 4 | from layers.SelfAttention_Family import TwoStageAttentionLayer 5 | 6 | 7 | class SegMerging(nn.Module): 8 | def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): 9 | super().__init__() 10 | self.d_model = d_model 11 | self.win_size = win_size 12 | self.linear_trans = nn.Linear(win_size * d_model, d_model) 13 | self.norm = norm_layer(win_size * d_model) 14 | 15 | def forward(self, x): 16 | batch_size, ts_d, seg_num, d_model = x.shape 17 | pad_num = seg_num % self.win_size 18 | if pad_num != 0: 19 | pad_num = self.win_size - pad_num 20 | x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2) 21 | 22 | seg_to_merge = [] 23 | for i in range(self.win_size): 24 | seg_to_merge.append(x[:, :, i::self.win_size, :]) 25 | x = torch.cat(seg_to_merge, -1) 26 | 27 | x = self.norm(x) 28 | x = self.linear_trans(x) 29 | 30 | return x 31 | 32 | 33 | class scale_block(nn.Module): 34 | def __init__(self, configs, win_size, d_model, n_heads, d_ff, depth, dropout, \ 35 | seg_num=10, factor=10): 36 | super(scale_block, self).__init__() 37 | 38 | if win_size > 1: 39 | self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) 40 | else: 41 | self.merge_layer = None 42 | 43 | self.encode_layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.encode_layers.append(TwoStageAttentionLayer(configs, seg_num, factor, d_model, n_heads, \ 47 | d_ff, dropout)) 48 | 49 | def forward(self, x, attn_mask=None, tau=None, delta=None): 50 | _, ts_dim, _, _ = x.shape 51 | 52 | if self.merge_layer is not None: 53 | x = self.merge_layer(x) 54 | 55 | for layer in self.encode_layers: 56 | x = layer(x) 57 | 58 | return x, None 59 | 60 | 61 | class Encoder(nn.Module): 62 | def __init__(self, attn_layers): 63 | super(Encoder, self).__init__() 64 | self.encode_blocks = nn.ModuleList(attn_layers) 65 | 66 | def forward(self, x): 67 | encode_x = [] 68 | encode_x.append(x) 69 | 70 | for block in self.encode_blocks: 71 | x, attns = block(x) 72 | encode_x.append(x) 73 | 74 | return encode_x, None 75 | 76 | 77 | class DecoderLayer(nn.Module): 78 | def __init__(self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1): 79 | super(DecoderLayer, self).__init__() 80 | self.self_attention = self_attention 81 | self.cross_attention = cross_attention 82 | self.norm1 = nn.LayerNorm(d_model) 83 | self.norm2 = nn.LayerNorm(d_model) 84 | self.dropout = nn.Dropout(dropout) 85 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model), 86 | nn.GELU(), 87 | nn.Linear(d_model, d_model)) 88 | self.linear_pred = nn.Linear(d_model, seg_len) 89 | 90 | def forward(self, x, cross): 91 | batch = x.shape[0] 92 | x = self.self_attention(x) 93 | x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model') 94 | 95 | cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model') 96 | tmp, attn = self.cross_attention(x, cross, cross, None, None, None,) 97 | x = x + self.dropout(tmp) 98 | y = x = self.norm1(x) 99 | y = self.MLP1(y) 100 | dec_output = self.norm2(x + y) 101 | 102 | dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b=batch) 103 | layer_predict = self.linear_pred(dec_output) 104 | layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len') 105 | 106 | return dec_output, layer_predict 107 | 108 | 109 | class Decoder(nn.Module): 110 | def __init__(self, layers): 111 | super(Decoder, self).__init__() 112 | self.decode_layers = nn.ModuleList(layers) 113 | 114 | 115 | def forward(self, x, cross): 116 | final_predict = None 117 | i = 0 118 | 119 | ts_d = x.shape[1] 120 | for layer in self.decode_layers: 121 | cross_enc = cross[i] 122 | x, layer_predict = layer(x, cross_enc) 123 | if final_predict is None: 124 | final_predict = layer_predict 125 | else: 126 | final_predict = final_predict + layer_predict 127 | i += 1 128 | 129 | final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d=ts_d) 130 | 131 | return final_predict 132 | -------------------------------------------------------------------------------- /TimesNet/layers/StandardNorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Normalize(nn.Module): 6 | def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): 7 | """ 8 | :param num_features: the number of features or channels 9 | :param eps: a value added for numerical stability 10 | :param affine: if True, RevIN has learnable affine parameters 11 | """ 12 | super(Normalize, self).__init__() 13 | self.num_features = num_features 14 | self.eps = eps 15 | self.affine = affine 16 | self.subtract_last = subtract_last 17 | self.non_norm = non_norm 18 | if self.affine: 19 | self._init_params() 20 | 21 | def forward(self, x, mode: str): 22 | if mode == 'norm': 23 | self._get_statistics(x) 24 | x = self._normalize(x) 25 | elif mode == 'denorm': 26 | x = self._denormalize(x) 27 | else: 28 | raise NotImplementedError 29 | return x 30 | 31 | def _init_params(self): 32 | # initialize RevIN params: (C,) 33 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 34 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 35 | 36 | def _get_statistics(self, x): 37 | dim2reduce = tuple(range(1, x.ndim - 1)) 38 | if self.subtract_last: 39 | self.last = x[:, -1, :].unsqueeze(1) 40 | else: 41 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 42 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() 43 | 44 | def _normalize(self, x): 45 | if self.non_norm: 46 | return x 47 | if self.subtract_last: 48 | x = x - self.last 49 | else: 50 | x = x - self.mean 51 | x = x / self.stdev 52 | if self.affine: 53 | x = x * self.affine_weight 54 | x = x + self.affine_bias 55 | return x 56 | 57 | def _denormalize(self, x): 58 | if self.non_norm: 59 | return x 60 | if self.affine: 61 | x = x - self.affine_bias 62 | x = x / (self.affine_weight + self.eps * self.eps) 63 | x = x * self.stdev 64 | if self.subtract_last: 65 | x = x + self.last 66 | else: 67 | x = x + self.mean 68 | return x 69 | -------------------------------------------------------------------------------- /TimesNet/layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, c_in): 8 | super(ConvLayer, self).__init__() 9 | self.downConv = nn.Conv1d(in_channels=c_in, 10 | out_channels=c_in, 11 | kernel_size=3, 12 | padding=2, 13 | padding_mode='circular') 14 | self.norm = nn.BatchNorm1d(c_in) 15 | self.activation = nn.ELU() 16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 17 | 18 | def forward(self, x): 19 | x = self.downConv(x.permute(0, 2, 1)) 20 | x = self.norm(x) 21 | x = self.activation(x) 22 | x = self.maxPool(x) 23 | x = x.transpose(1, 2) 24 | return x 25 | 26 | 27 | class EncoderLayer(nn.Module): 28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 29 | super(EncoderLayer, self).__init__() 30 | d_ff = d_ff or 4 * d_model 31 | self.attention = attention 32 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 33 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 34 | self.norm1 = nn.LayerNorm(d_model) 35 | self.norm2 = nn.LayerNorm(d_model) 36 | self.dropout = nn.Dropout(dropout) 37 | self.activation = F.relu if activation == "relu" else F.gelu 38 | 39 | def forward(self, x, attn_mask=None, tau=None, delta=None): 40 | new_x, attn = self.attention( 41 | x, x, x, 42 | attn_mask=attn_mask, 43 | tau=tau, delta=delta 44 | ) 45 | x = x + self.dropout(new_x) 46 | 47 | y = x = self.norm1(x) 48 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 49 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 50 | 51 | return self.norm2(x + y), attn 52 | 53 | 54 | class Encoder(nn.Module): 55 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 56 | super(Encoder, self).__init__() 57 | self.attn_layers = nn.ModuleList(attn_layers) 58 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 59 | self.norm = norm_layer 60 | 61 | def forward(self, x, attn_mask=None, tau=None, delta=None): 62 | # x [B, L, D] 63 | attns = [] 64 | if self.conv_layers is not None: 65 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 66 | delta = delta if i == 0 else None 67 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 68 | x = conv_layer(x) 69 | attns.append(attn) 70 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 71 | attns.append(attn) 72 | else: 73 | for attn_layer in self.attn_layers: 74 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 75 | attns.append(attn) 76 | 77 | if self.norm is not None: 78 | x = self.norm(x) 79 | 80 | return x, attns 81 | 82 | 83 | class DecoderLayer(nn.Module): 84 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 85 | dropout=0.1, activation="relu"): 86 | super(DecoderLayer, self).__init__() 87 | d_ff = d_ff or 4 * d_model 88 | self.self_attention = self_attention 89 | self.cross_attention = cross_attention 90 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 91 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 92 | self.norm1 = nn.LayerNorm(d_model) 93 | self.norm2 = nn.LayerNorm(d_model) 94 | self.norm3 = nn.LayerNorm(d_model) 95 | self.dropout = nn.Dropout(dropout) 96 | self.activation = F.relu if activation == "relu" else F.gelu 97 | 98 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 99 | x = x + self.dropout(self.self_attention( 100 | x, x, x, 101 | attn_mask=x_mask, 102 | tau=tau, delta=None 103 | )[0]) 104 | x = self.norm1(x) 105 | 106 | x = x + self.dropout(self.cross_attention( 107 | x, cross, cross, 108 | attn_mask=cross_mask, 109 | tau=tau, delta=delta 110 | )[0]) 111 | 112 | y = x = self.norm2(x) 113 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 114 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 115 | 116 | return self.norm3(x + y) 117 | 118 | 119 | class Decoder(nn.Module): 120 | def __init__(self, layers, norm_layer=None, projection=None): 121 | super(Decoder, self).__init__() 122 | self.layers = nn.ModuleList(layers) 123 | self.norm = norm_layer 124 | self.projection = projection 125 | 126 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 127 | for layer in self.layers: 128 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 129 | 130 | if self.norm is not None: 131 | x = self.norm(x) 132 | 133 | if self.projection is not None: 134 | x = self.projection(x) 135 | return x 136 | -------------------------------------------------------------------------------- /TimesNet/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__init__.py -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/AutoCorrelation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/AutoCorrelation.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/Autoformer_EncDec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/Autoformer_EncDec.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/Conv_Blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/Conv_Blocks.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/Crossformer_EncDec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/Crossformer_EncDec.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/DWT_Decomposition.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/DWT_Decomposition.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/ETSformer_EncDec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/ETSformer_EncDec.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/Embed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/Embed.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/FourierCorrelation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/FourierCorrelation.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/MultiWaveletCorrelation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/MultiWaveletCorrelation.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/Pyraformer_EncDec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/Pyraformer_EncDec.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/SelfAttention_Family.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/SelfAttention_Family.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/StandardNorm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/StandardNorm.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/Transformer_EncDec.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/Transformer_EncDec.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/layers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/layers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/DLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Autoformer_EncDec import series_decomp 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/pdf/2205.13504.pdf 10 | """ 11 | 12 | def __init__(self, configs, individual=False): 13 | """ 14 | individual: Bool, whether shared model among different variates. 15 | """ 16 | super(Model, self).__init__() 17 | self.task_name = configs.task_name 18 | self.seq_len = configs.seq_len 19 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 20 | self.pred_len = configs.seq_len 21 | else: 22 | self.pred_len = configs.pred_len 23 | # Series decomposition block from Autoformer 24 | self.decompsition = series_decomp(configs.moving_avg) 25 | self.individual = individual 26 | self.channels = configs.enc_in 27 | 28 | if self.individual: 29 | self.Linear_Seasonal = nn.ModuleList() 30 | self.Linear_Trend = nn.ModuleList() 31 | 32 | for i in range(self.channels): 33 | self.Linear_Seasonal.append( 34 | nn.Linear(self.seq_len, self.pred_len)) 35 | self.Linear_Trend.append( 36 | nn.Linear(self.seq_len, self.pred_len)) 37 | 38 | self.Linear_Seasonal[i].weight = nn.Parameter( 39 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 40 | self.Linear_Trend[i].weight = nn.Parameter( 41 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 42 | else: 43 | self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len) 44 | self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len) 45 | 46 | self.Linear_Seasonal.weight = nn.Parameter( 47 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 48 | self.Linear_Trend.weight = nn.Parameter( 49 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 50 | 51 | if self.task_name == 'classification': 52 | self.projection = nn.Linear( 53 | configs.enc_in * configs.seq_len, configs.num_class) 54 | 55 | def encoder(self, x): 56 | seasonal_init, trend_init = self.decompsition(x) 57 | seasonal_init, trend_init = seasonal_init.permute( 58 | 0, 2, 1), trend_init.permute(0, 2, 1) 59 | if self.individual: 60 | seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len], 61 | dtype=seasonal_init.dtype).to(seasonal_init.device) 62 | trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len], 63 | dtype=trend_init.dtype).to(trend_init.device) 64 | for i in range(self.channels): 65 | seasonal_output[:, i, :] = self.Linear_Seasonal[i]( 66 | seasonal_init[:, i, :]) 67 | trend_output[:, i, :] = self.Linear_Trend[i]( 68 | trend_init[:, i, :]) 69 | else: 70 | seasonal_output = self.Linear_Seasonal(seasonal_init) 71 | trend_output = self.Linear_Trend(trend_init) 72 | x = seasonal_output + trend_output 73 | return x.permute(0, 2, 1) 74 | 75 | def forecast(self, x_enc): 76 | # Encoder 77 | return self.encoder(x_enc) 78 | 79 | def imputation(self, x_enc): 80 | # Encoder 81 | return self.encoder(x_enc) 82 | 83 | def anomaly_detection(self, x_enc): 84 | # Encoder 85 | return self.encoder(x_enc) 86 | 87 | def classification(self, x_enc): 88 | # Encoder 89 | enc_out = self.encoder(x_enc) 90 | # Output 91 | # (batch_size, seq_length * d_model) 92 | output = enc_out.reshape(enc_out.shape[0], -1) 93 | # (batch_size, num_classes) 94 | output = self.projection(output) 95 | return output 96 | 97 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 98 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 99 | dec_out = self.forecast(x_enc) 100 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 101 | if self.task_name == 'imputation': 102 | dec_out = self.imputation(x_enc) 103 | return dec_out # [B, L, D] 104 | if self.task_name == 'anomaly_detection': 105 | dec_out = self.anomaly_detection(x_enc) 106 | return dec_out # [B, L, D] 107 | if self.task_name == 'classification': 108 | dec_out = self.classification(x_enc) 109 | return dec_out # [B, N] 110 | return None 111 | -------------------------------------------------------------------------------- /TimesNet/models/ETSformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Embed import DataEmbedding 4 | from layers.ETSformer_EncDec import EncoderLayer, Encoder, DecoderLayer, Decoder, Transform 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/abs/2202.01381 10 | """ 11 | 12 | def __init__(self, configs): 13 | super(Model, self).__init__() 14 | self.task_name = configs.task_name 15 | self.seq_len = configs.seq_len 16 | self.label_len = configs.label_len 17 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 18 | self.pred_len = configs.seq_len 19 | else: 20 | self.pred_len = configs.pred_len 21 | 22 | assert configs.e_layers == configs.d_layers, "Encoder and decoder layers must be equal" 23 | 24 | # Embedding 25 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 26 | configs.dropout) 27 | 28 | # Encoder 29 | self.encoder = Encoder( 30 | [ 31 | EncoderLayer( 32 | configs.d_model, configs.n_heads, configs.enc_in, configs.seq_len, self.pred_len, configs.top_k, 33 | dim_feedforward=configs.d_ff, 34 | dropout=configs.dropout, 35 | activation=configs.activation, 36 | ) for _ in range(configs.e_layers) 37 | ] 38 | ) 39 | # Decoder 40 | self.decoder = Decoder( 41 | [ 42 | DecoderLayer( 43 | configs.d_model, configs.n_heads, configs.c_out, self.pred_len, 44 | dropout=configs.dropout, 45 | ) for _ in range(configs.d_layers) 46 | ], 47 | ) 48 | self.transform = Transform(sigma=0.2) 49 | 50 | if self.task_name == 'classification': 51 | self.act = torch.nn.functional.gelu 52 | self.dropout = nn.Dropout(configs.dropout) 53 | self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class) 54 | 55 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 56 | with torch.no_grad(): 57 | if self.training: 58 | x_enc = self.transform.transform(x_enc) 59 | res = self.enc_embedding(x_enc, x_mark_enc) 60 | level, growths, seasons = self.encoder(res, x_enc, attn_mask=None) 61 | 62 | growth, season = self.decoder(growths, seasons) 63 | preds = level[:, -1:] + growth + season 64 | return preds 65 | 66 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 67 | res = self.enc_embedding(x_enc, x_mark_enc) 68 | level, growths, seasons = self.encoder(res, x_enc, attn_mask=None) 69 | growth, season = self.decoder(growths, seasons) 70 | preds = level[:, -1:] + growth + season 71 | return preds 72 | 73 | def anomaly_detection(self, x_enc): 74 | res = self.enc_embedding(x_enc, None) 75 | level, growths, seasons = self.encoder(res, x_enc, attn_mask=None) 76 | growth, season = self.decoder(growths, seasons) 77 | preds = level[:, -1:] + growth + season 78 | return preds 79 | 80 | def classification(self, x_enc, x_mark_enc): 81 | res = self.enc_embedding(x_enc, None) 82 | _, growths, seasons = self.encoder(res, x_enc, attn_mask=None) 83 | 84 | growths = torch.sum(torch.stack(growths, 0), 0)[:, :self.seq_len, :] 85 | seasons = torch.sum(torch.stack(seasons, 0), 0)[:, :self.seq_len, :] 86 | 87 | enc_out = growths + seasons 88 | output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity 89 | output = self.dropout(output) 90 | 91 | # Output 92 | output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings 93 | output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model) 94 | output = self.projection(output) # (batch_size, num_classes) 95 | return output 96 | 97 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 98 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 99 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 100 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 101 | if self.task_name == 'imputation': 102 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 103 | return dec_out # [B, L, D] 104 | if self.task_name == 'anomaly_detection': 105 | dec_out = self.anomaly_detection(x_enc) 106 | return dec_out # [B, L, D] 107 | if self.task_name == 'classification': 108 | dec_out = self.classification(x_enc, x_mark_enc) 109 | return dec_out # [B, N] 110 | return None 111 | -------------------------------------------------------------------------------- /TimesNet/models/FreTS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/pdf/2311.06184.pdf 10 | """ 11 | 12 | def __init__(self, configs): 13 | super(Model, self).__init__() 14 | self.task_name = configs.task_name 15 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 16 | self.pred_len = configs.seq_len 17 | else: 18 | self.pred_len = configs.pred_len 19 | self.embed_size = 128 # embed_size 20 | self.hidden_size = 256 # hidden_size 21 | self.pred_len = configs.pred_len 22 | self.feature_size = configs.enc_in # channels 23 | self.seq_len = configs.seq_len 24 | self.channel_independence = configs.channel_independence 25 | self.sparsity_threshold = 0.01 26 | self.scale = 0.02 27 | self.embeddings = nn.Parameter(torch.randn(1, self.embed_size)) 28 | self.r1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) 29 | self.i1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) 30 | self.rb1 = nn.Parameter(self.scale * torch.randn(self.embed_size)) 31 | self.ib1 = nn.Parameter(self.scale * torch.randn(self.embed_size)) 32 | self.r2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) 33 | self.i2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) 34 | self.rb2 = nn.Parameter(self.scale * torch.randn(self.embed_size)) 35 | self.ib2 = nn.Parameter(self.scale * torch.randn(self.embed_size)) 36 | 37 | self.fc = nn.Sequential( 38 | nn.Linear(self.seq_len * self.embed_size, self.hidden_size), 39 | nn.LeakyReLU(), 40 | nn.Linear(self.hidden_size, self.pred_len) 41 | ) 42 | 43 | # dimension extension 44 | def tokenEmb(self, x): 45 | # x: [Batch, Input length, Channel] 46 | x = x.permute(0, 2, 1) 47 | x = x.unsqueeze(3) 48 | # N*T*1 x 1*D = N*T*D 49 | y = self.embeddings 50 | return x * y 51 | 52 | # frequency temporal learner 53 | def MLP_temporal(self, x, B, N, L): 54 | # [B, N, T, D] 55 | x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on L dimension 56 | y = self.FreMLP(B, N, L, x, self.r2, self.i2, self.rb2, self.ib2) 57 | x = torch.fft.irfft(y, n=self.seq_len, dim=2, norm="ortho") 58 | return x 59 | 60 | # frequency channel learner 61 | def MLP_channel(self, x, B, N, L): 62 | # [B, N, T, D] 63 | x = x.permute(0, 2, 1, 3) 64 | # [B, T, N, D] 65 | x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on N dimension 66 | y = self.FreMLP(B, L, N, x, self.r1, self.i1, self.rb1, self.ib1) 67 | x = torch.fft.irfft(y, n=self.feature_size, dim=2, norm="ortho") 68 | x = x.permute(0, 2, 1, 3) 69 | # [B, N, T, D] 70 | return x 71 | 72 | # frequency-domain MLPs 73 | # dimension: FFT along the dimension, r: the real part of weights, i: the imaginary part of weights 74 | # rb: the real part of bias, ib: the imaginary part of bias 75 | def FreMLP(self, B, nd, dimension, x, r, i, rb, ib): 76 | o1_real = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size], 77 | device=x.device) 78 | o1_imag = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size], 79 | device=x.device) 80 | 81 | o1_real = F.relu( 82 | torch.einsum('bijd,dd->bijd', x.real, r) - \ 83 | torch.einsum('bijd,dd->bijd', x.imag, i) + \ 84 | rb 85 | ) 86 | 87 | o1_imag = F.relu( 88 | torch.einsum('bijd,dd->bijd', x.imag, r) + \ 89 | torch.einsum('bijd,dd->bijd', x.real, i) + \ 90 | ib 91 | ) 92 | 93 | y = torch.stack([o1_real, o1_imag], dim=-1) 94 | y = F.softshrink(y, lambd=self.sparsity_threshold) 95 | y = torch.view_as_complex(y) 96 | return y 97 | 98 | def forecast(self, x_enc): 99 | # x: [Batch, Input length, Channel] 100 | B, T, N = x_enc.shape 101 | # embedding x: [B, N, T, D] 102 | x = self.tokenEmb(x_enc) 103 | bias = x 104 | # [B, N, T, D] 105 | if self.channel_independence == '0': 106 | x = self.MLP_channel(x, B, N, T) 107 | # [B, N, T, D] 108 | x = self.MLP_temporal(x, B, N, T) 109 | x = x + bias 110 | x = self.fc(x.reshape(B, N, -1)).permute(0, 2, 1) 111 | return x 112 | 113 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 114 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 115 | dec_out = self.forecast(x_enc) 116 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 117 | else: 118 | raise ValueError('Only forecast tasks implemented yet') 119 | -------------------------------------------------------------------------------- /TimesNet/models/Mamba.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from mamba_ssm import Mamba 8 | 9 | from layers.Embed import DataEmbedding 10 | 11 | class Model(nn.Module): 12 | 13 | def __init__(self, configs): 14 | super(Model, self).__init__() 15 | self.task_name = configs.task_name 16 | self.pred_len = configs.pred_len 17 | 18 | self.d_inner = configs.d_model * configs.expand 19 | self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto" 20 | 21 | self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) 22 | 23 | self.mamba = Mamba( 24 | d_model = configs.d_model, 25 | d_state = configs.d_ff, 26 | d_conv = configs.d_conv, 27 | expand = configs.expand, 28 | ) 29 | 30 | self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) 31 | 32 | def forecast(self, x_enc, x_mark_enc): 33 | mean_enc = x_enc.mean(1, keepdim=True).detach() 34 | x_enc = x_enc - mean_enc 35 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() 36 | x_enc = x_enc / std_enc 37 | 38 | x = self.embedding(x_enc, x_mark_enc) 39 | x = self.mamba(x) 40 | x_out = self.out_layer(x) 41 | 42 | x_out = x_out * std_enc + mean_enc 43 | return x_out 44 | 45 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 46 | if self.task_name in ['short_term_forecast', 'long_term_forecast']: 47 | x_out = self.forecast(x_enc, x_mark_enc) 48 | return x_out[:, -self.pred_len:, :] 49 | 50 | # other tasks not implemented -------------------------------------------------------------------------------- /TimesNet/models/PAttn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Transformer_EncDec import Encoder, EncoderLayer 4 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 5 | from einops import rearrange 6 | 7 | 8 | class Model(nn.Module): 9 | """ 10 | Paper link: https://arxiv.org/abs/2406.16964 11 | """ 12 | def __init__(self, configs, patch_len=16, stride=8): 13 | super().__init__() 14 | self.seq_len = configs.seq_len 15 | self.pred_len = configs.pred_len 16 | self.patch_size = patch_len 17 | self.stride = stride 18 | 19 | self.d_model = configs.d_model 20 | 21 | self.patch_num = (configs.seq_len - self.patch_size) // self.stride + 2 22 | self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride)) 23 | self.in_layer = nn.Linear(self.patch_size, self.d_model) 24 | self.encoder = Encoder( 25 | [ 26 | EncoderLayer( 27 | AttentionLayer( 28 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 29 | output_attention=False), configs.d_model, configs.n_heads), 30 | configs.d_model, 31 | configs.d_ff, 32 | dropout=configs.dropout, 33 | activation=configs.activation 34 | ) for l in range(1) 35 | ], 36 | norm_layer=nn.LayerNorm(configs.d_model) 37 | ) 38 | self.out_layer = nn.Linear(self.d_model * self.patch_num, configs.pred_len) 39 | 40 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 41 | means = x_enc.mean(1, keepdim=True).detach() 42 | x_enc = x_enc - means 43 | stdev = torch.sqrt( 44 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 45 | x_enc /= stdev 46 | 47 | B, _, C = x_enc.shape 48 | x_enc = x_enc.permute(0, 2, 1) 49 | x_enc = self.padding_patch_layer(x_enc) 50 | x_enc = x_enc.unfold(dimension=-1, size=self.patch_size, step=self.stride) 51 | enc_out = self.in_layer(x_enc) 52 | enc_out = rearrange(enc_out, 'b c m l -> (b c) m l') 53 | dec_out, _ = self.encoder(enc_out) 54 | dec_out = rearrange(dec_out, '(b c) m l -> b c (m l)' , b=B , c=C) 55 | dec_out = self.out_layer(dec_out) 56 | dec_out = dec_out.permute(0, 2, 1) 57 | 58 | dec_out = dec_out * \ 59 | (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 60 | dec_out = dec_out + \ 61 | (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 62 | return dec_out -------------------------------------------------------------------------------- /TimesNet/models/Pyraformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Pyraformer_EncDec import Encoder 4 | 5 | 6 | class Model(nn.Module): 7 | """ 8 | Pyraformer: Pyramidal attention to reduce complexity 9 | Paper link: https://openreview.net/pdf?id=0EXmFzUn5I 10 | """ 11 | 12 | def __init__(self, configs, window_size=[4,4], inner_size=5): 13 | """ 14 | window_size: list, the downsample window size in pyramidal attention. 15 | inner_size: int, the size of neighbour attention 16 | """ 17 | super().__init__() 18 | self.task_name = configs.task_name 19 | self.pred_len = configs.pred_len 20 | self.d_model = configs.d_model 21 | 22 | if self.task_name == 'short_term_forecast': 23 | window_size = [2,2] 24 | self.encoder = Encoder(configs, window_size, inner_size) 25 | 26 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 27 | self.projection = nn.Linear( 28 | (len(window_size)+1)*self.d_model, self.pred_len * configs.enc_in) 29 | elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection': 30 | self.projection = nn.Linear( 31 | (len(window_size)+1)*self.d_model, configs.enc_in, bias=True) 32 | elif self.task_name == 'classification': 33 | self.act = torch.nn.functional.gelu 34 | self.dropout = nn.Dropout(configs.dropout) 35 | self.projection = nn.Linear( 36 | (len(window_size)+1)*self.d_model * configs.seq_len, configs.num_class) 37 | 38 | def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 39 | enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :] 40 | dec_out = self.projection(enc_out).view( 41 | enc_out.size(0), self.pred_len, -1) 42 | return dec_out 43 | 44 | def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 45 | # Normalization 46 | mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E 47 | x_enc = x_enc - mean_enc 48 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E 49 | x_enc = x_enc / std_enc 50 | 51 | enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :] 52 | dec_out = self.projection(enc_out).view( 53 | enc_out.size(0), self.pred_len, -1) 54 | 55 | dec_out = dec_out * std_enc + mean_enc 56 | return dec_out 57 | 58 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 59 | enc_out = self.encoder(x_enc, x_mark_enc) 60 | dec_out = self.projection(enc_out) 61 | return dec_out 62 | 63 | def anomaly_detection(self, x_enc, x_mark_enc): 64 | enc_out = self.encoder(x_enc, x_mark_enc) 65 | dec_out = self.projection(enc_out) 66 | return dec_out 67 | 68 | def classification(self, x_enc, x_mark_enc): 69 | # enc 70 | enc_out = self.encoder(x_enc, x_mark_enc=None) 71 | 72 | # Output 73 | # the output transformer encoder/decoder embeddings don't include non-linearity 74 | output = self.act(enc_out) 75 | output = self.dropout(output) 76 | # zero-out padding embeddings 77 | output = output * x_mark_enc.unsqueeze(-1) 78 | # (batch_size, seq_length * d_model) 79 | output = output.reshape(output.shape[0], -1) 80 | output = self.projection(output) # (batch_size, num_classes) 81 | 82 | return output 83 | 84 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 85 | if self.task_name == 'long_term_forecast': 86 | dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 87 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 88 | if self.task_name == 'short_term_forecast': 89 | dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 90 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 91 | if self.task_name == 'imputation': 92 | dec_out = self.imputation( 93 | x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 94 | return dec_out # [B, L, D] 95 | if self.task_name == 'anomaly_detection': 96 | dec_out = self.anomaly_detection(x_enc, x_mark_enc) 97 | return dec_out # [B, L, D] 98 | if self.task_name == 'classification': 99 | dec_out = self.classification(x_enc, x_mark_enc) 100 | return dec_out # [B, N] 101 | return None 102 | -------------------------------------------------------------------------------- /TimesNet/models/SegRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Autoformer_EncDec import series_decomp 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/abs/2308.11200.pdf 10 | """ 11 | 12 | def __init__(self, configs): 13 | super(Model, self).__init__() 14 | 15 | # get parameters 16 | self.seq_len = configs.seq_len 17 | self.enc_in = configs.enc_in 18 | self.d_model = configs.d_model 19 | self.dropout = configs.dropout 20 | 21 | self.task_name = configs.task_name 22 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 23 | self.pred_len = configs.seq_len 24 | else: 25 | self.pred_len = configs.pred_len 26 | 27 | self.seg_len = configs.seg_len 28 | self.seg_num_x = self.seq_len // self.seg_len 29 | self.seg_num_y = self.pred_len // self.seg_len 30 | 31 | # building model 32 | self.valueEmbedding = nn.Sequential( 33 | nn.Linear(self.seg_len, self.d_model), 34 | nn.ReLU() 35 | ) 36 | self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, 37 | batch_first=True, bidirectional=False) 38 | self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2)) 39 | self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2)) 40 | 41 | self.predict = nn.Sequential( 42 | nn.Dropout(self.dropout), 43 | nn.Linear(self.d_model, self.seg_len) 44 | ) 45 | 46 | if self.task_name == 'classification': 47 | self.act = F.gelu 48 | self.dropout = nn.Dropout(configs.dropout) 49 | self.projection = nn.Linear( 50 | configs.enc_in * configs.seq_len, configs.num_class) 51 | 52 | def encoder(self, x): 53 | # b:batch_size c:channel_size s:seq_len s:seq_len 54 | # d:d_model w:seg_len n:seg_num_x m:seg_num_y 55 | batch_size = x.size(0) 56 | 57 | # normalization and permute b,s,c -> b,c,s 58 | seq_last = x[:, -1:, :].detach() 59 | x = (x - seq_last).permute(0, 2, 1) # b,c,s 60 | 61 | # segment and embedding b,c,s -> bc,n,w -> bc,n,d 62 | x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len)) 63 | 64 | # encoding 65 | _, hn = self.rnn(x) # bc,n,d 1,bc,d 66 | 67 | # m,d//2 -> 1,m,d//2 -> c,m,d//2 68 | # c,d//2 -> c,1,d//2 -> c,m,d//2 69 | # c,m,d -> cm,1,d -> bcm, 1, d 70 | pos_emb = torch.cat([ 71 | self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1), 72 | self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1) 73 | ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1) 74 | 75 | _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d 76 | 77 | # 1,bcm,d -> 1,bcm,w -> b,c,s 78 | y = self.predict(hy).view(-1, self.enc_in, self.pred_len) 79 | 80 | # permute and denorm 81 | y = y.permute(0, 2, 1) + seq_last 82 | return y 83 | 84 | def forecast(self, x_enc): 85 | # Encoder 86 | return self.encoder(x_enc) 87 | 88 | def imputation(self, x_enc): 89 | # Encoder 90 | return self.encoder(x_enc) 91 | 92 | def anomaly_detection(self, x_enc): 93 | # Encoder 94 | return self.encoder(x_enc) 95 | 96 | def classification(self, x_enc): 97 | # Encoder 98 | enc_out = self.encoder(x_enc) 99 | # Output 100 | # (batch_size, seq_length * d_model) 101 | output = enc_out.reshape(enc_out.shape[0], -1) 102 | # (batch_size, num_classes) 103 | output = self.projection(output) 104 | return output 105 | 106 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 107 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 108 | dec_out = self.forecast(x_enc) 109 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 110 | if self.task_name == 'imputation': 111 | dec_out = self.imputation(x_enc) 112 | return dec_out # [B, L, D] 113 | if self.task_name == 'anomaly_detection': 114 | dec_out = self.anomaly_detection(x_enc) 115 | return dec_out # [B, L, D] 116 | if self.task_name == 'classification': 117 | dec_out = self.classification(x_enc) 118 | return dec_out # [B, N] 119 | return None 120 | -------------------------------------------------------------------------------- /TimesNet/models/TSMixer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ResBlock(nn.Module): 5 | def __init__(self, configs): 6 | super(ResBlock, self).__init__() 7 | 8 | self.temporal = nn.Sequential( 9 | nn.Linear(configs.seq_len, configs.d_model), 10 | nn.ReLU(), 11 | nn.Linear(configs.d_model, configs.seq_len), 12 | nn.Dropout(configs.dropout) 13 | ) 14 | 15 | self.channel = nn.Sequential( 16 | nn.Linear(configs.enc_in, configs.d_model), 17 | nn.ReLU(), 18 | nn.Linear(configs.d_model, configs.enc_in), 19 | nn.Dropout(configs.dropout) 20 | ) 21 | 22 | def forward(self, x): 23 | # x: [B, L, D] 24 | x = x + self.temporal(x.transpose(1, 2)).transpose(1, 2) 25 | x = x + self.channel(x) 26 | 27 | return x 28 | 29 | 30 | class Model(nn.Module): 31 | def __init__(self, configs): 32 | super(Model, self).__init__() 33 | self.task_name = configs.task_name 34 | self.layer = configs.e_layers 35 | self.model = nn.ModuleList([ResBlock(configs) 36 | for _ in range(configs.e_layers)]) 37 | self.pred_len = configs.pred_len 38 | self.projection = nn.Linear(configs.seq_len, configs.pred_len) 39 | 40 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 41 | 42 | # x: [B, L, D] 43 | for i in range(self.layer): 44 | x_enc = self.model[i](x_enc) 45 | enc_out = self.projection(x_enc.transpose(1, 2)).transpose(1, 2) 46 | 47 | return enc_out 48 | 49 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 50 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 51 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 52 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 53 | else: 54 | raise ValueError('Only forecast tasks implemented yet') 55 | -------------------------------------------------------------------------------- /TimesNet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__init__.py -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/Autoformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/Autoformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/Crossformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/Crossformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/DLinear.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/DLinear.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/ETSformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/ETSformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/FEDformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/FEDformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/FiLM.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/FiLM.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/FreTS.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/FreTS.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/GAN.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/GAN.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/Informer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/Informer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/Koopa.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/Koopa.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/LightTS.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/LightTS.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/MICN.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/MICN.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/MambaSimple.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/MambaSimple.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/MultiPatchFormer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/MultiPatchFormer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/Nonstationary_Transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/Nonstationary_Transformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/PAttn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/PAttn.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/PatchTST.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/PatchTST.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/Pyraformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/Pyraformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/Reformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/Reformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/SCINet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/SCINet.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/SegRNN.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/SegRNN.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/TSMixer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/TSMixer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/TemporalFusionTransformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/TemporalFusionTransformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/TiDE.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/TiDE.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/TimeMixer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/TimeMixer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/TimeXer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/TimeXer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/TimesNet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/TimesNet.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/Transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/Transformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/WPMixer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/WPMixer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/models/__pycache__/iTransformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/models/__pycache__/iTransformer.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/ADFtest.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | from statsmodels.tsa.stattools import adfuller 5 | from arch.unitroot import ADF 6 | 7 | def calculate_ADF(root_path,data_path): 8 | df_raw = pd.read_csv(os.path.join(root_path,data_path)) 9 | cols = list(df_raw.columns) 10 | cols.remove('date') 11 | df_raw = df_raw[cols] 12 | adf_list = [] 13 | for i in cols: 14 | df_data = df_raw[i] 15 | adf = adfuller(df_data, maxlag = 1) 16 | print(adf) 17 | adf_list.append(adf) 18 | return np.array(adf_list) 19 | 20 | def calculate_target_ADF(root_path,data_path,target='OT'): 21 | df_raw = pd.read_csv(os.path.join(root_path,data_path)) 22 | target_cols = target.split(',') 23 | # df_data = df_raw[target] 24 | df_raw = df_raw[target_cols] 25 | adf_list = [] 26 | for i in target_cols: 27 | df_data = df_raw[i] 28 | adf = adfuller(df_data, maxlag = 1) 29 | # print(adf) 30 | adf_list.append(adf) 31 | return np.array(adf_list) 32 | 33 | def archADF(root_path, data_path): 34 | df = pd.read_csv(os.path.join(root_path,data_path)) 35 | cols = df.columns[1:] 36 | stats = 0 37 | for target_col in cols: 38 | series = df[target_col].values 39 | adf = ADF(series) 40 | stat = adf.stat 41 | stats += stat 42 | return stats/len(cols) 43 | 44 | if __name__ == '__main__': 45 | 46 | # * Exchange - result: -1.902402344564288 | report: -1.889 47 | ADFmetric = archADF(root_path="./dataset/exchange_rate/",data_path="exchange_rate.csv") 48 | print("Exchange ADF metric", ADFmetric) 49 | 50 | # * Illness - result: -5.33416661870624 | report: -5.406 51 | ADFmetric = archADF(root_path="./dataset/illness/",data_path="national_illness.csv") 52 | print("Illness ADF metric", ADFmetric) 53 | 54 | # * ETTm2 - result: -5.663628743471695 | report: -6.225 55 | ADFmetric = archADF(root_path="./dataset/ETT-small/",data_path="ETTm2.csv") 56 | print("ETTm2 ADF metric", ADFmetric) 57 | 58 | # * Electricity - result: -8.44485821939281 | report: -8.483 59 | ADFmetric = archADF(root_path="./dataset/electricity/",data_path="electricity.csv") 60 | print("Electricity ADF metric", ADFmetric) 61 | 62 | # * Traffic - result: -15.020978067839014 | report: -15.046 63 | ADFmetric = archADF(root_path="./dataset/traffic/",data_path="traffic.csv") 64 | print("Traffic ADF metric", ADFmetric) 65 | 66 | # * Weather - result: -26.681433085204866 | report: -26.661 67 | ADFmetric = archADF(root_path="./dataset/weather/",data_path="weather.csv") 68 | print("Weather ADF metric", ADFmetric) 69 | 70 | 71 | # print(ADFmetric) 72 | 73 | # mean_ADFmetric = ADFmetric[:,0].mean() 74 | # print(mean_ADFmetric) -------------------------------------------------------------------------------- /TimesNet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__init__.py -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/augmentation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/augmentation.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/dtw_metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/dtw_metric.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/m4_summary.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/m4_summary.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/masking.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/masking.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/print_args.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/print_args.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/timefeatures.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/timefeatures.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/__pycache__/tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/universeplayer/MTSIR3-GAN/a5b91b3c8f36d2f6a3a3d6332b4c161366cfe3b1/TimesNet/utils/__pycache__/tools.cpython-310.pyc -------------------------------------------------------------------------------- /TimesNet/utils/losses.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | Loss functions for PyTorch. 17 | """ 18 | 19 | import torch as t 20 | import torch.nn as nn 21 | import numpy as np 22 | import pdb 23 | 24 | 25 | def divide_no_nan(a, b): 26 | """ 27 | a/b where the resulted NaN or Inf are replaced by 0. 28 | """ 29 | result = a / b 30 | result[result != result] = .0 31 | result[result == np.inf] = .0 32 | return result 33 | 34 | 35 | class mape_loss(nn.Module): 36 | def __init__(self): 37 | super(mape_loss, self).__init__() 38 | 39 | def forward(self, insample: t.Tensor, freq: int, 40 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 41 | """ 42 | MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error 43 | 44 | :param forecast: Forecast values. Shape: batch, time 45 | :param target: Target values. Shape: batch, time 46 | :param mask: 0/1 mask. Shape: batch, time 47 | :return: Loss value 48 | """ 49 | weights = divide_no_nan(mask, target) 50 | return t.mean(t.abs((forecast - target) * weights)) 51 | 52 | 53 | class smape_loss(nn.Module): 54 | def __init__(self): 55 | super(smape_loss, self).__init__() 56 | 57 | def forward(self, insample: t.Tensor, freq: int, 58 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 59 | """ 60 | sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993) 61 | 62 | :param forecast: Forecast values. Shape: batch, time 63 | :param target: Target values. Shape: batch, time 64 | :param mask: 0/1 mask. Shape: batch, time 65 | :return: Loss value 66 | """ 67 | return 200 * t.mean(divide_no_nan(t.abs(forecast - target), 68 | t.abs(forecast.data) + t.abs(target.data)) * mask) 69 | 70 | 71 | class mase_loss(nn.Module): 72 | def __init__(self): 73 | super(mase_loss, self).__init__() 74 | 75 | def forward(self, insample: t.Tensor, freq: int, 76 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 77 | """ 78 | MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf 79 | 80 | :param insample: Insample values. Shape: batch, time_i 81 | :param freq: Frequency value 82 | :param forecast: Forecast values. Shape: batch, time_o 83 | :param target: Target values. Shape: batch, time_o 84 | :param mask: 0/1 mask. Shape: batch, time_o 85 | :return: Loss value 86 | """ 87 | masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1) 88 | masked_masep_inv = divide_no_nan(mask, masep[:, None]) 89 | return t.mean(t.abs(target - forecast) * masked_masep_inv) 90 | -------------------------------------------------------------------------------- /TimesNet/utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /TimesNet/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(true - pred)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((true - pred) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | 26 | def MAPE(pred, true): 27 | return np.mean(np.abs((true - pred) / true)) 28 | 29 | 30 | def MSPE(pred, true): 31 | return np.mean(np.square((true - pred) / true)) 32 | 33 | 34 | def metric(pred, true): 35 | mae = MAE(pred, true) 36 | mse = MSE(pred, true) 37 | rmse = RMSE(pred, true) 38 | mape = MAPE(pred, true) 39 | mspe = MSPE(pred, true) 40 | 41 | return mae, mse, rmse, mape, mspe 42 | -------------------------------------------------------------------------------- /TimesNet/utils/print_args.py: -------------------------------------------------------------------------------- 1 | def print_args(args): 2 | print("\033[1m" + "Basic Config" + "\033[0m") 3 | print(f' {"Task Name:":<20}{args.task_name:<20}{"Is Training:":<20}{args.is_training:<20}') 4 | print(f' {"Model ID:":<20}{args.model_id:<20}{"Model:":<20}{args.model:<20}') 5 | print() 6 | 7 | print("\033[1m" + "Data Loader" + "\033[0m") 8 | print(f' {"Data:":<20}{args.data:<20}{"Root Path:":<20}{args.root_path:<20}') 9 | print(f' {"Data Path:":<20}{args.data_path:<20}{"Features:":<20}{args.features:<20}') 10 | print(f' {"Target:":<20}{args.target:<20}{"Freq:":<20}{args.freq:<20}') 11 | print(f' {"Checkpoints:":<20}{args.checkpoints:<20}') 12 | print() 13 | 14 | if args.task_name in ['long_term_forecast', 'short_term_forecast']: 15 | print("\033[1m" + "Forecasting Task" + "\033[0m") 16 | print(f' {"Seq Len:":<20}{args.seq_len:<20}{"Label Len:":<20}{args.label_len:<20}') 17 | print(f' {"Pred Len:":<20}{args.pred_len:<20}{"Seasonal Patterns:":<20}{args.seasonal_patterns:<20}') 18 | print(f' {"Inverse:":<20}{args.inverse:<20}') 19 | print() 20 | 21 | if args.task_name == 'imputation': 22 | print("\033[1m" + "Imputation Task" + "\033[0m") 23 | print(f' {"Mask Rate:":<20}{args.mask_rate:<20}') 24 | print() 25 | 26 | if args.task_name == 'anomaly_detection': 27 | print("\033[1m" + "Anomaly Detection Task" + "\033[0m") 28 | print(f' {"Anomaly Ratio:":<20}{args.anomaly_ratio:<20}') 29 | print() 30 | 31 | print("\033[1m" + "Model Parameters" + "\033[0m") 32 | print(f' {"Top k:":<20}{args.top_k:<20}{"Num Kernels:":<20}{args.num_kernels:<20}') 33 | print(f' {"Enc In:":<20}{args.enc_in:<20}{"Dec In:":<20}{args.dec_in:<20}') 34 | print(f' {"C Out:":<20}{args.c_out:<20}{"d model:":<20}{args.d_model:<20}') 35 | print(f' {"n heads:":<20}{args.n_heads:<20}{"e layers:":<20}{args.e_layers:<20}') 36 | print(f' {"d layers:":<20}{args.d_layers:<20}{"d FF:":<20}{args.d_ff:<20}') 37 | print(f' {"Moving Avg:":<20}{args.moving_avg:<20}{"Factor:":<20}{args.factor:<20}') 38 | print(f' {"Distil:":<20}{args.distil:<20}{"Dropout:":<20}{args.dropout:<20}') 39 | print(f' {"Embed:":<20}{args.embed:<20}{"Activation:":<20}{args.activation:<20}') 40 | print() 41 | 42 | print("\033[1m" + "Run Parameters" + "\033[0m") 43 | print(f' {"Num Workers:":<20}{args.num_workers:<20}{"Itr:":<20}{args.itr:<20}') 44 | print(f' {"Train Epochs:":<20}{args.train_epochs:<20}{"Batch Size:":<20}{args.batch_size:<20}') 45 | print(f' {"Patience:":<20}{args.patience:<20}{"Learning Rate:":<20}{args.learning_rate:<20}') 46 | print(f' {"Des:":<20}{args.des:<20}{"Loss:":<20}{args.loss:<20}') 47 | print(f' {"Lradj:":<20}{args.lradj:<20}{"Use Amp:":<20}{args.use_amp:<20}') 48 | print() 49 | 50 | print("\033[1m" + "GPU" + "\033[0m") 51 | print(f' {"Use GPU:":<20}{args.use_gpu:<20}{"GPU:":<20}{args.gpu:<20}') 52 | print(f' {"Use Multi GPU:":<20}{args.use_multi_gpu:<20}{"Devices:":<20}{args.devices:<20}') 53 | print() 54 | 55 | print("\033[1m" + "De-stationary Projector Params" + "\033[0m") 56 | p_hidden_dims_str = ', '.join(map(str, args.p_hidden_dims)) 57 | print(f' {"P Hidden Dims:":<20}{p_hidden_dims_str:<20}{"P Hidden Layers:":<20}{args.p_hidden_layers:<20}') 58 | print() 59 | -------------------------------------------------------------------------------- /TimesNet/utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /TimesNet/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import math 8 | 9 | plt.switch_backend('agg') 10 | 11 | 12 | def adjust_learning_rate(optimizer, epoch, args): 13 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 14 | if args.lradj == 'type1': 15 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 16 | elif args.lradj == 'type2': 17 | lr_adjust = { 18 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 19 | 10: 5e-7, 15: 1e-7, 20: 5e-8 20 | } 21 | elif args.lradj == 'type3': 22 | lr_adjust = {epoch: args.learning_rate if epoch < 3 else args.learning_rate * (0.9 ** ((epoch - 3) // 1))} 23 | elif args.lradj == "cosine": 24 | lr_adjust = {epoch: args.learning_rate /2 * (1 + math.cos(epoch / args.train_epochs * math.pi))} 25 | if epoch in lr_adjust.keys(): 26 | lr = lr_adjust[epoch] 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = lr 29 | print('Updating learning rate to {}'.format(lr)) 30 | 31 | 32 | class EarlyStopping: 33 | def __init__(self, patience=7, verbose=False, delta=0): 34 | self.patience = patience 35 | self.verbose = verbose 36 | self.counter = 0 37 | self.best_score = None 38 | self.early_stop = False 39 | self.val_loss_min = np.inf 40 | self.delta = delta 41 | 42 | def __call__(self, val_loss, model, path): 43 | score = -val_loss 44 | if self.best_score is None: 45 | self.best_score = score 46 | self.save_checkpoint(val_loss, model, path) 47 | elif score < self.best_score + self.delta: 48 | self.counter += 1 49 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 50 | if self.counter >= self.patience: 51 | self.early_stop = True 52 | else: 53 | self.best_score = score 54 | self.save_checkpoint(val_loss, model, path) 55 | self.counter = 0 56 | 57 | def save_checkpoint(self, val_loss, model, path): 58 | if self.verbose: 59 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 60 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 61 | self.val_loss_min = val_loss 62 | 63 | 64 | class dotdict(dict): 65 | """dot.notation access to dictionary attributes""" 66 | __getattr__ = dict.get 67 | __setattr__ = dict.__setitem__ 68 | __delattr__ = dict.__delitem__ 69 | 70 | 71 | class StandardScaler(): 72 | def __init__(self, mean, std): 73 | self.mean = mean 74 | self.std = std 75 | 76 | def transform(self, data): 77 | return (data - self.mean) / self.std 78 | 79 | def inverse_transform(self, data): 80 | return (data * self.std) + self.mean 81 | 82 | 83 | def visual(true, preds=None, name='./pic/test.pdf'): 84 | """ 85 | Results visualization 86 | """ 87 | plt.figure() 88 | if preds is not None: 89 | plt.plot(preds, label='Prediction', linewidth=2) 90 | plt.plot(true, label='GroundTruth', linewidth=2) 91 | plt.legend() 92 | plt.savefig(name, bbox_inches='tight') 93 | 94 | 95 | def adjustment(gt, pred): 96 | anomaly_state = False 97 | for i in range(len(gt)): 98 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 99 | anomaly_state = True 100 | for j in range(i, 0, -1): 101 | if gt[j] == 0: 102 | break 103 | else: 104 | if pred[j] == 0: 105 | pred[j] = 1 106 | for j in range(i, len(gt)): 107 | if gt[j] == 0: 108 | break 109 | else: 110 | if pred[j] == 0: 111 | pred[j] = 1 112 | elif gt[i] == 0: 113 | anomaly_state = False 114 | if anomaly_state: 115 | pred[i] = 1 116 | return gt, pred 117 | 118 | 119 | def cal_accuracy(y_pred, y_true): 120 | return np.mean(y_pred == y_true) 121 | -------------------------------------------------------------------------------- /datasets/PhysioNet/link.txt: -------------------------------------------------------------------------------- 1 | https://github.com/johnweichow/PhysioNet-2012-Challenge/tree/master --------------------------------------------------------------------------------