├── .gitignore
├── .idea
├── .gitignore
├── csv-plugin.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── other.xml
├── sysid-transfer-functions-pytorch.iml
└── vcs.xml
├── LICENSE
├── README.md
├── doc
├── paper
│ ├── fig
│ │ ├── PWH_timetrace.pdf
│ │ ├── WH_H.pdf
│ │ ├── WH_timetrace.pdf
│ │ ├── WH_timetrace_zoom.pdf
│ │ ├── backprop_tf_ab.pdf
│ │ ├── backprop_tf_ab.svg
│ │ ├── backprop_tf_ab_circle.pdf
│ │ ├── backprop_tf_ab_circle.svg
│ │ ├── backprop_tf_ab_stable.pdf
│ │ ├── backprop_tf_ab_stable.svg
│ │ ├── backprop_tf_ab_stable_mod.pdf
│ │ ├── backprop_tf_ab_stable_mod.svg
│ │ ├── coeff_2ndorder.svg
│ │ ├── dynonet_quant.pdf
│ │ ├── dynonet_quant.svg
│ │ ├── dynonet_quant_old.svg
│ │ ├── generalized_HW.pdf
│ │ ├── hammer.pdf
│ │ ├── hammerstein_wiener.pdf
│ │ ├── neural_PEM.pdf
│ │ ├── neural_PEM.svg
│ │ ├── neural_PEM_old.svg
│ │ ├── parallel_WH.pdf
│ │ ├── rho.svg
│ │ ├── rhopsi.png
│ │ ├── stable_2ndorder.pdf
│ │ ├── sym
│ │ │ ├── F.svg
│ │ │ ├── F1.svg
│ │ │ ├── G1.svg
│ │ │ ├── G2.svg
│ │ │ ├── GH.svg
│ │ │ ├── GW.svg
│ │ │ ├── Gz.svg
│ │ │ ├── H_hat.svg
│ │ │ ├── H_inv.svg
│ │ │ ├── H_inv_check.svg
│ │ │ ├── H_inv_par.svg
│ │ │ ├── H_star.svg
│ │ │ ├── L_k.svg
│ │ │ ├── L_k_theta.svg
│ │ │ ├── L_t_theta.svg
│ │ │ ├── Lcal.svg
│ │ │ ├── M_par.svg
│ │ │ ├── NN.svg
│ │ │ ├── U_k.svg
│ │ │ ├── U_t.svg
│ │ │ ├── ab.svg
│ │ │ ├── dots.svg
│ │ │ ├── eps_t_theta.svg
│ │ │ ├── epsilon.svg
│ │ │ ├── epsilon_hat.svg
│ │ │ ├── epsilon_k_theta.svg
│ │ │ ├── epsilon_t_theta_comma.svg
│ │ │ ├── epsilon_t_theta_vec.png
│ │ │ ├── et.svg
│ │ │ ├── minus.svg
│ │ │ ├── plus.svg
│ │ │ ├── sigma_e.svg
│ │ │ ├── thetabar.svg
│ │ │ ├── ut.svg
│ │ │ ├── uvec.svg
│ │ │ ├── uvecbar.svg
│ │ │ ├── v_hat.svg
│ │ │ ├── v_meas.svg
│ │ │ ├── wt.svg
│ │ │ ├── xt.svg
│ │ │ ├── y_hat.svg
│ │ │ ├── y_ol.svg
│ │ │ ├── y_sim.svg
│ │ │ ├── y_sim_t.svg
│ │ │ ├── y_t.svg
│ │ │ ├── y_t_vec.svg
│ │ │ ├── yt.svg
│ │ │ ├── yvec.svg
│ │ │ ├── yvecbar.svg
│ │ │ ├── z_k.svg
│ │ │ ├── z_t.svg
│ │ │ └── zt.svg
│ │ ├── transf.svg
│ │ ├── wiener.pdf
│ │ └── wiener_hammerstein.pdf
│ ├── ms.bib
│ ├── ms.pdf
│ └── ms.tex
└── slides
│ ├── preamble.tex
│ ├── presentation_main.pdf
│ └── presentation_main.tex
├── examples
├── ParWH
│ ├── __init__.py
│ ├── models.py
│ ├── models
│ │ └── PWH
│ │ │ └── PWH.pt
│ ├── parWH_plot_signals.py
│ ├── parWH_test.py
│ ├── parWH_train_NLSQ.py
│ └── parWH_train_quant_ML.py
├── WH2009
│ ├── WH2009_test.py
│ ├── WH2009_train.py
│ ├── WH2009_train_colored_noise_PEM.py
│ ├── WH2009_train_process_noise.py
│ ├── WH2009_train_quantized.py
│ └── __init__.py
└── __init__.py
├── fig
├── dynonet_quant.png
├── dynonet_quant.svg
├── neural_PEM.png
└── neural_PEM.svg
├── sphinx
├── Makefile
├── make.bat
└── source
│ ├── code.rst
│ ├── conf.py
│ └── index.rst
├── torchid
├── __init__.py
├── functional
│ ├── __init__.py
│ └── lti.py
└── module
│ ├── __init__.py
│ ├── lti.py
│ └── static.py
├── torchid_nb
├── __init__.py
├── functional
│ ├── __init__.py
│ └── lti.py
└── module
│ ├── __init__.py
│ ├── lti.py
│ └── static.py
└── util
├── __init__.py
├── filtering.py
└── metrics.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /examples/WH2009/data/
2 | /examples/WH2009/models/
3 | /examples/WH2009/fig/
4 | /examples/ParWH/data/
5 | /examples/ParWH/models
6 | /examples/ParWH/fig
7 | /sphinx/build/
8 |
9 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/csv-plugin.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
36 |
37 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/sysid-transfer-functions-pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Marco Forgione
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep learning with transfer functions: New applications in system identification
2 |
3 |
4 | This repository contains the Python code to reproduce the results of the paper [Deep learning with transfer functions: new applications in system identification](https://arxiv.org/abs/2104.09839) by Dario Piga, Marco Forgione, and Manas Mejari.
5 |
6 | We present a linear transfer function block, endowed with a well-defined and efficient back-propagation behavior for
7 | automatic derivatives computation. In the dynoNet architecture (already introduced [here](https://github.com/forgi86/dynonet)), linear dynamical operators are combined with static (i.e., memoryless) non-linearities which can be either elementary
8 | activation functions applied channel-wise; fully connected feed-forward neural networks; or other differentiable operators.
9 |
10 | In this work, we use the differentiable transfer function operator to tackle
11 | other challenging problems in system identification. In particular, we consider the problems of:
12 |
13 | 1. Learning of neural dynamical models in the presence of colored noise (prediction error minimization method)
14 | 1. Learning of dynoNet models from quantized output observations (maximum likelihood estimation method)
15 |
16 |
17 | Problem 1. is tackled by extending the prediction error minimization method to deep learning models. A trainable linear transfer function block
18 | is used to describe the power spectrum of the noise:
19 |
20 |
21 |
22 | Problem 2. is tackled by training a dynoNet model with a loss function corresponding to the log-likelihood of quantized observations:
23 |
24 |
25 | # Folders:
26 | * [torchid](torchid_nb): PyTorch implementation of the linear dynamical operator (aka G-block in the paper) used in dynoNet
27 | * [examples](examples): examples using dynoNet for system identification
28 | * [util](util): definition of metrics R-square, RMSE, fit index
29 |
30 | Two [examples](examples) discussed in the paper are:
31 |
32 | * [WH2009](examples/WH2009): A circuit with Wiener-Hammerstein structure. Experimental dataset from http://www.nonlinearbenchmark.org
33 | * [Parallel Wiener-Hammerstein](examples/ParWH): A circuit with a two-branch parallel Wiener-Hammerstein structure. Experimental dataset from http://www.nonlinearbenchmark.org
34 |
35 |
36 | For the [WH2009](examples/WH2009) example, the main scripts are:
37 |
38 | * ``WH2009_train_colored_noise_PEM.py``: Training of a dynoNet model with the prediction error method in presence of colored noise
39 | * ``WH2009_test.py``: Evaluation of the dynoNet model on the original test dataset, computation of metrics, plots.
40 |
41 | For the [Parallel Wiener-Hammerstein](examples/ParWH) example, the main scripts are:
42 |
43 | * ``parWH_train_quant_ML.py``: Training of a dynoNet model with maximum likelihood in presence of quantized measurements
44 | * ``parWH_test.py``: Evaluation of the dynoNet model on the original test dataset, computation of metrics, plots.
45 |
46 |
47 | NOTE: the original data sets are not included in this project. They have to be manually downloaded from
48 | http://www.nonlinearbenchmark.org and copied in the data sub-folder of the example.
49 | # Software requirements:
50 | Simulations were performed on a Python 3.7 conda environment with
51 |
52 | * numpy
53 | * scipy
54 | * matplotlib
55 | * pandas
56 | * numba
57 | * pytorch (version 1.6)
58 |
59 | These dependencies may be installed through the commands:
60 |
61 | ```
62 | conda install numpy scipy pandas numba matplotlib
63 | conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
64 | ```
65 |
66 | # Citing
67 |
68 | If you find this project useful, we encourage you to
69 |
70 | * Star this repository :star:
71 | * Cite the [paper](https://onlinelibrary.wiley.com/doi/abs/10.1002/acs.3216)
72 | ```
73 | @inproceedings{piga2021a,
74 | title={Deep learning with transfer functions: new applications in system identification},
75 | author={Piga, D. and Forgione, M. and Mejari, M.},
76 | booktitle={Proc. of the 19th IFAC Symposium System Identification: learning models for decision and control},
77 | year={2021}
78 | }
79 | ```
80 |
--------------------------------------------------------------------------------
/doc/paper/fig/PWH_timetrace.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/PWH_timetrace.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/WH_H.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/WH_H.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/WH_timetrace.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/WH_timetrace.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/WH_timetrace_zoom.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/WH_timetrace_zoom.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/backprop_tf_ab.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/backprop_tf_ab.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/backprop_tf_ab_circle.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/backprop_tf_ab_circle.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/backprop_tf_ab_stable.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/backprop_tf_ab_stable.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/backprop_tf_ab_stable_mod.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/backprop_tf_ab_stable_mod.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/coeff_2ndorder.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/dynonet_quant.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/dynonet_quant.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/generalized_HW.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/generalized_HW.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/hammer.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/hammer.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/hammerstein_wiener.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/hammerstein_wiener.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/neural_PEM.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/neural_PEM.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/parallel_WH.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/parallel_WH.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/rho.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/rhopsi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/rhopsi.png
--------------------------------------------------------------------------------
/doc/paper/fig/stable_2ndorder.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/stable_2ndorder.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/sym/F.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/F1.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
19 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/G1.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
21 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/G2.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
19 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/Gz.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/H_hat.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/H_inv.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/H_star.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/L_k.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/L_k_theta.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/L_t_theta.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/Lcal.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/NN.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/U_k.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/U_t.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/ab.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/dots.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/eps_t_theta.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/epsilon.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/epsilon_hat.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/epsilon_k_theta.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/epsilon_t_theta_comma.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/epsilon_t_theta_vec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/sym/epsilon_t_theta_vec.png
--------------------------------------------------------------------------------
/doc/paper/fig/sym/et.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
17 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/minus.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/plus.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/sigma_e.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/thetabar.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/ut.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
17 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/uvec.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/uvecbar.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/v_hat.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/v_meas.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/wt.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
17 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/xt.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
17 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/y_hat.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/y_ol.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/y_sim.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/y_sim_t.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/y_t.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/y_t_vec.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/yt.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
17 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/yvec.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/yvecbar.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/z_k.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/z_t.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/sym/zt.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
17 |
--------------------------------------------------------------------------------
/doc/paper/fig/transf.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/doc/paper/fig/wiener.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/wiener.pdf
--------------------------------------------------------------------------------
/doc/paper/fig/wiener_hammerstein.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/wiener_hammerstein.pdf
--------------------------------------------------------------------------------
/doc/paper/ms.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/ms.pdf
--------------------------------------------------------------------------------
/doc/slides/preamble.tex:
--------------------------------------------------------------------------------
1 | % does not look nice, try deleting the line with the fontenc.
2 | \usepackage[english]{babel}
3 | \usepackage{amsmath}
4 | \usepackage[latin1]{inputenc}
5 | \usepackage{units}
6 | \usepackage{colortbl}
7 | \usepackage{multimedia}
8 | \usepackage{bm}
9 | \usepackage{subcaption}
10 | \usepackage{algorithm2e}
11 | \usepackage{algorithmic}
12 |
13 | \mode
14 | {
15 | \usetheme{Boadilla}
16 | \useoutertheme{infolines}
17 | \setbeamercovered{transparent}
18 | }
19 |
20 |
21 | \title[Dynamical systems \& deep learning]{{Deep learning with transfer functions: New applications in system identification}}
22 |
23 |
24 | \author[]{Dario Piga, \underline{Marco Forgione}, Manas Mejari}
25 |
26 | \institute[IDSIA]{IDSIA Dalle Molle Institute for Artificial Intelligence USI-SUPSI, Lugano, Switzerland}
27 |
28 |
29 | \date[]{19th IFAC symposium System Identification: learning models for decision and control}
30 |
31 |
32 | \subject{System Identification, Deep Learning, Machine Learning, Regularization}
33 |
34 |
35 | %% MATH DEFINITIONS %%
36 | \newcommand{\So}{S_o} % true system
37 | \newcommand{\hidden}[1]{\overline{#1}}
38 | \newcommand{\nsamp}{N}
39 | \newcommand{\Yid}{Y}
40 | \newcommand{\Uid}{U}
41 | \newcommand{\Did}{{\mathcal{D}}}
42 | \newcommand{\tens}[1]{\bm{#1}}
43 |
44 | \newcommand{\batchsize}{q}
45 | \newcommand{\seqlen}{m}
46 | \newcommand{\nin}{n_u}
47 | \newcommand{\ny}{n_y}
48 | \newcommand{\nx}{n_x}
49 |
50 | \newcommand{\NN}{\mathcal{N}} % a feedforward neural network
51 |
52 | \newcommand{\norm}[1]{\left \lVert #1 \right \rVert}
53 | \DeclareMathOperator*\argmin{arg \, min}
54 | \newcommand{\Name}{\emph{dynoNet}}
55 |
56 |
57 | %% DYNONET MATH DEFINITIONS %%
58 | \newcommand{\q}{q} % shift operator
59 | \newcommand{\A}{A} % autoregressive polynomial
60 | \newcommand{\ac}{a} % autoregressive polynomial coefficient
61 | \newcommand{\B}{B} % exogenous polynomial
62 | \newcommand{\bb}{b} % exogenous polynomial coefficient
63 | \newcommand{\Gmat}{\mathbb{G}} % transfer function operator in matrix form
64 | \newcommand{\tvec}[1]{\bm{#1}}
65 | \newcommand{\mat}[1]{\bm{#1}}
66 | \newcommand{\sens}[1]{\tilde{#1}}
67 | \newcommand{\adjoint}[1]{\overline{#1}}
68 | \newcommand{\loss}{\mathcal{L}}
69 | \newcommand{\pdiff}[2]{\frac{\partial #1}{\partial #2}}
70 | %\newcommand{\nsamp}{T}
71 |
72 | \newcommand{\conv}{*}
73 | \newcommand{\ccorr}{\star}
74 | \definecolor{orange}{RGB}{204, 85, 0}
75 | %% DYNONET %%
76 |
--------------------------------------------------------------------------------
/doc/slides/presentation_main.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/slides/presentation_main.pdf
--------------------------------------------------------------------------------
/examples/ParWH/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/examples/ParWH/__init__.py
--------------------------------------------------------------------------------
/examples/ParWH/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchid_nb.module.lti import MimoLinearDynamicalOperator, SisoLinearDynamicalOperator
3 | from torchid_nb.module.static import MimoStaticNonLinearity, MimoChannelWiseNonLinearity
4 |
5 |
6 | class ParallelWHNet(torch.nn.Module):
7 | def __init__(self, nb_1=12, na_1=12, nb_2=13, na_2=12):
8 | super(ParallelWHNet, self).__init__()
9 | self.nb_1 = nb_1
10 | self.na_1 = na_1
11 | self.nb_2 = nb_2
12 | self.na_2 = na_2
13 | self.G1 = MimoLinearDynamicalOperator(1, 2, n_b=self.nb_1, n_a=self.na_1, n_k=1)
14 | self.F_nl = MimoChannelWiseNonLinearity(2, n_hidden=10)
15 | #self.F_nl = MimoStaticNonLinearity(2, 2, n_hidden=10)
16 | self.G2 = MimoLinearDynamicalOperator(2, 1, n_b=self.nb_2, n_a=self.na_2, n_k=0)
17 | #self.G3 = SisoLinearDynamicalOperator(n_b=3, n_a=3, n_k=1)
18 |
19 | def forward(self, u):
20 | y1_lin = self.G1(u)
21 | y1_nl = self.F_nl(y1_lin) # B, T, C1
22 | y2_lin = self.G2(y1_nl) # B, T, C2
23 |
24 | return y2_lin #+ self.G3(u)
25 |
26 |
27 | class ParallelWHNetVar(torch.nn.Module):
28 | def __init__(self):
29 | super(ParallelWHNetVar, self).__init__()
30 | self.nb_1 = 3
31 | self.na_1 = 3
32 | self.nb_2 = 3
33 | self.na_2 = 3
34 | self.G1 = MimoLinearDynamicalOperator(1, 16, n_b=self.nb_1, n_a=self.na_1, n_k=1)
35 | self.F_nl = MimoStaticNonLinearity(16, 16) #MimoChannelWiseNonLinearity(16, n_hidden=10)
36 | self.G2 = MimoLinearDynamicalOperator(16, 1, n_b=self.nb_2, n_a=self.na_2, n_k=1)
37 |
38 | def forward(self, u):
39 | y1_lin = self.G1(u)
40 | y1_nl = self.F_nl(y1_lin) # B, T, C1
41 | y2_lin = self.G2(y1_nl) # B, T, C2
42 |
43 | return y2_lin
44 |
45 |
46 | class ParallelWHResNet(torch.nn.Module):
47 | def __init__(self):
48 | super(ParallelWHResNet, self).__init__()
49 | self.nb_1 = 4
50 | self.na_1 = 4
51 | self.nb_2 = 4
52 | self.na_2 = 4
53 | self.G1 = MimoLinearDynamicalOperator(1, 2, n_b=self.nb_1, n_a=self.na_1, n_k=1)
54 | self.F_nl = MimoChannelWiseNonLinearity(2, n_hidden=10)
55 | self.G2 = MimoLinearDynamicalOperator(2, 1, n_b=self.nb_2, n_a=self.na_2, n_k=1)
56 | self.G3 = SisoLinearDynamicalOperator(n_b=6, n_a=6, n_k=1)
57 |
58 | def forward(self, u):
59 | y1_lin = self.G1(u)
60 | y1_nl = self.F_nl(y1_lin) # B, T, C1
61 | y2_lin = self.G2(y1_nl) # B, T, C2
62 |
63 | return y2_lin + self.G3(u)
--------------------------------------------------------------------------------
/examples/ParWH/models/PWH/PWH.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/examples/ParWH/models/PWH/PWH.pt
--------------------------------------------------------------------------------
/examples/ParWH/parWH_plot_signals.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import os
4 | import matplotlib.pyplot as plt
5 |
6 | if __name__ == '__main__':
7 |
8 | N = 16384 # number of samples per period
9 | M = 20 # number of random phase multisine realizations
10 | P = 2 # number of periods
11 | nAmp = 5 # number of different amplitudes
12 |
13 | # Column names in the dataset
14 | COL_F = ['fs']
15 | TAG_U = 'u'
16 | TAG_Y = 'y'
17 |
18 | # Load dataset
19 | #df_X = pd.read_csv(os.path.join("data", "WH_CombinedZeroMultisineSinesweep.csv"))
20 | df_X = pd.read_csv(os.path.join("data", "ParWHData_Estimation_Level2.csv"))
21 | df_X.columns = ['amplitude', 'fs', 'lines'] + [TAG_U + str(i) for i in range(M)] + [TAG_Y + str(i) for i in range(M)] + ['?']
22 |
23 | # Extract data
24 | y = np.array(df_X['y0'], dtype=np.float32)
25 | u = np.array(df_X['u0'], dtype=np.float32)
26 | fs = np.array(df_X[COL_F].iloc[0], dtype = np.float32)
27 | N = y.size
28 | ts = 1/fs
29 | t = np.arange(N)*ts
30 |
31 |
32 | # In[Plot]
33 | fig, ax = plt.subplots(2, 1, sharex=True)
34 | ax[0].plot(t, y, 'k', label="$y$")
35 | ax[0].legend()
36 | ax[0].grid()
37 |
38 | ax[1].plot(t, u, 'k', label="$u$")
39 | ax[1].legend()
40 | ax[1].grid()
41 |
42 |
--------------------------------------------------------------------------------
/examples/ParWH/parWH_test.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import os
4 | import matplotlib
5 | import matplotlib.pyplot as plt
6 | import torch
7 | import util.metrics
8 | from examples.ParWH.models import ParallelWHNet
9 |
10 |
11 | if __name__ == '__main__':
12 |
13 | matplotlib.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica'], 'size': 11})
14 |
15 | model_name = "PWH_quant"
16 |
17 | # Dataset constants
18 | amplitudes = 5 # number of different amplitudes
19 | realizations = 20 # number of random phase multisine realizations
20 | samp_per_period = 16384 # number of samples per period
21 | n_skip = 1000
22 | periods = 1 # number of periods
23 | seq_len = samp_per_period * periods # data points per realization
24 |
25 | # Column names in the dataset
26 | TAG_U = 'u'
27 | TAG_Y = 'y'
28 |
29 | # test_signal = "100mV"
30 | # test_signal = "325mV"
31 | # test_signal = "550mV"
32 | # test_signal = "775mV"
33 | # test_signal = "1000mV"
34 | test_signal = "ramp"
35 |
36 | #test_signal = "1000mV" #"ramp" #ramp"#"320mV" #"1000mV"#"ramp"
37 | plot_input = False
38 |
39 | # In[Load dataset]
40 |
41 | dict_test = {"100mV": 0, "325mV": 1, "550mV": 2, "775mV": 3, "1000mV": 4, "ramp": 5}
42 | dataset_list_level = ['ParWHData_Validation_Level' + str(i) for i in range(1, amplitudes + 1)]
43 | dataset_list = dataset_list_level + ['ParWHData_ValidationArrow']
44 |
45 | df_X_lst = []
46 | for dataset_name in dataset_list:
47 | dataset_filename = dataset_name + '.csv'
48 | df_Xi = pd.read_csv(os.path.join("data", dataset_filename))
49 | df_X_lst.append(df_Xi)
50 |
51 |
52 | df_X = df_X_lst[dict_test[test_signal]] # first
53 |
54 | # Extract data
55 | y_meas = np.array(df_X['y'], dtype=np.float32)
56 | u = np.array(df_X['u'], dtype=np.float32)
57 | fs = np.array(df_X['fs'].iloc[0], dtype=np.float32)
58 | N = y_meas.size
59 | ts = 1/fs
60 | t = np.arange(N)*ts
61 |
62 | # In[Set-up model]
63 |
64 | net = ParallelWHNet()
65 | model_folder = os.path.join("models", model_name)
66 | net.load_state_dict(torch.load(os.path.join(model_folder, f"{model_name}.pt")))
67 | #log_sigma_hat = torch.load(os.path.join(model_folder, "log_sigma_hat.pt"))
68 | #sigma_hat = torch.exp(log_sigma_hat) + 1e-3
69 | # In[Predict]
70 | u_torch = torch.tensor(u[None, :, None], dtype=torch.float, requires_grad=False)
71 |
72 | with torch.no_grad():
73 | y_hat = net(u_torch)
74 |
75 | # In[Detach]
76 |
77 | y_hat = y_hat.detach().numpy()[0, :, 0]
78 |
79 | # In[Plot]
80 | if plot_input:
81 | fig, ax = plt.subplots(2, 1, sharex=True)
82 | ax[0].plot(t, y_meas, 'k', label="$\mathbf{y}$")
83 | ax[0].plot(t, y_hat, 'b', label=r"$\mathbf{y}^{\rm sim}$")
84 | ax[0].plot(t, y_meas - y_hat, 'r', label="$\mathbf{e}$")
85 | ax[0].legend(loc="upper right")
86 | ax[0].set_ylabel("Voltage (V)")
87 | ax[0].grid()
88 |
89 | ax[1].plot(t, u, 'k', label="$u$")
90 | ax[1].legend(loc="upper right")
91 | ax[1].set_ylabel("Voltage (V)")
92 | ax[1].set_xlabel("Time (s)")
93 | ax[1].grid()
94 | else:
95 | fig, ax = plt.subplots(1, 1, figsize=(6, 3))
96 | ax.plot(t, y_meas, 'k', label="$\mathbf{y}$")
97 | ax.plot(t, y_hat, 'b', label=r"$\mathbf{y}^{\rm sim}$")
98 | ax.plot(t, y_meas - y_hat, 'r', label="$\mathbf{e}$")
99 | if test_signal == "ramp":
100 | ax.legend(loc="upper left")
101 | else:
102 | ax.legend(loc="upper right")
103 | ax.set_ylabel("Voltage (V)")
104 | ax.set_xlabel("Time (s)")
105 | ax.grid()
106 |
107 | if test_signal == "ramp":
108 | ax.set_xlim([0.0, 0.21])
109 |
110 | fig.tight_layout()
111 | fig_folder = "fig"
112 | if not os.path.exists(fig_folder):
113 | os.makedirs(fig_folder)
114 | fig.savefig(os.path.join(fig_folder, f"{model_name}_timetrace.pdf"))
115 |
116 |
117 | # In[Metrics]
118 |
119 | idx_test = range(n_skip, N)
120 |
121 | e_rms = 1000*util.metrics.error_rmse(y_meas[idx_test], y_hat[idx_test])
122 | mae = 1000 * util.metrics.error_mae(y_meas[idx_test], y_hat[idx_test])
123 | fit_idx = util.metrics.fit_index(y_meas[idx_test], y_hat[idx_test])
124 | r_sq = util.metrics.r_squared(y_meas[idx_test], y_hat[idx_test])
125 | u_rms = 1000*util.metrics.error_rmse(u, 0)
126 |
127 | print(f"RMSE: {e_rms:.2f}mV\nMAE: {mae:.2f}mV\nFIT: {fit_idx:.1f}%\nR_sq: {r_sq:.1f}\nRMSU: {u_rms:.2f}mV")
--------------------------------------------------------------------------------
/examples/WH2009/WH2009_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import numpy as np
4 | import os
5 | from torchid.module.lti import SisoLinearDynamicalOperator
6 | from torchid.module.static import SisoStaticNonLinearity
7 | import matplotlib.pyplot as plt
8 | import time
9 | import torch.nn as nn
10 |
11 |
12 | import util.metrics
13 |
14 | # In[Main]
15 | if __name__ == '__main__':
16 |
17 | # In[Set seed for reproducibility]
18 | np.random.seed(0)
19 | torch.manual_seed(0)
20 |
21 | # In[Settings]
22 | lr_ADAM = 2e-4
23 | lr_BFGS = 1e0
24 | num_iter_ADAM = 40000 # ADAM iterations 20000
25 | num_iter_BFGS = 0 # final BFGS iterations
26 | msg_freq = 100
27 | n_skip = 5000
28 | n_fit = 20000
29 | decimate = 1
30 | n_batch = 1
31 | n_b = 8
32 | n_a = 8
33 | model_name = "model_WH"
34 |
35 | num_iter = num_iter_ADAM + num_iter_BFGS
36 |
37 | # In[Column names in the dataset]
38 | COL_F = ['fs']
39 | COL_U = ['uBenchMark']
40 | COL_Y = ['yBenchMark']
41 |
42 | # In[Load dataset]
43 | df_X = pd.read_csv(os.path.join("data", "WienerHammerBenchmark.csv"))
44 |
45 | # Extract data
46 | y = np.array(df_X[COL_Y], dtype=np.float32) # batch, time, channel
47 | u = np.array(df_X[COL_U], dtype=np.float32)
48 | fs = np.array(df_X[COL_F].iloc[0], dtype=np.float32)
49 | N = y.size
50 | ts = 1/fs
51 | t = np.arange(N)*ts
52 |
53 | # In[Fit data]
54 | y_fit = y[0:n_fit:decimate]
55 | u_fit = u[0:n_fit:decimate]
56 | t_fit = t[0:n_fit:decimate]
57 |
58 | # In[Prepare training tensors]
59 | u_fit_torch = torch.tensor(u_fit[None, :, :], dtype=torch.float, requires_grad=False)
60 | y_fit_torch = torch.tensor(y_fit[None, :, :], dtype=torch.float)
61 |
62 | # In[Prepare model]
63 | G1 = SisoLinearDynamicalOperator(n_b, n_a, n_k=1)
64 | F_nl = SisoStaticNonLinearity(n_hidden=10, activation='tanh')
65 | G2 = SisoLinearDynamicalOperator(n_b+1, n_a, n_k=0)
66 |
67 | def model(u_in):
68 | y1_lin = G1(u_fit_torch)
69 | y1_nl = F_nl(y1_lin)
70 | y_hat = G2(y1_nl)
71 | return y_hat, y1_nl, y1_lin
72 |
73 | # In[Setup optimizer]
74 | optimizer_ADAM = torch.optim.Adam([
75 | {'params': G1.parameters(), 'lr': lr_ADAM},
76 | {'params': G2.parameters(), 'lr': lr_ADAM},
77 | {'params': F_nl.parameters(), 'lr': lr_ADAM},
78 | ], lr=lr_ADAM)
79 |
80 | optimizer_LBFGS = torch.optim.LBFGS(list(G1.parameters()) + list(G2.parameters()) + list(F_nl.parameters()), lr=lr_BFGS)
81 |
82 |
83 | def closure():
84 | optimizer_LBFGS.zero_grad()
85 |
86 | # Simulate
87 | y_hat, y1_nl, y1_lin = model(u_fit_torch)
88 |
89 | # Compute fit loss
90 | err_fit = y_fit_torch[:, n_skip:, :] - y_hat[:, n_skip:, :]
91 | loss = torch.mean(err_fit**2)*1000
92 |
93 | # Backward pas
94 | loss.backward()
95 | return loss
96 |
97 |
98 | # In[Train]
99 | LOSS = []
100 | start_time = time.time()
101 | for itr in range(0, num_iter):
102 |
103 | if itr < num_iter_ADAM:
104 | msg_freq = 10
105 | loss_train = optimizer_ADAM.step(closure)
106 | else:
107 | msg_freq = 10
108 | loss_train = optimizer_LBFGS.step(closure)
109 |
110 | LOSS.append(loss_train.item())
111 | if itr % msg_freq == 0:
112 | with torch.no_grad():
113 | RMSE = torch.sqrt(loss_train)
114 | print(f'Iter {itr} | Fit Loss {loss_train:.6f} | RMSE:{RMSE:.4f}')
115 |
116 | train_time = time.time() - start_time
117 | print(f"\nTrain time: {train_time:.2f}")
118 |
119 | # In[Save model]
120 | model_folder = os.path.join("models", model_name)
121 | if not os.path.exists(model_folder):
122 | os.makedirs(model_folder)
123 |
124 | torch.save(G1.state_dict(), os.path.join(model_folder, "G1.pt"))
125 | torch.save(F_nl.state_dict(), os.path.join(model_folder, "F_nl.pt"))
126 | torch.save(G2.state_dict(), os.path.join(model_folder, "G2.pt"))
127 |
128 |
129 | # In[Simulate one more time]
130 | with torch.no_grad():
131 | y_hat, y1_nl, y1_lin = model(u_fit_torch)
132 |
133 | # In[Detach]
134 | y_hat = y_hat.detach().numpy()[0, :, :]
135 | y1_lin = y1_lin.detach().numpy()[0, :, :]
136 | y1_nl = y1_nl.detach().numpy()[0, :, :]
137 |
138 | # In[Plot]
139 | plt.figure()
140 | plt.plot(t_fit, y_fit, 'k', label="$y$")
141 | plt.plot(t_fit, y_hat, 'b', label="$\hat y$")
142 | plt.legend()
143 |
144 | # In[Plot loss]
145 | plt.figure()
146 | plt.plot(LOSS)
147 | plt.grid(True)
148 |
149 | # In[Plot static non-linearity]
150 |
151 | y1_lin_min = np.min(y1_lin)
152 | y1_lin_max = np.max(y1_lin)
153 |
154 | in_nl = np.arange(y1_lin_min, y1_lin_max, (y1_lin_max- y1_lin_min)/1000).astype(np.float32).reshape(-1, 1)
155 |
156 | with torch.no_grad():
157 | out_nl = F_nl(torch.as_tensor(in_nl))
158 |
159 | plt.figure()
160 | plt.plot(in_nl, out_nl, 'b')
161 | plt.plot(in_nl, out_nl, 'b')
162 | #plt.plot(y1_lin, y1_nl, 'b*')
163 | plt.xlabel('Static non-linearity input (-)')
164 | plt.ylabel('Static non-linearity input (-)')
165 | plt.grid(True)
166 |
167 | # In[Plot]
168 | e_rms = util.metrics.error_rmse(y_hat, y_fit)[0]
169 | print(f"RMSE: {e_rms:.2f}") # target: 1mv
170 |
--------------------------------------------------------------------------------
/examples/WH2009/WH2009_train_quantized.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import numpy as np
4 | import os
5 | from torchid_nb.module.lti import SisoLinearDynamicalOperator
6 | from torchid_nb.module.static import SisoStaticNonLinearity
7 | import matplotlib.pyplot as plt
8 | import time
9 | import util.metrics
10 |
11 |
12 | def normal_standard_cdf(val):
13 | """Returns the value of the cumulative distribution function for a standard normal variable"""
14 | return 1/2 * (1 + torch.erf(val/np.sqrt(2)))
15 |
16 |
17 | # In[Main]
18 | if __name__ == '__main__':
19 |
20 | # In[Set seed for reproducibility]
21 | np.random.seed(0)
22 | torch.manual_seed(0)
23 |
24 | # In[Settings]
25 | lr = 1e-4
26 | num_iter = 200000
27 | msg_freq = 100
28 | n_skip = 5000
29 | n_fit = 20000
30 | decimate = 1
31 | n_batch = 1
32 | n_b = 3
33 | n_a = 3
34 |
35 | meas_intervals = np.array([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0], dtype=np.float32)
36 | meas_intervals_full = np.r_[-1000, meas_intervals, 1000]
37 |
38 | model_name = "model_WH_digit"
39 |
40 | # In[Column names in the dataset]
41 | COL_F = ['fs']
42 | COL_U = ['uBenchMark']
43 | COL_Y = ['yBenchMark']
44 |
45 | # In[Load dataset]
46 | df_X = pd.read_csv(os.path.join("data", "WienerHammerBenchmark.csv"))
47 |
48 | # Extract data
49 | y = np.array(df_X[COL_Y], dtype=np.float32) # batch, time, channel
50 | u = np.array(df_X[COL_U], dtype=np.float32)
51 | fs = np.array(df_X[COL_F].iloc[0], dtype=np.float32)
52 | N = y.size
53 | ts = 1/fs
54 | t = np.arange(N)*ts
55 |
56 | # In[Compute v signal]
57 | v = np.digitize(y, bins=meas_intervals)
58 | bins = meas_intervals_full[np.c_[v, v+1]] # bins of the measurement
59 |
60 | # In[Fit data]
61 | bins_fit = bins[0:n_fit:decimate, :]
62 | v_fit = v[0:n_fit:decimate]
63 | y_fit = y[0:n_fit:decimate]
64 | u_fit = u[0:n_fit:decimate]
65 | t_fit = t[0:n_fit:decimate]
66 |
67 | # In[Prepare training tensors]
68 | u_fit_torch = torch.tensor(u_fit[None, :, :], dtype=torch.float, requires_grad=False)
69 | bins_fit_torch = torch.tensor(bins_fit[None, :, :], dtype=torch.float, requires_grad=False)
70 | v_fit_torch = torch.tensor(v_fit[None, :, :], dtype=torch.float)
71 |
72 | # In[Prepare model]
73 | G1 = SisoLinearDynamicalOperator(n_b, n_a, n_k=1)
74 | F_nl = SisoStaticNonLinearity(n_hidden=10, activation='tanh')
75 | G2 = SisoLinearDynamicalOperator(n_b, n_a)
76 |
77 | log_sigma_hat = torch.tensor(np.log(1.0), requires_grad=True) # torch.randn(1, requires_grad = True)
78 |
79 | def model(u_in):
80 | y1_lin = G1(u_fit_torch)
81 | y1_nl = F_nl(y1_lin)
82 | y_hat = G2(y1_nl)
83 | return y_hat, y1_nl, y1_lin
84 |
85 | # In[Setup optimizer]
86 | optimizer = torch.optim.Adam([
87 | {'params': G1.parameters(), 'lr': lr},
88 | {'params': G2.parameters(), 'lr': lr},
89 | {'params': F_nl.parameters(), 'lr': lr},
90 | {'params': log_sigma_hat, 'lr': 2e-5},
91 | ], lr=lr)
92 |
93 |
94 | # In[Train]
95 | LOSS = []
96 | SIGMA = []
97 | start_time = time.time()
98 | #num_iter = 20
99 | for itr in range(0, num_iter):
100 |
101 | optimizer.zero_grad()
102 |
103 | sigma_hat = torch.exp(log_sigma_hat)
104 | y_hat, y1_nl, y1_lin = model(u_fit_torch)
105 | Phi_hat = normal_standard_cdf((bins_fit_torch - y_hat)/(sigma_hat + 1e-6))
106 | y_Phi_hat = Phi_hat[..., [1]] - Phi_hat[..., [0]]
107 | y_hat_log = y_Phi_hat.log()
108 | loss_train = - y_hat_log.mean()
109 |
110 | LOSS.append(loss_train.item())
111 | SIGMA.append(sigma_hat.item())
112 |
113 | if itr % msg_freq == 0:
114 | with torch.no_grad():
115 | pass
116 | #RMSE = torch.sqrt(loss_train)
117 | print(f'Iter {itr} | Fit Loss {loss_train:.5f} sigma_hat:{sigma_hat:.5f}')
118 |
119 | loss_train.backward()
120 | optimizer.step()
121 |
122 | train_time = time.time() - start_time
123 | print(f"\nTrain time: {train_time:.2f}")
124 |
125 | # In[Save model]
126 | model_folder = os.path.join("models", model_name)
127 | if not os.path.exists(model_folder):
128 | os.makedirs(model_folder)
129 |
130 | torch.save(G1.state_dict(), os.path.join(model_folder, "G1.pt"))
131 | torch.save(F_nl.state_dict(), os.path.join(model_folder, "F_nl.pt"))
132 | torch.save(G2.state_dict(), os.path.join(model_folder, "G2.pt"))
133 |
134 |
135 | # In[Simulate one more time]
136 | with torch.no_grad():
137 | y_hat, y1_nl, y1_lin = model(u_fit_torch)
138 |
139 | # In[Detach]
140 | y_hat = y_hat.detach().numpy()[0, :, :]
141 | y1_lin = y1_lin.detach().numpy()[0, :, :]
142 | y1_nl = y1_nl.detach().numpy()[0, :, :]
143 |
144 | # In[Plot]
145 | plt.figure()
146 | plt.plot(t_fit, y_fit, 'k', label="$y$")
147 | plt.plot(t_fit, y_hat, 'b', label="$\hat y$")
148 | plt.legend()
149 |
150 | # In[Plot loss]
151 | plt.figure()
152 | plt.plot(LOSS)
153 | plt.grid(True)
154 |
155 | # In[Plot sigma]
156 | plt.figure()
157 | plt.plot(SIGMA)
158 | plt.grid(True)
159 |
160 | # In[Plot static non-linearity]
161 |
162 | y1_lin_min = np.min(y1_lin)
163 | y1_lin_max = np.max(y1_lin)
164 |
165 | in_nl = np.arange(y1_lin_min, y1_lin_max, (y1_lin_max- y1_lin_min)/1000).astype(np.float32).reshape(-1, 1)
166 |
167 | with torch.no_grad():
168 | out_nl = F_nl(torch.as_tensor(in_nl))
169 |
170 | plt.figure()
171 | plt.plot(in_nl, out_nl, 'b')
172 | plt.plot(in_nl, out_nl, 'b')
173 | #plt.plot(y1_lin, y1_nl, 'b*')
174 | plt.xlabel('Static non-linearity input (-)')
175 | plt.ylabel('Static non-linearity input (-)')
176 | plt.grid(True)
177 |
178 | # In[Plot]
179 | e_rms = util.metrics.error_rmse(y_hat, y_fit)[0]
180 | print(f"RMSE: {e_rms:.2f}") # target: 1mv
181 |
182 |
183 |
184 |
185 |
186 |
187 |
--------------------------------------------------------------------------------
/examples/WH2009/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/examples/WH2009/__init__.py
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/examples/__init__.py
--------------------------------------------------------------------------------
/fig/dynonet_quant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/fig/dynonet_quant.png
--------------------------------------------------------------------------------
/fig/neural_PEM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/fig/neural_PEM.png
--------------------------------------------------------------------------------
/sphinx/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = source
8 | BUILDDIR = build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/sphinx/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/sphinx/source/code.rst:
--------------------------------------------------------------------------------
1 | dynoNet API
2 | ==========================
3 | ---------------
4 | LTI blocks
5 | ---------------
6 |
7 | .. automodule:: torchid.module.lti
8 | :members:
9 | :special-members:
10 | :member-order: bysource
11 | :exclude-members: __init__, forward
12 |
13 | ---------------
14 | Static blocks
15 | ---------------
16 |
17 | .. automodule:: torchid.module.static
18 | :members:
19 | :special-members:
20 | :member-order: bysource
21 | :exclude-members: __init__
--------------------------------------------------------------------------------
/sphinx/source/index.rst:
--------------------------------------------------------------------------------
1 | .. torchid documentation master file, created by
2 | sphinx-quickstart on Fri Apr 10 01:50:34 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to the dynoNet documentation!
7 | ===================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | code
14 |
15 |
16 | Indices and tables
17 | ==================
18 |
19 | * :ref:`genindex`
20 | * :ref:`modindex`
21 | * :ref:`search`
22 |
--------------------------------------------------------------------------------
/torchid/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid/__init__.py
--------------------------------------------------------------------------------
/torchid/functional/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid/functional/__init__.py
--------------------------------------------------------------------------------
/torchid/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid/module/__init__.py
--------------------------------------------------------------------------------
/torchid/module/static.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MimoStaticNonLinearity(nn.Module):
6 | r"""Applies a Static MIMO non-linearity.
7 | The non-linearity is implemented as a feed-forward neural network.
8 |
9 | Args:
10 | in_channels (int): Number of input channels
11 | out_channels (int): Number of output channels
12 | n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20
13 | activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh'
14 |
15 | Shape:
16 | - Input: (..., in_channels)
17 | - Output: (..., out_channels)
18 |
19 | Examples::
20 |
21 | >>> in_channels, out_channels = 2, 4
22 | >>> F = MimoStaticNonLinearity(in_channels, out_channels)
23 | >>> batch_size, seq_len = 32, 100
24 | >>> u_in = torch.ones((batch_size, seq_len, in_channels))
25 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels)
26 | """
27 |
28 | def __init__(self, in_channels, out_channels, n_hidden=20, activation='tanh'):
29 | super(MimoStaticNonLinearity, self).__init__()
30 |
31 | activation_dict = {'tanh': nn.Tanh, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid}
32 |
33 | self.net = nn.Sequential(
34 | nn.Linear(in_channels, n_hidden),
35 | activation_dict[activation](), #nn.Tanh(),
36 | nn.Linear(n_hidden, out_channels)
37 | )
38 |
39 | def forward(self, u_lin):
40 | y_nl = self.net(u_lin)
41 | return y_nl
42 |
43 |
44 | class SisoStaticNonLinearity(MimoStaticNonLinearity):
45 | r"""Applies a Static SISO non-linearity.
46 | The non-linearity is implemented as a feed-forward neural network.
47 |
48 | Args:
49 | n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20
50 | activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh'
51 | s
52 | Shape:
53 | - Input: (..., in_channels)
54 | - Output: (..., out_channels)
55 |
56 | Examples::
57 |
58 | >>> F = SisoStaticNonLinearity(n_hidden=20)
59 | >>> batch_size, seq_len = 32, 100
60 | >>> u_in = torch.ones((batch_size, seq_len, in_channels))
61 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels)
62 | """
63 | def __init__(self, n_hidden=20, activation='tanh'):
64 | super(SisoStaticNonLinearity, self).__init__(in_channels=1, out_channels=1, n_hidden=n_hidden, activation=activation)
65 |
66 |
67 | class MimoChannelWiseNonLinearity(nn.Module):
68 | r"""Applies a Channel-wise non-linearity.
69 | The non-linearity is implemented as a set of feed-forward neural networks (each one operating on a different channel).
70 |
71 | Args:
72 | channels (int): Number of both input and output channels
73 | n_hidden (int, optional): Number of nodes in the hidden layer of each network. Default: 10
74 |
75 | Shape:
76 | - Input: (..., channels)
77 | - Output: (..., channels)
78 |
79 | Examples::
80 |
81 | >>> channels = 4
82 | >>> F = MimoChannelWiseNonLinearity(channels)
83 | >>> batch_size, seq_len = 32, 100
84 | >>> u_in = torch.ones((batch_size, seq_len, channels))
85 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, channels)
86 |
87 | """
88 |
89 | def __init__(self, channels, n_hidden=10):
90 | super(MimoChannelWiseNonLinearity, self).__init__()
91 |
92 | self.net = nn.ModuleList()
93 | for channel_idx in range(channels):
94 | channel_net = nn.Sequential(
95 | nn.Linear(1, n_hidden), # 2 states, 1 input
96 | nn.ReLU(),
97 | nn.Linear(n_hidden, 1)
98 | )
99 | self.net.append(channel_net)
100 |
101 | def forward(self, u_lin):
102 |
103 | y_nl = []
104 | for channel_idx, u_channel in enumerate(u_lin.split(1, dim=-1)): # split over the last dimension (input channel)
105 | y_nl_channel = self.net[channel_idx](u_channel) # Process blocks individually
106 | y_nl.append(y_nl_channel)
107 |
108 | y_nl = torch.cat(y_nl, -1) # concatenate all output channels
109 | return y_nl
110 |
111 |
112 | if __name__ == '__main__':
113 |
114 | channels = 4
115 | nn1 = MimoChannelWiseNonLinearity(channels)
116 | in_data = torch.randn(100, 10, channels)
117 | xx = net_out = nn1(in_data)
--------------------------------------------------------------------------------
/torchid_nb/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid_nb/__init__.py
--------------------------------------------------------------------------------
/torchid_nb/functional/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid_nb/functional/__init__.py
--------------------------------------------------------------------------------
/torchid_nb/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid_nb/module/__init__.py
--------------------------------------------------------------------------------
/torchid_nb/module/static.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MimoStaticNonLinearity(nn.Module):
6 | r"""Applies a Static MIMO non-linearity.
7 | The non-linearity is implemented as a feed-forward neural network.
8 |
9 | Args:
10 | in_channels (int): Number of input channels
11 | out_channels (int): Number of output channels
12 | n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20
13 | activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh'
14 |
15 | Shape:
16 | - Input: (..., in_channels)
17 | - Output: (..., out_channels)
18 |
19 | Examples::
20 |
21 | >>> in_channels, out_channels = 2, 4
22 | >>> F = MimoStaticNonLinearity(in_channels, out_channels)
23 | >>> batch_size, seq_len = 32, 100
24 | >>> u_in = torch.ones((batch_size, seq_len, in_channels))
25 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels)
26 | """
27 |
28 | def __init__(self, in_channels, out_channels, n_hidden=20, activation='tanh'):
29 | super(MimoStaticNonLinearity, self).__init__()
30 |
31 | activation_dict = {'tanh': nn.Tanh, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid}
32 |
33 | self.net = nn.Sequential(
34 | nn.Linear(in_channels, n_hidden),
35 | activation_dict[activation](), #nn.Tanh(),
36 | nn.Linear(n_hidden, out_channels)
37 | )
38 |
39 | def forward(self, u_lin):
40 | y_nl = self.net(u_lin)
41 | return y_nl
42 |
43 |
44 | class SisoStaticNonLinearity(MimoStaticNonLinearity):
45 | r"""Applies a Static SISO non-linearity.
46 | The non-linearity is implemented as a feed-forward neural network.
47 |
48 | Args:
49 | n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20
50 | activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh'
51 | s
52 | Shape:
53 | - Input: (..., in_channels)
54 | - Output: (..., out_channels)
55 |
56 | Examples::
57 |
58 | >>> F = SisoStaticNonLinearity(n_hidden=20)
59 | >>> batch_size, seq_len = 32, 100
60 | >>> u_in = torch.ones((batch_size, seq_len, in_channels))
61 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels)
62 | """
63 | def __init__(self, n_hidden=20, activation='tanh'):
64 | super(SisoStaticNonLinearity, self).__init__(in_channels=1, out_channels=1, n_hidden=n_hidden, activation=activation)
65 |
66 |
67 | class MimoChannelWiseNonLinearity(nn.Module):
68 | r"""Applies a Channel-wise non-linearity.
69 | The non-linearity is implemented as a set of feed-forward neural networks (each one operating on a different channel).
70 |
71 | Args:
72 | channels (int): Number of both input and output channels
73 | n_hidden (int, optional): Number of nodes in the hidden layer of each network. Default: 10
74 |
75 | Shape:
76 | - Input: (..., channels)
77 | - Output: (..., channels)
78 |
79 | Examples::
80 |
81 | >>> channels = 4
82 | >>> F = MimoChannelWiseNonLinearity(channels)
83 | >>> batch_size, seq_len = 32, 100
84 | >>> u_in = torch.ones((batch_size, seq_len, channels))
85 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, channels)
86 |
87 | """
88 |
89 | def __init__(self, channels, n_hidden=10):
90 | super(MimoChannelWiseNonLinearity, self).__init__()
91 |
92 | self.net = nn.ModuleList()
93 | for channel_idx in range(channels):
94 | channel_net = nn.Sequential(
95 | nn.Linear(1, n_hidden), # 2 states, 1 input
96 | nn.Tanh(),
97 | nn.Linear(n_hidden, 1)
98 | )
99 | self.net.append(channel_net)
100 |
101 | def forward(self, u_lin):
102 |
103 | y_nl = []
104 | for channel_idx, u_channel in enumerate(u_lin.split(1, dim=-1)): # split over the last dimension (input channel)
105 | y_nl_channel = self.net[channel_idx](u_channel) # Process blocks individually
106 | y_nl.append(y_nl_channel)
107 |
108 | y_nl = torch.cat(y_nl, -1) # concatenate all output channels
109 | return y_nl
110 |
111 |
112 | if __name__ == '__main__':
113 |
114 | channels = 4
115 | nn1 = MimoChannelWiseNonLinearity(channels)
116 | in_data = torch.randn(100, 10, channels)
117 | xx = net_out = nn1(in_data)
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/util/__init__.py
--------------------------------------------------------------------------------
/util/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | #np.array
3 |
4 |
5 | def r_squared(y_true, y_pred, time_axis=0):
6 | """ Computes the R-square index.
7 |
8 | The R-squared index is computed separately on each channel.
9 |
10 | Parameters
11 | ----------
12 | y_true : np.array
13 | Array of true values. If must be at least 2D.
14 | y_pred : np.array
15 | Array of predicted values. If must be compatible with y_true'
16 | time_axis : int
17 | Time axis. All other axes define separate channels.
18 |
19 | Returns
20 | -------
21 | r_squared_val : np.array
22 | Array of r_squared value.
23 | """
24 |
25 | SSE = np.sum((y_pred - y_true)**2, axis=time_axis)
26 | y_mean = np.mean(y_true, axis=time_axis, keepdims=True)
27 | SST = np.sum((y_true - y_mean)**2, axis=time_axis)
28 |
29 | return 1.0 - SSE/SST
30 |
31 |
32 | def error_rmse(y_true, y_pred, time_axis=0):
33 | """ Computes the Root Mean Square Error (RMSE).
34 |
35 | The RMSE index is computed separately on each channel.
36 |
37 | Parameters
38 | ----------
39 | y_true : np.array
40 | Array of true values. If must be at least 2D.
41 | y_pred : np.array
42 | Array of predicted values. If must be compatible with y_true'
43 | time_axis : int
44 | Time axis. All other axes define separate channels.
45 |
46 | Returns
47 | -------
48 | RMSE : np.array
49 | Array of r_squared value.
50 |
51 | """
52 |
53 | SSE = np.mean((y_pred - y_true)**2, axis=time_axis)
54 | RMSE = np.sqrt(SSE)
55 | return RMSE
56 |
57 |
58 | def error_mean(y_true, y_pred, time_axis=0):
59 | """ Computes the error mean value.
60 |
61 | The RMSE index is computed separately on each channel.
62 |
63 | Parameters
64 | ----------
65 | y_true : np.array
66 | Array of true values. If must be at least 2D.
67 | y_pred : np.array
68 | Array of predicted values. If must be compatible with y_true'
69 | time_axis : int
70 | Time axis. All other axes define separate channels.
71 |
72 | Returns
73 | -------
74 | e_mean : np.array
75 | Array of error means.
76 | """
77 |
78 | e_mean = np.mean(y_true - y_pred, axis=time_axis)
79 | return e_mean
80 |
81 |
82 | def error_mae(y_true, y_pred, time_axis=0):
83 | """ Computes the error Mean Absolute Value (MAE)
84 |
85 | The RMSE index is computed separately on each channel.
86 |
87 | Parameters
88 | ----------
89 | y_true : np.array
90 | Array of true values. If must be at least 2D.
91 | y_pred : np.array
92 | Array of predicted values. If must be compatible with y_true'
93 | time_axis : int
94 | Time axis. All other axes define separate channels.
95 |
96 | Returns
97 | -------
98 | e_mean : np.array
99 | Array of error mean absolute values.
100 | """
101 |
102 | e_mean = np.mean(np.abs(y_true - y_pred), axis=time_axis)
103 | return e_mean
104 |
105 | def fit_index(y_true, y_pred, time_axis=0):
106 | """ Computes the per-channel fit index.
107 |
108 | The fit index is commonly used in System Identification. See the definitionin the System Identification Toolbox
109 | or in the paper 'Nonlinear System Identification: A User-Oriented Road Map',
110 | https://arxiv.org/abs/1902.00683, page 31.
111 | The fit index is computed separately on each channel.
112 |
113 | Parameters
114 | ----------
115 | y_true : np.array
116 | Array of true values. If must be at least 2D.
117 | y_pred : np.array
118 | Array of predicted values. If must be compatible with y_true'
119 | time_axis : int
120 | Time axis. All other axes define separate channels.
121 |
122 | Returns
123 | -------
124 | fit_val : np.array
125 | Array of r_squared value.
126 |
127 | """
128 |
129 | err_norm = np.linalg.norm(y_true - y_pred, axis=time_axis, ord=2) # || y - y_pred ||
130 | y_mean = np.mean(y_true, axis=time_axis)
131 | err_mean_norm = np.linalg.norm(y_true - y_mean, ord=2) # || y - y_mean ||
132 | fit_val = 100*(1 - err_norm/err_mean_norm)
133 |
134 | return fit_val
135 |
136 |
137 | if __name__ == '__main__':
138 | N = 20
139 | ny = 2
140 | SNR = 10
141 | y_true = SNR*np.random.randn(N, 2)
142 | y_pred = np.copy(y_true) + np.random.randn(N, 2)
143 | err_rmse_val = error_rmse(y_pred, y_true)
144 | r_squared_val = r_squared(y_true, y_pred)
145 | fit_val = fit_index(y_true, y_pred)
146 |
147 | print(f"RMSE: {err_rmse_val}")
148 | print(f"R-squared: {r_squared_val}")
149 | print(f"fit index: {fit_val}")
150 |
--------------------------------------------------------------------------------