├── .gitignore ├── .idea ├── .gitignore ├── csv-plugin.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml ├── sysid-transfer-functions-pytorch.iml └── vcs.xml ├── LICENSE ├── README.md ├── doc ├── paper │ ├── fig │ │ ├── PWH_timetrace.pdf │ │ ├── WH_H.pdf │ │ ├── WH_timetrace.pdf │ │ ├── WH_timetrace_zoom.pdf │ │ ├── backprop_tf_ab.pdf │ │ ├── backprop_tf_ab.svg │ │ ├── backprop_tf_ab_circle.pdf │ │ ├── backprop_tf_ab_circle.svg │ │ ├── backprop_tf_ab_stable.pdf │ │ ├── backprop_tf_ab_stable.svg │ │ ├── backprop_tf_ab_stable_mod.pdf │ │ ├── backprop_tf_ab_stable_mod.svg │ │ ├── coeff_2ndorder.svg │ │ ├── dynonet_quant.pdf │ │ ├── dynonet_quant.svg │ │ ├── dynonet_quant_old.svg │ │ ├── generalized_HW.pdf │ │ ├── hammer.pdf │ │ ├── hammerstein_wiener.pdf │ │ ├── neural_PEM.pdf │ │ ├── neural_PEM.svg │ │ ├── neural_PEM_old.svg │ │ ├── parallel_WH.pdf │ │ ├── rho.svg │ │ ├── rhopsi.png │ │ ├── stable_2ndorder.pdf │ │ ├── sym │ │ │ ├── F.svg │ │ │ ├── F1.svg │ │ │ ├── G1.svg │ │ │ ├── G2.svg │ │ │ ├── GH.svg │ │ │ ├── GW.svg │ │ │ ├── Gz.svg │ │ │ ├── H_hat.svg │ │ │ ├── H_inv.svg │ │ │ ├── H_inv_check.svg │ │ │ ├── H_inv_par.svg │ │ │ ├── H_star.svg │ │ │ ├── L_k.svg │ │ │ ├── L_k_theta.svg │ │ │ ├── L_t_theta.svg │ │ │ ├── Lcal.svg │ │ │ ├── M_par.svg │ │ │ ├── NN.svg │ │ │ ├── U_k.svg │ │ │ ├── U_t.svg │ │ │ ├── ab.svg │ │ │ ├── dots.svg │ │ │ ├── eps_t_theta.svg │ │ │ ├── epsilon.svg │ │ │ ├── epsilon_hat.svg │ │ │ ├── epsilon_k_theta.svg │ │ │ ├── epsilon_t_theta_comma.svg │ │ │ ├── epsilon_t_theta_vec.png │ │ │ ├── et.svg │ │ │ ├── minus.svg │ │ │ ├── plus.svg │ │ │ ├── sigma_e.svg │ │ │ ├── thetabar.svg │ │ │ ├── ut.svg │ │ │ ├── uvec.svg │ │ │ ├── uvecbar.svg │ │ │ ├── v_hat.svg │ │ │ ├── v_meas.svg │ │ │ ├── wt.svg │ │ │ ├── xt.svg │ │ │ ├── y_hat.svg │ │ │ ├── y_ol.svg │ │ │ ├── y_sim.svg │ │ │ ├── y_sim_t.svg │ │ │ ├── y_t.svg │ │ │ ├── y_t_vec.svg │ │ │ ├── yt.svg │ │ │ ├── yvec.svg │ │ │ ├── yvecbar.svg │ │ │ ├── z_k.svg │ │ │ ├── z_t.svg │ │ │ └── zt.svg │ │ ├── transf.svg │ │ ├── wiener.pdf │ │ └── wiener_hammerstein.pdf │ ├── ms.bib │ ├── ms.pdf │ └── ms.tex └── slides │ ├── preamble.tex │ ├── presentation_main.pdf │ └── presentation_main.tex ├── examples ├── ParWH │ ├── __init__.py │ ├── models.py │ ├── models │ │ └── PWH │ │ │ └── PWH.pt │ ├── parWH_plot_signals.py │ ├── parWH_test.py │ ├── parWH_train_NLSQ.py │ └── parWH_train_quant_ML.py ├── WH2009 │ ├── WH2009_test.py │ ├── WH2009_train.py │ ├── WH2009_train_colored_noise_PEM.py │ ├── WH2009_train_process_noise.py │ ├── WH2009_train_quantized.py │ └── __init__.py └── __init__.py ├── fig ├── dynonet_quant.png ├── dynonet_quant.svg ├── neural_PEM.png └── neural_PEM.svg ├── sphinx ├── Makefile ├── make.bat └── source │ ├── code.rst │ ├── conf.py │ └── index.rst ├── torchid ├── __init__.py ├── functional │ ├── __init__.py │ └── lti.py └── module │ ├── __init__.py │ ├── lti.py │ └── static.py ├── torchid_nb ├── __init__.py ├── functional │ ├── __init__.py │ └── lti.py └── module │ ├── __init__.py │ ├── lti.py │ └── static.py └── util ├── __init__.py ├── filtering.py └── metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | /examples/WH2009/data/ 2 | /examples/WH2009/models/ 3 | /examples/WH2009/fig/ 4 | /examples/ParWH/data/ 5 | /examples/ParWH/models 6 | /examples/ParWH/fig 7 | /sphinx/build/ 8 | 9 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/csv-plugin.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 36 | 37 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 10 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/sysid-transfer-functions-pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 12 | 14 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Marco Forgione 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep learning with transfer functions: New applications in system identification 2 | 3 | 4 | This repository contains the Python code to reproduce the results of the paper [Deep learning with transfer functions: new applications in system identification](https://arxiv.org/abs/2104.09839) by Dario Piga, Marco Forgione, and Manas Mejari. 5 | 6 | We present a linear transfer function block, endowed with a well-defined and efficient back-propagation behavior for 7 | automatic derivatives computation. In the dynoNet architecture (already introduced [here](https://github.com/forgi86/dynonet)), linear dynamical operators are combined with static (i.e., memoryless) non-linearities which can be either elementary 8 | activation functions applied channel-wise; fully connected feed-forward neural networks; or other differentiable operators. 9 | 10 | In this work, we use the differentiable transfer function operator to tackle 11 | other challenging problems in system identification. In particular, we consider the problems of: 12 | 13 | 1. Learning of neural dynamical models in the presence of colored noise (prediction error minimization method) 14 | 1. Learning of dynoNet models from quantized output observations (maximum likelihood estimation method) 15 | 16 |
17 | Problem 1. is tackled by extending the prediction error minimization method to deep learning models. A trainable linear transfer function block 18 | is used to describe the power spectrum of the noise: 19 |
Neural PEM
20 | 21 |
22 | Problem 2. is tackled by training a dynoNet model with a loss function corresponding to the log-likelihood of quantized observations: 23 | ML quantized measurements 24 | 25 | # Folders: 26 | * [torchid](torchid_nb): PyTorch implementation of the linear dynamical operator (aka G-block in the paper) used in dynoNet 27 | * [examples](examples): examples using dynoNet for system identification 28 | * [util](util): definition of metrics R-square, RMSE, fit index 29 | 30 | Two [examples](examples) discussed in the paper are: 31 | 32 | * [WH2009](examples/WH2009): A circuit with Wiener-Hammerstein structure. Experimental dataset from http://www.nonlinearbenchmark.org 33 | * [Parallel Wiener-Hammerstein](examples/ParWH): A circuit with a two-branch parallel Wiener-Hammerstein structure. Experimental dataset from http://www.nonlinearbenchmark.org 34 | 35 | 36 | For the [WH2009](examples/WH2009) example, the main scripts are: 37 | 38 | * ``WH2009_train_colored_noise_PEM.py``: Training of a dynoNet model with the prediction error method in presence of colored noise 39 | * ``WH2009_test.py``: Evaluation of the dynoNet model on the original test dataset, computation of metrics, plots. 40 | 41 | For the [Parallel Wiener-Hammerstein](examples/ParWH) example, the main scripts are: 42 | 43 | * ``parWH_train_quant_ML.py``: Training of a dynoNet model with maximum likelihood in presence of quantized measurements 44 | * ``parWH_test.py``: Evaluation of the dynoNet model on the original test dataset, computation of metrics, plots. 45 | 46 | 47 | NOTE: the original data sets are not included in this project. They have to be manually downloaded from 48 | http://www.nonlinearbenchmark.org and copied in the data sub-folder of the example. 49 | # Software requirements: 50 | Simulations were performed on a Python 3.7 conda environment with 51 | 52 | * numpy 53 | * scipy 54 | * matplotlib 55 | * pandas 56 | * numba 57 | * pytorch (version 1.6) 58 | 59 | These dependencies may be installed through the commands: 60 | 61 | ``` 62 | conda install numpy scipy pandas numba matplotlib 63 | conda install pytorch torchvision cudatoolkit=10.2 -c pytorch 64 | ``` 65 | 66 | # Citing 67 | 68 | If you find this project useful, we encourage you to 69 | 70 | * Star this repository :star: 71 | * Cite the [paper](https://onlinelibrary.wiley.com/doi/abs/10.1002/acs.3216) 72 | ``` 73 | @inproceedings{piga2021a, 74 | title={Deep learning with transfer functions: new applications in system identification}, 75 | author={Piga, D. and Forgione, M. and Mejari, M.}, 76 | booktitle={Proc. of the 19th IFAC Symposium System Identification: learning models for decision and control}, 77 | year={2021} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /doc/paper/fig/PWH_timetrace.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/PWH_timetrace.pdf -------------------------------------------------------------------------------- /doc/paper/fig/WH_H.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/WH_H.pdf -------------------------------------------------------------------------------- /doc/paper/fig/WH_timetrace.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/WH_timetrace.pdf -------------------------------------------------------------------------------- /doc/paper/fig/WH_timetrace_zoom.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/WH_timetrace_zoom.pdf -------------------------------------------------------------------------------- /doc/paper/fig/backprop_tf_ab.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/backprop_tf_ab.pdf -------------------------------------------------------------------------------- /doc/paper/fig/backprop_tf_ab_circle.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/backprop_tf_ab_circle.pdf -------------------------------------------------------------------------------- /doc/paper/fig/backprop_tf_ab_stable.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/backprop_tf_ab_stable.pdf -------------------------------------------------------------------------------- /doc/paper/fig/backprop_tf_ab_stable_mod.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/backprop_tf_ab_stable_mod.pdf -------------------------------------------------------------------------------- /doc/paper/fig/coeff_2ndorder.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /doc/paper/fig/dynonet_quant.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/dynonet_quant.pdf -------------------------------------------------------------------------------- /doc/paper/fig/generalized_HW.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/generalized_HW.pdf -------------------------------------------------------------------------------- /doc/paper/fig/hammer.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/hammer.pdf -------------------------------------------------------------------------------- /doc/paper/fig/hammerstein_wiener.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/hammerstein_wiener.pdf -------------------------------------------------------------------------------- /doc/paper/fig/neural_PEM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/neural_PEM.pdf -------------------------------------------------------------------------------- /doc/paper/fig/parallel_WH.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/parallel_WH.pdf -------------------------------------------------------------------------------- /doc/paper/fig/rho.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /doc/paper/fig/rhopsi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/rhopsi.png -------------------------------------------------------------------------------- /doc/paper/fig/stable_2ndorder.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/stable_2ndorder.pdf -------------------------------------------------------------------------------- /doc/paper/fig/sym/F.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/F1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/G1.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/G2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/Gz.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/H_hat.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/H_inv.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/H_star.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/L_k.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/L_k_theta.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/L_t_theta.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/Lcal.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/NN.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/U_k.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/U_t.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/ab.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/dots.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/eps_t_theta.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/epsilon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/epsilon_hat.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/epsilon_k_theta.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/epsilon_t_theta_comma.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/epsilon_t_theta_vec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/sym/epsilon_t_theta_vec.png -------------------------------------------------------------------------------- /doc/paper/fig/sym/et.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/minus.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/plus.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/sigma_e.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/thetabar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/ut.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/uvec.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/uvecbar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/v_hat.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/v_meas.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/wt.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/xt.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/y_hat.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/y_ol.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/y_sim.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/y_sim_t.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/y_t.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/y_t_vec.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/yt.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/yvec.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/yvecbar.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/z_k.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/z_t.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /doc/paper/fig/sym/zt.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /doc/paper/fig/transf.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /doc/paper/fig/wiener.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/wiener.pdf -------------------------------------------------------------------------------- /doc/paper/fig/wiener_hammerstein.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/fig/wiener_hammerstein.pdf -------------------------------------------------------------------------------- /doc/paper/ms.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/paper/ms.pdf -------------------------------------------------------------------------------- /doc/slides/preamble.tex: -------------------------------------------------------------------------------- 1 | % does not look nice, try deleting the line with the fontenc. 2 | \usepackage[english]{babel} 3 | \usepackage{amsmath} 4 | \usepackage[latin1]{inputenc} 5 | \usepackage{units} 6 | \usepackage{colortbl} 7 | \usepackage{multimedia} 8 | \usepackage{bm} 9 | \usepackage{subcaption} 10 | \usepackage{algorithm2e} 11 | \usepackage{algorithmic} 12 | 13 | \mode 14 | { 15 | \usetheme{Boadilla} 16 | \useoutertheme{infolines} 17 | \setbeamercovered{transparent} 18 | } 19 | 20 | 21 | \title[Dynamical systems \& deep learning]{{Deep learning with transfer functions: New applications in system identification}} 22 | 23 | 24 | \author[]{Dario Piga, \underline{Marco Forgione}, Manas Mejari} 25 | 26 | \institute[IDSIA]{IDSIA Dalle Molle Institute for Artificial Intelligence USI-SUPSI, Lugano, Switzerland} 27 | 28 | 29 | \date[]{19th IFAC symposium System Identification: learning models for decision and control} 30 | 31 | 32 | \subject{System Identification, Deep Learning, Machine Learning, Regularization} 33 | 34 | 35 | %% MATH DEFINITIONS %% 36 | \newcommand{\So}{S_o} % true system 37 | \newcommand{\hidden}[1]{\overline{#1}} 38 | \newcommand{\nsamp}{N} 39 | \newcommand{\Yid}{Y} 40 | \newcommand{\Uid}{U} 41 | \newcommand{\Did}{{\mathcal{D}}} 42 | \newcommand{\tens}[1]{\bm{#1}} 43 | 44 | \newcommand{\batchsize}{q} 45 | \newcommand{\seqlen}{m} 46 | \newcommand{\nin}{n_u} 47 | \newcommand{\ny}{n_y} 48 | \newcommand{\nx}{n_x} 49 | 50 | \newcommand{\NN}{\mathcal{N}} % a feedforward neural network 51 | 52 | \newcommand{\norm}[1]{\left \lVert #1 \right \rVert} 53 | \DeclareMathOperator*\argmin{arg \, min} 54 | \newcommand{\Name}{\emph{dynoNet}} 55 | 56 | 57 | %% DYNONET MATH DEFINITIONS %% 58 | \newcommand{\q}{q} % shift operator 59 | \newcommand{\A}{A} % autoregressive polynomial 60 | \newcommand{\ac}{a} % autoregressive polynomial coefficient 61 | \newcommand{\B}{B} % exogenous polynomial 62 | \newcommand{\bb}{b} % exogenous polynomial coefficient 63 | \newcommand{\Gmat}{\mathbb{G}} % transfer function operator in matrix form 64 | \newcommand{\tvec}[1]{\bm{#1}} 65 | \newcommand{\mat}[1]{\bm{#1}} 66 | \newcommand{\sens}[1]{\tilde{#1}} 67 | \newcommand{\adjoint}[1]{\overline{#1}} 68 | \newcommand{\loss}{\mathcal{L}} 69 | \newcommand{\pdiff}[2]{\frac{\partial #1}{\partial #2}} 70 | %\newcommand{\nsamp}{T} 71 | 72 | \newcommand{\conv}{*} 73 | \newcommand{\ccorr}{\star} 74 | \definecolor{orange}{RGB}{204, 85, 0} 75 | %% DYNONET %% 76 | -------------------------------------------------------------------------------- /doc/slides/presentation_main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/doc/slides/presentation_main.pdf -------------------------------------------------------------------------------- /examples/ParWH/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/examples/ParWH/__init__.py -------------------------------------------------------------------------------- /examples/ParWH/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchid_nb.module.lti import MimoLinearDynamicalOperator, SisoLinearDynamicalOperator 3 | from torchid_nb.module.static import MimoStaticNonLinearity, MimoChannelWiseNonLinearity 4 | 5 | 6 | class ParallelWHNet(torch.nn.Module): 7 | def __init__(self, nb_1=12, na_1=12, nb_2=13, na_2=12): 8 | super(ParallelWHNet, self).__init__() 9 | self.nb_1 = nb_1 10 | self.na_1 = na_1 11 | self.nb_2 = nb_2 12 | self.na_2 = na_2 13 | self.G1 = MimoLinearDynamicalOperator(1, 2, n_b=self.nb_1, n_a=self.na_1, n_k=1) 14 | self.F_nl = MimoChannelWiseNonLinearity(2, n_hidden=10) 15 | #self.F_nl = MimoStaticNonLinearity(2, 2, n_hidden=10) 16 | self.G2 = MimoLinearDynamicalOperator(2, 1, n_b=self.nb_2, n_a=self.na_2, n_k=0) 17 | #self.G3 = SisoLinearDynamicalOperator(n_b=3, n_a=3, n_k=1) 18 | 19 | def forward(self, u): 20 | y1_lin = self.G1(u) 21 | y1_nl = self.F_nl(y1_lin) # B, T, C1 22 | y2_lin = self.G2(y1_nl) # B, T, C2 23 | 24 | return y2_lin #+ self.G3(u) 25 | 26 | 27 | class ParallelWHNetVar(torch.nn.Module): 28 | def __init__(self): 29 | super(ParallelWHNetVar, self).__init__() 30 | self.nb_1 = 3 31 | self.na_1 = 3 32 | self.nb_2 = 3 33 | self.na_2 = 3 34 | self.G1 = MimoLinearDynamicalOperator(1, 16, n_b=self.nb_1, n_a=self.na_1, n_k=1) 35 | self.F_nl = MimoStaticNonLinearity(16, 16) #MimoChannelWiseNonLinearity(16, n_hidden=10) 36 | self.G2 = MimoLinearDynamicalOperator(16, 1, n_b=self.nb_2, n_a=self.na_2, n_k=1) 37 | 38 | def forward(self, u): 39 | y1_lin = self.G1(u) 40 | y1_nl = self.F_nl(y1_lin) # B, T, C1 41 | y2_lin = self.G2(y1_nl) # B, T, C2 42 | 43 | return y2_lin 44 | 45 | 46 | class ParallelWHResNet(torch.nn.Module): 47 | def __init__(self): 48 | super(ParallelWHResNet, self).__init__() 49 | self.nb_1 = 4 50 | self.na_1 = 4 51 | self.nb_2 = 4 52 | self.na_2 = 4 53 | self.G1 = MimoLinearDynamicalOperator(1, 2, n_b=self.nb_1, n_a=self.na_1, n_k=1) 54 | self.F_nl = MimoChannelWiseNonLinearity(2, n_hidden=10) 55 | self.G2 = MimoLinearDynamicalOperator(2, 1, n_b=self.nb_2, n_a=self.na_2, n_k=1) 56 | self.G3 = SisoLinearDynamicalOperator(n_b=6, n_a=6, n_k=1) 57 | 58 | def forward(self, u): 59 | y1_lin = self.G1(u) 60 | y1_nl = self.F_nl(y1_lin) # B, T, C1 61 | y2_lin = self.G2(y1_nl) # B, T, C2 62 | 63 | return y2_lin + self.G3(u) -------------------------------------------------------------------------------- /examples/ParWH/models/PWH/PWH.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/examples/ParWH/models/PWH/PWH.pt -------------------------------------------------------------------------------- /examples/ParWH/parWH_plot_signals.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import matplotlib.pyplot as plt 5 | 6 | if __name__ == '__main__': 7 | 8 | N = 16384 # number of samples per period 9 | M = 20 # number of random phase multisine realizations 10 | P = 2 # number of periods 11 | nAmp = 5 # number of different amplitudes 12 | 13 | # Column names in the dataset 14 | COL_F = ['fs'] 15 | TAG_U = 'u' 16 | TAG_Y = 'y' 17 | 18 | # Load dataset 19 | #df_X = pd.read_csv(os.path.join("data", "WH_CombinedZeroMultisineSinesweep.csv")) 20 | df_X = pd.read_csv(os.path.join("data", "ParWHData_Estimation_Level2.csv")) 21 | df_X.columns = ['amplitude', 'fs', 'lines'] + [TAG_U + str(i) for i in range(M)] + [TAG_Y + str(i) for i in range(M)] + ['?'] 22 | 23 | # Extract data 24 | y = np.array(df_X['y0'], dtype=np.float32) 25 | u = np.array(df_X['u0'], dtype=np.float32) 26 | fs = np.array(df_X[COL_F].iloc[0], dtype = np.float32) 27 | N = y.size 28 | ts = 1/fs 29 | t = np.arange(N)*ts 30 | 31 | 32 | # In[Plot] 33 | fig, ax = plt.subplots(2, 1, sharex=True) 34 | ax[0].plot(t, y, 'k', label="$y$") 35 | ax[0].legend() 36 | ax[0].grid() 37 | 38 | ax[1].plot(t, u, 'k', label="$u$") 39 | ax[1].legend() 40 | ax[1].grid() 41 | 42 | -------------------------------------------------------------------------------- /examples/ParWH/parWH_test.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import util.metrics 8 | from examples.ParWH.models import ParallelWHNet 9 | 10 | 11 | if __name__ == '__main__': 12 | 13 | matplotlib.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica'], 'size': 11}) 14 | 15 | model_name = "PWH_quant" 16 | 17 | # Dataset constants 18 | amplitudes = 5 # number of different amplitudes 19 | realizations = 20 # number of random phase multisine realizations 20 | samp_per_period = 16384 # number of samples per period 21 | n_skip = 1000 22 | periods = 1 # number of periods 23 | seq_len = samp_per_period * periods # data points per realization 24 | 25 | # Column names in the dataset 26 | TAG_U = 'u' 27 | TAG_Y = 'y' 28 | 29 | # test_signal = "100mV" 30 | # test_signal = "325mV" 31 | # test_signal = "550mV" 32 | # test_signal = "775mV" 33 | # test_signal = "1000mV" 34 | test_signal = "ramp" 35 | 36 | #test_signal = "1000mV" #"ramp" #ramp"#"320mV" #"1000mV"#"ramp" 37 | plot_input = False 38 | 39 | # In[Load dataset] 40 | 41 | dict_test = {"100mV": 0, "325mV": 1, "550mV": 2, "775mV": 3, "1000mV": 4, "ramp": 5} 42 | dataset_list_level = ['ParWHData_Validation_Level' + str(i) for i in range(1, amplitudes + 1)] 43 | dataset_list = dataset_list_level + ['ParWHData_ValidationArrow'] 44 | 45 | df_X_lst = [] 46 | for dataset_name in dataset_list: 47 | dataset_filename = dataset_name + '.csv' 48 | df_Xi = pd.read_csv(os.path.join("data", dataset_filename)) 49 | df_X_lst.append(df_Xi) 50 | 51 | 52 | df_X = df_X_lst[dict_test[test_signal]] # first 53 | 54 | # Extract data 55 | y_meas = np.array(df_X['y'], dtype=np.float32) 56 | u = np.array(df_X['u'], dtype=np.float32) 57 | fs = np.array(df_X['fs'].iloc[0], dtype=np.float32) 58 | N = y_meas.size 59 | ts = 1/fs 60 | t = np.arange(N)*ts 61 | 62 | # In[Set-up model] 63 | 64 | net = ParallelWHNet() 65 | model_folder = os.path.join("models", model_name) 66 | net.load_state_dict(torch.load(os.path.join(model_folder, f"{model_name}.pt"))) 67 | #log_sigma_hat = torch.load(os.path.join(model_folder, "log_sigma_hat.pt")) 68 | #sigma_hat = torch.exp(log_sigma_hat) + 1e-3 69 | # In[Predict] 70 | u_torch = torch.tensor(u[None, :, None], dtype=torch.float, requires_grad=False) 71 | 72 | with torch.no_grad(): 73 | y_hat = net(u_torch) 74 | 75 | # In[Detach] 76 | 77 | y_hat = y_hat.detach().numpy()[0, :, 0] 78 | 79 | # In[Plot] 80 | if plot_input: 81 | fig, ax = plt.subplots(2, 1, sharex=True) 82 | ax[0].plot(t, y_meas, 'k', label="$\mathbf{y}$") 83 | ax[0].plot(t, y_hat, 'b', label=r"$\mathbf{y}^{\rm sim}$") 84 | ax[0].plot(t, y_meas - y_hat, 'r', label="$\mathbf{e}$") 85 | ax[0].legend(loc="upper right") 86 | ax[0].set_ylabel("Voltage (V)") 87 | ax[0].grid() 88 | 89 | ax[1].plot(t, u, 'k', label="$u$") 90 | ax[1].legend(loc="upper right") 91 | ax[1].set_ylabel("Voltage (V)") 92 | ax[1].set_xlabel("Time (s)") 93 | ax[1].grid() 94 | else: 95 | fig, ax = plt.subplots(1, 1, figsize=(6, 3)) 96 | ax.plot(t, y_meas, 'k', label="$\mathbf{y}$") 97 | ax.plot(t, y_hat, 'b', label=r"$\mathbf{y}^{\rm sim}$") 98 | ax.plot(t, y_meas - y_hat, 'r', label="$\mathbf{e}$") 99 | if test_signal == "ramp": 100 | ax.legend(loc="upper left") 101 | else: 102 | ax.legend(loc="upper right") 103 | ax.set_ylabel("Voltage (V)") 104 | ax.set_xlabel("Time (s)") 105 | ax.grid() 106 | 107 | if test_signal == "ramp": 108 | ax.set_xlim([0.0, 0.21]) 109 | 110 | fig.tight_layout() 111 | fig_folder = "fig" 112 | if not os.path.exists(fig_folder): 113 | os.makedirs(fig_folder) 114 | fig.savefig(os.path.join(fig_folder, f"{model_name}_timetrace.pdf")) 115 | 116 | 117 | # In[Metrics] 118 | 119 | idx_test = range(n_skip, N) 120 | 121 | e_rms = 1000*util.metrics.error_rmse(y_meas[idx_test], y_hat[idx_test]) 122 | mae = 1000 * util.metrics.error_mae(y_meas[idx_test], y_hat[idx_test]) 123 | fit_idx = util.metrics.fit_index(y_meas[idx_test], y_hat[idx_test]) 124 | r_sq = util.metrics.r_squared(y_meas[idx_test], y_hat[idx_test]) 125 | u_rms = 1000*util.metrics.error_rmse(u, 0) 126 | 127 | print(f"RMSE: {e_rms:.2f}mV\nMAE: {mae:.2f}mV\nFIT: {fit_idx:.1f}%\nR_sq: {r_sq:.1f}\nRMSU: {u_rms:.2f}mV") -------------------------------------------------------------------------------- /examples/WH2009/WH2009_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | import os 5 | from torchid.module.lti import SisoLinearDynamicalOperator 6 | from torchid.module.static import SisoStaticNonLinearity 7 | import matplotlib.pyplot as plt 8 | import time 9 | import torch.nn as nn 10 | 11 | 12 | import util.metrics 13 | 14 | # In[Main] 15 | if __name__ == '__main__': 16 | 17 | # In[Set seed for reproducibility] 18 | np.random.seed(0) 19 | torch.manual_seed(0) 20 | 21 | # In[Settings] 22 | lr_ADAM = 2e-4 23 | lr_BFGS = 1e0 24 | num_iter_ADAM = 40000 # ADAM iterations 20000 25 | num_iter_BFGS = 0 # final BFGS iterations 26 | msg_freq = 100 27 | n_skip = 5000 28 | n_fit = 20000 29 | decimate = 1 30 | n_batch = 1 31 | n_b = 8 32 | n_a = 8 33 | model_name = "model_WH" 34 | 35 | num_iter = num_iter_ADAM + num_iter_BFGS 36 | 37 | # In[Column names in the dataset] 38 | COL_F = ['fs'] 39 | COL_U = ['uBenchMark'] 40 | COL_Y = ['yBenchMark'] 41 | 42 | # In[Load dataset] 43 | df_X = pd.read_csv(os.path.join("data", "WienerHammerBenchmark.csv")) 44 | 45 | # Extract data 46 | y = np.array(df_X[COL_Y], dtype=np.float32) # batch, time, channel 47 | u = np.array(df_X[COL_U], dtype=np.float32) 48 | fs = np.array(df_X[COL_F].iloc[0], dtype=np.float32) 49 | N = y.size 50 | ts = 1/fs 51 | t = np.arange(N)*ts 52 | 53 | # In[Fit data] 54 | y_fit = y[0:n_fit:decimate] 55 | u_fit = u[0:n_fit:decimate] 56 | t_fit = t[0:n_fit:decimate] 57 | 58 | # In[Prepare training tensors] 59 | u_fit_torch = torch.tensor(u_fit[None, :, :], dtype=torch.float, requires_grad=False) 60 | y_fit_torch = torch.tensor(y_fit[None, :, :], dtype=torch.float) 61 | 62 | # In[Prepare model] 63 | G1 = SisoLinearDynamicalOperator(n_b, n_a, n_k=1) 64 | F_nl = SisoStaticNonLinearity(n_hidden=10, activation='tanh') 65 | G2 = SisoLinearDynamicalOperator(n_b+1, n_a, n_k=0) 66 | 67 | def model(u_in): 68 | y1_lin = G1(u_fit_torch) 69 | y1_nl = F_nl(y1_lin) 70 | y_hat = G2(y1_nl) 71 | return y_hat, y1_nl, y1_lin 72 | 73 | # In[Setup optimizer] 74 | optimizer_ADAM = torch.optim.Adam([ 75 | {'params': G1.parameters(), 'lr': lr_ADAM}, 76 | {'params': G2.parameters(), 'lr': lr_ADAM}, 77 | {'params': F_nl.parameters(), 'lr': lr_ADAM}, 78 | ], lr=lr_ADAM) 79 | 80 | optimizer_LBFGS = torch.optim.LBFGS(list(G1.parameters()) + list(G2.parameters()) + list(F_nl.parameters()), lr=lr_BFGS) 81 | 82 | 83 | def closure(): 84 | optimizer_LBFGS.zero_grad() 85 | 86 | # Simulate 87 | y_hat, y1_nl, y1_lin = model(u_fit_torch) 88 | 89 | # Compute fit loss 90 | err_fit = y_fit_torch[:, n_skip:, :] - y_hat[:, n_skip:, :] 91 | loss = torch.mean(err_fit**2)*1000 92 | 93 | # Backward pas 94 | loss.backward() 95 | return loss 96 | 97 | 98 | # In[Train] 99 | LOSS = [] 100 | start_time = time.time() 101 | for itr in range(0, num_iter): 102 | 103 | if itr < num_iter_ADAM: 104 | msg_freq = 10 105 | loss_train = optimizer_ADAM.step(closure) 106 | else: 107 | msg_freq = 10 108 | loss_train = optimizer_LBFGS.step(closure) 109 | 110 | LOSS.append(loss_train.item()) 111 | if itr % msg_freq == 0: 112 | with torch.no_grad(): 113 | RMSE = torch.sqrt(loss_train) 114 | print(f'Iter {itr} | Fit Loss {loss_train:.6f} | RMSE:{RMSE:.4f}') 115 | 116 | train_time = time.time() - start_time 117 | print(f"\nTrain time: {train_time:.2f}") 118 | 119 | # In[Save model] 120 | model_folder = os.path.join("models", model_name) 121 | if not os.path.exists(model_folder): 122 | os.makedirs(model_folder) 123 | 124 | torch.save(G1.state_dict(), os.path.join(model_folder, "G1.pt")) 125 | torch.save(F_nl.state_dict(), os.path.join(model_folder, "F_nl.pt")) 126 | torch.save(G2.state_dict(), os.path.join(model_folder, "G2.pt")) 127 | 128 | 129 | # In[Simulate one more time] 130 | with torch.no_grad(): 131 | y_hat, y1_nl, y1_lin = model(u_fit_torch) 132 | 133 | # In[Detach] 134 | y_hat = y_hat.detach().numpy()[0, :, :] 135 | y1_lin = y1_lin.detach().numpy()[0, :, :] 136 | y1_nl = y1_nl.detach().numpy()[0, :, :] 137 | 138 | # In[Plot] 139 | plt.figure() 140 | plt.plot(t_fit, y_fit, 'k', label="$y$") 141 | plt.plot(t_fit, y_hat, 'b', label="$\hat y$") 142 | plt.legend() 143 | 144 | # In[Plot loss] 145 | plt.figure() 146 | plt.plot(LOSS) 147 | plt.grid(True) 148 | 149 | # In[Plot static non-linearity] 150 | 151 | y1_lin_min = np.min(y1_lin) 152 | y1_lin_max = np.max(y1_lin) 153 | 154 | in_nl = np.arange(y1_lin_min, y1_lin_max, (y1_lin_max- y1_lin_min)/1000).astype(np.float32).reshape(-1, 1) 155 | 156 | with torch.no_grad(): 157 | out_nl = F_nl(torch.as_tensor(in_nl)) 158 | 159 | plt.figure() 160 | plt.plot(in_nl, out_nl, 'b') 161 | plt.plot(in_nl, out_nl, 'b') 162 | #plt.plot(y1_lin, y1_nl, 'b*') 163 | plt.xlabel('Static non-linearity input (-)') 164 | plt.ylabel('Static non-linearity input (-)') 165 | plt.grid(True) 166 | 167 | # In[Plot] 168 | e_rms = util.metrics.error_rmse(y_hat, y_fit)[0] 169 | print(f"RMSE: {e_rms:.2f}") # target: 1mv 170 | -------------------------------------------------------------------------------- /examples/WH2009/WH2009_train_quantized.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | import os 5 | from torchid_nb.module.lti import SisoLinearDynamicalOperator 6 | from torchid_nb.module.static import SisoStaticNonLinearity 7 | import matplotlib.pyplot as plt 8 | import time 9 | import util.metrics 10 | 11 | 12 | def normal_standard_cdf(val): 13 | """Returns the value of the cumulative distribution function for a standard normal variable""" 14 | return 1/2 * (1 + torch.erf(val/np.sqrt(2))) 15 | 16 | 17 | # In[Main] 18 | if __name__ == '__main__': 19 | 20 | # In[Set seed for reproducibility] 21 | np.random.seed(0) 22 | torch.manual_seed(0) 23 | 24 | # In[Settings] 25 | lr = 1e-4 26 | num_iter = 200000 27 | msg_freq = 100 28 | n_skip = 5000 29 | n_fit = 20000 30 | decimate = 1 31 | n_batch = 1 32 | n_b = 3 33 | n_a = 3 34 | 35 | meas_intervals = np.array([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0], dtype=np.float32) 36 | meas_intervals_full = np.r_[-1000, meas_intervals, 1000] 37 | 38 | model_name = "model_WH_digit" 39 | 40 | # In[Column names in the dataset] 41 | COL_F = ['fs'] 42 | COL_U = ['uBenchMark'] 43 | COL_Y = ['yBenchMark'] 44 | 45 | # In[Load dataset] 46 | df_X = pd.read_csv(os.path.join("data", "WienerHammerBenchmark.csv")) 47 | 48 | # Extract data 49 | y = np.array(df_X[COL_Y], dtype=np.float32) # batch, time, channel 50 | u = np.array(df_X[COL_U], dtype=np.float32) 51 | fs = np.array(df_X[COL_F].iloc[0], dtype=np.float32) 52 | N = y.size 53 | ts = 1/fs 54 | t = np.arange(N)*ts 55 | 56 | # In[Compute v signal] 57 | v = np.digitize(y, bins=meas_intervals) 58 | bins = meas_intervals_full[np.c_[v, v+1]] # bins of the measurement 59 | 60 | # In[Fit data] 61 | bins_fit = bins[0:n_fit:decimate, :] 62 | v_fit = v[0:n_fit:decimate] 63 | y_fit = y[0:n_fit:decimate] 64 | u_fit = u[0:n_fit:decimate] 65 | t_fit = t[0:n_fit:decimate] 66 | 67 | # In[Prepare training tensors] 68 | u_fit_torch = torch.tensor(u_fit[None, :, :], dtype=torch.float, requires_grad=False) 69 | bins_fit_torch = torch.tensor(bins_fit[None, :, :], dtype=torch.float, requires_grad=False) 70 | v_fit_torch = torch.tensor(v_fit[None, :, :], dtype=torch.float) 71 | 72 | # In[Prepare model] 73 | G1 = SisoLinearDynamicalOperator(n_b, n_a, n_k=1) 74 | F_nl = SisoStaticNonLinearity(n_hidden=10, activation='tanh') 75 | G2 = SisoLinearDynamicalOperator(n_b, n_a) 76 | 77 | log_sigma_hat = torch.tensor(np.log(1.0), requires_grad=True) # torch.randn(1, requires_grad = True) 78 | 79 | def model(u_in): 80 | y1_lin = G1(u_fit_torch) 81 | y1_nl = F_nl(y1_lin) 82 | y_hat = G2(y1_nl) 83 | return y_hat, y1_nl, y1_lin 84 | 85 | # In[Setup optimizer] 86 | optimizer = torch.optim.Adam([ 87 | {'params': G1.parameters(), 'lr': lr}, 88 | {'params': G2.parameters(), 'lr': lr}, 89 | {'params': F_nl.parameters(), 'lr': lr}, 90 | {'params': log_sigma_hat, 'lr': 2e-5}, 91 | ], lr=lr) 92 | 93 | 94 | # In[Train] 95 | LOSS = [] 96 | SIGMA = [] 97 | start_time = time.time() 98 | #num_iter = 20 99 | for itr in range(0, num_iter): 100 | 101 | optimizer.zero_grad() 102 | 103 | sigma_hat = torch.exp(log_sigma_hat) 104 | y_hat, y1_nl, y1_lin = model(u_fit_torch) 105 | Phi_hat = normal_standard_cdf((bins_fit_torch - y_hat)/(sigma_hat + 1e-6)) 106 | y_Phi_hat = Phi_hat[..., [1]] - Phi_hat[..., [0]] 107 | y_hat_log = y_Phi_hat.log() 108 | loss_train = - y_hat_log.mean() 109 | 110 | LOSS.append(loss_train.item()) 111 | SIGMA.append(sigma_hat.item()) 112 | 113 | if itr % msg_freq == 0: 114 | with torch.no_grad(): 115 | pass 116 | #RMSE = torch.sqrt(loss_train) 117 | print(f'Iter {itr} | Fit Loss {loss_train:.5f} sigma_hat:{sigma_hat:.5f}') 118 | 119 | loss_train.backward() 120 | optimizer.step() 121 | 122 | train_time = time.time() - start_time 123 | print(f"\nTrain time: {train_time:.2f}") 124 | 125 | # In[Save model] 126 | model_folder = os.path.join("models", model_name) 127 | if not os.path.exists(model_folder): 128 | os.makedirs(model_folder) 129 | 130 | torch.save(G1.state_dict(), os.path.join(model_folder, "G1.pt")) 131 | torch.save(F_nl.state_dict(), os.path.join(model_folder, "F_nl.pt")) 132 | torch.save(G2.state_dict(), os.path.join(model_folder, "G2.pt")) 133 | 134 | 135 | # In[Simulate one more time] 136 | with torch.no_grad(): 137 | y_hat, y1_nl, y1_lin = model(u_fit_torch) 138 | 139 | # In[Detach] 140 | y_hat = y_hat.detach().numpy()[0, :, :] 141 | y1_lin = y1_lin.detach().numpy()[0, :, :] 142 | y1_nl = y1_nl.detach().numpy()[0, :, :] 143 | 144 | # In[Plot] 145 | plt.figure() 146 | plt.plot(t_fit, y_fit, 'k', label="$y$") 147 | plt.plot(t_fit, y_hat, 'b', label="$\hat y$") 148 | plt.legend() 149 | 150 | # In[Plot loss] 151 | plt.figure() 152 | plt.plot(LOSS) 153 | plt.grid(True) 154 | 155 | # In[Plot sigma] 156 | plt.figure() 157 | plt.plot(SIGMA) 158 | plt.grid(True) 159 | 160 | # In[Plot static non-linearity] 161 | 162 | y1_lin_min = np.min(y1_lin) 163 | y1_lin_max = np.max(y1_lin) 164 | 165 | in_nl = np.arange(y1_lin_min, y1_lin_max, (y1_lin_max- y1_lin_min)/1000).astype(np.float32).reshape(-1, 1) 166 | 167 | with torch.no_grad(): 168 | out_nl = F_nl(torch.as_tensor(in_nl)) 169 | 170 | plt.figure() 171 | plt.plot(in_nl, out_nl, 'b') 172 | plt.plot(in_nl, out_nl, 'b') 173 | #plt.plot(y1_lin, y1_nl, 'b*') 174 | plt.xlabel('Static non-linearity input (-)') 175 | plt.ylabel('Static non-linearity input (-)') 176 | plt.grid(True) 177 | 178 | # In[Plot] 179 | e_rms = util.metrics.error_rmse(y_hat, y_fit)[0] 180 | print(f"RMSE: {e_rms:.2f}") # target: 1mv 181 | 182 | 183 | 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /examples/WH2009/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/examples/WH2009/__init__.py -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/examples/__init__.py -------------------------------------------------------------------------------- /fig/dynonet_quant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/fig/dynonet_quant.png -------------------------------------------------------------------------------- /fig/neural_PEM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/fig/neural_PEM.png -------------------------------------------------------------------------------- /sphinx/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /sphinx/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /sphinx/source/code.rst: -------------------------------------------------------------------------------- 1 | dynoNet API 2 | ========================== 3 | --------------- 4 | LTI blocks 5 | --------------- 6 | 7 | .. automodule:: torchid.module.lti 8 | :members: 9 | :special-members: 10 | :member-order: bysource 11 | :exclude-members: __init__, forward 12 | 13 | --------------- 14 | Static blocks 15 | --------------- 16 | 17 | .. automodule:: torchid.module.static 18 | :members: 19 | :special-members: 20 | :member-order: bysource 21 | :exclude-members: __init__ -------------------------------------------------------------------------------- /sphinx/source/index.rst: -------------------------------------------------------------------------------- 1 | .. torchid documentation master file, created by 2 | sphinx-quickstart on Fri Apr 10 01:50:34 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to the dynoNet documentation! 7 | =================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | code 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | -------------------------------------------------------------------------------- /torchid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid/__init__.py -------------------------------------------------------------------------------- /torchid/functional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid/functional/__init__.py -------------------------------------------------------------------------------- /torchid/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid/module/__init__.py -------------------------------------------------------------------------------- /torchid/module/static.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MimoStaticNonLinearity(nn.Module): 6 | r"""Applies a Static MIMO non-linearity. 7 | The non-linearity is implemented as a feed-forward neural network. 8 | 9 | Args: 10 | in_channels (int): Number of input channels 11 | out_channels (int): Number of output channels 12 | n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20 13 | activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh' 14 | 15 | Shape: 16 | - Input: (..., in_channels) 17 | - Output: (..., out_channels) 18 | 19 | Examples:: 20 | 21 | >>> in_channels, out_channels = 2, 4 22 | >>> F = MimoStaticNonLinearity(in_channels, out_channels) 23 | >>> batch_size, seq_len = 32, 100 24 | >>> u_in = torch.ones((batch_size, seq_len, in_channels)) 25 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels) 26 | """ 27 | 28 | def __init__(self, in_channels, out_channels, n_hidden=20, activation='tanh'): 29 | super(MimoStaticNonLinearity, self).__init__() 30 | 31 | activation_dict = {'tanh': nn.Tanh, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid} 32 | 33 | self.net = nn.Sequential( 34 | nn.Linear(in_channels, n_hidden), 35 | activation_dict[activation](), #nn.Tanh(), 36 | nn.Linear(n_hidden, out_channels) 37 | ) 38 | 39 | def forward(self, u_lin): 40 | y_nl = self.net(u_lin) 41 | return y_nl 42 | 43 | 44 | class SisoStaticNonLinearity(MimoStaticNonLinearity): 45 | r"""Applies a Static SISO non-linearity. 46 | The non-linearity is implemented as a feed-forward neural network. 47 | 48 | Args: 49 | n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20 50 | activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh' 51 | s 52 | Shape: 53 | - Input: (..., in_channels) 54 | - Output: (..., out_channels) 55 | 56 | Examples:: 57 | 58 | >>> F = SisoStaticNonLinearity(n_hidden=20) 59 | >>> batch_size, seq_len = 32, 100 60 | >>> u_in = torch.ones((batch_size, seq_len, in_channels)) 61 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels) 62 | """ 63 | def __init__(self, n_hidden=20, activation='tanh'): 64 | super(SisoStaticNonLinearity, self).__init__(in_channels=1, out_channels=1, n_hidden=n_hidden, activation=activation) 65 | 66 | 67 | class MimoChannelWiseNonLinearity(nn.Module): 68 | r"""Applies a Channel-wise non-linearity. 69 | The non-linearity is implemented as a set of feed-forward neural networks (each one operating on a different channel). 70 | 71 | Args: 72 | channels (int): Number of both input and output channels 73 | n_hidden (int, optional): Number of nodes in the hidden layer of each network. Default: 10 74 | 75 | Shape: 76 | - Input: (..., channels) 77 | - Output: (..., channels) 78 | 79 | Examples:: 80 | 81 | >>> channels = 4 82 | >>> F = MimoChannelWiseNonLinearity(channels) 83 | >>> batch_size, seq_len = 32, 100 84 | >>> u_in = torch.ones((batch_size, seq_len, channels)) 85 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, channels) 86 | 87 | """ 88 | 89 | def __init__(self, channels, n_hidden=10): 90 | super(MimoChannelWiseNonLinearity, self).__init__() 91 | 92 | self.net = nn.ModuleList() 93 | for channel_idx in range(channels): 94 | channel_net = nn.Sequential( 95 | nn.Linear(1, n_hidden), # 2 states, 1 input 96 | nn.ReLU(), 97 | nn.Linear(n_hidden, 1) 98 | ) 99 | self.net.append(channel_net) 100 | 101 | def forward(self, u_lin): 102 | 103 | y_nl = [] 104 | for channel_idx, u_channel in enumerate(u_lin.split(1, dim=-1)): # split over the last dimension (input channel) 105 | y_nl_channel = self.net[channel_idx](u_channel) # Process blocks individually 106 | y_nl.append(y_nl_channel) 107 | 108 | y_nl = torch.cat(y_nl, -1) # concatenate all output channels 109 | return y_nl 110 | 111 | 112 | if __name__ == '__main__': 113 | 114 | channels = 4 115 | nn1 = MimoChannelWiseNonLinearity(channels) 116 | in_data = torch.randn(100, 10, channels) 117 | xx = net_out = nn1(in_data) -------------------------------------------------------------------------------- /torchid_nb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid_nb/__init__.py -------------------------------------------------------------------------------- /torchid_nb/functional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid_nb/functional/__init__.py -------------------------------------------------------------------------------- /torchid_nb/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/torchid_nb/module/__init__.py -------------------------------------------------------------------------------- /torchid_nb/module/static.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MimoStaticNonLinearity(nn.Module): 6 | r"""Applies a Static MIMO non-linearity. 7 | The non-linearity is implemented as a feed-forward neural network. 8 | 9 | Args: 10 | in_channels (int): Number of input channels 11 | out_channels (int): Number of output channels 12 | n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20 13 | activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh' 14 | 15 | Shape: 16 | - Input: (..., in_channels) 17 | - Output: (..., out_channels) 18 | 19 | Examples:: 20 | 21 | >>> in_channels, out_channels = 2, 4 22 | >>> F = MimoStaticNonLinearity(in_channels, out_channels) 23 | >>> batch_size, seq_len = 32, 100 24 | >>> u_in = torch.ones((batch_size, seq_len, in_channels)) 25 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels) 26 | """ 27 | 28 | def __init__(self, in_channels, out_channels, n_hidden=20, activation='tanh'): 29 | super(MimoStaticNonLinearity, self).__init__() 30 | 31 | activation_dict = {'tanh': nn.Tanh, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid} 32 | 33 | self.net = nn.Sequential( 34 | nn.Linear(in_channels, n_hidden), 35 | activation_dict[activation](), #nn.Tanh(), 36 | nn.Linear(n_hidden, out_channels) 37 | ) 38 | 39 | def forward(self, u_lin): 40 | y_nl = self.net(u_lin) 41 | return y_nl 42 | 43 | 44 | class SisoStaticNonLinearity(MimoStaticNonLinearity): 45 | r"""Applies a Static SISO non-linearity. 46 | The non-linearity is implemented as a feed-forward neural network. 47 | 48 | Args: 49 | n_hidden (int, optional): Number of nodes in the hidden layer. Default: 20 50 | activation (str): Activation function. Either 'tanh', 'relu', or 'sigmoid'. Default: 'tanh' 51 | s 52 | Shape: 53 | - Input: (..., in_channels) 54 | - Output: (..., out_channels) 55 | 56 | Examples:: 57 | 58 | >>> F = SisoStaticNonLinearity(n_hidden=20) 59 | >>> batch_size, seq_len = 32, 100 60 | >>> u_in = torch.ones((batch_size, seq_len, in_channels)) 61 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, out_channels) 62 | """ 63 | def __init__(self, n_hidden=20, activation='tanh'): 64 | super(SisoStaticNonLinearity, self).__init__(in_channels=1, out_channels=1, n_hidden=n_hidden, activation=activation) 65 | 66 | 67 | class MimoChannelWiseNonLinearity(nn.Module): 68 | r"""Applies a Channel-wise non-linearity. 69 | The non-linearity is implemented as a set of feed-forward neural networks (each one operating on a different channel). 70 | 71 | Args: 72 | channels (int): Number of both input and output channels 73 | n_hidden (int, optional): Number of nodes in the hidden layer of each network. Default: 10 74 | 75 | Shape: 76 | - Input: (..., channels) 77 | - Output: (..., channels) 78 | 79 | Examples:: 80 | 81 | >>> channels = 4 82 | >>> F = MimoChannelWiseNonLinearity(channels) 83 | >>> batch_size, seq_len = 32, 100 84 | >>> u_in = torch.ones((batch_size, seq_len, channels)) 85 | >>> y_out = F(u_in, y_0, u_0) # shape: (batch_size, seq_len, channels) 86 | 87 | """ 88 | 89 | def __init__(self, channels, n_hidden=10): 90 | super(MimoChannelWiseNonLinearity, self).__init__() 91 | 92 | self.net = nn.ModuleList() 93 | for channel_idx in range(channels): 94 | channel_net = nn.Sequential( 95 | nn.Linear(1, n_hidden), # 2 states, 1 input 96 | nn.Tanh(), 97 | nn.Linear(n_hidden, 1) 98 | ) 99 | self.net.append(channel_net) 100 | 101 | def forward(self, u_lin): 102 | 103 | y_nl = [] 104 | for channel_idx, u_channel in enumerate(u_lin.split(1, dim=-1)): # split over the last dimension (input channel) 105 | y_nl_channel = self.net[channel_idx](u_channel) # Process blocks individually 106 | y_nl.append(y_nl_channel) 107 | 108 | y_nl = torch.cat(y_nl, -1) # concatenate all output channels 109 | return y_nl 110 | 111 | 112 | if __name__ == '__main__': 113 | 114 | channels = 4 115 | nn1 = MimoChannelWiseNonLinearity(channels) 116 | in_data = torch.randn(100, 10, channels) 117 | xx = net_out = nn1(in_data) -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/forgi86/sysid-transfer-functions-pytorch/e761082a4f1780f88138769aba4124b3ed9fa9d2/util/__init__.py -------------------------------------------------------------------------------- /util/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #np.array 3 | 4 | 5 | def r_squared(y_true, y_pred, time_axis=0): 6 | """ Computes the R-square index. 7 | 8 | The R-squared index is computed separately on each channel. 9 | 10 | Parameters 11 | ---------- 12 | y_true : np.array 13 | Array of true values. If must be at least 2D. 14 | y_pred : np.array 15 | Array of predicted values. If must be compatible with y_true' 16 | time_axis : int 17 | Time axis. All other axes define separate channels. 18 | 19 | Returns 20 | ------- 21 | r_squared_val : np.array 22 | Array of r_squared value. 23 | """ 24 | 25 | SSE = np.sum((y_pred - y_true)**2, axis=time_axis) 26 | y_mean = np.mean(y_true, axis=time_axis, keepdims=True) 27 | SST = np.sum((y_true - y_mean)**2, axis=time_axis) 28 | 29 | return 1.0 - SSE/SST 30 | 31 | 32 | def error_rmse(y_true, y_pred, time_axis=0): 33 | """ Computes the Root Mean Square Error (RMSE). 34 | 35 | The RMSE index is computed separately on each channel. 36 | 37 | Parameters 38 | ---------- 39 | y_true : np.array 40 | Array of true values. If must be at least 2D. 41 | y_pred : np.array 42 | Array of predicted values. If must be compatible with y_true' 43 | time_axis : int 44 | Time axis. All other axes define separate channels. 45 | 46 | Returns 47 | ------- 48 | RMSE : np.array 49 | Array of r_squared value. 50 | 51 | """ 52 | 53 | SSE = np.mean((y_pred - y_true)**2, axis=time_axis) 54 | RMSE = np.sqrt(SSE) 55 | return RMSE 56 | 57 | 58 | def error_mean(y_true, y_pred, time_axis=0): 59 | """ Computes the error mean value. 60 | 61 | The RMSE index is computed separately on each channel. 62 | 63 | Parameters 64 | ---------- 65 | y_true : np.array 66 | Array of true values. If must be at least 2D. 67 | y_pred : np.array 68 | Array of predicted values. If must be compatible with y_true' 69 | time_axis : int 70 | Time axis. All other axes define separate channels. 71 | 72 | Returns 73 | ------- 74 | e_mean : np.array 75 | Array of error means. 76 | """ 77 | 78 | e_mean = np.mean(y_true - y_pred, axis=time_axis) 79 | return e_mean 80 | 81 | 82 | def error_mae(y_true, y_pred, time_axis=0): 83 | """ Computes the error Mean Absolute Value (MAE) 84 | 85 | The RMSE index is computed separately on each channel. 86 | 87 | Parameters 88 | ---------- 89 | y_true : np.array 90 | Array of true values. If must be at least 2D. 91 | y_pred : np.array 92 | Array of predicted values. If must be compatible with y_true' 93 | time_axis : int 94 | Time axis. All other axes define separate channels. 95 | 96 | Returns 97 | ------- 98 | e_mean : np.array 99 | Array of error mean absolute values. 100 | """ 101 | 102 | e_mean = np.mean(np.abs(y_true - y_pred), axis=time_axis) 103 | return e_mean 104 | 105 | def fit_index(y_true, y_pred, time_axis=0): 106 | """ Computes the per-channel fit index. 107 | 108 | The fit index is commonly used in System Identification. See the definitionin the System Identification Toolbox 109 | or in the paper 'Nonlinear System Identification: A User-Oriented Road Map', 110 | https://arxiv.org/abs/1902.00683, page 31. 111 | The fit index is computed separately on each channel. 112 | 113 | Parameters 114 | ---------- 115 | y_true : np.array 116 | Array of true values. If must be at least 2D. 117 | y_pred : np.array 118 | Array of predicted values. If must be compatible with y_true' 119 | time_axis : int 120 | Time axis. All other axes define separate channels. 121 | 122 | Returns 123 | ------- 124 | fit_val : np.array 125 | Array of r_squared value. 126 | 127 | """ 128 | 129 | err_norm = np.linalg.norm(y_true - y_pred, axis=time_axis, ord=2) # || y - y_pred || 130 | y_mean = np.mean(y_true, axis=time_axis) 131 | err_mean_norm = np.linalg.norm(y_true - y_mean, ord=2) # || y - y_mean || 132 | fit_val = 100*(1 - err_norm/err_mean_norm) 133 | 134 | return fit_val 135 | 136 | 137 | if __name__ == '__main__': 138 | N = 20 139 | ny = 2 140 | SNR = 10 141 | y_true = SNR*np.random.randn(N, 2) 142 | y_pred = np.copy(y_true) + np.random.randn(N, 2) 143 | err_rmse_val = error_rmse(y_pred, y_true) 144 | r_squared_val = r_squared(y_true, y_pred) 145 | fit_val = fit_index(y_true, y_pred) 146 | 147 | print(f"RMSE: {err_rmse_val}") 148 | print(f"R-squared: {r_squared_val}") 149 | print(f"fit index: {fit_val}") 150 | --------------------------------------------------------------------------------