├── .circleci └── config.yml ├── .flake8 ├── .github ├── dependabot.yml └── workflows │ ├── codesee-arch-diagram.yml │ └── python-publish.yml ├── .gitignore ├── .idea └── workspace.xml ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── basic_ae.rst │ ├── basic_utils.rst │ ├── buil_dataset.rst │ ├── closest_station.rst │ ├── conf.py │ ├── crossformer.rst │ ├── custom_opt.rst │ ├── custom_types.rst │ ├── d_linear.rst │ ├── data_converter.rst │ ├── dummy_torch.rst │ ├── evaluator.rst │ ├── explain_model_output.rst │ ├── index.rst │ ├── inference.rst │ ├── informer.rst │ ├── interpolate_preprocess.rst │ ├── itransformer.rst │ ├── long_train.rst │ ├── lower_upper_config.rst │ ├── meta_models.rst │ ├── model.rst │ ├── model_dict_function.rst │ ├── modules.rst │ ├── multi_head_base.rst │ ├── pre_dict.rst │ ├── preprocess_da_rnn.rst │ ├── preprocess_metadata.rst │ ├── process_usgs.rst │ ├── pytorch_loaders.rst │ ├── pytorch_training.rst │ ├── temporal_feats.rst │ ├── time_model.rst │ ├── train_da.rst │ ├── trainer.rst │ ├── training_utils.rst │ ├── transformer_basic.rst │ ├── transformer_bottleneck.rst │ ├── transformer_xl.rst │ ├── utils.rst │ └── utils_da.rst ├── flood_forecast ├── basic │ ├── base_line_methods.py │ ├── d_n_linear.py │ ├── gru_vanilla.py │ ├── linear_regression.py │ └── lstm_vanilla.py ├── custom │ ├── custom_activation.py │ ├── custom_opt.py │ ├── dilate_loss.py │ └── focal_loss.py ├── da_rnn │ ├── README.md │ ├── checkpoint │ │ ├── .gitkeep │ │ ├── decoder.pth │ │ └── encoder.pth │ ├── config │ │ ├── dec_kwargs.json │ │ └── enc_kwargs.json │ ├── constants.py │ ├── custom_types.py │ ├── main_predict.py │ ├── model.py │ ├── modules.py │ ├── train_da.py │ └── utils.py ├── data_analysis │ └── Flood Severity Info.ipynb ├── deployment │ └── inference.py ├── evaluator.py ├── explain_model_output.py ├── gcp_integration │ └── basic_utils.py ├── long_train.py ├── meta_models │ ├── basic_ae.py │ └── merging_model.py ├── meta_train.py ├── model_dict_function.py ├── multi_models │ └── crossvivit.py ├── plot_functions.py ├── pre_dict.py ├── preprocessing │ ├── buil_dataset.py │ ├── closest_station.py │ ├── data_converter.py │ ├── eco_gage_set.py │ ├── interpolate_preprocess.py │ ├── preprocess_da_rnn.py │ ├── preprocess_metadata.py │ ├── process_usgs.py │ ├── pytorch_loaders.py │ ├── temporal_feats.py │ └── virus_dataset │ │ └── .gitkeep ├── pytorch_training.py ├── series_id_helper.py ├── temporal_decoding.py ├── time_model.py ├── trainer.py ├── training_utils.py ├── transformer_xl │ ├── anomaly_transformer.py │ ├── attn.py │ ├── basis_former.py │ ├── cross_former.py │ ├── data_embedding.py │ ├── dsanet.py │ ├── dummy_torch.py │ ├── informer.py │ ├── itransformer.py │ ├── lower_upper_config.py │ ├── masks.py │ ├── multi_head_base.py │ ├── transformer_basic.py │ ├── transformer_bottleneck.py │ └── transformer_xl.py └── utils.py ├── requirements.txt ├── setup.py └── tests ├── 24_May_202202_25PM_1.json ├── __init__.py ├── auto_encoder.json ├── classification_test.json ├── config.json ├── cross_former.json ├── custom_encode.json ├── da_meta.json ├── da_rnn.json ├── da_rnn_probabilistic.json ├── data_format_tests.py ├── data_loader_tests.py ├── decoder_test.json ├── dlinear.json ├── dsanet.json ├── dsanet_3.json ├── full_transformer.json ├── gru_vanilla.json ├── lstm_probabilistic_test.json ├── lstm_test.json ├── meta_data_test.json ├── meta_data_tests.py ├── multi_config.json ├── multi_decoder_test.json ├── multi_modal_tests └── test_cross_vivit.py ├── multi_test.json ├── multitask_decoder.json ├── nlinear.json ├── probabilistic_linear_regression_test.json ├── pytorc_train_tests.py ├── scaling_json.json ├── test2.csv ├── test_attn.py ├── test_classification2_loader.py ├── test_da_rnn.py ├── test_data ├── asos-12N.csv ├── asos-12N_small.csv ├── asos_process.json ├── asos_raw.csv ├── avocad_small.csv ├── big_black_md.json ├── big_black_test_small.csv ├── fake_test_small.csv ├── farm_ex.csv ├── ff_test.csv ├── full_out.json ├── imputation_test.csv ├── keag_small.csv ├── river_test.csv ├── river_test_sm.csv ├── small_test.csv ├── solar_small.csv ├── test2.csv ├── test_asos_raw.csv └── test_format_data.csv ├── test_data2 ├── keag_small.csv └── test_csv.csv ├── test_decoder.py ├── test_deployment.py ├── test_dual.json ├── test_evaluation.py ├── test_explain_model_output.py ├── test_handle_multi_crit.py ├── test_iTransformer.json ├── test_inf_single.json ├── test_informer.json ├── test_informer.py ├── test_init ├── chick_final.csv └── keag_small.csv ├── test_join.py ├── test_loss.py ├── test_merging_models.py ├── test_meta_pr.py ├── test_multitask_decoder.py ├── test_plot.py ├── test_preprocessing.py ├── test_preprocessing_ae.py ├── test_series_id.py ├── test_squashed.py ├── test_variable_length.py ├── time_model_test.py ├── transformer_b_series.json ├── transformer_bottleneck.json ├── transformer_gaussian.json ├── tsmixer_test.json ├── usgs_tests.py ├── validation_loop_test.py └── variable_autoencoderl.json /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max_line_length=122 3 | ignore=E305,W504,E126,E401,E721,F722 4 | max-complexity=20 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | open-pull-requests-limit: 10 8 | ignore: 9 | - dependency-name: mpld3 10 | versions: 11 | - 0.5.2 12 | - dependency-name: torchvision 13 | versions: 14 | - 0.8.2 15 | -------------------------------------------------------------------------------- /.github/workflows/codesee-arch-diagram.yml: -------------------------------------------------------------------------------- 1 | # This is v2.0 of this workflow file for codesee 1 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request_target: 7 | types: [opened, synchronize, reopened] 8 | 9 | name: CodeSee 10 | 11 | permissions: read-all 12 | 13 | jobs: 14 | codesee: 15 | runs-on: ubuntu-latest 16 | continue-on-error: true 17 | name: Analyze the repo with CodeSee 18 | steps: 19 | - uses: Codesee-io/codesee-action@v2 20 | with: 21 | codesee-token: ${{ secrets.CODESEE_ARCH_DIAG_API_TOKEN }} 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created (must bump version) 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload --skip-existing dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | flood_forecast.egg-info 4 | *.DS_STORE 5 | *.pyc 6 | tests/runs 7 | tests/output/ 8 | .ipynb_checkpoints 9 | .pytest_cache/ 10 | data 11 | mypy 12 | .mypy_cache 13 | *.png 14 | *.svf 15 | *.svg 16 | .idea/flow-forecast.iml 17 | .idea/inspectionProfiles/profiles_settings.xml 18 | .idea/misc.xml 19 | .idea/vcs.xml 20 | .idea/workspace.xml 21 | .idea/workspace.xml 22 | .idea/workspace.xml 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | 9 | - repo: https://github.com/hhatto/autopep8 10 | rev: v2.0.4 11 | hooks: 12 | - id: autopep8 13 | args: [--in-place, --aggressive, --aggressive, --max-line-length=122] 14 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for Sphinx projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | # You can also specify other tool versions: 13 | # nodejs: "20" 14 | # rust: "1.70" 15 | # golang: "1.20" 16 | 17 | # Build documentation in the "docs/" directory with Sphinx 18 | sphinx: 19 | configuration: docs/source/conf.py 20 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs 21 | # builder: "dirhtml" 22 | # Fail on all warnings to avoid broken references 23 | # fail_on_warning: true 24 | 25 | # Optionally build your docs in additional formats such as PDF and ePub 26 | # formats: 27 | # - pdf 28 | # - epub 29 | 30 | # Optional but recommended, declare the Python requirements required 31 | # to build your documentation 32 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 33 | python: 34 | install: 35 | - requirements: docs/requirements.txt 36 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | shap==0.47.0 2 | scikit-learn>=1.0.1 3 | pandas 4 | torch 5 | tb-nightly 6 | seaborn 7 | future 8 | h5py 9 | wandb==0.19.3 10 | google-cloud 11 | google-cloud-storage 12 | plotly~=5.24.0 13 | pytz>=2022.1 14 | setuptools~=76.0.0 15 | numpy==1.26 16 | requests 17 | torchvision>=0.6.0 18 | mpld3>=0.5 19 | numba>=0.50 20 | sphinx 21 | sphinx-rtd-theme 22 | sphinx-autodoc-typehints 23 | sphinx 24 | einops 25 | pytorch-tsmixer 26 | einsum 27 | jaxtyping 28 | -------------------------------------------------------------------------------- /docs/source/basic_ae.rst: -------------------------------------------------------------------------------- 1 | Simple Auto Encoder 2 | ================== 3 | 4 | .. automodule:: flood_forecast.meta_models.basic_ae 5 | :members: 6 | 7 | A simple auto-encoder model. 8 | -------------------------------------------------------------------------------- /docs/source/basic_utils.rst: -------------------------------------------------------------------------------- 1 | Basic Google Cloud Platform Utilities 2 | ====================================== 3 | 4 | Flow Forecast natively integrates with Google Cloud Platform. 5 | 6 | .. automodule:: flood_forecast.gcp_integration.basic_utils 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/buil_dataset.rst: -------------------------------------------------------------------------------- 1 | Build FF original Dataset 2 | ========================= 3 | 4 | 5 | .. automodule:: flood_forecast.preprocessing.buil_dataset 6 | :members: 7 | -------------------------------------------------------------------------------- /docs/source/closest_station.rst: -------------------------------------------------------------------------------- 1 | Closest Station 2 | ================ 3 | 4 | .. automodule:: flood_forecast.preprocessing.closest_station 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | # sys.path.insert(0, os.path.abspath('.')) 17 | sys.path.insert(0, os.path.abspath('../../')) 18 | # sys.path.insert(0, os.path.abspath('../flood_forecast')) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'Flow Forecast' 23 | copyright = '2020, Isaac Godfried' 24 | author = 'Isaac Godfried' 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = '0.0.1' 28 | 29 | # -- General configuration --------------------------------------------------- 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | 'sphinx.ext.autodoc', 36 | # 'sphinx.ext.intersphinx', 37 | 'sphinx.ext.todo', 38 | 'sphinx.ext.coverage', 39 | 'sphinx.ext.viewcode', 40 | # 'sphinx_autodoc_typehints', 41 | ] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ['_templates'] 45 | 46 | # List of patterns, relative to source directory, that match files and 47 | # directories to ignore when looking for source files. 48 | # This pattern also affects html_static_path and html_extra_path. 49 | exclude_patterns = [] 50 | 51 | master_doc = 'index' 52 | 53 | # -- Options for HTML output ------------------------------------------------- 54 | 55 | # The theme to use for HTML and HTML Help pages. See the documentation for 56 | # a list of builtin themes. 57 | # 58 | html_theme = 'sphinx_rtd_theme' 59 | 60 | # Add any paths that contain custom static files (such as style sheets) here, 61 | # relative to this directory. They are copied after the builtin static files, 62 | # so a file named "default.css" will overwrite the builtin "default.css". 63 | html_static_path = [] # '_static' 64 | 65 | # Example configuration for intersphinx: refer to the Python standard library. 66 | # intersphinx_mapping = { 67 | # 'python': ('https://docs.python.org/3', None), 68 | # 'setuptools': ('https://setuptools.readthedocs.io/en/latest/', None), 69 | # } 70 | 71 | # autodoc_member_order = 'bysource' 72 | # autoclass_content = 'both' 73 | 74 | # if os.environ.get('READTHEDOCS', None): 75 | # tags.add('readthedocs') 76 | 77 | # if 'READTHEDOCS' not in os.environ: 78 | # import cython_generated_ext 79 | 80 | autodoc_default_options = { 81 | 'members': True, 82 | 'member-order': 'bysource', 83 | 'special-members': '__init__', 84 | 'undoc-members': True, 85 | 'exclude-members': '__weakref__' 86 | } 87 | -------------------------------------------------------------------------------- /docs/source/crossformer.rst: -------------------------------------------------------------------------------- 1 | Crossformer 2 | ========================= 3 | .. automodule:: flood_forecast.transformer_xl.cross_former 4 | :members: 5 | -------------------------------------------------------------------------------- /docs/source/custom_opt.rst: -------------------------------------------------------------------------------- 1 | Custom Optimizers 2 | ==================== 3 | 4 | .. automodule:: flood_forecast.custom.custom_opt 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/custom_types.rst: -------------------------------------------------------------------------------- 1 | Custom Types 2 | ============= 3 | 4 | .. automodule:: flood_forecast.da_rnn.custom_types 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/d_linear.rst: -------------------------------------------------------------------------------- 1 | D and N Linear 2 | ================== 3 | 4 | .. automodule:: flood_forecast.basic.d_n_linear 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/data_converter.rst: -------------------------------------------------------------------------------- 1 | Data Converter 2 | ============== 3 | This module holds functions to convert data effectively. 4 | 5 | .. automodule:: flood_forecast.preprocessing.data_converter 6 | :members: 7 | -------------------------------------------------------------------------------- /docs/source/dummy_torch.rst: -------------------------------------------------------------------------------- 1 | Dummy Torch Model 2 | ================== 3 | 4 | .. automodule:: flood_forecast.transformer_xl.dummy_torch 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/evaluator.rst: -------------------------------------------------------------------------------- 1 | Model Evaluation 2 | ================= 3 | 4 | .. automodule:: flood_forecast.evaluator 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/explain_model_output.rst: -------------------------------------------------------------------------------- 1 | Explain Model Output 2 | ==================== 3 | 4 | .. automodule:: flood_forecast.explain_model_output 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Flow Forecast documentation master file, created by 2 | sphinx-quickstart on Sun Aug 2 16:20:18 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Flow Forecast's documentation! 7 | ========================================= 8 | 9 | Flow Forecast is a deep learning for time series forecasting framework written in PyTorch. Flow Forecast makes it easy to train PyTorch Forecast models on a wide variety 10 | of datasets. This documentation describes the internal Python code that makes up Flow Forecast. 11 | 12 | .. automodule:: flood_forecast 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: General: 17 | 18 | evaluator 19 | long_train 20 | model_dict_function 21 | pre_dict 22 | pytorch_training 23 | time_model 24 | trainer 25 | explain_model_output 26 | utils 27 | 28 | .. automodule:: flood_forecast.deployment 29 | 30 | .. toctree:: 31 | :maxdepth: 2 32 | :caption: Deployment 33 | 34 | inference 35 | 36 | .. automodule:: flood_forecast.preprocessing 37 | 38 | .. toctree:: 39 | :maxdepth: 2 40 | :caption: Preprocessing: 41 | 42 | interpolate_preprocess 43 | buil_dataset 44 | closest_station 45 | data_converter 46 | preprocess_da_rnn 47 | preprocess_metadata 48 | process_usgs 49 | pytorch_loaders 50 | temporal_feats 51 | 52 | .. automodule:: flood_forecast.custom 53 | 54 | .. toctree:: 55 | :maxdepth: 2 56 | :caption: Custom: 57 | 58 | custom_opt 59 | 60 | .. automodule:: flood_forecast.transformer_xl 61 | 62 | .. toctree:: 63 | :maxdepth: 2 64 | :caption: TransformerXL: 65 | 66 | dummy_torch 67 | itransformer 68 | lower_upper_config 69 | multi_head_base 70 | transformer_basic 71 | transformer_xl 72 | transformer_bottleneck 73 | informer 74 | 75 | .. automodule:: flood_forecast.gcp_integration 76 | 77 | .. toctree:: 78 | :maxdepth: 3 79 | :caption: GCP Integration: 80 | 81 | basic_utils 82 | 83 | .. automodule:: flood_forecast.da_rnn 84 | 85 | .. toctree:: 86 | :maxdepth: 3 87 | :caption: DA RNN: 88 | model 89 | 90 | 91 | Indices and tables 92 | ================== 93 | 94 | * :ref:`genindex` 95 | * :ref:`modindex` 96 | * :ref:`search` 97 | -------------------------------------------------------------------------------- /docs/source/inference.rst: -------------------------------------------------------------------------------- 1 | Inference 2 | ========================= 3 | 4 | This API makes it easy to run inference on trained PyTorchForecast modules. To use this code you 5 | need three main files: your model's configuration file, a CSV containing your data, and a path to 6 | your model weights. 7 | 8 | .. code-block:: python 9 | :caption: example initialization 10 | 11 | import json 12 | from datetime import datetime 13 | from flood_forecast.deployment.inference import InferenceMode 14 | new_water_data_path = "gs://predict_cfs/day_addition/01046000KGNR_flow.csv" 15 | weight_path = "gs://predict_cfs/experiments/10_December_202009_34PM_model.pth" 16 | with open("config.json") as y: 17 | config_test = json.load(y) 18 | infer_model = InferenceMode(336, 30, config_test, new_water_data_path, weight_path, "river") 19 | 20 | .. code-block:: python 21 | :caption: example plotting 22 | 23 | .. automodule:: flood_forecast.deployment.inference 24 | :members: 25 | -------------------------------------------------------------------------------- /docs/source/informer.rst: -------------------------------------------------------------------------------- 1 | Informer 2 | ========================= 3 | .. automodule:: flood_forecast.transformer_xl.informer 4 | :members: search 5 | -------------------------------------------------------------------------------- /docs/source/interpolate_preprocess.rst: -------------------------------------------------------------------------------- 1 | Interpolate Preprocessing 2 | ========================= 3 | This module allows easy pre-processing of data. 4 | 5 | .. code-block:: python 6 | :emphasize-lines: 3,5 7 | 8 | 9 | 10 | .. automodule:: flood_forecast.preprocessing.interpolate_preprocess 11 | :members: 12 | -------------------------------------------------------------------------------- /docs/source/itransformer.rst: -------------------------------------------------------------------------------- 1 | I-Transformer Model. 2 | ================== 3 | 4 | .. automodule:: flood_forecast.transformer_xl.itransformer 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/long_train.rst: -------------------------------------------------------------------------------- 1 | Long Train 2 | =========== 3 | 4 | .. automodule:: flood_forecast.long_train 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/lower_upper_config.rst: -------------------------------------------------------------------------------- 1 | Lower Upper Configuration 2 | ========================= 3 | 4 | .. automodule:: flood_forecast.transformer_xl.lower_upper_config 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/meta_models.rst: -------------------------------------------------------------------------------- 1 | Meta Models for FF 2 | ========================= 3 | 4 | .. automodule:: flood_forecast.meta_models.merging_model 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/model.rst: -------------------------------------------------------------------------------- 1 | Model 2 | ===== 3 | 4 | .. automodule:: flood_forecast.da_rnn.model 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/model_dict_function.rst: -------------------------------------------------------------------------------- 1 | Model Dictionaries 2 | ==================== 3 | 4 | .. automodule:: flood_forecast.model_dict_function 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======== 3 | 4 | .. automodule:: flood_forecast.da_rnn.modules 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/multi_head_base.rst: -------------------------------------------------------------------------------- 1 | Simple Multi Attention Head Model 2 | ================================== 3 | 4 | .. automodule:: flood_forecast.transformer_xl.multi_head_base 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/pre_dict.rst: -------------------------------------------------------------------------------- 1 | Pre Dictionaries 2 | ================= 3 | 4 | .. automodule:: flood_forecast.pre_dict 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/preprocess_da_rnn.rst: -------------------------------------------------------------------------------- 1 | Preprocess DA RNN 2 | ================== 3 | 4 | .. automodule:: flood_forecast.preprocessing.preprocess_da_rnn 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/preprocess_metadata.rst: -------------------------------------------------------------------------------- 1 | Preprocess Metadata 2 | ==================== 3 | 4 | .. automodule:: flood_forecast.preprocessing.preprocess_metadata 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/process_usgs.rst: -------------------------------------------------------------------------------- 1 | Process USGS 2 | ============= 3 | 4 | .. automodule:: flood_forecast.preprocessing.process_usgs 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/pytorch_loaders.rst: -------------------------------------------------------------------------------- 1 | PyTorch Loaders 2 | ================ 3 | 4 | .. automodule:: flood_forecast.preprocessing.pytorch_loaders 5 | :members: 6 | :undoc-members: 7 | :inherited-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/source/pytorch_training.rst: -------------------------------------------------------------------------------- 1 | PyTorch Training 2 | ================== 3 | 4 | .. automodule:: flood_forecast.pytorch_training 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/temporal_feats.rst: -------------------------------------------------------------------------------- 1 | Temporal Features 2 | ================== 3 | 4 | .. automodule:: flood_forecast.preprocessing.temporal_feats 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/time_model.rst: -------------------------------------------------------------------------------- 1 | Time Model 2 | ============ 3 | 4 | .. automodule:: flood_forecast.time_model 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/train_da.rst: -------------------------------------------------------------------------------- 1 | Train da 2 | ========= 3 | 4 | .. automodule:: flood_forecast.da_rnn.train_da 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/trainer.rst: -------------------------------------------------------------------------------- 1 | Trainer 2 | ======== 3 | 4 | .. automodule:: flood_forecast.trainer 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/training_utils.rst: -------------------------------------------------------------------------------- 1 | Training Utils 2 | ================== 3 | 4 | This module includes functions that useful for training . 5 | 6 | .. automodule:: flood_forecast.training_utils 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/transformer_basic.rst: -------------------------------------------------------------------------------- 1 | Basic Transformer 2 | ================== 3 | 4 | .. automodule:: flood_forecast.transformer_xl.transformer_basic 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/transformer_bottleneck.rst: -------------------------------------------------------------------------------- 1 | Convolutional Transformer 2 | ================== 3 | 4 | This is an implementation of the code from this paper 5 | 6 | .. automodule:: flood_forecast.transformer_xl.transformer_bottleneck 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/transformer_xl.rst: -------------------------------------------------------------------------------- 1 | Transformer XL 2 | ================ 3 | 4 | .. automodule:: flood_forecast.transformer_xl.transformer_xl 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | .. automodule:: flood_forecast.utils 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/utils_da.rst: -------------------------------------------------------------------------------- 1 | DA_RNN util 2 | ====== 3 | 4 | .. automodule:: flood_forecast.da_rnn.utils 5 | :members: 6 | -------------------------------------------------------------------------------- /flood_forecast/basic/base_line_methods.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class NaiveBase(torch.nn.Module): 5 | """A very simple baseline model that returns the fixed value based on the input sequence. 6 | 7 | No learning used at all. 8 | """ 9 | 10 | def __init__(self, seq_length: int, n_time_series: int, output_seq_len=1, metric: str = "last"): 11 | super().__init__() 12 | self.forecast_history = seq_length 13 | self.n_time_series = n_time_series 14 | self.initial_layer = torch.nn.Linear(n_time_series, 1) 15 | self.output_layer = torch.nn.Linear(seq_length, output_seq_len) 16 | self.metric_dict = {"last": the_last1} 17 | self.output_seq_len = output_seq_len 18 | self.metric_function = self.metric_dict[metric] 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | """_summary_ 22 | 23 | Args: 24 | x (torch.Tensor): _description_ 25 | 26 | Returns: 27 | torch.Tensor: _description_ 28 | """ 29 | return self.metric_function(x, self.output_seq_len) 30 | 31 | 32 | def the_last(index_in_tensor: int, the_tensor: torch.Tensor) -> torch.Tensor: 33 | """Warning this assumes that target is the last column Will return a torch tensor of the proper dim.""" 34 | for batch_num in range(0, the_tensor.shape[0]): 35 | value = the_tensor[batch_num, -1, -1] 36 | the_tensor[batch_num, :, -1] = value 37 | return the_tensor 38 | 39 | 40 | def the_last1(tensor: torch.Tensor, out_len: int) -> torch.Tensor: 41 | """Creates a tensor based on the last element. 42 | 43 | :param tensor: A tensor of dimension (batch_size, seq_len, n_time_series) 44 | :param out_len: The length or the forecast_length 45 | :type out_len: int 46 | 47 | :return: Returns a tensor of (batch_size, out_len, 1) 48 | :rtype: torch.Tensor 49 | """ 50 | return tensor[:, -1, :].unsqueeze(0).permute(1, 0, 2).repeat(1, out_len, 1) 51 | -------------------------------------------------------------------------------- /flood_forecast/basic/gru_vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class VanillaGRU(torch.nn.Module): 5 | def __init__(self, n_time_series: int, hidden_dim: int, num_layers: int, n_target: int, dropout: float, 6 | forecast_length=1, use_hidden=False, probabilistic=False): 7 | """Simple GRU to preform deep time series forecasting. 8 | 9 | :param n_time_series: The number of time series present in the data 10 | :type n_time_series int: 11 | :param hidden_dim: 12 | :type hidden_dim: 13 | 14 | Note for probablistic n_targets must be set to two and actual multiple targs are not supported now. 15 | """ 16 | super(VanillaGRU, self).__init__() 17 | 18 | # Defining the number of layers and the nodes in each layer 19 | self.layer_dim = num_layers 20 | self.hidden_dim = hidden_dim 21 | self.hidden = None 22 | self.use_hidden = use_hidden 23 | self.forecast_length = forecast_length 24 | self.probablistic = probabilistic 25 | 26 | # GRU layers 27 | self.gru = torch.nn.GRU( 28 | n_time_series, hidden_dim, num_layers, batch_first=True, dropout=dropout 29 | ) 30 | 31 | # Fully connected layer 32 | self.fc = torch.nn.Linear(hidden_dim, n_target) 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | """Forward function for GRU. 36 | 37 | :param x: torch of shape 38 | :type model: torch.Tensor 39 | :return: Returns a tensor of shape (batch_size, forecast_length, n_target) or (batch_size, n_target) 40 | :rtype: torch.Tensor 41 | """ 42 | # Initializing hidden state for first input with zeros 43 | if self.hidden is not None and self.use_hidden: 44 | h0 = self.hidden 45 | else: 46 | h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_() 47 | 48 | # Forward propagation by passing in the input and hidden state into the model 49 | out, self.hidden = self.gru(x, h0.detach()) 50 | 51 | # Reshaping the outputs in the shape of (batch_size, seq_length, hidden_size) 52 | # so that it can fit into the fully connected layer 53 | out = out[:, -self.forecast_length:, :] 54 | 55 | # Convert the final state to our desired output shape (batch_size, output_dim) 56 | out = self.fc(out) 57 | if self.probablistic: 58 | mean = out[..., 0][..., None] 59 | std = torch.clamp(out[..., 1][..., None], min=0.01) 60 | y_pred = torch.distributions.Normal(mean, std) 61 | return y_pred 62 | if self.fc.out_features == 1: 63 | return out[:, :, 0] 64 | return out 65 | -------------------------------------------------------------------------------- /flood_forecast/basic/lstm_vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LSTMForecast(torch.nn.Module): 5 | """A very simple baseline LSTM model that returns an output sequence given a multi-dimensional input seq. 6 | 7 | Inspired by the StackOverflow link below. 8 | https://stackoverflow.com/questions/56858924/multivariate-input-lstm-in-pytorch 9 | """ 10 | 11 | def __init__( 12 | self, 13 | seq_length: int, 14 | n_time_series: int, 15 | output_seq_len=1, 16 | hidden_states: int = 20, 17 | num_layers=2, 18 | bias=True, 19 | batch_size=100, 20 | probabilistic=False): 21 | super().__init__() 22 | self.forecast_history = seq_length 23 | self.n_time_series = n_time_series 24 | self.hidden_dim = hidden_states 25 | self.num_layers = num_layers 26 | self.lstm = torch.nn.LSTM(n_time_series, hidden_states, num_layers, bias, batch_first=True) 27 | self.probabilistic = probabilistic 28 | if self.probabilistic: 29 | output_seq_len = 2 30 | self.final_layer = torch.nn.Linear(seq_length * hidden_states, output_seq_len) 31 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 32 | self.init_hidden(batch_size) 33 | 34 | def init_hidden(self, batch_size: int) -> None: 35 | """[summary] 36 | 37 | :param batch_size: [description] 38 | :type batch_size: int 39 | """ 40 | # This is what we'll initialise our hidden state 41 | self.hidden = ( 42 | torch.zeros( 43 | self.num_layers, 44 | batch_size, 45 | self.hidden_dim).to( 46 | self.device), 47 | torch.zeros( 48 | self.num_layers, 49 | batch_size, 50 | self.hidden_dim).to( 51 | self.device)) 52 | 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: 54 | batch_size = x.size()[0] 55 | self.init_hidden(batch_size) 56 | out_x, self.hidden = self.lstm(x, self.hidden) 57 | x = self.final_layer(out_x.contiguous().view(batch_size, -1)) 58 | 59 | if self.probabilistic: 60 | mean = x[..., 0][..., None] 61 | std = torch.clamp(x[..., 1][..., None], min=0.01) 62 | x = torch.distributions.Normal(mean, std) 63 | return x 64 | -------------------------------------------------------------------------------- /flood_forecast/da_rnn/README.md: -------------------------------------------------------------------------------- 1 | This code is taken pretty much verbatim from [Seanny123](https://github.com/Seanny123/da-rnn). It is a simple Dual-Stage Attention-Based Recurrent Neural Net for Time Series Prediction. 2 | -------------------------------------------------------------------------------- /flood_forecast/da_rnn/checkpoint/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIStream-Peelout/flow-forecast/9a2af06685db4a635eb21e57d8e522a355f85286/flood_forecast/da_rnn/checkpoint/.gitkeep -------------------------------------------------------------------------------- /flood_forecast/da_rnn/checkpoint/decoder.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIStream-Peelout/flow-forecast/9a2af06685db4a635eb21e57d8e522a355f85286/flood_forecast/da_rnn/checkpoint/decoder.pth -------------------------------------------------------------------------------- /flood_forecast/da_rnn/checkpoint/encoder.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIStream-Peelout/flow-forecast/9a2af06685db4a635eb21e57d8e522a355f85286/flood_forecast/da_rnn/checkpoint/encoder.pth -------------------------------------------------------------------------------- /flood_forecast/da_rnn/config/dec_kwargs.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoder_hidden_size": 64, 3 | "decoder_hidden_size": 64, 4 | "T": 10, 5 | "out_feats": 1 6 | } 7 | -------------------------------------------------------------------------------- /flood_forecast/da_rnn/config/enc_kwargs.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_size": 2, 3 | "hidden_size": 64, 4 | "T": 10 5 | } 6 | -------------------------------------------------------------------------------- /flood_forecast/da_rnn/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 4 | -------------------------------------------------------------------------------- /flood_forecast/da_rnn/custom_types.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import typing 3 | 4 | import numpy as np 5 | 6 | 7 | class TrainConfig(typing.NamedTuple): 8 | T: int 9 | train_size: int 10 | batch_size: int 11 | loss_func: typing.Callable 12 | 13 | 14 | class TrainData(typing.NamedTuple): 15 | feats: np.ndarray 16 | targs: np.ndarray 17 | 18 | 19 | DaRnnNet = collections.namedtuple("DaRnnNet", ["encoder", "decoder", "enc_opt", "dec_opt"]) 20 | -------------------------------------------------------------------------------- /flood_forecast/da_rnn/main_predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | from sklearn.externals import joblib 9 | 10 | from modules import Encoder, Decoder 11 | from utils import numpy_to_tvar 12 | import utils 13 | from custom_types import TrainData 14 | from constants import device 15 | 16 | 17 | def preprocess_data(dat, col_names, scale) -> TrainData: 18 | proc_dat = scale.transform(dat) 19 | 20 | mask = np.ones(proc_dat.shape[1], dtype=bool) 21 | dat_cols = list(dat.columns) 22 | for col_name in col_names: 23 | mask[dat_cols.index(col_name)] = False 24 | 25 | feats = proc_dat[:, mask] 26 | targs = proc_dat[:, ~mask] 27 | 28 | return TrainData(feats, targs) 29 | 30 | 31 | def predict(encoder, decoder, t_dat, batch_size: int, T: int) -> np.ndarray: 32 | y_pred = np.zeros((t_dat.feats.shape[0] - T + 1, t_dat.targs.shape[0])) 33 | 34 | for y_i in range(0, len(y_pred), batch_size): 35 | y_slc = slice(y_i, y_i + batch_size) 36 | batch_idx = range(len(y_pred))[y_slc] 37 | b_len = len(batch_idx) 38 | X = np.zeros((b_len, T - 1, t_dat.feats.shape[1])) 39 | y_history = np.zeros((b_len, T - 1, t_dat.targs.shape[0])) 40 | 41 | for b_i, b_idx in enumerate(batch_idx): 42 | idx = range(b_idx, b_idx + T - 1) 43 | 44 | X[b_i, :, :] = t_dat.feats[idx, :] 45 | y_history[b_i, :] = t_dat.targs[idx] 46 | 47 | y_history = numpy_to_tvar(y_history) 48 | _, input_encoded = encoder(numpy_to_tvar(X)) 49 | y_pred[y_slc] = decoder(input_encoded, y_history).cpu().data.numpy() 50 | 51 | return y_pred 52 | 53 | 54 | debug = False 55 | save_plots = False 56 | 57 | with open(os.path.join("data", "enc_kwargs.json"), "r") as fi: 58 | enc_kwargs = json.load(fi) 59 | enc = Encoder(**enc_kwargs) 60 | enc.load_state_dict(torch.load(os.path.join("data", "encoder.torch"), map_location=device)) 61 | 62 | with open(os.path.join("data", "dec_kwargs.json"), "r") as fi: 63 | dec_kwargs = json.load(fi) 64 | dec = Decoder(**dec_kwargs) 65 | dec.load_state_dict(torch.load(os.path.join("data", "decoder.torch"), map_location=device)) 66 | 67 | scaler = joblib.load(os.path.join("data", "scaler.pkl")) 68 | raw_data = pd.read_csv(os.path.join("data", "nasdaq100_padding.csv"), nrows=100 if debug else None) 69 | targ_cols = ("NDX",) 70 | data = preprocess_data(raw_data, targ_cols, scaler) 71 | 72 | with open(os.path.join("data", "da_rnn_kwargs.json"), "r") as fi: 73 | da_rnn_kwargs = json.load(fi) 74 | final_y_pred = predict(enc, dec, data, **da_rnn_kwargs) 75 | 76 | plt.figure() 77 | plt.plot(final_y_pred, label='Predicted') 78 | plt.plot(data.targs[(da_rnn_kwargs["T"] - 1):], label="True") 79 | plt.legend(loc='upper left') 80 | utils.save_or_show_plot("final_predicted_reloaded.png", save_plots) 81 | -------------------------------------------------------------------------------- /flood_forecast/da_rnn/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | from flood_forecast.da_rnn.constants import device 9 | 10 | 11 | def setup_log(tag='VOC_TOPICS'): 12 | # create logger 13 | logger = logging.getLogger(tag) 14 | # logger.handlers = [] 15 | logger.propagate = False 16 | logger.setLevel(logging.DEBUG) 17 | # create console handler and set level to debug 18 | ch = logging.StreamHandler() 19 | ch.setLevel(logging.DEBUG) 20 | # create formatter 21 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 22 | # add formatter to ch 23 | ch.setFormatter(formatter) 24 | # add ch to logger 25 | # logger.handlers = [] 26 | logger.addHandler(ch) 27 | return logger 28 | 29 | 30 | def save_or_show_plot(file_nm: str, save: bool, save_path=""): 31 | if save: 32 | plt.savefig(os.path.join(save_path, file_nm)) 33 | else: 34 | plt.show() 35 | 36 | 37 | def numpy_to_tvar(x: torch.Tensor): 38 | return Variable(torch.from_numpy(x).type(torch.FloatTensor).to(device)) 39 | -------------------------------------------------------------------------------- /flood_forecast/gcp_integration/basic_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from google.cloud import storage 3 | from google.oauth2.service_account import Credentials 4 | import os 5 | 6 | 7 | def get_storage_client( 8 | service_key_path: Optional[str] = None, 9 | ) -> storage.Client: 10 | """Utility function to return a properly authenticated GCS storage client whether working in Colab, CircleCI, 11 | Dataverse, or other environments.""" 12 | if service_key_path is None: 13 | if os.environ["ENVIRONMENT_GCP"] == "Colab": 14 | return storage.Client(project=os.environ["GCP_PROJECT"]) 15 | else: 16 | import ast 17 | cred_dict = ast.literal_eval(os.environ["ENVIRONMENT_GCP"]) 18 | credentials = Credentials.from_service_account_info(cred_dict) 19 | return storage.Client(credentials=credentials, project=credentials.project_id) 20 | else: 21 | return storage.Client.from_service_account_json(service_key_path) 22 | 23 | 24 | def upload_file( 25 | bucket_name: str, file_name: str, upload_name: str, client: storage.Client 26 | ): 27 | """A function to upload a file to a GCP bucket. 28 | 29 | :param bucket_name: The name of the bucket 30 | :type bucket_name: str 31 | :param file_name: The name of the file 32 | :type file_name: str 33 | :param upload_name: [description] 34 | :type upload_name: str 35 | :param client: [description] 36 | :type client: storage.Client 37 | """ 38 | bucket = client.get_bucket(bucket_name) 39 | blob = bucket.blob(file_name) 40 | blob.upload_from_filename(upload_name) 41 | 42 | 43 | def download_file( 44 | bucket_name: str, 45 | source_blob_name: str, 46 | destination_file_name: str, 47 | service_key_path: Optional[str] = None, 48 | ) -> None: 49 | """Download data file from GCS. 50 | 51 | Args: 52 | bucket_name ([str]): bucket name on GCS, eg. task_ts_data 53 | source_blob_name ([str]): storage object name 54 | destination_file_name ([str]): filepath to save to local 55 | """ 56 | storage_client = get_storage_client(service_key_path) 57 | 58 | bucket = storage_client.bucket(bucket_name) 59 | blob = bucket.blob(source_blob_name) 60 | blob.download_to_filename(destination_file_name) 61 | 62 | print( 63 | "Blob {} downloaded to {}.".format( 64 | source_blob_name, destination_file_name 65 | ) 66 | ) 67 | -------------------------------------------------------------------------------- /flood_forecast/long_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | import argparse 5 | import traceback 6 | from flood_forecast.trainer import train_function 7 | from typing import List 8 | 9 | 10 | def split_on_letter(s: str) -> List: 11 | match = re.compile(r"[^\W\d]").search(s) 12 | return [s[:match.start()], s[match.start():]] 13 | 14 | 15 | def loop_through( 16 | data_dir: str, 17 | interrmittent_gcs: bool = False, 18 | use_transfer: bool = True, 19 | start_index: int = 0, 20 | end_index: int = 25) -> None: 21 | """Function that makes and executes a set of config files This is since we have over 9k files and.""" 22 | if not os.path.exists("model_save"): 23 | os.mkdir("model_save") 24 | sorted_dir_list = sorted(os.listdir(data_dir)) 25 | # total = len(sorted_dir_list) 26 | for i in range(start_index, end_index): 27 | file_name = sorted_dir_list[i] 28 | station_id_gage = file_name.split("_flow.csv")[0] 29 | res = split_on_letter(station_id_gage) 30 | gage_id = res[0] 31 | station_id = res[1] 32 | file_path_name = os.path.join(data_dir, file_name) 33 | print("Training on: " + file_path_name) 34 | correct_file = None 35 | if use_transfer and len(os.listdir("model_save")) > 1: 36 | weight_files = filter(lambda x: x.endswith(".pth"), os.listdir("model_save")) 37 | paths = [] 38 | for weight_file in weight_files: 39 | paths.append(os.path.join("model_save", weight_file)) 40 | correct_file = max(paths, key=os.path.getctime) 41 | print(correct_file) 42 | config = make_config_file(file_path_name, gage_id, station_id, correct_file) 43 | extension = ".json" 44 | file_name_json = station_id + "config_f" + extension 45 | with open(file_name_json, "w+") as f: 46 | json.dump(config, f) 47 | try: 48 | train_function("PyTorch", config) 49 | except Exception as e: 50 | print("An exception occured for: " + file_name_json) 51 | traceback.print_exc() 52 | print(e) 53 | 54 | 55 | def make_config_file(flow_file_path: str, gage_id: str, station_id: str, weight_path: str = None): 56 | the_config = { 57 | "model_name": "MultiAttnHeadSimple", 58 | "model_type": "PyTorch", 59 | # "weight_path": "31_December_201906_32AM_model.pth", 60 | "model_params": { 61 | "number_time_series": 3, 62 | "seq_len": 36 63 | }, 64 | "dataset_params": 65 | {"class": "default", 66 | "training_path": flow_file_path, 67 | "validation_path": flow_file_path, 68 | "test_path": flow_file_path, 69 | "batch_size": 20, 70 | "forecast_history": 36, 71 | "forecast_length": 36, 72 | "train_end": 35000, 73 | "valid_start": 35001, 74 | "valid_end": 40000, 75 | "target_col": ["cfs1"], 76 | "relevant_cols": ["cfs1", "precip", "temp"], 77 | "scaler": "StandardScaler" 78 | }, 79 | "training_params": 80 | { 81 | "criterion": "MSE", 82 | "optimizer": "Adam", 83 | "optim_params": { 84 | "lr": 0.001 85 | # Default is lr=0.001 86 | }, 87 | 88 | "epochs": 14, 89 | "batch_size": 20 90 | }, 91 | "GCS": True, 92 | 93 | "wandb": { 94 | "name": "flood_forecast_" + str(gage_id), 95 | "tags": [gage_id, station_id, "MultiAttnHeadSimple", "36", "corrected"] 96 | }, 97 | "forward_params": {} 98 | } 99 | if weight_path: 100 | the_config["weight_path"] = weight_path 101 | # 31_December_201906_12AM_model.pth 102 | return the_config 103 | 104 | 105 | def main(): 106 | parser = argparse.ArgumentParser(description="Argument parsing for training and evaluation") 107 | parser.add_argument("-p", "--path", help="Data path") 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /flood_forecast/meta_models/basic_ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class AE(nn.Module): 6 | def __init__(self, input_shape: int, out_features: int): 7 | """A basic and simple to use AutoEncoder. 8 | 9 | :param input_shape: The number of features for input. 10 | :type input_shape: int 11 | :param out_features: The number of output features (that will be merged) 12 | :type out_features: int 13 | """ 14 | super().__init__() 15 | self.encoder_hidden_layer = nn.Linear( 16 | in_features=input_shape, out_features=out_features 17 | ) 18 | self.encoder_output_layer = nn.Linear( 19 | in_features=out_features, out_features=out_features 20 | ) 21 | self.decoder_hidden_layer = nn.Linear( 22 | in_features=out_features, out_features=out_features 23 | ) 24 | self.decoder_output_layer = nn.Linear( 25 | in_features=out_features, out_features=input_shape 26 | ) 27 | 28 | def forward(self, features: torch.Tensor): 29 | """Runs the full forward pass on the model. In practice this will only be done during training. 30 | 31 | :param features: [description] 32 | :type features: [type] 33 | :return: [description] 34 | :rtype: [type] 35 | 36 | .. code-block:: python 37 | auto_model = AE(10, 4) 38 | x = torch.rand(2, 10) # batch_size, n_features 39 | result = auto_model(x) 40 | print(result.shape) # (2, 10) 41 | """ 42 | activation = self.encoder_hidden_layer(features) 43 | activation = torch.relu(activation) 44 | code = self.encoder_output_layer(activation) 45 | code = torch.relu(code) 46 | activation = self.decoder_hidden_layer(code) 47 | activation = torch.relu(activation) 48 | activation = self.decoder_output_layer(activation) 49 | reconstructed = torch.relu(activation) 50 | return reconstructed 51 | 52 | def generate_representation(self, features): 53 | activation = self.encoder_hidden_layer(features) 54 | activation = torch.relu(activation) 55 | code = self.encoder_output_layer(activation) 56 | code = torch.relu(code) 57 | return code 58 | -------------------------------------------------------------------------------- /flood_forecast/meta_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict 3 | import json 4 | from flood_forecast.pytorch_training import train_transformer_style 5 | from flood_forecast.time_model import PyTorchForecast 6 | 7 | 8 | def train_function(model_type: str, params: Dict) -> PyTorchForecast: 9 | """Function to train meta data-models.""" 10 | params["forward_params"] = {} 11 | dataset_params = params["dataset_params"] 12 | if "forecast_history" not in dataset_params: 13 | dataset_params["forecast_history"] = 1 14 | dataset_params["forecast_length"] = 1 15 | trained_model = PyTorchForecast( 16 | params["model_name"], 17 | dataset_params["training_path"], 18 | dataset_params["validation_path"], 19 | dataset_params["test_path"], 20 | params) 21 | train_transformer_style(trained_model, params["training_params"], params["forward_params"]) 22 | return trained_model 23 | 24 | 25 | def main(): 26 | """Main meta training function which is called from the command line. 27 | 28 | Entrypoint for all AutoEncoder models. 29 | """ 30 | parser = argparse.ArgumentParser(description="Argument parsing for model training") 31 | parser.add_argument("-p", "--params", help="Path to the model config file") 32 | args = parser.parse_args() 33 | with open(args.params) as f: 34 | training_config = json.load(f) 35 | train_function(training_config["model_type"], training_config) 36 | print("Meta-training of model is now complete.") 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /flood_forecast/model_dict_function.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.multi_models.crossvivit import RoCrossViViT 2 | from flood_forecast.transformer_xl.multi_head_base import MultiAttnHeadSimple 3 | from flood_forecast.transformer_xl.transformer_basic import SimpleTransformer, CustomTransformerDecoder 4 | from flood_forecast.transformer_xl.informer import Informer 5 | from flood_forecast.transformer_xl.transformer_xl import TransformerXL 6 | from flood_forecast.transformer_xl.dummy_torch import DummyTorchModel 7 | from flood_forecast.basic.linear_regression import SimpleLinearModel 8 | from flood_forecast.basic.lstm_vanilla import LSTMForecast 9 | from flood_forecast.custom.custom_opt import BertAdam, QuantileLoss 10 | from torch.optim import Adam, SGD 11 | from torch.nn import MSELoss, SmoothL1Loss, PoissonNLLLoss, L1Loss, CrossEntropyLoss, BCELoss, BCEWithLogitsLoss 12 | from flood_forecast.basic.linear_regression import simple_decode 13 | from flood_forecast.transformer_xl.transformer_basic import greedy_decode 14 | from flood_forecast.custom.focal_loss import FocalLoss 15 | from flood_forecast.da_rnn.model import DARNN 16 | from flood_forecast.custom.custom_opt import (RMSELoss, MAPELoss, PenalizedMSELoss, NegativeLogLikelihood, MASELoss, 17 | GaussianLoss) 18 | from flood_forecast.transformer_xl.transformer_bottleneck import DecoderTransformer 19 | from flood_forecast.custom.dilate_loss import DilateLoss 20 | from flood_forecast.meta_models.basic_ae import AE 21 | from flood_forecast.transformer_xl.dsanet import DSANet 22 | from flood_forecast.basic.gru_vanilla import VanillaGRU 23 | from flood_forecast.basic.d_n_linear import DLinear, NLinear 24 | from flood_forecast.transformer_xl.itransformer import ITransformer 25 | from flood_forecast.transformer_xl.cross_former import Crossformer as Crossformer10 26 | from torchtsmixer import TSMixer 27 | from torchtsmixer import TSMixerExt 28 | 29 | 30 | """ 31 | Utility dictionaries to map a string to a class in the flood_forecast package. 32 | """ 33 | pytorch_model_dict = { 34 | "MultiAttnHeadSimple": MultiAttnHeadSimple, 35 | "SimpleTransformer": SimpleTransformer, 36 | "TransformerXL": TransformerXL, 37 | "DummyTorchModel": DummyTorchModel, 38 | "LSTM": LSTMForecast, 39 | "SimpleLinearModel": SimpleLinearModel, 40 | "CustomTransformerDecoder": CustomTransformerDecoder, 41 | "DARNN": DARNN, 42 | "DecoderTransformer": DecoderTransformer, 43 | "BasicAE": AE, 44 | "Informer": Informer, 45 | "DSANet": DSANet, 46 | "VanillaGRU": VanillaGRU, 47 | "DLinear": DLinear, 48 | "Crossformer": Crossformer10, 49 | "NLinear": NLinear, 50 | "TSMixer": TSMixer, 51 | "TSMixerExt": TSMixerExt, 52 | "ITransformer": ITransformer, 53 | "CrossVIVIT": RoCrossViViT, 54 | } 55 | 56 | pytorch_criterion_dict = { 57 | "GaussianLoss": GaussianLoss, 58 | "MASELoss": MASELoss, 59 | "MSE": MSELoss, 60 | "SmoothL1Loss": SmoothL1Loss, 61 | "PoissonNLLLoss": PoissonNLLLoss, 62 | "RMSE": RMSELoss, 63 | "MAPE": MAPELoss, 64 | "DilateLoss": DilateLoss, 65 | "L1": L1Loss, 66 | "PenalizedMSELoss": PenalizedMSELoss, 67 | "CrossEntropyLoss": CrossEntropyLoss, 68 | "NegativeLogLikelihood": NegativeLogLikelihood, 69 | "BCELossLogits": BCEWithLogitsLoss, 70 | "FocalLoss": FocalLoss, 71 | "QuantileLoss": QuantileLoss, 72 | "BinaryCrossEntropy": BCELoss} 73 | 74 | decoding_functions = {"greedy_decode": greedy_decode, "simple_decode": simple_decode} 75 | 76 | pytorch_opt_dict = {"Adam": Adam, "SGD": SGD, "BertAdam": BertAdam} 77 | -------------------------------------------------------------------------------- /flood_forecast/pre_dict.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler, MaxAbsScaler 2 | from flood_forecast.preprocessing.interpolate_preprocess import (interpolate_missing_values, 3 | back_forward_generic, forward_back_generic) 4 | # SAMMY IS TOO LITTLE TO BE A REAL DOG. 5 | scaler_dict = { 6 | "StandardScaler": StandardScaler, 7 | "RobustScaler": RobustScaler, 8 | "MinMaxScaler": MinMaxScaler, 9 | "MaxAbsScaler": MaxAbsScaler} 10 | 11 | interpolate_dict = {"back_forward": interpolate_missing_values, "back_forward_generic": back_forward_generic, 12 | "forward_back_generic": forward_back_generic} 13 | print("loaded dicts") 14 | -------------------------------------------------------------------------------- /flood_forecast/preprocessing/data_converter.py: -------------------------------------------------------------------------------- 1 | """A set of function aimed at making it easy to convert other time series datasets to our format for transfer learning 2 | purposes.""" 3 | 4 | 5 | def make_column_names(df): 6 | num_cols = len(list(df)) 7 | # generate range of ints for suffixes 8 | # with length exactly half that of num_cols; 9 | # if num_cols is even, truncate concatenated list later 10 | # to get to original list length 11 | column_arr = [] 12 | for i in range(0, num_cols): 13 | column_arr.append("solar_" + str(i)) 14 | 15 | # ensure the length of the new columns list is equal to the length of df's columns 16 | df.columns = column_arr 17 | return df 18 | -------------------------------------------------------------------------------- /flood_forecast/preprocessing/interpolate_preprocess.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import List 3 | 4 | 5 | def fix_timezones(df: pd.DataFrame) -> pd.DataFrame: 6 | """Basic function to fix initil data bug related to NaN values in non-eastern-time zones due to UTC conversion.""" 7 | the_count = df[0:2]['cfs'].isna().sum() 8 | return df[the_count:] 9 | 10 | 11 | def interpolate_missing_values(df: pd.DataFrame) -> pd.DataFrame: 12 | """Function to fill missing values with nearest value. 13 | 14 | Should be run only after splitting on the NaN chunks. 15 | """ 16 | df = fix_timezones(df) 17 | df['cfs1'] = df['cfs'].interpolate(method='nearest').ffill().bfill() 18 | df['precip'] = df['p01m'].interpolate(method='nearest').ffill().bfill() 19 | df['temp'] = df['tmpf'].interpolate(method='nearest').ffill().bfill() 20 | return df 21 | 22 | 23 | def forward_back_generic(df: pd.DataFrame, relevant_columns: List) -> pd.DataFrame: 24 | """Function to fill missing values with nearest value (forward first)""" 25 | for col in relevant_columns: 26 | df[col] = df[col].interpolate(method='nearest').ffill().bfill() 27 | return df 28 | 29 | 30 | def back_forward_generic(df: pd.DataFrame, relevant_columns: List[str]) -> pd.DataFrame: 31 | """Function to fill missing values with nearest values (backward first)""" 32 | for col in relevant_columns: 33 | df[col] = df[col].interpolate(method='nearest').bfill().ffill() 34 | return df 35 | -------------------------------------------------------------------------------- /flood_forecast/preprocessing/preprocess_da_rnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from typing import List 4 | from flood_forecast.da_rnn.custom_types import TrainData 5 | 6 | 7 | def format_data(dat, targ_column: List[str]) -> TrainData: 8 | # Test numpy conversion 9 | proc_dat = dat.to_numpy() 10 | mask = np.ones(proc_dat.shape[1], dtype=bool) 11 | dat_cols = list(dat.columns) 12 | for col_name in targ_column: 13 | mask[dat_cols.index(col_name)] = False 14 | feats = proc_dat[:, mask].astype(float) 15 | targs = proc_dat[:, ~mask].astype(float) 16 | return TrainData(feats, targs) 17 | 18 | 19 | def make_data( 20 | csv_path: str, 21 | target_col: List[str], 22 | test_length: int, 23 | relevant_cols=[ 24 | "cfs", 25 | "temp", 26 | "precip"]) -> TrainData: 27 | """Returns full preprocessed data. 28 | 29 | Does not split train/test that must be done later. 30 | """ 31 | final_df = pd.read_csv(csv_path) 32 | print(final_df.shape[0]) 33 | if len(target_col) > 1: 34 | # Restrict target columns to height and cfs. Alternatively could replace this with loop 35 | height_df = final_df[[target_col[0], target_col[1], 'precip', 'temp']] 36 | height_df.columns = [target_col[0], target_col[1], 'precip', 'temp'] 37 | else: 38 | height_df = final_df[[target_col[0]] + relevant_cols] 39 | preprocessed_data2 = format_data(height_df, target_col) 40 | return preprocessed_data2 41 | -------------------------------------------------------------------------------- /flood_forecast/preprocessing/preprocess_metadata.py: -------------------------------------------------------------------------------- 1 | # import json 2 | import pandas as pd 3 | 4 | 5 | def make_gage_data_csv(file_path: str): 6 | "returns df" 7 | with open(file_path) as f: 8 | df = pd.read_json(f) 9 | df = df.T 10 | df.index.name = "id" 11 | return df 12 | 13 | # todo define this function properly (what is econet?) 14 | # def make_station_meta(file_path_eco: str, file_path_assos: str): 15 | # core_columns = econet[['Station', 'Name', 'Latitude', 'Longitude', 16 | # 'Elevation', 'First Ob', 'Supported By', 'Time Interval(s)', 'Precip']] 17 | 18 | # todo define this function properly (haversine is not defined) 19 | # def get_closest_gage_list(station_df: pd.DataFrame, gage_df: pd.DataFrame): 20 | # for row in gage_df.iterrows(): 21 | # gage_info = {} 22 | # gage_info["river_id"] = row[1]['id'] 23 | # gage_lat = row[1]['latitude'] 24 | # gage_long = row[1]['logitude'] 25 | # gage_info["stations"] = [] 26 | # for stat_row in station_df.iterrows(): 27 | # dist = haversine(stat_row[1]["lon"], stat_row[1]["lat"], gage_long, gage_lat) 28 | # st_id = stat_row[1]['stid'] 29 | # gage_info["stations"].append({"station_id": st_id, "dist": dist}) 30 | # gage_info["stations"] = sorted(gage_info['stations'], key=lambda i: i["dist"], reverse=True) 31 | # print(gage_info) 32 | # with open(str(gage_info["river_id"]) + "stations.json", 'w') as w: 33 | # json.dump(gage_info, w) 34 | -------------------------------------------------------------------------------- /flood_forecast/preprocessing/process_usgs.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import requests 3 | from datetime import datetime 4 | from typing import Tuple, Dict 5 | import pytz 6 | # url format 7 | 8 | 9 | def make_usgs_data(start_date: datetime, end_date: datetime, site_number: str) -> pd.DataFrame: 10 | """_summary_ 11 | 12 | :param start_date: _description_ 13 | :type start_date: datetime 14 | :param end_date: _description_ 15 | :type end_date: datetime 16 | :param site_number: _description_ 17 | :type site_number: str 18 | :return: _description_ 19 | :rtype: pd.DataFrame 20 | """ 21 | base_url = "https://nwis.waterdata.usgs.gov/usa/nwis/uv/?cb_00060=on&cb_00065&format=rdb&" 22 | full_url = base_url + "site_no=" + site_number + "&period=&begin_date=" + \ 23 | start_date.strftime("%Y-%m-%d") + "&end_date=" + end_date.strftime("%Y-%m-%d") 24 | print("Getting request from USGS") 25 | print(full_url) 26 | r = requests.get(full_url) 27 | with open(site_number + ".txt", "w") as f: 28 | f.write(r.text) 29 | print("Request finished") 30 | response_data = process_response_text(site_number + ".txt") 31 | create_csv(response_data[0], response_data[1], site_number) 32 | return pd.read_csv(site_number + "_flow_data.csv") 33 | 34 | 35 | def process_response_text(file_name: str) -> Tuple[str, Dict]: 36 | extractive_params = {} 37 | with open(file_name, "r") as f: 38 | lines = f.readlines() 39 | i = 0 40 | params = False 41 | while "#" in lines[i]: 42 | # TODO figure out getting height and discharge code efficently 43 | the_split_line = lines[i].split()[1:] 44 | if params: 45 | print(the_split_line) 46 | if len(the_split_line) < 2: 47 | params = False 48 | else: 49 | extractive_params[the_split_line[0] + "_" + the_split_line[1]] = df_label(the_split_line[2]) 50 | if len(the_split_line) > 2: 51 | if the_split_line[0] == "TS": 52 | params = True 53 | i += 1 54 | with open(file_name.split(".")[0] + "data.tsv", "w") as t: 55 | t.write("".join(lines[i:])) 56 | return file_name.split(".")[0] + "data.tsv", extractive_params 57 | 58 | 59 | def df_label(usgs_text: str) -> str: 60 | usgs_text = usgs_text.replace(",", "") 61 | if usgs_text == "Discharge": 62 | return "cfs" 63 | elif usgs_text == "Gage": 64 | return "height" 65 | else: 66 | return usgs_text 67 | 68 | 69 | def create_csv(file_path: str, params_names: dict, site_number: str): 70 | """Function that creates the final version of the CSV files .""" 71 | df = pd.read_csv(file_path, sep="\t") 72 | for key, value in params_names.items(): 73 | df[value] = df[key] 74 | df.to_csv(site_number + "_flow_data.csv") 75 | 76 | 77 | def get_timezone_map(): 78 | timezone_map = { 79 | "EST": "America/New_York", 80 | "EDT": "America/New_York", 81 | "CST": "America/Chicago", 82 | "CDT": "America/Chicago", 83 | "MDT": "America/Denver", 84 | "MST": "America/Denver", 85 | "PST": "America/Los_Angeles", 86 | "PDT": "America/Los_Angeles"} 87 | return timezone_map 88 | 89 | 90 | def process_intermediate_csv(df: pd.DataFrame) -> (pd.DataFrame, int, int, int): 91 | # Remove garbage first row 92 | # TODO check if more rows are garbage 93 | df = df.iloc[1:] 94 | time_zone = df["tz_cd"].iloc[0] 95 | time_zone = get_timezone_map()[time_zone] 96 | old_timezone = pytz.timezone(time_zone) 97 | new_timezone = pytz.timezone("UTC") 98 | # This assumes timezones are consistent throughout the USGS stream (this should be true) 99 | df["datetime"] = df["datetime"].map(lambda x: old_timezone.localize( 100 | datetime.strptime(x, "%Y-%m-%d %H:%M")).astimezone(new_timezone)) 101 | df["cfs"] = pd.to_numeric(df['cfs'], errors='coerce') 102 | max_flow = df["cfs"].max() 103 | min_flow = df["cfs"].min() 104 | count_nan = len(df["cfs"]) - df["cfs"].count() 105 | print(f"there are {count_nan} nan values") 106 | return df[df.datetime.dt.minute == 0], max_flow, min_flow 107 | -------------------------------------------------------------------------------- /flood_forecast/preprocessing/temporal_feats.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import Dict 3 | import numpy as np 4 | 5 | 6 | def create_feature(key: str, value: str, df: pd.DataFrame, dt_column: str): 7 | """Function to create temporal feature. Uses dict to make val. 8 | 9 | :param key: The datetime feature you would like to create from the datetime column. 10 | :type key: str 11 | :param value: The type of feature you would like to create (cyclical or numerical) 12 | :type value: str 13 | :param df: The Pandas dataframe with thes datetime. 14 | :type df: pd.DataFrame 15 | :param dt_column: The name of the datetime column 16 | :type dt_column: str 17 | :return: The dataframe with the newly added columns. 18 | :rtype: pd.DataFrame 19 | """ 20 | if key == "day_of_week": 21 | df[key] = df[dt_column].map(lambda x: x.weekday()) 22 | elif key == "minute": 23 | df[key] = df[dt_column].map(lambda x: x.minute) 24 | elif key == "hour": 25 | df[key] = df[dt_column].map(lambda x: x.hour) 26 | elif key == "day": 27 | df[key] = df[dt_column].map(lambda x: x.day) 28 | elif key == "month": 29 | df[key] = df[dt_column].map(lambda x: x.month) 30 | elif key == "year": 31 | df[key] = df[dt_column].map(lambda x: x.year) 32 | if value == "cyclical": 33 | df = cyclical(df, key) 34 | return df 35 | 36 | 37 | def feature_fix(preprocess_params: Dict, dt_column: str, df: pd.DataFrame): 38 | """Adds temporal features. 39 | 40 | :param preprocess_params: Dictionary of temporal parameters e.g. {"day":"numerical"} 41 | :type preprocess_params: Dict 42 | :param dt_column: The column name of the data 43 | :param df: The dataframe to add the temporal features to 44 | :type df: pd.DataFrame 45 | :return: Returns the new data-frame and a list of the new column names 46 | :rtype: Tuple(pd.Dataframe, List[str]) 47 | 48 | .. code-block:: python 49 | feats_to_add = {"month":"cyclical", "day":"numerical"} 50 | df, column_names feature_fix(feats_to_add, "datetime") 51 | print(column_names) # ["cos_month", "sin_month", "day"] 52 | """ 53 | print("Running the code to add temporal features") 54 | column_names = [] 55 | if "datetime_params" in preprocess_params: 56 | for key, value in preprocess_params["datetime_params"].items(): 57 | df = create_feature(key, value, df, dt_column) 58 | if value == "cyclical": 59 | column_names.append("cos_" + key) 60 | column_names.append("sin_" + key) 61 | else: 62 | column_names.append(key) 63 | return df, column_names 64 | 65 | 66 | def cyclical(df: pd.DataFrame, feature_column: str) -> pd.DataFrame: 67 | """A function to create cyclical encodings for Pandas data-frames. 68 | 69 | :param df: A Pandas Dataframe where you want the dt encoded 70 | :type df: pd.DataFrame 71 | :param feature_column: The name of the feature column. Should be either (day_of_week, hour, month, year) 72 | :type feature_column: str 73 | :return: The dataframe with three new columns: norm_feature, cos_feature, sin_feature 74 | :rtype: pd.DataFrame 75 | """ 76 | df["norm"] = 2 * np.pi * df[feature_column] / df[feature_column].max() 77 | df['cos_' + feature_column] = np.cos(df['norm']) 78 | df['sin_' + feature_column] = np.sin(df['norm']) 79 | return df 80 | -------------------------------------------------------------------------------- /flood_forecast/preprocessing/virus_dataset/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIStream-Peelout/flow-forecast/9a2af06685db4a635eb21e57d8e522a355f85286/flood_forecast/preprocessing/virus_dataset/.gitkeep -------------------------------------------------------------------------------- /flood_forecast/series_id_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, List 3 | 4 | 5 | def handle_csv_id_output(src: Dict[int, torch.Tensor], trg: Dict[int, torch.Tensor], model, criterion, opt, 6 | random_sample: bool = False, n_targs: int = 1): 7 | """A helper function to better handle the output of models with a series_id and compute full loss. 8 | 9 | :param src: A dictionary of src sequences (partitioned by series_id) 10 | :type src: torch.Tensor 11 | :param trg: A dictionary of target sequences (key as series_id) 12 | :type trg: torch.Tensor 13 | :param model: A model that takes both a src and a series_id 14 | :type model: torch.nn.Module 15 | """ 16 | total_loss = 0.00 17 | for (k, v), (k2, v2) in zip(src.items(), trg.items()): 18 | output = model.model(v, k) 19 | loss = criterion(output, v2[:, :, :n_targs]) 20 | total_loss += loss.item() 21 | loss.backward() 22 | opt.step() 23 | total_loss /= len(src.keys()) 24 | return total_loss 25 | 26 | 27 | def handle_csv_id_validation(src: Dict[int, torch.Tensor], trg: Dict[int, torch.Tensor], model: torch.nn.Module, 28 | criterion: List, random_sample: bool = False, n_targs: int = 1, max_seq_len: int = 100): 29 | """Function handles validation of models with a series_id. Returns a dictionary of losses for each criterion. 30 | 31 | :param src: The source sequences 32 | :type src: Dict[int, torchd 33 | :param trg: _description_ 34 | :type trg: Dict[int, torch.Tensor] 35 | :param model: _description_ 36 | :type model: torch.nn.Module 37 | :param criterion: _description_ 38 | :type criterion: List 39 | :param random_sample: _description_, defaults to False 40 | :type random_sample: bool, optional 41 | :param n_targs: _description_, defaults to 1 42 | :type n_targs: int, optional 43 | :param max_seq_len: _description_, defaults to 100 44 | :type max_seq_len: int, optional 45 | :return: Returns a dictionary of losses for each criterion 46 | :rtype: Dict[str, float] 47 | """ 48 | scaled_crit = dict.fromkeys(criterion, 0) 49 | losses = [0] * len(criterion) 50 | losses[0] = 0 51 | for (k, v), (k2, v2) in zip(src.items(), trg.items()): 52 | output = model(v, k) 53 | for critt in criterion: 54 | loss = critt(output, v2[:, :, :n_targs]) 55 | scaled_crit[critt] += loss.item() 56 | return scaled_crit 57 | -------------------------------------------------------------------------------- /flood_forecast/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EarlyStopper(object): 5 | """EarlyStopping handler can be used to stop the training if no improvement after a given number of events. 6 | Args: 7 | patience (int): 8 | Number of events to wait if no improvement and then stop the training. 9 | score_function (callable): 10 | It should be a function taking a single argument, an :class:`~ignite.engine.Engine` object, 11 | and return a score `float`. An improvement is considered if the score is higher. 12 | trainer (Engine): 13 | trainer engine to stop the run if no improvement. 14 | min_delta (float, optional): 15 | A minimum increase in the score to qualify as an improvement, 16 | i.e. an increase of less than or equal to `min_delta`, will count as no improvement. 17 | cumulative_delta (bool, optional): 18 | It True, `min_delta` defines an increase since the last `patience` reset, otherwise, 19 | it defines an increase after the last event. Default value is False. 20 | Examples: 21 | .. code-block:: python 22 | from ignite.engine import Engine, Events 23 | from ignite.handlers import EarlyStopping 24 | def score_function(engine): 25 | val_loss = engine.state.metrics['nll'] 26 | return -val_loss 27 | handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) 28 | # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). 29 | evaluator.add_event_handler(Events.COMPLETED, handler) 30 | """ 31 | 32 | def __init__( 33 | self, 34 | patience: int, 35 | min_delta: float = 0.0, 36 | cumulative_delta: bool = False, 37 | ): 38 | 39 | if patience < 1: 40 | raise ValueError("Argument patience should be positive integer.") 41 | 42 | if min_delta < 0.0: 43 | raise ValueError("Argument min_delta should not be a negative number.") 44 | 45 | self.patience = patience 46 | self.min_delta = min_delta 47 | self.cumulative_delta = cumulative_delta 48 | self.counter = 0 49 | self.best_score = None 50 | 51 | def check_loss(self, model, validation_loss) -> bool: 52 | score = validation_loss 53 | if self.best_score is None: 54 | self.save_model_checkpoint(model) 55 | self.best_score = score 56 | 57 | elif score + self.min_delta >= self.best_score: 58 | if not self.cumulative_delta and score > self.best_score: 59 | self.best_score = score 60 | self.counter += 1 61 | print(self.counter) 62 | if self.counter >= self.patience: 63 | return False 64 | else: 65 | self.save_model_checkpoint(model) 66 | self.best_score = score 67 | self.counter = 0 68 | return True 69 | 70 | def save_model_checkpoint(self, model): 71 | torch.save(model.state_dict(), "checkpoint.pth") 72 | """_summary_ 73 | """ 74 | -------------------------------------------------------------------------------- /flood_forecast/transformer_xl/anomaly_transformer.py: -------------------------------------------------------------------------------- 1 | class AnomalyTransformer(): 2 | pass 3 | -------------------------------------------------------------------------------- /flood_forecast/transformer_xl/basis_former.py: -------------------------------------------------------------------------------- 1 | # TO-DO implement basis former 2 | -------------------------------------------------------------------------------- /flood_forecast/transformer_xl/dummy_torch.py: -------------------------------------------------------------------------------- 1 | """A small dummy model specifically for unit and integration testing purposes.""" 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class DummyTorchModel(nn.Module): 7 | def __init__(self, forecast_length: int) -> None: 8 | """A dummy model that will return a tensor of ones (batch_size, forecast_len). 9 | 10 | :param forecast_length: The length to forecast 11 | :type forecast_length: int 12 | """ 13 | super(DummyTorchModel, self).__init__() 14 | self.out_len = forecast_length 15 | # Layer specifically to avoid NULL parameter method 16 | self.linear_test_layer = nn.Linear(3, 10) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | """The forward pass for the dummy model. 20 | 21 | :param x: Here the data is irrelvant. Only batch_size is grabbed 22 | :type x: torch.Tensor 23 | :param mask: [description], defaults to None 24 | :type mask: torch.Tensor, optional 25 | :return: A tensor with fixed data of one 26 | :rtype: torch.Tensor 27 | """ 28 | batch_sz = x.size(0) 29 | result = torch.ones(batch_sz, self.out_len, requires_grad=True, device=x.device) 30 | return result 31 | -------------------------------------------------------------------------------- /flood_forecast/transformer_xl/lower_upper_config.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Dict 3 | import torch.nn.functional as F 4 | import torch 5 | from flood_forecast.custom.custom_activation import entmax15, sparsemax 6 | 7 | 8 | def initial_layer(layer_type: str, layer_params: Dict, layer_number: int = 1): 9 | layer_map = {"1DCon2v": nn.Conv1d, "Linear": nn.Linear} 10 | return layer_map[layer_type](**layer_params) 11 | 12 | 13 | def swish(x): 14 | return x * torch.sigmoid(x) 15 | 16 | activation_dict = {"ReLU": torch.nn.ReLU(), "Softplus": torch.nn.Softplus(), "Swish": swish, 17 | "entmax": entmax15, "sparsemax": sparsemax, "Softmax": torch.nn.Softmax} 18 | 19 | 20 | def variable_forecast_layer(layer_type, layer_params): 21 | final_layer_map = {"Linear": nn.Linear, "PositionWiseFeedForward": PositionwiseFeedForward} 22 | return final_layer_map 23 | 24 | 25 | class PositionwiseFeedForward(nn.Module): 26 | """A two-feed-forward-layer module Taken from DSANET repos.""" 27 | 28 | def __init__(self, d_in, d_hid, dropout=0.1): 29 | super().__init__() 30 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) 31 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) 32 | self.layer_norm = nn.LayerNorm(d_in) 33 | self.dropout = nn.Dropout(dropout) 34 | 35 | def forward(self, x): 36 | residual = x 37 | output = x.transpose(1, 2) 38 | output = self.w_2(F.relu(self.w_1(output))) 39 | output = output.transpose(1, 2) 40 | output = self.dropout(output) 41 | output = self.layer_norm(output + residual) # w 42 | return output 43 | 44 | 45 | class AR(nn.Module): 46 | 47 | def __init__(self, window): 48 | 49 | super(AR, self).__init__() 50 | self.linear = nn.Linear(window, 1) 51 | 52 | def forward(self, x): 53 | x = torch.transpose(x, 1, 2) 54 | x = self.linear(x) 55 | x = torch.transpose(x, 1, 2) 56 | return x 57 | -------------------------------------------------------------------------------- /flood_forecast/transformer_xl/masks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_square_subsequent_mask(sz: int) -> torch.Tensor: 5 | """Generates a square mask for the sequence. 6 | 7 | The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). 8 | """ 9 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 10 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 11 | return mask 12 | 13 | 14 | class TriangularCausalMask(object): 15 | def __init__(self, bath_size: int, seq_len: int, device="cpu"): 16 | """This is a mask for the attention mechanism. 17 | 18 | :param bath_size: The batch_size should be passed on init 19 | :type bath_size: int 20 | :param seq_len: Number of historical time steps. 21 | :type seq_len: int 22 | :param device: The device typically will be cpu, cuda, or tpu, defaults to "cpu" 23 | :type device: str, optional 24 | """ 25 | mask_shape = [bath_size, 1, seq_len, seq_len] 26 | with torch.no_grad(): 27 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 28 | 29 | @property 30 | def mask(self): 31 | return self._mask 32 | 33 | 34 | class ProbMask(object): 35 | def __init__(self, B: int, H, L, index, scores, device="cpu"): 36 | """Creates a probablistic mask. 37 | 38 | :param B: batch_size 39 | :type B: int 40 | :param H: Number of heads 41 | :type H: int 42 | :param L: Sequence length 43 | :type L: in 44 | :param index: [description]s 45 | :type index: [type] 46 | :param scores: [description] 47 | :type scores: [type] 48 | :param device: [description], defaults to "cpu" 49 | :type device: str, optional 50 | """ 51 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 52 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 53 | indicator = _mask_ex[torch.arange(B)[:, None, None], 54 | torch.arange(H)[None, :, None], 55 | index, :].to(device) 56 | self._mask = indicator.view(scores.shape).to(device) 57 | 58 | @property 59 | def mask(self): 60 | return self._mask 61 | -------------------------------------------------------------------------------- /flood_forecast/transformer_xl/multi_head_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.activation import MultiheadAttention 3 | from flood_forecast.transformer_xl.lower_upper_config import activation_dict 4 | from flood_forecast.transformer_xl.transformer_basic import SimplePositionalEncoding 5 | 6 | 7 | class MultiAttnHeadSimple(torch.nn.Module): 8 | """A simple multi-head attention model inspired by Vaswani et al.""" 9 | 10 | def __init__( 11 | self, 12 | number_time_series: int, 13 | seq_len=10, 14 | output_seq_len=None, 15 | d_model=128, 16 | num_heads=8, 17 | dropout=0.1, 18 | output_dim=1, 19 | final_layer=False): 20 | 21 | super().__init__() 22 | self.dense_shape = torch.nn.Linear(number_time_series, d_model) 23 | self.pe = SimplePositionalEncoding(d_model) 24 | self.multi_attn = MultiheadAttention( 25 | embed_dim=d_model, num_heads=num_heads, dropout=dropout) 26 | self.final_layer = torch.nn.Linear(d_model, output_dim) 27 | self.length_data = seq_len 28 | self.forecast_length = output_seq_len 29 | self.sigmoid = None 30 | self.output_dim = output_dim 31 | if self.forecast_length: 32 | self.last_layer = torch.nn.Linear(seq_len, output_seq_len) 33 | if final_layer: 34 | self.sigmoid = activation_dict[final_layer]() 35 | 36 | def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: 37 | """ 38 | :param: x torch.Tensor: of shape (B, L, M) 39 | Where B is the batch size, L is the sequence length and M is the number of time 40 | :return: a tensor of dimension (B, forecast_length) 41 | """ 42 | x = self.dense_shape(x) 43 | x = self.pe(x) 44 | # Permute to (L, B, M) 45 | x = x.permute(1, 0, 2) 46 | if mask is None: 47 | x = self.multi_attn(x, x, x)[0] 48 | else: 49 | x = self.multi_attn(x, x, x, attn_mask=self.mask)[0] 50 | x = self.final_layer(x) 51 | if self.forecast_length: 52 | # Switch to (B, M, L) 53 | x = x.permute(1, 2, 0) 54 | x = self.last_layer(x) 55 | if self.sigmoid: 56 | x = self.sigmoid(x) 57 | return x.permute(0, 2, 1) 58 | return x.view(-1, self.forecast_length) 59 | return x.view(-1, self.length_data) 60 | -------------------------------------------------------------------------------- /flood_forecast/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | from torch.autograd import Variable 4 | from flood_forecast.model_dict_function import pytorch_criterion_dict 5 | 6 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def numpy_to_tvar(x: torch.Tensor) -> torch.autograd.Variable: 10 | """Converts a numpy array into a PyTorch Tensor. 11 | 12 | :param x: A numpy array you want to convert to a PyTorch tensor 13 | :type x: torch.Tensor 14 | :return: A tensor variable 15 | :rtype: torch.Variable 16 | """ 17 | return Variable(torch.from_numpy(x).type(torch.FloatTensor).to(device)) 18 | 19 | 20 | def flatten_list_function(input_list: List) -> List: 21 | """A function to flatten a list.""" 22 | return [item for sublist in input_list for item in sublist] 23 | 24 | 25 | def make_criterion_functions(crit_list: List) -> List: 26 | """crit_list should be either dict or list. 27 | 28 | If dict, then it should be a dictionary of the form . returns a list 29 | """ 30 | final_list = [] 31 | if type(crit_list) == list: 32 | for crit in crit_list: 33 | final_list.append(pytorch_criterion_dict[crit]()) 34 | else: 35 | for k, v in crit_list.items(): 36 | final_list.append(pytorch_criterion_dict[k](**v)) 37 | return final_list 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | shap==0.47.0 2 | scikit-learn>=1.0.1 3 | pandas 4 | torch 5 | tb-nightly 6 | seaborn 7 | future 8 | h5py 9 | wandb==0.19.3 10 | google-cloud 11 | google-cloud-storage 12 | plotly~=5.24.0 13 | pytz>=2022.1 14 | setuptools~=76.0.0 15 | numpy==1.26 16 | requests 17 | torchvision>=0.6.0 18 | mpld3>=0.5 19 | numba>=0.50 20 | sphinx 21 | sphinx-rtd-theme 22 | sphinx-autodoc-typehints 23 | sphinx 24 | einops 25 | pytorch-tsmixer 26 | einsum 27 | jaxtyping 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from setuptools import setup 3 | import os 4 | 5 | library_folder = os.path.dirname(os.path.realpath(__file__)) 6 | requirementPath = f'{library_folder}/requirements.txt' 7 | install_requires = [] 8 | if os.path.isfile(requirementPath): 9 | with open(requirementPath) as f: 10 | install_requires = f.read().splitlines() 11 | 12 | dev_requirements = [ 13 | 'autopep8', 14 | 'flake8' 15 | ] 16 | 17 | setup( 18 | name='flood_forecast', 19 | version='1.001dev', 20 | packages=[ 21 | 'flood_forecast', 22 | 'flood_forecast.transformer_xl', 23 | 'flood_forecast.preprocessing', 24 | 'flood_forecast.da_rnn', 25 | "flood_forecast.basic", 26 | "flood_forecast.meta_models", 27 | "flood_forecast.gcp_integration", 28 | "flood_forecast.deployment", 29 | "flood_forecast.custom"], 30 | license='GPL 3.0', 31 | description="An open source framework for deep time series forecasting and classfication built with PyTorch.", 32 | long_description='Flow Forecast is the top open source deep learning for time series forecasting and classification framework. We were the original TS framework to contain models like the transformer and have now expanded to include all popular deep learning models.', 33 | install_requires=install_requires, 34 | extras_require={ 35 | 'dev': dev_requirements}) 36 | -------------------------------------------------------------------------------- /tests/24_May_202202_25PM_1.json: -------------------------------------------------------------------------------- 1 | {"model_name": "CustomTransformerDecoder", "n_targets": 2, "model_type": "PyTorch", "model_params": {"n_time_series": 4, "seq_length": 26, "output_seq_length": 1, "output_dim": 2, "n_layers_encoder": 6}, "dataset_params": {"class": "GeneralClassificationLoader", "n_classes": 2, "training_path": "tests/test_data/ff_test.csv", "validation_path": "tests/test_data/ff_test.csv", "test_path": "tests/test_data/ff_test.csv", "sequence_length": 26, "batch_size": 101, "forecast_history": 26, "train_end": 80, "valid_start": 4, "valid_end": 90, "target_col": ["anomalous_rain"], "relevant_cols": ["anomalous_rain", "tmpf", "cfs", "dwpf", "height"], "scaler": "StandardScaler", "interpolate": {"method": "back_forward_generic", "params": {"relevant_columns": ["cfs", "tmpf", "p01m", "dwpf"]}}, "forecast_length": 1}, "training_params": {"criterion": "CrossEntropyLoss", "optimizer": "Adam", "optim_params": {}, "lr": 0.03, "epochs": 4, "batch_size": 100, "shuffle": false}, "GCS": false, "wandb": {"name": "flood_forecast_circleci", "tags": ["dummy_run", "circleci", "multi_head", "classification"], "project": "repo-flood_forecast"}, "forward_params": {}, "metrics": ["CrossEntropyLoss"], "run": [{"epoch": 0, "train_loss": "0.0680028973509454", "validation_loss": "113.27512431330979"}, {"epoch": 1, "train_loss": "0.05458420396058096", "validation_loss": "108.99862844124436"}, {"epoch": 2, "train_loss": "0.054659905693390305", "validation_loss": "106.30307429283857"}, {"epoch": 3, "train_loss": "0.054730736438391936", "validation_loss": "104.98548858333379"}]} 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIStream-Peelout/flow-forecast/9a2af06685db4a635eb21e57d8e522a355f85286/tests/__init__.py -------------------------------------------------------------------------------- /tests/auto_encoder.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "BasicAE", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "input_shape":3, 6 | "out_features":128 7 | }, 8 | "n_targets": 3, 9 | "dataset_params": 10 | { "class": "AutoEncoder", 11 | "training_path": "tests/test_data/keag_small.csv", 12 | "validation_path": "tests/test_data/keag_small.csv", 13 | "test_path": "tests/test_data/keag_small.csv", 14 | "batch_size":4, 15 | "forecast_history": 1, 16 | "train_end": 100, 17 | "valid_start":301, 18 | "valid_end": 401, 19 | "relevant_cols": ["cfs", "precip", "temp"], 20 | "scaler": "StandardScaler", 21 | "interpolate": false 22 | }, 23 | "training_params": 24 | { 25 | "criterion":"MSE", 26 | "optimizer": "Adam", 27 | "lr": 0.3, 28 | "epochs": 1, 29 | "batch_size":4, 30 | "optim_params": 31 | { 32 | } 33 | }, 34 | "GCS": false, 35 | 36 | "wandb": { 37 | "name": "flood_forecast_circleci", 38 | "project": "repo-flood_forecast", 39 | "tags": ["dummy_run", "circleci", "ae"] 40 | }, 41 | "metrics":["MSE"], 42 | 43 | "inference_params":{ 44 | "hours_to_forecast":1 45 | 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /tests/classification_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "MultiAttnHeadSimple", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "number_time_series":2, 6 | "seq_len":6, 7 | "final_layer": "Softmax", 8 | "output_seq_len": 1, 9 | "output_dim": 9 10 | }, 11 | "dataset_params": 12 | { "class": "GeneralClassificationLoader", 13 | "n_classes": 9, 14 | "training_path": "tests/test_data/keag_small.csv", 15 | "validation_path": "tests/test_data/keag_small.csv", 16 | "test_path": "tests/test_data/keag_small.csv", 17 | "sequence_length":6, 18 | "batch_size":20, 19 | "forecast_history":6, 20 | "train_end": 300, 21 | "valid_start":0, 22 | "valid_end": 300, 23 | "test_end": 303, 24 | "target_col": ["precip"], 25 | "relevant_cols": ["precip", "temp", "cfs"], 26 | "scaler": "StandardScaler", 27 | "interpolate": false 28 | }, 29 | 30 | "training_params": 31 | { 32 | "criterion":"FocalLoss", 33 | "optimizer": "Adam", 34 | "criterion_params": 35 | {"alpha": 0.5, 36 | "reduction": "sum"}, 37 | "optim_params": 38 | {}, 39 | "lr": 0.3, 40 | "epochs": 1, 41 | "batch_size":4 42 | }, 43 | "GCS": false, 44 | 45 | "wandb": { 46 | "name": "flood_forecast_circleci", 47 | "tags": ["dummy_run", "circleci", "multi_head", "classification"], 48 | "project": "repo-flood_forecast" 49 | }, 50 | "forward_params":{}, 51 | "metrics":["CrossEntropyLoss"] 52 | } 53 | -------------------------------------------------------------------------------- /tests/config.json: -------------------------------------------------------------------------------- 1 | {"model_name": "CustomTransformerDecoder", "model_type": "PyTorch", "model_params": {"seq_length": 11, "n_time_series": 9, "output_seq_length": 2, "n_layers_encoder": 3, "use_mask": true}, "dataset_params": {"class": "default", "training_path": "United_States__Florida__Palm_Beach_County.csv", "validation_path": "United_States__Florida__Palm_Beach_County.csv", "test_path": "United_States__Florida__Palm_Beach_County.csv", "forecast_test_len": 15, "batch_size": 20, "forecast_history": 11, "forecast_length": 2, "train_end": 61, "valid_start": 62, "valid_end": 88, "test_start": 61, "test_end": 90, "target_col": ["new_cases"], "relevant_cols": ["new_cases", "month", "weekday", "mobility_retail_recreation", "mobility_grocery_pharmacy", "mobility_parks", "mobility_transit_stations", "mobility_workplaces", "mobility_residential"], "scaler": "StandardScaler", "interpolate": false}, "training_params": {"criterion": "MSE", "optimizer": "Adam", "optim_params": {}, "lr": 0.0001, "epochs": 10, "batch_size": 19}, "GCS": true, "early_stopping": {"patience": 3}, "sweep": true, "wandb": false, "forward_params": {}, "metrics": ["MSE"], "inference_params": {"datetime_start": "2020-06-10", "hours_to_forecast": 15, "num_prediction_samples": 100, "test_csv_path": "United_States__Florida__Palm_Beach_County.csv", "decoder_params": {"decoder_function": "simple_decode", "unsqueeze_dim": 1}, "dataset_params": {"file_path": "United_States__Florida__Palm_Beach_County.csv", "forecast_history": 11, "forecast_length": 2, "relevant_cols": ["new_cases", "month", "weekday", "mobility_retail_recreation", "mobility_grocery_pharmacy", "mobility_parks", "mobility_transit_stations", "mobility_workplaces", "mobility_residential"], "target_col": ["new_cases"], "scaling": "StandardScaler", "interpolate_param": false}}, "weight_path": "/content/github_aistream-peelout_flow-forecast/29_June_202009_26AM_model.pth", "run": [{"epoch": 0, "train_loss": "0.6654510299364725", "validation_loss": "1.1439250165765935"}, {"epoch": 1, "train_loss": "0.7072166502475739", "validation_loss": "1.095255504954945"}, {"epoch": 2, "train_loss": "0.5650965571403503", "validation_loss": "1.2630093747919255"}, {"epoch": 3, "train_loss": "0.504930337270101", "validation_loss": "1.8901219367980957"}, {"epoch": 4, "train_loss": "0.49202097455660504", "validation_loss": "2.055953329259699"}], "epochs": 0} 2 | -------------------------------------------------------------------------------- /tests/cross_former.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Crossformer", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "forecast_history":36, 7 | "n_time_series": 3, 8 | "forecast_length": 10, 9 | "seg_len": 4 10 | 11 | }, 12 | "n_targets": 4, 13 | "dataset_params": 14 | { "class": "default", 15 | "training_path": "tests/test_data/keag_small.csv", 16 | "validation_path": "tests/test_data/keag_small.csv", 17 | "test_path": "tests/test_data/keag_small.csv", 18 | "batch_size":10, 19 | "forecast_history":36, 20 | "forecast_length":10, 21 | "train_start": 1, 22 | "train_end": 300, 23 | "valid_start":302, 24 | "valid_end": 401, 25 | "test_start":50, 26 | "test_end": 450, 27 | "target_col": ["cfs"], 28 | "relevant_cols": ["cfs", "precip", "temp"], 29 | "scaler": "StandardScaler", 30 | "interpolate": false 31 | }, 32 | "training_params": 33 | { 34 | "criterion":"MSE", 35 | "optimizer": "Adam", 36 | "optim_params": 37 | { 38 | 39 | }, 40 | "lr": 0.03, 41 | "epochs": 1, 42 | "batch_size":4 43 | 44 | }, 45 | "GCS": false, 46 | 47 | "wandb": { 48 | "name": "flood_forecast_circleci", 49 | "tags": ["dummy_run", "circleci"], 50 | "project": "repo-flood_forecast" 51 | }, 52 | "forward_params":{}, 53 | "metrics":["MSE"], 54 | "inference_params": 55 | { 56 | "datetime_start":"2016-05-31", 57 | "hours_to_forecast":334, 58 | "test_csv_path":"tests/test_data/keag_small.csv", 59 | "decoder_params":{ 60 | "decoder_function": "simple_decode", 61 | "unsqueeze_dim": 1} 62 | , 63 | "dataset_params":{ 64 | "file_path": "tests/test_data/keag_small.csv", 65 | "forecast_history":36, 66 | "forecast_length":10, 67 | "relevant_cols": ["cfs", "precip", "temp"], 68 | "target_col": ["cfs"], 69 | "scaling": "StandardScaler", 70 | "interpolate_param": false 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /tests/custom_encode.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "CustomTransformerDecoder", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "seq_length":5, 7 | "output_seq_length": 1, 8 | "n_layers_encoder": 6 9 | }, 10 | "dataset_params": 11 | { "class": "default", 12 | "training_path": "tests/test_data/keag_small.csv", 13 | "validation_path": "tests/test_data/keag_small.csv", 14 | "test_path": "tests/test_data/keag_small.csv", 15 | "forecast_history":5, 16 | "forecast_length":1, 17 | "train_end": 99, 18 | "valid_start":301, 19 | "valid_end": 401, 20 | "test_start":125, 21 | "test_end":500, 22 | "target_col": ["cfs"], 23 | "relevant_cols": ["cfs", "precip", "temp"], 24 | "scaler": "StandardScaler", 25 | "interpolate": false 26 | }, 27 | "early_stopping":{ 28 | "patience":1 29 | 30 | }, 31 | "training_params": 32 | { 33 | "criterion":"MAPE", 34 | "optimizer": "Adam", 35 | "optim_params": 36 | { 37 | 38 | }, 39 | "lr": 0.3, 40 | "epochs": 2, 41 | "batch_size":4 42 | 43 | }, 44 | "GCS": false, 45 | "wandb": { 46 | "name": "flood_forecast_circleci", 47 | "project": "repo-flood_forecast", 48 | "tags": ["dummy_run", "circleci"] 49 | }, 50 | "forward_params":{}, 51 | "metrics":["MSE"], 52 | "inference_params": 53 | { 54 | "datetime_start":"2016-05-31", 55 | "hours_to_forecast":336, 56 | "test_csv_path":"tests/test_data/keag_small.csv", 57 | "decoder_params":{ 58 | "decoder_function": "simple_decode", 59 | "unsqueeze_dim": 1}, 60 | "num_prediction_samples":20, 61 | "dataset_params":{ 62 | "file_path": "tests/test_data/keag_small.csv", 63 | "forecast_history":5, 64 | "forecast_length":1, 65 | "relevant_cols": ["cfs", "precip", "temp"], 66 | "target_col": ["cfs"], 67 | "scaling": "StandardScaler", 68 | "interpolate_param": false 69 | } 70 | } 71 | 72 | 73 | } 74 | -------------------------------------------------------------------------------- /tests/da_meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DARNN", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "hidden_size_encoder":128, 7 | "decoder_hidden_size":128, 8 | "out_feats":1, 9 | "forecast_history":6, 10 | "dropout": 0.4, 11 | "gru_lstm": false, 12 | "meta_data": 13 | { 14 | "method":"Bilinear", 15 | "da_method": "down_sample", 16 | "meta_dim": 128, 17 | "params": { 18 | "in1_features": 5, 19 | "in2_features": 1, 20 | "out_features": 5 21 | } 22 | } 23 | }, 24 | "meta_data":{ 25 | "path":"tests/auto_encoder.json", 26 | "column_id": "datetime", 27 | "uuid": "2014-05-02 01:00:00" 28 | }, 29 | "dataset_params": 30 | { "class": "default", 31 | "training_path": "tests/test_data/keag_small.csv", 32 | "validation_path": "tests/test_data/keag_small.csv", 33 | "test_path": "tests/test_data/keag_small.csv", 34 | "batch_size":4, 35 | "forecast_history":5, 36 | "forecast_length":1, 37 | "train_end": 300, 38 | "valid_start":302, 39 | "valid_end": 404, 40 | "test_end": 500, 41 | "target_col": ["cfs"], 42 | "relevant_cols": ["cfs", "precip", "temp"], 43 | "scaler": "StandardScaler", 44 | "interpolate": false 45 | }, 46 | 47 | "training_params": 48 | { 49 | "criterion":"MSE", 50 | "optimizer": "Adam", 51 | "optim_params": 52 | { 53 | }, 54 | "lr": 0.3, 55 | "epochs": 1, 56 | "batch_size":4 57 | 58 | }, 59 | "GCS": false, 60 | 61 | "wandb": { 62 | "name": "flood_forecast_circleci", 63 | "tags": ["dummy_run", "circleci", "da_revised"], 64 | "project": "repo-flood_forecast" 65 | }, 66 | "forward_params":{}, 67 | "metrics":["MSE"], 68 | "inference_params": 69 | { "num_prediction_samples": 10, 70 | "datetime_start":"2016-05-31", 71 | "hours_to_forecast":336, 72 | "test_csv_path":"tests/test_data/keag_small.csv", 73 | "dataset_params":{ 74 | "file_path": "tests/test_data/keag_small.csv", 75 | "forecast_history":5, 76 | "forecast_length":1, 77 | "relevant_cols": ["cfs", "precip", "temp"], 78 | "target_col": ["cfs"], 79 | "scaling": "StandardScaler", 80 | "interpolate_param": false 81 | }, 82 | "decoder_params":{ 83 | "decoder_function": "simple_decode", "unsqueeze_dim": 1} 84 | } 85 | 86 | } 87 | -------------------------------------------------------------------------------- /tests/da_rnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DARNN", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "hidden_size_encoder":128, 7 | "decoder_hidden_size":128, 8 | "out_feats":1, 9 | "forecast_history":6, 10 | "gru_lstm": false 11 | }, 12 | 13 | "dataset_params": 14 | { "class": "default", 15 | "training_path": "tests/test_data/keag_small.csv", 16 | "validation_path": "tests/test_data/keag_small.csv", 17 | "test_path": "tests/test_data/keag_small.csv", 18 | "batch_size":4, 19 | "forecast_history":5, 20 | "forecast_length":1, 21 | "train_end": 300, 22 | "valid_start":301, 23 | "valid_end": 401, 24 | "test_end": 500, 25 | "target_col": ["cfs"], 26 | "relevant_cols": ["cfs", "precip", "temp"], 27 | "scaler": "StandardScaler", 28 | "interpolate": false 29 | }, 30 | 31 | "training_params": 32 | { 33 | "criterion":"MASELoss", 34 | "optimizer": "Adam", 35 | "optim_params": 36 | { 37 | }, 38 | "criterion_params": 39 | { 40 | "baseline_method":"mean" 41 | }, 42 | "lr": 0.3, 43 | "epochs": 1, 44 | "batch_size":4 45 | 46 | }, 47 | "GCS": false, 48 | 49 | "wandb": { 50 | "name": "flood_forecast_circleci", 51 | "tags": ["dummy_run", "circleci", "da_revised"], 52 | "project": "repo-flood_forecast" 53 | }, 54 | "forward_params":{}, 55 | "metrics":{"MASELoss":{"baseline_method":"mean"}}, 56 | "inference_params": 57 | { "num_prediction_samples": 10, 58 | "criterion_params":{ 59 | "baseline_method":"mean" 60 | }, 61 | "datetime_start":"2016-05-31", 62 | "hours_to_forecast":336, 63 | "test_csv_path":"tests/test_data/keag_small.csv", 64 | 65 | "decoder_params":{ 66 | "decoder_function": "simple_decode", "unsqueeze_dim": 1} 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /tests/da_rnn_probabilistic.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DARNN", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "hidden_size_encoder":128, 7 | "decoder_hidden_size":128, 8 | "out_feats":1, 9 | "forecast_history":6, 10 | "gru_lstm": true, 11 | "probabilistic": true 12 | }, 13 | 14 | "dataset_params": 15 | { "class": "default", 16 | "training_path": "tests/test_data/keag_small.csv", 17 | "validation_path": "tests/test_data/keag_small.csv", 18 | "test_path": "tests/test_data/keag_small.csv", 19 | "batch_size":4, 20 | "forecast_history":5, 21 | "forecast_length":1, 22 | "train_end": 300, 23 | "valid_start":301, 24 | "valid_end": 401, 25 | "test_end": 500, 26 | "target_col": ["cfs"], 27 | "relevant_cols": ["cfs", "precip", "temp"], 28 | "scaler": "StandardScaler", 29 | "interpolate": false 30 | }, 31 | 32 | "training_params": 33 | { 34 | "criterion":"NegativeLogLikelihood", 35 | "probabilistic": true, 36 | "optimizer": "Adam", 37 | "optim_params": 38 | { 39 | }, 40 | "lr": 0.3, 41 | "epochs": 1, 42 | "batch_size":4 43 | 44 | }, 45 | "GCS": false, 46 | "wandb": { 47 | "name": "flood_forecast_circleci", 48 | "tags": ["dummy_run", "circleci", "da_revised_probabilistic"], 49 | "project": "repo-flood_forecast" 50 | }, 51 | "forward_params":{}, 52 | "metrics":["NegativeLogLikelihood"], 53 | "inference_params": 54 | { "num_prediction_samples": 10, 55 | "datetime_start":"2016-05-31", 56 | "hours_to_forecast":336, 57 | "test_csv_path":"tests/test_data/keag_small.csv", 58 | "probabilistic": true, 59 | "dataset_params":{ 60 | "file_path": "tests/test_data/keag_small.csv", 61 | "forecast_history":5, 62 | "forecast_length":1, 63 | "relevant_cols": ["cfs", "precip", "temp"], 64 | "target_col": ["cfs"], 65 | "scaling": "StandardScaler", 66 | "interpolate_param": false 67 | }, 68 | "decoder_params":{ 69 | "decoder_function": "simple_decode", "unsqueeze_dim": 1, "probabilistic": true} 70 | } 71 | 72 | } 73 | -------------------------------------------------------------------------------- /tests/data_format_tests.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.preprocessing.closest_station import get_weather_data, format_dt, convert_temp, \ 2 | process_asos_csv, process_asos_data 3 | from datetime import datetime 4 | import unittest 5 | import os 6 | 7 | 8 | class DataQualityTests(unittest.TestCase): 9 | def setUp(self): 10 | self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") 11 | 12 | def test_format_dt(self): 13 | self.assertEqual(format_dt("2017-04-07 08:55"), datetime(year=2017, month=4, day=7, hour=9)) 14 | self.assertEqual(format_dt("2018-04-08 23:55"), datetime(year=2018, month=4, day=9, hour=0)) 15 | 16 | def test_convert_temp(self): 17 | self.assertEqual(convert_temp("50.3"), 50.3) 18 | self.assertEqual(convert_temp("-12.0"), -12.0) 19 | self.assertEqual(convert_temp("M"), 50) 20 | 21 | def test_process_asos_csv(self): 22 | df, precip_missing, temp_missing = process_asos_csv( 23 | os.path.join(self.test_data_path, "small_test.csv")) 24 | self.assertEqual(df.iloc[1]['p01m'], 47) 25 | self.assertEqual(df.iloc[0]['tmpf'], 50) 26 | self.assertEqual(df.iloc[1]['hour_updated'].hour, 1) 27 | self.assertEqual(df.iloc[1]['tmpf'], 53) 28 | self.assertEqual(precip_missing, 0) 29 | self.assertEqual(temp_missing, 0) 30 | 31 | def test_process_asos_full(self): 32 | df, precip_missing, temp_missing = process_asos_csv( 33 | os.path.join(self.test_data_path, "asos_raw.csv")) 34 | self.assertGreater(temp_missing, 10) 35 | self.assertGreater(precip_missing, 2) 36 | 37 | def test_value_imputation(self): 38 | df, precip_missing, temp_missing = process_asos_csv( 39 | os.path.join(self.test_data_path, "imputation_test.csv")) 40 | self.assertEqual(df.iloc[0]['p01m'], 0) 41 | self.assertEqual(df.iloc[2]['p01m'], 23) 42 | 43 | def test_get_weather_data(self): 44 | url = ( 45 | "https://mesonet.agron.iastate.edu/cgi-bin/request/asos.py?" 46 | "station={}&data=tmpf&data=p01m&year1=2019&month1=1&day1=1&year2=2019&month2=1&" 47 | "day2=2&tz=Etc%2FUTC&format=onlycomma&latlon=no&missing=M&trace=T&direct=no&report_type=1&report_type=2" 48 | ) 49 | print(url) 50 | get_weather_data(os.path.join(self.test_data_path, "full_out.json"), {}, url) 51 | self.assertEqual(1, 1) 52 | 53 | def test_process_asos_data(self): 54 | full_data_url = ( 55 | "https://mesonet.agron.iastate.edu/cgi-bin/request/asos.py?" 56 | "station={}&data=tmpf&data=p01m&year1=2014&month1=1&day1=1&year2=2019&month2=1&day2=2" 57 | "&tz=Etc%2FUTC&format=onlycomma&latlon=no&missing=M&trace=T&direct=no&report_type=1&report_type=2" 58 | ) 59 | river_result = process_asos_data(os.path.join(self.test_data_path, "asos_process.json"), full_data_url) 60 | self.assertGreater(river_result["stations"][1]["missing_temp"], -1) 61 | self.assertGreater(river_result["stations"][2]["missing_precip"], -1) 62 | 63 | 64 | if __name__ == '__main__': 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /tests/data_loader_tests.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.preprocessing.pytorch_loaders import ( 2 | CSVTestLoader, 3 | CSVDataLoader, 4 | AEDataloader, 5 | ) 6 | import unittest 7 | import os 8 | import torch 9 | from datetime import datetime 10 | 11 | 12 | class DataLoaderTests(unittest.TestCase): 13 | """Class to test data loader functionality for the code mod. 14 | 15 | Specifically, reuturn types and indexing to make sure there is no overlap. 16 | """ 17 | 18 | def setUp(self): 19 | self.test_data_path = os.path.join( 20 | os.path.dirname(os.path.abspath(__file__)), "test_data" 21 | ) 22 | data_base_params = { 23 | "file_path": os.path.join(self.test_data_path, "keag_small.csv"), 24 | "forecast_history": 20, 25 | "forecast_length": 20, 26 | "relevant_cols": ["cfs", "temp", "precip"], 27 | "target_col": ["cfs"], 28 | "interpolate_param": False, 29 | } 30 | self.train_loader = CSVDataLoader( 31 | os.path.join(self.test_data_path, "keag_small.csv"), 32 | 30, 33 | 20, 34 | target_col=["cfs"], 35 | relevant_cols=["cfs", "precip", "temp"], 36 | interpolate_param=False, 37 | ) 38 | data_base_params["start_stamp"] = 20 39 | self.test_loader = CSVTestLoader( 40 | os.path.join(self.test_data_path, "keag_small.csv"), 41 | 336, 42 | **data_base_params 43 | ) 44 | self.ae_loader = AEDataloader( 45 | os.path.join(self.test_data_path, "keag_small.csv"), 46 | relevant_cols=["cfs", "temp", "precip"], 47 | ) 48 | data_base_params["end_stamp"] = 220 49 | self.train_loader2 = CSVDataLoader( 50 | **data_base_params 51 | ) 52 | 53 | def test_loader2_get_item(self): 54 | src, df, forecast_start_index = self.test_loader[0] 55 | self.assertEqual(type(src), torch.Tensor) 56 | self.assertEqual(forecast_start_index, 20) 57 | self.assertEqual(df.iloc[2]["cfs"], 445) 58 | self.assertEqual(len(df), 356) 59 | 60 | def test_loader2_get_date(self): 61 | src, df, forecast_start_index, = self.test_loader.get_from_start_date( 62 | datetime(2014, 6, 3, 0) 63 | ) 64 | self.assertEqual(type(src), torch.Tensor) 65 | self.assertEqual(forecast_start_index, 783) 66 | self.assertEqual( 67 | df.iloc[0]["datetime"].day, datetime(2014, 6, 2, 4).day 68 | ) 69 | 70 | def test_loader_get_gcs_data(self): 71 | test_loader = CSVDataLoader( 72 | file_path="gs://flow_datasets/Afghanistan____.csv", 73 | forecast_history=14, 74 | forecast_length=14, 75 | target_col=["cases"], 76 | relevant_cols=["cases", "recovered", "active", "deaths"], 77 | sort_column="date", 78 | interpolate_param=False, 79 | gcp_service_key=None, # for CircleCI tests, local test needs key.json 80 | ) 81 | self.assertIsInstance(test_loader, CSVDataLoader) 82 | 83 | def test_ae(self): 84 | x, y = self.ae_loader[0] 85 | self.assertEqual(x.shape, y.squeeze(1).shape) 86 | 87 | def test_trainer(self): 88 | x, y = self.train_loader[0] 89 | self.assertEqual(x.shape[0], 30) 90 | self.assertEqual(x.shape[1], 3) 91 | self.assertEqual(y.shape[0], 20) 92 | # Check first and last dim are not overlap 93 | self.assertFalse(torch.eq(x[29, 0], y[0, 0])) 94 | 95 | def test_start_end(self): 96 | self.assertEqual(len(self.train_loader.df), len(self.test_loader.df) + 20) 97 | self.assertEqual(len(self.train_loader2.df), 200) 98 | 99 | 100 | if __name__ == "__main__": 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /tests/decoder_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "CustomTransformerDecoder", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "seq_length":5, 7 | "output_seq_length": 1, 8 | "n_layers_encoder": 4 9 | }, 10 | "dataset_params": 11 | { "class": "default", 12 | "training_path": "tests/test_data/keag_small.csv", 13 | "validation_path": "tests/test_data/keag_small.csv", 14 | "test_path": "tests/test_data/keag_small.csv", 15 | "batch_size":10, 16 | "forecast_history":5, 17 | "forecast_length":1, 18 | "train_end": 100, 19 | "valid_start":301, 20 | "valid_end": 401, 21 | "test_end":400, 22 | "target_col": ["cfs"], 23 | "relevant_cols": ["cfs", "precip", "temp"], 24 | "scaler": "StandardScaler", 25 | "interpolate": false 26 | }, 27 | "training_params": 28 | { 29 | "criterion":"RMSE", 30 | "optimizer": "Adam", 31 | "optim_params": 32 | { 33 | 34 | }, 35 | "lr": 0.3, 36 | "epochs": 1, 37 | "batch_size":4 38 | }, 39 | "GCS": false, 40 | 41 | "wandb": { 42 | "name": "flood_forecast_circleci", 43 | "project": "repo-flood_forecast", 44 | "tags": ["dummy_run", "circleci"] 45 | }, 46 | "forward_params":{}, 47 | "metrics":["MSE"], 48 | "inference_params": 49 | { 50 | "datetime_start":"2016-05-31", 51 | "hours_to_forecast":336, 52 | "test_csv_path":"tests/test_data/keag_small.csv", 53 | "decoder_params":{ 54 | "decoder_function": "simple_decode", 55 | "unsqueeze_dim": 1}, 56 | "dataset_params":{ 57 | "file_path": "tests/test_data/keag_small.csv", 58 | "forecast_history":5, 59 | "forecast_length":1, 60 | "relevant_cols": ["cfs", "precip", "temp"], 61 | "target_col": ["cfs"], 62 | "scaling": "StandardScaler", 63 | "interpolate_param": false 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /tests/dlinear.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DLinear", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "forecast_history":20, 7 | "forecast_length": 10, 8 | "enc_in": 321, 9 | "individual": false 10 | }, 11 | "dataset_params": 12 | { "class": "default", 13 | "training_path": "tests/test_data/keag_small.csv", 14 | "validation_path": "tests/test_data/keag_small.csv", 15 | "test_path": "tests/test_data/keag_small.csv", 16 | "batch_size":10, 17 | "forecast_history":20, 18 | "forecast_length":10, 19 | "train_start": 1, 20 | "train_end": 300, 21 | "valid_start":302, 22 | "valid_end": 401, 23 | "test_start":50, 24 | "test_end": 450, 25 | "target_col": ["cfs"], 26 | "relevant_cols": ["cfs", "precip", "temp"], 27 | "scaler": "StandardScaler", 28 | "interpolate": false 29 | }, 30 | "training_params": 31 | { 32 | "criterion":"MSE", 33 | "optimizer": "Adam", 34 | "optim_params": 35 | { 36 | 37 | }, 38 | "lr": 0.03, 39 | "epochs": 1, 40 | "batch_size":4 41 | 42 | }, 43 | "GCS": false, 44 | 45 | "wandb": { 46 | "name": "flood_forecast_circleci", 47 | "tags": ["dummy_run", "circleci"], 48 | "project": "repo-flood_forecast" 49 | }, 50 | "forward_params":{}, 51 | "metrics":["MSE"], 52 | "inference_params": 53 | { 54 | "datetime_start":"2016-05-31", 55 | "hours_to_forecast":334, 56 | "test_csv_path":"tests/test_data/keag_small.csv", 57 | "decoder_params":{ 58 | "decoder_function": "simple_decode", 59 | "unsqueeze_dim": 1} 60 | , 61 | "dataset_params":{ 62 | "file_path": "tests/test_data/keag_small.csv", 63 | "forecast_history":20, 64 | "forecast_length":10, 65 | "relevant_cols": ["cfs", "precip", "temp"], 66 | "target_col": ["cfs"], 67 | "scaling": "StandardScaler", 68 | "interpolate_param": false 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /tests/dsanet.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DSANet", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "forecast_history":10, 7 | "n_time_series":3, 8 | "dsa_local": 2, 9 | "dsanet_n_kernels": 32, 10 | "dsanet_w_kernals": 1, 11 | "dsanet_d_model":128, 12 | "dsanet_d_inner": 2048, 13 | "dsanet_n_layers": 2, 14 | "dsanet_n_head": 8, 15 | "dropout": 0.11 16 | }, 17 | "n_targets":3, 18 | "dataset_params": 19 | { "class": "default", 20 | "training_path": "tests/test_data/keag_small.csv", 21 | "validation_path": "tests/test_data/keag_small.csv", 22 | "test_path": "tests/test_data/keag_small.csv", 23 | "batch_size":10, 24 | "forecast_history":10, 25 | "forecast_length":1, 26 | "train_start": 1, 27 | "train_end": 300, 28 | "valid_start":301, 29 | "valid_end": 401, 30 | "test_start":50, 31 | "test_end": 450, 32 | "target_col": ["cfs"], 33 | "relevant_cols": ["cfs", "precip", "temp"], 34 | "scaler": "StandardScaler", 35 | "interpolate": false 36 | }, 37 | "training_params": 38 | { 39 | "criterion":"MSE", 40 | "optimizer": "Adam", 41 | "optim_params": 42 | { 43 | 44 | }, 45 | "lr": 0.03, 46 | "epochs": 1, 47 | "batch_size":4 48 | 49 | }, 50 | "GCS": false, 51 | 52 | "wandb": { 53 | "name": "flood_forecast_circleci", 54 | "tags": ["dummy_run", "circleci"], 55 | "project": "repo-flood_forecast" 56 | }, 57 | "forward_params":{}, 58 | "metrics":["MSE"], 59 | "inference_params": 60 | { 61 | "datetime_start":"2016-05-31", 62 | "hours_to_forecast":334, 63 | "test_csv_path":"tests/test_data/keag_small.csv", 64 | "decoder_params":{ 65 | "decoder_function": "simple_decode", 66 | "unsqueeze_dim": 1} 67 | , 68 | "dataset_params":{ 69 | "file_path": "tests/test_data/keag_small.csv", 70 | "forecast_history":10, 71 | "forecast_length":1, 72 | "relevant_cols": ["cfs", "precip", "temp"], 73 | "target_col": ["cfs"], 74 | "scaling": "StandardScaler", 75 | "interpolate_param": false 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /tests/dsanet_3.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DSANet", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "forecast_history":10, 7 | "dsa_targs": 1, 8 | "n_time_series":3, 9 | "dsa_local": 2, 10 | "dsanet_n_kernels": 32, 11 | "dsanet_w_kernals": 1, 12 | "dsanet_d_model":128, 13 | "dsanet_d_inner": 2048, 14 | "dsanet_n_layers": 2, 15 | "dsanet_n_head": 8, 16 | "dropout": 0.11 17 | }, 18 | "dataset_params": 19 | { "class": "default", 20 | "training_path": "tests/test_data/keag_small.csv", 21 | "validation_path": "tests/test_data/keag_small.csv", 22 | "test_path": "tests/test_data/keag_small.csv", 23 | "batch_size":10, 24 | "forecast_history":10, 25 | "forecast_length":1, 26 | "train_start": 1, 27 | "train_end": 300, 28 | "valid_start":301, 29 | "valid_end": 401, 30 | "test_start":50, 31 | "test_end": 450, 32 | "target_col": ["cfs"], 33 | "relevant_cols": ["cfs", "precip", "temp"], 34 | "scaler": "StandardScaler", 35 | "interpolate": false 36 | }, 37 | "training_params": 38 | { 39 | "criterion":"MSE", 40 | "optimizer": "Adam", 41 | "optim_params": 42 | { 43 | 44 | }, 45 | "lr": 0.03, 46 | "epochs": 1, 47 | "batch_size":4 48 | 49 | }, 50 | "GCS": false, 51 | 52 | "wandb": { 53 | "name": "flood_forecast_circleci", 54 | "tags": ["dummy_run", "circleci"], 55 | "project": "repo-flood_forecast" 56 | }, 57 | "forward_params":{}, 58 | "metrics":["MSE"], 59 | "inference_params": 60 | { 61 | "datetime_start":"2016-05-31", 62 | "hours_to_forecast":334, 63 | "test_csv_path":"tests/test_data/keag_small.csv", 64 | "decoder_params":{ 65 | "decoder_function": "simple_decode", 66 | "unsqueeze_dim": 1} 67 | , 68 | "dataset_params":{ 69 | "file_path": "tests/test_data/keag_small.csv", 70 | "forecast_history":10, 71 | "forecast_length":1, 72 | "relevant_cols": ["cfs", "precip", "temp"], 73 | "target_col": ["cfs"], 74 | "scaling": "StandardScaler", 75 | "interpolate_param": false 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /tests/full_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "SimpleTransformer", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "number_time_series":3, 7 | "seq_length":10, 8 | "output_seq_len": 2 9 | }, 10 | "dataset_params": 11 | { "class": "default", 12 | "training_path": "tests/test_data/keag_small.csv", 13 | "validation_path": "tests/test_data/keag_small.csv", 14 | "test_path": "tests/test_data/keag_small.csv", 15 | "batch_size":5, 16 | "forecast_history":10, 17 | "forecast_length":2, 18 | "train_end": 201, 19 | "valid_start":201, 20 | "valid_end": 220, 21 | "test_start":299, 22 | "test_end": 400, 23 | "target_col": ["cfs"], 24 | "relevant_cols": ["cfs", "precip", "temp"], 25 | "scaler": "StandardScaler", 26 | "interpolate": false 27 | }, 28 | "early_stopping": 29 | { 30 | "patience":2 31 | 32 | }, 33 | "training_params": 34 | { 35 | "criterion":"MSE", 36 | "optimizer": "Adam", 37 | "optim_params": 38 | { 39 | }, 40 | "lr": 0.3, 41 | "epochs": 4, 42 | "batch_size":4 43 | 44 | }, 45 | "GCS": false, 46 | 47 | "wandb": { 48 | "name": "flood_forecast_circleci", 49 | "tags": ["dummy_run", "circleci"], 50 | "project":"repo-flood_forecast" 51 | }, 52 | "forward_params":{ 53 | "t":{} 54 | }, 55 | "takes_target": true, 56 | "metrics":["MSE"], 57 | "inference_params": 58 | { 59 | "datetime_start":"2016-05-31", 60 | "hours_to_forecast":10, 61 | "test_csv_path":"tests/test_data/keag_small.csv", 62 | "decoder_params":{ 63 | "decoder_function": "greedy_decode", 64 | "unsqueeze_dim": 1}, 65 | "dataset_params":{ 66 | "file_path": "tests/test_data/keag_small.csv", 67 | "forecast_history":10, 68 | "forecast_length":2, 69 | "relevant_cols": ["cfs", "precip", "temp"], 70 | "target_col": ["cfs"], 71 | "scaling": "StandardScaler", 72 | "interpolate_param": false 73 | } 74 | } 75 | 76 | 77 | } 78 | -------------------------------------------------------------------------------- /tests/gru_vanilla.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "VanillaGRU", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "n_time_series":3, 7 | "hidden_dim": 20, 8 | "num_layers": 1, 9 | "n_target": 1, 10 | "dropout": 0.11 11 | }, 12 | "dataset_params": 13 | { "class": "default", 14 | "training_path": "tests/test_data/keag_small.csv", 15 | "validation_path": "tests/test_data/keag_small.csv", 16 | "test_path": "tests/test_data/keag_small.csv", 17 | "batch_size":10, 18 | "forecast_history":10, 19 | "forecast_length":1, 20 | "train_start": 1, 21 | "train_end": 300, 22 | "valid_start":302, 23 | "valid_end": 401, 24 | "test_start":50, 25 | "test_end": 450, 26 | "target_col": ["cfs"], 27 | "relevant_cols": ["cfs", "precip", "temp"], 28 | "scaler": "StandardScaler", 29 | "interpolate": false 30 | }, 31 | "training_params": 32 | { 33 | "criterion":"MSE", 34 | "optimizer": "Adam", 35 | "optim_params": 36 | { 37 | 38 | }, 39 | "lr": 0.03, 40 | "epochs": 1, 41 | "batch_size":4 42 | 43 | }, 44 | "GCS": false, 45 | 46 | "wandb": { 47 | "name": "flood_forecast_circleci", 48 | "tags": ["dummy_run", "circleci"], 49 | "project": "repo-flood_forecast" 50 | }, 51 | "forward_params":{}, 52 | "metrics":["MSE"], 53 | "inference_params": 54 | { 55 | "datetime_start":"2016-05-31", 56 | "hours_to_forecast":334, 57 | "test_csv_path":"tests/test_data/keag_small.csv", 58 | "decoder_params":{ 59 | "decoder_function": "simple_decode", 60 | "unsqueeze_dim": 1} 61 | , 62 | "dataset_params":{ 63 | "file_path": "tests/test_data/keag_small.csv", 64 | "forecast_history":10, 65 | "forecast_length":1, 66 | "relevant_cols": ["cfs", "precip", "temp"], 67 | "target_col": ["cfs"], 68 | "scaling": "StandardScaler", 69 | "interpolate_param": false 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /tests/lstm_probabilistic_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "LSTM", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "seq_length": 10, 6 | "n_time_series":3, 7 | "output_seq_len":1, 8 | "batch_size":4, 9 | "probabilistic": true 10 | }, 11 | "dataset_params": 12 | { "class": "default", 13 | "training_path": "tests/test_data/keag_small.csv", 14 | "validation_path": "tests/test_data/keag_small.csv", 15 | "test_path": "tests/test_data/keag_small.csv", 16 | "batch_size":4, 17 | "forecast_history":10, 18 | "forecast_length":1, 19 | "train_start": 2, 20 | "train_end": 301, 21 | "valid_start":301, 22 | "valid_end": 401, 23 | "test_start":50, 24 | "test_end": 450, 25 | "target_col": ["cfs"], 26 | "relevant_cols": ["cfs", "precip", "temp"], 27 | "scaler": "StandardScaler", 28 | "interpolate": false 29 | }, 30 | "training_params": 31 | { 32 | "criterion":"NegativeLogLikelihood", 33 | "probabilistic": true, 34 | "optimizer": "Adam", 35 | "optim_params": 36 | { 37 | 38 | }, 39 | "lr": 0.01, 40 | "epochs": 1, 41 | "batch_size":4 42 | 43 | }, 44 | "GCS": false, 45 | 46 | "wandb": { 47 | "name": "flood_forecast_circleci", 48 | "tags": ["dummy_run", "circleci"], 49 | "project": "repo-flood_forecast" 50 | }, 51 | "forward_params":{}, 52 | "metrics":["NegativeLogLikelihood"], 53 | "inference_params": 54 | { 55 | "datetime_start":"2016-05-31", 56 | "hours_to_forecast":334, 57 | "test_csv_path":"tests/test_data/keag_small.csv", 58 | "probabilistic": true, 59 | "decoder_params":{ 60 | "decoder_function": "simple_decode", 61 | "unsqueeze_dim": 1, "probabilistic": true} 62 | , 63 | "dataset_params":{ 64 | "file_path": "tests/test_data/keag_small.csv", 65 | "forecast_history":10, 66 | "forecast_length":1, 67 | "relevant_cols": ["cfs", "precip", "temp"], 68 | "target_col": ["cfs"], 69 | "scaling": "StandardScaler", 70 | "interpolate_param": false 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /tests/lstm_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "LSTM", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "seq_length": 10, 6 | "n_time_series":3, 7 | "output_seq_len":1, 8 | "batch_size":4 9 | }, 10 | "dataset_params": 11 | { "class": "default", 12 | "training_path": "tests/test_data/keag_small.csv", 13 | "validation_path": "tests/test_data/keag_small.csv", 14 | "test_path": "tests/test_data/keag_small.csv", 15 | "batch_size":4, 16 | "forecast_history":10, 17 | "forecast_length":1, 18 | "train_start": 1, 19 | "train_end": 300, 20 | "valid_start":301, 21 | "valid_end": 401, 22 | "test_start":50, 23 | "test_end": 450, 24 | "target_col": ["cfs"], 25 | "relevant_cols": ["cfs", "precip", "temp"], 26 | "scaler": "StandardScaler", 27 | "interpolate": false 28 | }, 29 | "training_params": 30 | { 31 | "criterion":"MSE", 32 | "optimizer": "Adam", 33 | "optim_params": 34 | { 35 | 36 | }, 37 | "lr": 0.3, 38 | "epochs": 1, 39 | "batch_size":4 40 | 41 | }, 42 | "GCS": false, 43 | 44 | "wandb": { 45 | "name": "flood_forecast_circleci", 46 | "tags": ["dummy_run", "circleci"], 47 | "project": "repo-flood_forecast" 48 | }, 49 | "forward_params":{}, 50 | "metrics":["MSE"], 51 | "inference_params": 52 | { 53 | "datetime_start":"2016-05-31", 54 | "hours_to_forecast":334, 55 | "test_csv_path":"tests/test_data/keag_small.csv", 56 | "decoder_params":{ 57 | "decoder_function": "simple_decode", 58 | "unsqueeze_dim": 1} 59 | , 60 | "dataset_params":{ 61 | "file_path": "tests/test_data/keag_small.csv", 62 | "forecast_history":10, 63 | "forecast_length":1, 64 | "relevant_cols": ["cfs", "precip", "temp"], 65 | "target_col": ["cfs"], 66 | "scaling": "StandardScaler", 67 | "interpolate_param": false 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /tests/meta_data_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "CustomTransformerDecoder", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "seq_length":5, 7 | "output_seq_length": 1, 8 | "n_layers_encoder": 6, 9 | "meta_data": { 10 | "method": "Bilinear", 11 | "params": 12 | {"in1_features":5, 13 | "in2_features":1, 14 | "out_features":5} 15 | } 16 | 17 | }, 18 | "meta_data":{ 19 | "path":"tests/auto_encoder.json", 20 | "column_id": "datetime", 21 | "uuid": "2014-05-01 01:00:00", 22 | "meta_loss":"MSE" 23 | }, 24 | "dataset_params": 25 | { "class": "default", 26 | "training_path": "tests/test_data/keag_small.csv", 27 | "validation_path": "tests/test_data/keag_small.csv", 28 | "test_path": "tests/test_data/keag_small.csv", 29 | "batch_size":4, 30 | "forecast_history":5, 31 | "forecast_length":1, 32 | "train_end": 100, 33 | "valid_start":301, 34 | "valid_end": 401, 35 | "test_end":400, 36 | "target_col": ["cfs"], 37 | "relevant_cols": ["cfs", "precip", "temp"], 38 | "scaler": "StandardScaler", 39 | "interpolate": false 40 | }, 41 | "training_params": 42 | { 43 | "criterion":"RMSE", 44 | "optimizer": "Adam", 45 | "optim_params": 46 | { 47 | 48 | }, 49 | "lr": 0.3, 50 | "epochs": 1, 51 | "batch_size":4 52 | 53 | }, 54 | "GCS": false, 55 | 56 | "wandb": { 57 | "name": "flood_forecast_circleci", 58 | "project": "repo-flood_forecast", 59 | "tags": ["dummy_run", "circleci"] 60 | }, 61 | "forward_params":{}, 62 | "metrics":["MSE"], 63 | "inference_params": 64 | { 65 | "datetime_start":"2016-05-31", 66 | "hours_to_forecast":336, 67 | "test_csv_path":"tests/test_data/keag_small.csv", 68 | "decoder_params":{ 69 | "decoder_function": "simple_decode", 70 | "unsqueeze_dim": 1}, 71 | "dataset_params":{ 72 | "file_path": "tests/test_data/keag_small.csv", 73 | "forecast_history":5, 74 | "forecast_length":1, 75 | "relevant_cols": ["cfs", "precip", "temp"], 76 | "target_col": ["cfs"], 77 | "scaling": "StandardScaler", 78 | "interpolate_param": false 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /tests/meta_data_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | 4 | 5 | class MetaDataTests(unittest.TestCase): 6 | def setUp(self): 7 | self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") 8 | 9 | if __name__ == '__main__': 10 | unittest.main() 11 | -------------------------------------------------------------------------------- /tests/multi_config.json: -------------------------------------------------------------------------------- 1 | {"model_name": "CustomTransformerDecoder", "model_type": "PyTorch", "n_targets": 2, "model_params": {"dropout": 0.1, "seq_length": 11, "n_time_series": 18, "output_dim": 2, "output_seq_length": 1, "n_layers_encoder": 2, "use_mask": true}, "dataset_params": {"class": "default", "num_workers": 5, "forecast_test_len": 20, "pin_memory": true, "training_path": "/content/flow-forecast/miami_f.csv", "validation_path": "/content/flow-forecast/miami_f.csv", "test_path": "/content/flow-forecast/miami_f.csv", "batch_size": 10, "forecast_history": 11, "forecast_length": 1, "scaler": "StandardScaler", "train_start": 0, "train_end": 170, "valid_start": 170, "valid_end": 310, "sort_column": "date", "test_start": 170, "test_end": 310, "target_col": ["rolling_7", "rolling_deaths"], "relevant_cols": ["rolling_7", "rolling_deaths", "mobility_retail_recreation", "mobility_grocery_pharmacy", "mobility_parks", "mobility_transit_stations", "mobility_workplaces", "mobility_residential", "avg_temperature", "min_temperature", "max_temperature", "relative_humidity", "specific_humidity", "pressure"], "feature_param": {"datetime_params": {"day_of_week": "cyclical", "month": "cyclical"}}, "interpolate": false}, "training_params": {"criterion": "MSE", "optimizer": "SGD", "optim_params": {"lr": 0.0001}, "epochs": 10, "batch_size": 10}, "early_stopping": {"patience": 3}, "GCS": true, "sweep": true, "wandb": false, "forward_params": {}, "metrics": ["MSE"], "inference_params": {"datetime_start": "2020-12-14", "hours_to_forecast": 18, "num_prediction_samples": 20, "test_csv_path": "/content/flow-forecast/miami_f.csv", "decoder_params": {"decoder_function": "simple_decode", "unsqueeze_dim": 1}, "dataset_params": {"file_path": "/content/flow-forecast/miami_f.csv", "sort_column": "date", "scaling": "StandardScaler", "forecast_history": 11, "forecast_length": 1, "relevant_cols": ["rolling_7", "rolling_deaths", "mobility_retail_recreation", "mobility_grocery_pharmacy", "mobility_parks", "mobility_transit_stations", "mobility_workplaces", "mobility_residential", "avg_temperature", "min_temperature", "max_temperature", "relative_humidity", "specific_humidity", "pressure"], "target_col": ["rolling_7", "rolling_deaths"], "interpolate_param": false, "feature_params": {"datetime_params": {"day_of_week": "cyclical", "month": "cyclical"}}}}, "meta_data": false, "run": [{"epoch": 0, "train_loss": "1.1954958769492805", "validation_loss": "85.43445341289043"}, {"epoch": 1, "train_loss": "1.1476804590784013", "validation_loss": "84.1799928843975"}, {"epoch": 2, "train_loss": "1.065674600424245", "validation_loss": "84.03104758262634"}, {"epoch": 3, "train_loss": "1.0211504658218473", "validation_loss": "84.54550993442535"}, {"epoch": 4, "train_loss": "0.9789167386479676", "validation_loss": "85.40744817256927"}, {"epoch": 5, "train_loss": "0.9342440171167254", "validation_loss": "86.52448198199272"}]} 2 | -------------------------------------------------------------------------------- /tests/multi_decoder_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "CustomTransformerDecoder", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "seq_length":5, 7 | "output_seq_length": 1, 8 | "n_layers_encoder": 6, 9 | "output_dim":1, 10 | "final_act":"Swish" 11 | }, 12 | "dataset_params": 13 | { "class": "default", 14 | "training_path": "tests/test_data/keag_small.csv", 15 | "validation_path": "tests/test_data/keag_small.csv", 16 | "test_path": "tests/test_data/keag_small.csv", 17 | "batch_size":11, 18 | "forecast_history":5, 19 | "forecast_length":1, 20 | "train_end": 100, 21 | "valid_start":101, 22 | "valid_end": 201, 23 | "test_start": 202, 24 | "test_end": 290, 25 | "target_col": ["cfs"], 26 | "relevant_cols": ["cfs", "precip", "temp"], 27 | "scaler": "MinMaxScaler", 28 | "scaler_params":{ 29 | "feature_range":[0, 2] 30 | }, 31 | "interpolate": false 32 | }, 33 | "training_params": 34 | { 35 | "criterion":"MAPE", 36 | "optimizer": "Adam", 37 | "optim_params": 38 | { 39 | 40 | }, 41 | "lr": 0.3, 42 | "epochs": 1, 43 | "batch_size":4 44 | 45 | }, 46 | "GCS": false, 47 | 48 | "wandb": { 49 | "name": "flood_forecast_circleci", 50 | "project": "repo-flood_forecast", 51 | "tags": ["dummy_run", "circleci"] 52 | }, 53 | "forward_params":{}, 54 | "metrics":["MSE"], 55 | "inference_params": 56 | { 57 | "datetime_start":"2016-05-31", 58 | "num_prediction_samples":10, 59 | "hours_to_forecast":336, 60 | "test_csv_path":"tests/test_data/keag_small.csv", 61 | "decoder_params":{ 62 | "decoder_function": "simple_decode", 63 | "unsqueeze_dim": 1}, 64 | "dataset_params":{ 65 | "file_path": "tests/test_data/keag_small.csv", 66 | "forecast_history":5, 67 | "forecast_length":1, 68 | "relevant_cols": ["cfs", "precip", "temp"], 69 | "target_col": ["cfs"], 70 | "scaling": "MinMaxScaler", 71 | "scaler_params":{ 72 | "feature_range":[0,2] 73 | }, 74 | "interpolate_param": false 75 | } 76 | } 77 | 78 | 79 | } 80 | -------------------------------------------------------------------------------- /tests/multi_modal_tests/test_cross_vivit.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from flood_forecast.multi_models.crossvivit import RoCrossViViT, VisionTransformer 4 | from flood_forecast.transformer_xl.attn import SelfAttention 5 | from flood_forecast.transformer_xl.data_embedding import ( 6 | CyclicalEmbedding, 7 | NeRF_embedding, 8 | PositionalEncoding2D, 9 | ) 10 | 11 | 12 | class TestCrossVivVit(unittest.TestCase): 13 | def setUp(self): 14 | self.crossvivit = RoCrossViViT( 15 | image_size=(120, 120), 16 | patch_size=(8, 8), 17 | time_coords_encoder=CyclicalEmbedding(), 18 | ctx_channels=12, 19 | num_time_series=12, 20 | dim=128, 21 | depth=4, 22 | heads=4, 23 | mlp_ratio=4, 24 | forecast_history=10, 25 | out_dim=1, 26 | dropout=0.0, 27 | video_cat_dim=2, 28 | axial_kwargs={"max_freq": 12}, 29 | ) 30 | 31 | def test_positional_encoding_forward(self): 32 | """Test the positional encoding forward pass with a PositionalEncoding2D layer.""" 33 | positional_encoding = PositionalEncoding2D(channels=2) 34 | # Coordinates with format [B, 2, H, W] 35 | coords = torch.rand(5, 2, 32, 32) 36 | output = positional_encoding(coords) 37 | self.assertEqual(output.shape, (5, 32, 32, 4)) 38 | 39 | def test_vivit_model(self): 40 | """Tests the Vision Video Transformer VIVIT model with simulated image data.""" 41 | self.vivit_model = VisionTransformer( 42 | dim=128, depth=5, heads=8, dim_head=128, mlp_dim=128, dropout=0.1 43 | ) 44 | out = self.vivit_model( 45 | torch.rand(5, 512, 128), (torch.rand(5, 512, 64), torch.rand(5, 512, 64)) 46 | ) 47 | assert out[0].shape == (5, 512, 128) 48 | 49 | def test_forward(self): 50 | """This tests the forward pass of the RoCrossVIVIT model from the CrossVIVIT paper. 51 | 52 | ctx (torch.Tensor): Context frames of shape [batch_size, number_time_stamps, number_channels, height, wid] 53 | ctx_coords (torch.Tensor): Coordinates of context frames of shape [B, 2, H, W] 54 | ts (torch.Tensor): Station timeseries of shape [B, T, C] 55 | ts_coords (torch.Tensor): Station coordinates of shape [B, 2, 1, 1] 56 | time_coords (torch.Tensor): Time coordinates of shape [B, T, C, H, W] 57 | mask (bool): Whether to mask or not. Useful for inference. 58 | video_context: Float[torch.Tensor, "batch time ctx_channels height width"], 59 | context_coords: Float[torch.Tensor, "batch 2 height width"], 60 | timeseries: Float[torch.Tensor, "batch time num_time_series"], 61 | timeseries_spatial_coordinates: Float[torch.Tensor, "batch 2 1 1"], 62 | ts_positional_encoding 63 | """ 64 | # Construct a context tensor this tensor will 65 | ctx_tensor = torch.rand(5, 10, 12, 120, 120) 66 | ctx_coords = torch.rand(5, 2, 120, 120) 67 | ts = torch.rand(5, 10, 12) 68 | time_coords1 = torch.rand(5, 10, 4, 120, 120) 69 | ts_coords = torch.rand(5, 2, 1, 1) 70 | x = self.crossvivit( 71 | video_context=ctx_tensor, 72 | context_coords=ctx_coords, 73 | timeseries=ts, 74 | timeseries_spatial_coordinates=ts_coords, 75 | ts_positional_encoding=time_coords1, 76 | ) 77 | self.assertEqual(x[0].shape, (5, 10, 1, 1)) 78 | 79 | def test_self_attention_dims(self): 80 | """Test the self attention layer with the correct dimensions.""" 81 | self.self_attention = SelfAttention(dim=128, use_rotary=True) 82 | self.self_attention( 83 | torch.rand(5, 512, 128), (torch.rand(5, 512, 64), torch.rand(5, 512, 64)) 84 | ) 85 | 86 | def test_neRF_embedding(self): 87 | """Test the NeRF embedding layer.""" 88 | nerf_embedding = NeRF_embedding(n_layers=128) 89 | coords = torch.rand(5, 2, 32, 32) 90 | output = nerf_embedding(coords) 91 | self.assertEqual(output.shape, (5, 512, 32, 32)) 92 | 93 | 94 | if __name__ == "__main__": 95 | unittest.main() 96 | -------------------------------------------------------------------------------- /tests/multi_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "MultiAttnHeadSimple", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "number_time_series":3, 6 | "seq_len":5 7 | }, 8 | "dataset_params": 9 | { "class": "default", 10 | "num_workers": 3, 11 | "training_path": "tests/test_data/keag_small.csv", 12 | "validation_path": "tests/test_data/keag_small.csv", 13 | "test_path": "tests/test_data/keag_small.csv", 14 | "batch_size":4, 15 | "forecast_history":5, 16 | "forecast_length":5, 17 | "train_end": 190, 18 | "valid_start":301, 19 | "valid_end": 401, 20 | "test_end": 500, 21 | "target_col": ["cfs"], 22 | "relevant_cols": ["cfs", "precip", "temp"], 23 | "sort_column":"datetime", 24 | "interpolate": false 25 | }, 26 | "training_params": 27 | { 28 | "criterion":"MSE", 29 | "optimizer": "Adam", 30 | "optim_params": 31 | { 32 | }, 33 | "lr": 0.3, 34 | "epochs": 1, 35 | "batch_size":4 36 | }, 37 | "GCS": false, 38 | 39 | "wandb": { 40 | "name": "flood_forecast_circleci", 41 | "tags": ["dummy_run", "circleci"], 42 | "project": "repo-flood_forecast" 43 | }, 44 | "forward_params":{}, 45 | "metrics":["MSE"], 46 | "inference_params": 47 | { "num_prediction_samples": 100, 48 | "datetime_start":"2016-05-31", 49 | "hours_to_forecast":336, 50 | "test_csv_path":"tests/test_data/keag_small.csv", 51 | "dataset_params":{ 52 | "file_path": "tests/test_data/keag_small.csv", 53 | "forecast_history":5, 54 | "forecast_length":5, 55 | "relevant_cols": ["cfs", "precip", "temp"], 56 | "target_col": ["cfs"], 57 | "interpolate_param": false 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /tests/multitask_decoder.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "CustomTransformerDecoder", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "seq_length":5, 7 | "output_seq_length": 1, 8 | "n_layers_encoder": 6, 9 | "output_dim":2, 10 | "final_act":"Swish" 11 | }, 12 | "n_targets":2, 13 | "dataset_params": 14 | { "class": "default", 15 | "training_path": "tests/test_data/keag_small.csv", 16 | "validation_path": "tests/test_data/keag_small.csv", 17 | "test_path": "tests/test_data/keag_small.csv", 18 | "batch_size":10, 19 | "forecast_history":5, 20 | "forecast_length":1, 21 | "train_end": 100, 22 | "valid_start":101, 23 | "valid_end": 301, 24 | "test_start": 202, 25 | "test_end": 290, 26 | "no_scale": true, 27 | "target_col": ["cfs", "temp"], 28 | "relevant_cols": ["cfs", "precip", "temp"], 29 | "scaler": "MinMaxScaler", 30 | "scaler_params":{ 31 | "feature_range":[0, 2] 32 | }, 33 | "interpolate": false 34 | }, 35 | "training_params": 36 | { 37 | "criterion":"MSE", 38 | "optimizer": "Adam", 39 | "optim_params": 40 | { 41 | }, 42 | "lr": 0.3, 43 | "epochs": 1, 44 | "batch_size":4 45 | 46 | }, 47 | "GCS": false, 48 | 49 | "wandb": { 50 | "name": "flood_forecast_circleci", 51 | "project": "repo-flood_forecast", 52 | "tags": ["dummy_run", "circleci"] 53 | }, 54 | "forward_params":{}, 55 | "metrics":["MSE"], 56 | "inference_params": 57 | { 58 | "datetime_start":"2016-05-31", 59 | "hours_to_forecast":356, 60 | "test_csv_path":"tests/test_data/keag_small.csv", 61 | "decoder_params":{ 62 | "decoder_function": "simple_decode", 63 | "unsqueeze_dim": 1}, 64 | "dataset_params":{ 65 | "file_path": "tests/test_data/keag_small.csv", 66 | "forecast_history":5, 67 | "forecast_length":1, 68 | "relevant_cols": ["cfs", "precip", "temp"], 69 | "target_col": ["cfs", "temp"], 70 | "no_scale": true, 71 | "scaling": "MinMaxScaler", 72 | "scaler_params":{ 73 | "feature_range":[0,2] 74 | }, 75 | "interpolate_param": false 76 | } 77 | } 78 | 79 | 80 | } 81 | -------------------------------------------------------------------------------- /tests/nlinear.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "NLinear", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "forecast_history":20, 7 | "forecast_length": 10, 8 | "enc_in": 3, 9 | "individual": true 10 | 11 | }, 12 | "dataset_params": 13 | { "class": "default", 14 | "training_path": "tests/test_data/keag_small.csv", 15 | "validation_path": "tests/test_data/keag_small.csv", 16 | "test_path": "tests/test_data/keag_small.csv", 17 | "forecast_history":20, 18 | "forecast_length":10, 19 | "train_start": 1, 20 | "train_end": 300, 21 | "valid_start":302, 22 | "valid_end": 401, 23 | "test_start":50, 24 | "test_end": 450, 25 | "target_col": ["cfs"], 26 | "relevant_cols": ["cfs", "precip", "temp"], 27 | "scaler": "StandardScaler", 28 | "interpolate": false 29 | }, 30 | "training_params": 31 | { 32 | "batch_size":10, 33 | "criterion":"MSE", 34 | "optimizer": "Adam", 35 | "optim_params": 36 | { 37 | 38 | }, 39 | "lr": 0.03, 40 | "epochs": 1 41 | 42 | }, 43 | "GCS": false, 44 | 45 | "wandb": { 46 | "name": "flood_forecast_circleci", 47 | "tags": ["dummy_run", "circleci"], 48 | "project": "repo-flood_forecast" 49 | }, 50 | "forward_params":{}, 51 | "metrics":["MSE"], 52 | "inference_params": 53 | { 54 | "datetime_start":"2016-05-31", 55 | "hours_to_forecast":334, 56 | "test_csv_path":"tests/test_data/keag_small.csv", 57 | "decoder_params":{ 58 | "decoder_function": "simple_decode", 59 | "unsqueeze_dim": 1} 60 | , 61 | "dataset_params":{ 62 | "file_path": "tests/test_data/keag_small.csv", 63 | "forecast_history":20, 64 | "forecast_length":10, 65 | "relevant_cols": ["cfs", "precip", "temp"], 66 | "target_col": ["cfs"], 67 | "scaling": "StandardScaler", 68 | "interpolate_param": false 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /tests/probabilistic_linear_regression_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "SimpleLinearModel", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "seq_length": 10, 6 | "n_time_series":3, 7 | "output_seq_len":1, 8 | "probabilistic": true 9 | }, 10 | "dataset_params": 11 | { "class": "default", 12 | "training_path": "tests/test_data/keag_small.csv", 13 | "validation_path": "tests/test_data/keag_small.csv", 14 | "test_path": "tests/test_data/keag_small.csv", 15 | "batch_size":4, 16 | "forecast_history":10, 17 | "forecast_length":1, 18 | "forecast_test_len": 10, 19 | "train_start": 1, 20 | "train_end": 300, 21 | "valid_start":301, 22 | "valid_end": 401, 23 | "test_start":380, 24 | "test_end": 450, 25 | "target_col": ["cfs"], 26 | "relevant_cols": ["cfs", "precip", "temp"], 27 | "scaler": "StandardScaler", 28 | "interpolate": { 29 | "method":"back_forward_generic", 30 | "params":{ 31 | "relevant_columns":["cfs"] 32 | } 33 | } 34 | }, 35 | "training_params": 36 | { 37 | "criterion":"NegativeLogLikelihood", 38 | "probabilistic": true, 39 | "optimizer": "Adam", 40 | "optim_params": 41 | { 42 | }, 43 | "lr": 0.01, 44 | "epochs": 1, 45 | "batch_size":4 46 | 47 | }, 48 | "GCS": false, 49 | 50 | "wandb": { 51 | "name": "flood_forecast_circleci", 52 | "tags": ["dummy_run", "circleci"], 53 | "project": "repo-flood_forecast" 54 | }, 55 | "forward_params":{}, 56 | "metrics":["NegativeLogLikelihood"], 57 | "inference_params": 58 | { 59 | "datetime_start":"2016-05-31", 60 | "hours_to_forecast":334, 61 | "test_csv_path":"tests/test_data/keag_small.csv", 62 | "probabilistic": true, 63 | "decoder_params":{ 64 | "decoder_function": "simple_decode", 65 | "unsqueeze_dim": 1, "probabilistic": true} 66 | , 67 | "dataset_params":{ 68 | "file_path": "tests/test_data/keag_small.csv", 69 | "forecast_history":10, 70 | "forecast_length":1, 71 | "relevant_cols": ["cfs", "precip", "temp"], 72 | "target_col": ["cfs"], 73 | "scaling": "StandardScaler", 74 | "interpolate_param": false 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /tests/scaling_json.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "CustomTransformerDecoder", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "seq_length":5, 7 | "output_seq_length": 1, 8 | "n_layers_encoder": 6, 9 | "output_dim":2, 10 | "final_act":"Swish" 11 | }, 12 | "n_targets":2, 13 | "dataset_params": 14 | { "class": "default", 15 | "training_path": "tests/test_data/keag_small.csv", 16 | "validation_path": "tests/test_data/keag_small.csv", 17 | "test_path": "tests/test_data/keag_small.csv", 18 | "batch_size":10, 19 | "forecast_history":5, 20 | "forecast_length":1, 21 | "train_end": 100, 22 | "valid_start":101, 23 | "valid_end": 301, 24 | "test_start": 202, 25 | "test_end": 290, 26 | "no_scale": true, 27 | "target_col": ["cfs", "temp"], 28 | "relevant_cols": ["cfs", "precip", "temp"], 29 | "scaler": "MinMaxScaler", 30 | "scaler_params":{ 31 | "feature_range":[0, 2] 32 | }, 33 | "interpolate": false 34 | }, 35 | "training_params": 36 | { 37 | "criterion":"MSE", 38 | "optimizer": "Adam", 39 | "optim_params": 40 | { 41 | }, 42 | "lr": 0.3, 43 | "epochs": 1, 44 | "batch_size":4 45 | 46 | }, 47 | "GCS": false, 48 | 49 | "wandb": { 50 | "name": "flood_forecast_circleci", 51 | "project": "repo-flood_forecast", 52 | "tags": ["dummy_run", "circleci"] 53 | }, 54 | "forward_params":{}, 55 | "metrics":["MSE"], 56 | "inference_params": 57 | { 58 | "datetime_start":"2016-05-31", 59 | "hours_to_forecast":336, 60 | "test_csv_path":"tests/test_data/keag_small.csv", 61 | "decoder_params":{ 62 | "decoder_function": "simple_decode", 63 | "unsqueeze_dim": 1} 64 | } 65 | 66 | 67 | } 68 | -------------------------------------------------------------------------------- /tests/test_attn.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.transformer_xl.attn import ProbAttention, FullAttention 2 | from flood_forecast.transformer_xl.masks import TriangularCausalMask 3 | from flood_forecast.transformer_xl.dsanet import Single_Local_SelfAttn_Module 4 | import unittest 5 | import torch 6 | 7 | 8 | class TestAttention(unittest.TestCase): 9 | def setUp(self): 10 | """""" 11 | self.prob_attention = ProbAttention() 12 | self.full_attention = FullAttention() 13 | self.triangle = TriangularCausalMask(2, 20) 14 | 15 | def test_prob_attn(self): 16 | # B, L, H, D (where B is batch_size, L is sequence length, H is number of heads, and D is embedding dim) 17 | a = torch.rand(2, 20, 8, 30) 18 | r = self.prob_attention(torch.rand(2, 20, 8, 30), a, torch.rand(2, 20, 8, 30), self.triangle) 19 | self.assertGreater(len(r[0].shape), 2) 20 | self.assertIsInstance(r[0], torch.Tensor) 21 | 22 | def test_full_attn(self): 23 | # Tests the full attention mechanism and 24 | t = torch.rand(2, 20, 8, 30) 25 | a = self.full_attention(torch.rand(2, 20, 8, 30), t, t, self.triangle) 26 | self.assertIsInstance(a[0], torch.Tensor) 27 | self.assertEqual(len(a[0].shape), 4) 28 | self.assertEqual(a[0].shape[0], 2) 29 | 30 | def test_single_local(self): 31 | Single_Local_SelfAttn_Module(10, 4, 10, 5, 1, 128, 128, 128, 32, 2, 8) 32 | -------------------------------------------------------------------------------- /tests/test_classification2_loader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from flood_forecast.preprocessing.pytorch_loaders import GeneralClassificationLoader 4 | import torch 5 | from flood_forecast.model_dict_function import pytorch_criterion_dict 6 | 7 | 8 | class TestGeneralClassificationCSVLoader(unittest.TestCase): 9 | def setUp(self): 10 | self.test_data_path = os.path.join( 11 | os.path.dirname(os.path.abspath(__file__)), "test_data" 12 | ) 13 | self.dataset_params = { 14 | "file_path": os.path.join(self.test_data_path, "test2.csv"), 15 | "sequence_length": 20, 16 | "relevant_cols": ["vel", "obs", "day_of_week"], 17 | "target_col": ["vel"], 18 | "interpolate_param": False, 19 | } 20 | self.data_loader = GeneralClassificationLoader(self.dataset_params.copy(), 7) 21 | 22 | def test_classification_return(self): 23 | """Tests the series_id method for one.""" 24 | x, y = self.data_loader[0] 25 | self.assertIsInstance(x, torch.Tensor) 26 | self.assertIsInstance(y, torch.Tensor) 27 | self.assertGreater(x.shape[0], 1) 28 | self.assertGreater(x.shape[1], 1) 29 | 30 | def test_class(self): 31 | """Tests the classification of a dl module.""" 32 | x, y = self.data_loader[1] 33 | self.assertIsInstance(x, torch.Tensor) 34 | self.assertIsInstance(y, torch.Tensor) 35 | print("y is below") 36 | print(y) 37 | self.assertEqual(y.shape[1], 7) 38 | 39 | def test_bce_stuff(self): 40 | loss = pytorch_criterion_dict["CrossEntropyLoss"]() 41 | x, y = self.data_loader[1] 42 | the_loss = loss(torch.rand(1, 7), y.max(dim=1)[1]).item() 43 | self.assertGreater(the_loss, 0) 44 | 45 | if __name__ == '__main__': 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/test_da_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | import os 4 | import tempfile 5 | from flood_forecast.preprocessing.preprocess_da_rnn import make_data 6 | from flood_forecast.da_rnn.train_da import da_rnn, train 7 | 8 | 9 | class TestDARNN(unittest.TestCase): 10 | def setUp(self): 11 | self.preprocessed_data = self.preprocessed_data = make_data(os.path.join( 12 | os.path.dirname(__file__), "test_init", "keag_small.csv"), ["cfs"], 72) 13 | 14 | def test_train_model(self): 15 | with tempfile.TemporaryDirectory() as param_directory: 16 | config, da_network = da_rnn(self.preprocessed_data, 1, 64, 17 | param_output_path=param_directory) 18 | loss_results, model = train(da_network, self.preprocessed_data, 19 | config, n_epochs=1, tensorboard=True) 20 | self.assertTrue(model) 21 | 22 | def test_tf_data(self): 23 | dirname = os.path.dirname(__file__) 24 | # Test that Tensorboard directory was indeed created 25 | self.assertTrue(os.listdir(os.path.join(dirname))) 26 | 27 | def test_create_model(self): 28 | with tempfile.TemporaryDirectory() as param_directory: 29 | config, dnn_network = da_rnn(self.preprocessed_data, 1, 64, 30 | param_output_path=param_directory) 31 | self.assertNotEqual(config.batch_size, 20) 32 | self.assertIsNotNone(dnn_network) 33 | 34 | def test_resume_ckpt(self): 35 | """This test is dependent on test_train_model succeding.""" 36 | config, da = da_rnn(self.preprocessed_data, 1, 64) 37 | with tempfile.TemporaryDirectory() as checkpoint: 38 | torch.save(da.encoder.state_dict(), os.path.join(checkpoint, "encoder.pth")) 39 | torch.save(da.decoder.state_dict(), os.path.join(checkpoint, "decoder.pth")) 40 | config, dnn_network = da_rnn(self.preprocessed_data, 1, 64, save_path=checkpoint) 41 | self.assertTrue(dnn_network) 42 | 43 | if __name__ == '__main__': 44 | unittest.main() 45 | -------------------------------------------------------------------------------- /tests/test_data/asos-12N_small.csv: -------------------------------------------------------------------------------- 1 | ,hour_updated,p01m,valid,tmpf 2 | 0,2014-01-01 01:00:00,0.0,2014-01-01 00:54,24.98 3 | 1,2014-01-01 02:00:00,0.0,2014-01-01 01:54,24.08 4 | 2,2014-01-01 03:00:00,0.0,2014-01-01 02:54,21.92 5 | 3,2014-01-01 04:00:00,0.0,2014-01-01 03:54,21.02 6 | 4,2014-01-01 05:00:00,0.0,2014-01-01 04:54,19.94 7 | 5,2014-01-01 06:00:00,0.0,2014-01-01 05:54,19.04 8 | 6,2014-01-01 07:00:00,0.0,2014-01-01 06:54,19.04 9 | 7,2014-01-01 08:00:00,0.0,2014-01-01 07:54,19.04 10 | 8,2014-01-01 09:00:00,0.0,2014-01-01 08:54,15.98 11 | 9,2014-01-01 10:00:00,0.0,2014-01-01 09:54,15.98 12 | 10,2014-01-01 11:00:00,0.0,2014-01-01 10:54,15.08 13 | 11,2014-01-01 12:00:00,0.0,2014-01-01 11:54,14.0 14 | 12,2014-01-01 13:00:00,0.0,2014-01-01 12:54,14.0 15 | 13,2014-01-01 14:00:00,0.0,2014-01-01 13:54,17.96 16 | 14,2014-01-01 15:00:00,0.0,2014-01-01 14:54,23.0 17 | 15,2014-01-01 16:00:00,0.0,2014-01-01 15:54,26.06 18 | 16,2014-01-01 17:00:00,0.0,2014-01-01 16:54,28.04 19 | 17,2014-01-01 18:00:00,0.0,2014-01-01 17:54,28.04 20 | 18,2014-01-01 19:00:00,0.0,2014-01-01 18:54,28.94 21 | 19,2014-01-01 20:00:00,0.0,2014-01-01 19:54,30.02 22 | -------------------------------------------------------------------------------- /tests/test_data/asos_process.json: -------------------------------------------------------------------------------- 1 | {"gage_id": 10302002, "stations": [{"station_id": "HTH", "dist": 46.435053412548825, "cat": "ASOS"}, {"station_id": "NFL", "dist": 53.70693602675465, "cat": "ASOS"}, {"station_id": "MEV", "dist": 82.04322213390148, "cat": "ASOS"}, {"station_id": "KCXP", "dist": 84.65071932778115, "cat": "ASOS"}, {"station_id": "CXP", "dist": 84.94858800475141, "cat": "ASOS"}, {"station_id": "BAN", "dist": 89.32502180345543, "cat": "ASOS"}, {"station_id": "KBAN", "dist": 89.90074913931772, "cat": "ASOS"}, {"station_id": "KRNO", "dist": 102.95297368839903, "cat": "ASOS"}, {"station_id": "RNO", "dist": 102.95950352597654, "cat": "ASOS"}, {"station_id": "KTVL", "dist": 103.1369712179468, "cat": "ASOS"}, {"station_id": "TVL", "dist": 103.21884385169895, "cat": "ASOS"}]} 2 | -------------------------------------------------------------------------------- /tests/test_data/big_black_md.json: -------------------------------------------------------------------------------- 1 | {"gage_id": 1010070, "stations": [{"station_id": "40B", "dist": 35.86286975647646, "cat": "ASOS", "missing_precip": 0, "missing_temp": 39}, {"station_id": "CWST", "dist": 55.57017989683566, "cat": "ASOS", "missing_precip": 0, "missing_temp": 0}, {"station_id": "CWIG", "dist": 62.55508196187881, "cat": "ASOS", "missing_precip": 0, "missing_temp": 42418}, {"station_id": "CWIS", "dist": 79.88055804274804, "cat": "ASOS", "missing_precip": 0, "missing_temp": 0}, {"station_id": "CWTN", "dist": 80.72267697891424, "cat": "ASOS", "missing_precip": 0, "missing_temp": 0}, {"station_id": "CWER", "dist": 81.02146645992319, "cat": "ASOS", "missing_precip": 0, "missing_temp": 42174}, {"station_id": "CWNH", "dist": 102.55052000344958, "cat": "ASOS", "missing_precip": 0, "missing_temp": 0}, {"station_id": "CWHV", "dist": 110.05272501753285, "cat": "ASOS", "missing_precip": 0, "missing_temp": 0}, {"station_id": "CXBO", "dist": 110.07973477851439, "cat": "ASOS", "missing_precip": 0, "missing_temp": 42654}, {"station_id": "CMFM", "dist": 115.95899164844458, "cat": "ASOS", "missing_precip": 0, "missing_temp": 1}, {"station_id": "KFVE", "dist": 117.27669180798962, "cat": "ASOS", "missing_precip": 203752, "missing_temp": 217205}, {"station_id": "FVE", "dist": 117.31715057603458, "cat": "ASOS", "missing_precip": 203752, "missing_temp": 217205}, {"station_id": "CWJB", "dist": 117.50331039853623, "cat": "ASOS", "missing_precip": 0, "missing_temp": 0}, {"station_id": "CERM", "dist": 122.57597870894188, "cat": "ASOS", "missing_precip": 0, "missing_temp": 315}, {"station_id": "CYQB", "dist": 125.38125174466414, "cat": "ASOS", "missing_precip": 0, "missing_temp": 3}, {"station_id": "CWAF", "dist": 131.48254619436221, "cat": "ASOS", "missing_precip": 0, "missing_temp": 0}, {"station_id": "CAR", "dist": 131.63453862667052, "cat": "ASOS", "missing_precip": 215342, "missing_temp": 232959}, {"station_id": "KPQI", "dist": 131.65095845184678, "cat": "ASOS", "missing_precip": 220363, "missing_temp": 234985}], "time_zone_code": "EDT", "max_flow": 5870.0, "min_flow": 12.3, "nan_flow": 13627, "nan_precip": 678, "files": ["101007040B_flow.csv"]} 2 | -------------------------------------------------------------------------------- /tests/test_data/farm_ex.csv: -------------------------------------------------------------------------------- 1 | FarmNumber,NumberOfAnimals,MeanMilkProduction (kg/lactation),FarmType, 2 | 2,95,8326,0, 3 | 11,190,9679,0, 4 | 10,165,9452,0, 5 | 13,,8552,0, 6 | 8,53,10529,1, 7 | 9,54,9924,0, 8 | 1,92,10533,0, 9 | 6,115,9167,1, 10 | 12,180,9289,0, 11 | 14,97,8343,1, 12 | 4,84,9005,1, 13 | 5,82,8867,1, 14 | 3,80,9352,1, 15 | 7,77,8900,1, 16 | -------------------------------------------------------------------------------- /tests/test_data/imputation_test.csv: -------------------------------------------------------------------------------- 1 | station,valid,tmpf,p01m 2 | FVE,2019-01-01 00:00,50,M 3 | FVE,2019-01-01 01:25,52,23 4 | FVE,2019-01-01 02:59,52,M 5 | FVE,2019-01-01 03:00,55,1 6 | FVE,2019-01-01 02:25,52,M 7 | FVE,2019-01-01 02:59,52,1 8 | FVE,2019-01-01 03:00,0,21 9 | -------------------------------------------------------------------------------- /tests/test_data/river_test_sm.csv: -------------------------------------------------------------------------------- 1 | ,Unnamed: 0_x,hour_updated,p01m,valid,tmpf,Unnamed: 0_y,agency_cd,site_no,datetime,tz_cd,103981_00060,103981_00060_cd,cfs 2 | 5,5.0,2014-01-01 06:00:00+00:00,0.0,2014-01-01 05:56,24.98,,,,,,,, 3 | 6,6.0,2014-01-01 07:00:00+00:00,0.0,2014-01-01 06:56,24.98,,,,,,,, 4 | 7,7.0,2014-01-01 08:00:00+00:00,0.0,2014-01-01 07:56,21.92,1.0,USGS,10321950,2014-01-01 08:00:00+00:00,PST,0.00,A,0.0 5 | 8,8.0,2014-01-01 09:00:00+00:00,0.0,2014-01-01 08:56,19.94,5.0,USGS,10321950,2014-01-01 09:00:00+00:00,PST,0.00,A,0.0 6 | 9,9.0,2014-01-01 10:00:00+00:00,0.0,2014-01-01 09:56,21.92,9.0,USGS,10321950,2014-01-01 10:00:00+00:00,PST,0.00,A,0.0 7 | 10,10.0,2014-01-01 11:00:00+00:00,0.0,2014-01-01 10:56,21.02,13.0,USGS,10321950,2014-01-01 11:00:00+00:00,PST,0.00,A,0.0 8 | 11,11.0,2014-01-01 12:00:00+00:00,0.0,2014-01-01 11:56,19.94,17.0,USGS,10321950,2014-01-01 12:00:00+00:00,PST,0.00,A,0.0 9 | -------------------------------------------------------------------------------- /tests/test_data/small_test.csv: -------------------------------------------------------------------------------- 1 | station,valid,tmpf,p01m 2 | FVE,2019-01-01 00:00,50,25 3 | FVE,2019-01-01 00:25,52,23 4 | FVE,2019-01-01 00:59,52,23 5 | FVE,2019-01-01 01:00,55,1 6 | FVE,2019-01-01 02:25,52,0 7 | FVE,2019-01-01 02:59,52,1 8 | FVE,2019-01-01 03:00,0,21s 9 | -------------------------------------------------------------------------------- /tests/test_data/test_format_data.csv: -------------------------------------------------------------------------------- 1 | precip,temp,height 2 | 0,25,.08 3 | .01,24,.06 4 | 0,25,.08 5 | .01,24,.06 6 | -------------------------------------------------------------------------------- /tests/test_decoder.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader 6 | from flood_forecast.model_dict_function import pytorch_criterion_dict 7 | from flood_forecast.transformer_xl.transformer_basic import SimpleTransformer, greedy_decode 8 | 9 | 10 | class TestDecoding(unittest.TestCase): 11 | def setUp(self): 12 | self.model = SimpleTransformer(3, 30, 20) 13 | self.data_test_path = os.path.join( 14 | os.path.dirname( 15 | os.path.abspath(__file__)), 16 | "test_init", 17 | "chick_final.csv") 18 | self.validation_loader = DataLoader( 19 | CSVDataLoader( 20 | self.data_test_path, 21 | forecast_history=30, 22 | forecast_length=20, 23 | target_col=['cfs'], 24 | relevant_cols=[ 25 | 'cfs', 26 | 'temp', 27 | 'precip'], 28 | interpolate_param=False), 29 | shuffle=False, 30 | sampler=None, 31 | batch_sampler=None, 32 | num_workers=0, 33 | collate_fn=None, 34 | pin_memory=False, 35 | drop_last=False, 36 | timeout=0, 37 | worker_init_fn=None) 38 | self.sequence_size = 30 39 | 40 | def test_full_forward_method(self): 41 | test_data = torch.rand(1, 30, 3) 42 | result = self.model(test_data, t=torch.rand(1, 20, 3)) 43 | self.assertEqual(result.shape, torch.Size([1, 20])) 44 | 45 | def test_encoder_seq(self): 46 | test_data = torch.rand(1, 30, 3) 47 | result = self.model.encode_sequence(test_data) 48 | self.assertEqual(result.shape, torch.Size([30, 1, 128])) 49 | 50 | def test_for_leakage(self): 51 | """Simple test to check that raw target data does NOT leak during validation steps.""" 52 | src, trg = next(iter(self.validation_loader)) 53 | trg_mem = trg.clone().detach() 54 | result = greedy_decode(self.model, src, 20, trg) 55 | self.assertNotEqual(result[0, 1, 0], trg_mem[0, 1, 0]) 56 | self.assertEqual(result[0, 1, 1], trg_mem[0, 1, 1]) 57 | self.assertEqual(result[0, 1, 2], trg_mem[0, 1, 2]) 58 | loss = pytorch_criterion_dict["MSE"]()(trg, trg_mem) 59 | 60 | self.assertNotEqual(result[0, 1, 0], result[0, 4, 0]) 61 | self.assertGreater(loss, 0) 62 | 63 | def test_make_function(self): 64 | self.assertEqual(1, 1) 65 | 66 | if __name__ == '__main__': 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /tests/test_deployment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from flood_forecast.deployment.inference import load_model, convert_to_torch_script, InferenceMode 4 | import unittest 5 | from datetime import datetime 6 | import torch 7 | 8 | 9 | class InferenceTests(unittest.TestCase): 10 | def setUp(self): 11 | """Modules to test model inference.////""" 12 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.json")) as y: 13 | self.config_test = json.load(y) 14 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "multi_config.json")) as y: 15 | self.multi_config_test = json.load(y) 16 | self.new_csv_path = "gs://flow_datasets/Massachusetts_Middlesex_County.csv" 17 | self.weight_path = "gs://coronaviruspublicdata/experiments/01_July_202009_44PM_model.pth" 18 | self.multi_path = "gs://flow_datasets/miami_multi.csv" 19 | self.multi_weight_path = "gs://coronaviruspublicdata/experiments/28_January_202102_14AM_model.pth" 20 | self.classification_weight_path = "gs://flow_datasets/test_data/model_save/24_May_202202_25PM_model.pth" 21 | self.ff_class_data_1 = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data/ff_test.csv") 22 | self.class_infer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "24_May_202202_25PM_1.json") 23 | with open(self.class_infer_path) as y: 24 | self.infer_class_mod = json.load(y) 25 | self.infer_class = InferenceMode(20, 30, self.config_test, self.new_csv_path, self.weight_path, "covid-core") 26 | 27 | def test_load_model(self): 28 | model = load_model(self.config_test, self.new_csv_path, self.weight_path) 29 | self.assertIsInstance(model, object) 30 | convert_to_torch_script(model, "test.pt") 31 | 32 | def test_infer_mode(self): 33 | # Test inference 34 | self.infer_class.infer_now(datetime(2020, 6, 1), self.new_csv_path) 35 | 36 | def test_plot_model(self): 37 | self.infer_class.make_plots(datetime(2020, 5, 1), self.new_csv_path, "flow_datasets", "tes1/t.csv", "prod_plot") 38 | 39 | def test_infer_multi(self): 40 | infer_multi = InferenceMode(20, 30, self.multi_config_test, self.multi_path, self.multi_weight_path, 41 | "covid-core") 42 | infer_multi.make_plots(datetime(2020, 12, 10), csv_bucket="flow_datasets", 43 | save_name="tes1/t2.csv", wandb_plot_id="prod_plot") 44 | 45 | def test_speed(self): 46 | # TODO compare torch script vs model here 47 | pass 48 | 49 | def test_classification_infer(self): 50 | m = InferenceMode(1, 1, self.infer_class_mod, self.ff_class_data_1, self.classification_weight_path) 51 | res = m.infer_now_classification() 52 | self.assertIsInstance(res, list) 53 | self.assertIsInstance(res[0], torch.Tensor) 54 | self.assertGreater(len(res), 10) 55 | self.assertTrue(torch.any(res[0] < 1)) 56 | self.assertTrue(torch.any(res[1] < 1)) 57 | 58 | def test_classification_infer_df(self): 59 | m = InferenceMode(1, 1, self.infer_class_mod, self.ff_class_data_1, self.classification_weight_path) 60 | original_df = m.model.training.original_df 61 | res = m.infer_now_classification(original_df[1:99]) 62 | self.assertIsInstance(res, list) 63 | self.assertGreater(len(res), 1) 64 | 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /tests/test_dual.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "CustomTransformerDecoder", 3 | "model_type": "PyTorch", 4 | "n_targets": 2, 5 | "model_params": { 6 | "n_time_series":3, 7 | "seq_length":5, 8 | "output_seq_length": 1, 9 | "n_layers_encoder": 6, 10 | "output_dim":2 11 | }, 12 | "dataset_params": 13 | { "class": "default", 14 | "training_path": "tests/test_data/keag_small.csv", 15 | "validation_path": "tests/test_data/keag_small.csv", 16 | "test_path": "tests/test_data/keag_small.csv", 17 | "batch_size":10, 18 | "forecast_history":5, 19 | "forecast_length":1, 20 | "train_end": 100, 21 | "valid_start":101, 22 | "valid_end": 201, 23 | "test_start": 202, 24 | "test_end": 290, 25 | "target_col": ["cfs"], 26 | "relevant_cols": ["cfs", "precip", "temp"], 27 | "scaler": "MinMaxScaler", 28 | "scaler_params":{ 29 | "feature_range":[0, 2] 30 | }, 31 | "interpolate": false 32 | }, 33 | "training_params": 34 | { 35 | "criterion":["MAPE", "CrossEntropyLoss"], 36 | "criterion_params": [ 37 | {}, {} 38 | ], 39 | "optimizer": "Adam", 40 | "optim_params": 41 | { 42 | 43 | }, 44 | "lr": 0.003, 45 | "epochs": 2, 46 | "batch_size":4 47 | 48 | }, 49 | "GCS": false, 50 | 51 | "wandb": { 52 | "name": "flood_forecast_circleci", 53 | "project": "repo-flood_forecast", 54 | "tags": ["dummy_run_dual", "circleci"] 55 | }, 56 | "forward_params":{}, 57 | "metrics":["MSE"], 58 | "inference_params": 59 | { 60 | "datetime_start":"2016-05-31", 61 | "num_prediction_samples":10, 62 | "hours_to_forecast":336, 63 | "test_csv_path":"tests/test_data/keag_small.csv", 64 | "decoder_params":{ 65 | "decoder_function": "simple_decode", 66 | "unsqueeze_dim": 1}, 67 | "dataset_params":{ 68 | "file_path": "tests/test_data/keag_small.csv", 69 | "forecast_history":5, 70 | "forecast_length":1, 71 | "relevant_cols": ["cfs", "precip", "temp"], 72 | "target_col": ["cfs"], 73 | "scaling": "MinMaxScaler", 74 | "scaler_params":{ 75 | "feature_range":[0,2] 76 | }, 77 | "interpolate_param": false 78 | } 79 | } 80 | 81 | 82 | } 83 | -------------------------------------------------------------------------------- /tests/test_handle_multi_crit.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from flood_forecast.pytorch_training import multi_crit 3 | from torch.nn import BCELoss 4 | from flood_forecast.custom.focal_loss import FocalLoss, BinaryFocalLossWithLogits 5 | import torch 6 | 7 | 8 | class TestMulticrit(unittest.TestCase): 9 | def setUp(self): 10 | self.crit = [BCELoss(), FocalLoss(0.25, reduction="sum"), BinaryFocalLossWithLogits(0.25, reduction="mean")] 11 | 12 | def test_crit_function(self): 13 | r1 = multi_crit(self.crit, torch.rand(4, 20, 5), torch.ones(4, 20, 5, dtype=torch.int64)) 14 | self.assertGreater(r1, 0.25) 15 | 16 | def test_focal_loss(self): 17 | f = FocalLoss(0.3) 18 | r = f(torch.rand(4, 20, 30), torch.rand(4, 20, 30)) 19 | self.assertGreater(r.shape[0], 0) 20 | 21 | if __name__ == '__main__': 22 | unittest.main() 23 | -------------------------------------------------------------------------------- /tests/test_iTransformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "ITransformer", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "forecast_history": 96, 7 | "forecast_length": 96, 8 | "embed": "timeF", 9 | "dropout": 0.1, 10 | "d_model": 512, 11 | "use_norm": true, 12 | "targs": 3 13 | 14 | }, 15 | "n_targets":3, 16 | "dataset_params": 17 | { "class": "TemporalLoader", 18 | "temporal_feats": ["month", "day", "day_of_week", "hour"], 19 | "training_path": "tests/test_data/keag_small.csv", 20 | "validation_path": "tests/test_data/keag_small.csv", 21 | "test_path": "tests/test_data/keag_small.csv", 22 | "batch_size":100, 23 | "forecast_history":96, 24 | "forecast_length": 96, 25 | "train_end": 200, 26 | "valid_start":180, 27 | "valid_end": 500, 28 | "test_start":299, 29 | "test_end": 500, 30 | "target_col": ["cfs", "precip", "temp"], 31 | "relevant_cols": ["cfs", "precip", "temp"], 32 | "scaler": "StandardScaler", 33 | "sort_column":"datetime", 34 | "interpolate": false, 35 | "feature_param": 36 | { 37 | "datetime_params":{ 38 | "month": "numerical", 39 | "day": "numerical", 40 | "day_of_week": "numerical", 41 | "hour":"numerical" 42 | } 43 | } 44 | }, 45 | "early_stopping": 46 | { 47 | "patience":3 48 | }, 49 | "training_params": 50 | { 51 | "criterion":"MSE", 52 | "optimizer": "Adam", 53 | "optim_params": 54 | { 55 | "lr": 0.002 56 | }, 57 | 58 | "epochs": 1, 59 | "batch_size":5 60 | 61 | }, 62 | "GCS": false, 63 | 64 | "wandb": { 65 | "name": "flood_forecast_circleci", 66 | "tags": ["dummy_run", "circleci"], 67 | "project":"repo-flood_forecast" 68 | }, 69 | "forward_params":{ 70 | }, 71 | "metrics":["MSE"], 72 | "inference_params": 73 | { "num_prediction_samples": 100, 74 | "datetime_start":"2016-05-30", 75 | "hours_to_forecast":300, 76 | "test_csv_path":"tests/test_data/keag_small.csv", 77 | "decoder_params":{ 78 | "decoder_function": "simple_decode", 79 | "unsqueeze_dim": 1}, 80 | "dataset_params":{ 81 | "file_path": "tests/test_data/keag_small.csv", 82 | "forecast_history":96, 83 | "forecast_length":96, 84 | "relevant_cols": ["cfs", "precip", "temp"], 85 | "target_col": ["cfs", "precip", "temp"], 86 | "scaling": "StandardScaler", 87 | "interpolate_param": false, 88 | "feature_params": 89 | { 90 | "datetime_params":{ 91 | "month": "numerical", 92 | "day": "numerical", 93 | "day_of_week": "numerical", 94 | "hour":"numerical" 95 | } 96 | }, 97 | "sort_column":"datetime" 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /tests/test_inf_single.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Informer", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "n_time_series":3, 7 | "dec_in":3, 8 | "c_out": 1, 9 | "seq_len":20, 10 | "label_len":10, 11 | "out_len":2, 12 | "factor":2 13 | }, 14 | "dataset_params": 15 | { "class": "TemporalLoader", 16 | "temporal_feats": ["month", "day", "day_of_week", "hour"], 17 | "training_path": "tests/test_data/keag_small.csv", 18 | "validation_path": "tests/test_data/keag_small.csv", 19 | "test_path": "tests/test_data/keag_small.csv", 20 | "batch_size":5, 21 | "forecast_history":20, 22 | "forecast_length":2, 23 | "train_end": 200, 24 | "valid_start":188, 25 | "valid_end": 220, 26 | "test_start":299, 27 | "test_end": 400, 28 | "target_col": ["cfs"], 29 | "relevant_cols": ["cfs", "precip", "temp"], 30 | "scaler": "StandardScaler", 31 | "sort_column":"datetime", 32 | "interpolate": false, 33 | "feature_param": 34 | { 35 | "datetime_params":{ 36 | "month": "numerical", 37 | "day": "numerical", 38 | "day_of_week": "numerical", 39 | "hour":"numerical" 40 | } 41 | } 42 | }, 43 | "early_stopping": 44 | { 45 | "patience":2 46 | 47 | }, 48 | "training_params": 49 | { 50 | "criterion":"MSE", 51 | "optimizer": "Adam", 52 | "optim_params": 53 | { 54 | 55 | }, 56 | "lr": 0.04, 57 | "epochs": 1, 58 | "batch_size":4 59 | 60 | }, 61 | "GCS": false, 62 | "wandb": { 63 | "name": "flood_forecast_circleci", 64 | "tags": ["dummy_run", "circleci"], 65 | "project":"repo-flood_forecast" 66 | }, 67 | "forward_params":{ 68 | }, 69 | "metrics":["MSE"], 70 | "inference_params": 71 | { 72 | "datetime_start":"2016-05-31", 73 | "hours_to_forecast":334, 74 | "test_csv_path":"tests/test_data/keag_small.csv", 75 | "decoder_params":{ 76 | "decoder_function": "greedy_decode", 77 | "unsqueeze_dim": 1}, 78 | "dataset_params":{ 79 | "file_path": "tests/test_data/keag_small.csv", 80 | "forecast_history":20, 81 | "forecast_length":2, 82 | "relevant_cols": ["cfs", "precip", "temp"], 83 | "target_col": ["cfs"], 84 | "scaling": "StandardScaler", 85 | "interpolate_param": false, 86 | "sort_column":"datetime", 87 | "feature_params": 88 | { 89 | "datetime_params":{ 90 | "month": "numerical", 91 | "day": "numerical", 92 | "day_of_week": "numerical", 93 | "hour":"numerical" 94 | } 95 | } 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /tests/test_informer.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Informer", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "n_time_series":3, 7 | "dec_in":3, 8 | "c_out": 2, 9 | "seq_len":20, 10 | "label_len":10, 11 | "out_len":2, 12 | "factor":2 13 | }, 14 | "n_targets":2, 15 | "dataset_params": 16 | { "class": "TemporalLoader", 17 | "temporal_feats": ["month", "day", "day_of_week", "hour"], 18 | "training_path": "tests/test_data/keag_small.csv", 19 | "validation_path": "tests/test_data/keag_small.csv", 20 | "test_path": "tests/test_data/keag_small.csv", 21 | "batch_size":4, 22 | "forecast_history":20, 23 | "label_len":10, 24 | "forecast_length":2, 25 | "train_end": 200, 26 | "valid_start":191, 27 | "valid_end": 290, 28 | "test_start":299, 29 | "test_end": 400, 30 | "target_col": ["cfs", "precip"], 31 | "relevant_cols": ["cfs", "precip", "temp"], 32 | "scaler": "StandardScaler", 33 | "sort_column":"datetime", 34 | "interpolate": false, 35 | "feature_param": 36 | { 37 | "datetime_params":{ 38 | "month": "numerical", 39 | "day": "numerical", 40 | "day_of_week": "numerical", 41 | "hour":"numerical" 42 | } 43 | } 44 | }, 45 | "early_stopping": 46 | { 47 | "patience":3 48 | 49 | }, 50 | "training_params": 51 | { 52 | "criterion":"MSE", 53 | "optimizer": "Adam", 54 | "optim_params": 55 | { 56 | "lr": 0.004 57 | }, 58 | 59 | "epochs": 1, 60 | "batch_size":5 61 | 62 | }, 63 | "GCS": false, 64 | 65 | "wandb": { 66 | "name": "flood_forecast_circleci", 67 | "tags": ["dummy_run", "circleci"], 68 | "project":"repo-flood_forecast" 69 | }, 70 | "forward_params":{ 71 | }, 72 | "metrics":["MSE"], 73 | "inference_params": 74 | { 75 | "datetime_start":"2016-05-30", 76 | "num_prediction_samples": 5, 77 | "hours_to_forecast":300, 78 | "test_csv_path":"tests/test_data/keag_small.csv", 79 | "decoder_params":{ 80 | "decoder_function": "greedy_decode", 81 | "unsqueeze_dim": 1}, 82 | "dataset_params":{ 83 | "file_path": "tests/test_data/keag_small.csv", 84 | "forecast_history":20, 85 | "forecast_length":2, 86 | "relevant_cols": ["cfs", "precip", "temp"], 87 | "target_col": ["cfs", "precip"], 88 | "scaling": "StandardScaler", 89 | "interpolate_param": false, 90 | "feature_params": 91 | { 92 | "datetime_params":{ 93 | "month": "numerical", 94 | "day": "numerical", 95 | "day_of_week": "numerical", 96 | "hour":"numerical" 97 | } 98 | }, 99 | "sort_column":"datetime" 100 | } 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /tests/test_join.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from flood_forecast.preprocessing.buil_dataset import combine_data 3 | import pandas as pd 4 | import os 5 | import json 6 | from datetime import datetime 7 | import pytz 8 | 9 | 10 | class JoinTest(unittest.TestCase): 11 | def setUp(self): 12 | self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") 13 | # 1 14 | 15 | def test_join_function(self): 16 | df = pd.read_csv(os.path.join(self.test_data_path, "fake_test_small.csv"), sep="\t") 17 | asos_df = pd.read_csv(os.path.join(self.test_data_path, "asos-12N_small.csv")) 18 | old_timezone = pytz.timezone("America/New_York") 19 | new_timezone = pytz.timezone("UTC") 20 | # This assumes timezones are consistent throughout the USGS stream (this should be true for all) 21 | df["datetime"] = df["datetime"].map(lambda x: old_timezone.localize( 22 | datetime.strptime(x, "%Y-%m-%d %H:%M")).astimezone(new_timezone)) 23 | with open(os.path.join(self.test_data_path, "big_black_md.json")) as a: 24 | meta_data = json.load(a) 25 | self.assertEqual(meta_data['gage_id'], 1010070) 26 | result_df, nan_f, nan_p = combine_data(df, asos_df) 27 | self.assertEqual(result_df.iloc[0]['p01m'], 0) 28 | self.assertEqual(result_df.iloc[0]['cfs'], 2210) 29 | self.assertEqual(result_df.iloc[0]['tmpf'], 19.94) 30 | 31 | if __name__ == '__main__': 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.custom.custom_opt import MASELoss, MAPELoss, RMSELoss, BertAdam, l1_regularizer, orth_regularizer 2 | from flood_forecast.da_rnn.model import DARNN 3 | from flood_forecast.custom.dilate_loss import pairwise_distances 4 | from flood_forecast.training_utils import EarlyStopper 5 | from flood_forecast.basic.base_line_methods import NaiveBase 6 | from flood_forecast.custom.custom_activation import _sparsemax_threshold_and_support, _entmax_threshold_and_support 7 | from flood_forecast.custom.custom_activation import Sparsemax, Entmax15 8 | import torch 9 | import unittest 10 | 11 | 12 | class TestLossFunctions(unittest.TestCase): 13 | def setUp(self): 14 | self.mase = MASELoss("mean") 15 | 16 | def test_mase_runs(self): 17 | mase_input = torch.rand(2, 5, 1) 18 | mase_targ = torch.rand(2, 5, 1) 19 | mase_hist = torch.rand(2, 20, 20) 20 | self.mase(mase_input, mase_targ, mase_hist) 21 | 22 | def test_mase_mean_correct(self): 23 | m = MASELoss("mean") 24 | pred = torch.Tensor([2, 2]).repeat(2, 1) 25 | targ = torch.Tensor([4, 4]).repeat(2, 1) 26 | hist = torch.Tensor([6, 6]).repeat(2, 1) 27 | result = m(targ, pred, hist) 28 | self.assertEqual(result, 1) 29 | 30 | def test_mape_correct(self): 31 | m = MAPELoss() 32 | hist = torch.Tensor([7, 7]).repeat(2, 1) 33 | targ = torch.Tensor([4, 4]).repeat(2, 1) 34 | m(torch.rand(1, 3), torch.rand(1, 3)) 35 | self.assertEqual(.75, m(hist, targ)) 36 | 37 | def test_rmse_correct(self): 38 | pred = torch.Tensor([2, 2]).repeat(2, 1) 39 | targ = torch.Tensor([4, 4]).repeat(2, 1) 40 | r = RMSELoss() 41 | self.assertEqual(r(pred, targ), 2) 42 | 43 | def test_bert_adam(self): 44 | dd = DARNN(3, 128, 10, 128, 1, 0.2) 45 | b_adam = BertAdam(dd.parameters(), lr=.01, warmup=0.0) 46 | print(b_adam.get_lr) 47 | self.assertEqual(1, 1) 48 | 49 | def test_regularlizer(self): 50 | dd = DARNN(3, 128, 10, 128, 1, 0.2) 51 | l1_regularizer(dd) 52 | orth_regularizer(dd) 53 | self.assertIsInstance(dd, DARNN) 54 | 55 | def test_pairwise(self): 56 | pairwise_distances(torch.rand(2, 3)) 57 | 58 | def test_early_stopper(self): 59 | e = EarlyStopper(3, .2) 60 | n = NaiveBase(2, 2) 61 | e.check_loss(n, .9) 62 | e.check_loss(n, .8) 63 | e.check_loss(n, .9) 64 | self.assertFalse(e.check_loss(n, .75)) 65 | 66 | def test_early_stopper2(self): 67 | e = EarlyStopper(3, .2) 68 | n = NaiveBase(2, 2) 69 | e.check_loss(n, .9) 70 | e.check_loss(n, .7) 71 | self.assertTrue(e.check_loss(n, .6)) 72 | 73 | def test_early_stopper3(self): 74 | e = EarlyStopper(3, .2, True) 75 | n = NaiveBase(2, 2) 76 | e.check_loss(n, .9) 77 | e.check_loss(n, 1.1) 78 | e.check_loss(n, 1.2) 79 | self.assertFalse(e.check_loss(n, .8)) 80 | 81 | def test_dilate_correct(self): 82 | pass 83 | 84 | def test_sparse_max_runs(self): 85 | _entmax_threshold_and_support(torch.rand(2, 20, 3)) 86 | _sparsemax_threshold_and_support(torch.rand(2, 30, 3)) 87 | s = Sparsemax() 88 | s(torch.rand(2, 4, 2)) 89 | e = Entmax15() 90 | e(torch.rand(2, 4, 2)) 91 | 92 | if __name__ == '__main__': 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /tests/test_merging_models.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.meta_models.merging_model import MergingModel, MultiModalSelfAttention 2 | from flood_forecast.utils import make_criterion_functions 3 | import unittest 4 | import torch 5 | 6 | 7 | class TestMerging(unittest.TestCase): 8 | def setUp(self): 9 | self.merging_model = MergingModel("Concat", {"cat_dim": 2, "repeat": True}) 10 | self.merging_model_bi = MergingModel("Bilinear", {"in1_features": 6, "in2_features": 3 - 2, "out_features": 40}) 11 | self.merging_model_2 = MergingModel("Bilinear2", {"in1_features": 20, "in2_features": 25, "out_features": 49}) 12 | self.merging_mode3 = MergingModel("Concat", {"cat_dim": 2, "repeat": True, "use_layer": True, "out_shape": 10, 13 | "combined_d": 15}) 14 | self.attn = MultiModalSelfAttention(128, 4, 0.2) 15 | 16 | def test_merger_runs(self): 17 | m = self.merging_model(torch.rand(2, 6, 10), torch.rand(4)) 18 | self.assertEqual(m.shape[0], 2) 19 | self.assertEqual(m.shape[1], 6) 20 | self.assertEqual(m.shape[2], 14) 21 | 22 | def test_merger_two(self): 23 | m = self.merging_model(torch.rand(2, 6, 20), torch.rand(4)) 24 | self.assertEqual(m.shape[2], 24) 25 | 26 | def test_crit_functions_list(self): 27 | res = make_criterion_functions(["MSE", "RMSE", "MAPE"]) 28 | self.assertIsInstance(res, list) 29 | 30 | def test_crit_functions_dict(self): 31 | res = make_criterion_functions({"MASELoss": {"baseline_method": "mean"}, "MSE": {}}) 32 | self.assertIsInstance(res, list) 33 | 34 | def test_bilinear_model(self): 35 | r = self.merging_model_bi(torch.rand(2, 6, 128), torch.rand(128)) 36 | self.assertEqual(r.shape[1], 40) 37 | 38 | def test_bilinear_2(self): 39 | m = self.merging_model_2(torch.rand(2, 6, 20), torch.rand(25)) 40 | self.assertEqual(m.shape[2], 49) 41 | 42 | def test_cat_out(self): 43 | m = self.merging_mode3(torch.rand(2, 6, 10), torch.rand(5)) 44 | self.assertEqual(m.shape[2], 10) 45 | 46 | if __name__ == '__main__': 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /tests/test_meta_pr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | import unittest 4 | from flood_forecast.meta_models.basic_ae import AE 5 | from flood_forecast.utils import numpy_to_tvar 6 | 7 | 8 | class MetaModels(unittest.TestCase): 9 | def setUp(self): 10 | self.AE = AE(10, 128) 11 | 12 | def test_ae_init(self): 13 | self.assertEqual(self.AE.encoder_hidden_layer.in_features, 10) 14 | self.assertEqual(self.AE(torch.rand(2, 10)).shape[0], 2) 15 | 16 | def test_ae_2(self): 17 | self.assertEqual(self.AE.decoder_output_layer.out_features, 10) 18 | res = numpy_to_tvar(numpy.random.rand(1, 2)) 19 | self.assertIsInstance(res, torch.Tensor) 20 | 21 | if __name__ == '__main__': 22 | unittest.main() 23 | -------------------------------------------------------------------------------- /tests/test_multitask_decoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import unittest 5 | from flood_forecast.basic.linear_regression import simple_decode 6 | from flood_forecast.trainer import train_function, correct_stupid_sklearn_error 7 | 8 | 9 | class MultitTaskTests(unittest.TestCase): 10 | @classmethod 11 | def setUpClass(cls): 12 | """Modules to test model inference.""" 13 | with open(os.path.join(os.path.dirname(__file__), "multi_decoder_test.json")) as a: 14 | cls.model_params = json.load(a) 15 | with open(os.path.join(os.path.dirname(__file__), "multitask_decoder.json")) as a: 16 | cls.model_params3 = json.load(a) 17 | cls.keag_path = os.path.join(os.path.dirname(__file__), "test_data", "keag_small.csv") 18 | if "save_path" in cls.model_params: 19 | del cls.model_params["save_path"] 20 | cls.model_params = correct_stupid_sklearn_error(cls.model_params) 21 | cls.model_params3 = correct_stupid_sklearn_error(cls.model_params3) 22 | # cls.forecast_model2 = train_function("PyTorch", cls.model_params) 23 | 24 | def test_decoder_multi_step(self): 25 | if "save_path" in self.model_params: 26 | del self.model_params["save_path"] 27 | forecast_model = train_function("PyTorch", self.model_params) 28 | t = torch.Tensor([3, 4, 5]).repeat(1, 336, 1) 29 | output = simple_decode(forecast_model.model, torch.ones(1, 5, 3), 336, t, output_len=1) 30 | # We want to check for leakage 31 | self.assertFalse(3 in output[:, :, 0]) 32 | 33 | def test_multivariate_single_step(self): 34 | # dumb error fixes 35 | if "save_path" in self.model_params3: 36 | del self.model_params["save_path"] 37 | t = torch.Tensor([3, 6, 5]).repeat(1, 100, 1) 38 | forecast_model3 = train_function("PyTorch", self.model_params3) 39 | output = simple_decode(forecast_model3.model, torch.ones(1, 5, 3), 100, t, output_len=1, multi_targets=2) 40 | self.assertFalse(3 in output) 41 | self.assertFalse(6 in output) 42 | 43 | """def test_decoder_single_step(self): 44 | t = torch.Tensor([3, 4, 5]).repeat(1, 336, 1) 45 | output = simple_decode(self.forecast_model2.model, torch.ones(1, 5, 3), 336, t, output_len=3) 46 | # We want to check for leakage here 47 | self.assertFalse(3 in output[:, :, 0])""" 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tests/test_plot.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pandas as pd 3 | import plotly.graph_objects as go 4 | from flood_forecast.plot_functions import calculate_confidence_intervals, plot_df_test_with_confidence_interval 5 | 6 | 7 | class PlotFunctionsTest(unittest.TestCase): 8 | """Tests the plot functions.""" 9 | df_test = pd.DataFrame({ 10 | 'preds': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], 11 | 'target_col': [4.0, 5.0, 6.0, 7.0, 8.0, 9.0] 12 | }) 13 | df_preds = pd.DataFrame({ 14 | 0: [-1.0, -2.0, -1.0, 0.0, -1.0, 6.0], 15 | 1: [1.0, 2.0, 4.0, 3.0, 2.0, 9.0] 16 | }) 17 | df_preds_empty = pd.DataFrame(index=[0, 1, 2, 3, 4, 5]) 18 | 19 | def test_calculate_confidence_intervals(self): 20 | ci_lower, ci_upper = 0.025, 0.975 21 | df_quantiles = calculate_confidence_intervals( 22 | self.df_preds, self.df_test['preds'], ci_lower, ci_upper) 23 | df_preds_mean = self.df_preds.mean(axis=1) 24 | self.assertTrue((df_quantiles[ci_lower] < df_preds_mean).all()) 25 | self.assertTrue((df_quantiles[ci_upper] > df_preds_mean).all()) 26 | self.assertTrue((df_quantiles[ci_lower] <= self.df_test['preds']).all()) 27 | self.assertTrue((df_quantiles[ci_upper] >= self.df_test['preds']).all()) 28 | 29 | def test_calculate_confidence_intervals_df_preds_empty(self): 30 | ci_lower, ci_upper = 0.025, 0.975 31 | df_quantiles = calculate_confidence_intervals( 32 | self.df_preds_empty, self.df_test['preds'], ci_lower, ci_upper) 33 | self.assertTrue(df_quantiles[ci_lower].isna().all()) 34 | self.assertTrue(df_quantiles[ci_upper].isna().all()) 35 | 36 | def test_plot_df_test_with_confidence_interval(self): 37 | params = {'dataset_params': {'target_col': ['target_col']}} 38 | fig = plot_df_test_with_confidence_interval(self.df_test, self.df_preds, 0, params, "target_col", 95) 39 | self.assertIsInstance(fig, go.Figure) 40 | 41 | def test_plot_df_test_with_confidence_interval_df_preds_empty(self): 42 | params = {'dataset_params': {'target_col': ['target_col']}} 43 | fig = plot_df_test_with_confidence_interval( 44 | self.df_test, self.df_preds_empty, 0, params, "target_col", 95) 45 | self.assertIsInstance(fig, go.Figure) 46 | 47 | 48 | if __name__ == '__main__': 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /tests/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.preprocessing.interpolate_preprocess import back_forward_generic 2 | from flood_forecast.preprocessing.temporal_feats import feature_fix 3 | import unittest 4 | import pandas as pd 5 | import os 6 | 7 | 8 | class TestInterpolationCode(unittest.TestCase): 9 | def setUp(self): 10 | file_path = os.path.join(os.path.dirname(__file__), "test_data", "farm_ex.csv") 11 | file_path_2 = os.path.join(os.path.dirname(__file__), "test_data", "fake_test_small.csv") 12 | self.df = pd.read_csv(file_path) 13 | self.df_2 = pd.read_csv(file_path_2, delimiter="\t") 14 | self.df_2["datetime"] = pd.to_datetime(self.df_2["datetime"]) 15 | 16 | def test_back_forward(self): 17 | """Test the generation of forward and backward data interp.""" 18 | df = back_forward_generic(self.df, ["NumberOfAnimals"]) 19 | self.assertEqual(df.iloc[3]["NumberOfAnimals"], 165) 20 | 21 | def test_make_temp_feats(self): 22 | feats = feature_fix({"datetime_params": {"hour": "cyclical"}}, "datetime", self.df_2) 23 | self.assertIn("sin_hour", feats[0].columns) 24 | self.assertIn("cos_hour", feats[0].columns) 25 | self.assertIn("norm", feats[0].columns) 26 | 27 | def test_make_temp_feats2(self): 28 | feats = feature_fix({"datetime_params": {"year": "numerical", "day": "cyclical"}}, "datetime", self.df_2) 29 | self.assertIn("year", feats[0].columns) 30 | self.assertIn("sin_day", feats[0].columns) 31 | self.assertIn("cos_day", feats[0].columns) 32 | self.assertIn("norm", feats[0].columns) 33 | 34 | if __name__ == '__main__': 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /tests/test_preprocessing_ae.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.preprocessing.preprocess_da_rnn import TrainData, format_data, make_data 2 | import unittest 3 | import pandas as pd 4 | import os 5 | 6 | 7 | class TestPreprocessingDA(unittest.TestCase): 8 | def test_format_data(self): 9 | df = pd.read_csv( 10 | os.path.join( 11 | os.path.dirname(__file__), 12 | "test_data", 13 | "test_format_data.csv")) 14 | self.assertEqual(type(format_data(df, ["height"])), TrainData) 15 | self.assertEqual(len(format_data(df, ["height"]).feats[0]), 2) 16 | 17 | def test_make_function(self): 18 | result = make_data( 19 | os.path.join( 20 | os.path.dirname(__file__), 21 | "test_data", 22 | "test_format_data.csv"), 23 | target_col=["height"], 24 | test_length=3, 25 | relevant_cols=[ 26 | "temp", 27 | "precip"]) 28 | self.assertEqual(len(result.feats), 4) 29 | self.assertEqual(len(result.targs), 4) 30 | 31 | if __name__ == '__main__': 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /tests/test_series_id.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.preprocessing.pytorch_loaders import CSVSeriesIDLoader, SeriesIDTestLoader 2 | # from flood_forecast.evaluator import infer_on_torch_model 3 | import unittest 4 | import os 5 | from torch.nn import MSELoss 6 | from torch.utils.data import DataLoader 7 | import torch 8 | from flood_forecast.series_id_helper import handle_csv_id_output 9 | from flood_forecast.model_dict_function import DecoderTransformer 10 | from datetime import datetime 11 | 12 | 13 | class TestInterpolationCSVLoader(unittest.TestCase): 14 | def setUp(self): 15 | self.test_data_path = os.path.join( 16 | os.path.dirname(os.path.abspath(__file__)), "test_data" 17 | ) 18 | self.dataset_params = { 19 | "file_path": os.path.join(self.test_data_path, "solar_small.csv"), 20 | "forecast_history": 20, 21 | "forecast_length": 1, 22 | "relevant_cols": ["DAILY_YIELD", "DC_POWER", "AC_POWER"], 23 | "target_col": ["DAILY_YIELD"], 24 | "interpolate_param": False, 25 | } 26 | self.data_loader = CSVSeriesIDLoader("PLANT_ID", self.dataset_params, "r") 27 | 28 | def test_seriesid(self): 29 | """Tests the series_id method a single item.""" 30 | x, y = self.data_loader[0] 31 | self.assertIsInstance(x, dict) 32 | self.assertIsInstance(y, dict) 33 | # self.assertGreater(x[1][0, 0], 1) redo test later 34 | self.assertEqual(x[1].shape[1], 3) 35 | 36 | def test_handle_series_id(self): 37 | """Tests the handle_series_id method(s)""" 38 | mse1 = MSELoss() 39 | d1 = DataLoader(self.data_loader, batch_size=2) 40 | d = DecoderTransformer(3, 8, 4, 128, 20, 0.2, 1, {}, seq_num1=3, forecast_length=1) 41 | 42 | class DummyHolder(): 43 | def __init__(self, model): 44 | self.model = model 45 | mod = DummyHolder(d) 46 | x, y = d1.__iter__().__next__() 47 | l1 = handle_csv_id_output(x, y, mod, mse1, torch.optim.Adam(d.parameters())) 48 | self.assertGreater(l1, 0) 49 | 50 | def test_series_test_loader(self): 51 | loader_ds1 = SeriesIDTestLoader("PLANT_ID", self.dataset_params, "shit") 52 | res = loader_ds1.get_from_start_date_all(datetime(2020, 6, 6)) 53 | self.assertGreater(len(res), 1) 54 | historical_rows, all_rows_orig, forecast_start = res[0] 55 | self.assertEqual(historical_rows.shape[0], 20) 56 | self.assertEqual(historical_rows.shape[1], 3) 57 | print(all_rows_orig) 58 | # self.assertIsInstance(all_rows_orig, pd.DataFrame) 59 | self.assertGreater(forecast_start, 0) 60 | # self.assertIsInstance(df_train_test, pd.DataFrame) 61 | 62 | def test_eval_series_loader(self): 63 | # infer_on_torch_model("s") # to-do fill in 64 | self.assertFalse(False) 65 | pass 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /tests/test_squashed.py: -------------------------------------------------------------------------------- 1 | from flood_forecast.transformer_xl.transformer_basic import CustomTransformerDecoder 2 | import unittest 3 | import torch 4 | 5 | 6 | class TestTransformerDecoderEmbedding(unittest.TestCase): 7 | def setUp(self): 8 | self.transformer_encoder = CustomTransformerDecoder(20, 20, 5, output_dim=5, squashed_embedding=True) 9 | 10 | def test_custom_full(self): 11 | m = self.transformer_encoder(torch.rand(10, 20, 5)) 12 | self.assertEqual(m.shape[0], 10) 13 | self.assertEqual(m.shape[1], 20) 14 | self.assertEqual(m.shape[2], 5) 15 | 16 | def test_encoder(self): 17 | m = self.transformer_encoder.make_embedding(torch.rand(10, 20, 5)) 18 | self.assertEqual(m.shape[2], 1) 19 | self.assertEqual(m.shape[1], 128) 20 | self.assertEqual(m.shape[0], 10) 21 | 22 | if __name__ == '__main__': 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /tests/test_variable_length.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch, os 3 | from flood_forecast.preprocessing.pytorch_loaders import VariableSequenceLength 4 | 5 | 6 | class TestVariableLength(unittest.TestCase): 7 | def setUp(self) -> None: 8 | self.test_data_path = os.path.join( 9 | os.path.dirname(os.path.abspath(__file__)), "test_data2" 10 | ) 11 | self.dataset_params = { 12 | "file_path": os.path.join(self.test_data_path, "test_csv.csv"), 13 | "forecast_history": 20, 14 | "forecast_length": 1, 15 | "relevant_cols": ["playId", "yardlineNumber", "yardsToGo"], 16 | "target_col": ["vel"], 17 | "interpolate_param": False, 18 | } 19 | self.loader = VariableSequenceLength("playId", self.dataset_params, 100, "auto") 20 | 21 | def test_padding(self): 22 | dat = torch.rand(2, 4) 23 | self.assertEqual(self.loader.pad_input_data(dat).shape[0], 100) 24 | self.assertEqual(self.loader.pad_input_data(dat).shape[1], 4) 25 | 26 | def test_get_item_classification(self): 27 | self.loader.get_item_classification(0) 28 | 29 | def test_get_item_auto(self): 30 | x, y = self.loader.get_item_auto_encoder(0) 31 | self.assertEqual(x.shape[0], 100) 32 | self.assertEqual(y.shape[0], 100) 33 | self.assertEqual(x.shape[1], 3) 34 | self.assertEqual(y.shape[1], 3) 35 | 36 | def test_forecast(self): 37 | self.assertEqual(0, 0) 38 | -------------------------------------------------------------------------------- /tests/time_model_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from flood_forecast.model_dict_function import pytorch_model_dict as pytorch_model_dict1 3 | from flood_forecast.time_model import PyTorchForecast 4 | import os 5 | import torch 6 | 7 | 8 | class TimeSeriesModelTest(unittest.TestCase): 9 | def setUp(self): 10 | self.test_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_init") 11 | self.model_params = { 12 | "metrics": ["MSE", "DilateLoss"], 13 | "model_params": { 14 | "number_time_series": 3}, 15 | "inference_params": { 16 | "hours_to_forecast": 16}, 17 | "dataset_params": { 18 | "forecast_history": 20, 19 | "class": "default", 20 | "forecast_length": 20, 21 | "relevant_cols": [ 22 | "cfs", 23 | "temp", 24 | "precip"], 25 | "target_col": ["cfs"], 26 | "interpolate": False}, 27 | "wandb": False} 28 | 29 | def test_pytorch_model_dict(self): 30 | self.assertEqual(type(pytorch_model_dict1), dict) 31 | 32 | def test_pytorch_wrapper_default(self): 33 | keag_file = os.path.join(self.test_path, "keag_small.csv") 34 | model = PyTorchForecast( 35 | "MultiAttnHeadSimple", 36 | keag_file, 37 | keag_file, 38 | keag_file, 39 | self.model_params) 40 | self.assertEqual(model.model.dense_shape.in_features, 3) 41 | self.assertEqual(model.model.multi_attn.embed_dim, 128) 42 | self.assertEqual(model.model.multi_attn.num_heads, 8) 43 | 44 | def test_pytorch_wrapper_custom(self): 45 | self.model_params["model_params"] = {"number_time_series": 6, "d_model": 112} 46 | keag_file = os.path.join(self.test_path, "keag_small.csv") 47 | model = PyTorchForecast( 48 | "MultiAttnHeadSimple", 49 | keag_file, 50 | keag_file, 51 | keag_file, 52 | self.model_params) 53 | self.assertEqual(model.model.dense_shape.in_features, 6) 54 | self.assertEqual(model.model.multi_attn.embed_dim, 112) 55 | 56 | def test_model_save(self): 57 | keag_file = os.path.join(self.test_path, "keag_small.csv") 58 | model = PyTorchForecast( 59 | "MultiAttnHeadSimple", 60 | keag_file, 61 | keag_file, 62 | keag_file, 63 | self.model_params) 64 | model.save_model("output", 0) 65 | self.assertEqual(model.training[0][0].shape, torch.Size([20, 3])) 66 | 67 | def test_simple_transformer(self): 68 | self.model_params["model_params"] = { 69 | "seq_length": 19, 70 | "number_time_series": 6, 71 | "d_model": 136, 72 | "n_heads": 8} 73 | keag_file = os.path.join(self.test_path, "keag_small.csv") 74 | model = PyTorchForecast( 75 | "SimpleTransformer", 76 | keag_file, 77 | keag_file, 78 | keag_file, 79 | self.model_params) 80 | self.assertEqual(model.model.dense_shape.in_features, 6) 81 | self.assertEqual(model.model.mask.shape, torch.Size([19, 19])) 82 | 83 | def test_data_correct(self): 84 | keag_file = os.path.join(self.test_path, "keag_small.csv") 85 | model = PyTorchForecast( 86 | "MultiAttnHeadSimple", 87 | keag_file, 88 | keag_file, 89 | keag_file, 90 | self.model_params) 91 | model 92 | 93 | def test_informer_init(self): 94 | import json 95 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_informer.json")) as y: 96 | json_params = json.load(y) 97 | keag_file = os.path.join(self.test_path, "keag_small.csv") 98 | inf = PyTorchForecast("Informer", keag_file, keag_file, keag_file, json_params) 99 | self.assertTrue(inf) 100 | self.assertEqual(inf.model.label_len, 10) 101 | 102 | 103 | if __name__ == '__main__': 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /tests/transformer_b_series.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DecoderTransformer", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":5, 6 | "n_head": 8, 7 | "forecast_history":5, 8 | "n_embd": 1, 9 | "num_layer": 5, 10 | "dropout": 0.000001, 11 | "q_len": 1, 12 | "scale_att": false, 13 | "forecast_length": 1, 14 | "additional_params":{} 15 | }, 16 | "dataset_params": 17 | { "class": "SeriesIDLoader", 18 | "series_id_col": "PLANT_ID", 19 | "return_method": "all", 20 | "num_workers": 2, 21 | "training_path": "tests/test_data/solar_small.csv", 22 | "validation_path": "tests/test_data/solar_small.csv", 23 | "test_path": "tests/test_data/solar_small.csv", 24 | "batch_size":4, 25 | "forecast_history":5, 26 | "forecast_length":1, 27 | "train_end": 100, 28 | "valid_start":301, 29 | "valid_end": 401, 30 | "test_end":400, 31 | "target_col": ["DAILY_YIELD"], 32 | "relevant_cols": ["DAILY_YIELD", "DC_POWER", "AC_POWER"], 33 | "scaler": "RobustScaler", 34 | "no_scale": true, 35 | "interpolate": false, 36 | "sort_column":"datetime", 37 | "feature_param": 38 | {"datetime_params":{ 39 | "hour":"cyclical" 40 | }} 41 | }, 42 | "training_params": 43 | { 44 | "criterion":"DilateLoss", 45 | "optimizer": "Adam", 46 | "optim_params": 47 | { 48 | }, 49 | "lr": 0.01, 50 | "epochs": 1, 51 | "batch_size":5 52 | }, 53 | "GCS": false, 54 | 55 | "wandb": { 56 | "name": "flood_forecast_circleci", 57 | "project": "repo-flood_forecast", 58 | "tags": ["dummy_run", "circleci"] 59 | }, 60 | "forward_params":{}, 61 | "metrics":["MSE"], 62 | "inference_params": 63 | { 64 | "datetime_start":"2020-06-06", 65 | "hours_to_forecast":4, 66 | "test_csv_path":"tests/test_data/solar_small.csv", 67 | "decoder_params":{ 68 | "decoder_function": "simple_decode", 69 | "unsqueeze_dim": 1} 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /tests/transformer_bottleneck.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DecoderTransformer", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":5, 6 | "n_head": 8, 7 | "forecast_history":5, 8 | "n_embd": 1, 9 | "num_layer": 5, 10 | "dropout":0.1, 11 | "q_len": 1, 12 | "scale_att": false, 13 | "forecast_length": 1, 14 | "additional_params":{} 15 | }, 16 | "dataset_params": 17 | { "class": "default", 18 | "num_workers": 1, 19 | "training_path": "tests/test_data/keag_small.csv", 20 | "validation_path": "tests/test_data/keag_small.csv", 21 | "test_path": "tests/test_data/keag_small.csv", 22 | "batch_size":10, 23 | "forecast_history":5, 24 | "forecast_length":1, 25 | "train_end": 101, 26 | "valid_start":301, 27 | "valid_end": 401, 28 | "test_end":400, 29 | "target_col": ["cfs"], 30 | "relevant_cols": ["cfs", "precip", "temp"], 31 | "scaler": "RobustScaler", 32 | "no_scale": true, 33 | "interpolate": false, 34 | "sort_column":"datetime", 35 | "feature_param": 36 | {"datetime_params":{ 37 | "hour":"cyclical" 38 | }} 39 | }, 40 | "training_params": 41 | { 42 | "criterion":"DilateLoss", 43 | "optimizer": "Adam", 44 | "optim_params": 45 | { 46 | }, 47 | "lr": 0.03, 48 | "epochs": 1, 49 | "batch_size":8 50 | }, 51 | "GCS": false, 52 | 53 | "wandb": { 54 | "name": "flood_forecast_circleci", 55 | "project": "repo-flood_forecast", 56 | "tags": ["dummy_run", "circleci"] 57 | }, 58 | "forward_params":{}, 59 | "metrics":["DilateLoss"], 60 | "inference_params": 61 | { 62 | "datetime_start":"2016-05-31", 63 | "hours_to_forecast":336, 64 | "test_csv_path":"tests/test_data/keag_small.csv", 65 | "decoder_params":{ 66 | "decoder_function": "simple_decode", 67 | "unsqueeze_dim": 1} 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /tests/transformer_gaussian.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "DecoderTransformer", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "mu": true, 6 | "n_time_series":5, 7 | "n_head": 8, 8 | "forecast_history":5, 9 | "n_embd": 1, 10 | "num_layer": 5, 11 | "dropout":0.1, 12 | "q_len": 1, 13 | "scale_att": false, 14 | "forecast_length": 1, 15 | "additional_params":{} 16 | }, 17 | "dataset_params": 18 | { "class":"default", 19 | "num_workers": 2, 20 | "training_path": "tests/test_data/keag_small.csv", 21 | "validation_path": "tests/test_data/keag_small.csv", 22 | "test_path": "tests/test_data/keag_small.csv", 23 | "batch_size":4, 24 | "forecast_history":5, 25 | "forecast_length":5, 26 | "train_end": 102, 27 | "valid_start":301, 28 | "valid_end": 401, 29 | "test_end":400, 30 | "target_col": ["cfs"], 31 | "relevant_cols": ["cfs", "precip", "temp"], 32 | "scaler": "RobustScaler", 33 | "no_scale": true, 34 | "interpolate": false, 35 | "sort_column":"datetime", 36 | "feature_param": 37 | {"datetime_params":{ 38 | "hour":"cyclical" 39 | }} 40 | }, 41 | "training_params": 42 | { 43 | "criterion":"GaussianLoss", 44 | "optimizer": "Adam", 45 | "optim_params": 46 | { 47 | }, 48 | "criterion_params":{ 49 | "mu":0, 50 | "sigma":0 51 | }, 52 | "lr": 0.003, 53 | "epochs": 1, 54 | "batch_size": 3 55 | }, 56 | "GCS": false, 57 | 58 | "wandb": { 59 | "name": "flood_forecast_circleci", 60 | "project": "repo-flood_forecast", 61 | "tags": ["dummy_run", "circleci"] 62 | }, 63 | "forward_params":{}, 64 | "metrics":["GaussianLoss"], 65 | "inference_params": 66 | { "datetime_start":"2016-05-31", 67 | "hours_to_forecast":336, 68 | "test_csv_path":"tests/test_data/keag_small.csv", 69 | "decoder_params":{ 70 | "decoder_function": "simple_decode", 71 | "unsqueeze_dim": 1}, 72 | "dataset_params":{ 73 | "file_path": "tests/test_data/keag_small.csv", 74 | "forecast_history":5, 75 | "forecast_length":5, 76 | "no_scale": true, 77 | "relevant_cols": ["cfs", "precip", "temp"], 78 | "target_col": ["cfs"], 79 | "scaling": "RobustScaler", 80 | "interpolate_param": false, 81 | "sort_column":"datetime", 82 | "feature_params":{ 83 | "datetime_params":{ 84 | "hour":"cyclical" 85 | } 86 | } 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /tests/tsmixer_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "TSMixer", 3 | "use_decoder": true, 4 | "model_type": "PyTorch", 5 | "model_params": { 6 | "sequence_length":36, 7 | "input_channels": 3, 8 | "prediction_length": 10, 9 | "output_channels": 1 10 | 11 | }, 12 | "n_targets": 4, 13 | "dataset_params": 14 | { "class": "default", 15 | "training_path": "tests/test_data/keag_small.csv", 16 | "validation_path": "tests/test_data/keag_small.csv", 17 | "test_path": "tests/test_data/keag_small.csv", 18 | "batch_size":10, 19 | "forecast_history":36, 20 | "forecast_length":10, 21 | "train_start": 1, 22 | "train_end": 300, 23 | "valid_start":302, 24 | "valid_end": 401, 25 | "test_start":50, 26 | "test_end": 450, 27 | "target_col": ["cfs"], 28 | "relevant_cols": ["cfs", "precip", "temp"], 29 | "scaler": "StandardScaler", 30 | "interpolate": false 31 | }, 32 | "training_params": 33 | { 34 | "criterion":"MSE", 35 | "optimizer": "Adam", 36 | "optim_params": 37 | { 38 | 39 | }, 40 | "lr": 0.03, 41 | "epochs": 1, 42 | "batch_size":4 43 | 44 | }, 45 | "GCS": false, 46 | 47 | "wandb": { 48 | "name": "flood_forecast_circleci", 49 | "tags": ["dummy_run", "circleci"], 50 | "project": "repo-flood_forecast" 51 | }, 52 | "forward_params":{}, 53 | "metrics":["MSE"], 54 | "inference_params": 55 | { 56 | "datetime_start":"2016-05-31", 57 | "hours_to_forecast":334, 58 | "test_csv_path":"tests/test_data/keag_small.csv", 59 | "decoder_params":{ 60 | "decoder_function": "simple_decode", 61 | "unsqueeze_dim": 1} 62 | , 63 | "dataset_params":{ 64 | "file_path": "tests/test_data/keag_small.csv", 65 | "forecast_history":36, 66 | "forecast_length":10, 67 | "relevant_cols": ["cfs", "precip", "temp"], 68 | "target_col": ["cfs"], 69 | "scaling": "StandardScaler", 70 | "interpolate_param": false 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /tests/usgs_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import pandas as pd 4 | from flood_forecast.preprocessing.process_usgs import process_intermediate_csv 5 | from flood_forecast.preprocessing.interpolate_preprocess import fix_timezones 6 | from flood_forecast.preprocessing.interpolate_preprocess import interpolate_missing_values 7 | 8 | 9 | class DataQualityTests(unittest.TestCase): 10 | def setUp(self): 11 | # These are historical tests. 12 | self.test_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data") 13 | 14 | def test_intermediate_csv(self): 15 | df = pd.read_csv(os.path.join(self.test_data_path, "big_black_test_small.csv"), sep="\t") 16 | result_df, max_flow, min_flow = process_intermediate_csv(df) 17 | self.assertEqual(result_df.iloc[1]['datetime'].hour, 6) 18 | self.assertGreater(max_flow, 2640) 19 | self.assertLess(min_flow, 1600) 20 | 21 | def test_tz_interpolate_fix(self): 22 | """Additional function to test data interpolation.""" 23 | file_path = os.path.join(self.test_data_path, "river_test_sm.csv") 24 | test_d = pd.read_csv(file_path) 25 | revised_df = fix_timezones(test_d) 26 | self.assertEqual(revised_df.iloc[0]['cfs'], 0.0) 27 | self.assertEqual(revised_df.iloc[1]['tmpf'], 19.94) 28 | revised_df = interpolate_missing_values(revised_df) 29 | self.assertEqual(0, sum(pd.isnull(revised_df['cfs']))) 30 | self.assertEqual(0, sum(pd.isnull(revised_df['precip']))) 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /tests/validation_loop_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from flood_forecast.basic.linear_regression import SimpleLinearModel, handle_gaussian_loss 4 | from flood_forecast.meta_models.basic_ae import AE 5 | from flood_forecast.basic.base_line_methods import NaiveBase 6 | from flood_forecast.custom.custom_activation import _roll_last 7 | 8 | 9 | class TestBasicMethodVal(unittest.TestCase): 10 | 11 | def test_simple_linear_prob(self): 12 | s = SimpleLinearModel(9, 3, 1, True) 13 | r = s(torch.rand(4, 9, 3)) 14 | self.assertIsInstance(r, torch.distributions.Normal) 15 | 16 | def test_handle_gaussian_loss(self): 17 | result = handle_gaussian_loss((torch.rand(10, 2), torch.rand(10, 2))) 18 | print(result) 19 | 20 | def test_hano_scaling(self): 21 | n = NaiveBase(20, 10, 1) 22 | e = n(torch.rand(4, 20, 10)) 23 | self.assertEqual(e.shape[1], 1) 24 | 25 | def test_ae(self): 26 | ae = AE(9, 128) 27 | rep = ae.generate_representation(torch.rand(4, 9)) 28 | self.assertEqual(rep.shape, (4, 128)) 29 | 30 | def new_test(self): 31 | _roll_last(torch.rand(43, 4), 1) 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /tests/variable_autoencoderl.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "CustomTransformerDecoder", 3 | "model_type": "PyTorch", 4 | "model_params": { 5 | "n_time_series":3, 6 | "seq_length":6, 7 | "output_seq_length": 6, 8 | "output_dim": 3, 9 | "n_layers_encoder": 2, 10 | "squashed_embedding": true 11 | }, 12 | "dataset_params": 13 | { "class": "VariableSequenceLength", 14 | "task": "auto", 15 | "n_classes": 9, 16 | "pad_len": 6, 17 | "training_path": "tests/test_data2/test_csv.csv", 18 | "validation_path": "tests/test_data2/test_csv.csv", 19 | "test_path": "tests/test_data2/test_csv.csv", 20 | "forecast_length":6, 21 | "forecast_history":6, 22 | "train_end": 301, 23 | "valid_start":0, 24 | "valid_end": 300, 25 | "test_end": 303, 26 | "target_col": ["playId", "yardlineNumber", "yardsToGo"], 27 | "relevant_cols": ["playId", "yardlineNumber", "yardsToGo"], 28 | "series_marker_column":"playId", 29 | "scaler": "StandardScaler", 30 | "interpolate": false 31 | }, 32 | "n_targets":3, 33 | "training_params": 34 | { 35 | "criterion":"MSE", 36 | "batch_size":20, 37 | "optimizer": "Adam", 38 | "optim_params": 39 | {}, 40 | "lr": 0.01, 41 | "epochs": 3 42 | }, 43 | "GCS": false, 44 | 45 | "wandb": { 46 | "name": "flood_forecast_circleci", 47 | "tags": ["dummy_run", "circleci", "multi_head", "classification"], 48 | "project": "repo-flood_forecast" 49 | }, 50 | "forward_params":{}, 51 | "metrics":["MSE"] 52 | } 53 | --------------------------------------------------------------------------------