├── requirements.txt
├── AR-Net_paper.pdf
├── .gitattributes
├── scripts
├── pre_commit.bash
├── pre_push.bash
├── install_hooks.bash
└── arnet_dev_setup
├── .gitignore
├── pyproject.toml
├── setup.py
├── LICENSE
├── arnet
├── __init__.py
├── utils.py
├── plotting.py
├── utils_data.py
├── ar_net_legacy.py
├── fastai_mods.py
├── create_ar_data.py
└── ar_net.py
├── v0_1
├── model.py
├── example.py
├── data_loader.py
├── utils.py
└── training.py
├── README.md
├── tests
├── test_unit.py
├── test_integration.py
└── test_legacy.py
└── example_notebooks
├── create_ar_data.ipynb
├── make_dataset.ipynb
├── 01_fit_arnet.ipynb
└── legacy_run_experiments.ipynb
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastai>=2.1.4
2 | statsmodels>=0.12.1
3 | seaborn>=0.11.0
--------------------------------------------------------------------------------
/AR-Net_paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ourownstory/AR-Net/HEAD/AR-Net_paper.pdf
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Jupyter notebook language stats
2 |
3 | # For text count
4 | # *.ipynb text
5 |
6 | # To ignore it use below
7 | *.ipynb linguist-documentation
8 |
--------------------------------------------------------------------------------
/scripts/pre_commit.bash:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 |
5 | files=$(git diff --staged --name-only --diff-filter=d -- "*.py")
6 | for file in $files; do
7 | black "$file"
8 | git add "$file"
9 | done
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __py*
2 | *.egg*
3 | *.idea*
4 | *.log
5 | **/*.pyc
6 | **/__pycache__/
7 | **/*.gz
8 | **/*.whl
9 | **/*egg-info/
10 | site/
11 | *.ipynb_checkpoints*
12 | ar_data/
13 | results/
14 | */models/*
15 |
--------------------------------------------------------------------------------
/scripts/pre_push.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | echo "Running pre-push hook: unittests"
4 | if ! python3 -m unittest discover -s tests;
5 | then
6 | echo "Failed tests. Unittests must pass before push!"
7 | exit 1
8 | fi
--------------------------------------------------------------------------------
/scripts/install_hooks.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | echo "Installing hooks..."
4 | GIT_DIR=$(git rev-parse --git-dir)
5 | # create symlink to our pre-commit and pre-push scripts
6 | ln -s ../../scripts/pre_commit.bash "$GIT_DIR"/hooks/pre-commit
7 | ln -s ../../scripts/pre_push.bash "$GIT_DIR"/hooks/pre-push
8 | # make the symlinks executable
9 | chmod a+rwx "$GIT_DIR"/hooks/pre-commit
10 | chmod a+rwx "$GIT_DIR"/hooks/pre-push
11 | echo "Done!"
12 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 | target-version = ['py38']
4 | include = '\.pyi?$'
5 | exclude = '''
6 |
7 | (
8 | /(
9 | \.eggs # exclude a few common directories in the
10 | | \.git # root of the project
11 | | \.hg
12 | | \.mypy_cache
13 | | \.tox
14 | | \.venv
15 | | _build
16 | | buck-out
17 | | build
18 | | dist
19 | | site
20 | | ar_data # exclude our project directories that are not current source code.
21 | | v0_1
22 | | v1_0
23 | | results
24 | )/
25 | | .gitignore
26 | )
27 | '''
--------------------------------------------------------------------------------
/scripts/arnet_dev_setup:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import subprocess
5 |
6 |
7 | def install_hooks():
8 | dir_scripts = os.path.abspath(os.path.dirname(__file__))
9 | script_files = [
10 | "install_hooks.bash",
11 | "pre_commit.bash",
12 | "pre_push.bash",
13 | ]
14 | for script_f in script_files:
15 | file = os.path.join(dir_scripts, script_f)
16 | subprocess.check_call(["chmod", "a+rwx", file])
17 | subprocess.call(os.path.join(dir_scripts, "install_hooks.bash"), shell=True)
18 |
19 |
20 | if __name__ == "__main__":
21 | install_hooks()
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import setuptools
3 |
4 | dir_repo = os.path.abspath(os.path.dirname(__file__))
5 | # read the contents of REQUIREMENTS file
6 | with open(os.path.join(dir_repo, "requirements.txt"), "r") as f:
7 | requirements = f.read().splitlines()
8 | # read the contents of README file
9 | with open(os.path.join(dir_repo, "README.md"), encoding="utf-8") as f:
10 | readme = f.read()
11 |
12 | setuptools.setup(
13 | name="arnet",
14 | version="1.2.0",
15 | description="A simple auto-regressive Neural Network for time-series",
16 | author="Oskar Triebe",
17 | url="https://github.com/ourownstory/AR-Net",
18 | packages=setuptools.find_packages(),
19 | python_requires=">=3.8",
20 | install_requires=requirements,
21 | extras_require={
22 | "dev": ["black"],
23 | },
24 | # setup_requires=[""],
25 | scripts=["scripts/arnet_dev_setup"],
26 | long_description=readme,
27 | long_description_content_type="text/markdown",
28 | )
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Oskar Triebe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/arnet/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | log = logging.getLogger("ARNet")
4 | log.setLevel("INFO")
5 | # Create handlers
6 | c_handler = logging.StreamHandler()
7 | f_handler = logging.FileHandler("logs.log", "w+")
8 | # c_handler.setLevel("WARNING")
9 | # f_handler.setLevel("INFO")
10 | # Create formatters and add it to handlers
11 | c_format = logging.Formatter("%(levelname)s: %(name)s - %(funcName)s: %(message)s")
12 | c_handler.setFormatter(c_format)
13 | log.addHandler(c_handler)
14 |
15 | # uncomment for a log file
16 | # f_format = logging.Formatter("%(asctime)s; %(levelname)s; %(name)s; %(funcName)s; %(message)s")
17 | # f_handler.setFormatter(f_format)
18 | # log.addHandler(f_handler)
19 |
20 | # lazy imports ala fastai2 style (for nice print functionality)
21 | from fastai.basics import *
22 | from fastai.tabular.all import *
23 |
24 | from .ar_net import ARNet
25 |
26 | # from .ar_net_legacy import init_ar_learner
27 | from .utils_data import load_from_file, tabularize_univariate, estimate_noise, split_by_p_valid
28 | from .utils import pad_ar_params, nice_print_list, compute_sTPE, coeff_from_model
29 | from .plotting import plot_weights, plot_prediction_sample, plot_error_scatter
30 |
--------------------------------------------------------------------------------
/v0_1/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class DAR(nn.Module):
6 | '''
7 | A simple, general purpose, fully connected network
8 | '''
9 |
10 | def __init__(self, ar, num_layers=1, d_hidden=None):
11 | # Perform initialization of the pytorch superclass
12 | super().__init__()
13 | # Define network layer dimensions
14 | d_in, d_out = [ar, 1]
15 | self.ar = ar
16 | self.num_layers = num_layers
17 | if d_hidden is None and num_layers > 1:
18 | d_hidden = d_in
19 | if self.num_layers == 1:
20 | self.layer_1 = nn.Linear(d_in, d_out, bias=True)
21 | else:
22 | self.layer_1 = nn.Linear(d_in, d_hidden, bias=True)
23 | self.mid_layers = []
24 | for i in range(self.num_layers - 2):
25 | self.mid_layers.append(nn.Linear(d_hidden, d_hidden, bias=True))
26 | self.layer_out = nn.Linear(d_hidden, d_out, bias=True)
27 |
28 | def forward(self, x):
29 | '''
30 | This method defines the network layering and activation functions
31 | '''
32 | activation = F.relu
33 | x = self.layer_1(x)
34 | if self.num_layers > 1:
35 | x = activation(x)
36 | for layer in self.mid_layers:
37 | x = layer(x)
38 | x = activation(x)
39 | x = self.layer_out(x)
40 | return x
41 |
42 |
43 | def main():
44 | pass
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://github.com/psf/black)
2 |
3 | # AR-Net
4 | A simple auto-regressive Neural Network for time-series ([link to paper](https://arxiv.org/abs/1911.12436)).
5 |
6 | ## Install
7 | After downloading the code repository (via `git clone`), change to the repository directory (`cd AR-Net`)
8 | and install arnet as python package with `pip install .`
9 |
10 | ## Use
11 | View the notebook [`example_notebooks/arnet.ipynb`](example_notebooks/01_fit_arnet.ipynb) for an example of how to use the model.
12 |
13 | ## Versions
14 | ### Current (1.2)
15 | The version 1.0 made the model easier to use with your own datasets and requires less hyperparameters
16 | for a simpler training procedure. It is built on the fastai library.
17 |
18 | Changes (1.1 -> 1.2):
19 | * simplified UI with ARNet as object
20 | * GPU support
21 | * robustified training
22 | * added test cases
23 | * updated example notebooks
24 |
25 | Changes (1.0 -> 1.1):
26 | * port [beta fastai2](https://github.com/fastai/fastai2) to it's current [stable release](https://github.com/fastai/fastai)
27 | * make install as pip package possible
28 | * add black code formatter (and git pre-commit hook)
29 | * add unittests (and git pre-push hook)
30 | * fix issues with new fastai api
31 | * remove old code fragments
32 |
33 | ### Pure PyTorch (0.1)
34 | Version 0.1 was based on Pytorch and you can still use it if you do not want to use fastai.
35 |
36 | See file [`v0_1/example.py`](v0_1/example.py) for how to use the v0.1 model.
37 |
38 | ### Now also part of NeuralProphet
39 | AR-Net is now part of a more comprehensive package [NeuralProphet](https://github.com/ourownstory/neural_prophet).
40 |
41 | I strongly recommend using it instead of the standalone version, unless you specifically want to use AR-Net,
42 | which may make sense if you need to model a highly-autoregressive time-series with sparse long-range dependencies.
43 |
--------------------------------------------------------------------------------
/arnet/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def pad_ar_params(ar_params, n_lags, n_forecasts=1):
6 | """ "
7 | pads ar_parameter lists to the length of n_lags
8 | ar_params: list of length n_forecasts with elements: lists of ar coeffs
9 | n_lags: length to which pad each of the ar coeffs
10 | """
11 | assert n_forecasts == len(ar_params)
12 | if n_forecasts != 1:
13 | if all([isinstance(ar_params[i], list) for i in range(n_forecasts)]):
14 | return [pad_ar_params([ar_params[i]], n_lags, 1)[0] for i in range(n_forecasts)]
15 | else:
16 | raise NotImplementedError("AR Coeff for each of the forecast targets are needed")
17 | return [ar_params[0] + [0.0] * (n_lags - len(ar_params[0]))]
18 |
19 |
20 | def nice_print_list(data):
21 | if all([isinstance(data[i], list) for i in range(len(data))]):
22 | return [nice_print_list(data[i]) for i in range(len(data))]
23 | return ["{:.3f}".format(x) for x in data]
24 | # return [["{:.2f}".format(x) for x in sublist] for sublist in data]
25 |
26 |
27 | def compute_sTPE(est, real):
28 | est, real = np.array(est), np.array(real)
29 | sum_abs_diff = np.sum(np.abs(est - real))
30 | sum_abs = np.sum(np.abs(est) + np.abs(real))
31 | return 100.0 * sum_abs_diff / (10e-9 + sum_abs)
32 |
33 |
34 | def coeff_from_model(model, reversed_weights=True):
35 | for layer in model.modules():
36 | if isinstance(layer, torch.nn.Linear):
37 | weights = [list(x[::-1] if reversed_weights else x) for x in layer.weight.detach().cpu().numpy()]
38 | return weights # note: preliminary exit of loop is a feature.
39 |
40 |
41 | def set_logger_level(logger, log_level=None, include_handlers=False):
42 | if log_level is None:
43 | logger.warning("Failed to set log_level to None.")
44 | elif log_level not in ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", 10, 20, 30, 40, 50):
45 | logger.error(
46 | "Failed to set log_level to {}."
47 | "Please specify a valid log level from: "
48 | "'DEBUG', 'INFO', 'WARNING', 'ERROR' or 'CRITICAL'"
49 | "".format(log_level)
50 | )
51 | else:
52 | logger.setLevel(log_level)
53 | if include_handlers:
54 | for h in logger.handlers:
55 | h.setLevel(log_level)
56 | logger.debug("Set log level to {}".format(log_level))
57 |
--------------------------------------------------------------------------------
/arnet/plotting.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 | import os
5 |
6 |
7 | def plot_weights(ar_val, weights, ar=None, model_name="AR-Net", save=False, savedir="results", figsize=(10, 4)):
8 | if ar is not None:
9 | df = pd.DataFrame(
10 | zip(
11 | list(range(1, ar_val + 1)) * 2,
12 | ["AR-Process (True)"] * ar_val + [model_name] * ar_val,
13 | list(ar) + list(weights),
14 | ),
15 | columns=["AR-coefficient (lag number)", "model", "value (weight)"],
16 | )
17 | plt.figure(figsize=figsize)
18 | palette = {"Classic-AR": "C0", "AR-Net": "C1", "AR-Process (True)": "k"}
19 | sns.barplot(data=df, palette=palette, x="AR-coefficient (lag number)", hue="model", y="value (weight)")
20 | else:
21 | df = pd.DataFrame(
22 | zip(
23 | list(range(1, ar_val + 1)),
24 | [model_name] * ar_val,
25 | list(weights),
26 | ),
27 | columns=["AR-coefficient (lag number)", "model", "value (weight)"],
28 | )
29 | plt.figure(figsize=figsize)
30 | sns.barplot(data=df, x="AR-coefficient (lag number)", hue="model", y="value (weight)")
31 | if save:
32 | if not os.path.exists(savedir):
33 | os.makedirs(savedir)
34 | figname = "weights_{}_{}.png".format(ar_val, model_name)
35 | plt.savefig(os.path.join(savedir, figname), dpi=300, bbox_inches="tight")
36 | else:
37 | plt.show()
38 |
39 |
40 | def plot_prediction_sample(predicted, actual, model_name="AR-Net", save=False, savedir="results"):
41 | fig2 = plt.figure()
42 | fig2.set_size_inches(10, 6)
43 | plt.plot(actual)
44 | plt.plot(predicted)
45 | plt.legend(["Actual Time-Series", "{}-Prediction".format(model_name)])
46 | if save:
47 | if not os.path.exists(savedir):
48 | os.makedirs(savedir)
49 | figname = "prediction_{}.png".format(model_name)
50 | plt.savefig(os.path.join(savedir, figname), dpi=600, bbox_inches="tight")
51 | else:
52 | plt.show()
53 |
54 |
55 | def plot_error_scatter(predicted, actual, model_name="AR-Net", save=False, savedir="results"):
56 | # error = predicted - actual
57 | fig3 = plt.figure()
58 | fig3.set_size_inches(6, 6)
59 | plt.scatter(actual, predicted - actual, marker="o", s=10, alpha=0.3)
60 | plt.legend(["{}-Error".format(model_name)])
61 | if save:
62 | if not os.path.exists(savedir):
63 | os.makedirs(savedir)
64 | figname = "scatter_{}.png".format(model_name)
65 | plt.savefig(os.path.join(savedir, figname), dpi=600, bbox_inches="tight")
66 | else:
67 | plt.show()
68 |
--------------------------------------------------------------------------------
/tests/test_unit.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import unittest
4 | import os
5 | import pathlib
6 | import shutil
7 | import logging
8 | import random
9 | import pandas as pd
10 | import numpy as np
11 | import warnings
12 |
13 | warnings.filterwarnings("ignore", message=".*nonzero.*", category=UserWarning)
14 |
15 | import arnet
16 | from arnet.create_ar_data import generate_armaprocess_data
17 | from arnet.create_ar_data import save_to_file, load_from_file
18 |
19 | log = logging.getLogger("ARNet.test")
20 | log.setLevel("DEBUG")
21 | log.parent.setLevel("WARNING")
22 |
23 | DIR = pathlib.Path(__file__).parent.absolute()
24 | data_path = os.path.join(DIR, "ar_data")
25 | results_path = os.path.join(DIR, "results")
26 | EPOCHS = 2
27 |
28 |
29 | class UnitTests(unittest.TestCase):
30 | save = True
31 |
32 | def test_create_data_random(self):
33 | # option 1: Randomly generated AR parameters
34 | data_config = {
35 | "samples": 1000,
36 | "noise_std": 0.1,
37 | "ar_order": 3,
38 | "ma_order": 0,
39 | "params": None, # for randomly generated AR params
40 | }
41 | log.debug("{}".format(data_config))
42 |
43 | # Generate data
44 | series, data_config["ar_params"], data_config["ma_params"] = generate_armaprocess_data(**data_config)
45 |
46 | if self.save:
47 | del data_config["params"]
48 | data_name = save_to_file(data_path, series, data_config)
49 | # just to test:
50 | df, data_config2 = load_from_file(data_path, data_name, load_config=True)
51 | log.debug("loaded from saved files:")
52 | log.debug("{}".format(data_config2))
53 | log.debug("{}".format(df.head()))
54 |
55 | def test_create_data_manual(self):
56 | # option 1: Randomly generated AR parameters
57 | # option 2: Manually define AR parameters
58 | data_config = {
59 | "samples": 1000,
60 | "noise_std": 0.1,
61 | "params": ([0.2, 0.3, -0.5], []),
62 | # "params": ([0.2, 0, 0.3, 0, 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], []),
63 | }
64 | data_config["ar_order"] = int(sum(np.array(data_config["params"][0]) != 0.0))
65 | data_config["ma_order"] = int(sum(np.array(data_config["params"][1]) != 0.0))
66 | log.debug("{}".format(data_config))
67 |
68 | # Generate data
69 | series, data_config["ar_params"], data_config["ma_params"] = generate_armaprocess_data(**data_config)
70 |
71 | if self.save:
72 | del data_config["params"]
73 | data_name = save_to_file(data_path, series, data_config)
74 | # just to test:
75 | df, data_config2 = load_from_file(data_path, data_name, load_config=True)
76 | log.debug("loaded from saved files:")
77 | log.debug("{}".format(data_config2))
78 | log.debug("{}".format(df.head()))
79 |
--------------------------------------------------------------------------------
/arnet/utils_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import logging
4 |
5 | from arnet.create_ar_data import load_from_file
6 |
7 | log = logging.getLogger("ARNet")
8 |
9 |
10 | def estimate_noise(series):
11 | return float(np.mean(np.abs(series.iloc[:-1].values - series.iloc[1:].values)))
12 |
13 |
14 | def split_by_p_valid(valid_p, n_sample, verbose=False):
15 | split_idx = int(n_sample * (1 - valid_p))
16 | splits = [list(range(split_idx)), list(range(split_idx, n_sample))]
17 | if verbose:
18 | print("split on idx: ", split_idx)
19 | print("split sizes: ", [len(x) for x in splits])
20 | return splits
21 |
22 |
23 | def tabularize_univariate(series, n_lags, n_forecasts=1, nested_list=False):
24 | """
25 | Create a tabular dataset with ar_order lags for supervised forecasting
26 | Arguments:
27 | series: Sequence of observations as a Pandas DataFrame
28 | n_lags: Number of lag observations as input (X).
29 | n_forecasts: Number of observations as output (y).
30 | Returns:
31 | df: Pandas DataFrame of input lags and forecast values (as nested lists)
32 | shape (n_samples, 2).
33 | Cols: "x": list(n_lags)
34 | Cols: "y": list(n_forecasts)
35 | """
36 | n_samples = len(series) - n_lags + 1 - n_forecasts
37 |
38 | x = pd.DataFrame([series.iloc[i : i + n_lags, 0].values for i in range(n_samples)])
39 | y = pd.DataFrame([series.iloc[i + n_lags : i + n_lags + n_forecasts, 0].values for i in range(n_samples)])
40 | if nested_list:
41 | df = pd.concat([x.apply(list, axis=1), y.apply(list, axis=1)], axis=1)
42 | df.columns = ["x", "y"]
43 | else:
44 | df = pd.concat([x, y], axis=1)
45 | df.columns = ["x_{}".format(num) for num in list(range(len(x.columns)))] + [
46 | "y_{}".format(num) for num in list(range(len(y.columns)))
47 | ]
48 | return df
49 |
50 |
51 | def main():
52 | verbose = True
53 | data_path = "ar_data"
54 | data_name = "ar_3_ma_0_noise_0.100_len_10000"
55 |
56 | ## if created AR data with create_ar_data, we can use the helper function:
57 | df, data_config = load_from_file(data_path, data_name, load_config=True)
58 | n_lags = data_config["ar_order"]
59 |
60 | ## else we can manually load any file that stores a time series, for example:
61 | # df = pd.read_csv(os.path.join(data_path, data_name + '.csv'), header=None, index_col=False)
62 | # n_lags = 3
63 |
64 | if verbose:
65 | print(data_config)
66 | print(df.shape)
67 |
68 | ## create a tabularized dataset from time series
69 | df_tab = tabularize_univariate(
70 | df,
71 | n_lags=n_lags,
72 | n_forecasts=1,
73 | nested_list=False,
74 | )
75 |
76 | if verbose:
77 | print("tabularized df")
78 | print(df_tab.shape)
79 | # print(df_tab.columns)
80 | # if nested_list:
81 | # print("x_dim:", len(df_tab['x'][0]), "y_dim:", len(df_tab['y'][0]))
82 | print(df_tab.head())
83 |
84 |
85 | if __name__ == "__main__":
86 | main()
87 |
--------------------------------------------------------------------------------
/arnet/ar_net_legacy.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | import logging
3 |
4 | import fastai
5 |
6 | # lazy imports ala fastai2 style (for nice print functionality)
7 | from fastai.basics import *
8 | from fastai.tabular.all import *
9 |
10 | # explicit imports (for reference)
11 | # from fastai.data.core import DataLoaders
12 | # from fastai.metrics import mse, mae
13 | # from fastai.tabular.core import TabularPandas, TabDataLoader
14 | # from fastai.tabular.learner import tabular_learner
15 | # from fastai.data.transforms import Normalize
16 |
17 |
18 | ## import arnet
19 | from arnet.utils_data import load_from_file, tabularize_univariate, estimate_noise, split_by_p_valid
20 | from arnet.utils import pad_ar_params, nice_print_list, compute_sTPE, coeff_from_model
21 | from arnet import utils, plotting
22 | from arnet.fastai_mods import SparsifyAR, sTPE
23 |
24 | log = logging.getLogger("ARNet")
25 |
26 |
27 | def init_ar_learner(
28 | series,
29 | ar_order,
30 | n_forecasts=1,
31 | valid_p=0.1,
32 | sparsity=None,
33 | ar_params=None,
34 | train_bs=32,
35 | valid_bs=128,
36 | verbose=False,
37 | ):
38 | if sparsity is not None and sparsity == 1.0:
39 | sparsity = None
40 | df_all = tabularize_univariate(series, ar_order, n_forecasts)
41 | est_noise = estimate_noise(series)
42 |
43 | if verbose:
44 | print("tabularized df")
45 | print("df columns", list(df_all.columns))
46 | print("df shape", df_all.shape)
47 | # if nested_list: print("x_dim:", len(df_all['x'][0]), "y_dim:", len(df_all['y'][0]))
48 | # print("df head(3)", df_all.head(3))
49 | print("estimated noise of series", est_noise)
50 |
51 | ## split
52 | splits = split_by_p_valid(valid_p, len(df_all), verbose)
53 |
54 | cont_names = [col for col in list(df_all.columns) if "x_" == col[:2]]
55 | target_names = [col for col in list(df_all.columns) if "y_" == col[:2]]
56 |
57 | ## preprocess?
58 | # procs = [Normalize]
59 | procs = []
60 |
61 | tp = TabularPandas(df_all, procs=procs, cat_names=None, cont_names=cont_names, y_names=target_names, splits=splits)
62 | if verbose:
63 | print("cont var num", len(tp.cont_names), tp.cont_names)
64 | # print(tp.iloc[0:5])
65 |
66 | ### next: data loader, learner
67 | trn_dl = TabDataLoader(tp.train, bs=train_bs, shuffle=True, drop_last=True)
68 | val_dl = TabDataLoader(tp.valid, bs=valid_bs)
69 | dls = DataLoaders(trn_dl, val_dl)
70 |
71 | # if verbose:
72 | # print("showing batch")
73 | # print(dls.show_batch(show=False))
74 |
75 | callbacks = []
76 | if sparsity is not None:
77 | callbacks.append(SparsifyAR(sparsity, est_noise))
78 | if verbose:
79 | print("reg lam: ", callbacks[0].lam)
80 |
81 | metrics = [mae]
82 | if ar_params is not None:
83 | metrics.append(sTPE(ar_params, at_epoch_end=False))
84 |
85 | tm_config = {"use_bn": False, "bn_final": False, "bn_cont": False}
86 | learn = tabular_learner(
87 | dls,
88 | layers=[], # Note: None defaults to [200, 100]
89 | config=tm_config, # None calls tabular_config()
90 | n_out=len(target_names), # None calls get_c(dls)
91 | train_bn=False, # passed to Learner
92 | metrics=metrics, # passed on to TabularLearner, to parent Learner
93 | loss_func=mse,
94 | cbs=callbacks,
95 | )
96 | if verbose:
97 | print(learn.model)
98 | return learn
99 |
--------------------------------------------------------------------------------
/tests/test_integration.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import unittest
4 | import os
5 | import pathlib
6 | import shutil
7 | import logging
8 | import random
9 | import pandas as pd
10 | import warnings
11 |
12 | warnings.filterwarnings("ignore", message=".*nonzero.*", category=UserWarning)
13 | import arnet
14 |
15 | log = logging.getLogger("ARNet.test")
16 | log.setLevel("WARNING")
17 | log.parent.setLevel("WARNING")
18 |
19 | DIR = pathlib.Path(__file__).parent.absolute()
20 | data_path = os.path.join(DIR, "ar_data")
21 | results_path = os.path.join(DIR, "results_test")
22 | AR_FILE = "ar_3_ma_0_noise_0.100_len_1000"
23 | EPOCHS = 2
24 |
25 |
26 | class IntegrationTests(unittest.TestCase):
27 | plot = False
28 | save = False
29 |
30 | def test_random_data(self):
31 | df = pd.DataFrame({"x": [random.gauss(0.0, 1.0) for i in range(1000)]})
32 | m = arnet.ARNet(ar_order=3, n_epoch=3)
33 | m.tabularize(df)
34 | m.make_datasets()
35 | m.create_regularizer(sparsity=0.3)
36 | m.create_learner()
37 | m.find_lr(plot=False)
38 | m.fit_one_cycle(cycles=2, plot=False)
39 | log.info("coeff of random data: {}".format(m.coeff))
40 |
41 | def test_plot(self):
42 | if not os.path.exists(results_path):
43 | os.makedirs(results_path)
44 |
45 | df = pd.DataFrame({"x": [random.gauss(0.0, 1.0) for i in range(1000)]})
46 | m = arnet.ARNet(ar_order=3, n_epoch=3)
47 | m.fit(series=df)
48 | if self.plot:
49 | m.learn.recorder.plot_loss()
50 | m.plot_weights(save=True, savedir=results_path)
51 | m.plot_fitted_obs(num_obs=100, save=True, savedir=results_path)
52 | m.plot_errors(save=True, savedir=results_path)
53 |
54 | shutil.rmtree(results_path)
55 |
56 | def test_save_load(self):
57 | if not os.path.exists(results_path):
58 | os.makedirs(results_path)
59 |
60 | df = pd.DataFrame({"x": [random.gauss(0.0, 1.0) for i in range(1000)]})
61 | m = arnet.ARNet(ar_order=3, n_epoch=3)
62 | m.fit(series=df)
63 |
64 | # Optional:save and create inference learner
65 | sparsity = 1.0 if m.sparsity is None else m.sparsity
66 | model_name = "ar{}_sparse_{:.3f}_ahead_{}_epoch_{}.pkl".format(m.ar_order, sparsity, m.n_forecasts, m.n_epoch)
67 | m.save_model(results_path=results_path, model_name=model_name)
68 | # can be loaded like this
69 | m.load_model(results_path, model_name)
70 | # can unfreeze the model and fine_tune
71 | log.info("loaded coeff: {}".format(m.coeff))
72 |
73 | shutil.rmtree(results_path)
74 |
75 | def test_ar_data(self):
76 | df, data_config = arnet.load_from_file(data_path, AR_FILE, load_config=True, verbose=False)
77 | df = df[:1000]
78 |
79 | # Hyperparameters
80 | sparsity = 0.3
81 | ar_order = int(1 / sparsity * data_config["ar_order"]) # sparse AR: (for non-sparse, set sparsity to 1.0)
82 | ar_params = arnet.pad_ar_params([data_config["ar_params"]], ar_order, 1) # to compute stats
83 |
84 | # run
85 | m = arnet.ARNet(
86 | ar_order=ar_order,
87 | n_epoch=EPOCHS,
88 | sparsity=sparsity,
89 | ar_params=ar_params,
90 | )
91 | m.fit(series=df)
92 |
93 | # Look at Coeff
94 | log.info("ar params: {}".format(arnet.nice_print_list(ar_params)))
95 | log.info("model weights: {}".format(arnet.nice_print_list(m.coeff)))
96 |
--------------------------------------------------------------------------------
/v0_1/example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from data_loader import load_data
3 | from training import run as run_training
4 |
5 | from v0_1_pure_pytorch import utils
6 |
7 |
8 | def load_config(verbose=False, random=True):
9 | # load specified settings
10 |
11 | #### Data settings ####
12 | data_config = {
13 | "type": 'AR',
14 | "ar_val": 3,
15 | "pad_to": 10, # set to >ar_val for sparse AR estimation
16 | "ar_params": None, # for randomly generated AR params
17 | "noise_std": 1.0,
18 | "test": 0.2,
19 | "n_samples": int(1.25e5), # for 1e5 train size
20 | }
21 |
22 | # OR manually define AR params:
23 | if not random:
24 | # data_config["ar_params"] = [0.2, 0.3, -0.5]
25 | # Alternative: sparse AR params:
26 | data_config["ar_params"] = [0.2, 0, 0.3, 0, 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
27 |
28 | # correct settings if manually set
29 | data_config["ar_val"] = sum(np.array(data_config["ar_params"]) != 0.0)
30 | data_config["pad_to"] = int(len(data_config["ar_params"]))
31 |
32 | #### Model settings ####
33 | model_config = {
34 | "ar": data_config["ar_val"], # for normal AR
35 | "ma": 0,
36 | "num_layers": 1,
37 | "d_hidden": None
38 | }
39 | if data_config["pad_to"] is not None and data_config["pad_to"] > data_config["ar_val"]:
40 | model_config["ar"] = data_config["pad_to"] # for sparse AR
41 |
42 | #### Train settings ####
43 | train_config = {
44 | "lr": 2e-4,
45 | "lr_decay": 0.9,
46 | "epochs": 10,
47 | "batch": 128,
48 | "est_sparsity": 1, # 0 = fully sparse, 1 = not sparse
49 | "lambda_delay": 10, # delays start of regularization by lambda_delay epochs
50 | }
51 | # For auto-regularization based on sparsity estimation:
52 | if data_config["pad_to"] is not None and data_config["pad_to"] > data_config["ar_val"]:
53 | train_config["est_sparsity"] = data_config["ar_val"] / (1.0 * data_config["pad_to"])
54 |
55 | # Note: find the right learning rate range with a learning rate range test
56 | # e.g. a LR range test on random AR data (with 5e5 data, batch 64, pad_to 100) led to
57 | # ---> min 5e-7, max 5e-4
58 |
59 | if verbose:
60 | print("data_config\n", data_config)
61 | print("model_config\n", model_config)
62 | print("train_config\n", train_config)
63 |
64 | return data_config, model_config, train_config
65 |
66 |
67 | def main(verbose=False, plot=False, save=False, random_ar_param=True):
68 | # load configuration dicts. Could be implemented to load from JSON instead.
69 | data_config, model_config, train_config = load_config(verbose, random_ar_param)
70 | # loads randomly generated data. Could be implemented to load a specific dataset instead.
71 | data = load_data(data_config, verbose, plot)
72 | # runs training and testing.
73 | results_dar, stats_dar = run_training(data, model_config, train_config, verbose)
74 |
75 | # optional printing
76 | if verbose:
77 | print(stats_dar)
78 |
79 | # optional plotting
80 | if plot:
81 | utils.plot_loss_curve(
82 | losses=results_dar["losses"],
83 | test_loss=results_dar["test_mse"],
84 | epoch_losses=results_dar["epoch_losses"],
85 | show=False,
86 | save=save
87 | )
88 | utils.plot_weights(
89 | model_config["ar"],
90 | results_dar["weights"],
91 | data["ar"],
92 | model_name="AR-Net",
93 | save=save
94 | )
95 | utils.plot_results(
96 | results_dar,
97 | model_name="AR-Net",
98 | save=save
99 | )
100 |
101 |
102 | if __name__ == "__main__":
103 | main(verbose=True, plot=True, save=True, random_ar_param=False)
104 |
--------------------------------------------------------------------------------
/arnet/fastai_mods.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import torch.nn.functional as F
4 | import fastai
5 | from fastai.torch_core import flatten_check
6 | from fastai.metrics import mse, mae
7 | from fastai.basics import Callback
8 | from fastai.learner import Metric
9 |
10 | from arnet import utils
11 |
12 | log = logging.getLogger("ARNet.fastai_mods")
13 |
14 |
15 | def huber(inp, targ):
16 | """Huber error between `inp` and `targ`."""
17 | return F.smooth_l1_loss(*flatten_check(inp, targ))
18 |
19 |
20 | class SparsifyAR(Callback):
21 | """Callback that adds regularization of first linear layer according to AR-Net paper"""
22 |
23 | def __init__(
24 | self,
25 | est_sparsity,
26 | est_noise=1.0,
27 | reg_strength=0.01,
28 | start_pct=0.0,
29 | full_pct=0.5,
30 | c1=2.0,
31 | c2=2.0,
32 | **kwargs,
33 | ):
34 | super().__init__(**kwargs)
35 | self.lam_max = 0.0
36 | est_noise = 1.0 if est_noise is None else est_noise
37 | if est_sparsity is not None:
38 | assert 1 >= est_sparsity > 0
39 | self.lam_max = reg_strength * est_noise * ((1.0 + 1e-9) / (est_sparsity + 1e-9) - 1.0)
40 | self.start_pct = start_pct
41 | self.full_pct = full_pct
42 | # note: implementation in paper used c1 = 3.0, c2 = 3.0
43 | self.c1 = c1
44 | self.c2 = c2
45 | self.lam = None
46 |
47 | def after_loss(self):
48 | if not self.training:
49 | return
50 | if self.lam_max == 0.0 or self.lam == 0.0:
51 | return
52 | abs_weights = None
53 | for layer in self.learn.model.modules():
54 | if isinstance(layer, torch.nn.Linear):
55 | abs_weights = torch.abs(layer.weight)
56 | break
57 | if abs_weights is None:
58 | raise NotImplementedError("weight regualarization only implemented for model with Linear layer")
59 | reg = torch.div(2.0, 1.0 + torch.exp(-self.c1 * abs_weights.pow(1.0 / self.c2))) - 1.0
60 |
61 | progress_iter = (1.0 + self.learn.iter) / (1.0 * self.learn.n_iter)
62 | progress = (progress_iter + self.learn.epoch) / (1.0 * self.learn.n_epoch)
63 | progress = (progress - self.start_pct) / (self.full_pct - self.start_pct)
64 | if progress <= 0:
65 | self.lam = 0.0
66 | elif progress < 1:
67 | self.lam = self.lam_max * progress ** 2
68 | else:
69 | self.lam = self.lam_max
70 |
71 | self.learn.loss += self.lam * torch.mean(reg)
72 |
73 | _docs = dict(after_loss="Add regularization of first linear layer")
74 |
75 |
76 | class sTPE(Metric):
77 | """ "
78 | Symmetrical Total Percentage Error of learned weights compared to underlying AR coefficients.
79 | Computed as the average over snapshots at each batch.
80 | """
81 |
82 | def __init__(self, ar_params, at_epoch_end=False):
83 | self.ar_params = ar_params
84 | self.at_epoch_end = at_epoch_end
85 | self.total = 0.0
86 | self.count = 0
87 | self.sTPE = None
88 |
89 | def reset(self):
90 | self.total = 0.0
91 | self.count = 0
92 | self.sTPE = None
93 |
94 | def accumulate(self, learn):
95 | self.sTPE = fastai.torch_core.to_detach(
96 | utils.compute_sTPE(
97 | est=utils.coeff_from_model(model=learn.model, reversed_weights=True),
98 | real=self.ar_params,
99 | )
100 | )
101 | self.total += self.sTPE
102 | self.count += 1
103 |
104 | @property
105 | def value(self):
106 | if self.at_epoch_end:
107 | return self.sTPE
108 | return self.total / self.count if self.count != 0 else None
109 |
110 | @property
111 | def name(self):
112 | return "sTPE of AR coeff"
113 |
114 |
115 | def get_loss_func(loss_func):
116 | if type(loss_func) == str:
117 | if loss_func.lower() == "mse":
118 | loss_func = mse
119 | elif loss_func.lower() in ["huber", "smooth_l1", "smoothl1"]:
120 | loss_func = huber
121 | elif loss_func.lower() in ["mae", "l1"]:
122 | loss_func = mae
123 | else:
124 | log.error("loss {} not defined".format(loss_func))
125 | loss_func = None
126 | return loss_func
127 |
--------------------------------------------------------------------------------
/example_notebooks/create_ar_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "from statsmodels.tsa.arima_process import ArmaProcess"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from arnet.create_ar_data import generate_armaprocess_data\n",
20 | "from arnet.create_ar_data import save_to_file, load_from_file"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 3,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "#### Notebook settings ####\n",
30 | "random_ar_params = False\n",
31 | "save = True\n",
32 | "save_path = '../ar_data'"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 4,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "#### Data settings ####\n",
42 | "# option 1: Randomly generated AR parameters\n",
43 | "data_config_random = {\n",
44 | " \"samples\": 10000,\n",
45 | " \"noise_std\": 0.1,\n",
46 | " \"ar_order\": 3,\n",
47 | " \"ma_order\": 0,\n",
48 | " \"params\": None, # for randomly generated AR params\n",
49 | "}\n",
50 | "\n",
51 | "# option 2: Manually define AR parameters\n",
52 | "data_config_manual = {\n",
53 | " \"samples\": 10000,\n",
54 | " \"noise_std\": 0.1,\n",
55 | " \"params\": ([0.2, 0.3, -0.5], []), \n",
56 | "# \"params\": ([0.2, 0, 0.3, 0, 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], []), \n",
57 | "}\n",
58 | "data_config_manual[\"ar_order\"] = int(sum(np.array(data_config_manual[\"params\"][0]) != 0.0))\n",
59 | "data_config_manual[\"ma_order\"] = int(sum(np.array(data_config_manual[\"params\"][1]) != 0.0))"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 5,
65 | "metadata": {},
66 | "outputs": [
67 | {
68 | "name": "stdout",
69 | "output_type": "stream",
70 | "text": [
71 | "{'samples': 10000, 'noise_std': 0.1, 'params': ([0.2, 0.3, -0.5], []), 'ar_order': 3, 'ma_order': 0}\n"
72 | ]
73 | }
74 | ],
75 | "source": [
76 | "## Select config\n",
77 | "data_config = data_config_random if random_ar_params else data_config_manual\n",
78 | "print(data_config)"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 6,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "## Generate data\n",
88 | "series, data_config[\"ar_params\"], data_config[\"ma_params\"] = generate_armaprocess_data(**data_config)\n",
89 | "del data_config[\"params\"]"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 7,
95 | "metadata": {},
96 | "outputs": [
97 | {
98 | "name": "stdout",
99 | "output_type": "stream",
100 | "text": [
101 | "loaded from saved files:\n",
102 | "{'samples': 10000, 'noise_std': 0.1, 'ar_order': 3, 'ma_order': 0, 'ar_params': [0.2, 0.3, -0.5], 'ma_params': []}\n",
103 | " 0\n",
104 | "0 0.114610\n",
105 | "1 0.027092\n",
106 | "2 -0.015263\n",
107 | "3 -0.125719\n",
108 | "4 -0.279195\n"
109 | ]
110 | }
111 | ],
112 | "source": [
113 | "if save:\n",
114 | " data_name = save_to_file(save_path, series, data_config)\n",
115 | " \n",
116 | " # just to test:\n",
117 | " df, data_config2 = load_from_file(save_path, data_name, load_config=True)\n",
118 | " print(\"loaded from saved files:\")\n",
119 | " print(data_config2)\n",
120 | " print(df.head())"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": []
129 | }
130 | ],
131 | "metadata": {
132 | "kernelspec": {
133 | "display_name": "Python 3",
134 | "language": "python",
135 | "name": "python3"
136 | },
137 | "language_info": {
138 | "codemirror_mode": {
139 | "name": "ipython",
140 | "version": 3
141 | },
142 | "file_extension": ".py",
143 | "mimetype": "text/x-python",
144 | "name": "python",
145 | "nbconvert_exporter": "python",
146 | "pygments_lexer": "ipython3",
147 | "version": "3.8.6"
148 | }
149 | },
150 | "nbformat": 4,
151 | "nbformat_minor": 4
152 | }
153 |
--------------------------------------------------------------------------------
/arnet/create_ar_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import pandas as pd
5 | from statsmodels.tsa.arima_process import ArmaProcess
6 |
7 |
8 | def _get_config(noise_std=0.1, n_samples=10000, random_ar_params=True):
9 | # option 1: Randomly generated AR parameters
10 | data_config_random = {"noise_std": noise_std, "ar_order": 3, "ma_order": 0, "params": None, "samples": n_samples}
11 | # option 2: Manually define AR parameters
12 | data_config_manual = {"noise_std": noise_std, "params": ([0.2, 0.3, -0.5], [])}
13 | data_config_manual["ar_order"] = int(sum(np.array(data_config_manual["params"][0]) != 0.0))
14 | data_config_manual["ma_order"] = int(sum(np.array(data_config_manual["params"][1]) != 0.0))
15 | data_config_manual["samples"] = n_samples # + int(data_config_manual["ar_order"])
16 | return data_config_random if random_ar_params else data_config_manual
17 |
18 |
19 | def _generate_random_arparams(ar_order, ma_order, limit_abs_sum=True, maxiter=100):
20 | is_stationary = False
21 | iteration = 0
22 | while not is_stationary:
23 | iteration += 1
24 | # print("Iteration", iteration)
25 | if iteration > maxiter:
26 | raise RuntimeError("failed to find stationary coefficients")
27 | # Generate random parameters
28 | arparams = []
29 | maparams = []
30 | for i in range(ar_order):
31 | arparams.append(2 * np.random.random() - 1)
32 | for i in range(ma_order):
33 | maparams.append(2 * np.random.random() - 1)
34 | # print(arparams)
35 | arparams = np.array(arparams)
36 | maparams = np.array(maparams)
37 | if limit_abs_sum:
38 | ar_abssum = sum(np.abs(arparams))
39 | ma_abssum = sum(np.abs(maparams))
40 | if ar_abssum > 1:
41 | arparams = arparams / (ar_abssum + 10e-6)
42 | arparams = arparams * (0.5 + 0.5 * np.random.random())
43 | if ma_abssum > 1:
44 | maparams = maparams / (ma_abssum + 10e-6)
45 | maparams = maparams * (0.5 + 0.5 * np.random.random())
46 |
47 | arparams = arparams - np.mean(arparams)
48 | maparams = maparams - np.mean(maparams)
49 | arma_process = ArmaProcess.from_coeffs(arparams, maparams, nobs=100)
50 | is_stationary = arma_process.isstationary
51 | return arparams, maparams
52 |
53 |
54 | def generate_armaprocess_data(samples, ar_order, ma_order, noise_std, params=None):
55 | if params is not None:
56 | # use specified params (make sure to sum up to 1 or less)
57 | arparams, maparams = params
58 | else:
59 | # iterate to find random arparams that are stationary
60 | arparams, maparams = _generate_random_arparams(ar_order, ma_order)
61 | arma_process = ArmaProcess.from_coeffs(arparams, maparams, nobs=samples)
62 | # sample output from ARMA Process
63 | series = arma_process.generate_sample(samples, scale=noise_std)
64 | # make zero-mean:
65 | series = series - np.mean(series)
66 | return series, list(arparams), list(maparams)
67 |
68 |
69 | def save_to_file(save_path, series, data_config):
70 | if not os.path.exists(save_path):
71 | os.makedirs(save_path)
72 | file_data = "ar_{}_ma_{}_noise_{:.3f}_len_{}".format(
73 | data_config["ar_order"], data_config["ma_order"], data_config["noise_std"], data_config["samples"]
74 | )
75 | # data_config["ar_params"] = list(data_config["ar_params"])
76 | # data_config["ma_params"] = list(data_config["ma_params"])
77 | np.savetxt(os.path.join(save_path, file_data + ".csv"), series, delimiter=",")
78 | with open(os.path.join(save_path, "info_" + file_data + ".json"), "w") as f:
79 | json.dump(data_config, f)
80 | return file_data
81 |
82 |
83 | def load_from_file(data_path, data_name, load_config=True, verbose=False):
84 | df = pd.read_csv(os.path.join(data_path, data_name + ".csv"), header=None, index_col=False)
85 | if load_config:
86 | with open(os.path.join(data_path, "info_" + data_name + ".json"), "r") as f:
87 | data_config = json.load(f)
88 | else:
89 | data_config = None
90 | if verbose:
91 | print("loaded series from file")
92 | print("data_config", data_config)
93 | print(df.shape)
94 | print(df.head())
95 | return df, data_config
96 |
97 |
98 | def main():
99 | verbose = True
100 | random = False
101 | save = True
102 | save_path = "ar_data"
103 |
104 | data_config = _get_config(random_ar_params=random)
105 | if verbose:
106 | print(data_config)
107 |
108 | series, data_config["ar_params"], data_config["ma_params"] = generate_armaprocess_data(**data_config)
109 | del data_config["params"]
110 |
111 | if save:
112 | data_name = save_to_file(save_path, series, data_config)
113 |
114 | # just to test:
115 | df, data_config2 = load_from_file(save_path, data_name, load_config=True)
116 | if verbose:
117 | print("loaded from saved files:")
118 | print(data_config2)
119 | print(df.head())
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/tests/test_legacy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import unittest
4 | import os
5 | import pathlib
6 | import shutil
7 | import logging
8 | import warnings
9 |
10 | warnings.filterwarnings("ignore", message=".*nonzero.*", category=UserWarning)
11 |
12 | import fastai
13 |
14 | ## lazy imports ala fastai2 style (needed for nice print functionality)
15 | from fastai.basics import *
16 | from fastai.tabular.all import *
17 |
18 | import arnet
19 | from arnet.ar_net_legacy import init_ar_learner
20 |
21 | log = logging.getLogger("ARNet.test_legacy")
22 | log.setLevel("WARNING")
23 | log.parent.setLevel("WARNING")
24 |
25 | DIR = pathlib.Path(__file__).parent.parent.absolute()
26 | data_path = os.path.join(DIR, "ar_data")
27 | results_path = os.path.join(data_path, "results_test")
28 | EPOCHS = 2
29 |
30 |
31 | class LegacyTests(unittest.TestCase):
32 | verbose = False
33 | plot = False
34 | save = False
35 |
36 | def test_legacy_ar(self):
37 | self.save = True
38 | if self.save:
39 | if not os.path.exists(results_path):
40 | os.makedirs(results_path)
41 | # Hyperparameters
42 | n_epoch = 3
43 | valid_p = 0.2
44 | n_forecasts = 1 # Note: if more than one, must have a list of ar_param for each forecast target.
45 | sparsity = 0.3 # guesstimate
46 | data_name = "ar_3_ma_0_noise_0.100_len_10000"
47 | df, data_config = arnet.load_from_file(data_path, data_name, load_config=True, verbose=self.verbose)
48 | df = df[:1000]
49 |
50 | # sparse AR: (for non-sparse, set sparsity to 1.0)
51 | ar_order = int(1 / sparsity * data_config["ar_order"])
52 | # to compute stats
53 | ar_params = arnet.pad_ar_params([data_config["ar_params"]], ar_order, n_forecasts)
54 |
55 | learn = init_ar_learner(
56 | series=df,
57 | ar_order=ar_order,
58 | n_forecasts=n_forecasts,
59 | valid_p=valid_p,
60 | sparsity=sparsity,
61 | ar_params=ar_params,
62 | verbose=self.verbose,
63 | )
64 |
65 | lr_at_min, _ = learn.lr_find(start_lr=1e-6, end_lr=1e2, num_it=400)
66 | log.info("lr at minimum: {}".format(lr_at_min))
67 |
68 | # Run Model
69 | # if you know the best learning rate:
70 | # learn.fit(n_epoch, 1e-2)
71 | # else use onecycle
72 | learn.fit_one_cycle(n_epoch=EPOCHS, lr_max=lr_at_min / 10)
73 |
74 | # Look at Coeff
75 | coeff = arnet.coeff_from_model(learn.model)
76 | log.info("ar params: {}".format(arnet.nice_print_list(ar_params)))
77 | log.info("model weights: {}".format(arnet.nice_print_list(coeff)))
78 | # should be [0.20, 0.30, -0.50, ...]
79 |
80 | preds, y = learn.get_preds()
81 | if self.plot or self.save:
82 | if self.plot:
83 | learn.recorder.plot_loss()
84 | arnet.plot_weights(
85 | ar_val=len(ar_params[0]), weights=coeff[0], ar=ar_params[0], save=not self.plot, savedir=results_path
86 | )
87 | arnet.plot_prediction_sample(preds[:100], y[:100], save=not self.plot, savedir=results_path)
88 | arnet.plot_error_scatter(preds, y, save=not self.plot, savedir=results_path)
89 |
90 | if self.save:
91 | # Optional:save and create inference learner
92 | learn.freeze()
93 | model_name = "ar{}_sparse_{:.3f}_ahead_{}_epoch_{}.pkl".format(ar_order, sparsity, n_forecasts, n_epoch)
94 | learn.export(fname=os.path.join(results_path, model_name))
95 | # can be loaded like this
96 | infer = load_learner(fname=os.path.join(results_path, model_name), cpu=True)
97 | # can unfreeze the model and fine_tune
98 | learn.unfreeze()
99 | learn.fit_one_cycle(1, lr_at_min / 100)
100 |
101 | coeff2 = arnet.coeff_from_model(learn.model)
102 | log.info("ar params: {}".format(arnet.nice_print_list(ar_params)))
103 | log.info("model weights: {}".format(arnet.nice_print_list(coeff)))
104 | log.info("model weights2: {}".format(arnet.nice_print_list(coeff2)))
105 |
106 | if self.plot or self.save:
107 | if self.plot:
108 | learn.recorder.plot_loss()
109 | arnet.plot_weights(
110 | ar_val=len(ar_params[0]), weights=coeff2[0], ar=ar_params[0], save=not self.plot, savedir=results_path
111 | )
112 | if self.save:
113 | shutil.rmtree(results_path)
114 |
115 | def test_legacy_random(self):
116 | df = pd.DataFrame({"x": [random.gauss(0.0, 1.0) for i in range(1000)]})
117 | learn = init_ar_learner(
118 | series=df,
119 | ar_order=3,
120 | n_forecasts=1,
121 | valid_p=0.1,
122 | sparsity=0.3,
123 | train_bs=32,
124 | valid_bs=1024,
125 | verbose=False,
126 | )
127 | # find Learning Rate
128 | lr_at_min, lr_steep = learn.lr_find(start_lr=1e-6, end_lr=1, num_it=1000, show_plot=self.plot)
129 | if self.plot:
130 | plt.show()
131 | print("lr at minimum: {}; steeptes lr: {}".format(lr_at_min, lr_steep))
132 | learn.fit_one_cycle(n_epoch=EPOCHS, lr_max=lr_at_min / 10)
133 | if self.plot:
134 | learn.recorder.plot_loss()
135 | plt.show()
136 | # record Coeff
137 | coeff = arnet.coeff_from_model(learn.model)
138 |
--------------------------------------------------------------------------------
/v0_1/data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from statsmodels.tsa.arima_process import ArmaProcess
4 | from torch.utils.data.dataset import Dataset
5 | import torch
6 | import copy
7 |
8 |
9 | def sample(y, offset, sample_inp_size, sample_out_size):
10 | Xin = np.arange(offset, offset + sample_inp_size)
11 | Xout = np.arange(sample_inp_size + offset, offset + sample_inp_size + sample_out_size)
12 | out = y[Xout]
13 | inp = y[Xin]
14 | return inp, out
15 |
16 |
17 | def create_dataset(series, n_samples=None, sample_inp_size=51, sample_out_size=1, test=None, verbose=False, plot=False):
18 | if n_samples is None:
19 | n_samples = len(series)
20 | data_inp = np.zeros((n_samples, sample_inp_size))
21 | data_out = np.zeros((n_samples, sample_out_size))
22 |
23 | for i in range(n_samples):
24 | sample_inp, sample_out = sample(series, i, sample_inp_size, sample_out_size)
25 | data_inp[i, :] = sample_inp
26 | data_out[i, :] = sample_out
27 | if test is not None:
28 | assert 0 < test < 1
29 | split = int(n_samples * (1 - test))
30 | train_inp, train_out = data_inp[:split], data_out[:split]
31 | test_inp, test_out = data_inp[split:], data_out[split:]
32 | series_train = series[:split]
33 | series_test = series[split:]
34 | else:
35 | train_inp, train_out = data_inp, data_out
36 | test_inp, test_out = data_inp, data_out
37 | series_train = series
38 | series_test = series
39 |
40 | dataset_train = LocalDataset(x=train_inp, y=train_out)
41 | dataset_test = LocalDataset(x=test_inp, y=test_out)
42 |
43 | if verbose:
44 | print("Train set size: ", dataset_train.length)
45 | print("Test set size: ", dataset_test.length)
46 |
47 | if plot:
48 | # Plot generated process.
49 | plt.plot(np.array(series)[:200])
50 | plt.show()
51 | return dataset_train, dataset_test, series_train, series_test
52 |
53 |
54 | class LocalDataset(Dataset):
55 | def __init__(self, x, y):
56 | x_dtype = torch.FloatTensor
57 | y_dtype = torch.FloatTensor # for MSE or L1 Loss
58 |
59 | self.length = x.shape[0]
60 |
61 | self.x_data = torch.from_numpy(x).type(x_dtype)
62 | self.y_data = torch.from_numpy(y).type(y_dtype)
63 |
64 | def __getitem__(self, index):
65 | return self.x_data[index], self.y_data[index]
66 |
67 | def __len__(self):
68 | return self.length
69 |
70 |
71 | def generate_armaprocess_data(samples, noise_std, random_order=None, params=None, limit_abs_sum=True):
72 | if params is not None:
73 | # use specified params, make sure to sum up to 1 or less
74 | arparams, maparams = params
75 | arma_process = ArmaProcess.from_coeffs(arparams, maparams, nobs=samples)
76 | else:
77 | is_stationary = False
78 | iteration = 0
79 | while not is_stationary:
80 | iteration += 1
81 | # print("Iteration", iteration)
82 | if iteration > 100:
83 | raise RuntimeError("failed to find stationary coefficients")
84 | # Generate random parameters
85 | arparams = []
86 | maparams = []
87 | ar_order, ma_order = random_order
88 | for i in range(ar_order):
89 | arparams.append(2 * np.random.random() - 1)
90 | for i in range(ma_order):
91 | maparams.append(2 * np.random.random() - 1)
92 |
93 | # print(arparams)
94 | arparams = np.array(arparams)
95 | maparams = np.array(maparams)
96 | if limit_abs_sum:
97 | ar_abssum = sum(np.abs(arparams))
98 | ma_abssum = sum(np.abs(maparams))
99 | if ar_abssum > 1:
100 | arparams = arparams / (ar_abssum + 10e-6)
101 | arparams = arparams * (0.5 + 0.5*np.random.random())
102 | if ma_abssum > 1:
103 | maparams = maparams / (ma_abssum + 10e-6)
104 | maparams = maparams * (0.5 + 0.5*np.random.random())
105 |
106 | arparams = arparams - np.mean(arparams)
107 | maparams = maparams - np.mean(maparams)
108 | arma_process = ArmaProcess.from_coeffs(arparams, maparams, nobs=samples)
109 | is_stationary = arma_process.isstationary
110 |
111 | # sample output from ARMA Process
112 | series = arma_process.generate_sample(samples, scale=noise_std)
113 | # make zero-mean:
114 | series = series - np.mean(series)
115 | return series, arparams, maparams
116 |
117 |
118 | def init_ar_dataset(n_samples, ar_val, ar_params=None, noise_std=1.0, plot=False, verbose=False, test=None, pad_to=None):
119 | # AR-Process
120 | if ar_params is not None:
121 | ar_val = len(ar_params)
122 | params = (ar_params, [])
123 | else:
124 | params = None
125 |
126 | if pad_to is None:
127 | inp_size = ar_val
128 | else:
129 | inp_size = pad_to
130 |
131 | series, ar, ma = generate_armaprocess_data(
132 | samples=n_samples+inp_size,
133 | noise_std=noise_std,
134 | random_order=(ar_val, 0),
135 | params=params,
136 | )
137 | # print("series mean", np.mean(series))
138 |
139 | if pad_to is not None:
140 | ar_pad = [0.0] * max(0, pad_to - ar_val)
141 | ar = list(ar) + ar_pad
142 |
143 | if verbose:
144 | print("AR params: ")
145 | print(ar)
146 |
147 | # Initialize data for DAR
148 | dataset_train, dataset_test, series_train, series_test = create_dataset(
149 | series=series,
150 | n_samples=n_samples,
151 | sample_inp_size=inp_size,
152 | sample_out_size=1,
153 | verbose=verbose,
154 | plot=plot,
155 | test=test,
156 | )
157 |
158 | return dataset_train, dataset_test, series_train, series_test, ar
159 |
160 |
161 | def load_data(data_config_in, verbose=False, plot=False):
162 | data_config = copy.deepcopy(data_config_in)
163 | data_type = data_config.pop("type")
164 | data = {
165 | "type": data_type
166 | }
167 | if data_type == 'AR':
168 | data["train"], data["test"], data["series_train"], data["series_test"], data["ar"] = init_ar_dataset(
169 | **data_config,
170 | verbose=verbose,
171 | plot=plot,
172 | )
173 | data["pad_to"] = data_config_in["pad_to"]
174 | else:
175 | raise NotImplementedError
176 | return data
177 |
--------------------------------------------------------------------------------
/v0_1/utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import seaborn as sns
5 | import os
6 |
7 |
8 | def compute_stats_ar(results, ar_params, verbose=False):
9 | weights = results["weights"]
10 | error = results["predicted"] - results["actual"]
11 | stats = {}
12 |
13 | abs_error = np.abs(weights - ar_params)
14 |
15 | symmetric_abs_coeff = np.abs(weights) + np.abs(ar_params)
16 | stats["sMAPE (AR-coefficients)"] = 100 * np.mean(abs_error / (10e-9 + symmetric_abs_coeff))
17 |
18 | sTPE = 100 * np.sum(abs_error) / (10e-9 + np.sum(symmetric_abs_coeff))
19 | stats["sTPE (AR-coefficients)"] = sTPE
20 |
21 | # predictions error
22 | stats["MSE"] = np.mean(error ** 2)
23 |
24 | if verbose:
25 | print("MSE: {}".format(stats["MSE"]))
26 | print("sMAPE (AR-coefficients): {:6.3f}".format(stats["sMAPE (AR-coefficients)"]))
27 | print("sTPE (AR-coefficients): {:6.3f}".format(stats["sTPE (AR-coefficients)"]))
28 | # print("Relative error: {:6.3f}".format(stats["TP (AR-coefficients)"]))
29 | # print("Mean relative error: {:6.3f}".format(mean_rel_error))
30 |
31 | print("AR params: ")
32 | print(ar_params)
33 |
34 | print("Weights: ")
35 | print(weights)
36 | return stats
37 |
38 |
39 | def plot_loss_curve(losses, test_loss=None, epoch_losses=None, show=False, save=False):
40 | fig = plt.figure()
41 | fig.set_size_inches(12, 6)
42 | ax = plt.axes()
43 | ax.set_xlabel("Iteration")
44 | ax.set_ylabel("Loss")
45 | x_loss = list(range(len(losses)))
46 | plt.plot(x_loss, losses, 'b', alpha=0.3)
47 | if epoch_losses is not None:
48 | iter_per_epoch = int(len(losses) / len(epoch_losses))
49 | epoch_ends = int(iter_per_epoch/2) + iter_per_epoch*np.arange(len(epoch_losses))
50 | plt.plot(epoch_ends, epoch_losses, 'b')
51 | if test_loss is not None:
52 | plt.hlines(test_loss, xmin=x_loss[0], xmax=x_loss[-1])
53 | if save:
54 | if not os.path.exists('results'):
55 | os.makedirs('results')
56 | figname = 'results/loss_DAR.png'
57 | plt.savefig(figname, dpi=600, bbox_inches='tight')
58 | plt.show()
59 | # plt.close()
60 |
61 |
62 | def plot_prediction_sample(predicted, actual, num_obs=100, model_name="AR-Net", save=False):
63 | fig2 = plt.figure()
64 | fig2.set_size_inches(10, 6)
65 | plt.plot(actual[0:num_obs])
66 | plt.plot(predicted[0:num_obs])
67 | plt.legend(["Actual Time-Series", "{}-Prediction".format(model_name)])
68 | if save:
69 | if not os.path.exists('results'):
70 | os.makedirs('results')
71 | figname = 'results/prediction_{}.png'.format(model_name)
72 | plt.savefig(figname, dpi=600, bbox_inches='tight')
73 | plt.show()
74 |
75 |
76 | def plot_error_scatter(predicted, actual, model_name="AR-Net", save=False):
77 | # error = predicted - actual
78 | fig3 = plt.figure()
79 | fig3.set_size_inches(6, 6)
80 | plt.scatter(actual, predicted - actual, marker='o', s=10, alpha=0.3)
81 | plt.legend(["{}-Error".format(model_name)])
82 | if save:
83 | if not os.path.exists('results'):
84 | os.makedirs('results')
85 | figname = 'results/scatter_{}.png'.format(model_name)
86 | plt.savefig(figname, dpi=600, bbox_inches='tight')
87 | plt.show()
88 |
89 |
90 | def plot_weights(ar_val, weights, ar, model_name="AR-Net", save=False):
91 | df = pd.DataFrame(
92 | zip(
93 | list(range(1, ar_val + 1)) * 2,
94 | ["AR-Process (True)"] * ar_val + [model_name] * ar_val,
95 | list(ar) + list(weights)
96 | ),
97 | columns=["AR-coefficient (lag number)", "model", "value (weight)"]
98 | )
99 | plt.figure(figsize=(10, 6))
100 | palette = {"Classic-AR": "C0", "AR-Net": "C1", "AR-Process (True)": "k"}
101 | sns.barplot(x="AR-coefficient (lag number)", hue="model", y="value (weight)", data=df)
102 | if save:
103 | if not os.path.exists('results'):
104 | os.makedirs('results')
105 | figname = 'results/weights_{}_{}.png'.format(ar_val, model_name, palette=palette)
106 | plt.savefig(figname, dpi=600, bbox_inches='tight')
107 |
108 | plt.show()
109 |
110 |
111 | def plot_results(results, model_name="MODEL", save=False):
112 | plot_prediction_sample(results["predicted"], results["actual"], num_obs=100, model_name=model_name, save=save)
113 | plot_error_scatter(results["predicted"], results["actual"], model_name=model_name, save=save)
114 |
115 |
116 | def jsonize(results):
117 | for key, value in results.items():
118 | if type(value) is list:
119 | if type(value[0]) is list:
120 | results[key] = [["{:8.5f}".format(xy) for xy in x] for x in value]
121 | else:
122 | results[key] = ["{:8.5f}".format(x) for x in value]
123 | else:
124 | results[key] = "{:8.5f}".format(value)
125 | return results
126 |
127 |
128 | def list_of_dicts_2_dict_of_lists(sources):
129 | keys = sources[0].keys()
130 | res = {}
131 | for key in keys:
132 | res[key] = [d[key] for d in sources]
133 | return res
134 |
135 |
136 | def list_of_dicts_2_dict_of_means(sources):
137 | keys = sources[0].keys()
138 | res = {}
139 | for key in keys:
140 | res[key] = np.mean([d[key] for d in sources])
141 | return res
142 |
143 |
144 | def list_of_dicts_2_dict_of_means_minmax(sources):
145 | keys = sources[0].keys()
146 | res = {}
147 | for key in keys:
148 | values = [d[key] for d in sources]
149 | res[key] = (np.mean(values), min(values), max(values))
150 | return res
151 |
152 |
153 | def get_json_filenames(values, subdir=None):
154 | ar_filename = get_json_filenames_type("AR", values, subdir)
155 | dar_filename = get_json_filenames_type("DAR", values, subdir)
156 | return ar_filename, dar_filename
157 |
158 |
159 | def get_json_filenames_type(model_type, values, subdir=None):
160 | filename = 'results/{}{}_{}.json'.format(
161 | subdir + "/" if subdir is not None else "",
162 | model_type,
163 | "-".join([str(x) for x in values])
164 | )
165 | return filename
166 |
167 |
168 | def intelligent_regularization(sparsity):
169 | if sparsity is not None:
170 | # best:
171 | # lam = 0.01 * (1.0 / sparsity - 1.0)
172 | lam = 0.02 * (1.0 / sparsity - 1.0)
173 | # lam = 0.05 * (1.0 / sparsity - 1.0)
174 |
175 | # alternatives
176 | # l1 = 0.02 * (np.log(2) / np.log(1 + sparsity) - 1.0)
177 | # l1 = 0.1 * (1.0 / np.sqrt(sparsity) - 1.0)
178 | else:
179 | lam = 0.0
180 | return lam
181 |
182 |
183 |
--------------------------------------------------------------------------------
/v0_1/training.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from model import DAR
7 | from torch import optim
8 | from torch.utils.data import DataLoader
9 |
10 | from v0_1_pure_pytorch import utils
11 |
12 |
13 | # Architecture, batching etc of DARMA
14 | def train_batch(model, x, y, optimizer, loss_fn, lambda_value=None):
15 | # Run forward calculation
16 | y_predict = model.forward(x)
17 |
18 | # Compute loss.
19 | loss = loss_fn(y_predict, y)
20 |
21 | # regularize
22 | if lambda_value is not None:
23 | reg_loss = torch.zeros(1, dtype=torch.float, requires_grad=True)
24 | if model.num_layers == 1:
25 | abs_weights = torch.abs(model.layer_1.weight)
26 | # classic L1
27 | # reg = torch.mean(abs_weights)
28 |
29 | # sqrt - helps to protect some weights and bring others to zero,
30 | # but is still tough on larger weights
31 | # reg = torch.mean(torch.sqrt(abs_weights))
32 |
33 | # new, less hard on larger weights: (protects ~0.1-1.0)
34 | # reg = torch.div(2.0, 1.0 + torch.exp(-5.0*abs_weights.pow(0.4))) - 1.0
35 |
36 | # mid-way, more stable
37 | reg = torch.div(2.0, 1.0 + torch.exp(-3.0*abs_weights.pow(1.0/3.0))) - 1.0
38 |
39 | reg_loss = reg_loss + torch.mean(reg)
40 | else:
41 | # for weights in model.parameters():
42 | raise NotImplementedError("L1 Norm for deeper models not implemented")
43 |
44 | loss = loss + lambda_value * reg_loss
45 |
46 | optimizer.zero_grad()
47 |
48 | loss.backward(retain_graph=True)
49 |
50 | optimizer.step()
51 |
52 | return loss.data.item()
53 |
54 |
55 | def train(model, loader, loss_fn, lr, epochs, lr_decay, est_sparsity, lambda_delay=None, verbose=False):
56 |
57 | # Initialize the optimizer with above parameters
58 | optimizer = optim.Adam(model.parameters(), lr=lr)
59 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=lr_decay)
60 |
61 | losses = list()
62 | batch_index = 0
63 | epoch_losses = []
64 | avg_losses = []
65 | lambda_value = utils.intelligent_regularization(est_sparsity)
66 |
67 | for e in range(epochs):
68 | # slowly increase regularization until lambda_delay epoch
69 | if lambda_delay is not None and e < lambda_delay:
70 | l_factor = e / (1.0 * lambda_delay)
71 | # l_factor = (e / (1.0 * lambda_delay))**2
72 | else:
73 | l_factor = 1.0
74 |
75 | for x, y in loader:
76 | loss = train_batch(model=model, x=x, y=y, optimizer=optimizer,
77 | loss_fn=loss_fn, lambda_value=l_factor*lambda_value)
78 | epoch_losses.append(loss)
79 | batch_index += 1
80 | scheduler.step()
81 | losses.extend(epoch_losses)
82 | avg_loss = np.mean(epoch_losses)
83 | avg_losses.append(avg_loss)
84 | epoch_losses = []
85 | if verbose:
86 | print("{}. Epoch Avg Loss: {:10.2f}".format(e + 1, avg_loss))
87 | if verbose:
88 | print("Total Batches: ", batch_index)
89 |
90 | return losses, avg_losses
91 |
92 |
93 | def test_batch(model, x, y, loss_fn):
94 | # run forward calculation
95 | y_predict = model.forward(x)
96 | loss = loss_fn(y_predict, y)
97 |
98 | return y_predict, loss
99 |
100 |
101 | def test(model, loader, loss_fn):
102 | losses = list()
103 | y_vectors = list()
104 | y_predict_vectors = list()
105 |
106 | batch_index = 0
107 | for x, y in loader:
108 | y_predict, loss = test_batch(model=model, x=x, y=y, loss_fn=loss_fn)
109 |
110 | losses.append(loss.data.numpy())
111 | y_vectors.append(y.data.numpy())
112 | y_predict_vectors.append(y_predict.data.numpy())
113 |
114 | batch_index += 1
115 |
116 | losses = np.array(losses)
117 | y_predict_vector = np.concatenate(y_predict_vectors)
118 | mse = np.mean((y_predict_vector - np.concatenate(y_vectors)) ** 2)
119 |
120 | return y_predict_vector, losses, mse
121 |
122 |
123 | def run_train_test(dataset_train, dataset_test, model_config, train_config, verbose=False):
124 | data_loader_train = DataLoader(dataset=dataset_train, batch_size=train_config["batch"], shuffle=True)
125 | data_loader_test = DataLoader(dataset=dataset_test, batch_size=len(dataset_test), shuffle=False)
126 |
127 | if model_config["ma"] > 0:
128 | # TODO: implement DARMA
129 | raise NotImplementedError
130 | else:
131 | del model_config["ma"]
132 | model = DAR(
133 | **model_config
134 | )
135 |
136 | # Define the loss function
137 | loss_fn = nn.MSELoss() # mean squared error
138 |
139 | # Train and get the resulting loss per iteration
140 | del train_config["batch"]
141 | losses, avg_losses = train(
142 | model=model,
143 | loader=data_loader_train,
144 | loss_fn=loss_fn,
145 | **train_config,
146 | verbose=verbose,
147 | )
148 |
149 | # Test and get the resulting predicted y values
150 | y_predict, test_losses, test_mse = test(model=model, loader=data_loader_test, loss_fn=loss_fn)
151 |
152 | actual = np.concatenate(np.array(dataset_test.y_data))
153 | predicted = np.concatenate(y_predict)
154 | # weights_rereversed = np.array(model.layer_1.data)[0, ::-1]
155 | weights_rereversed = model.layer_1.weight.detach().numpy()[0, ::-1]
156 |
157 | return predicted, actual, np.array(losses), weights_rereversed, test_mse, avg_losses
158 |
159 |
160 | def run(data, model_config, train_config, verbose=False):
161 | if verbose:
162 | print("################ Model: AR-Net ################")
163 | start = time.time()
164 | predicted, actual, losses, weights, test_mse, epoch_losses = run_train_test(
165 | dataset_train=data["train"],
166 | dataset_test=data["test"],
167 | model_config=model_config,
168 | train_config=train_config,
169 | verbose=verbose,
170 | )
171 | end = time.time()
172 | duration = end - start
173 |
174 | if verbose:
175 | print("Time: {:8.4f}".format(duration))
176 | print("Final train epoch loss: {:10.2f}".format(epoch_losses[-1]))
177 | print("Test MSEs: {:10.2f}".format(test_mse))
178 |
179 | results = {}
180 | results["weights"] = weights
181 | results["predicted"] = predicted
182 | results["actual"] = actual
183 | results["test_mse"] = test_mse
184 | results["losses"] = losses
185 | results["epoch_losses"] = epoch_losses
186 | if data["type"] == 'AR':
187 | stats = utils.compute_stats_ar(results, ar_params=data["ar"], verbose=verbose)
188 | else:
189 | raise NotImplementedError
190 | stats["Time (s)"] = duration
191 |
192 | return results, stats
193 |
194 |
195 | def main():
196 | print("deprecated")
197 |
198 |
199 | if __name__ == "__main__":
200 | main()
201 |
--------------------------------------------------------------------------------
/example_notebooks/make_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import json\n",
11 | "import pandas as pd\n",
12 | "import numpy as np"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "from arnet.create_ar_data import load_from_file\n",
22 | "from arnet.make_dataset import tabularize_univariate"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 3,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "## Notebook settings\n",
32 | "data_path = '../ar_data'\n",
33 | "data_name = 'ar_3_ma_0_noise_0.100_len_10000'"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 4,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "name": "stdout",
43 | "output_type": "stream",
44 | "text": [
45 | "{'samples': 10000, 'noise_std': 0.1, 'ar_order': 3, 'ma_order': 0, 'ar_params': [0.2, 0.3, -0.5], 'ma_params': []}\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "## Load data\n",
51 | "## if created AR data with create_ar_data, we can use the helper function:\n",
52 | "df, data_config = load_from_file(data_path, data_name, load_config=True)\n",
53 | "print(data_config)\n",
54 | "n_lags = data_config[\"ar_order\"]"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 5,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "## Load data\n",
64 | "## else we can manually load any file that stores a time series, for example:\n",
65 | "# df = pd.read_csv(os.path.join(data_path, data_name + '.csv'), header=None, index_col=False)\n",
66 | "# n_lags = 3"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 6,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "name": "stdout",
76 | "output_type": "stream",
77 | "text": [
78 | "(10000, 1)\n"
79 | ]
80 | },
81 | {
82 | "data": {
83 | "text/html": [
84 | "
\n",
85 | "\n",
98 | "
\n",
99 | " \n",
100 | " \n",
101 | " | \n",
102 | " 0 | \n",
103 | "
\n",
104 | " \n",
105 | " \n",
106 | " \n",
107 | " | 0 | \n",
108 | " 0.114610 | \n",
109 | "
\n",
110 | " \n",
111 | " | 1 | \n",
112 | " 0.027092 | \n",
113 | "
\n",
114 | " \n",
115 | " | 2 | \n",
116 | " -0.015263 | \n",
117 | "
\n",
118 | " \n",
119 | " | 3 | \n",
120 | " -0.125719 | \n",
121 | "
\n",
122 | " \n",
123 | " | 4 | \n",
124 | " -0.279195 | \n",
125 | "
\n",
126 | " \n",
127 | "
\n",
128 | "
"
129 | ],
130 | "text/plain": [
131 | " 0\n",
132 | "0 0.114610\n",
133 | "1 0.027092\n",
134 | "2 -0.015263\n",
135 | "3 -0.125719\n",
136 | "4 -0.279195"
137 | ]
138 | },
139 | "execution_count": 6,
140 | "metadata": {},
141 | "output_type": "execute_result"
142 | }
143 | ],
144 | "source": [
145 | "print(df.shape)\n",
146 | "df.head()"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 7,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "## create a tabularized dataset from time series\n",
156 | "df_tab = tabularize_univariate(df, n_lags, n_forecasts=1)"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 8,
162 | "metadata": {},
163 | "outputs": [
164 | {
165 | "name": "stdout",
166 | "output_type": "stream",
167 | "text": [
168 | "(9997, 4)\n"
169 | ]
170 | },
171 | {
172 | "data": {
173 | "text/html": [
174 | "\n",
175 | "\n",
188 | "
\n",
189 | " \n",
190 | " \n",
191 | " | \n",
192 | " x_0 | \n",
193 | " x_1 | \n",
194 | " x_2 | \n",
195 | " y_0 | \n",
196 | "
\n",
197 | " \n",
198 | " \n",
199 | " \n",
200 | " | 0 | \n",
201 | " 0.114610 | \n",
202 | " 0.027092 | \n",
203 | " -0.015263 | \n",
204 | " -0.125719 | \n",
205 | "
\n",
206 | " \n",
207 | " | 1 | \n",
208 | " 0.027092 | \n",
209 | " -0.015263 | \n",
210 | " -0.125719 | \n",
211 | " -0.279195 | \n",
212 | "
\n",
213 | " \n",
214 | " | 2 | \n",
215 | " -0.015263 | \n",
216 | " -0.125719 | \n",
217 | " -0.279195 | \n",
218 | " -0.133535 | \n",
219 | "
\n",
220 | " \n",
221 | " | 3 | \n",
222 | " -0.125719 | \n",
223 | " -0.279195 | \n",
224 | " -0.133535 | \n",
225 | " -0.254239 | \n",
226 | "
\n",
227 | " \n",
228 | " | 4 | \n",
229 | " -0.279195 | \n",
230 | " -0.133535 | \n",
231 | " -0.254239 | \n",
232 | " 0.122992 | \n",
233 | "
\n",
234 | " \n",
235 | "
\n",
236 | "
"
237 | ],
238 | "text/plain": [
239 | " x_0 x_1 x_2 y_0\n",
240 | "0 0.114610 0.027092 -0.015263 -0.125719\n",
241 | "1 0.027092 -0.015263 -0.125719 -0.279195\n",
242 | "2 -0.015263 -0.125719 -0.279195 -0.133535\n",
243 | "3 -0.125719 -0.279195 -0.133535 -0.254239\n",
244 | "4 -0.279195 -0.133535 -0.254239 0.122992"
245 | ]
246 | },
247 | "execution_count": 8,
248 | "metadata": {},
249 | "output_type": "execute_result"
250 | }
251 | ],
252 | "source": [
253 | "print(df_tab.shape)\n",
254 | "df_tab.head()"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": []
263 | }
264 | ],
265 | "metadata": {
266 | "kernelspec": {
267 | "display_name": "Python 3",
268 | "language": "python",
269 | "name": "python3"
270 | },
271 | "language_info": {
272 | "codemirror_mode": {
273 | "name": "ipython",
274 | "version": 3
275 | },
276 | "file_extension": ".py",
277 | "mimetype": "text/x-python",
278 | "name": "python",
279 | "nbconvert_exporter": "python",
280 | "pygments_lexer": "ipython3",
281 | "version": "3.8.6"
282 | }
283 | },
284 | "nbformat": 4,
285 | "nbformat_minor": 4
286 | }
287 |
--------------------------------------------------------------------------------
/arnet/ar_net.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | import logging
3 | import os
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 |
7 | import torch
8 | from fastai.data.core import DataLoaders
9 | from fastai.tabular.core import TabularPandas, TabDataLoader
10 | from fastai.tabular.learner import tabular_learner, TabularLearner
11 | from fastai.data.transforms import Normalize
12 | from fastai.learner import load_learner
13 | from fastai.distributed import ParallelTrainer
14 |
15 | from arnet import utils, utils_data, plotting
16 | from arnet.fastai_mods import SparsifyAR, huber, sTPE, get_loss_func
17 |
18 | log = logging.getLogger("ARNet")
19 |
20 |
21 | @dataclass
22 | class ARNet:
23 | ar_order: int
24 | sparsity: float = None
25 | n_forecasts: int = 1
26 | n_epoch: int = 20
27 | lr: float = None
28 | est_noise: float = None
29 | start_reg_pct: float = 0.0
30 | full_reg_pct: float = 0.5
31 | use_reg_noise: bool = False
32 | reg_c1: float = 2.0
33 | reg_c2: float = 2.0
34 | loss_func: str = "huber"
35 | train_bs: int = 32
36 | valid_bs: int = 1024
37 | valid_p: float = 0.1
38 | normalize: bool = False
39 | ar_params: list = None
40 | log_level: str = None
41 | callbacks: list = None
42 | metrics: list = None
43 | use_gpu: bool = False
44 | dls: DataLoaders = field(init=False, default=None)
45 | learn: TabularLearner = field(init=False, default=None)
46 | coeff: list = field(init=False, default=None)
47 | df: pd.DataFrame = field(init=False, default=None)
48 | regularizer: SparsifyAR = field(init=False, default=None)
49 |
50 | def __post_init__(self):
51 | if self.log_level is not None:
52 | utils.set_logger_level(log, self.log_level)
53 | self.loss_func = get_loss_func(self.loss_func)
54 | if self.use_gpu:
55 | if torch.cuda.is_available():
56 | self.device = torch.device("cuda")
57 | # torch.cuda.set_device(0)
58 | else:
59 | log.error("CUDA is not available. defaulting to CPU")
60 | self.device = torch.device("cpu")
61 | else:
62 | self.device = torch.device("cpu")
63 |
64 | def tabularize(self, series):
65 | if self.est_noise is None:
66 | self.est_noise = utils_data.estimate_noise(series)
67 | log.info("estimated noise of series: {}".format(self.est_noise))
68 | df_all = utils_data.tabularize_univariate(series, self.ar_order, self.n_forecasts)
69 | log.debug("tabularized df")
70 | log.debug("df columns: {}".format(list(df_all.columns)))
71 | log.debug("df shape: {}".format(df_all.shape))
72 | # log.debug("df head(3): {}".format(df_all.head(3)))
73 | self.df = df_all
74 | return self
75 |
76 | def make_datasets(
77 | self,
78 | series=None,
79 | valid_p=None,
80 | train_bs=None,
81 | valid_bs=None,
82 | normalize=None,
83 | ):
84 | if series is None:
85 | if self.df is None:
86 | raise ValueError("must pass a series.")
87 | else:
88 | self.tabularize(series)
89 | valid_p = self.valid_p if valid_p is None else valid_p
90 | train_bs = self.train_bs if train_bs is None else train_bs
91 | valid_bs = self.valid_bs if valid_bs is None else valid_bs
92 | normalize = self.normalize if normalize is None else normalize
93 |
94 | procs = []
95 | if normalize:
96 | procs.append(Normalize)
97 |
98 | df_all = self.df
99 | splits = utils_data.split_by_p_valid(valid_p, len(df_all))
100 | cont_names = [col for col in list(df_all.columns) if "x_" == col[:2]]
101 | target_names = [col for col in list(df_all.columns) if "y_" == col[:2]]
102 | tp = TabularPandas(
103 | df_all,
104 | procs=procs,
105 | cat_names=None,
106 | cont_names=cont_names,
107 | y_names=target_names,
108 | splits=splits,
109 | )
110 | log.debug("cont var num: {}, names: {}".format(len(tp.cont_names), tp.cont_names))
111 |
112 | trn_dl = TabDataLoader(tp.train, bs=train_bs, shuffle=True, drop_last=True, device=self.device)
113 | val_dl = TabDataLoader(tp.valid, bs=valid_bs, device=self.device)
114 | self.dls = DataLoaders(trn_dl, val_dl, device=self.device)
115 | log.debug("showing batch")
116 | log.debug("{}".format(self.dls.show_batch(show=False)))
117 | return self
118 |
119 | def create_regularizer(
120 | self,
121 | sparsity=None,
122 | start_reg_pct=None,
123 | full_reg_pct=None,
124 | est_noise=None,
125 | use_reg_noise=None,
126 | reg_c1=None,
127 | reg_c2=None,
128 | ):
129 | sparsity = self.sparsity if sparsity is None else sparsity
130 | start_reg_pct = self.start_reg_pct if start_reg_pct is None else start_reg_pct
131 | full_reg_pct = self.full_reg_pct if full_reg_pct is None else full_reg_pct
132 | est_noise = self.est_noise if est_noise is None else est_noise
133 | use_reg_noise = self.use_reg_noise if use_reg_noise is None else use_reg_noise
134 | reg_c1 = self.reg_c1 if reg_c1 is None else reg_c1
135 | reg_c2 = self.reg_c2 if reg_c2 is None else reg_c2
136 |
137 | self.regularizer = SparsifyAR(
138 | sparsity,
139 | est_noise=est_noise if use_reg_noise else None,
140 | start_pct=start_reg_pct,
141 | full_pct=full_reg_pct,
142 | c1=reg_c1,
143 | c2=reg_c2,
144 | )
145 | log.info("reg lam (max): {}".format(self.regularizer.lam_max))
146 | return self
147 |
148 | def create_learner(
149 | self,
150 | loss_func=None,
151 | metrics=None,
152 | ar_params=None,
153 | callbacks=None,
154 | ):
155 | loss_func = self.loss_func if loss_func is None else get_loss_func(loss_func)
156 | metrics = self.metrics if metrics is None else metrics
157 | ar_params = self.ar_params if ar_params is None else ar_params
158 | callbacks = self.callbacks if callbacks is None else callbacks
159 |
160 | if metrics is None:
161 | metrics = ["MSE", "MAE"]
162 | metrics = [get_loss_func(m) for m in metrics]
163 | if ar_params is not None:
164 | metrics.append(sTPE(ar_params, at_epoch_end=False))
165 |
166 | if callbacks is None:
167 | callbacks = []
168 | if self.sparsity is not None and self.regularizer is None:
169 | self.create_regularizer()
170 | if self.regularizer is not None:
171 | callbacks.append(self.regularizer)
172 |
173 | self.learn = tabular_learner(
174 | self.dls,
175 | layers=[], # Note: None defaults to [200, 100]
176 | config={"use_bn": False, "bn_final": False, "bn_cont": False},
177 | n_out=self.n_forecasts, # None calls get_c(dls)
178 | train_bn=False, # passed to Learner
179 | metrics=metrics, # passed to Learner
180 | loss_func=loss_func,
181 | cbs=callbacks,
182 | )
183 | log.debug("{}".format(self.learn.model))
184 | return self
185 |
186 | def find_lr(self, plot=True):
187 | if self.learn is None:
188 | raise ValueError("create learner first.")
189 | lr_at_min, lr_steep = self.learn.lr_find(start_lr=1e-6, end_lr=10, num_it=300, show_plot=plot)
190 | if plot:
191 | plt.show()
192 | log.debug("lr at minimum: {}; (steepest lr: {})".format(lr_at_min, lr_steep))
193 | lr = lr_at_min
194 | log.info("Optimal learning rate: {}".format(lr))
195 | self.lr = lr
196 | return self
197 |
198 | def fit_one_cycle(self, n_epoch=None, lr=None, cycles=1, plot=True):
199 | n_epoch = self.n_epoch if n_epoch is None else n_epoch
200 | lr = self.lr if lr is None else lr
201 |
202 | if lr is None:
203 | self.find_lr(plot=plot)
204 | lr = self.lr
205 | for i in range(0, cycles):
206 | self.learn.fit_one_cycle(n_epoch=n_epoch, lr_max=lr, div=25.0, div_final=10000.0, pct_start=0.25)
207 | lr = lr / 10
208 | if plot:
209 | self.learn.recorder.plot_loss(skip_start=20)
210 | if plot:
211 | plt.show()
212 | # record Coeff
213 | self.coeff = utils.coeff_from_model(self.learn.model)
214 | return self
215 |
216 | def fit(self, series, plot=False):
217 | self.make_datasets(series)
218 | self.create_learner()
219 | self.fit_one_cycle(plot=plot)
220 | return self
221 |
222 | def plot_weights(self, **kwargs):
223 | plotting.plot_weights(
224 | ar_val=self.ar_order,
225 | weights=self.coeff[0],
226 | ar=self.ar_params,
227 | **kwargs,
228 | )
229 |
230 | def plot_fitted_obs(self, num_obs=100, **kwargs):
231 | preds, y = self.learn.get_preds()
232 | if num_obs is not None:
233 | y = y[0:num_obs]
234 | preds = preds[0:num_obs]
235 | plotting.plot_prediction_sample(preds, y, **kwargs)
236 |
237 | def plot_errors(self, **kwargs):
238 | preds, y = self.learn.get_preds()
239 | plotting.plot_error_scatter(preds, y, **kwargs)
240 |
241 | def save_model(self, results_path="results", model_name=None):
242 | # self.learn.freeze()
243 | sparsity = 1.0 if self.sparsity is None else self.sparsity
244 | if model_name is None:
245 | model_name = "ar{}_sparse_{:.3f}_ahead_{}_epoch_{}.pkl".format(
246 | self.ar_order, sparsity, self.n_forecasts, self.n_epoch
247 | )
248 | self.learn.export(fname=os.path.join(results_path, model_name))
249 | return self
250 |
251 | def load_model(self, results_path="results", model_name=None, cpu=True):
252 | self.learn = load_learner(fname=os.path.join(results_path, model_name), cpu=cpu)
253 | # can unfreeze the model and fine_tune
254 | self.learn.unfreeze()
255 | return self
256 |
--------------------------------------------------------------------------------
/example_notebooks/01_fit_arnet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import pandas as pd\n",
11 | "import os\n",
12 | "import torch\n",
13 | "import arnet\n",
14 | "from arnet import ARNet\n",
15 | "import matplotlib.pyplot as plt\n",
16 | "%matplotlib inline"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "DIR = os.path.dirname(os.path.abspath(''))\n",
26 | "data_path = os.path.join(DIR, 'ar_data')\n",
27 | "name = 'ar_3_ma_0_noise_0.100_len_10000'\n",
28 | "df = pd.read_csv(os.path.join(data_path, name + '.csv'), header=None, index_col=False)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 3,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "name": "stderr",
38 | "output_type": "stream",
39 | "text": [
40 | "INFO: ARNet - tabularize: estimated noise of series: 0.12853080551811422\n",
41 | "INFO: ARNet - create_regularizer: reg lam (max): 0.009999999980000004\n"
42 | ]
43 | },
44 | {
45 | "data": {
46 | "text/html": [],
47 | "text/plain": [
48 | ""
49 | ]
50 | },
51 | "metadata": {},
52 | "output_type": "display_data"
53 | },
54 | {
55 | "data": {
56 | "image/png": "\n",
57 | "text/plain": [
58 | ""
59 | ]
60 | },
61 | "metadata": {
62 | "needs_background": "light"
63 | },
64 | "output_type": "display_data"
65 | },
66 | {
67 | "name": "stderr",
68 | "output_type": "stream",
69 | "text": [
70 | "INFO: ARNet - find_lr: Optimal learning rate: 0.1369830012321472\n"
71 | ]
72 | },
73 | {
74 | "data": {
75 | "text/html": [
76 | "\n",
77 | " \n",
78 | " \n",
79 | " | epoch | \n",
80 | " train_loss | \n",
81 | " valid_loss | \n",
82 | " mse | \n",
83 | " mae | \n",
84 | " time | \n",
85 | "
\n",
86 | " \n",
87 | " \n",
88 | " \n",
89 | " | 0 | \n",
90 | " 0.010955 | \n",
91 | " 0.010912 | \n",
92 | " 0.010912 | \n",
93 | " 0.084099 | \n",
94 | " 00:01 | \n",
95 | "
\n",
96 | " \n",
97 | " | 1 | \n",
98 | " 0.012936 | \n",
99 | " 0.010850 | \n",
100 | " 0.010850 | \n",
101 | " 0.083003 | \n",
102 | " 00:01 | \n",
103 | "
\n",
104 | " \n",
105 | " | 2 | \n",
106 | " 0.013433 | \n",
107 | " 0.010741 | \n",
108 | " 0.010741 | \n",
109 | " 0.082137 | \n",
110 | " 00:01 | \n",
111 | "
\n",
112 | " \n",
113 | " | 3 | \n",
114 | " 0.014620 | \n",
115 | " 0.010933 | \n",
116 | " 0.010933 | \n",
117 | " 0.082773 | \n",
118 | " 00:01 | \n",
119 | "
\n",
120 | " \n",
121 | " | 4 | \n",
122 | " 0.015200 | \n",
123 | " 0.010846 | \n",
124 | " 0.010846 | \n",
125 | " 0.083539 | \n",
126 | " 00:01 | \n",
127 | "
\n",
128 | " \n",
129 | " | 5 | \n",
130 | " 0.014374 | \n",
131 | " 0.011287 | \n",
132 | " 0.011287 | \n",
133 | " 0.083522 | \n",
134 | " 00:01 | \n",
135 | "
\n",
136 | " \n",
137 | " | 6 | \n",
138 | " 0.013678 | \n",
139 | " 0.010433 | \n",
140 | " 0.010433 | \n",
141 | " 0.081839 | \n",
142 | " 00:01 | \n",
143 | "
\n",
144 | " \n",
145 | " | 7 | \n",
146 | " 0.013177 | \n",
147 | " 0.010253 | \n",
148 | " 0.010253 | \n",
149 | " 0.080798 | \n",
150 | " 00:01 | \n",
151 | "
\n",
152 | " \n",
153 | " | 8 | \n",
154 | " 0.012035 | \n",
155 | " 0.010105 | \n",
156 | " 0.010105 | \n",
157 | " 0.080134 | \n",
158 | " 00:01 | \n",
159 | "
\n",
160 | " \n",
161 | " | 9 | \n",
162 | " 0.011887 | \n",
163 | " 0.010102 | \n",
164 | " 0.010102 | \n",
165 | " 0.080020 | \n",
166 | " 00:01 | \n",
167 | "
\n",
168 | " \n",
169 | "
"
170 | ],
171 | "text/plain": [
172 | ""
173 | ]
174 | },
175 | "metadata": {},
176 | "output_type": "display_data"
177 | },
178 | {
179 | "data": {
180 | "image/png": "\n",
181 | "text/plain": [
182 | ""
183 | ]
184 | },
185 | "metadata": {
186 | "needs_background": "light"
187 | },
188 | "output_type": "display_data"
189 | }
190 | ],
191 | "source": [
192 | "m = arnet.ARNet(\n",
193 | " ar_order=10,\n",
194 | " sparsity=0.5,\n",
195 | " n_epoch=10,\n",
196 | " loss_func=\"MSE\",\n",
197 | " valid_p=0.1,\n",
198 | " use_gpu=False,\n",
199 | ")\n",
200 | "m = m.fit(df, plot=True)"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 4,
206 | "metadata": {},
207 | "outputs": [
208 | {
209 | "data": {
210 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmkAAAEKCAYAAABJ+cK7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhwUlEQVR4nO3de5xeZX33+8+XcIgotJAgIiEmj6IcSgw6IHLSchI3j4BbqlDloNLIFiyI1tLtsxVbt9BKaxHpQ9ngIx4IKlKJTz1BEDlUkATDMSioAUIRQlAOVdTAb/9xryTDOJMMmbln3TP5vF+ved3rdK/rtyaafFnrWteVqkKSJEm9ZYO2C5AkSdIfMqRJkiT1IEOaJElSDzKkSZIk9SBDmiRJUg8ypEmSJPWgVkNakoOT/DjJPUlOG2T/CUluS7IoyXVJdmqjTkmSpLGWtsZJSzIJ+AlwILAUuAk4qqru7HfM5lX1eLN8KPDeqjq4jXolSZLGUpt30nYH7qmqn1XV74BLgMP6H7AyoDWeDzjyriRJWi9s2GLb2wL391tfCrxm4EFJTgROBTYG9lvbSadOnVozZswYpRIlSZK6Z+HChY9U1VaD7WszpA1LVZ0LnJvkz4H/ARw78Jgkc4A5ANOnT2fBggVjW6QkSdI6SHLvUPvafNz5ALBdv/VpzbahXAIcPtiOqjq/qvqqqm+rrQYNo5IkSeNKmyHtJmD7JDOTbAwcCczrf0CS7futHgLcPYb1SZIktaa1x51VtSLJScB3gEnAZ6vqjiR/CyyoqnnASUkOAH4P/JJBHnVKkiRNRK32SauqbwLfHLDtI/2WTx7zoiRJ0oj8/ve/Z+nSpTz11FNtl9IzJk+ezLRp09hoo42G/Z2ef3FAkiSNL0uXLmWzzTZjxowZJGm7nNZVFcuXL2fp0qXMnDlz2N9zWihJkjSqnnrqKaZMmWJAayRhypQpz/nOoiFNkiSNOgPas63L78OQJkmS1IPskzZBvPqvPt92Cc/Jwk8e03YJkqT11IwZM1iwYAFTp04d0THd5p00SZKkHmRIkyRJPW/JkiXssMMOHHfccbz85S/n7W9/O1deeSV77bUX22+/PT/84Q959NFHOfzww5k1axZ77LEHt956KwDLly/noIMOYuedd+b444+nqlad94tf/CK77747s2fP5j3veQ9PP/10W5f4BwxpkiRpXLjnnnv4wAc+wF133cVdd93FxRdfzHXXXcdZZ53FJz7xCT760Y+y6667cuutt/KJT3yCY47pdK352Mc+xt57780dd9zBm9/8Zu677z4AFi9ezJe//GWuv/56Fi1axKRJk/jSl77U5iU+i33SJEnSuDBz5kx22WUXAHbeeWf2339/krDLLruwZMkS7r33Xr72ta8BsN9++7F8+XIef/xxrrnmGi677DIADjnkELbYYgsA5s+fz8KFC9ltt90A+M1vfsMLX/jCFq5scIY0SZI0LmyyySarljfYYINV6xtssAErVqx4TqP5Q2eQ2WOPPZYzzjhjVOscLT7ulCRJE8I+++yz6nHl1VdfzdSpU9l8883Zd999ufjiiwH41re+xS9/+UsA9t9/fy699FIefvhhAB599FHuvffedoofhHfSJEnShHD66afzrne9i1mzZrHpppty0UUXAfDRj36Uo446ip133pk999yT6dOnA7DTTjvx8Y9/nIMOOohnnnmGjTbaiHPPPZeXvOQlbV7GKun/hsNE0NfXVwsWLGi7jDHnOGmSpF6xePFidtxxx7bL6DmD/V6SLKyqvsGO93GnJElSDzKkSZIk9SBDmiRJUg8ypEmSJPUgQ5okSVIPMqRJkiT1oFbHSUtyMHA2MAm4oKrOHLD/VOB4YAWwDHhXVfXOKHOSJGmtRnuYqOEO4/T1r3+dN7/5zSxevJgddtiBJUuWsOOOO/KKV7yC3/3ud/T19XHhhRcOOlNBEk499VT+8R//EYCzzjqLJ598ktNPP33I9q6++mo23nhj9txzz3W6roFau5OWZBJwLvBGYCfgqCQ7DTjsR0BfVc0CLgX+YWyrlCRJ49XcuXPZe++9mTt37qptL33pS1m0aBG33XYbS5cu5Stf+cqg391kk0247LLLeOSRR4bd3tVXX81//Md/jLjuldp83Lk7cE9V/ayqfgdcAhzW/4Cq+l5V/bpZvQGYNsY1SpKkcejJJ5/kuuuu48ILL+SSSy75g/2TJk1i991354EHHhj0+xtuuCFz5szhU5/61B/sW7ZsGW95y1vYbbfd2G233bj++utZsmQJ5513Hp/61KeYPXs211577Yivoc3HndsC9/dbXwq8Zg3Hvxv4VlcrkiRJE8Lll1/OwQcfzMtf/nKmTJnCwoULmTJlyqr9Tz31FDfeeCNnn332kOc48cQTmTVrFh/60Ieetf3kk0/m/e9/P3vvvTf33Xcfb3jDG1i8eDEnnHACL3jBC/jgBz84KtcwLubuTPIOoA943RD75wBzgFXzcUmSpPXX3LlzOfnkkwE48sgjmTt3LieddBI//elPmT17Nj//+c855JBDmDVr1pDn2HzzzTnmmGP49Kc/zfOe97xV26+88kruvPPOVeuPP/44Tz755KhfQ5sh7QFgu37r05ptz5LkAODDwOuq6reDnaiqzgfOh87cnaNfqiRJGi8effRRrrrqKm677TaS8PTTT5OEE088cVWftEceeYS99tqLefPmseuuu/KmN70JgBNOOIETTjhh1blOOeUUXvWqV/HOd75z1bZnnnmGG264gcmTJ3f1Otrsk3YTsH2SmUk2Bo4E5vU/IMmuwL8Ch1bVwy3UKEmSxplLL72Uo48+mnvvvZclS5Zw//33M3PmTO6/f3Uvq6lTp3LmmWdyxhlnsN1227Fo0SIWLVr0rIAGsOWWW/LWt76VCy+8cNW2gw46iHPOOWfV+qJFiwDYbLPNeOKJJ0btOlq7k1ZVK5KcBHyHzhAcn62qO5L8LbCgquYBnwReAHw1CcB9VXVoWzVLkqTnbrhDZoyWuXPn8td//dfP2vaWt7yFM84441nbDj/8cE4//XSuvfZa9tlnnyHP94EPfIDPfOYzq9Y//elPr+qvtmLFCvbdd1/OO+883vSmN3HEEUdw+eWXc84556zxnMORqon1dLCvr68WLFjQdhljbrTHoOm2sf4/rCRp7CxevJgdd9yx7TJ6zmC/lyQLq6pvsOOdcUCSJKkHGdIkSZJ6kCFNkiSNuonWnWqk1uX3YUiTJEmjavLkySxfvtyg1qgqli9f/pyH7BgXg9lKkqTxY9q0aSxdupRly5a1XUrPmDx5MtOmPbfZLQ1pkiRpVG200UbMnDmz7TLGPR93SpIk9SBDmiRJUg8ypEmSJPUgQ5okSVIPWm9eHBhv0yaBUydJkrQ+806aJElSDzKkSZIk9SBDmiRJUg8ypEmSJPUgQ5okSVIPMqRJkiT1IEOaJElSDzKkSZIk9aBWQ1qSg5P8OMk9SU4bZP++SW5OsiLJEW3UKEmS1IbWQlqSScC5wBuBnYCjkuw04LD7gOOAi8e2OkmSpHa1OS3U7sA9VfUzgCSXAIcBd648oKqWNPueaaNASZKktrT5uHNb4P5+60ubbZIkSeu9CfHiQJI5SRYkWbBs2bK2y5EkSRqxNkPaA8B2/danNdues6o6v6r6qqpvq622GpXiJEmS2tRmSLsJ2D7JzCQbA0cC81qsR5IkqWe0FtKqagVwEvAdYDHwlaq6I8nfJjkUIMluSZYCfwb8a5I72qpXkiRpLLX5didV9U3gmwO2faTf8k10HoNKkiStVybEiwOSJEkTjSFNkiSpBxnSJEmSepAhTZIkqQcZ0iRJknqQIU2SJKkHGdIkSZJ6kCFNkiSpBxnSJEmSepAhTZIkqQcZ0iRJknqQIU2SJKkHDWuC9SRbAC8GfgMsqapnulqVJEnSem7IkJbkj4ATgaOAjYFlwGRg6yQ3AP9SVd8bkyolSZLWM2u6k3Yp8Hlgn6r6Vf8dSfqAdyT5b1V1YRfrkyRJWi8NGdKq6sA17FsALOhKRZIkSVr7iwNJ5g9nmyRJkkbPmvqkTQY2BaY2Lw6k2bU5sO0Y1CZJkrTeWlOftPcAp9B5q3Mhq0Pa48BnuluWJEnS+m3Ix51VdXZVzQQ+WFX/rapmNj+vrKpRCWlJDk7y4yT3JDltkP2bJPlys//GJDNGo11JkqRet9Zx0qrqnCR7AjP6H19Vnx9Jw0kmAecCBwJLgZuSzKuqO/sd9m7gl1X1siRHAn8PvG0k7UqSJI0Haw1pSb4AvBRYBDzdbC46w3OMxO7APVX1s6adS4DDgP4h7TDg9Gb5UuAzSVJVNcK2JUmSetpwZhzoA3bqQjDaFri/3/pS4DVDHVNVK5I8BkwBHhnlWiRpnbz6r0b636tjb+Enj2m7BEnDMJyQdjvwIuDBLteyzpLMAeYATJ8+fdBjJvpfShP9+sbbP4QT/c9Dq/lnLalb1jQExzfoPNbcDLgzyQ+B367cX1WHjrDtB4Dt+q1Pa7YNdszSJBsCfwQsH3iiqjofOB+gr6/PR6GSJGncW9OdtLO63PZNwPZJZtIJY0cCfz7gmHnAscAPgCOAq+yPJkmS1gdrmhbq+91suOljdhLwHWAS8NmquiPJ3wILqmoecCHwhST3AI/SCXKSJEkT3nDe7nyCzmPP/h6jM3fnB1a+nbkuquqbwDcHbPtIv+WngD9b1/NLkiSNV8N5ceCf6bx5eTGdWQeOpDMkx83AZ4HXd6k2SZKk9dZaJ1gHDq2qf62qJ6rq8aaT/huq6svAFl2uT5Ikab00nJD26yRvTbJB8/NW4Klmn534JUmSumA4Ie3twNHAw8BDzfI7kjwPOKmLtUmSJK23hjN358+ANw2x+7rRLUeSJEmw5sFsP1RV/5DkHAZ5rFlVf9nVyiRJktZja7qTtrj5XDAWhUiSJGm1NQ1m+43m8yKAJJtW1a/HqjBJkqT12VpfHEjy2iR3Anc1669M8i9dr0ySJGk9Npy3O/8ZeAPNxOZVdQuwbxdrkiRJWu8NJ6RRVfcP2PR0F2qRJElSYzjTQt2fZE+gkmwEnMzqlwokSZLUBcO5k3YCcCKwLfAAMLtZlyRJUpcM507ak1X19q5XIkmSpFWGE9JuT/IQcG3zc11VPdbdsiRJktZva33cWVUvA44CbgMOAW5JsqjLdUmSJK3X1nonLck0YC9gH+CVwB04Z6ckSVJXDedx533ATcAnquqELtcjSZIkhvd2567A54E/T/KDJJ9P8u4u1yVJkrReG06ftFuAi4D/BVwFvA74yEgaTbJlkiuS3N18bjHEcd9O8qsk/3sk7UmSJI03w5m7cwHwA+DNdAax3beqXjLCdk8D5lfV9sD8Zn0wnwSOHmFbkiRJ485w+qS9saqWjXK7hwGvb5YvAq4G/nrgQVU1P8nrB26XJEma6Ia8k5bkHUk2GCqgJXlpkr3Xsd2tq+rBZvkXwNbreB5JkqQJaU130qYAP0qyEFgILAMmAy+j0y/tEYZ+TEmSK4EXDbLrw/1XqqqS1HOse2Bbc4A5ANOnTx/JqSRJknrCkCGtqs5O8hlgPzrjpM0CfkOnX9rRVXXfmk5cVQcMtS/JQ0m2qaoHk2wDPLxO1a9u63zgfIC+vr4RBT5JkqResMY+aVX1NHBF8zOa5gHHAmc2n5eP8vklSZLGteGMk9YNZwIHJrkbOKBZJ0lfkgtWHpTkWuCrwP5JliZ5QyvVSpIkjbHhvN056qpqObD/INsXAMf3W99nLOuSJEnqFW3dSZMkSdIaDGcw262TXJjkW836Tk4LJUmS1F3DuZP2OeA7wIub9Z8Ap3SpHkmSJDG8kDa1qr4CPANQVSuAp7talSRJ0npuOCHtv5JMAQogyR7AY12tSpIkaT03nLc7T6UzrtlLk1wPbAUc0dWqJEmS1nNrDWlVdXOS1wGvAAL8uKp+3/XKJEmS1mNrDWlJjhmw6VVJqKrPd6kmSZKk9d5wHnfu1m95Mp1BaG8GDGmSJEldMpzHne/rv57kj4FLulWQJEmS1m3Ggf8CZo52IZIkSVptOH3SvkEz/AadULcT8JVuFiVJkrS+G06ftLP6La8A7q2qpV2qR5IkSQyvT9r3x6IQSZIkrTZkSEvyBKsfcz5rF1BVtXnXqpIkSVrPDRnSqmqzsSxEkiRJqw2nTxoASV5IZ5w0AKrqvq5UJEmSpLUPwZHk0CR3Az8Hvg8sAb7V5bokSZLWa8MZJ+3vgD2An1TVTDozDtzQ1aokSZLWc8MJab+vquXABkk2qKrvAX0jaTTJlkmuSHJ387nFIMfMTvKDJHckuTXJ20bSpiRJ0ngynJD2qyQvAK4BvpTkbDqzDozEacD8qtoemN+sD/Rr4Jiq2hk4GPjnZkoqSZKkCW84Ie0wOoHp/cC3gZ8Cbxphu4cBFzXLFwGHDzygqn5SVXc3y/8JPAxsNcJ2JUmSxoXhvN35HuDLVfUAq4PVSG1dVQ82y78Atl7TwUl2BzamExAlSZImvOGEtM2A7yZ5FPgy8NWqemhtX0pyJfCiQXZ9uP9KVVWSwQbNXXmebYAvAMdW1TNDHDMHmAMwffr0tZUmSZLU84YzLdTHgI8lmQW8Dfh+kqVVdcBavjfk/iQPJdmmqh5sQtjDQxy3OfDvwIerasg3SqvqfOB8gL6+viEDnyRJ0ngxnD5pKz1M59HkcuCFI2x3HnBss3wscPnAA5JsDPwb8PmqunSE7UmSJI0rwxnM9r1JrqbzFuYU4C+qatYI2z0TOLAZJPeAZp0kfUkuaI55K7AvcFySRc3P7BG2K0mSNC4Mp0/adsApVbVotBptxl3bf5DtC4Djm+UvAl8crTYlSZLGk+H0SfubsShEkiRJqz2XPmmSJEkaI4Y0SZKkHmRIkyRJ6kGGNEmSpB5kSJMkSepBhjRJkqQeZEiTJEnqQYY0SZKkHmRIkyRJ6kGGNEmSpB5kSJMkSepBhjRJkqQeZEiTJEnqQYY0SZKkHmRIkyRJ6kGGNEmSpB5kSJMkSepBhjRJkqQe1EpIS7JlkiuS3N18bjHIMS9JcnOSRUnuSHJCG7VKkiS1oa07aacB86tqe2B+sz7Qg8Brq2o28BrgtCQvHrsSJUmS2tNWSDsMuKhZvgg4fOABVfW7qvpts7oJPpqVJEnrkbaCz9ZV9WCz/Atg68EOSrJdkluB+4G/r6r/HKsCJUmS2rRht06c5ErgRYPs+nD/laqqJDXYOarqfmBW85jz60kuraqHBmlrDjAHYPr06SOuXZIkqW1dC2lVdcBQ+5I8lGSbqnowyTbAw2s5138muR3YB7h0kP3nA+cD9PX1DRr4JEmSxpO2HnfOA45tlo8FLh94QJJpSZ7XLG8B7A38eMwqlCRJalFbIe1M4MAkdwMHNOsk6UtyQXPMjsCNSW4Bvg+cVVW3tVKtJEnSGOva4841qarlwP6DbF8AHN8sXwHMGuPSJEmSeoLDWkiSJPUgQ5okSVIPMqRJkiT1IEOaJElSDzKkSZIk9SBDmiRJUg8ypEmSJPUgQ5okSVIPMqRJkiT1IEOaJElSDzKkSZIk9SBDmiRJUg8ypEmSJPUgQ5okSVIPMqRJkiT1IEOaJElSDzKkSZIk9SBDmiRJUg8ypEmSJPWgVkJaki2TXJHk7uZzizUcu3mSpUk+M5Y1SpIktamtO2mnAfOrantgfrM+lL8DrhmTqiRJknpEWyHtMOCiZvki4PDBDkryamBr4LtjU5YkSVJvaCukbV1VDzbLv6ATxJ4lyQbAPwIfHMvCJEmSesGG3TpxkiuBFw2y68P9V6qqktQgx70X+GZVLU2ytrbmAHMApk+fvm4FS5Ik9ZCuhbSqOmCofUkeSrJNVT2YZBvg4UEOey2wT5L3Ai8ANk7yZFX9Qf+1qjofOB+gr69vsMAnSZI0rnQtpK3FPOBY4Mzm8/KBB1TV21cuJzkO6BssoEmSJE1EbfVJOxM4MMndwAHNOkn6klzQUk2SJEk9o5U7aVW1HNh/kO0LgOMH2f454HNdL0ySJKlHOOOAJElSDzKkSZIk9SBDmiRJUg8ypEmSJPUgQ5okSVIPMqRJkiT1oLYGs5Wek4WfPKbtEiRJGlPeSZMkSepBhjRJkqQeZEiTJEnqQYY0SZKkHmRIkyRJ6kGGNEmSpB5kSJMkSepBhjRJkqQeZEiTJEnqQamqtmsYVUmWAfeOYZNTgUfGsL2x5vWNb17f+DWRrw28vvHO6xs9L6mqrQbbMeFC2lhLsqCq+tquo1u8vvHN6xu/JvK1gdc33nl9Y8PHnZIkST3IkCZJktSDDGkjd37bBXSZ1ze+eX3j10S+NvD6xjuvbwzYJ02SJKkHeSdNkiSpBxnS1lGSzyZ5OMntbdfSDUm2S/K9JHcmuSPJyW3XNJqSTE7ywyS3NNf3sbZrGm1JJiX5UZL/3XYtoy3JkiS3JVmUZEHb9Yy2JH+c5NIkdyVZnOS1bdc0WpK8ovlzW/nzeJJT2q5rtCR5f/N3yu1J5iaZ3HZNoynJyc213TER/twG+7c8yZZJrkhyd/O5RVv1GdLW3eeAg9suootWAB+oqp2APYATk+zUck2j6bfAflX1SmA2cHCSPdotadSdDCxuu4gu+tOqmt0Lr8l3wdnAt6tqB+CVTKA/x6r6cfPnNht4NfBr4N/arWp0JNkW+Eugr6r+BJgEHNluVaMnyZ8AfwHsTud/l/89ycvarWrEPscf/lt+GjC/qrYH5jfrrTCkraOqugZ4tO06uqWqHqyqm5vlJ+j8I7Ftu1WNnup4slndqPmZMB00k0wDDgEuaLsWPTdJ/gjYF7gQoKp+V1W/arWo7tkf+GlVjeUA5N22IfC8JBsCmwL/2XI9o2lH4Maq+nVVrQC+D/yfLdc0IkP8W34YcFGzfBFw+FjW1J8hTWuVZAawK3Bjy6WMquZx4CLgYeCKqppI1/fPwIeAZ1quo1sK+G6ShUnmtF3MKJsJLAP+V/O4+oIkz2+7qC45EpjbdhGjpaoeAM4C7gMeBB6rqu+2W9Wouh3YJ8mUJJsC/wewXcs1dcPWVfVgs/wLYOu2CjGkaY2SvAD4GnBKVT3edj2jqaqebh65TAN2b27lj3tJ/jvwcFUtbLuWLtq7ql4FvJHOo/h92y5oFG0IvAr4n1W1K/BftPi4pVuSbAwcCny17VpGS9N36TA6QfvFwPOTvKPdqkZPVS0G/h74LvBtYBHwdJs1dVt1hsBo7SmLIU1DSrIRnYD2paq6rO16uqV5lPQ9Jk4fw72AQ5MsAS4B9kvyxXZLGl3NHQuq6mE6/Zl2b7eiUbUUWNrvzu6ldELbRPNG4OaqeqjtQkbRAcDPq2pZVf0euAzYs+WaRlVVXVhVr66qfYFfAj9pu6YueCjJNgDN58NtFWJI06CShE6fmMVV9U9t1zPakmyV5I+b5ecBBwJ3tVrUKKmqv6mqaVU1g87jpKuqasL813yS5yfZbOUycBCdxzATQlX9Arg/ySuaTfsDd7ZYUrccxQR61Nm4D9gjyabN36H7M4Fe+gBI8sLmczqd/mgXt1tRV8wDjm2WjwUub6uQDdtqeLxLMhd4PTA1yVLgo1V1YbtVjaq9gKOB25p+WwD/d1V9s72SRtU2wEVJJtH5j5WvVNWEG6pigtoa+LfOv4FsCFxcVd9ut6RR9z7gS80jwZ8B72y5nlHVhOsDgfe0Xctoqqobk1wK3EznDfkf0SMj14+iryWZAvweOHG8v9Qy2L/lwJnAV5K8G7gXeGtr9TnjgCRJUu/xcackSVIPMqRJkiT1IEOaJElSDzKkSZIk9SBDmiRJUg8ypElaoySHJ6kkO/TbNiPJb5IsSnJnks83gx93q4ZNklzZtPe2JPskuaNZ37YZ9mBN378gyU7r2Pbrkww5IGnz+/lIs3x6kg+uSztjJcnVSdZ5UvpmjMGJNuSJ1JMMaZLW5ijguuazv58202rtQmdqrW6OJbQrQFXNrqovA28HzmjWH6iqI9b05ao6vqrWdUDY17PmUeM/BPzLOp57XEmyYVUtAx5Mslfb9UgTnSFN0pCauVv3Bt5NZ/aCP1BVTwM/BLYd4hzHJLk1yS1JvtBsm5Hkqmb7/Gb08pV3ab6W5KbmZ69mhPMvArs1d87eQycQ/l2SLzXnur35/qQkZyW5vTn3+5rtq+4eJTkoyQ+S3Jzkq801kmRJko81229LskOSGcAJwPubtvcZcG0vB35bVY8Mct1/0VzDLc01bdpsf2mSG5o2Pp7kyUG+OyPJ4iT/X3PH8LvNzBgDr2VqM/0XSY5L8vUkVzTXclKSU9OZpP2GJFv2a+Lo5npuT7J78/3nJ/lskh823zms33nnJbkKmN98/+t0grKkLjKkSVqTw4BvV9VPgOVJXj3wgCSTgdfQmXB54L6dgf8B7FdVrwRObnadA1xUVbOALwGfbrafDXyqqnYD3gJc0MzPeTxwbXPn7F/pTNvyV1U1MCjMAWYAs/udu389U5t6DmgmaF8AnNrvkEea7f8T+GBVLQHOa2qaXVXXDmhvLzqjyw/msqrarbnuxXSC7sprPLuqdqEzT+dQtgfOraqdgV81v4+1+RM6U/XsBvy/wK+bSdp/ABzT77hNm7ug7wU+22z7MJ0pxHYH/hT4ZDMzAHTmDj2iql7XrC8AnhVYJY0+p4WStCZH0QkV0Jms/ShgYbP+0nSmDJsJ/HtV3TrI9/cDvrryTlNVPdpsfy2dMAHwBeAfmuUDgJ3SmfIJYPOVd7qG6QDgvKpaMaC9lfYAdgKub9rYmE6AWemy5nNhv/rWZBtg2RD7/iTJx4E/Bl4AfKfZ/lrg8Gb5YuCsIb7/86pa1K+eGcOo53tV9QTwRJLHgG80228DZvU7bi5AVV2TZPN05rE9CDi0X5+6ycD0ZvmKAb/Lh4EXD6MeSSNgSJM0qObx2H7ALkkKmARUkr9qDvlpVc1u7k5dn+RQOnMVrgwG561DsxsAe1TVUwNqWadrGEToBI6B/etW+m3z+TTD+/vxN8AfDbHvc8DhVXVLkuPo9G17Ln7bb/lp4HnN8gpWPwWZvIbvPNNv/RmefT0D5wMsOr+bt1TVj/vvSPIa4L8GHD+ZzrVL6iIfd0oayhHAF6rqJVU1o6q2A37OgMdczV2y04C/qar7m8eCs6vqPOAq4M/SmZB5ZfAD+A9W93F7O7DyMeJ36UwuTnP87OdY8xXAe5JsOKC9lW4A9krysmb/85t+ZWvyBLDZEPsWAy8bYt9mdDrYb8Sz+2/dwOpHl4P281uLJcDKx85rfGFiDd4GkGRv4LGqeozOnb73pUnESXZdw/dfDty+jm1LGiZDmqShHAX824BtX+MP3/KETkfyTQd2rK+qO+j0jfp+kluAf2p2vQ94Z5JbgaNZ3VftL4G+ptP/nXQ67T8XFwD3Abc27f35gHqWAccBc5u2fwDsMPAkA3wDePNgLw4A1wC7ZvBbff8PcCNwPXBXv+2nAKc27b8MeGwY19XfWcD/leRHwNTn+N2Vnmq+fx6r+8r9HbARnd/dHc36UP4U+Pd1bFvSMKVq4F1vSdJwJTkb+EZVXTnM4zcFflNVleRI4KiqOqyrRY6yJNcAh1XVL9uuRZrI7JMmSSPzCTpvtw7Xq4HPNHfffgW8qxtFdUuSrYB/MqBJ3eedNEmSpB5knzRJkqQeZEiTJEnqQYY0SZKkHmRIkyRJ6kGGNEmSpB5kSJMkSepB/z+6OpguWFsORQAAAABJRU5ErkJggg==\n",
211 | "text/plain": [
212 | ""
213 | ]
214 | },
215 | "metadata": {
216 | "needs_background": "light"
217 | },
218 | "output_type": "display_data"
219 | }
220 | ],
221 | "source": [
222 | "m.plot_weights()"
223 | ]
224 | }
225 | ],
226 | "metadata": {
227 | "kernelspec": {
228 | "display_name": "light-matter",
229 | "language": "python",
230 | "name": "light-matter"
231 | },
232 | "language_info": {
233 | "codemirror_mode": {
234 | "name": "ipython",
235 | "version": 3
236 | },
237 | "file_extension": ".py",
238 | "mimetype": "text/x-python",
239 | "name": "python",
240 | "nbconvert_exporter": "python",
241 | "pygments_lexer": "ipython3",
242 | "version": "3.8.6"
243 | }
244 | },
245 | "nbformat": 4,
246 | "nbformat_minor": 4
247 | }
248 |
--------------------------------------------------------------------------------
/example_notebooks/legacy_run_experiments.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import pandas as pd\n",
11 | "import os\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "\n",
14 | "import arnet\n",
15 | "from arnet.ar_net_legacy import init_ar_learner\n",
16 | "\n",
17 | "%matplotlib inline"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 3,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "# hyperparameters\n",
27 | "save = True\n",
28 | "sparsity = 0.5 # guesstimate\n",
29 | "ar_order = 10 # guesstimate\n",
30 | "\n",
31 | "n_epoch = 10\n",
32 | "valid_p = 0.1"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "# data\n",
42 | "DIR = os.path.dirname(os.path.abspath(''))\n",
43 | "data_path = os.path.join(DIR, 'ar_data')\n",
44 | "results_path = os.path.join(DIR, 'results')\n",
45 | "models_path = os.path.join(DIR, 'models')\n",
46 | "if not os.path.exists(results_path): \n",
47 | " os.makedirs(results_path) \n",
48 | "if not os.path.exists(models_path): \n",
49 | " os.makedirs(models_path)\n",
50 | " \n",
51 | "data_names = ['ar_3_ma_0_noise_0.100_len_10000'] "
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 4,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "fitting: ar_3_ma_0_noise_0.100_len_10000\n"
64 | ]
65 | },
66 | {
67 | "data": {
68 | "text/html": [],
69 | "text/plain": [
70 | ""
71 | ]
72 | },
73 | "metadata": {},
74 | "output_type": "display_data"
75 | },
76 | {
77 | "data": {
78 | "image/png": "\n",
79 | "text/plain": [
80 | ""
81 | ]
82 | },
83 | "metadata": {
84 | "needs_background": "light"
85 | },
86 | "output_type": "display_data"
87 | },
88 | {
89 | "name": "stdout",
90 | "output_type": "stream",
91 | "text": [
92 | "lr at minimum: 0.05623413324356079; steeptes lr: 0.020417379215359688\n"
93 | ]
94 | },
95 | {
96 | "data": {
97 | "text/html": [
98 | "\n",
99 | " \n",
100 | " \n",
101 | " | epoch | \n",
102 | " train_loss | \n",
103 | " valid_loss | \n",
104 | " mae | \n",
105 | " time | \n",
106 | "
\n",
107 | " \n",
108 | " \n",
109 | " \n",
110 | " | 0 | \n",
111 | " 0.020248 | \n",
112 | " 0.013306 | \n",
113 | " 0.091418 | \n",
114 | " 00:01 | \n",
115 | "
\n",
116 | " \n",
117 | " | 1 | \n",
118 | " 0.010328 | \n",
119 | " 0.010245 | \n",
120 | " 0.080722 | \n",
121 | " 00:01 | \n",
122 | "
\n",
123 | " \n",
124 | " | 2 | \n",
125 | " 0.010258 | \n",
126 | " 0.010097 | \n",
127 | " 0.080104 | \n",
128 | " 00:01 | \n",
129 | "
\n",
130 | " \n",
131 | " | 3 | \n",
132 | " 0.010552 | \n",
133 | " 0.010240 | \n",
134 | " 0.080945 | \n",
135 | " 00:01 | \n",
136 | "
\n",
137 | " \n",
138 | " | 4 | \n",
139 | " 0.010190 | \n",
140 | " 0.010257 | \n",
141 | " 0.080975 | \n",
142 | " 00:01 | \n",
143 | "
\n",
144 | " \n",
145 | " | 5 | \n",
146 | " 0.010604 | \n",
147 | " 0.010100 | \n",
148 | " 0.080252 | \n",
149 | " 00:01 | \n",
150 | "
\n",
151 | " \n",
152 | " | 6 | \n",
153 | " 0.010374 | \n",
154 | " 0.010058 | \n",
155 | " 0.079831 | \n",
156 | " 00:01 | \n",
157 | "
\n",
158 | " \n",
159 | " | 7 | \n",
160 | " 0.010370 | \n",
161 | " 0.010076 | \n",
162 | " 0.079970 | \n",
163 | " 00:01 | \n",
164 | "
\n",
165 | " \n",
166 | " | 8 | \n",
167 | " 0.010277 | \n",
168 | " 0.010079 | \n",
169 | " 0.079972 | \n",
170 | " 00:01 | \n",
171 | "
\n",
172 | " \n",
173 | " | 9 | \n",
174 | " 0.010468 | \n",
175 | " 0.010070 | \n",
176 | " 0.079924 | \n",
177 | " 00:01 | \n",
178 | "
\n",
179 | " \n",
180 | "
"
181 | ],
182 | "text/plain": [
183 | ""
184 | ]
185 | },
186 | "metadata": {},
187 | "output_type": "display_data"
188 | },
189 | {
190 | "data": {
191 | "image/png": "\n",
192 | "text/plain": [
193 | ""
194 | ]
195 | },
196 | "metadata": {
197 | "needs_background": "light"
198 | },
199 | "output_type": "display_data"
200 | }
201 | ],
202 | "source": [
203 | "# fit and collect AR coefficients\n",
204 | "coeff_list = []\n",
205 | "for name in data_names:\n",
206 | " print(\"fitting: {}\".format(name))\n",
207 | " df = pd.read_csv(os.path.join(data_path, name + '.csv'), header=None, index_col=False)\n",
208 | " # Init Model\n",
209 | " learn = init_ar_learner(\n",
210 | " series=df,\n",
211 | " ar_order=ar_order,\n",
212 | " n_forecasts=1,\n",
213 | " valid_p=valid_p,\n",
214 | " sparsity=sparsity,\n",
215 | " train_bs=32,\n",
216 | " valid_bs=1024,\n",
217 | " verbose=False,\n",
218 | " )\n",
219 | " # find Learning Rate\n",
220 | " lr_at_min, lr_steep = learn.lr_find(start_lr=1e-7, end_lr=100, num_it=100, show_plot=True)\n",
221 | " plt.show()\n",
222 | " print(\"lr at minimum: {}; steeptes lr: {}\".format(lr_at_min, lr_steep))\n",
223 | " lr_max = lr_at_min/10\n",
224 | "\n",
225 | " # Run Model\n",
226 | " learn.fit_one_cycle(n_epoch=n_epoch, lr_max=lr_max)\n",
227 | " learn.recorder.plot_loss()\n",
228 | " plt.show()\n",
229 | " # record Coeff\n",
230 | " coeff = arnet.coeff_from_model(learn.model) \n",
231 | " coeff_list.append({\"name\": name, \"coeff\": coeff[0]})\n",
232 | "df_coeff = pd.DataFrame(coeff_list)\n"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": 5,
238 | "metadata": {},
239 | "outputs": [
240 | {
241 | "data": {
242 | "text/html": [
243 | "\n",
244 | "\n",
257 | "
\n",
258 | " \n",
259 | " \n",
260 | " | \n",
261 | " name | \n",
262 | " coeff | \n",
263 | "
\n",
264 | " \n",
265 | " \n",
266 | " \n",
267 | " | 0 | \n",
268 | " ar_3_ma_0_noise_0.100_len_10000 | \n",
269 | " [0.19604042, 0.2887145, -0.4777012, 0.0003959917, 1.7987082e-05, 3.091372e-05, 0.00032146755, -1.1136071e-05, 2.8788853e-07, -2.827955e-06] | \n",
270 | "
\n",
271 | " \n",
272 | "
\n",
273 | "
"
274 | ],
275 | "text/plain": [
276 | " name \\\n",
277 | "0 ar_3_ma_0_noise_0.100_len_10000 \n",
278 | "\n",
279 | " coeff \n",
280 | "0 [0.19604042, 0.2887145, -0.4777012, 0.0003959917, 1.7987082e-05, 3.091372e-05, 0.00032146755, -1.1136071e-05, 2.8788853e-07, -2.827955e-06] "
281 | ]
282 | },
283 | "execution_count": 5,
284 | "metadata": {},
285 | "output_type": "execute_result"
286 | }
287 | ],
288 | "source": [
289 | "df_coeff"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": 6,
295 | "metadata": {},
296 | "outputs": [
297 | {
298 | "data": {
299 | "image/png": "\n",
300 | "text/plain": [
301 | ""
302 | ]
303 | },
304 | "metadata": {
305 | "needs_background": "light"
306 | },
307 | "output_type": "display_data"
308 | }
309 | ],
310 | "source": [
311 | "if save:\n",
312 | " df_coeff.to_csv(\n",
313 | " os.path.join(results_path, \"coeff_ar-{}_spar-{}.csv\".format(ar_order, sparsity)),\n",
314 | " index=False,\n",
315 | " )\n",
316 | " for index, row in df_coeff.iterrows():\n",
317 | " arnet.plot_weights(\n",
318 | " ar_val=ar_order,\n",
319 | " weights=row[\"coeff\"],\n",
320 | " model_name=row[\"name\"],\n",
321 | " save=True,\n",
322 | " savedir=results_path,\n",
323 | " figsize=(10,3),\n",
324 | " )"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": null,
330 | "metadata": {},
331 | "outputs": [],
332 | "source": []
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": null,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": []
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": null,
344 | "metadata": {},
345 | "outputs": [],
346 | "source": []
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": null,
351 | "metadata": {},
352 | "outputs": [],
353 | "source": []
354 | }
355 | ],
356 | "metadata": {
357 | "kernelspec": {
358 | "display_name": "light-matter",
359 | "language": "python",
360 | "name": "light-matter"
361 | },
362 | "language_info": {
363 | "codemirror_mode": {
364 | "name": "ipython",
365 | "version": 3
366 | },
367 | "file_extension": ".py",
368 | "mimetype": "text/x-python",
369 | "name": "python",
370 | "nbconvert_exporter": "python",
371 | "pygments_lexer": "ipython3",
372 | "version": "3.8.6"
373 | }
374 | },
375 | "nbformat": 4,
376 | "nbformat_minor": 4
377 | }
378 |
--------------------------------------------------------------------------------