├── .github └── workflows │ └── release.yml ├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── Sorted_Dataset_Info.csv ├── TimeGPT_experiments.py ├── __pycache__ ├── dataset.cpython-311.pyc ├── model.cpython-311.pyc └── utils.cpython-311.pyc ├── comp.py ├── config ├── chronos.json ├── chronos_dataset.json ├── chronosbolt.json ├── lptm.json ├── moirai.json ├── moment_base.json ├── moment_classification.json ├── moment_detection.json ├── moment_forecast.json ├── moment_imputation.json ├── timemoe.json ├── timesfm.json └── tinytimemixer.json ├── create_leaderboard.py ├── data ├── data │ ├── 198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out │ ├── ECG5000_TEST.csv │ ├── ECG5000_TEST.ts │ ├── ECG5000_TRAIN.csv │ ├── ECG5000_TRAIN.ts │ └── ETTh1.csv └── dataset │ └── timesfm_covid_pivot.csv ├── download_data.py ├── example ├── chronos.ipynb ├── chronosbolt.ipynb ├── colab │ ├── .gitkeep │ ├── LPTM.ipynb │ └── timesfm.ipynb ├── lptm.ipynb ├── lptm_zero.ipynb ├── moirai.ipynb ├── moment_anomaly_detection.ipynb ├── moment_classification.ipynb ├── moment_forecasting.ipynb ├── moment_imputation.ipynb ├── timemoe.ipynb ├── timesfm.ipynb └── tinytimemixer.ipynb ├── leaderboard.py ├── leaderboard ├── chronos.csv ├── chronosbolt.csv ├── lptm.csv ├── moirai.csv ├── moment.csv ├── monash_chronosbolt.csv ├── monash_lptm.csv ├── monash_moment.csv ├── monash_timesfm.csv ├── monash_ttm.csv ├── timesfm.csv └── ttm.csv ├── leaderboard_monash.py ├── plot_timeGPT.py ├── pyproject.toml ├── src ├── __init__.py ├── samay │ ├── __init__.py │ ├── dataset.py │ ├── metric.py │ ├── model.py │ ├── models │ │ ├── Time_MoE │ │ │ └── time_moe │ │ │ │ ├── __init__.py │ │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ ├── benchmark_dataset.py │ │ │ │ ├── binary_dataset.py │ │ │ │ ├── general_dataset.py │ │ │ │ ├── time_moe_dataset.py │ │ │ │ ├── time_moe_window_dataset.py │ │ │ │ └── ts_dataset.py │ │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_time_moe.py │ │ │ │ ├── modeling_time_moe.py │ │ │ │ └── ts_generation_mixin.py │ │ │ │ ├── runner.py │ │ │ │ ├── trainer │ │ │ │ ├── __init__.py │ │ │ │ └── hf_trainer.py │ │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── dist_util.py │ │ │ │ └── log_util.py │ │ ├── TinyTimeMixer │ │ │ ├── __init__.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ └── tinytimemixer │ │ │ │ │ ├── README.md │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── configuration_tinytimemixer.py │ │ │ │ │ ├── modeling_tinytimemixer.py │ │ │ │ │ └── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── ttm_args.py │ │ │ │ │ └── ttm_image.webp │ │ │ ├── resources │ │ │ │ ├── __init__.py │ │ │ │ ├── data_config │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── electricity.yaml │ │ │ │ │ ├── etth1.yaml │ │ │ │ │ ├── etth2.yaml │ │ │ │ │ ├── ettm1.yaml │ │ │ │ │ ├── ettm2.yaml │ │ │ │ │ ├── exchange.yaml │ │ │ │ │ ├── solar.yaml │ │ │ │ │ ├── traffic.yaml │ │ │ │ │ ├── weather.yaml │ │ │ │ │ └── zafnoo.yaml │ │ │ │ └── model_paths_config │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ttm.yaml │ │ │ ├── toolkit │ │ │ │ ├── __init__.py │ │ │ │ ├── callbacks.py │ │ │ │ ├── data_handling.py │ │ │ │ ├── dataset.py │ │ │ │ ├── get_model.py │ │ │ │ ├── lr_finder.py │ │ │ │ ├── recursive_predictor.py │ │ │ │ ├── time_series_forecasting_pipeline.py │ │ │ │ ├── time_series_preprocessor.py │ │ │ │ ├── util.py │ │ │ │ └── visualization.py │ │ │ └── version.py │ │ ├── __init__.py │ │ ├── chronosforecasting │ │ │ ├── chronos │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── chronos.py │ │ │ │ ├── chronos_bolt.py │ │ │ │ └── utils.py │ │ │ └── scripts │ │ │ │ ├── __init__.py │ │ │ │ ├── evaluate.py │ │ │ │ ├── finetune.py │ │ │ │ └── jsonlogger.py │ │ ├── lptm │ │ │ ├── __init__.py │ │ │ ├── model │ │ │ │ ├── backbone.py │ │ │ │ ├── layers.py │ │ │ │ └── masktrain.py │ │ │ ├── segment │ │ │ │ ├── __init__.py │ │ │ │ ├── scoring.py │ │ │ │ └── selection.py │ │ │ └── utils.py │ │ ├── moment │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── assets │ │ │ │ ├── MOMENT Logo.png │ │ │ │ ├── autonlab_logo.png │ │ │ │ ├── cmu_logo.png │ │ │ │ ├── moment_architecture.png │ │ │ │ └── moment_comparison .png │ │ │ ├── data │ │ │ │ ├── 198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out │ │ │ │ ├── ECG5000_TEST.ts │ │ │ │ ├── ECG5000_TRAIN.ts │ │ │ │ └── ETTh1.csv │ │ │ ├── momentfm │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ ├── data │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── anomaly_detection_dataset.py │ │ │ │ │ ├── classification_dataset.py │ │ │ │ │ ├── informer_dataset.py │ │ │ │ │ ├── ptbxl_classification_dataset.py │ │ │ │ │ └── synthetic_data.py │ │ │ │ ├── dataclass │ │ │ │ │ └── base.py │ │ │ │ ├── models │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── layers │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── embed.py │ │ │ │ │ │ └── revin.py │ │ │ │ │ ├── moment.py │ │ │ │ │ └── statistical_classifiers.py │ │ │ │ └── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── anomaly_detection_metrics.py │ │ │ │ │ ├── data.py │ │ │ │ │ ├── forecasting_metrics.py │ │ │ │ │ ├── masking.py │ │ │ │ │ └── utils.py │ │ │ ├── pyproject.toml │ │ │ ├── requirements.txt │ │ │ ├── setup.py │ │ │ └── tutorials │ │ │ │ ├── anomaly_detection.ipynb │ │ │ │ ├── classification.ipynb │ │ │ │ ├── finetune_demo │ │ │ │ ├── classification.py │ │ │ │ ├── classification.sh │ │ │ │ └── ds.yaml │ │ │ │ ├── forecasting.ipynb │ │ │ │ ├── imputation.ipynb │ │ │ │ ├── ptbxl_classification.ipynb │ │ │ │ └── representation_learning.ipynb │ │ └── timesfm │ │ │ ├── __init__.py │ │ │ ├── adapter │ │ │ ├── __init__.py │ │ │ ├── dora_layers.py │ │ │ ├── lora_layers.py │ │ │ └── utils.py │ │ │ ├── setup.py │ │ │ └── timesfm │ │ │ ├── __init__.py │ │ │ ├── data_loader.py │ │ │ ├── pytorch_patched_decoder.py │ │ │ ├── time_features.py │ │ │ ├── timesfm_base.py │ │ │ ├── timesfm_torch.py │ │ │ └── xreg_lib.py │ ├── moirai_utils.py │ ├── py.typed │ ├── trial.py │ ├── utils.py │ └── visualization.py └── uni2ts │ ├── __about__.py │ ├── __init__.py │ ├── callbacks │ ├── HuggingFaceCheckpoint.py │ └── __init__.py │ ├── cli │ ├── __init__.py │ ├── conf │ │ ├── eval │ │ │ ├── data │ │ │ │ ├── etth1_test.yaml │ │ │ │ ├── etth1_val.yaml │ │ │ │ ├── gluonts_test.yaml │ │ │ │ ├── gluonts_val.yaml │ │ │ │ ├── lsf_test.yaml │ │ │ │ ├── lsf_val.yaml │ │ │ │ └── monash.yaml │ │ │ ├── default.yaml │ │ │ └── model │ │ │ │ ├── moirai_1.0_R_base.yaml │ │ │ │ ├── moirai_1.0_R_large.yaml │ │ │ │ ├── moirai_1.0_R_small.yaml │ │ │ │ ├── moirai_1.1_R_base.yaml │ │ │ │ ├── moirai_1.1_R_large.yaml │ │ │ │ ├── moirai_1.1_R_small.yaml │ │ │ │ ├── moirai_lightning_ckpt.yaml │ │ │ │ ├── moirai_moe_1.0_R_base.yaml │ │ │ │ └── moirai_moe_1.0_R_small.yaml │ │ ├── finetune │ │ │ ├── data │ │ │ │ └── etth1.yaml │ │ │ ├── default.yaml │ │ │ ├── model │ │ │ │ ├── moirai_1.0_R_base.yaml │ │ │ │ ├── moirai_1.0_R_large.yaml │ │ │ │ ├── moirai_1.0_R_small.yaml │ │ │ │ ├── moirai_1.1_R_base.yaml │ │ │ │ ├── moirai_1.1_R_large.yaml │ │ │ │ ├── moirai_1.1_R_small.yaml │ │ │ │ ├── moirai_base.yaml │ │ │ │ ├── moirai_large.yaml │ │ │ │ ├── moirai_moe_1.0_R_small.yaml │ │ │ │ └── moirai_small.yaml │ │ │ └── val_data │ │ │ │ ├── etth1.yaml │ │ │ │ └── etth1_multi.yaml │ │ └── pretrain │ │ │ ├── data │ │ │ ├── buildings_900k.yaml │ │ │ ├── buildings_bench.yaml │ │ │ ├── cloudops_tsf.yaml │ │ │ ├── cmip6.yaml │ │ │ ├── era5.yaml │ │ │ ├── gluonts.yaml │ │ │ ├── largest.yaml │ │ │ ├── lib_city.yaml │ │ │ ├── lotsa_v1_unweighted.yaml │ │ │ ├── lotsa_v1_weighted.yaml │ │ │ ├── others.yaml │ │ │ ├── proenfo.yaml │ │ │ └── subseasonal.yaml │ │ │ ├── default.yaml │ │ │ └── model │ │ │ ├── moirai_base.yaml │ │ │ ├── moirai_large.yaml │ │ │ └── moirai_small.yaml │ ├── eval.py │ └── train.py │ ├── common │ ├── __init__.py │ ├── core.py │ ├── env.py │ ├── hydra_util.py │ ├── sampler.py │ ├── torch_util.py │ └── typing.py │ ├── data │ ├── __init__.py │ ├── builder │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── lotsa_v1 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ ├── _base.py │ │ │ ├── buildings_bench.py │ │ │ ├── cloudops_tsf.py │ │ │ ├── cmip6.py │ │ │ ├── era5.py │ │ │ ├── gluonts.py │ │ │ ├── largest.py │ │ │ ├── lib_city.py │ │ │ ├── others.py │ │ │ ├── proenfo.py │ │ │ └── subseasonal.py │ │ └── simple.py │ ├── dataset.py │ ├── indexer │ │ ├── __init__.py │ │ ├── _base.py │ │ └── hf_dataset_indexer.py │ └── loader.py │ ├── distribution │ ├── __init__.py │ ├── _base.py │ ├── laplace.py │ ├── log_normal.py │ ├── mixture.py │ ├── negative_binomial.py │ ├── normal.py │ ├── pareto.py │ └── student_t.py │ ├── eval_util │ ├── __init__.py │ ├── _hf_dataset.py │ ├── _lsf_dataset.py │ ├── _pf_dataset.py │ ├── data.py │ ├── evaluation.py │ ├── metrics.py │ └── plot.py │ ├── loss │ ├── __init__.py │ └── packed │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── distribution.py │ │ ├── normalized.py │ │ ├── percentage_error.py │ │ └── point.py │ ├── model │ ├── moirai │ │ ├── __init__.py │ │ ├── finetune.py │ │ ├── forecast.py │ │ ├── module.py │ │ └── pretrain.py │ └── moirai_moe │ │ ├── __init__.py │ │ ├── forecast.py │ │ └── module.py │ ├── module │ ├── __init__.py │ ├── attention.py │ ├── ffn.py │ ├── norm.py │ ├── packed_scaler.py │ ├── position │ │ ├── __init__.py │ │ ├── additive.py │ │ ├── attn_bias.py │ │ └── attn_projection.py │ ├── transformer.py │ └── ts_embed.py │ ├── optim │ ├── __init__.py │ └── lr_scheduler.py │ └── transform │ ├── __init__.py │ ├── _base.py │ ├── _mixin.py │ ├── crop.py │ ├── feature.py │ ├── field.py │ ├── imputation.py │ ├── pad.py │ ├── patch.py │ ├── resample.py │ ├── reshape.py │ └── task.py ├── transform_ILI.py ├── transform_gifteval.py ├── transform_monash.py └── uv.lock /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Build and Release Wheel 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Install uv 17 | uses: astral-sh/setup-uv@v5 18 | with: 19 | # Install a specific version of uv. 20 | version: "0.6.1" 21 | enable-cache: true 22 | 23 | - name: "Set up Python" 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version-file: "pyproject.toml" 27 | 28 | - name: Install dependencies 29 | run: uv sync --all-extras --dev 30 | 31 | - name: Build wheel 32 | run: uv build --wheel 33 | 34 | - name: Upload wheel to release 35 | uses: ncipollo/release-action@v1 36 | with: 37 | artifacts: dist/*.whl 38 | token: ${{ secrets.GITHUB_TOKEN }} 39 | tag: ${{ github.ref }} 40 | name: Release ${{ github.ref }} 41 | body: | 42 | This is an automated release of the wheel. 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | *.pyc 3 | *.pyo 4 | *.pyd 5 | __pycache__ 6 | *.so 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | .eggs 12 | *.log 13 | *.pot 14 | *.mo 15 | *.swp 16 | *.swo 17 | *.DS_Store 18 | *.coverage 19 | .coverage.* 20 | .cache 21 | .pytest_cache 22 | .tox 23 | htmlcov 24 | .ipynb_checkpoints 25 | .mypy_cache 26 | .pyre/ 27 | .pytype/ 28 | *.png 29 | .vscode/ 30 | 31 | 32 | /data -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /Sorted_Dataset_Info.csv: -------------------------------------------------------------------------------- 1 | ,Dataset,Domain,Frequency 2 | 0,m4_yearly,Econ/Fin,year 3 | 5,m4_quarterly,Econ/Fin,quarter 4 | 15,m4_monthly,Econ/Fin,month 5 | 14,m4_daily,Econ/Fin,day 6 | 13,m4_weekly,Econ/Fin,week 7 | 47,m4_hourly,Econ/Fin,hour 8 | 71,solar,Energy,10T 9 | 72,solar,Energy,hour 10 | 27,solar,Energy,day 11 | 80,electricity,Energy,15T 12 | 92,solar,Energy,10T 13 | 81,electricity,Energy,hour 14 | 83,ett1,Energy,hour 15 | 50,solar,Energy,hour 16 | 49,solar,Energy,10T 17 | 36,electricity,Energy,15T 18 | 37,electricity,Energy,hour 19 | 38,ett1,Energy,15T 20 | 39,ett1,Energy,hour 21 | 40,ett2,Energy,15T 22 | 41,ett2,Energy,hour 23 | 84,ett2,Energy,15T 24 | 82,ett1,Energy,15T 25 | 93,solar,Energy,hour 26 | 60,electricity,Energy,hour 27 | 61,ett1,Energy,15T 28 | 1,electricity,Energy,week 29 | 2,ett1,Energy,week 30 | 3,ett2,Energy,week 31 | 7,solar,Energy,week 32 | 64,ett2,Energy,hour 33 | 59,electricity,Energy,15T 34 | 63,ett2,Energy,15T 35 | 85,ett2,Energy,hour 36 | 17,electricity,Energy,day 37 | 18,ett1,Energy,day 38 | 19,ett2,Energy,day 39 | 62,ett1,Energy,hour 40 | 16,covid_deaths,Healthcare,day 41 | 8,us_births,Healthcare,week 42 | 29,us_births,Healthcare,day 43 | 12,us_births,Healthcare,month 44 | 10,hospital,Healthcare,month 45 | 65,jena_weather,Nature,10T 46 | 66,jena_weather,Nature,hour 47 | 88,kdd_cup_2018,Nature,hour 48 | 44,kdd_cup_2018,Nature,hour 49 | 43,jena_weather,Nature,hour 50 | 42,jena_weather,Nature,10T 51 | 87,jena_weather,Nature,hour 52 | 67,kdd_cup_2018,Nature,hour 53 | 6,saugeen,Nature,week 54 | 28,temperature_rain,Nature,day 55 | 11,saugeen,Nature,month 56 | 26,saugeen,Nature,day 57 | 86,jena_weather,Nature,10T 58 | 21,jena_weather,Nature,day 59 | 22,kdd_cup_2018,Nature,day 60 | 4,hierarchical_sales,Sales,week 61 | 9,car_parts,Sales,month 62 | 20,hierarchical_sales,Sales,day 63 | 25,restaurant,Sales,day 64 | 70,m_dense,Transport,hour 65 | 73,sz_taxi,Transport,15T 66 | 91,m_dense,Transport,hour 67 | 90,loop_seattle,Transport,hour 68 | 89,loop_seattle,Transport,5T 69 | 94,sz_taxi,Transport,15T 70 | 69,loop_seattle,Transport,hour 71 | 48,m_dense,Transport,hour 72 | 51,sz_taxi,Transport,15T 73 | 23,loop_seattle,Transport,day 74 | 24,m_dense,Transport,day 75 | 45,loop_seattle,Transport,5T 76 | 46,loop_seattle,Transport,hour 77 | 52,sz_taxi,Transport,hour 78 | 68,loop_seattle,Transport,5T 79 | 30,bitbrains_fast_storage,Web/CloudOps,5T 80 | 31,bitbrains_fast_storage,Web/CloudOps,hour 81 | 32,bitbrains_rnd,Web/CloudOps,5T 82 | 33,bitbrains_rnd,Web/CloudOps,hour 83 | 34,bizitobs_l2c,Web/CloudOps,5T 84 | 35,bizitobs_l2c,Web/CloudOps,hour 85 | 95,bizitobs_application,Web/CloudOps,10S 86 | 58,bizitobs_l2c,Web/CloudOps,hour 87 | 54,bizitobs_service,Web/CloudOps,10S 88 | 79,bizitobs_l2c,Web/CloudOps,hour 89 | 78,bizitobs_l2c,Web/CloudOps,5T 90 | 77,bitbrains_rnd,Web/CloudOps,5T 91 | 76,bitbrains_fast_storage,Web/CloudOps,5T 92 | 75,bizitobs_service,Web/CloudOps,10S 93 | 74,bizitobs_application,Web/CloudOps,10S 94 | 55,bitbrains_fast_storage,Web/CloudOps,5T 95 | 56,bitbrains_rnd,Web/CloudOps,5T 96 | 57,bizitobs_l2c,Web/CloudOps,5T 97 | 53,bizitobs_application,Web/CloudOps,10S 98 | 96,bizitobs_service,Web/CloudOps,10S 99 | -------------------------------------------------------------------------------- /TimeGPT_experiments.py: -------------------------------------------------------------------------------- 1 | from nixtla import NixtlaClient 2 | import pandas as pd 3 | import numpy as np 4 | from datasets import load_from_disk 5 | 6 | 7 | if __name__ == "__main__": 8 | API_KEY = ["nixak-Dcl3rmoqOEqgaNK1jd30zNLN5vhoc34loGaljdTgARJBzHeJNZuSKDwWd7azFsUGvTBoB6qjgNIp5J4k", 9 | "nixak-lgkeiACnUJx7jbslOQwP4qgYByjbKkmHG2iCiLjo2Ymy7B8tEEXyo26JFzKToXDWwAvK3i8u98uxnDph", 10 | "nixak-AsJr2gE9btKpfOCp654eKQX47ALZwWdoArI5gZNN5LQMIGDeO1SFeEsSpceB2fFdMKYtmxomU47vg8N4", 11 | "nixak-IlppDwy73vQjqbEzxMGrMx0PGYcn358jCrzEBYxv1OGjtcvSy3YIayC5wmNsFtDnOTXCF8vnwvKsEWU1"] 12 | for horizon in [1, 2, 3, 4]: 13 | client = NixtlaClient( 14 | api_key=API_KEY[horizon-1], 15 | ) 16 | client.validate_api_key() 17 | df = pd.read_csv("data/Flu_USA/Flu_USA.csv") 18 | context = 40 19 | chunks = [df.iloc[i:i+context] for i in range(0, len(df), context)] 20 | maes = [] 21 | 22 | for i, chunk in enumerate(chunks): 23 | start_idx = chunk.index[-1] + 1 24 | end_idx = start_idx + horizon 25 | if end_idx > len(df): 26 | print("End of data") 27 | break 28 | if len(chunk) < context: 29 | print("End of data") 30 | break 31 | df_horizon = df.iloc[start_idx:end_idx] 32 | forecast = client.forecast(chunk, h=horizon, target_col="% WEIGHTED ILI", time_col="date", model='timegpt-1') 33 | pred = forecast["TimeGPT"].values 34 | true = df_horizon["% WEIGHTED ILI"].values 35 | mae = (pred - true).mean() 36 | maes.append(mae) 37 | 38 | mae = np.mean(maes) 39 | print(f"Horizon {horizon}: MAE {mae}") 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/__pycache__/dataset.cpython-311.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/__pycache__/model.cpython-311.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /config/chronos.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "amazon/chronos-t5-small", 3 | "config": { 4 | "num_layers": 6 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /config/chronos_dataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "tokenizer_class": "MeanScaleUniformBins", 3 | "tokenizer_kwargs": {"low_limit": -15.0, "high_limit": 15.0}, 4 | "n_tokens": 4096, 5 | "n_special_tokens": 2, 6 | "pad_token_id": 0, 7 | "eos_token_id": 1, 8 | "use_eos_token": true, 9 | "model_type": "seq2seq", 10 | "context_length": 512, 11 | "prediction_length": 64, 12 | "num_samples": 20, 13 | "temperature": 1.0, 14 | "top_k": 50, 15 | "top_p": 1.0 16 | } -------------------------------------------------------------------------------- /config/chronosbolt.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "autogluon/chronos-bolt-small", 3 | "architectures": [ 4 | "ChronosBoltModelForForecasting" 5 | ], 6 | "chronos_config": { 7 | "context_length": 2048, 8 | "input_patch_size": 16, 9 | "input_patch_stride": 16, 10 | "prediction_length": 64, 11 | "quantiles": [ 12 | 0.1, 13 | 0.2, 14 | 0.3, 15 | 0.4, 16 | 0.5, 17 | 0.6, 18 | 0.7, 19 | 0.8, 20 | 0.9 21 | ], 22 | "use_reg_token": true 23 | }, 24 | "chronos_pipeline_class": "ChronosBoltPipeline", 25 | "classifier_dropout": 0.0, 26 | "d_ff": 2048, 27 | "d_kv": 64, 28 | "d_model": 512, 29 | "decoder_start_token_id": 0, 30 | "dense_act_fn": "relu", 31 | "dropout_rate": 0.1, 32 | "eos_token_id": 1, 33 | "feed_forward_proj": "relu", 34 | "initializer_factor": 0.05, 35 | "is_encoder_decoder": true, 36 | "is_gated_act": false, 37 | "layer_norm_epsilon": 1e-06, 38 | "model_type": "t5", 39 | "n_positions": 512, 40 | "num_decoder_layers": 6, 41 | "num_heads": 8, 42 | "num_layers": 6, 43 | "pad_token_id": 0, 44 | "reg_token_id": 1, 45 | "relative_attention_max_distance": 128, 46 | "relative_attention_num_buckets": 32, 47 | "torch_dtype": "float32", 48 | "transformers_version": "4.39.3", 49 | "use_cache": true, 50 | "vocab_size": 2 51 | } 52 | 53 | -------------------------------------------------------------------------------- /config/lptm.json: -------------------------------------------------------------------------------- 1 | { 2 | "config": { 3 | "task_name": "forecasting2", 4 | "forecast_horizon": 192, 5 | "head_dropout": 0.1, 6 | "weight_decay": 0, 7 | "freeze_encoder": true, 8 | "freeze_embedder": true, 9 | "freeze_head": false 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /config/moirai.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "Salesforce/moirai-moe-1.0-R-small", 3 | "config": { 4 | "context_len": 128, 5 | "horizon_len": 64, 6 | "num_layers": 100, 7 | "model_type": "moirai-moe", 8 | "model_size": "small" 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /config/moment_base.json: -------------------------------------------------------------------------------- 1 | {"task_name": "reconstruction", 2 | "model_name": "MOMENT", 3 | "transformer_type": "encoder_only", 4 | "d_model": null, 5 | "seq_len": 512, 6 | "patch_len": 8, 7 | "patch_stride_len": 8, 8 | "device": "cpu", 9 | "transformer_backbone": "google/flan-t5-large", 10 | "model_kwargs": {}, 11 | "t5_config": { 12 | "architectures": [ 13 | "T5ForConditionalGeneration" 14 | ], 15 | "d_ff": 2816, 16 | "d_kv": 64, 17 | "d_model": 1024, 18 | "decoder_start_token_id": 0, 19 | "dropout_rate": 0.1, 20 | "eos_token_id": 1, 21 | "feed_forward_proj": "gated-gelu", 22 | "initializer_factor": 1.0, 23 | "is_encoder_decoder": true, 24 | "layer_norm_epsilon": 1e-06, 25 | "model_type": "t5", 26 | "n_positions": 512, 27 | "num_decoder_layers": 24, 28 | "num_heads": 16, 29 | "num_layers": 24, 30 | "output_past": true, 31 | "pad_token_id": 0, 32 | "relative_attention_max_distance": 128, 33 | "relative_attention_num_buckets": 32, 34 | "tie_word_embeddings": false, 35 | "use_cache": true, 36 | "vocab_size": 32128 37 | } 38 | } -------------------------------------------------------------------------------- /config/moment_classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "AutonLab/MOMENT-1-large", 3 | "config": { 4 | "task_name": "classification", 5 | "n_channels": 1, 6 | "num_class": 5 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /config/moment_detection.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "AutonLab/MOMENT-1-large", 3 | "config": { 4 | "task_name": "reconstruction" 5 | } 6 | } -------------------------------------------------------------------------------- /config/moment_forecast.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "AutonLab/MOMENT-1-large", 3 | "config": { 4 | "task_name": "forecasting", 5 | "forecast_horizon": 192, 6 | "head_dropout": 0.1, 7 | "weight_decay": 0, 8 | "freeze_encoder": true, 9 | "freeze_embedder": true, 10 | "freeze_head": false 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /config/moment_imputation.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "AutonLab/MOMENT-1-large", 3 | "config": { 4 | "task_name": "reconstruction" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /config/timemoe.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "Maple728/TimeMoE-50M", 3 | "config": { 4 | "input_size": 1, 5 | "hidden_size": 4096, 6 | "intermediate_size": 22016, 7 | "horizon_lengths": 1, 8 | "num_hidden_layers": 32, 9 | "num_attention_heads": 32, 10 | "num_key_value_heads": null, 11 | "hidden_act": "silu", 12 | "num_experts_per_tok": 2, 13 | "num_experts": 1, 14 | "max_position_embeddings": 32768, 15 | "initializer_range": 0.02, 16 | "rms_norm_eps": 1e-6, 17 | "use_cache": true, 18 | "use_dense": false, 19 | "rope_theta": 10000, 20 | "attention_dropout": 0.0, 21 | "apply_aux_loss": true, 22 | "router_aux_loss_factor": 0.02, 23 | "tie_word_embeddings": false 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /config/timesfm.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "google/timesfm-2.0-500m-pytorch", 3 | "config": { 4 | "context_len": 512, 5 | "horizon_len": 96, 6 | "backend": "gpu", 7 | "per_core_batch_size": 32, 8 | "input_patch_len": 32, 9 | "output_patch_len": 128, 10 | "use_positional_embedding": false, 11 | "num_layers": 50, 12 | "quantiles": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /config/tinytimemixer.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": "ibm-granite/granite-timeseries-ttm-r2", 3 | "config": { 4 | "context_len": 512, 5 | "horizon_len": 96 6 | } 7 | } -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import Dataset, load_dataset 4 | 5 | if __name__ == "__main__": 6 | save_dir = "data/monash" 7 | if not os.path.exists(save_dir): 8 | os.makedirs(save_dir) 9 | dataset_names = [ 10 | "weather", 11 | "tourism_yearly", 12 | "tourism_quarterly", 13 | "tourism_monthly", 14 | "cif_2016", 15 | "london_smart_meters", 16 | "australian_electricity_demand", 17 | "wind_farms_minutely", 18 | "bitcoin", 19 | "pedestrian_counts", 20 | "vehicle_trips", 21 | "kdd_cup_2018", 22 | "nn5_daily", 23 | "nn5_weekly", 24 | "kaggle_web_traffic", 25 | "kaggle_web_traffic_weekly", 26 | "solar_10_minutes", 27 | "solar_weekly", 28 | "car_parts", 29 | "fred_md", 30 | "traffic_hourly", 31 | "traffic_weekly", 32 | "hospital", 33 | "covid_deaths", 34 | "sunspot", 35 | "saugeenday", 36 | "us_births", 37 | "solar_4_seconds", 38 | "wind_4_seconds", 39 | "rideshare", 40 | "oikolab_weather", 41 | "temperature_rain", 42 | ] 43 | for dataset_name in dataset_names: 44 | dataset: Dataset = load_dataset("monash_tsf", dataset_name) # type: ignore 45 | dataset.save_to_disk(f"{save_dir}/{dataset_name}") 46 | print(f"Downloaded {dataset_name} dataset") 47 | -------------------------------------------------------------------------------- /example/colab/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/example/colab/.gitkeep -------------------------------------------------------------------------------- /plot_timeGPT.py: -------------------------------------------------------------------------------- 1 | from nixtla import NixtlaClient 2 | import pandas as pd 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | 7 | if __name__ == "__main__": 8 | API_KEY = "nixak-Dcl3rmoqOEqgaNK1jd30zNLN5vhoc34loGaljdTgARJBzHeJNZuSKDwWd7azFsUGvTBoB6qjgNIp5J4k" 9 | client = NixtlaClient( 10 | api_key=API_KEY, 11 | ) 12 | client.validate_api_key() 13 | df = pd.read_csv("data/Flu_USA/Flu_USA.csv") 14 | context = 40 15 | chunks = [df.iloc[i:i+context] for i in range(0, len(df), context)] 16 | chunk_ten = chunks[9] 17 | horizon = 4 18 | start_idx = chunk_ten.index[-1] + 1 19 | end_idx = start_idx + horizon 20 | df_horizon = df.iloc[start_idx:end_idx] 21 | forecast = client.forecast(chunk_ten, h=horizon, target_col="% WEIGHTED ILI", time_col="date", model='timegpt-1') 22 | pred = forecast["TimeGPT"].values 23 | true = df_horizon["% WEIGHTED ILI"].values 24 | history = chunk_ten["% WEIGHTED ILI"].values 25 | # save history, pred, true to a txt file 26 | np.savetxt("data/plot_TimeGPT.txt", history) 27 | np.savetxt("data/plot_TimeGPT.txt", pred) 28 | np.savetxt("data/plot_TimeGPT.txt", true) 29 | plt.plot(range(len(history)), history, label="History") 30 | plt.plot(range(len(history), len(history)+horizon), true, label="True") 31 | plt.plot(range(len(history), len(history)+horizon), pred, label="Pred") 32 | plt.legend() 33 | plt.savefig("data/plot_TimeGPT.png") 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "samay" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | authors = [{ name = "kage08", email = "harshavardhan864.hk@gmail.com" }] 7 | requires-python = ">=3.11" 8 | dependencies = [ 9 | "absl-py>=2.1.0", 10 | "datasets>=3.2.0", 11 | "einshape>=1.0", 12 | "gluonts>=0.16.0", 13 | "huggingface-hub>=0.26.2", 14 | "matplotlib>=3.10.0", 15 | "numpy>=2.1.3", 16 | "pandas>=2.2.3", 17 | "scikit-learn>=1.5.2", 18 | "torch>=2.5.1", 19 | "transformers>=4.47.0", 20 | "typer-config>=1.4.2", 21 | "typer>=0.15.1", 22 | "utilsforecast>=0.2.7", 23 | "datasets>=3.2.0", 24 | "chronos-forecasting>=1.4.1", 25 | "tensorboardx>=2.6.2.2", 26 | "einops>=0.8.1", 27 | "hydra-core>=1.3.2", 28 | "jax>=0.5.3", 29 | "jaxtyping>=0.3.0", 30 | "torchvision>=0.20.1", 31 | "lightning>=2.5.1", 32 | ] 33 | 34 | [build-system] 35 | requires = ["hatchling"] 36 | build-backend = "hatchling.build" 37 | 38 | [dependency-groups] 39 | dev = [ 40 | "jupyter>=1.1.1", 41 | "pre-commit>=4.0.1", 42 | "pytest>=8.3.3", 43 | "ruff>=0.8.1", 44 | "wandb>=0.18.5", 45 | ] 46 | 47 | [tool.hatch.build.targets.wheel] 48 | packages = ["src/samay", "src/uni2ts"] 49 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/__init__.py -------------------------------------------------------------------------------- /src/samay/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/__init__.py -------------------------------------------------------------------------------- /src/samay/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def MSE(y_true: np.ndarray, y_pred: np.ndarray): 5 | """Mean squared error""" 6 | return np.mean((y_true - y_pred) ** 2) 7 | 8 | 9 | def MAE(y_true: np.ndarray, y_pred: np.ndarray): 10 | """Mean absolute error""" 11 | return np.mean(np.abs(y_true - y_pred)) 12 | 13 | 14 | def MASE(y_true: np.ndarray, y_pred: np.ndarray, freq: str = "h"): 15 | """Mean absolute scaled error""" 16 | DEFAULT_SEASONALITIES = { 17 | "S": 3600, # 1 hour 18 | "s": 3600, # 1 hour 19 | "T": 1440, # 1 day 20 | "min": 1440, # 1 day 21 | "H": 24, # 1 day 22 | "h": 24, # 1 day 23 | "D": 1, # 1 day 24 | "W": 1, # 1 week 25 | "M": 12, 26 | "ME": 12, 27 | "B": 5, 28 | "Q": 4, 29 | "QE": 4, 30 | } 31 | # seasonality = DEFAULT_SEASONALITIES[freq] 32 | y_t = y_true[:, :, 1:] - y_true[:, :, :-1] 33 | return np.mean(np.abs(y_true - y_pred) / (np.mean(np.abs(y_t)) + 1e-5)) 34 | 35 | 36 | def MAPE(y_true: np.ndarray, y_pred: np.ndarray): 37 | """Mean absolute percentage error""" 38 | return np.mean(np.abs(y_true - y_pred) / (y_true + 1e-5)) 39 | 40 | 41 | def RMSE(y_true: np.ndarray, y_pred: np.ndarray): 42 | """Root mean squared error""" 43 | return np.sqrt(MSE(y_true, y_pred)) 44 | 45 | 46 | def NRMSE(y_true: np.ndarray, y_pred: np.ndarray): 47 | """Normalized root mean squared error""" 48 | return RMSE(y_true, y_pred) / (np.max(y_true) - np.min(y_true) + 1e-5) 49 | 50 | 51 | def SMAPE(y_true: np.ndarray, y_pred: np.ndarray): 52 | """Symmetric mean absolute percentage error""" 53 | return np.mean( 54 | 2.0 * np.abs(y_true - y_pred) / (np.abs(y_true) + np.abs(y_pred) + 1e-5) 55 | ) 56 | 57 | 58 | def MSIS(y_true: np.ndarray, y_pred: np.ndarray, alpha: float = 0.05): 59 | """Mean scaled interval score""" 60 | q1 = np.percentile(y_true, 100 * alpha / 2) 61 | q2 = np.percentile(y_true, 100 * (1 - alpha / 2)) 62 | denominator = q2 - q1 63 | penalties = 2 * ((y_true < q1) * (q1 - y_pred) + (y_true > q2) * (y_pred - q2)) 64 | return np.mean(np.abs(y_true - y_pred) / (denominator + 1e-5)) + np.mean( 65 | penalties / (denominator + 1e-5) 66 | ) 67 | 68 | 69 | def ND(y_true: np.ndarray, y_pred: np.ndarray): 70 | """Normalized deviation""" 71 | return np.mean(np.abs(y_true - y_pred)) / (np.mean(y_true) + 1e-5) 72 | 73 | 74 | def MWSQ(y_true: np.ndarray, y_pred: np.ndarray, quantiles: np.ndarray): 75 | """Mean weighted squared quantile loss""" 76 | 77 | def quantile_loss(y_true, y_pred, q): 78 | return np.maximum(q * (y_true - y_pred), (q - 1) * (y_true - y_pred)).mean() 79 | 80 | return np.mean([quantile_loss(y_true, y_pred, q) for q in quantiles]) 81 | 82 | 83 | def CRPS(y_true: np.ndarray, y_pred: np.ndarray, quantiles: np.ndarray): 84 | """Continuous ranked probability score""" 85 | crps = np.mean( 86 | (y_pred - y_true) ** 2 * np.abs(quantiles - (y_true <= y_pred).astype(float)) 87 | ) 88 | return crps 89 | -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | from .binary_dataset import BinaryDataset 4 | from .general_dataset import GeneralDataset 5 | -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/datasets/general_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | import json 4 | import os 5 | import pickle 6 | import gzip 7 | import yaml 8 | import numpy as np 9 | 10 | from .ts_dataset import TimeSeriesDataset 11 | 12 | 13 | class GeneralDataset(TimeSeriesDataset): 14 | def __init__(self, data_path): 15 | self.data = read_file_by_extension(data_path) 16 | self.num_tokens = None 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | def __getitem__(self, seq_idx): 22 | seq = self.data[seq_idx] 23 | if isinstance(seq, dict): 24 | seq = seq['sequence'] 25 | return seq 26 | 27 | def get_num_tokens(self): 28 | if self.num_tokens is None: 29 | self.num_tokens = sum([len(seq) for seq in self]) 30 | return self.num_tokens 31 | 32 | def get_sequence_length_by_idx(self, seq_idx): 33 | seq = self[seq_idx] 34 | return len(seq) 35 | 36 | @staticmethod 37 | def is_valid_path(data_path): 38 | if os.path.exists(data_path) and os.path.isfile(data_path): 39 | parts = data_path.split('.') 40 | if len(parts) == 0: 41 | return False 42 | suffix = parts[-1] 43 | if suffix in ('json', 'jsonl', 'npy', 'npy.gz', 'pkl'): 44 | return True 45 | else: 46 | return False 47 | else: 48 | return False 49 | 50 | 51 | def read_file_by_extension(fn): 52 | if fn.endswith('.json'): 53 | with open(fn, encoding='utf-8') as file: 54 | data = json.load(file) 55 | elif fn.endswith('.jsonl'): 56 | data = read_jsonl_to_list(fn) 57 | elif fn.endswith('.yaml'): 58 | data = load_yaml_file(fn) 59 | elif fn.endswith('.npy'): 60 | data = np.load(fn, allow_pickle=True) 61 | elif fn.endswith('.npz'): 62 | data = np.load(fn, allow_pickle=True) 63 | elif fn.endswith('.npy.gz'): 64 | with gzip.GzipFile(fn, 'r') as file: 65 | data = np.load(file, allow_pickle=True) 66 | elif fn.endswith('.pkl') or fn.endswith('.pickle'): 67 | data = load_pkl_obj(fn) 68 | else: 69 | raise RuntimeError(f'Unknown file extension: {fn}') 70 | return data 71 | 72 | 73 | def read_jsonl_to_list(jsonl_fn): 74 | with open(jsonl_fn, 'r', encoding='utf-8') as file: 75 | return [json.loads(line) for line in file.readlines()] 76 | 77 | 78 | def load_yaml_file(fn): 79 | if isinstance(fn, str): 80 | with open(fn, 'r', encoding="utf-8") as f: 81 | config = yaml.safe_load(f) 82 | return config 83 | else: 84 | return fn 85 | 86 | 87 | def load_pkl_obj(fn): 88 | out_list = [] 89 | with open(fn, 'rb') as f: 90 | while True: 91 | try: 92 | data = pickle.load(f) 93 | out_list.append(data) 94 | except EOFError: 95 | break 96 | if len(out_list) == 0: 97 | return None 98 | elif len(out_list) == 1: 99 | return out_list[0] 100 | else: 101 | return out_list 102 | -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/datasets/ts_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | from abc import abstractmethod 4 | 5 | 6 | class TimeSeriesDataset: 7 | @abstractmethod 8 | def __len__(self): 9 | pass 10 | 11 | @abstractmethod 12 | def __getitem__(self, seq_idx): 13 | pass 14 | 15 | @abstractmethod 16 | def get_num_tokens(self): 17 | pass 18 | 19 | @abstractmethod 20 | def get_sequence_length_by_idx(self, seq_idx): 21 | pass 22 | 23 | @staticmethod 24 | def is_valid_path(data_path): 25 | return True 26 | 27 | def __iter__(self): 28 | n_seqs = len(self) 29 | for i in range(n_seqs): 30 | yield self[i] -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/Time_MoE/time_moe/models/__init__.py -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/models/configuration_time_moe.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from transformers import PretrainedConfig 3 | 4 | 5 | class TimeMoeConfig(PretrainedConfig): 6 | model_type = "time_moe" 7 | keys_to_ignore_at_inference = ["past_key_values"] 8 | 9 | def __init__( 10 | self, 11 | input_size: int = 1, 12 | hidden_size: int = 4096, 13 | intermediate_size: int = 22016, 14 | horizon_lengths: List[int] = 1, 15 | num_hidden_layers: int = 32, 16 | num_attention_heads: int = 32, 17 | num_key_value_heads: int = None, 18 | hidden_act: str = "silu", 19 | num_experts_per_tok: int = 2, 20 | num_experts: int = 1, 21 | max_position_embeddings: int = 32768, 22 | initializer_range: float = 0.02, 23 | rms_norm_eps: float = 1e-6, 24 | use_cache: bool = True, 25 | use_dense: bool = False, 26 | rope_theta: int = 10000, 27 | attention_dropout: float = 0.0, 28 | apply_aux_loss: bool = True, 29 | router_aux_loss_factor: float = 0.02, 30 | tie_word_embeddings: bool = False, 31 | **kwargs, 32 | ): 33 | self.input_size = input_size 34 | self.hidden_size = hidden_size 35 | self.intermediate_size = intermediate_size 36 | self.max_position_embeddings = max_position_embeddings 37 | self.num_hidden_layers = num_hidden_layers 38 | self.num_attention_heads = num_attention_heads 39 | 40 | if num_key_value_heads is None: 41 | num_key_value_heads = num_attention_heads 42 | 43 | self.num_key_value_heads = num_key_value_heads 44 | self.hidden_act = hidden_act 45 | if isinstance(horizon_lengths, int): 46 | horizon_lengths = [horizon_lengths] 47 | self.horizon_lengths = horizon_lengths # Predict horizon length for each prediction. 48 | self.num_experts_per_tok = num_experts_per_tok 49 | self.num_experts = num_experts 50 | self.initializer_range = initializer_range 51 | self.rms_norm_eps = rms_norm_eps 52 | self.use_cache = use_cache 53 | self.use_dense = use_dense 54 | self.rope_theta = rope_theta 55 | self.attention_dropout = attention_dropout 56 | self.apply_aux_loss = apply_aux_loss 57 | self.router_aux_loss_factor = router_aux_loss_factor 58 | 59 | assert self.use_dense ^ self.apply_aux_loss, 'Both use_dense and apply_aux_loss cannot be set to True or False at the same time.' 60 | 61 | kwargs.pop('tie_word_embeddings', None) 62 | super().__init__( 63 | tie_word_embeddings=tie_word_embeddings, 64 | **kwargs, 65 | ) 66 | -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/Time_MoE/time_moe/trainer/__init__.py -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/Time_MoE/time_moe/utils/__init__.py -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | import os 4 | import torch.distributed as dist 5 | 6 | 7 | def is_master_process(): 8 | rank = os.getenv('RANK') 9 | if (rank is None or rank == '0') and is_local_rank_0(): 10 | return True 11 | else: 12 | return False 13 | 14 | 15 | def is_local_rank_0(): 16 | local_rank = os.getenv('LOCAL_RANK') 17 | if local_rank is None or local_rank == '0': 18 | return True 19 | else: 20 | return False 21 | 22 | 23 | def get_local_world_size(): 24 | import torch 25 | local_world_size = os.getenv('LOCAL_WORLD_SIZE') 26 | if local_world_size is None: 27 | num_gpus = torch.cuda.device_count() 28 | local_world_size = num_gpus or 1 29 | else: 30 | local_world_size = int(local_world_size) 31 | return local_world_size 32 | 33 | 34 | def get_world_size(): 35 | try: 36 | world_size = dist.get_world_size() 37 | return world_size 38 | except Exception: 39 | pass 40 | world_size = os.getenv('WORLD_SIZE') 41 | if world_size is None: 42 | world_size = 1 43 | else: 44 | world_size = int(world_size) 45 | return world_size 46 | -------------------------------------------------------------------------------- /src/samay/models/Time_MoE/time_moe/utils/log_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | import os 4 | import logging 5 | import sys 6 | import typing 7 | 8 | # -------- log setting --------- 9 | DEFAULT_LOGGER = "time_moe_logger" 10 | 11 | DEFAULT_FORMATTER = logging.Formatter( 12 | '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s] - %(levelname)s: %(message)s' 13 | ) 14 | 15 | _ch = logging.StreamHandler(stream=sys.stdout) 16 | _ch.setFormatter(DEFAULT_FORMATTER) 17 | 18 | _DEFAULT_HANDLERS = [_ch] 19 | 20 | _LOGGER_CACHE = {} # type: typing.Dict[str, logging.Logger] 21 | 22 | 23 | def is_local_rank_0(): 24 | local_rank = os.getenv('LOCAL_RANK') 25 | if local_rank is None or local_rank == '0': 26 | return True 27 | else: 28 | return False 29 | 30 | 31 | def get_logger(name, level="INFO", handlers=None, update=False): 32 | if name in _LOGGER_CACHE and not update: 33 | return _LOGGER_CACHE[name] 34 | logger = logging.getLogger(name) 35 | logger.setLevel(level) 36 | logger.handlers = handlers or _DEFAULT_HANDLERS 37 | logger.propagate = False 38 | return logger 39 | 40 | 41 | def log_in_local_rank_0(*msg, type='info', used_logger=None): 42 | msg = ' '.join([str(s) for s in msg]) 43 | if used_logger is None: 44 | used_logger = logger 45 | 46 | if is_local_rank_0(): 47 | if type == 'warn' or type == 'warning': 48 | used_logger.warning(msg) 49 | elif type == 'error': 50 | used_logger.error(msg) 51 | else: 52 | used_logger.info(msg) 53 | 54 | 55 | # -------------------------- Singleton Object -------------------------- 56 | logger = get_logger(DEFAULT_LOGGER) 57 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | 4 | import logging 5 | import os 6 | from pathlib import Path 7 | from typing import TYPE_CHECKING 8 | 9 | # Check the dependencies satisfy the minimal versions required. 10 | from transformers.utils import _LazyModule 11 | 12 | from .version import __version__, __version_tuple__ 13 | 14 | 15 | TSFM_PYTHON_LOGGING_LEVEL = os.getenv("TSFM_PYTHON_LOGGING_LEVEL", "INFO") 16 | 17 | LevelNamesMapping = { 18 | "INFO": logging.INFO, 19 | "WARN": logging.WARN, 20 | "WARNING": logging.WARNING, 21 | "ERROR": logging.ERROR, 22 | "CRITICAL": logging.CRITICAL, 23 | "DEBUG": logging.DEBUG, 24 | "FATAL": logging.FATAL, 25 | } 26 | 27 | TSFM_PYTHON_LOGGING_LEVEL = ( 28 | logging.getLevelNamesMapping()[TSFM_PYTHON_LOGGING_LEVEL] 29 | if hasattr(logging, "getLevelNamesMapping") 30 | else LevelNamesMapping[TSFM_PYTHON_LOGGING_LEVEL] 31 | ) 32 | TSFM_PYTHON_LOGGING_FORMAT = os.getenv( 33 | "TSFM_PYTHON_LOGGING_FORMAT", 34 | "%(levelname)s:p-%(process)d:t-%(thread)d:%(filename)s:%(funcName)s:%(message)s", 35 | ) 36 | 37 | logging.basicConfig( 38 | format=TSFM_PYTHON_LOGGING_FORMAT, 39 | level=TSFM_PYTHON_LOGGING_LEVEL, 40 | ) 41 | 42 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 43 | 44 | # Base objects, independent of any specific backend 45 | _import_structure = { 46 | "models": [], 47 | "models.tinytimemixer": ["TINYTIMEMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TinyTimeMixerConfig"], 48 | "toolkit": [ 49 | "TimeSeriesPreprocessor", 50 | "TimeSeriesForecastingPipeline", 51 | "ForecastDFDataset", 52 | "PretrainDFDataset", 53 | "RegressionDFDataset", 54 | "get_datasets", 55 | "load_dataset", 56 | "TrackingCallback", 57 | "count_parameters", 58 | ], 59 | } 60 | 61 | 62 | # PyTorch-backed objects 63 | _import_structure["models.tinytimemixer"].extend( 64 | [ 65 | "TINYTIMEMIXER_PRETRAINED_MODEL_ARCHIVE_LIST", 66 | "TinyTimeMixerPreTrainedModel", 67 | "TinyTimeMixerModel", 68 | "TinyTimeMixerForMaskedPrediction", 69 | "TinyTimeMixerForPrediction", 70 | ] 71 | ) 72 | 73 | # Direct imports for type-checking 74 | if TYPE_CHECKING: 75 | from .models.tinytimemixer import ( 76 | TINYTIMEMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP, 77 | TINYTIMEMIXER_PRETRAINED_MODEL_ARCHIVE_LIST, 78 | TinyTimeMixerConfig, 79 | TinyTimeMixerForMaskedPrediction, 80 | TinyTimeMixerForPrediction, 81 | TinyTimeMixerModel, 82 | TinyTimeMixerPreTrainedModel, 83 | ) 84 | from .toolkit import ( 85 | ForecastDFDataset, 86 | PretrainDFDataset, 87 | RegressionDFDataset, 88 | TimeSeriesForecastingPipeline, 89 | TimeSeriesPreprocessor, 90 | TrackingCallback, 91 | count_parameters, 92 | get_datasets, 93 | load_dataset, 94 | ) 95 | else: 96 | # Standard 97 | import sys 98 | 99 | sys.modules[__name__] = _LazyModule( 100 | __name__, 101 | globals()["__file__"], 102 | _import_structure, 103 | module_spec=__spec__, 104 | extra_objects={"__version__": __version__}, 105 | ) 106 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | from . import tinytimemixer 4 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/models/tinytimemixer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | 4 | from typing import TYPE_CHECKING 5 | 6 | # rely on isort to merge the imports 7 | from transformers.utils import ( 8 | OptionalDependencyNotAvailable, 9 | _LazyModule, 10 | is_torch_available, 11 | ) 12 | 13 | 14 | _import_structure = { 15 | "configuration_tinytimemixer": [ 16 | "TINYTIMEMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP", 17 | "TinyTimeMixerConfig", 18 | ], 19 | } 20 | 21 | try: 22 | if not is_torch_available(): 23 | raise OptionalDependencyNotAvailable() 24 | except OptionalDependencyNotAvailable: 25 | pass 26 | else: 27 | _import_structure["modeling_tinytimemixer"] = [ 28 | "TINYTIMEMIXER_PRETRAINED_MODEL_ARCHIVE_LIST", 29 | "TinyTimeMixerModel", 30 | "TinyTimeMixerForPrediction", 31 | "TinyTimeMixerForMaskedPrediction", 32 | ] 33 | 34 | _import_structure["utils_tinytimemixer"] = [ 35 | "get_freq_mapping", 36 | "get_freq_token", 37 | "get_freq_vocab_size", 38 | ] 39 | 40 | 41 | if TYPE_CHECKING: 42 | from .configuration_tinytimemixer import ( 43 | TINYTIMEMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP, 44 | TinyTimeMixerConfig, 45 | ) 46 | 47 | try: 48 | if not is_torch_available(): 49 | raise OptionalDependencyNotAvailable() 50 | except OptionalDependencyNotAvailable: 51 | pass 52 | else: 53 | from .modeling_tinytimemixer import ( 54 | TINYTIMEMIXER_PRETRAINED_MODEL_ARCHIVE_LIST, 55 | TinyTimeMixerForMaskedPrediction, 56 | TinyTimeMixerForPrediction, 57 | TinyTimeMixerModel, 58 | ) 59 | from .utils_tinytimemixer import ( 60 | get_freq_mapping, 61 | get_freq_token, 62 | get_freq_vocab_size, 63 | ) 64 | 65 | else: 66 | import sys 67 | 68 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 69 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/models/tinytimemixer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | from tsfm_public.models.tinytimemixer.utils.ttm_args import get_ttm_args 4 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/models/tinytimemixer/utils/ttm_image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/TinyTimeMixer/models/tinytimemixer/utils/ttm_image.webp -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/electricity.yaml: -------------------------------------------------------------------------------- 1 | data_file: electricity.csv 2 | data_path: electricity/ 3 | id_columns: [] 4 | timestamp_column: date 5 | target_columns: [] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 1h 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 0.7 20 | test: 0.2 21 | 22 | 23 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/etth1.yaml: -------------------------------------------------------------------------------- 1 | data_file: ETTh1.csv 2 | data_path: ETT-small/ 3 | id_columns: [] 4 | timestamp_column: date 5 | target_columns: [] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 1h 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 20 | - 0 21 | - 8640 # = 12 * 30 * 24 22 | valid: 23 | - 8640 # = 12 * 30 * 24 24 | - 11520 # = 12 * 30 * 24 + 4 * 30 * 24 25 | test: 26 | - 11520 # = 12 * 30 * 24 + 4 * 30 * 24 27 | - 14400 # = 12 * 30 * 24 + 8 * 30 * 24 28 | 29 | 30 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/etth2.yaml: -------------------------------------------------------------------------------- 1 | data_file: ETTh2.csv 2 | data_path: ETT-small/ 3 | id_columns: [] 4 | timestamp_column: date 5 | target_columns: [] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 1h 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 20 | - 0 21 | - 8640 # = 12 * 30 * 24 22 | valid: 23 | - 8640 # = 12 * 30 * 24 24 | - 11520 # = 12 * 30 * 24 + 4 * 30 * 24 25 | test: 26 | - 11520 # = 12 * 30 * 24 + 4 * 30 * 24 27 | - 14400 # = 12 * 30 * 24 + 8 * 30 * 24 28 | 29 | 30 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/ettm1.yaml: -------------------------------------------------------------------------------- 1 | data_file: ETTm1.csv 2 | data_path: ETT-small/ 3 | id_columns: [] 4 | timestamp_column: date 5 | target_columns: [] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 15min 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 20 | - 0 21 | - 34560 # = 12 * 30 * 24 * 4 22 | valid: 23 | - 34560 # = 12 * 30 * 24 * 4 24 | - 46080 # = 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 25 | test: 26 | - 46080 # = 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 27 | - 57600 # = 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4 28 | 29 | 30 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/ettm2.yaml: -------------------------------------------------------------------------------- 1 | data_file: ETTm2.csv 2 | data_path: ETT-small/ 3 | id_columns: [] 4 | timestamp_column: date 5 | target_columns: [] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 15min 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 20 | - 0 21 | - 34560 # = 12 * 30 * 24 * 4 22 | valid: 23 | - 34560 # = 12 * 30 * 24 * 4 24 | - 46080 # = 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 25 | test: 26 | - 46080 # = 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 27 | - 57600 # = 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4 28 | 29 | 30 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/exchange.yaml: -------------------------------------------------------------------------------- 1 | data_file: Exchange.csv 2 | data_path: Exchange 3 | id_columns: ["cols"] 4 | timestamp_column: date 5 | target_columns: ["data"] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 1d 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 0.7 20 | test: 0.2 21 | 22 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/solar.yaml: -------------------------------------------------------------------------------- 1 | data_file: Solar.csv 2 | data_path: Solar 3 | id_columns: ["cols"] 4 | timestamp_column: date 5 | target_columns: ["data"] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 10min 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 0.7 20 | test: 0.2 -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/traffic.yaml: -------------------------------------------------------------------------------- 1 | data_file: traffic.csv 2 | data_path: traffic/ 3 | id_columns: [] 4 | timestamp_column: date 5 | target_columns: [] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 1h 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 0.7 20 | test: 0.2 21 | 22 | 23 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/weather.yaml: -------------------------------------------------------------------------------- 1 | data_file: weather.csv 2 | data_path: weather/ 3 | id_columns: [] 4 | timestamp_column: date 5 | target_columns: [] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 10min 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 0.7 20 | test: 0.2 21 | 22 | 23 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/data_config/zafnoo.yaml: -------------------------------------------------------------------------------- 1 | data_file: ZafNoo.csv 2 | data_path: ZafNoo 3 | id_columns: ["cols"] 4 | timestamp_column: date 5 | target_columns: ["data"] 6 | observable_columns: [] 7 | control_columns: [] 8 | conditional_columns: [] 9 | static_categorical_columns: [] 10 | freq: 30min 11 | 12 | scale: 13 | scaling: True 14 | scaler_type: standard 15 | 16 | encode_categorical: False 17 | 18 | split: 19 | train: 0.7 20 | test: 0.2 21 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/resources/model_paths_config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/toolkit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | 4 | from .callbacks import TrackingCallback 5 | from .data_handling import load_dataset 6 | from .dataset import ForecastDFDataset, PretrainDFDataset, RegressionDFDataset 7 | from .recursive_predictor import RecursivePredictor, RecursivePredictorConfig, RecursivePredictorOutput 8 | from .time_series_forecasting_pipeline import TimeSeriesForecastingPipeline 9 | from .time_series_preprocessor import TimeSeriesPreprocessor, get_datasets 10 | from .util import count_parameters 11 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/toolkit/callbacks.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | """Some basic callbacks for training with HF Trainer""" 4 | 5 | import time 6 | 7 | import numpy as np 8 | from transformers import TrainerCallback 9 | from transformers.trainer_callback import TrainerControl, TrainerState 10 | from transformers.training_args import TrainingArguments 11 | 12 | 13 | class TrackingCallback(TrainerCallback): 14 | """Simple tracking callback that tracks per epoch run times and calculates some statistics after training completes. 15 | 16 | Args: 17 | verbose (bool, optional): If true, prints additional information at the completion of each epoch. Defaults to False. 18 | """ 19 | 20 | def __init__(self, verbose: bool = False): 21 | self.verbose = verbose 22 | 23 | def on_train_begin( 24 | self, 25 | args: TrainingArguments, 26 | state: TrainerState, 27 | control: TrainerControl, 28 | **kwargs, 29 | ): 30 | self.all_epoch_times = [] 31 | self.train_start_time = time.time() 32 | return super().on_train_begin(args, state, control, **kwargs) 33 | 34 | def on_train_end( 35 | self, 36 | args: TrainingArguments, 37 | state: TrainerState, 38 | control: TrainerControl, 39 | **kwargs, 40 | ): 41 | self.train_end_time = time.time() 42 | self.mean_epoch_time = np.mean(self.all_epoch_times) 43 | self.total_train_time = self.train_end_time - self.train_start_time 44 | self.best_eval_metric = state.best_metric 45 | print( 46 | f"[{self.__class__.__name__}] Mean Epoch Time = {self.mean_epoch_time} seconds, Total Train Time = {self.total_train_time}" 47 | ) 48 | return super().on_train_end(args, state, control, **kwargs) 49 | 50 | def on_epoch_begin( 51 | self, 52 | args: TrainingArguments, 53 | state: TrainerState, 54 | control: TrainerControl, 55 | **kwargs, 56 | ): 57 | self.epoch_start_time = time.time() 58 | return super().on_epoch_begin(args, state, control, **kwargs) 59 | 60 | def on_epoch_end( 61 | self, 62 | args: TrainingArguments, 63 | state: TrainerState, 64 | control: TrainerControl, 65 | **kwargs, 66 | ): 67 | self.epoch_end_time = time.time() 68 | self.last_epoch_time = self.epoch_end_time - self.epoch_start_time 69 | if self.verbose: 70 | print(f"[{self.__class__.__name__}] Epoch Time = {self.last_epoch_time} seconds") 71 | self.all_epoch_times.append(self.last_epoch_time) 72 | return super().on_epoch_end(args, state, control, **kwargs) 73 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/toolkit/data_handling.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | """Utilities for handling datasets""" 4 | 5 | import logging 6 | from importlib import resources 7 | from pathlib import Path 8 | from typing import Optional 9 | 10 | import pandas as pd 11 | import yaml 12 | 13 | from .time_series_preprocessor import TimeSeriesPreprocessor, get_datasets 14 | 15 | 16 | LOGGER = logging.getLogger(__file__) 17 | 18 | 19 | def load_dataset( 20 | dataset_name: str, 21 | context_length, 22 | forecast_length, 23 | fewshot_fraction=1.0, 24 | fewshot_location="first", 25 | dataset_root_path: str = "datasets/", 26 | dataset_path: Optional[str] = None, 27 | use_frequency_token: bool = False, 28 | enable_padding: bool = True, 29 | seed: int = 42, 30 | **dataset_kwargs, 31 | ): 32 | LOGGER.info(f"Dataset name: {dataset_name}, context length: {context_length}, prediction length {forecast_length}") 33 | 34 | config_path = resources.files("tsfm_public.resources.data_config") 35 | names_to_config = {p.stem: p for p in config_path.iterdir() if p.suffix == ".yaml"} 36 | 37 | config_path = names_to_config.get(dataset_name, None) 38 | 39 | if config_path is None: 40 | raise ValueError( 41 | f"Currently the `load_dataset()` function supports the following datasets: {names_to_config.keys()}\n \ 42 | For other datasets, please provide the proper configs to the TimeSeriesPreprocessor (TSP) module." 43 | ) 44 | 45 | config = yaml.safe_load(open(config_path, "r")) 46 | 47 | tsp = TimeSeriesPreprocessor( 48 | id_columns=config["id_columns"], 49 | timestamp_column=config["timestamp_column"], 50 | target_columns=config["target_columns"], 51 | observable_columns=config["observable_columns"], 52 | control_columns=config["control_columns"], 53 | conditional_columns=config["conditional_columns"], 54 | static_categorical_columns=config["static_categorical_columns"], 55 | scaling=config["scale"]["scaling"], 56 | scaler_type=config["scale"]["scaler_type"], 57 | encode_categorical=config["encode_categorical"], 58 | freq=config["freq"], 59 | context_length=context_length, 60 | prediction_length=forecast_length, 61 | ) 62 | 63 | split_config = config["split"] 64 | 65 | # if dataset_path is provided we will ignore the config file 66 | if dataset_path is None: 67 | dataset_path = Path(dataset_root_path) / config["data_path"] / config["data_file"] 68 | 69 | data = pd.read_csv( 70 | dataset_path, 71 | parse_dates=[config["timestamp_column"]], 72 | ) 73 | 74 | train_dataset, valid_dataset, test_dataset = get_datasets( 75 | tsp, 76 | data, 77 | split_config=split_config, 78 | fewshot_fraction=fewshot_fraction, 79 | fewshot_location=fewshot_location, 80 | use_frequency_token=use_frequency_token, 81 | enable_padding=enable_padding, 82 | seed=seed, 83 | **dataset_kwargs, 84 | ) 85 | LOGGER.info(f"Data lengths: train = {len(train_dataset)}, val = {len(valid_dataset)}, test = {len(test_dataset)}") 86 | 87 | return train_dataset, valid_dataset, test_dataset 88 | -------------------------------------------------------------------------------- /src/samay/models/TinyTimeMixer/version.py: -------------------------------------------------------------------------------- 1 | # Copyright contributors to the TSFM project 2 | # 3 | 4 | try: 5 | # Local 6 | from ._version import __version__, __version_tuple__ # noqa: F401 # unused import 7 | except ImportError: 8 | __version__ = "unknown" 9 | __version_tuple__ = (0, 0, __version__) 10 | -------------------------------------------------------------------------------- /src/samay/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/__init__.py -------------------------------------------------------------------------------- /src/samay/models/chronosforecasting/chronos/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # from samay.models.chronosforecasting.chronos.chronos import ( 5 | # ChronosConfig, 6 | # ChronosModel, 7 | # ChronosPipeline, 8 | # ChronosTokenizer, 9 | # MeanScaleUniformBins, 10 | # ) 11 | 12 | # __all__ = [ 13 | # "ChronosConfig", 14 | # "ChronosModel", 15 | # "ChronosPipeline", 16 | # "ChronosTokenizer", 17 | # "MeanScaleUniformBins", 18 | # ] 19 | -------------------------------------------------------------------------------- /src/samay/models/chronosforecasting/chronos/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | from typing import List 6 | 7 | import torch 8 | 9 | 10 | def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: 11 | max_len = max(len(c) for c in tensors) 12 | padded = [] 13 | for c in tensors: 14 | assert isinstance(c, torch.Tensor) 15 | assert c.ndim == 1 16 | padding = torch.full( 17 | size=(max_len - len(c),), fill_value=torch.nan, device=c.device 18 | ) 19 | padded.append(torch.concat((padding, c), dim=-1)) 20 | return torch.stack(padded) -------------------------------------------------------------------------------- /src/samay/models/chronosforecasting/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # from tsfmproject.models.chronosforecasting.scripts.finetune import ( 5 | # train_model, 6 | # ) 7 | 8 | # # import all the functions from the scripts/json_loader.py 9 | # from tsfmproject.models.chronosforecasting.scripts.jsonlogger import * 10 | 11 | 12 | # __all__ = [ "train_model", ] -------------------------------------------------------------------------------- /src/samay/models/chronosforecasting/scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | import torch 4 | from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM 5 | from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error 6 | from chronos import ChronosPipeline 7 | from json_logger import JsonFileHandler, JsonFormatter 8 | 9 | # Configure logging 10 | log_file = Path("evaluation_results.json") 11 | json_handler = JsonFileHandler(log_file) 12 | json_handler.setFormatter(JsonFormatter()) 13 | 14 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | logger.addHandler(json_handler) 18 | 19 | def load_model(model_dir: str, model_type: str = "seq2seq", device: str = "cuda", **kwargs): 20 | model_class = AutoModelForSeq2SeqLM if model_type == "seq2seq" else AutoModelForCausalLM 21 | model = ChronosPipeline.from_pretrained(model_dir, model_type=model_type) 22 | return model 23 | 24 | def evaluate_model(model, fit_data, test_data, prediction_length, metrics, logger=None): 25 | predictions, targets = [], [] 26 | context = torch.tensor(fit_data) 27 | predictions = model.predict(context, prediction_length=prediction_length).squeeze().tolist() 28 | 29 | results = {} 30 | results['num_samples'] = len(predictions) 31 | results['predictions'] = predictions[0] 32 | 33 | if 'RMSE' in metrics: 34 | results['RMSE'] = mean_squared_error(test_data, predictions[0], squared=False) 35 | if 'MAPE' in metrics: 36 | results['MAPE'] = mean_absolute_percentage_error(test_data, predictions[0]) 37 | 38 | logger.info(f"Evaluation results: {results}") 39 | 40 | return results 41 | 42 | # Example usage 43 | if __name__ == "__main__": 44 | pass 45 | # model_dir = "./output/run-15/checkpoint-final" 46 | # model_type = "seq2seq" 47 | # metrics = ['RMSE', 'MAPE'] 48 | # model = load_model(model_dir, model_type) 49 | # logger.info(f"Model loaded from {model_dir}") 50 | 51 | # # Assuming data_train and data_test are already defined 52 | # column_id = column_list[0] 53 | # results = evaluate_model(model, data_train[column_id].values, data_test[column_id].values, abs(offset), metrics, logger) 54 | # print(results) -------------------------------------------------------------------------------- /src/samay/models/chronosforecasting/scripts/jsonlogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | 4 | class JsonFileHandler(logging.Handler): 5 | def __init__(self, filename): 6 | super().__init__() 7 | self.filename = filename 8 | self.logs = [] 9 | 10 | def emit(self, record): 11 | log_entry = self.format(record) 12 | self.logs.append(json.loads(log_entry)) 13 | with open(self.filename, 'w') as f: 14 | json.dump(self.logs, f, indent=4) 15 | 16 | class JsonFormatter(logging.Formatter): 17 | def __init__(self, log_type): 18 | super().__init__() 19 | self.log_type = log_type 20 | 21 | def format(self, record): 22 | if self.log_type == "evaluation": 23 | log_record = { 24 | 'column_id': record.msg.get('column_id'), 25 | 'num_samples': record.msg.get('num_samples'), 26 | 'predictions': record.msg.get('predictions'), 27 | 'eval_results': record.msg.get('eval_results') 28 | } 29 | else: 30 | log_record = { 31 | 'time': self.formatTime(record, self.datefmt), 32 | 'name': record.name, 33 | 'level': record.levelname, 34 | 'message': record.getMessage() 35 | } 36 | return json.dumps(log_record) -------------------------------------------------------------------------------- /src/samay/models/lptm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/lptm/__init__.py -------------------------------------------------------------------------------- /src/samay/models/lptm/segment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/lptm/segment/__init__.py -------------------------------------------------------------------------------- /src/samay/models/lptm/segment/scoring.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ScoringModuleBase(nn.Module, ABC): 8 | def __init__(self, embed_size: int, hidden_size: int | None = None): 9 | super(ScoringModuleBase, self).__init__() 10 | self.embed_size = embed_size 11 | self.hidden_size = hidden_size or embed_size 12 | self.W1 = nn.Linear(self.embed_size, self.hidden_size) 13 | self.W2 = nn.Linear(self.embed_size, self.hidden_size) 14 | 15 | def forward( 16 | self, time_embeds: torch.Tensor, mask: torch.Tensor | None = None 17 | ) -> torch.Tensor: 18 | """ 19 | Args: 20 | time_embeds: torch.Tensor, shape (batch_size, seq_len, embed_size) 21 | mask: torch.Tensor, shape (batch_size, seq_len) 22 | Returns: 23 | torch.Tensor, shape (batch_size, seq_len, seq_len) 24 | """ 25 | if time_embeds.size(-1) != self.embed_size: 26 | time_embeds = torch.nn.functional.interpolate( 27 | time_embeds, size=self.embed_size, mode="linear" 28 | ) 29 | batch_size, seq_len, _ = time_embeds.size() 30 | # Compute the scores 31 | W1: torch.Tensor = self.W1(time_embeds) # (batch_size, seq_len, hidden_size) 32 | W2: torch.Tensor = self.W2(time_embeds) # (batch_size, seq_len, hidden_size) 33 | # Pairwise addition 34 | scores = self.compute_scores(W1, W2) # (batch_size, seq_len, seq_len) 35 | 36 | # Mask out the scores 37 | # if mask is not None: 38 | # mask = mask.unsqueeze(1) 39 | # mask = mask.expand(batch_size, seq_len, seq_len) 40 | # scores = scores.masked_fill(mask, float("-inf")) 41 | 42 | return scores 43 | 44 | @abstractmethod 45 | def compute_scores(self, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor: 46 | raise NotImplementedError("Subclasses must implement compute_scores method") 47 | 48 | 49 | class ScoringModuleAddn(ScoringModuleBase): 50 | def __init__(self, embed_size: int, hidden_size: int | None = None): 51 | super(ScoringModuleAddn, self).__init__(embed_size, hidden_size) 52 | self.b = nn.Parameter(torch.zeros(self.hidden_size)) 53 | self.v = nn.Parameter(torch.zeros(self.hidden_size)) 54 | # Initialize weights 55 | nn.init.uniform_(self.b) 56 | nn.init.uniform_(self.v) 57 | 58 | def compute_scores(self, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor: 59 | scores = ( 60 | W1.unsqueeze(1) + W2.unsqueeze(2) + self.b[None, None, None, :] 61 | ) # (batch_size, seq_len, seq_len, hidden_size) 62 | scores = torch.tanh(scores) 63 | scores = scores @ self.v # (batch_size, seq_len, seq_len) 64 | return scores 65 | 66 | 67 | class ScoringModuleMult(ScoringModuleBase): 68 | def __init__(self, embed_size: int, hidden_size: int | None = None): 69 | super(ScoringModuleMult, self).__init__(embed_size, hidden_size) 70 | 71 | def compute_scores(self, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor: 72 | scores = torch.bmm(W1, W2.transpose(1, 2)) # (batch_size, seq_len, seq_len) 73 | return scores 74 | -------------------------------------------------------------------------------- /src/samay/models/lptm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import Namespace 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class NamespaceWithDefaults(Namespace): 10 | @classmethod 11 | def from_namespace(cls, namespace): 12 | new_instance = cls() 13 | for attr in dir(namespace): 14 | if not attr.startswith("__"): 15 | setattr(new_instance, attr, getattr(namespace, attr)) 16 | return new_instance 17 | 18 | def getattr(self, key, default=None): 19 | return getattr(self, key, default) 20 | 21 | 22 | def parse_config(config: dict) -> NamespaceWithDefaults: 23 | args = NamespaceWithDefaults(**config) 24 | return args 25 | 26 | 27 | def make_dir_if_not_exists(path, verbose=True): 28 | if not is_directory(path): 29 | path = path.split(".")[0] 30 | if not os.path.exists(path=path): 31 | os.makedirs(path) 32 | if verbose: 33 | print(f"Making directory: {path}...") 34 | return True 35 | 36 | 37 | def is_directory(path): 38 | extensions = [".pth", ".txt", ".json", ".yaml"] 39 | 40 | for ext in extensions: 41 | if ext in path: 42 | return False 43 | return True 44 | 45 | 46 | def control_randomness(seed: int = 13): 47 | random.seed(seed) 48 | os.environ["PYTHONHASHSEED"] = str(seed) 49 | np.random.seed(seed) 50 | torch.manual_seed(seed) 51 | torch.cuda.manual_seed(seed) 52 | torch.backends.cudnn.deterministic = True 53 | torch.backends.cudnn.benchmark = False 54 | 55 | 56 | def dtype_map(dtype: str): 57 | map = { 58 | "float16": torch.float16, 59 | "float32": torch.float32, 60 | "float64": torch.float64, 61 | "bfloat16": torch.bfloat16, 62 | "uint8": torch.uint8, 63 | "int8": torch.int8, 64 | "int16": torch.int16, 65 | "int32": torch.int32, 66 | "int64": torch.int64, 67 | "bool": torch.bool, 68 | } 69 | return map[dtype] 70 | 71 | 72 | def get_huggingface_model_dimensions(model_name: str = "flan-t5-base"): 73 | from transformers import T5Config 74 | 75 | config = T5Config.from_pretrained(model_name) 76 | return config.d_model 77 | 78 | 79 | def get_anomaly_criterion(anomaly_criterion: str = "mse"): 80 | if anomaly_criterion == "mse": 81 | return torch.nn.MSELoss(reduction="none") 82 | elif anomaly_criterion == "mae": 83 | return torch.nn.L1Loss(reduction="none") 84 | else: 85 | raise ValueError(f"Anomaly criterion {anomaly_criterion} not supported.") 86 | 87 | 88 | def _reduce(metric, reduction="mean", axis=None): 89 | if reduction == "mean": 90 | return np.nanmean(metric, axis=axis) 91 | elif reduction == "sum": 92 | return np.nansum(metric, axis=axis) 93 | elif reduction == "none": 94 | return metric 95 | 96 | 97 | class EarlyStopping: 98 | def __init__(self, patience: int = 3, verbose: bool = False, delta: float = 0): 99 | self.patience = patience 100 | self.verbose = verbose 101 | self.counter = 0 102 | self.best_score = None 103 | self.early_stop = False 104 | self.val_loss_min = np.inf 105 | self.delta = delta 106 | 107 | def __call__(self, validation_loss): 108 | score = -validation_loss 109 | if self.best_score is None: 110 | self.best_score = score 111 | 112 | elif score < self.best_score + self.delta: 113 | self.counter += 1 114 | if self.verbose: 115 | print(f"EarlyStopping counter: {self.counter} out of {self.patience}") 116 | if self.counter >= self.patience: 117 | self.early_stop = True 118 | else: 119 | self.best_score = score 120 | self.counter = 0 121 | -------------------------------------------------------------------------------- /src/samay/models/moment/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Auton Lab, Carnegie Mellon University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/samay/models/moment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/__init__.py -------------------------------------------------------------------------------- /src/samay/models/moment/assets/MOMENT Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/assets/MOMENT Logo.png -------------------------------------------------------------------------------- /src/samay/models/moment/assets/autonlab_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/assets/autonlab_logo.png -------------------------------------------------------------------------------- /src/samay/models/moment/assets/cmu_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/assets/cmu_logo.png -------------------------------------------------------------------------------- /src/samay/models/moment/assets/moment_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/assets/moment_architecture.png -------------------------------------------------------------------------------- /src/samay/models/moment/assets/moment_comparison .png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/assets/moment_comparison .png -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.moment import MOMENT, MOMENTPipeline 2 | -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/common.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class TASKS: 6 | RECONSTRUCTION: str = "reconstruction" 7 | FORECASTING: str = "forecasting" 8 | CLASSIFICATION: str = "classification" 9 | EMBED: str = "embedding" 10 | -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/momentfm/data/__init__.py -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/data/anomaly_detection_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.preprocessing import StandardScaler 4 | 5 | 6 | class AnomalyDetectionDataset: 7 | def __init__( 8 | self, 9 | data_split: str = "train", 10 | data_stride_len: int = 512, 11 | random_seed: int = 42, 12 | ): 13 | """ 14 | Parameters 15 | ---------- 16 | data_split : str 17 | Split of the dataset, 'train', or 'test' 18 | data_stride_len : int 19 | Stride length for the data. 20 | random_seed : int 21 | Random seed for reproducibility. 22 | """ 23 | 24 | self.full_file_path_and_name = ( 25 | "../data/198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out" 26 | ) 27 | self.series = "198_UCR_Anomaly_tiltAPB2_50000_124159_124985" 28 | self.data_split = data_split 29 | self.data_stride_len = data_stride_len 30 | self.random_seed = random_seed 31 | self.seq_len = 512 32 | 33 | # Downsampling for experiments. Refer 34 | # https://github.com/mononitogoswami/tsad-model-selection for more details 35 | self.downsampling_factor = 10 36 | self.min_length = ( 37 | 2560 # Minimum length of time-series after downsampling for experiments 38 | ) 39 | 40 | # Read data 41 | self._read_data() 42 | 43 | def _get_borders(self): 44 | details = self.series.split("_") 45 | n_train = int(details[4]) 46 | train_end = n_train 47 | test_start = train_end 48 | 49 | return slice(0, train_end), slice(test_start, None) 50 | 51 | def _read_data(self): 52 | self.scaler = StandardScaler() 53 | df = pd.read_csv(self.full_file_path_and_name) 54 | df.interpolate(inplace=True, method="cubic") 55 | 56 | self.length_timeseries = len(df) 57 | self.n_channels = 1 58 | labels = df.iloc[:, -1].values 59 | timeseries = df.iloc[:, 0].values.reshape(-1, 1) 60 | 61 | data_splits = self._get_borders() 62 | 63 | self.scaler.fit(timeseries[data_splits[0]]) 64 | timeseries = self.scaler.transform(timeseries) 65 | timeseries = timeseries.squeeze() 66 | 67 | if self.data_split == "train": 68 | self.data, self.labels = timeseries[data_splits[0]], labels[data_splits[0]] 69 | elif self.data_split == "test": 70 | self.data, self.labels = timeseries[data_splits[1]], labels[data_splits[1]] 71 | 72 | self.length_timeseries = self.data.shape[0] 73 | 74 | def __getitem__(self, index): 75 | seq_start = self.data_stride_len * index 76 | seq_end = seq_start + self.seq_len 77 | input_mask = np.ones(self.seq_len) 78 | 79 | if seq_end > self.length_timeseries: 80 | seq_start = self.length_timeseries - self.seq_len 81 | seq_end = None 82 | 83 | timeseries = self.data[seq_start:seq_end].reshape( 84 | (self.n_channels, self.seq_len) 85 | ) 86 | labels = ( 87 | self.labels[seq_start:seq_end] 88 | .astype(int) 89 | .reshape((self.n_channels, self.seq_len)) 90 | ) 91 | 92 | return timeseries, input_mask, labels 93 | 94 | def __len__(self): 95 | return (self.length_timeseries // self.data_stride_len) + 1 96 | -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/data/classification_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.preprocessing import StandardScaler 3 | 4 | from moment.momentfm.utils.data import load_from_tsfile 5 | 6 | 7 | class ClassificationDataset: 8 | def __init__(self, data_split="train"): 9 | """ 10 | Parameters 11 | ---------- 12 | data_split : str 13 | Split of the dataset, 'train', 'val' or 'test'. 14 | """ 15 | 16 | self.seq_len = 512 17 | self.train_file_path_and_name = "../data/ECG5000_TRAIN.ts" 18 | self.test_file_path_and_name = "../data/ECG5000_TEST.ts" 19 | self.data_split = data_split # 'train' or 'test' 20 | 21 | # Read data 22 | self._read_data() 23 | 24 | def _transform_labels(self, train_labels: np.ndarray, test_labels: np.ndarray): 25 | labels = np.unique(train_labels) # Move the labels to {0, ..., L-1} 26 | transform = {} 27 | for i, l in enumerate(labels): 28 | transform[l] = i 29 | 30 | train_labels = np.vectorize(transform.get)(train_labels) 31 | test_labels = np.vectorize(transform.get)(test_labels) 32 | 33 | return train_labels, test_labels 34 | 35 | def __len__(self): 36 | return self.num_timeseries 37 | 38 | def _read_data(self): 39 | self.scaler = StandardScaler() 40 | 41 | self.train_data, self.train_labels = load_from_tsfile( 42 | self.train_file_path_and_name 43 | ) 44 | self.test_data, self.test_labels = load_from_tsfile( 45 | self.test_file_path_and_name 46 | ) 47 | 48 | self.train_labels, self.test_labels = self._transform_labels( 49 | self.train_labels, self.test_labels 50 | ) 51 | 52 | if self.data_split == "train": 53 | self.data = self.train_data 54 | self.labels = self.train_labels 55 | else: 56 | self.data = self.test_data 57 | self.labels = self.test_labels 58 | 59 | self.num_timeseries = self.data.shape[0] 60 | self.len_timeseries = self.data.shape[2] 61 | 62 | self.data = self.data.reshape(-1, self.len_timeseries) 63 | self.scaler.fit(self.data) 64 | self.data = self.scaler.transform(self.data) 65 | self.data = self.data.reshape(self.num_timeseries, self.len_timeseries) 66 | 67 | self.data = self.data.T 68 | 69 | def __getitem__(self, index): 70 | assert index < self.__len__() 71 | 72 | timeseries = self.data[:, index] 73 | timeseries_len = len(timeseries) 74 | labels = self.labels[index,].astype(int) 75 | input_mask = np.ones(self.seq_len) 76 | input_mask[: self.seq_len - timeseries_len] = 0 77 | 78 | timeseries = np.pad(timeseries, (self.seq_len - timeseries_len, 0)) 79 | 80 | return np.expand_dims(timeseries, axis=0), input_mask, labels 81 | -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/dataclass/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy.typing as npt 4 | 5 | 6 | @dataclass 7 | class TimeseriesOutputs: 8 | forecast: npt.NDArray = None 9 | anomaly_scores: npt.NDArray = None 10 | logits: npt.NDArray = None 11 | labels: int = None 12 | input_mask: npt.NDArray = None 13 | pretrain_mask: npt.NDArray = None 14 | reconstruction: npt.NDArray = None 15 | embeddings: npt.NDArray = None 16 | metadata: dict = None 17 | illegal_output: bool = False -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/momentfm/models/__init__.py -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/models/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/momentfm/models/layers/__init__.py -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/models/layers/revin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def nanvar(tensor, dim=None, keepdim=False): 6 | tensor_mean = tensor.nanmean(dim=dim, keepdim=True) 7 | output = (tensor - tensor_mean).square().nanmean(dim=dim, keepdim=keepdim) 8 | return output 9 | 10 | 11 | def nanstd(tensor, dim=None, keepdim=False): 12 | output = nanvar(tensor, dim=dim, keepdim=keepdim) 13 | output = output.sqrt() 14 | return output 15 | 16 | 17 | class RevIN(nn.Module): 18 | def __init__(self, num_features: int, eps: float = 1e-5, affine: bool = False): 19 | """ 20 | :param num_features: the number of features or channels 21 | :param eps: a value added for numerical stability 22 | :param affine: if True, RevIN has learnable affine parameters 23 | """ 24 | super(RevIN, self).__init__() 25 | self.num_features = num_features 26 | self.eps = eps 27 | self.affine = affine 28 | 29 | if self.affine: 30 | self._init_params() 31 | 32 | def forward(self, x: torch.Tensor, mode: str = "norm", mask: torch.Tensor = None): 33 | """ 34 | :param x: input tensor of shape (batch_size, n_channels, seq_len) 35 | :param mode: 'norm' or 'denorm' 36 | :param mask: input mask of shape (batch_size, seq_len) 37 | :return: RevIN transformed tensor 38 | """ 39 | if mode == "norm": 40 | self._get_statistics(x, mask=mask) 41 | x = self._normalize(x) 42 | elif mode == "denorm": 43 | x = self._denormalize(x) 44 | else: 45 | raise NotImplementedError 46 | return x 47 | 48 | def _init_params(self): 49 | # initialize RevIN params: (C,) 50 | self.affine_weight = nn.Parameter(torch.ones(1, self.num_features, 1)) 51 | self.affine_bias = nn.Parameter(torch.zeros(1, self.num_features, 1)) 52 | 53 | def _get_statistics(self, x, mask=None): 54 | """ 55 | x : batch_size x n_channels x seq_len 56 | mask : batch_size x seq_len 57 | """ 58 | if mask is None: 59 | mask = torch.ones((x.shape[0], x.shape[-1])) 60 | n_channels = x.shape[1] 61 | mask = mask.unsqueeze(1).repeat(1, n_channels, 1).bool() 62 | # Set masked positions to NaN, and unmasked positions are taken from x 63 | masked_x = torch.where(mask, x, torch.nan) 64 | self.mean = torch.nanmean(masked_x, dim=-1, keepdim=True).detach() 65 | self.stdev = nanstd(masked_x, dim=-1, keepdim=True).detach() + self.eps 66 | # self.stdev = torch.sqrt( 67 | # torch.var(masked_x, dim=-1, keepdim=True) + self.eps).get_data().detach() 68 | # NOTE: By default not bessel correction 69 | 70 | def _normalize(self, x): 71 | x = x - self.mean 72 | x = x / self.stdev 73 | 74 | if self.affine: 75 | x = x * self.affine_weight 76 | x = x + self.affine_bias 77 | return x 78 | 79 | def _denormalize(self, x): 80 | if self.affine: 81 | x = x - self.affine_bias 82 | x = x / (self.affine_weight + self.eps * self.eps) 83 | x = x * self.stdev 84 | x = x + self.mean 85 | return x 86 | -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/models/statistical_classifiers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | from sklearn.model_selection import GridSearchCV, train_test_split 4 | from sklearn.svm import SVC 5 | 6 | 7 | def fit_svm(features: npt.NDArray, y: npt.NDArray, MAX_SAMPLES: int = 10000): 8 | nb_classes = np.unique(y, return_counts=True)[1].shape[0] 9 | train_size = features.shape[0] 10 | 11 | svm = SVC(C=100000, gamma="scale") 12 | if train_size // nb_classes < 5 or train_size < 50: 13 | return svm.fit(features, y) 14 | else: 15 | grid_search = GridSearchCV( 16 | svm, 17 | { 18 | "C": [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000], 19 | "kernel": ["rbf"], 20 | "degree": [3], 21 | "gamma": ["scale"], 22 | "coef0": [0], 23 | "shrinking": [True], 24 | "probability": [False], 25 | "tol": [0.001], 26 | "cache_size": [200], 27 | "class_weight": [None], 28 | "verbose": [False], 29 | "max_iter": [10000000], 30 | "decision_function_shape": ["ovr"], 31 | }, 32 | cv=5, 33 | n_jobs=10, 34 | ) 35 | # If the training set is too large, subsample MAX_SAMPLES examples 36 | if train_size > MAX_SAMPLES: 37 | split = train_test_split( 38 | features, y, train_size=MAX_SAMPLES, random_state=0, stratify=y 39 | ) 40 | features = split[0] 41 | y = split[2] 42 | 43 | grid_search.fit(features, y) 44 | return grid_search.best_estimator_ 45 | -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/moment/momentfm/utils/__init__.py -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/utils/anomaly_detection_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def f1_score(predict, actual): 5 | TP = np.sum(predict * actual) 6 | TN = np.sum((1 - predict) * (1 - actual)) 7 | FP = np.sum(predict * (1 - actual)) 8 | FN = np.sum((1 - predict) * actual) 9 | precision = TP / (TP + FP + 0.00001) 10 | recall = TP / (TP + FN + 0.00001) 11 | f1 = 2 * precision * recall / (precision + recall + 0.00001) 12 | return f1 13 | 14 | 15 | def adjust_predicts(score, label, threshold=None, pred=None, calc_latency=False): 16 | """ 17 | Calculate adjusted predict labels using given `score`, `threshold` (or given `pred`) and `label`. 18 | Args: 19 | score (np.ndarray): The anomaly score 20 | label (np.ndarray): The ground-truth label 21 | threshold (float): The threshold of anomaly score. 22 | A point is labeled as "anomaly" if its score is lower than the threshold. 23 | pred (np.ndarray or None): if not None, adjust `pred` and ignore `score` and `threshold`, 24 | calc_latency (bool): 25 | Returns: 26 | np.ndarray: predict labels 27 | """ 28 | if len(score) != len(label): 29 | raise ValueError("score and label must have the same length") 30 | score = np.asarray(score) 31 | label = np.asarray(label) 32 | latency = 0 33 | if pred is None: 34 | predict = score < threshold 35 | else: 36 | predict = pred 37 | actual = label > 0.1 38 | anomaly_state = False 39 | anomaly_count = 0 40 | for i in range(len(score)): 41 | if actual[i] and predict[i] and not anomaly_state: 42 | anomaly_state = True 43 | anomaly_count += 1 44 | for j in range(i, 0, -1): 45 | if not actual[j]: 46 | break 47 | else: 48 | if not predict[j]: 49 | predict[j] = True 50 | latency += 1 51 | elif not actual[i]: 52 | anomaly_state = False 53 | if anomaly_state: 54 | predict[i] = True 55 | if calc_latency: 56 | return predict, latency / (anomaly_count + 1e-4) 57 | else: 58 | return predict 59 | 60 | 61 | def adjbestf1(y_true: np.array, y_scores: np.array, n_splits: int = 100): 62 | thresholds = np.linspace(y_scores.min(), y_scores.max(), n_splits) 63 | adjusted_f1 = np.zeros(thresholds.shape) 64 | 65 | for i, threshold in enumerate(thresholds): 66 | y_pred = y_scores >= threshold 67 | y_pred = adjust_predicts( 68 | score=y_scores, 69 | label=(y_true > 0), 70 | pred=y_pred, 71 | threshold=None, 72 | calc_latency=False, 73 | ) 74 | adjusted_f1[i] = f1_score(y_pred, y_true) 75 | 76 | best_adjusted_f1 = np.max(adjusted_f1) 77 | return best_adjusted_f1 78 | -------------------------------------------------------------------------------- /src/samay/models/moment/momentfm/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import Namespace 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class NamespaceWithDefaults(Namespace): 10 | @classmethod 11 | def from_namespace(cls, namespace): 12 | new_instance = cls() 13 | for attr in dir(namespace): 14 | if not attr.startswith("__"): 15 | setattr(new_instance, attr, getattr(namespace, attr)) 16 | return new_instance 17 | 18 | def getattr(self, key, default=None): 19 | return getattr(self, key, default) 20 | 21 | 22 | def parse_config(config: dict) -> NamespaceWithDefaults: 23 | args = NamespaceWithDefaults(**config) 24 | return args 25 | 26 | 27 | def make_dir_if_not_exists(path, verbose=True): 28 | if not is_directory(path): 29 | path = path.split(".")[0] 30 | if not os.path.exists(path=path): 31 | os.makedirs(path) 32 | if verbose: 33 | print(f"Making directory: {path}...") 34 | return True 35 | 36 | 37 | def is_directory(path): 38 | extensions = [".pth", ".txt", ".json", ".yaml"] 39 | 40 | for ext in extensions: 41 | if ext in path: 42 | return False 43 | return True 44 | 45 | 46 | def control_randomness(seed: int = 13): 47 | random.seed(seed) 48 | os.environ["PYTHONHASHSEED"] = str(seed) 49 | np.random.seed(seed) 50 | torch.manual_seed(seed) 51 | torch.cuda.manual_seed(seed) 52 | torch.backends.cudnn.deterministic = True 53 | torch.backends.cudnn.benchmark = False 54 | 55 | 56 | def dtype_map(dtype: str): 57 | map = { 58 | "float16": torch.float16, 59 | "float32": torch.float32, 60 | "float64": torch.float64, 61 | "bfloat16": torch.bfloat16, 62 | "uint8": torch.uint8, 63 | "int8": torch.int8, 64 | "int16": torch.int16, 65 | "int32": torch.int32, 66 | "int64": torch.int64, 67 | "bool": torch.bool, 68 | } 69 | return map[dtype] 70 | 71 | 72 | def get_huggingface_model_dimensions(model_name: str = "flan-t5-base"): 73 | from transformers import T5Config 74 | 75 | config = T5Config.from_pretrained(model_name) 76 | return config.d_model 77 | 78 | 79 | def get_anomaly_criterion(anomaly_criterion: str = "mse"): 80 | if anomaly_criterion == "mse": 81 | return torch.nn.MSELoss(reduction="none") 82 | elif anomaly_criterion == "mae": 83 | return torch.nn.L1Loss(reduction="none") 84 | else: 85 | raise ValueError(f"Anomaly criterion {anomaly_criterion} not supported.") 86 | 87 | 88 | def _reduce(metric, reduction="mean", axis=None): 89 | if reduction == "mean": 90 | return np.nanmean(metric, axis=axis) 91 | elif reduction == "sum": 92 | return np.nansum(metric, axis=axis) 93 | elif reduction == "none": 94 | return metric 95 | 96 | 97 | class EarlyStopping: 98 | def __init__(self, patience: int = 3, verbose: bool = False, delta: float = 0): 99 | self.patience = patience 100 | self.verbose = verbose 101 | self.counter = 0 102 | self.best_score = None 103 | self.early_stop = False 104 | self.val_loss_min = np.inf 105 | self.delta = delta 106 | 107 | def __call__(self, validation_loss): 108 | score = -validation_loss 109 | if self.best_score is None: 110 | self.best_score = score 111 | 112 | elif score < self.best_score + self.delta: 113 | self.counter += 1 114 | if self.verbose: 115 | print(f"EarlyStopping counter: {self.counter} out of {self.patience}") 116 | if self.counter >= self.patience: 117 | self.early_stop = True 118 | else: 119 | self.best_score = score 120 | self.counter = 0 121 | -------------------------------------------------------------------------------- /src/samay/models/moment/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "momentfm" 7 | version = "0.1.2" 8 | dependencies = [ 9 | "huggingface-hub==0.24.0", 10 | "numpy==1.25.2", 11 | "torch~=2.0", # package was tested on 2.0.1 12 | "transformers==4.33.3", 13 | ] 14 | 15 | requires-python = ">=3.10" 16 | 17 | authors = [ 18 | {name = "Mononito Goswami", email = "mononitog@hotmail.com"}, 19 | {name = "Konrad Szafer", email = ""}, 20 | {name = "Arjun Choudhry"}, 21 | {name = "Yifu Cai"}, 22 | ] 23 | 24 | maintainers = [ 25 | {name = "Mononito Goswami", email = "mononitog@hotmail.com"}, 26 | {name = "Konrad Szafer", email = "szafer.konrad@gmail.com"}, 27 | {name = "Yifu Cai", email = "yifuc@andrew.cmu.edu"}, 28 | ] 29 | 30 | description = "MOMENT: A Family of Open Time Series Foundation Models" 31 | readme = "README.md" 32 | license = {file = "LICENSE"} 33 | 34 | keywords = [ 35 | "time series", 36 | "forecasting", 37 | "classification", 38 | "imputation", 39 | "anomaly detection", 40 | "transformers", 41 | "pytorch", 42 | "huggingface", 43 | "moment", 44 | "foundation models", 45 | "large language models", 46 | ] 47 | 48 | classifiers = [ 49 | "Programming Language :: Python" 50 | ] 51 | 52 | [project.urls] 53 | Homepage = "https://moment-timeseries-foundation-model.github.io/" 54 | Repository = "https://github.com/moment-timeseries-foundation-model/moment" 55 | "Bug Tracker" = "https://github.com/moment-timeseries-foundation-model/moment/issues" 56 | -------------------------------------------------------------------------------- /src/samay/models/moment/requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface-hub==0.24.0 2 | numpy==1.25.2 3 | torch==2.0.1 4 | transformers==4.33.3 5 | -------------------------------------------------------------------------------- /src/samay/models/moment/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # read the contents of your README file 4 | from pathlib import Path 5 | this_directory = Path(__file__).parent 6 | long_description = (this_directory / "README.md").read_text() 7 | 8 | with open("requirements.txt") as f: 9 | required = f.read().splitlines() 10 | 11 | setup( 12 | name="momentfm", 13 | version="0.1.2", 14 | description="MOMENT: A Family of Open Time-series Foundation Models", 15 | author="Mononito Goswami, Konrad Szafer, Arjun Choudhry, Yifu Cai, Shuo Li, Artur Dubrawski", 16 | author_email="mgoswami@andrew.cmu.edu", 17 | license="MIT", 18 | url="https://moment-timeseries-foundation-model.github.io/", 19 | zip_safe=False, 20 | packages=find_packages(exclude=["data", "tutorials"]), 21 | install_requires=required, 22 | long_description=long_description, 23 | long_description_content_type='text/markdown' 24 | ) 25 | -------------------------------------------------------------------------------- /src/samay/models/moment/tutorials/finetune_demo/classification.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4,5 2 | 3 | # use this for full finetuning 4 | accelerate launch --config_file tutorials/finetune_demo/ds.yaml \ 5 | tutorials/finetune_demo/classification.py \ 6 | --base_path path to your ptbxl base folder \ 7 | --cache_dir path to cache directory for preprocessed dataset \ 8 | --mode full_finetuning \ 9 | --output_path path to store train log and checkpoint \ 10 | 11 | # #use this for linear_probing, svm, unsupervised_representation_learning 12 | python3 tutorials/finetune_demo/classification.py \ 13 | --base_path path to your ptbxl base folder \ 14 | --cache_dir path to cache directory for preprocessed dataset \ 15 | --mode linear_probing \ 16 | --output_path path to store train log and checkpoint \ -------------------------------------------------------------------------------- /src/samay/models/moment/tutorials/finetune_demo/ds.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | deepspeed_config: 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero_stage: 2 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | machine_rank: 0 11 | main_training_function: main 12 | mixed_precision: bf16 13 | num_machines: 1 14 | num_processes: 2 15 | rdzv_backend: static 16 | same_network: true 17 | use_cpu: false -------------------------------------------------------------------------------- /src/samay/models/timesfm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/models/timesfm/__init__.py -------------------------------------------------------------------------------- /src/samay/models/timesfm/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """adapter init file.""" 16 | 17 | from .dora_layers import DoraAttentionProjection, DoraCombinedQKVProjection, DoraLinear 18 | from .lora_layers import LoraAttentionProjection, LoraCombinedQKVProjection, LoraLinear 19 | -------------------------------------------------------------------------------- /src/samay/models/timesfm/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="timesfm", 5 | version="1.0.1", 6 | packages=find_packages(), 7 | install_requires=[], 8 | ) 9 | -------------------------------------------------------------------------------- /src/samay/models/timesfm/timesfm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # """TimesFM init file.""" 15 | # print( 16 | # "TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs." 17 | # ) 18 | from samay.models.timesfm.timesfm.timesfm_base import (TimesFmCheckpoint, 19 | TimesFmHparams, 20 | TimesFmBase, 21 | freq_map, 22 | ) 23 | # try: 24 | # print("Loaded Jax TimesFM.") 25 | # from timesfm.src.timesfm.timesfm_jax import TimesFmJax as TimesFm 26 | # from timesfm.src.timesfm import data_loader 27 | # except Exception as _: 28 | # print("Loaded PyTorch TimesFM.") 29 | from samay.models.timesfm.timesfm.timesfm_torch import TimesFmTorch as TimesFm 30 | -------------------------------------------------------------------------------- /src/samay/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/samay/py.typed -------------------------------------------------------------------------------- /src/samay/trial.py: -------------------------------------------------------------------------------- 1 | from .dataset import MomentDataset 2 | from .model import MomentModel 3 | 4 | 5 | def main(): 6 | # dataset = TimesfmDataset(name="tycho", path='/nethome/sli999/data/Tycho/timesfm_US_covid_pivot.csv') 7 | dataset = MomentDataset( 8 | name="ett", 9 | datetime_col="date", 10 | path="/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv", 11 | ) 12 | 13 | # repo = "google/timesfm-1.0-200m-pytorch" 14 | repo = "AutonLab/MOMENT-1-large" 15 | # config = { 16 | # "context_len": 128, 17 | # "horizon_len": 32, 18 | # "backend": "gpu", 19 | # "per_core_batch_size": 32, 20 | # "input_patch_len": 32, 21 | # "output_patch_len": 128, 22 | # "num_layers": 20, 23 | # "model_dims": 1280, 24 | # "quantiles": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 25 | # } 26 | config = { 27 | "task_name": "forecasting", 28 | "forecast_horizon": 192, 29 | "head_dropout": 0.1, 30 | "weight_decay": 0, 31 | "freeze_encoder": True, # Freeze the patch embedding layer 32 | "freeze_embedder": True, # Freeze the transformer encoder 33 | "freeze_head": False, # The linear forecasting head must be trained 34 | } 35 | 36 | # tfm = TimesfmModel(config=config, repo=repo) 37 | mmt = MomentModel(config=config, repo=repo) 38 | # finetuned_model = tfm.finetune(dataset) 39 | finetuned_model = mmt.finetune(dataset) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /src/uni2ts/__about__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | __version__ = "1.2.0" 17 | -------------------------------------------------------------------------------- /src/uni2ts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from .__about__ import __version__ 16 | -------------------------------------------------------------------------------- /src/uni2ts/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/uni2ts/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/data/etth1_test.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.eval_util.data.get_custom_eval_dataset 2 | dataset_name: ETTh1_eval 3 | offset: 14400 4 | windows: 2785 5 | distance: 1 6 | prediction_length: 96 7 | mode: null -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/data/etth1_val.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.eval_util.data.get_custom_eval_dataset 2 | dataset_name: ETTh1_eval 3 | offset: 11520 4 | windows: 2785 5 | distance: 1 6 | prediction_length: 96 7 | mode: null -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/data/gluonts_test.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.eval_util.data.get_gluonts_test_dataset 2 | dataset_name: ??? 3 | prediction_length: null 4 | mode: S -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/data/gluonts_val.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.eval_util.data.get_gluonts_val_dataset 2 | dataset_name: ??? 3 | prediction_length: null 4 | mode: S -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/data/lsf_test.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.eval_util.data.get_lsf_test_dataset 2 | dataset_name: ??? 3 | prediction_length: ??? 4 | mode: S -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/data/lsf_val.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.eval_util.data.get_lsf_val_dataset 2 | dataset_name: ??? 3 | prediction_length: ??? 4 | mode: S -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/data/monash.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.eval_util.data.get_gluonts_test_dataset 2 | dataset_name: ??? 3 | prediction_length: null 4 | mode: S -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/default.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: outputs/${hydra:job.name}/${hydra:runtime.choices.data}/${data.dataset_name}/${data.mode}/prediction_length=${data.prediction_length}/${run_name} 4 | defaults: 5 | - model: ??? 6 | - data: ??? 7 | - _self_ 8 | run_name: ??? 9 | metrics: 10 | - _target_: gluonts.ev.metrics.MSE 11 | - _target_: uni2ts.eval_util.metrics.MedianMSE 12 | - _target_: gluonts.ev.metrics.MAE 13 | - _target_: gluonts.ev.metrics.MASE 14 | - _target_: gluonts.ev.metrics.MAPE 15 | - _target_: gluonts.ev.metrics.SMAPE 16 | - _target_: gluonts.ev.metrics.MSIS 17 | - _target_: gluonts.ev.metrics.RMSE 18 | - _target_: gluonts.ev.metrics.NRMSE 19 | - _target_: gluonts.ev.metrics.ND 20 | - _target_: gluonts.ev.metrics.MeanWeightedSumQuantileLoss 21 | quantile_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 22 | batch_size: 512 23 | min_batch_size: 1 24 | device: auto -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_1.0_R_base.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiForecast 2 | module: 3 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 4 | pretrained_model_name_or_path: Salesforce/moirai-1.0-R-base 5 | num_samples: 100 6 | patch_size: ??? 7 | context_length: ??? -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_1.0_R_large.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiForecast 2 | module: 3 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 4 | pretrained_model_name_or_path: Salesforce/moirai-1.0-R-large 5 | num_samples: 100 6 | patch_size: ??? 7 | context_length: ??? -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_1.0_R_small.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiForecast 2 | module: 3 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 4 | pretrained_model_name_or_path: Salesforce/moirai-1.0-R-small 5 | num_samples: 100 6 | patch_size: ??? 7 | context_length: ??? -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_1.1_R_base.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiForecast 2 | module: 3 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 4 | pretrained_model_name_or_path: Salesforce/moirai-1.1-R-base 5 | num_samples: 100 6 | patch_size: ??? 7 | context_length: ??? -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_1.1_R_large.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiForecast 2 | module: 3 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 4 | pretrained_model_name_or_path: Salesforce/moirai-1.1-R-large 5 | num_samples: 100 6 | patch_size: ??? 7 | context_length: ??? -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_1.1_R_small.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiForecast 2 | module: 3 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 4 | pretrained_model_name_or_path: Salesforce/moirai-1.1-R-small 5 | num_samples: 100 6 | patch_size: ??? 7 | context_length: ??? -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_lightning_ckpt.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiForecast.load_from_checkpoint 2 | checkpoint_path: ... 3 | num_samples: 100 4 | patch_size: ??? 5 | context_length: ??? 6 | -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_moe_1.0_R_base.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai_moe.MoiraiMoEForecast 2 | module: 3 | _target_: uni2ts.model.moirai_moe.MoiraiMoEModule.from_pretrained 4 | pretrained_model_name_or_path: Salesforce/moirai-moe-1.0-R-base 5 | num_samples: 100 6 | patch_size: 16 7 | context_length: ??? -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/eval/model/moirai_moe_1.0_R_small.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai_moe.MoiraiMoEForecast 2 | module: 3 | _target_: uni2ts.model.moirai_moe.MoiraiMoEModule.from_pretrained 4 | pretrained_model_name_or_path: Salesforce/moirai-moe-1.0-R-small 5 | num_samples: 100 6 | patch_size: 16 7 | context_length: ??? -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/data/etth1.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.simple.SimpleDatasetBuilder 2 | dataset: ETTh1 3 | weight: 1000 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/default.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: outputs/finetune/${hydra:runtime.choices.model}/${hydra:runtime.choices.data}/${run_name} 4 | defaults: 5 | - model: ??? 6 | - data: ??? 7 | - val_data: null 8 | - _self_ 9 | run_name: ??? 10 | seed: 0 11 | tf32: true 12 | compile: false # set to mode: default, reduce-overhead, max-autotune 13 | ckpt_path: null 14 | trainer: 15 | _target_: lightning.Trainer 16 | accelerator: auto 17 | strategy: auto 18 | devices: auto 19 | num_nodes: 1 20 | precision: 32 21 | logger: 22 | _target_: lightning.pytorch.loggers.TensorBoardLogger 23 | save_dir: ${hydra:runtime.output_dir} 24 | name: logs 25 | callbacks: 26 | - _target_: lightning.pytorch.callbacks.LearningRateMonitor 27 | logging_interval: epoch 28 | - _target_: lightning.pytorch.callbacks.ModelCheckpoint 29 | dirpath: ${hydra:runtime.output_dir}/checkpoints 30 | monitor: val/PackedNLLLoss 31 | save_weights_only: true 32 | mode: min 33 | save_top_k: 1 34 | every_n_epochs: 1 35 | - _target_: lightning.pytorch.callbacks.EarlyStopping 36 | monitor: val/PackedNLLLoss 37 | min_delta: 0.0 38 | patience: 3 39 | mode: min 40 | strict: false 41 | verbose: true 42 | max_epochs: 100 43 | enable_progress_bar: true 44 | accumulate_grad_batches: 1 45 | gradient_clip_val: 1.0 46 | gradient_clip_algorithm: norm 47 | train_dataloader: 48 | _target_: uni2ts.data.loader.DataLoader 49 | batch_size: 128 50 | batch_size_factor: 2.0 51 | cycle: true 52 | num_batches_per_epoch: 100 53 | shuffle: true 54 | num_workers: 11 55 | collate_fn: 56 | _target_: uni2ts.data.loader.PackCollate 57 | max_length: ${model.module_kwargs.max_seq_len} 58 | seq_fields: ${cls_getattr:${model._target_},seq_fields} 59 | pad_func_map: ${cls_getattr:${model._target_},pad_func_map} 60 | pin_memory: true 61 | drop_last: false 62 | fill_last: false 63 | worker_init_fn: null 64 | prefetch_factor: 2 65 | persistent_workers: true 66 | val_dataloader: 67 | _target_: uni2ts.data.loader.DataLoader 68 | batch_size: 128 69 | batch_size_factor: 2.0 70 | cycle: false 71 | num_batches_per_epoch: null 72 | shuffle: false 73 | num_workers: 11 74 | collate_fn: 75 | _target_: uni2ts.data.loader.PackCollate 76 | max_length: ${model.module_kwargs.max_seq_len} 77 | seq_fields: ${cls_getattr:${model._target_},seq_fields} 78 | pad_func_map: ${cls_getattr:${model._target_},pad_func_map} 79 | pin_memory: false 80 | drop_last: false 81 | fill_last: true 82 | worker_init_fn: null 83 | prefetch_factor: 2 84 | persistent_workers: true -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_1.0_R_base.yaml: -------------------------------------------------------------------------------- 1 | # load a pretrained checkpoint from huggingface hub 2 | _target_: uni2ts.model.moirai.MoiraiFinetune 3 | module: 4 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 5 | pretrained_model_name_or_path: Salesforce/moirai-1.0-R-base 6 | module_kwargs: 7 | _target_: builtins.dict 8 | distr_output: 9 | _target_: uni2ts.distribution.MixtureOutput 10 | components: 11 | - _target_: uni2ts.distribution.StudentTOutput 12 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 13 | - _target_: uni2ts.distribution.NegativeBinomialOutput 14 | - _target_: uni2ts.distribution.LogNormalOutput 15 | d_model: 768 16 | num_layers: 12 17 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 18 | max_seq_len: 512 19 | attn_dropout_p: 0.0 20 | dropout_p: 0.0 21 | scaling: true 22 | min_patches: 2 23 | min_mask_ratio: 0.15 24 | max_mask_ratio: 0.5 25 | max_dim: 128 26 | loss_func: 27 | _target_: uni2ts.loss.packed.PackedNLLLoss 28 | lr: 1e-3 29 | weight_decay: 1e-1 30 | beta1: 0.9 31 | beta2: 0.98 32 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 33 | num_warmup_steps: 0 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_1.0_R_large.yaml: -------------------------------------------------------------------------------- 1 | # load a pretrained checkpoint from huggingface hub 2 | _target_: uni2ts.model.moirai.MoiraiFinetune 3 | module: 4 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 5 | pretrained_model_name_or_path: Salesforce/moirai-1.0-R-large 6 | module_kwargs: 7 | _target_: builtins.dict 8 | distr_output: 9 | _target_: uni2ts.distribution.MixtureOutput 10 | components: 11 | - _target_: uni2ts.distribution.StudentTOutput 12 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 13 | - _target_: uni2ts.distribution.NegativeBinomialOutput 14 | - _target_: uni2ts.distribution.LogNormalOutput 15 | d_model: 1024 16 | num_layers: 24 17 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 18 | max_seq_len: 512 19 | attn_dropout_p: 0.0 20 | dropout_p: 0.0 21 | scaling: true 22 | min_patches: 2 23 | min_mask_ratio: 0.15 24 | max_mask_ratio: 0.5 25 | max_dim: 128 26 | loss_func: 27 | _target_: uni2ts.loss.packed.PackedNLLLoss 28 | lr: 1e-3 29 | weight_decay: 1e-1 30 | beta1: 0.9 31 | beta2: 0.98 32 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 33 | num_warmup_steps: 0 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_1.0_R_small.yaml: -------------------------------------------------------------------------------- 1 | # load a pretrained checkpoint from huggingface hub 2 | _target_: uni2ts.model.moirai.MoiraiFinetune 3 | module: 4 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 5 | pretrained_model_name_or_path: Salesforce/moirai-1.0-R-small 6 | module_kwargs: 7 | _target_: builtins.dict 8 | distr_output: 9 | _target_: uni2ts.distribution.MixtureOutput 10 | components: 11 | - _target_: uni2ts.distribution.StudentTOutput 12 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 13 | - _target_: uni2ts.distribution.NegativeBinomialOutput 14 | - _target_: uni2ts.distribution.LogNormalOutput 15 | d_model: 384 16 | num_layers: 6 17 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 18 | max_seq_len: 512 19 | attn_dropout_p: 0.0 20 | dropout_p: 0.0 21 | scaling: true 22 | min_patches: 2 23 | min_mask_ratio: 0.15 24 | max_mask_ratio: 0.5 25 | max_dim: 128 26 | loss_func: 27 | _target_: uni2ts.loss.packed.PackedNLLLoss 28 | val_metric: 29 | - _target_: uni2ts.loss.packed.PackedMSELoss 30 | - _target_: uni2ts.loss.packed.PackedNRMSELoss 31 | normalize: absolute_target_squared 32 | lr: 1e-3 33 | weight_decay: 1e-1 34 | beta1: 0.9 35 | beta2: 0.98 36 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 37 | num_warmup_steps: 0 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_1.1_R_base.yaml: -------------------------------------------------------------------------------- 1 | # load a pretrained checkpoint from huggingface hub 2 | _target_: uni2ts.model.moirai.MoiraiFinetune 3 | module: 4 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 5 | pretrained_model_name_or_path: Salesforce/moirai-1.1-R-base 6 | module_kwargs: 7 | _target_: builtins.dict 8 | distr_output: 9 | _target_: uni2ts.distribution.MixtureOutput 10 | components: 11 | - _target_: uni2ts.distribution.StudentTOutput 12 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 13 | - _target_: uni2ts.distribution.NegativeBinomialOutput 14 | - _target_: uni2ts.distribution.LogNormalOutput 15 | d_model: 768 16 | num_layers: 12 17 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 18 | max_seq_len: 512 19 | attn_dropout_p: 0.0 20 | dropout_p: 0.0 21 | scaling: true 22 | min_patches: 2 23 | min_mask_ratio: 0.15 24 | max_mask_ratio: 0.5 25 | max_dim: 128 26 | loss_func: 27 | _target_: uni2ts.loss.packed.PackedNLLLoss 28 | lr: 1e-3 29 | weight_decay: 1e-1 30 | beta1: 0.9 31 | beta2: 0.98 32 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 33 | num_warmup_steps: 0 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_1.1_R_large.yaml: -------------------------------------------------------------------------------- 1 | # load a pretrained checkpoint from huggingface hub 2 | _target_: uni2ts.model.moirai.MoiraiFinetune 3 | module: 4 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 5 | pretrained_model_name_or_path: Salesforce/moirai-1.1-R-large 6 | module_kwargs: 7 | _target_: builtins.dict 8 | distr_output: 9 | _target_: uni2ts.distribution.MixtureOutput 10 | components: 11 | - _target_: uni2ts.distribution.StudentTOutput 12 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 13 | - _target_: uni2ts.distribution.NegativeBinomialOutput 14 | - _target_: uni2ts.distribution.LogNormalOutput 15 | d_model: 1024 16 | num_layers: 24 17 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 18 | max_seq_len: 512 19 | attn_dropout_p: 0.0 20 | dropout_p: 0.0 21 | scaling: true 22 | min_patches: 2 23 | min_mask_ratio: 0.15 24 | max_mask_ratio: 0.5 25 | max_dim: 128 26 | loss_func: 27 | _target_: uni2ts.loss.packed.PackedNLLLoss 28 | lr: 1e-3 29 | weight_decay: 1e-1 30 | beta1: 0.9 31 | beta2: 0.98 32 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 33 | num_warmup_steps: 0 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_1.1_R_small.yaml: -------------------------------------------------------------------------------- 1 | # load a pretrained checkpoint from huggingface hub 2 | _target_: uni2ts.model.moirai.MoiraiFinetune 3 | module: 4 | _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained 5 | pretrained_model_name_or_path: Salesforce/moirai-1.1-R-small 6 | module_kwargs: 7 | _target_: builtins.dict 8 | distr_output: 9 | _target_: uni2ts.distribution.MixtureOutput 10 | components: 11 | - _target_: uni2ts.distribution.StudentTOutput 12 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 13 | - _target_: uni2ts.distribution.NegativeBinomialOutput 14 | - _target_: uni2ts.distribution.LogNormalOutput 15 | d_model: 384 16 | num_layers: 6 17 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 18 | max_seq_len: 512 19 | attn_dropout_p: 0.0 20 | dropout_p: 0.0 21 | scaling: true 22 | min_patches: 2 23 | min_mask_ratio: 0.15 24 | max_mask_ratio: 0.5 25 | max_dim: 128 26 | loss_func: 27 | _target_: uni2ts.loss.packed.PackedNLLLoss 28 | val_metric: 29 | - _target_: uni2ts.loss.packed.PackedMSELoss 30 | - _target_: uni2ts.loss.packed.PackedNRMSELoss 31 | normalize: absolute_target_squared 32 | lr: 1e-3 33 | weight_decay: 1e-1 34 | beta1: 0.9 35 | beta2: 0.98 36 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 37 | num_warmup_steps: 0 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_base.yaml: -------------------------------------------------------------------------------- 1 | # load a pytorch lightning checkpoint 2 | _target_: uni2ts.model.moirai.MoiraiFinetune.load_from_checkpoint 3 | module_kwargs: 4 | _target_: builtins.dict 5 | distr_output: 6 | _target_: uni2ts.distribution.MixtureOutput 7 | components: 8 | - _target_: uni2ts.distribution.StudentTOutput 9 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 10 | - _target_: uni2ts.distribution.NegativeBinomialOutput 11 | - _target_: uni2ts.distribution.LogNormalOutput 12 | d_model: 768 13 | num_layers: 12 14 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 15 | max_seq_len: 512 16 | attn_dropout_p: 0.0 17 | dropout_p: 0.0 18 | scaling: true 19 | min_patches: 2 20 | min_mask_ratio: 0.15 21 | max_mask_ratio: 0.5 22 | max_dim: 128 23 | loss_func: 24 | _target_: uni2ts.loss.packed.PackedNLLLoss 25 | lr: 1e-3 26 | weight_decay: 1e-1 27 | beta1: 0.9 28 | beta2: 0.98 29 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 30 | num_warmup_steps: 0 31 | checkpoint_path: ... -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_large.yaml: -------------------------------------------------------------------------------- 1 | # load a pytorch lightning checkpoint 2 | _target_: uni2ts.model.moirai.MoiraiFinetune.load_from_checkpoint 3 | module_kwargs: 4 | _target_: builtins.dict 5 | distr_output: 6 | _target_: uni2ts.distribution.MixtureOutput 7 | components: 8 | - _target_: uni2ts.distribution.StudentTOutput 9 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 10 | - _target_: uni2ts.distribution.NegativeBinomialOutput 11 | - _target_: uni2ts.distribution.LogNormalOutput 12 | d_model: 1024 13 | num_layers: 24 14 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 15 | max_seq_len: 512 16 | attn_dropout_p: 0.0 17 | dropout_p: 0.0 18 | scaling: true 19 | min_patches: 2 20 | min_mask_ratio: 0.15 21 | max_mask_ratio: 0.5 22 | max_dim: 128 23 | loss_func: 24 | _target_: uni2ts.loss.packed.PackedNLLLoss 25 | lr: 1e-3 26 | weight_decay: 1e-1 27 | beta1: 0.9 28 | beta2: 0.98 29 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 30 | num_warmup_steps: 0 31 | checkpoint_path: ... -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_moe_1.0_R_small.yaml: -------------------------------------------------------------------------------- 1 | # load a pytorch lightning checkpoint 2 | _target_: uni2ts.model.moirai.MoiraiFinetune.load_from_checkpoint 3 | module_kwargs: 4 | _target_: builtins.dict 5 | attn_dropout_p: 0 6 | d_ff: 512 7 | d_model: 384 8 | distr_output: 9 | _target_: uni2ts.distribution.mixture.MixtureOutput 10 | components: 11 | - _target_: uni2ts.distribution.student_t.StudentTOutput 12 | - _target_: uni2ts.distribution.normal.NormalFixedScaleOutput 13 | scale: 0.001 14 | - _target_: uni2ts.distribution.negative_binomial.NegativeBinomialOutput 15 | - _target_: uni2ts.distribution.log_normal.LogNormalOutput 16 | dropout_p: 0 17 | max_seq_len: 512 18 | num_layers: 6 19 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 20 | scaling: true 21 | min_patches: 2 22 | min_mask_ratio: 0.15 23 | max_mask_ratio: 0.5 24 | max_dim: 128 25 | loss_func: 26 | _target_: uni2ts.loss.packed.PackedNLLLoss 27 | val_metric: 28 | - _target_: uni2ts.loss.packed.PackedMSELoss 29 | - _target_: uni2ts.loss.packed.PackedNRMSELoss 30 | normalize: absolute_target_squared 31 | lr: 1e-3 32 | weight_decay: 1e-1 33 | beta1: 0.9 34 | beta2: 0.98 35 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 36 | num_warmup_steps: 0 37 | checkpoint_path: ... 38 | -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/model/moirai_small.yaml: -------------------------------------------------------------------------------- 1 | # load a pytorch lightning checkpoint 2 | _target_: uni2ts.model.moirai.MoiraiFinetune.load_from_checkpoint 3 | module_kwargs: 4 | _target_: builtins.dict 5 | distr_output: 6 | _target_: uni2ts.distribution.MixtureOutput 7 | components: 8 | - _target_: uni2ts.distribution.StudentTOutput 9 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 10 | - _target_: uni2ts.distribution.NegativeBinomialOutput 11 | - _target_: uni2ts.distribution.LogNormalOutput 12 | d_model: 384 13 | num_layers: 6 14 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 15 | max_seq_len: 512 16 | attn_dropout_p: 0.0 17 | dropout_p: 0.0 18 | scaling: true 19 | min_patches: 2 20 | min_mask_ratio: 0.15 21 | max_mask_ratio: 0.5 22 | max_dim: 128 23 | loss_func: 24 | _target_: uni2ts.loss.packed.PackedNLLLoss 25 | val_metric: 26 | - _target_: uni2ts.loss.packed.PackedMSELoss 27 | - _target_: uni2ts.loss.packed.PackedNRMSELoss 28 | normalize: absolute_target_squared 29 | lr: 1e-3 30 | weight_decay: 1e-1 31 | beta1: 0.9 32 | beta2: 0.98 33 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 34 | num_warmup_steps: 0 35 | checkpoint_path: ... -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/val_data/etth1.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.ConcatDatasetBuilder 2 | _args_: 3 | _target_: uni2ts.data.builder.simple.generate_eval_builders 4 | dataset: ETTh1_eval 5 | offset: 11520 6 | eval_length: 2880 7 | prediction_lengths: [96, 192, 336, 720] 8 | context_lengths: [1000, 2000, 3000, 4000, 5000] 9 | patch_sizes: [32, 64] -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/finetune/val_data/etth1_multi.yaml: -------------------------------------------------------------------------------- 1 | - _target_: uni2ts.data.builder.simple.SimpleEvalDatasetBuilder 2 | dataset: ETTh1_eval 3 | offset: 11520 4 | windows: 10 5 | distance: 96 6 | prediction_length: 96 7 | context_length: 1000 8 | patch_size: 32 9 | - _target_: uni2ts.data.builder.simple.SimpleEvalDatasetBuilder 10 | dataset: ETTh1_eval 11 | offset: 11520 12 | windows: 10 13 | distance: 192 14 | prediction_length: 192 15 | context_length: 1000 16 | patch_size: 32 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/buildings_900k.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.Buildings900KDatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/buildings_bench.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.BuildingsBenchDatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/cloudops_tsf.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.CloudOpsTSFDatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/cmip6.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.CMIP6DatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/era5.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.ERA5DatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/gluonts.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.GluonTSDatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/largest.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.LargeSTDatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/lib_city.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.LibCityDatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/lotsa_v1_unweighted.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.ConcatDatasetBuilder 2 | _args_: 3 | - _target_: uni2ts.data.builder.lotsa_v1.Buildings900KDatasetBuilder 4 | datasets: ${cls_getattr:${._target_},dataset_list} 5 | - _target_: uni2ts.data.builder.lotsa_v1.BuildingsBenchDatasetBuilder 6 | datasets: ${cls_getattr:${._target_},dataset_list} 7 | - _target_: uni2ts.data.builder.lotsa_v1.CloudOpsTSFDatasetBuilder 8 | datasets: ${cls_getattr:${._target_},dataset_list} 9 | - _target_: uni2ts.data.builder.lotsa_v1.CMIP6DatasetBuilder 10 | datasets: ${cls_getattr:${._target_},dataset_list} 11 | - _target_: uni2ts.data.builder.lotsa_v1.ERA5DatasetBuilder 12 | datasets: ${cls_getattr:${._target_},dataset_list} 13 | - _target_: uni2ts.data.builder.lotsa_v1.GluonTSDatasetBuilder 14 | datasets: ${cls_getattr:${._target_},dataset_list} 15 | - _target_: uni2ts.data.builder.lotsa_v1.LargeSTDatasetBuilder 16 | datasets: ${cls_getattr:${._target_},dataset_list} 17 | - _target_: uni2ts.data.builder.lotsa_v1.LibCityDatasetBuilder 18 | datasets: ${cls_getattr:${._target_},dataset_list} 19 | - _target_: uni2ts.data.builder.lotsa_v1.OthersLOTSADatasetBuilder 20 | datasets: ${cls_getattr:${._target_},dataset_list} 21 | - _target_: uni2ts.data.builder.lotsa_v1.ProEnFoDatasetBuilder 22 | datasets: ${cls_getattr:${._target_},dataset_list} 23 | - _target_: uni2ts.data.builder.lotsa_v1.SubseasonalDatasetBuilder 24 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/others.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.OthersLOTSADatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/proenfo.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.ProEnFoDatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/data/subseasonal.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.data.builder.lotsa_v1.SubseasonalDatasetBuilder 2 | datasets: ${cls_getattr:${._target_},dataset_list} -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/default.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: outputs/pretrain/${hydra:runtime.choices.model}/${hydra:runtime.choices.data}/${run_name} 4 | defaults: 5 | - model: ??? 6 | - data: ??? 7 | - val_data: null 8 | - _self_ 9 | run_name: ??? 10 | seed: 0 11 | tf32: true 12 | compile: false # set to mode: default, reduce-overhead, max-autotune 13 | ckpt_path: null # set to "last" to resume training 14 | trainer: 15 | _target_: lightning.Trainer 16 | accelerator: auto 17 | strategy: auto 18 | devices: auto 19 | num_nodes: 1 20 | precision: 32 21 | logger: 22 | _target_: lightning.pytorch.loggers.TensorBoardLogger 23 | save_dir: ${hydra:runtime.output_dir} 24 | name: logs 25 | callbacks: 26 | - _target_: lightning.pytorch.callbacks.LearningRateMonitor 27 | logging_interval: epoch 28 | - _target_: lightning.pytorch.callbacks.ModelCheckpoint 29 | dirpath: ${hydra:runtime.output_dir}/checkpoints 30 | filename: last 31 | monitor: epoch 32 | mode: max 33 | save_top_k: 1 34 | every_n_epochs: 10 35 | - _target_: lightning.pytorch.callbacks.ModelCheckpoint 36 | dirpath: ${hydra:runtime.output_dir}/checkpoints 37 | monitor: epoch 38 | save_weights_only: true 39 | mode: max 40 | save_top_k: -1 41 | every_n_epochs: ${floordiv:${trainer.max_epochs},10} 42 | - _target_: uni2ts.callbacks.HuggingFaceCheckpoint.HuggingFaceCheckpoint 43 | dirpath: ${hydra:runtime.output_dir}/HF_checkpoints 44 | filename: last 45 | monitor: epoch 46 | mode: max 47 | save_top_k: 1 48 | every_n_epochs: 1 49 | # epoch-based training provides averaged metrics 50 | # cannot use max_steps with epoch-based training - resume from checkpoint on wrong epoch 51 | max_epochs: 1_000 52 | enable_progress_bar: true 53 | accumulate_grad_batches: 1 54 | gradient_clip_val: 1.0 55 | gradient_clip_algorithm: norm 56 | train_dataloader: 57 | _target_: uni2ts.data.loader.DataLoader 58 | batch_size: 128 59 | batch_size_factor: 2.0 60 | cycle: true 61 | num_batches_per_epoch: 100 62 | shuffle: true 63 | num_workers: 11 64 | collate_fn: 65 | _target_: uni2ts.data.loader.PackCollate 66 | max_length: ${model.module_kwargs.max_seq_len} 67 | seq_fields: ${cls_getattr:${model._target_},seq_fields} 68 | pad_func_map: ${cls_getattr:${model._target_},pad_func_map} 69 | pin_memory: true 70 | drop_last: true 71 | fill_last: false 72 | worker_init_fn: null 73 | prefetch_factor: 2 74 | persistent_workers: true -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/model/moirai_base.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiPretrain 2 | module_kwargs: 3 | _target_: builtins.dict 4 | distr_output: 5 | _target_: uni2ts.distribution.MixtureOutput 6 | components: 7 | - _target_: uni2ts.distribution.StudentTOutput 8 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 9 | - _target_: uni2ts.distribution.NegativeBinomialOutput 10 | - _target_: uni2ts.distribution.LogNormalOutput 11 | d_model: 768 12 | num_layers: 12 13 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 14 | max_seq_len: 512 15 | attn_dropout_p: 0.0 16 | dropout_p: 0.0 17 | scaling: true 18 | min_patches: 2 19 | min_mask_ratio: 0.15 20 | max_mask_ratio: 0.5 21 | max_dim: 128 22 | loss_func: 23 | _target_: uni2ts.loss.packed.PackedNLLLoss 24 | lr: 1e-3 25 | weight_decay: 1e-1 26 | beta1: 0.9 27 | beta2: 0.98 28 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 29 | num_warmup_steps: 10_000 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/model/moirai_large.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiPretrain 2 | module_kwargs: 3 | _target_: builtins.dict 4 | distr_output: 5 | _target_: uni2ts.distribution.MixtureOutput 6 | components: 7 | - _target_: uni2ts.distribution.StudentTOutput 8 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 9 | - _target_: uni2ts.distribution.NegativeBinomialOutput 10 | - _target_: uni2ts.distribution.LogNormalOutput 11 | d_model: 1024 12 | num_layers: 24 13 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 14 | max_seq_len: 512 15 | attn_dropout_p: 0.0 16 | dropout_p: 0.0 17 | scaling: true 18 | min_patches: 2 19 | min_mask_ratio: 0.15 20 | max_mask_ratio: 0.5 21 | max_dim: 128 22 | loss_func: 23 | _target_: uni2ts.loss.packed.PackedNLLLoss 24 | lr: 1e-3 25 | weight_decay: 1e-1 26 | beta1: 0.9 27 | beta2: 0.98 28 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 29 | num_warmup_steps: 10_000 -------------------------------------------------------------------------------- /src/uni2ts/cli/conf/pretrain/model/moirai_small.yaml: -------------------------------------------------------------------------------- 1 | _target_: uni2ts.model.moirai.MoiraiPretrain 2 | module_kwargs: 3 | _target_: builtins.dict 4 | distr_output: 5 | _target_: uni2ts.distribution.MixtureOutput 6 | components: 7 | - _target_: uni2ts.distribution.StudentTOutput 8 | - _target_: uni2ts.distribution.NormalFixedScaleOutput 9 | - _target_: uni2ts.distribution.NegativeBinomialOutput 10 | - _target_: uni2ts.distribution.LogNormalOutput 11 | d_model: 384 12 | num_layers: 6 13 | patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} 14 | max_seq_len: 512 15 | attn_dropout_p: 0.0 16 | dropout_p: 0.0 17 | scaling: true 18 | min_patches: 2 19 | min_mask_ratio: 0.15 20 | max_mask_ratio: 0.5 21 | max_dim: 128 22 | loss_func: 23 | _target_: uni2ts.loss.packed.PackedNLLLoss 24 | lr: 1e-3 25 | weight_decay: 1e-1 26 | beta1: 0.9 27 | beta2: 0.98 28 | num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} 29 | num_warmup_steps: 10_000 -------------------------------------------------------------------------------- /src/uni2ts/cli/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import hydra 17 | import torch 18 | from gluonts.time_feature import get_seasonality 19 | from hydra.core.hydra_config import HydraConfig 20 | from hydra.utils import call, instantiate 21 | from omegaconf import DictConfig 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | from uni2ts.eval_util.evaluation import evaluate_model 25 | 26 | 27 | @hydra.main(version_base="1.3", config_path="conf/eval", config_name="default") 28 | def main(cfg: DictConfig): 29 | test_data, metadata = call(cfg.data) 30 | batch_size = cfg.batch_size 31 | while True: 32 | model = call(cfg.model, _partial_=True, _convert_="all")( 33 | prediction_length=metadata.prediction_length, 34 | target_dim=metadata.target_dim, 35 | feat_dynamic_real_dim=metadata.feat_dynamic_real_dim, 36 | past_feat_dynamic_real_dim=metadata.past_feat_dynamic_real_dim, 37 | ) 38 | metrics = instantiate(cfg.metrics, _convert_="all") 39 | try: 40 | predictor = model.create_predictor(batch_size, cfg.device) 41 | res = evaluate_model( 42 | predictor, 43 | test_data=test_data, 44 | metrics=metrics, 45 | batch_size=cfg.batch_size, 46 | axis=None, 47 | mask_invalid_label=True, 48 | allow_nan_forecast=False, 49 | seasonality=get_seasonality(metadata.freq), 50 | ) 51 | print(res) 52 | output_dir = HydraConfig.get().runtime.output_dir 53 | writer = SummaryWriter(log_dir=output_dir) 54 | for name, metric in res.to_dict("records")[0].items(): 55 | writer.add_scalar(f"{metadata.split}_metrics/{name}", metric) 56 | writer.close() 57 | break 58 | except torch.cuda.OutOfMemoryError: 59 | print( 60 | f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size // 2}" 61 | ) 62 | batch_size //= 2 63 | if batch_size < cfg.min_batch_size: 64 | print( 65 | f"batch_size {batch_size} smaller than " 66 | f"min_batch_size {cfg.min_batch_size}, ending evaluation" 67 | ) 68 | break 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /src/uni2ts/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/uni2ts/common/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from collections.abc import Callable 17 | from typing import TypeVar 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | def abstract_class_property(*names: str) -> Callable[[type[T], ...], type[T]]: 23 | def _func(cls: type[T]) -> type[T]: 24 | original_init_subclass = cls.__init_subclass__ 25 | 26 | def _init_subclass(_cls, **kwargs): 27 | # The default implementation of __init_subclass__ takes no 28 | # positional arguments, but a custom implementation does. 29 | # If the user has not reimplemented __init_subclass__ then 30 | # the first signature will fail and we try the second. 31 | try: 32 | original_init_subclass(_cls, **kwargs) 33 | except TypeError: 34 | original_init_subclass(**kwargs) 35 | 36 | # Check that each attribute is defined. 37 | for name in names: 38 | if not hasattr(_cls, name): 39 | raise NotImplementedError( 40 | f"{name} has not been defined for {_cls.__name__}" 41 | ) 42 | if getattr(_cls, name, NotImplemented) is NotImplemented: 43 | raise NotImplementedError( 44 | f"dataset_list has not been defined for {_cls.__name__}" 45 | ) 46 | 47 | cls.__init_subclass__ = classmethod(_init_subclass) 48 | return cls 49 | 50 | return _func 51 | -------------------------------------------------------------------------------- /src/uni2ts/common/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import warnings 18 | from pathlib import Path 19 | from typing import Optional 20 | 21 | from dotenv import load_dotenv 22 | 23 | 24 | def get_path_var(var: Optional[str]) -> Optional[Path]: 25 | if (path := os.getenv(var)) is not None: 26 | return Path(path) 27 | return None 28 | 29 | 30 | class Env: 31 | _instance: Optional["Env"] = None 32 | path_vars: list[str] = [ 33 | "LOTSA_V1_PATH", 34 | "LSF_PATH", 35 | "CUSTOM_DATA_PATH", 36 | "HF_CACHE_PATH", 37 | ] 38 | 39 | def __new__(cls): 40 | if cls._instance is None: 41 | cls._instance = super().__new__(cls) 42 | if not load_dotenv(): 43 | warnings.warn("Failed to load .env file.") 44 | cls.monkey_patch_path_vars() 45 | return cls._instance 46 | 47 | @classmethod 48 | def monkey_patch_path_vars(cls): 49 | for var in cls.path_vars: 50 | setattr(cls, var, get_path_var(var)) 51 | 52 | 53 | env = Env() 54 | -------------------------------------------------------------------------------- /src/uni2ts/common/hydra_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from collections.abc import Callable 17 | from typing import Any 18 | 19 | from hydra.utils import get_class 20 | from omegaconf import OmegaConf 21 | 22 | 23 | def register_resolver(name: str) -> Callable[[Callable], Callable]: 24 | def decorator(resolver: Callable) -> Callable: 25 | OmegaConf.register_new_resolver(name, resolver) 26 | return resolver 27 | 28 | return decorator 29 | 30 | 31 | @register_resolver("as_tuple") 32 | def resolve_as_tuple(ls: list) -> tuple: 33 | return tuple(ls) 34 | 35 | 36 | @register_resolver("cls_getattr") 37 | def resolve_cls_getattr(cls_name: str, attribute_name: str) -> Any: 38 | if cls_name.endswith(".load_from_checkpoint"): 39 | cls_name = cls_name[: -len(".load_from_checkpoint")] 40 | cls = get_class(cls_name) 41 | return getattr(cls, attribute_name) 42 | 43 | 44 | @register_resolver("floordiv") 45 | def resolve_floordiv(a: int, b: int) -> int: 46 | return a // b 47 | 48 | 49 | @register_resolver("mul") 50 | def resolve_mul(a: float, b: float) -> float: 51 | return a * b 52 | -------------------------------------------------------------------------------- /src/uni2ts/common/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from collections.abc import Callable 17 | from functools import partial 18 | from typing import cast 19 | 20 | import numpy as np 21 | 22 | Sampler = Callable[[int | np.ndarray], int | np.ndarray] 23 | 24 | 25 | def uniform_sampler(n: int | np.ndarray) -> int | np.ndarray: 26 | return np.random.randint(1, n + 1) 27 | 28 | 29 | def binomial_sampler(n: int | np.ndarray, p: float = 0.5) -> int | np.ndarray: 30 | return np.random.binomial(n - 1, p) + 1 31 | 32 | 33 | def beta_binomial_sampler( 34 | n: int | np.ndarray, a: float = 1, b: float = 1 35 | ) -> int | np.ndarray: 36 | # equivalent to uniform_sampler when a = b = 1 37 | if isinstance(n, np.ndarray): 38 | p = np.random.beta(a, b, size=n.shape) 39 | else: 40 | p = np.random.beta(a, b) 41 | return np.random.binomial(n - 1, p) + 1 42 | 43 | 44 | def get_sampler(distribution: str, **kwargs) -> Sampler: 45 | if distribution == "uniform": 46 | return uniform_sampler 47 | elif distribution == "binomial": 48 | p = kwargs.get("p", 0.5) 49 | return cast(Sampler, partial(binomial_sampler, p=p)) 50 | elif distribution == "beta_binomial": 51 | a = kwargs.get("a", 1) 52 | b = kwargs.get("b", 1) 53 | return cast(Sampler, partial(beta_binomial_sampler, a=a, b=b)) 54 | else: 55 | raise NotImplementedError(f"distribution {distribution} not implemented") 56 | -------------------------------------------------------------------------------- /src/uni2ts/common/typing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from collections.abc import Callable, Iterable 17 | from typing import Any 18 | 19 | import numpy as np 20 | import torch 21 | from jaxtyping import AbstractDtype, Num 22 | 23 | 24 | class DateTime64(AbstractDtype): 25 | dtypes = ["datetime64"] 26 | 27 | 28 | class Character(AbstractDtype): 29 | dtypes = ["str_"] 30 | 31 | 32 | # Data preparation 33 | GenFunc = Callable[[], Iterable[dict[str, Any]]] 34 | SliceableGenFunc = Callable[..., Iterable[dict[str, Any]]] 35 | 36 | 37 | # Indexer 38 | DateTime = DateTime64[np.ndarray, ""] 39 | BatchedDateTime = DateTime64[np.ndarray, "batch"] 40 | String = np.character 41 | BatchedString = Character[np.ndarray, "batch"] 42 | UnivarTimeSeries = Num[np.ndarray, "time"] 43 | MultivarTimeSeries = Num[np.ndarray, "var time"] 44 | Data = DateTime | String | UnivarTimeSeries | MultivarTimeSeries 45 | BatchedData = ( 46 | BatchedDateTime | BatchedString | list[UnivarTimeSeries] | list[MultivarTimeSeries] 47 | ) 48 | FlattenedData = DateTime | String | list[UnivarTimeSeries] 49 | 50 | 51 | # Loader 52 | Sample = dict[str, Num[torch.Tensor, "*sample"]] 53 | BatchedSample = dict[str, Num[torch.Tensor, "batch *sample"]] 54 | -------------------------------------------------------------------------------- /src/uni2ts/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/uni2ts/data/builder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ._base import ConcatDatasetBuilder, DatasetBuilder 17 | 18 | __all__ = [ 19 | "DatasetBuilder", 20 | "ConcatDatasetBuilder", 21 | ] 22 | -------------------------------------------------------------------------------- /src/uni2ts/data/builder/_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import abc 17 | from typing import Any, Callable 18 | 19 | from torch.utils.data import ConcatDataset, Dataset 20 | 21 | from uni2ts.transform import Transformation 22 | 23 | 24 | # TODO: Add __repr__ 25 | class DatasetBuilder(abc.ABC): 26 | """ 27 | Base class for DatasetBuilders. 28 | """ 29 | 30 | @abc.abstractmethod 31 | def build_dataset(self, *args, **kwargs): 32 | """ 33 | Builds the dataset into the required file format. 34 | """ 35 | ... 36 | 37 | @abc.abstractmethod 38 | def load_dataset( 39 | self, transform_map: dict[Any, Callable[..., Transformation]] 40 | ) -> Dataset: 41 | """ 42 | Load the dataset. 43 | 44 | :param transform_map: a map which returns the required dataset transformations to be applied 45 | :return: the dataset ready for training 46 | """ 47 | ... 48 | 49 | 50 | class ConcatDatasetBuilder(DatasetBuilder): 51 | """ 52 | Concatenates DatasetBuilders such that they can be loaded together. 53 | """ 54 | 55 | def __init__(self, *builders: DatasetBuilder): 56 | """ 57 | :param builders: DatasetBuilders to be concatenated together. 58 | """ 59 | super().__init__() 60 | assert len(builders) > 0, "Must provide at least one builder to ConcatBuilder" 61 | assert all( 62 | isinstance(builder, DatasetBuilder) for builder in builders 63 | ), "All builders must be instances of DatasetBuilder" 64 | self.builders: tuple[DatasetBuilder, ...] = builders 65 | 66 | def build_dataset(self): 67 | raise ValueError( 68 | "Do not use ConcatBuilder to build datasets, build sub datasets individually instead." 69 | ) 70 | 71 | def load_dataset( 72 | self, transform_map: dict[Any, Callable[..., Transformation]] 73 | ) -> ConcatDataset: 74 | """ 75 | Loads all builders with ConcatDataset. 76 | 77 | :param transform_map: a map which returns the required dataset transformations to be applied 78 | :return: the dataset ready for training 79 | """ 80 | return ConcatDataset( 81 | [builder.load_dataset(transform_map) for builder in self.builders] 82 | ) 83 | -------------------------------------------------------------------------------- /src/uni2ts/data/builder/lotsa_v1/README.md: -------------------------------------------------------------------------------- 1 | ### ProEnFo 2 | 3 | 1. Download the relevant data files from the [ProEnFo repository](https://github.com/Leo-VK/ProEnFo), and place them in folders with their respective dataset names. 4 | 2. Register the directory path with the ProEnFo datasets by running: ```echo "PROENFO_PATH=ADD_YOUR_PATH" >> .env``` 5 | 3. Run the following command to process the datasets: ```python -m uni2ts.builder.lotsa_v1 proenfo``` 6 | 7 | ` 8 | Warning! The command above reads the pickle data files provided by the ProEnFo repository. 9 | Deserializing pickle files can lead to remote code execution. 10 | Please be careful when dealing with untrusted pickle files. 11 | ` -------------------------------------------------------------------------------- /src/uni2ts/data/builder/lotsa_v1/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ._base import LOTSADatasetBuilder 17 | from .buildings_bench import Buildings900KDatasetBuilder, BuildingsBenchDatasetBuilder 18 | from .cloudops_tsf import CloudOpsTSFDatasetBuilder 19 | from .cmip6 import CMIP6DatasetBuilder 20 | from .era5 import ERA5DatasetBuilder 21 | from .gluonts import GluonTSDatasetBuilder 22 | from .largest import LargeSTDatasetBuilder 23 | from .lib_city import LibCityDatasetBuilder 24 | from .others import OthersLOTSADatasetBuilder 25 | from .proenfo import ProEnFoDatasetBuilder 26 | from .subseasonal import SubseasonalDatasetBuilder 27 | 28 | __all__ = [ 29 | "LOTSADatasetBuilder", 30 | "Buildings900KDatasetBuilder", 31 | "BuildingsBenchDatasetBuilder", 32 | "CloudOpsTSFDatasetBuilder", 33 | "CMIP6DatasetBuilder", 34 | "ERA5DatasetBuilder", 35 | "GluonTSDatasetBuilder", 36 | "LargeSTDatasetBuilder", 37 | "LibCityDatasetBuilder", 38 | "OthersLOTSADatasetBuilder", 39 | "ProEnFoDatasetBuilder", 40 | "SubseasonalDatasetBuilder", 41 | ] 42 | -------------------------------------------------------------------------------- /src/uni2ts/data/builder/lotsa_v1/cloudops_tsf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from collections import defaultdict 18 | from functools import partial 19 | from typing import Any, Generator 20 | 21 | import datasets 22 | from datasets import Features, Sequence, Value, load_dataset, load_dataset_builder 23 | from gluonts.dataset.common import ProcessDataEntry 24 | from gluonts.dataset.split import DateSplitter 25 | 26 | from uni2ts.common.env import env 27 | from uni2ts.data.dataset import TimeSeriesDataset 28 | 29 | from ._base import LOTSADatasetBuilder 30 | 31 | 32 | class CloudOpsTSFDatasetBuilder(LOTSADatasetBuilder): 33 | dataset_list = [ 34 | "azure_vm_traces_2017", 35 | "borg_cluster_data_2011", 36 | "alibaba_cluster_trace_2018", 37 | ] 38 | dataset_type_map = defaultdict(lambda: TimeSeriesDataset) 39 | dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset)) 40 | 41 | def build_dataset(self, dataset: str, num_proc: int = os.cpu_count()): 42 | cloudops_dataset = load_dataset( 43 | path="Salesforce/cloudops_tsf", name=dataset, split="pretrain" 44 | ) 45 | cfg = load_dataset_builder( 46 | path="Salesforce/cloudops_tsf", 47 | name=dataset, 48 | ).config 49 | pde = ProcessDataEntry( 50 | freq=cfg.freq, one_dim_target=cfg.univariate, use_timestamp=False 51 | ) 52 | splitter = DateSplitter(cfg.test_split_date) 53 | 54 | def process(entry): 55 | return next(iter(splitter.split([pde(entry)])[0])) 56 | 57 | def gen_func(ids: list[int]) -> Generator[dict[str, Any], None, None]: 58 | for item in cloudops_dataset.select(ids): 59 | item = process(item) 60 | yield dict( 61 | item_id=item["item_id"], 62 | start=item["start"].to_timestamp(), 63 | freq=cfg.freq, 64 | target=item["target"], 65 | past_feat_dynamic_real=item["past_feat_dynamic_real"], 66 | ) 67 | 68 | target_feature = ( 69 | Sequence(Value("float32")) 70 | if cfg.target_dim == 1 71 | else Sequence(Sequence(Value("float32")), length=cfg.target_dim) 72 | ) 73 | past_feat_dynamic_real_feature = Sequence( 74 | Sequence(Value("float32")), length=cfg.past_feat_dynamic_real_dim 75 | ) 76 | 77 | hf_dataset = datasets.Dataset.from_generator( 78 | gen_func, 79 | features=Features( 80 | dict( 81 | item_id=Value("string"), 82 | start=Value("timestamp[s]"), 83 | freq=Value("string"), 84 | target=target_feature, 85 | past_feat_dynamic_real=past_feat_dynamic_real_feature, 86 | ) 87 | ), 88 | gen_kwargs={"ids": [i for i in range(len(cloudops_dataset))]}, 89 | num_proc=num_proc, 90 | cache_dir=env.HF_CACHE_PATH, 91 | ) 92 | hf_dataset.info.dataset_name = dataset 93 | hf_dataset.save_to_disk( 94 | self.storage_path / dataset, 95 | num_proc=10, 96 | ) 97 | -------------------------------------------------------------------------------- /src/uni2ts/data/builder/lotsa_v1/largest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from collections import defaultdict 18 | from functools import partial 19 | from pathlib import Path 20 | from typing import Any, Generator 21 | 22 | import datasets 23 | import pandas as pd 24 | from datasets import Features, Sequence, Value 25 | 26 | from uni2ts.common.env import env 27 | from uni2ts.data.dataset import TimeSeriesDataset 28 | 29 | from ._base import LOTSADatasetBuilder 30 | 31 | 32 | class LargeSTDatasetBuilder(LOTSADatasetBuilder): 33 | dataset_list = [ 34 | "largest_2017", 35 | "largest_2018", 36 | "largest_2019", 37 | "largest_2020", 38 | "largest_2021", 39 | ] 40 | dataset_type_map = defaultdict(lambda: TimeSeriesDataset) 41 | dataset_load_func_map = defaultdict(lambda: partial(TimeSeriesDataset)) 42 | 43 | def build_dataset(self, dataset: str, num_proc: int = os.cpu_count()): 44 | year = dataset.split("_")[-1] 45 | df = pd.read_hdf(Path(os.getenv("LARGEST_PATH")) / f"ca_his_raw_{year}.h5") 46 | 47 | def gen_func(cols: list[int]) -> Generator[dict[str, Any], None, None]: 48 | for col in cols: 49 | if df[col].isnull().all(): 50 | continue 51 | yield dict( 52 | item_id=f"{col}", 53 | start=df.index[0], 54 | target=df[col], 55 | freq="5T", 56 | ) 57 | 58 | hf_dataset = datasets.Dataset.from_generator( 59 | gen_func, 60 | features=Features( 61 | dict( 62 | item_id=Value("string"), 63 | start=Value("timestamp[s]"), 64 | freq=Value("string"), 65 | target=Sequence(Value("float32")), 66 | ) 67 | ), 68 | num_proc=num_proc, 69 | gen_kwargs={"cols": list(df.columns)}, 70 | cache_dir=env.HF_CACHE_PATH, 71 | ) 72 | hf_dataset.info.dataset_name = dataset 73 | hf_dataset.save_to_disk(self.storage_path / dataset) 74 | -------------------------------------------------------------------------------- /src/uni2ts/data/indexer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ._base import Indexer 17 | from .hf_dataset_indexer import HuggingFaceDatasetIndexer 18 | 19 | __all__ = ["Indexer", "HuggingFaceDatasetIndexer"] 20 | -------------------------------------------------------------------------------- /src/uni2ts/distribution/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ._base import DistributionOutput, DistrParamProj 17 | from .laplace import LaplaceFixedScaleOutput, LaplaceOutput 18 | from .log_normal import LogNormalOutput 19 | from .mixture import MixtureOutput 20 | from .negative_binomial import NegativeBinomialOutput 21 | from .normal import NormalFixedScaleOutput, NormalOutput 22 | from .pareto import ParetoFixedAlphaOutput, ParetoOutput 23 | from .student_t import StudentTOutput 24 | 25 | DISTRIBUTION_OUTPUTS = [ 26 | "LaplaceFixedScaleOutput", 27 | "LaplaceOutput", 28 | "LogNormalOutput", 29 | "MixtureOutput", 30 | "NegativeBinomialOutput", 31 | "NormalFixedScaleOutput", 32 | "NormalOutput", 33 | "ParetoFixedAlphaOutput", 34 | "StudentTOutput", 35 | ] 36 | 37 | __all__ = ["DistrParamProj", "DistributionOutput"] + DISTRIBUTION_OUTPUTS 38 | -------------------------------------------------------------------------------- /src/uni2ts/distribution/laplace.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Callable, Optional 17 | 18 | import torch 19 | from jaxtyping import Float, PyTree 20 | from torch.distributions import Laplace 21 | from torch.nn import functional as F 22 | 23 | from ._base import DistributionOutput 24 | 25 | 26 | class LaplaceOutput(DistributionOutput): 27 | distr_cls = Laplace 28 | args_dim = dict(loc=1, scale=1) 29 | 30 | @property 31 | def domain_map( 32 | self, 33 | ) -> PyTree[ 34 | Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" 35 | ]: 36 | return dict(loc=self._loc, scale=self._scale) 37 | 38 | @staticmethod 39 | def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 40 | return loc.squeeze(-1) 41 | 42 | @staticmethod 43 | def _scale(scale: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 44 | epsilon = torch.finfo(scale.dtype).eps 45 | return F.softplus(scale).clamp_min(epsilon).squeeze(-1) 46 | 47 | 48 | class LaplaceFixedScaleOutput(DistributionOutput): 49 | distr_cls = Laplace 50 | args_dim = dict(loc=1) 51 | 52 | def __init__(self, scale: float = 1e-3): 53 | self.scale = scale 54 | 55 | @property 56 | def domain_map( 57 | self, 58 | ) -> PyTree[ 59 | Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" 60 | ]: 61 | return dict(loc=self._loc) 62 | 63 | @staticmethod 64 | def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 65 | return loc.squeeze(-1) 66 | 67 | def _distribution( 68 | self, 69 | distr_params: PyTree[Float[torch.Tensor, "*batch 1"], "T"], 70 | validate_args: Optional[bool] = None, 71 | ) -> Laplace: 72 | loc = distr_params["loc"] 73 | distr_params["scale"] = torch.as_tensor( 74 | self.scale, dtype=loc.dtype, device=loc.device 75 | ) 76 | return self.distr_cls(**distr_params, validate_args=validate_args) 77 | -------------------------------------------------------------------------------- /src/uni2ts/distribution/log_normal.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Callable 17 | 18 | import torch 19 | from jaxtyping import Float, PyTree 20 | from torch.distributions import LogNormal 21 | from torch.nn import functional as F 22 | 23 | from ._base import DistributionOutput 24 | 25 | 26 | class LogNormalOutput(DistributionOutput): 27 | distr_cls = LogNormal 28 | args_dim = dict(loc=1, scale=1) 29 | 30 | @property 31 | def domain_map( 32 | self, 33 | ) -> PyTree[ 34 | Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" 35 | ]: 36 | return dict(loc=self._loc, scale=self._scale) 37 | 38 | @staticmethod 39 | def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 40 | return loc.squeeze(-1) 41 | 42 | @staticmethod 43 | def _scale(scale: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 44 | epsilon = torch.finfo(scale.dtype).eps 45 | return F.softplus(scale).clamp_min(epsilon).squeeze(-1) 46 | -------------------------------------------------------------------------------- /src/uni2ts/distribution/normal.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Callable, Optional 17 | 18 | import torch 19 | from jaxtyping import Float, PyTree 20 | from torch.distributions import Normal 21 | from torch.nn import functional as F 22 | 23 | from ._base import DistributionOutput 24 | 25 | 26 | class NormalOutput(DistributionOutput): 27 | distr_cls = Normal 28 | args_dim = dict(loc=1, scale=1) 29 | 30 | @property 31 | def domain_map(self) -> PyTree[Callable, "T"]: 32 | return dict( 33 | loc=self._loc, 34 | scale=self._scale, 35 | ) 36 | 37 | @staticmethod 38 | def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 39 | return loc.squeeze(-1) 40 | 41 | @staticmethod 42 | def _scale(scale: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 43 | epsilon = torch.finfo(scale.dtype).eps 44 | return F.softplus(scale).clamp_min(epsilon).squeeze(-1) 45 | 46 | 47 | class NormalFixedScaleOutput(DistributionOutput): 48 | distr_cls = Normal 49 | args_dim = dict(loc=1) 50 | 51 | def __init__(self, scale: float = 1e-3): 52 | self.scale = scale 53 | 54 | @property 55 | def domain_map( 56 | self, 57 | ) -> PyTree[ 58 | Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" 59 | ]: 60 | return dict(loc=self._loc) 61 | 62 | @staticmethod 63 | def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 64 | return loc.squeeze(-1) 65 | 66 | def _distribution( 67 | self, 68 | distr_params: PyTree[Float[torch.Tensor, "*batch 1"], "T"], 69 | validate_args: Optional[bool] = None, 70 | ) -> Normal: 71 | loc = distr_params["loc"] 72 | distr_params["scale"] = torch.as_tensor( 73 | self.scale, dtype=loc.dtype, device=loc.device 74 | ) 75 | return self.distr_cls(**distr_params, validate_args=validate_args) 76 | -------------------------------------------------------------------------------- /src/uni2ts/distribution/pareto.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Callable, Optional 17 | 18 | import torch 19 | from jaxtyping import Float, PyTree 20 | from torch.distributions import Pareto 21 | from torch.nn import functional as F 22 | 23 | from ._base import DistributionOutput 24 | 25 | 26 | class ParetoOutput(DistributionOutput): 27 | distr_cls = Pareto 28 | args_dim = dict(scale=1, alpha=1) 29 | 30 | @property 31 | def domain_map( 32 | self, 33 | ) -> PyTree[ 34 | Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" 35 | ]: 36 | return dict(scale=self._scale, alpha=self._alpha) 37 | 38 | def _scale( 39 | self, scale: Float[torch.Tensor, "*batch 1"] 40 | ) -> Float[torch.Tensor, "*batch"]: 41 | epsilon = torch.finfo(scale.dtype).eps 42 | return F.softplus(scale).clamp_min(epsilon).squeeze(-1) 43 | 44 | def _alpha( 45 | self, alpha: Float[torch.Tensor, "*batch 1"] 46 | ) -> Float[torch.Tensor, "*batch"]: 47 | epsilon = torch.finfo(alpha.dtype).eps 48 | return (2.0 + F.softplus(alpha).clamp_min(epsilon)).squeeze(-1) 49 | 50 | 51 | class ParetoFixedAlphaOutput(DistributionOutput): 52 | distr_cls = Pareto 53 | args_dim = dict(scale=1) 54 | 55 | def __init__(self, alpha: float = 3.0): 56 | assert alpha > 0.0 57 | self.alpha = alpha 58 | 59 | @property 60 | def domain_map( 61 | self, 62 | ) -> PyTree[ 63 | Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" 64 | ]: 65 | return dict(scale=self._scale) 66 | 67 | def _scale( 68 | self, scale: Float[torch.Tensor, "*batch 1"] 69 | ) -> Float[torch.Tensor, "*batch"]: 70 | epsilon = torch.finfo(scale.dtype).eps 71 | return F.softplus(scale).clamp_min(epsilon).squeeze(-1) 72 | 73 | def _distribution( 74 | self, 75 | distr_params: PyTree[Float[torch.Tensor, "*batch 1"], "T"], 76 | validate_args: Optional[bool] = None, 77 | ) -> Pareto: 78 | scale = distr_params["scale"] 79 | distr_params["alpha"] = torch.as_tensor( 80 | self.alpha, dtype=scale.dtype, device=scale.device 81 | ) 82 | return self.distr_cls(**distr_params, validate_args=validate_args) 83 | -------------------------------------------------------------------------------- /src/uni2ts/distribution/student_t.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Callable 17 | 18 | import torch 19 | from jaxtyping import Float, PyTree 20 | from torch.distributions import StudentT 21 | from torch.nn import functional as F 22 | 23 | from ._base import DistributionOutput 24 | 25 | 26 | class StudentTOutput(DistributionOutput): 27 | distr_cls = StudentT 28 | args_dim = dict(df=1, loc=1, scale=1) 29 | 30 | @property 31 | def domain_map( 32 | self, 33 | ) -> PyTree[ 34 | Callable[[Float[torch.Tensor, "*batch 1"]], Float[torch.Tensor, "*batch"]], "T" 35 | ]: 36 | return dict(df=self._df, loc=self._loc, scale=self._scale) 37 | 38 | @staticmethod 39 | def _df(df: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 40 | return (2.0 + F.softplus(df)).squeeze(-1) 41 | 42 | @staticmethod 43 | def _loc(loc: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 44 | return loc.squeeze(-1) 45 | 46 | @staticmethod 47 | def _scale(scale: Float[torch.Tensor, "*batch 1"]) -> Float[torch.Tensor, "*batch"]: 48 | epsilon = torch.finfo(scale.dtype).eps 49 | return F.softplus(scale).clamp_min(epsilon).squeeze(-1) 50 | -------------------------------------------------------------------------------- /src/uni2ts/eval_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdityaLab/Samay/6b677e8ca3666259034412aeb5ae4765732b2172/src/uni2ts/eval_util/__init__.py -------------------------------------------------------------------------------- /src/uni2ts/eval_util/_hf_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import datasets 4 | 5 | from uni2ts.common.env import env 6 | 7 | 8 | class HFDataset: 9 | def __init__(self, dataset_name: str, storage_path: Path = env.CUSTOM_DATA_PATH): 10 | self.hf_dataset = datasets.load_from_disk( 11 | str(storage_path / dataset_name) 12 | ).with_format("numpy") 13 | self.freq = self.hf_dataset[0]["freq"] 14 | self.target_dim = ( 15 | target.shape[-1] 16 | if len((target := self.hf_dataset[0]["target"]).shape) > 1 17 | else 1 18 | ) 19 | 20 | def __iter__(self): 21 | for sample in self.hf_dataset: 22 | sample["start"] = sample["start"].item() 23 | yield sample 24 | -------------------------------------------------------------------------------- /src/uni2ts/eval_util/metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Optional 4 | 5 | from gluonts.ev.aggregations import Mean 6 | from gluonts.ev.metrics import BaseMetricDefinition, DirectMetric 7 | from gluonts.ev.stats import squared_error 8 | 9 | 10 | @dataclass 11 | class MedianMSE(BaseMetricDefinition): 12 | """Mean Squared Error""" 13 | 14 | forecast_type: str = "0.5" 15 | 16 | def __call__(self, axis: Optional[int] = None) -> DirectMetric: 17 | return DirectMetric( 18 | name=f"MSE[{self.forecast_type}]", 19 | stat=partial(squared_error, forecast_type=self.forecast_type), 20 | aggregate=Mean(axis=axis), 21 | ) 22 | -------------------------------------------------------------------------------- /src/uni2ts/eval_util/plot.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | from gluonts import maybe 7 | from gluonts.model import Forecast 8 | 9 | 10 | def plot_single( 11 | inp: dict, 12 | label: dict, 13 | forecast: Forecast, 14 | context_length: int, 15 | intervals: tuple[float, ...] = (0.5, 0.9), 16 | ax: Optional[plt.axis] = None, 17 | dim: Optional[int] = None, 18 | name: Optional[str] = None, 19 | show_label: bool = False, 20 | ): 21 | ax = maybe.unwrap_or_else(ax, plt.gca) 22 | 23 | target = np.concatenate([inp["target"], label["target"]], axis=-1) 24 | start = inp["start"] 25 | if dim is not None: 26 | target = target[dim] 27 | forecast = forecast.copy_dim(dim) 28 | 29 | index = pd.period_range(start, periods=len(target), freq=start.freq) 30 | ax.plot( 31 | index.to_timestamp()[-context_length - forecast.prediction_length :], 32 | target[-context_length - forecast.prediction_length :], 33 | label="target", 34 | color="black", 35 | ) 36 | forecast.plot( 37 | intervals=intervals, 38 | ax=ax, 39 | color="blue", 40 | name=name, 41 | show_label=show_label, 42 | ) 43 | ax.set_xticks(ax.get_xticks()) 44 | ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right") 45 | ax.legend(loc="lower left") 46 | 47 | 48 | def plot_next_multi( 49 | axes: np.ndarray, 50 | input_it: Iterator[dict], 51 | label_it: Iterator[dict], 52 | forecast_it: Iterator[Forecast], 53 | context_length: int, 54 | intervals: tuple[float, ...] = (0.5, 0.9), 55 | dim: Optional[int] = None, 56 | name: Optional[str] = None, 57 | show_label: bool = False, 58 | ): 59 | axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes] 60 | for ax, inp, label, forecast in zip(axes, input_it, label_it, forecast_it): 61 | plot_single( 62 | inp, 63 | label, 64 | forecast, 65 | context_length, 66 | intervals=intervals, 67 | ax=ax, 68 | dim=dim, 69 | name=name, 70 | show_label=show_label, 71 | ) 72 | -------------------------------------------------------------------------------- /src/uni2ts/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/uni2ts/loss/packed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ._base import PackedDistributionLoss, PackedLoss, PackedPointLoss 17 | from .distribution import PackedNLLLoss 18 | from .normalized import ( 19 | PackedNMAELoss, 20 | PackedNMLSELoss, 21 | PackedNMSELoss, 22 | PackedNRMSELoss, 23 | PackedPointNormalizedLoss, 24 | PointNormType, 25 | ) 26 | from .percentage_error import PackedMAPELoss, PackedSMAPELoss 27 | from .point import PackedMAELoss, PackedMSELoss, PackedRMSELoss 28 | 29 | __all__ = [ 30 | "PackedDistributionLoss", 31 | "PackedLoss", 32 | "PackedMAELoss", 33 | "PackedMAPELoss", 34 | "PackedMSELoss", 35 | "PackedNLLLoss", 36 | "PackedNMAELoss", 37 | "PackedNMLSELoss", 38 | "PackedNMSELoss", 39 | "PackedNRMSELoss", 40 | "PackedPointLoss", 41 | "PackedPointNormalizedLoss", 42 | "PackedRMSELoss", 43 | "PackedSMAPELoss", 44 | "PointNormType", 45 | ] 46 | -------------------------------------------------------------------------------- /src/uni2ts/loss/packed/distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from jaxtyping import Bool, Float, Int 18 | from torch.distributions import Distribution 19 | 20 | from ._base import PackedDistributionLoss 21 | 22 | 23 | class PackedNLLLoss(PackedDistributionLoss): 24 | def _loss_func( 25 | self, 26 | pred: Distribution, 27 | target: Float[torch.Tensor, "*batch seq_len #dim"], 28 | prediction_mask: Bool[torch.Tensor, "*batch seq_len"], 29 | observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], 30 | sample_id: Int[torch.Tensor, "*batch seq_len"], 31 | variate_id: Int[torch.Tensor, "*batch seq_len"], 32 | ): 33 | l = -pred.log_prob(target) 34 | return l 35 | -------------------------------------------------------------------------------- /src/uni2ts/loss/packed/percentage_error.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from jaxtyping import Bool, Float, Int 18 | from torch.nn import functional as F 19 | 20 | from uni2ts.common.torch_util import safe_div 21 | 22 | from ._base import PackedPointLoss 23 | 24 | 25 | class PackedMAPELoss(PackedPointLoss): 26 | def _loss_func( 27 | self, 28 | pred: Float[torch.Tensor, "*batch seq_len #dim"], 29 | target: Float[torch.Tensor, "*batch seq_len #dim"], 30 | prediction_mask: Bool[torch.Tensor, "*batch seq_len"], 31 | observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], 32 | sample_id: Int[torch.Tensor, "*batch seq_len"], 33 | variate_id: Int[torch.Tensor, "*batch seq_len"], 34 | ) -> Float[torch.Tensor, "*batch seq_len #dim"]: 35 | loss = F.l1_loss(pred, target, reduction="none") 36 | loss = safe_div(loss, target.abs()) 37 | return 100 * loss 38 | 39 | 40 | class PackedSMAPELoss(PackedPointLoss): 41 | def _loss_func( 42 | self, 43 | pred: Float[torch.Tensor, "*batch seq_len #dim"], 44 | target: Float[torch.Tensor, "*batch seq_len #dim"], 45 | prediction_mask: Bool[torch.Tensor, "*batch seq_len"], 46 | observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], 47 | sample_id: Int[torch.Tensor, "*batch seq_len"], 48 | variate_id: Int[torch.Tensor, "*batch seq_len"], 49 | ) -> Float[torch.Tensor, "*batch seq_len #dim"]: 50 | loss = F.l1_loss(pred, target, reduction="none") 51 | loss = safe_div(loss, target.abs() + pred.detach().abs()) 52 | return 200 * loss 53 | -------------------------------------------------------------------------------- /src/uni2ts/loss/packed/point.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .normalized import PackedNMAELoss, PackedNMSELoss, PackedNRMSELoss, PointNormType 17 | 18 | 19 | class PackedMAELoss(PackedNMAELoss): 20 | def __init__(self): 21 | super().__init__(normalize=PointNormType.NONE) 22 | 23 | 24 | class PackedMSELoss(PackedNMSELoss): 25 | def __init__(self): 26 | super().__init__(normalize=PointNormType.NONE) 27 | 28 | 29 | class PackedRMSELoss(PackedNRMSELoss): 30 | def __init__(self): 31 | super().__init__(normalize=PointNormType.NONE) 32 | -------------------------------------------------------------------------------- /src/uni2ts/model/moirai/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from uni2ts.model.moirai.finetune import MoiraiFinetune 18 | from uni2ts.model.moirai.forecast import MoiraiForecast 19 | from uni2ts.model.moirai.module import MoiraiModule 20 | from uni2ts.model.moirai.pretrain import MoiraiPretrain 21 | 22 | # from .finetune import MoiraiFinetune 23 | # from .forecast import MoiraiForecast 24 | # from .module import MoiraiModule 25 | # from .pretrain import MoiraiPretrain 26 | 27 | __all__ = [ 28 | "MoiraiFinetune", 29 | "MoiraiForecast", 30 | "MoiraiModule", 31 | "MoiraiPretrain", 32 | ] 33 | -------------------------------------------------------------------------------- /src/uni2ts/model/moirai_moe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from uni2ts.model.moirai_moe.forecast import MoiraiMoEForecast 16 | from uni2ts.model.moirai_moe.module import MoiraiMoEModule 17 | 18 | # from .forecast import MoiraiMoEForecast 19 | # from .module import MoiraiMoEModule 20 | 21 | __all__ = [ 22 | "MoiraiMoEForecast", 23 | "MoiraiMoEModule", 24 | ] 25 | -------------------------------------------------------------------------------- /src/uni2ts/module/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/uni2ts/module/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Optional 17 | 18 | import torch 19 | from jaxtyping import Float 20 | from torch import nn 21 | 22 | 23 | class RMSNorm(nn.Module): 24 | def __init__( 25 | self, 26 | normalized_shape: int | list[int] | torch.Size, 27 | eps: float = 1e-5, 28 | weight: bool = True, 29 | dtype: Optional[torch.dtype] = None, 30 | ): 31 | super().__init__() 32 | if isinstance(normalized_shape, int): 33 | normalized_shape = (normalized_shape,) 34 | 35 | self.normalized_shape = normalized_shape 36 | self.eps = eps 37 | self.mean_dim = tuple(range(-len(normalized_shape), 0)) 38 | 39 | if weight: 40 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype)) 41 | else: 42 | self.register_parameter("weight", None) 43 | 44 | def forward( 45 | self, x: Float[torch.Tensor, "*batch normalized_shape"] 46 | ) -> Float[torch.Tensor, "*batch normalized_shape"]: 47 | output = x * torch.rsqrt( 48 | x.pow(2).mean(dim=self.mean_dim, keepdim=True) + self.eps 49 | ) 50 | if self.weight is not None: 51 | return output * self.weight 52 | return output 53 | 54 | def extra_repr(self) -> str: 55 | return ( 56 | f"normalized_shape={self.normalized_shape}, " 57 | f"eps={self.eps}, " 58 | f"weight={self.weight is not None}" 59 | ) 60 | -------------------------------------------------------------------------------- /src/uni2ts/module/position/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .additive import LearnedEmbedding, SinusoidalPositionEncoding 17 | from .attn_bias import ( 18 | AttentionBias, 19 | BinaryAttentionBias, 20 | LinearAttentionBias, 21 | RelativeAttentionBias, 22 | ) 23 | from .attn_projection import ( 24 | IdentityProjection, 25 | LearnedProjection, 26 | Projection, 27 | QueryKeyProjection, 28 | RotaryProjection, 29 | ) 30 | 31 | __all__ = [ 32 | "AttentionBias", 33 | "IdentityProjection", 34 | "RelativeAttentionBias", 35 | "BinaryAttentionBias", 36 | "LearnedEmbedding", 37 | "LearnedProjection", 38 | "LinearAttentionBias", 39 | "Projection", 40 | "QueryKeyProjection", 41 | "RotaryProjection", 42 | "SinusoidalPositionEncoding", 43 | ] 44 | -------------------------------------------------------------------------------- /src/uni2ts/module/position/additive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | 18 | import torch 19 | from jaxtyping import Float, Int 20 | from torch import nn 21 | 22 | 23 | class SinusoidalPositionEncoding(nn.Module): 24 | def __init__( 25 | self, 26 | *, 27 | width: int, 28 | max_len: int, 29 | normalize: bool = True, 30 | ): 31 | """ 32 | Construct a sinusoidal positional embedding module. 33 | 34 | :param width: 35 | Width of the embedding. 36 | :param max_len: 37 | Maximum length of the embedding. 38 | :param normalize: 39 | Perform L2 normalization of the embedding. 40 | """ 41 | super().__init__() 42 | 43 | position = torch.arange(max_len).unsqueeze(1) 44 | div_term = torch.exp(torch.arange(0, width, 2) * (-math.log(10000.0) / width)) 45 | 46 | pe = torch.zeros(max_len, width) 47 | pe[:, 0::2] = torch.sin(position * div_term) 48 | pe[:, 1::2] = torch.cos(position * div_term) 49 | 50 | if normalize: 51 | l2 = torch.linalg.vector_norm(pe, dim=-1) 52 | pe /= l2.unsqueeze(-1) 53 | 54 | self.register_buffer("pe", pe, persistent=False) 55 | 56 | def forward( 57 | self, pos_id: Int[torch.Tensor, "*batch length"] 58 | ) -> Float[torch.Tensor, "*batch length dim"]: 59 | return self.pe[pos_id] 60 | 61 | 62 | class LearnedEmbedding(nn.Module): 63 | def __init__( 64 | self, 65 | *, 66 | width: int, 67 | max_len: int, 68 | ): 69 | super().__init__() 70 | self.pe = nn.Embedding( 71 | max_len, 72 | width, 73 | ) 74 | 75 | def forward( 76 | self, pos_id: Int[torch.Tensor, "*batch length"] 77 | ) -> Float[torch.Tensor, "*batch length dim"]: 78 | return self.pe(pos_id) 79 | -------------------------------------------------------------------------------- /src/uni2ts/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import SchedulerType, get_scheduler 2 | 3 | __all__ = ["SchedulerType", "get_scheduler"] 4 | -------------------------------------------------------------------------------- /src/uni2ts/transform/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ._base import Chain, Identity, Transformation 17 | from .crop import EvalCrop, PatchCrop 18 | from .feature import AddObservedMask, AddTimeIndex, AddVariateIndex 19 | from .field import LambdaSetFieldIfNotPresent, RemoveFields, SelectFields, SetValue 20 | from .imputation import DummyValueImputation, ImputeTimeSeries, LastValueImputation 21 | from .pad import EvalPad, Pad, PadFreq 22 | from .patch import ( 23 | DefaultPatchSizeConstraints, 24 | FixedPatchSizeConstraints, 25 | GetPatchSize, 26 | Patchify, 27 | PatchSizeConstraints, 28 | ) 29 | from .resample import SampleDimension 30 | from .reshape import ( 31 | FlatPackCollection, 32 | FlatPackFields, 33 | PackCollection, 34 | PackFields, 35 | SequencifyField, 36 | Transpose, 37 | ) 38 | from .task import EvalMaskedPrediction, ExtendMask, MaskedPrediction 39 | 40 | __all__ = [ 41 | "AddObservedMask", 42 | "AddTimeIndex", 43 | "AddVariateIndex", 44 | "Chain", 45 | "DefaultPatchSizeConstraints", 46 | "DummyValueImputation", 47 | "EvalCrop", 48 | "EvalMaskedPrediction", 49 | "EvalPad", 50 | "ExtendMask", 51 | "FixedPatchSizeConstraints", 52 | "FlatPackCollection", 53 | "FlatPackFields", 54 | "GetPatchSize", 55 | "Identity", 56 | "ImputeTimeSeries", 57 | "LambdaSetFieldIfNotPresent", 58 | "LastValueImputation", 59 | "MaskedPrediction", 60 | "PackCollection", 61 | "PackFields", 62 | "Pad", 63 | "PadFreq", 64 | "PatchCrop", 65 | "PatchSizeConstraints", 66 | "Patchify", 67 | "RemoveFields", 68 | "SampleDimension", 69 | "SelectFields", 70 | "SequencifyField", 71 | "SetValue", 72 | "Transformation", 73 | "Transpose", 74 | ] 75 | -------------------------------------------------------------------------------- /src/uni2ts/transform/_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import abc 17 | from dataclasses import dataclass 18 | from typing import Any 19 | 20 | 21 | class Transformation(abc.ABC): 22 | @abc.abstractmethod 23 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: ... 24 | 25 | def chain(self, other: "Transformation") -> "Chain": 26 | return Chain([self, other]) 27 | 28 | def __add__(self, other: "Transformation") -> "Chain": 29 | return self.chain(other) 30 | 31 | def __radd__(self, other): 32 | if other == 0: 33 | return self 34 | return other + self 35 | 36 | 37 | @dataclass 38 | class Chain(Transformation): 39 | """ 40 | Chain multiple transformations together. 41 | """ 42 | 43 | transformations: list[Transformation] 44 | 45 | def __post_init__(self) -> None: 46 | transformations = [] 47 | 48 | for transformation in self.transformations: 49 | if isinstance(transformation, Identity): 50 | continue 51 | elif isinstance(transformation, Chain): 52 | transformations.extend(transformation.transformations) 53 | else: 54 | assert isinstance(transformation, Transformation) 55 | transformations.append(transformation) 56 | 57 | self.transformations = transformations 58 | self.__init_passed_kwargs__ = {"transformations": transformations} 59 | 60 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 61 | for t in self.transformations: 62 | data_entry = t(data_entry) 63 | return data_entry 64 | 65 | 66 | class Identity(Transformation): 67 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 68 | return data_entry 69 | -------------------------------------------------------------------------------- /src/uni2ts/transform/field.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from collections.abc import Callable 17 | from dataclasses import dataclass 18 | from typing import Any 19 | 20 | from ._base import Transformation 21 | 22 | 23 | @dataclass 24 | class SetValue: 25 | value: Any 26 | 27 | def __call__(self, data_entry: dict[str, Any]) -> Any: 28 | return self.value 29 | 30 | 31 | @dataclass 32 | class LambdaSetFieldIfNotPresent(Transformation): 33 | field: str 34 | get_value: Callable[[dict[str, Any]], Any] 35 | 36 | @staticmethod 37 | def set_field(data_entry: dict[str, Any], field: str, value: Any) -> dict[str, Any]: 38 | if field not in data_entry.keys(): 39 | data_entry[field] = value 40 | return data_entry 41 | 42 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 43 | return self.set_field(data_entry, self.field, self.get_value(data_entry)) 44 | 45 | 46 | @dataclass 47 | class SelectFields(Transformation): 48 | fields: list[str] 49 | allow_missing: bool = False 50 | 51 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 52 | if self.allow_missing: 53 | return {f: data_entry[f] for f in self.fields if f in data_entry} 54 | return {f: data_entry[f] for f in self.fields} 55 | 56 | 57 | @dataclass 58 | class RemoveFields(Transformation): 59 | fields: list[str] 60 | 61 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 62 | for k in self.fields: 63 | data_entry.pop(k, None) 64 | return data_entry 65 | -------------------------------------------------------------------------------- /src/uni2ts/transform/imputation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | from typing import Any 18 | 19 | import numpy as np 20 | from jaxtyping import Num 21 | 22 | from ._base import Transformation 23 | from ._mixin import ApplyFuncMixin 24 | 25 | 26 | class ImputationMethod: 27 | def __call__( 28 | self, x: Num[np.ndarray, "length *dim"] 29 | ) -> Num[np.ndarray, "length *dim"]: ... 30 | 31 | 32 | @dataclass(frozen=True) 33 | class DummyValueImputation(ImputationMethod): 34 | value: int | float | complex = 0.0 35 | 36 | def __call__( 37 | self, x: Num[np.ndarray, "length *dim"] 38 | ) -> Num[np.ndarray, "length *dim"]: 39 | x[np.isnan(x)] = self.value 40 | return x 41 | 42 | 43 | @dataclass(frozen=True) 44 | class LastValueImputation(ImputationMethod): 45 | value: int | float | complex = 0.0 46 | 47 | def __call__( 48 | self, x: Num[np.ndarray, "length *dim"] 49 | ) -> Num[np.ndarray, "length *dim"]: 50 | x = x.T 51 | x[0:1][np.isnan(x[0:1])] = self.value 52 | mask = np.isnan(x) 53 | idx = np.arange(len(x)) 54 | if x.ndim == 2: 55 | idx = np.expand_dims(idx, axis=1) 56 | idx = np.where(~mask, idx, 0) 57 | idx = np.maximum.accumulate(idx, axis=0) 58 | if x.ndim == 2: 59 | x = x[idx, np.arange(x.shape[1])] 60 | else: 61 | x = x[idx] 62 | return x.T 63 | 64 | 65 | class CausalMeanImputation(ImputationMethod): 66 | # TODO: implement causal mean imputation 67 | def __call__( 68 | self, x: Num[np.ndarray, "length *dim"], value: int | float | complex = 0.0 69 | ) -> Num[np.ndarray, "length *dim"]: ... 70 | 71 | 72 | @dataclass 73 | class ImputeTimeSeries(ApplyFuncMixin, Transformation): 74 | fields: tuple[str, ...] 75 | optional_fields: tuple[str, ...] = tuple() 76 | imputation_method: ImputationMethod = DummyValueImputation(value=0.0) 77 | 78 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 79 | self.apply_func( 80 | self._impute, 81 | data_entry, 82 | self.fields, 83 | optional_fields=self.optional_fields, 84 | ) 85 | return data_entry 86 | 87 | def _impute(self, data_entry: dict[str, Any], field: str): 88 | value = data_entry[field] 89 | nan_entries = np.isnan(value) 90 | if nan_entries.any(): 91 | data_entry[field] = self.imputation_method(value) 92 | -------------------------------------------------------------------------------- /src/uni2ts/transform/resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Salesforce, Inc. 2 | # SPDX-License-Identifier: Apache-2 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | from functools import partial 18 | from typing import Any 19 | 20 | import numpy as np 21 | 22 | from uni2ts.common.sampler import Sampler, get_sampler 23 | from uni2ts.common.typing import UnivarTimeSeries 24 | 25 | from ._base import Transformation 26 | from ._mixin import CheckArrNDimMixin, CollectFuncMixin, MapFuncMixin 27 | 28 | 29 | @dataclass 30 | class SampleDimension( 31 | CheckArrNDimMixin, CollectFuncMixin, MapFuncMixin, Transformation 32 | ): 33 | max_dim: int 34 | fields: tuple[str, ...] 35 | optional_fields: tuple[str, ...] = tuple() 36 | sampler: Sampler = get_sampler("uniform") 37 | 38 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 39 | total_field_dim = sum( 40 | self.collect_func_list( 41 | self._get_dim, 42 | data_entry, 43 | self.fields, 44 | optional_fields=self.optional_fields, 45 | ) 46 | ) 47 | self.map_func( 48 | partial(self._process, total_field_dim=total_field_dim), # noqa 49 | data_entry, 50 | self.fields, 51 | optional_fields=self.optional_fields, 52 | ) 53 | return data_entry 54 | 55 | def _get_dim(self, data_entry: dict[str, Any], field: str) -> int: 56 | self.check_ndim(field, data_entry[field], 2) 57 | return len(data_entry[field]) 58 | 59 | def _process( 60 | self, data_entry: dict[str, Any], field: str, total_field_dim: int 61 | ) -> list[UnivarTimeSeries]: 62 | arr: list[UnivarTimeSeries] = data_entry[field] 63 | rand_idx = np.random.permutation(len(arr)) 64 | field_max_dim = (self.max_dim * len(arr)) // total_field_dim 65 | n = self.sampler(min(len(arr), field_max_dim)) 66 | return [arr[idx] for idx in rand_idx[:n]] 67 | 68 | 69 | @dataclass 70 | class Subsample(Transformation): # just take every n-th element 71 | fields: tuple[str, ...] = ("target", "past_feat_dynamic_real") 72 | 73 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 74 | pass 75 | 76 | 77 | class GaussianFilterSubsample( 78 | Subsample 79 | ): # blur using gaussian filter before subsampling 80 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 81 | # gaussian filter 82 | return super()(data_entry) 83 | 84 | 85 | class Downsample(Transformation): # aggregate 86 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 87 | pass 88 | 89 | 90 | class Upsample(Transformation): 91 | def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: 92 | pass 93 | -------------------------------------------------------------------------------- /transform_ILI.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | if __name__ == "__main__": 4 | df = pd.read_csv("data/Flu_USA/ILINet.csv") 5 | 6 | df = df[df["REGION TYPE"] == "National"] 7 | df = df[df["WEEK"] != 53] 8 | 9 | df["date"] = pd.to_datetime( 10 | df["YEAR"].astype(str) + "-W" + df["WEEK"].astype(str) + "-1", 11 | format="%Y-W%U-%w", 12 | ) 13 | 14 | result = [] 15 | for i in range(len(df) - 1): 16 | result.append(df.iloc[i]) 17 | 18 | if (df.iloc[i + 1]["date"] - df.iloc[i]["date"]).days == 14: 19 | new_row = df.iloc[i].copy() 20 | new_row["date"] = df.iloc[i]["date"] + pd.Timedelta(days=7) 21 | result.append(new_row) 22 | 23 | result.append(df.iloc[-1]) 24 | df = pd.DataFrame(result) 25 | 26 | df = df.drop(columns=["YEAR", "WEEK"]) 27 | gaps = df["date"].diff().dropna().unique() 28 | print("Unique time intervals:", gaps) 29 | df["time_diff"] = df["date"].diff() 30 | 31 | rows_with_14_days = df[df["time_diff"] == pd.Timedelta(days=14)] 32 | print(rows_with_14_days) 33 | df = df.drop(columns=["time_diff"]) 34 | infered_freq = pd.infer_freq(df["date"]) 35 | print(f"Infered frequency: {infered_freq}") 36 | 37 | df.to_csv("data/Flu_USA/Flu_USA.csv", index=False) 38 | 39 | print("Data saved to output.csv") 40 | -------------------------------------------------------------------------------- /transform_monash.py: -------------------------------------------------------------------------------- 1 | from samay.utils import arrow_to_csv 2 | import os 3 | import pandas as pd 4 | 5 | FREQS = { 6 | "weather": "1D", 7 | "tourism_yearly": "1YE", 8 | "tourism_quarterly": "1Q", 9 | "tourism_monthly": "1M", 10 | "cif_2016": "1M", 11 | "london_smart_meters": "30min", 12 | "australian_electricity_demand": "30min", 13 | "wind_farms_minutely": "1min", 14 | "bitcoin": "1D", 15 | "pedestrian_counts": "1h", 16 | "vehicle_trips": "1D", 17 | "kdd_cup_2018": "1H", 18 | "nn5_daily": "1D", 19 | "nn5_weekly": "1W", 20 | "kaggle_web_traffic": "1D", 21 | "kaggle_web_traffic_weekly": "1W", 22 | "solar_10_minutes": "10min", 23 | "solar_weekly": "1W", 24 | "car_parts": "1M", 25 | "fred_md": "1M", 26 | "traffic_hourly": "1h", 27 | "traffic_weekly": "1W", 28 | "hospital": "1M", 29 | "covid_deaths": "1D", 30 | "sunspot": "1D", 31 | "saugeenday": "1D", 32 | "us_births": "1D", 33 | "solar_4_seconds": "4s", 34 | "wind_4_seconds": "4s", 35 | "rideshare": "1h", 36 | "oikolab_weather": "1h", 37 | "temperature_rain": "1D" 38 | } 39 | 40 | 41 | if __name__ == "__main__": 42 | monash_dir = "data/monash" 43 | dataset_list = os.listdir(monash_dir) 44 | splits = ["train", "validation", "test"] 45 | for dataset in dataset_list: 46 | # if not dataset in ["rideshare"]: 47 | # continue 48 | print(f"Converting {dataset} dataset") 49 | for split in splits: 50 | arrow_dir = os.path.join(monash_dir, dataset, split) 51 | freq = FREQS[dataset] 52 | if os.path.exists(os.path.join(monash_dir, dataset, split + "/data.csv")) == False: 53 | if freq == "1YE": 54 | freq = "1Y" 55 | arrow_to_csv(arrow_dir, freq) 56 | csv_file = os.path.join(monash_dir, dataset, split + "/data.csv") 57 | df = pd.read_csv(csv_file) 58 | # fill missing values with 0 59 | df.fillna(0, inplace=True) 60 | df.to_csv(csv_file, index=False) 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | --------------------------------------------------------------------------------