├── .gitignore ├── LICENSE ├── README.md ├── demos ├── __init__.py ├── basics │ ├── kalman-filter.ipynb │ ├── lofi-diagonal.ipynb │ └── subspace-neural-net.ipynb ├── collas │ ├── README.md │ ├── __init__.py │ ├── classification │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── permuted │ │ │ │ ├── cifar10 │ │ │ │ │ └── mlp │ │ │ │ │ │ └── nll │ │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ │ ├── fdekf.json │ │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ │ └── vdekf.json │ │ │ │ └── fashion_mnist │ │ │ │ │ └── mlp │ │ │ │ │ └── nll │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── lofi-20.json │ │ │ │ │ ├── lofi-5.json │ │ │ │ │ ├── lofi-50.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ ├── rotated │ │ │ │ └── fashion_mnist │ │ │ │ │ └── mlp │ │ │ │ │ └── nll │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ ├── split │ │ │ │ └── fashion_mnist │ │ │ │ │ └── mlp │ │ │ │ │ └── nll │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ └── stationary │ │ │ │ ├── cifar10 │ │ │ │ └── mlp │ │ │ │ │ └── nll │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── lofi-2.json │ │ │ │ │ ├── lofi-20.json │ │ │ │ │ ├── lofi-5.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ │ ├── cifar10_10k │ │ │ │ ├── cnn │ │ │ │ │ └── nll │ │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ │ ├── fdekf.json │ │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ │ └── vdekf.json │ │ │ │ ├── dnn │ │ │ │ │ └── nll │ │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ │ ├── fdekf.json │ │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ │ └── vdekf.json │ │ │ │ └── mlp │ │ │ │ │ └── nll │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ │ └── fashion_mnist │ │ │ │ ├── cnn │ │ │ │ └── nll │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── lofi-20.json │ │ │ │ │ ├── lofi-5.json │ │ │ │ │ ├── lofi-50.json │ │ │ │ │ ├── lofi-sph-1.json │ │ │ │ │ ├── lofi-sph-10.json │ │ │ │ │ ├── lofi-sph-20.json │ │ │ │ │ ├── lofi-sph-5.json │ │ │ │ │ ├── lofi-sph-50.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ │ └── mlp │ │ │ │ ├── nll │ │ │ │ ├── adam-rb-1.json │ │ │ │ ├── adam-rb-10.json │ │ │ │ ├── fdekf-it.json │ │ │ │ ├── fdekf.json │ │ │ │ ├── lofi-1.json │ │ │ │ ├── lofi-10.json │ │ │ │ ├── sgd-rb-1.json │ │ │ │ ├── sgd-rb-10.json │ │ │ │ ├── vdekf-it.json │ │ │ │ └── vdekf.json │ │ │ │ └── nlpd-linearized │ │ │ │ ├── adam-rb-1.json │ │ │ │ ├── adam-rb-10.json │ │ │ │ ├── fdekf.json │ │ │ │ ├── lofi-1.json │ │ │ │ ├── lofi-10.json │ │ │ │ ├── sgd-rb-1.json │ │ │ │ ├── sgd-rb-10.json │ │ │ │ └── vdekf.json │ │ ├── outputs │ │ │ ├── generate_cifar10_plots.ipynb │ │ │ ├── generate_permuted_clf_plots.ipynb │ │ │ ├── generate_rotated_clf_plots.ipynb │ │ │ ├── generate_split_clf_plots.ipynb │ │ │ └── generate_stationary_clf_plots.ipynb │ │ ├── probe.ipynb │ │ └── spectral_decay.ipynb │ ├── datasets │ │ ├── __init__.py │ │ └── dataloaders.py │ ├── hparam_tune.py │ ├── regression │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── iid │ │ │ │ └── fashion_mnist │ │ │ │ │ └── mlp │ │ │ │ │ ├── nll │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ │ │ ├── nlpd-linearized │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ │ │ └── nlpd-mc │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1-temp-0.01.json │ │ │ │ │ ├── lofi-1-temp-0.05.json │ │ │ │ │ ├── lofi-1-temp-0.1.json │ │ │ │ │ ├── lofi-1-temp-0.5.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ ├── permuted │ │ │ │ └── fashion_mnist │ │ │ │ │ └── mlp │ │ │ │ │ └── nll │ │ │ │ │ ├── adam-rb-1.json │ │ │ │ │ ├── adam-rb-10.json │ │ │ │ │ ├── fdekf.json │ │ │ │ │ ├── lofi-1.json │ │ │ │ │ ├── lofi-10.json │ │ │ │ │ ├── sgd-rb-1.json │ │ │ │ │ ├── sgd-rb-10.json │ │ │ │ │ └── vdekf.json │ │ │ └── random-walk │ │ │ │ └── fashion_mnist │ │ │ │ └── mlp │ │ │ │ └── nll │ │ │ │ ├── adam-rb-1.json │ │ │ │ ├── adam-rb-10.json │ │ │ │ ├── fdekf.json │ │ │ │ ├── lofi-1.json │ │ │ │ ├── lofi-10.json │ │ │ │ ├── sgd-rb-1.json │ │ │ │ ├── sgd-rb-10.json │ │ │ │ └── vdekf.json │ │ ├── outputs │ │ │ ├── generate_iid_reg_plots.ipynb │ │ │ ├── generate_permuted_reg_plots.ipynb │ │ │ └── generate_rw_reg_plots.ipynb │ │ ├── probe.ipynb │ │ ├── rotating-mnist-gradual.ipynb │ │ ├── rotating-mnist-unsorted.ipynb │ │ └── rotating_mnist_unsorted.py │ ├── run_classification_experiments.py │ ├── run_regression_experiments.py │ └── train_utils.py ├── dekf_demos │ ├── __init__.py │ ├── diagonal_demo.ipynb │ ├── half_moons_demo.ipynb │ ├── in_between_uncertainty.ipynb │ ├── mnist_evaluation.ipynb │ └── spiral_classif_demo.py ├── figures │ ├── regression_plot_1d_lofi.pdf │ ├── regression_plot_1d_lofi.png │ ├── regression_plot_1d_vcl.pdf │ ├── regression_plot_1d_vcl.png │ ├── rmse_vs_time.pdf │ ├── rmse_vs_time.png │ ├── rmse_vs_time_log.pdf │ ├── rmse_vs_time_log.png │ ├── rmse_vs_time_vdekf.pdf │ └── rmse_vs_time_vdekf.png ├── gradually-rotating │ ├── cfg_main.py │ ├── damped-rotating-mnist-classification.ipynb │ ├── damped-rotating-mnist.ipynb │ ├── eval-clf-lofi-rsgd.ipynb │ ├── eval-gradually-increasing-angle.ipynb │ ├── eval-lofi-hparam-clf.ipynb │ ├── run_clf_lofi_rsgd.py │ ├── run_gradually_increasing_angle.py │ ├── run_lofi_hparam_clf.py │ └── run_main.py ├── misc │ ├── BBB-1d-regression.ipynb │ ├── BBB-rmnist.ipynb │ ├── VCL-rmnist.ipynb │ ├── bnn-incremental-hmc.ipynb │ ├── bnn-rmnist-hmc-trenches.ipynb │ ├── ensemble-binned-hetdnn-rmnist-trenches.ipynb │ ├── ensemble-dnn-rmnist-trenches.ipynb │ ├── ensemble-hetdnn-rmnist-trenches.ipynb │ ├── l-rvga-flax.ipynb │ ├── l-rvga-linreg.ipynb │ ├── lofi-diagonal-fourier.ipynb │ ├── lofi-v-subspace-classification-changing.ipynb │ ├── lofi-v-subspace-classification.ipynb │ ├── mekf-mlofi-nlinreg.ipynb │ ├── nlm_1d_regression.ipynb │ ├── nonstat-1d-regression.ipynb │ ├── online-laplace-rb.ipynb │ ├── r-vga-lr.ipynb │ ├── rotating-moons-dataset.ipynb │ ├── sgd-replay-mnist.ipynb │ ├── sgd-replay-rotating-mnist.ipynb │ ├── slowly-changing-rotation.ipynb │ ├── tyxe_rebayes.ipynb │ └── xp-lrvga-linear-regression.ipynb ├── nonstat-1d-regression.ipynb ├── param_initialization │ └── zero_lecun_demo.ipynb ├── showdown │ ├── __init__.py │ ├── build_uci_method_df.py │ ├── classification.ipynb │ ├── classification │ │ ├── __init__.py │ │ ├── ablation_inflation_types.py │ │ ├── classification_train.py │ │ ├── hparam_tune_clf.py │ │ ├── permuted-mnist-classification-experiment.ipynb │ │ └── permuted-mnist-classification.ipynb │ ├── linreg-comparison.ipynb │ ├── logistic-regression.ipynb │ ├── mnist.ipynb │ ├── nonstationary │ │ ├── __init__.py │ │ ├── hparam_tune_clf.py │ │ ├── r-moons-classification.ipynb │ │ └── split-mnist-classification.ipynb │ ├── plots-xval-passes.ipynb │ ├── plots-xval.ipynb │ ├── regression-rotating-mnist.ipynb │ ├── regression │ │ ├── __init__.py │ │ ├── hparam_tune_clf.py │ │ ├── hparam_tune_ekf.py │ │ ├── hparam_tune_lofi.py │ │ ├── hparam_tune_sgd.py │ │ ├── peter │ │ │ ├── __init__.py │ │ │ └── hparam_tune_reg.py │ │ ├── regression-uci.ipynb │ │ ├── regression_train.py │ │ └── rotating_mnist.py │ ├── split-mnist.ipynb │ └── time-analysis.ipynb ├── stitching │ ├── __init__.py │ ├── bayes_glue_demo.ipynb │ ├── configs │ │ ├── mlp_[50, 50, 50]_n_train_200_fcekf.json │ │ └── mlp_[50, 50, 50]_n_train_200_sgd.json │ ├── stitch_reg.py │ └── stitching_demo.ipynb └── testbed.ipynb ├── misc ├── README.md ├── kfac.ipynb ├── nando-efk-mlp │ ├── data.txt │ ├── ekfdemo1.m │ ├── mlpekf.m │ └── mlpekfQ.m ├── rotated_digits_load.py ├── rotated_digits_matlab.mat └── torch_scan.py ├── rebayes ├── __init__.py ├── base.py ├── datasets │ ├── __init__.py │ ├── classification_data.py │ ├── data_utils.py │ ├── datasets.py │ ├── moons_data.py │ ├── nonstat_1d_data.py │ ├── rotating_mnist_data.py │ ├── rotating_permuted_mnist_data.py │ ├── uci_regression_data.py │ └── uci_uncertainty_data.py ├── deprecated │ ├── __init__.py │ ├── ekf_old.py │ ├── linreg_demo.py │ ├── old_base.py │ ├── optax_optimizer.py │ ├── optimizer.py │ └── simple_base.py ├── dual_base.py ├── extended_kalman_filter │ ├── __init__.py │ ├── demos │ │ ├── dual_ekf_demo.ipynb │ │ ├── nfekf_demo.ipynb │ │ ├── ocl_demo.ipynb │ │ └── sw_ekf_demo.ipynb │ ├── dual_ekf.py │ ├── ekf.py │ ├── ekf_core.py │ ├── enkf.py │ ├── replay_ekf.py │ ├── sw_ekf.py │ ├── test_dual_ekf.py │ ├── test_ekf.py │ └── viking.py ├── ivon │ ├── __init__.py │ └── ivon.py ├── linear_filter │ └── kf.py ├── low_rank_filter │ ├── __init__.py │ ├── cold_posterior_lofi.py │ ├── demos │ │ ├── LOFIScoreMatching.ipynb │ │ ├── generalized_orfit_mnist.ipynb │ │ └── orfit_paper_reproduction.ipynb │ ├── dual_lofi.py │ ├── ggt.py │ ├── lofi.py │ ├── lofi_core.py │ ├── lrvga.py │ ├── orfit.py │ ├── replay_lofi.py │ ├── slang.py │ ├── subspace_filter.py │ └── test_dual_lofi.py ├── mcmc_filter │ └── hamiltonian_monte_carlo.py ├── sgd_filter │ ├── __init__.py │ ├── replay_sgd.py │ └── sgd.py ├── utils │ ├── __init__.py │ ├── callbacks.py │ ├── models.py │ ├── normalizing_flows.py │ ├── preprocessing.py │ ├── sampling.py │ ├── split_mnist_data_test.py │ └── utils.py └── vi_filter │ ├── bayes_by_backprop.py │ └── variational_continual_learning.py ├── setup.py └── tests ├── __init__.py ├── test_base.py ├── test_base_dl.py ├── test_dual_base.py ├── test_lofi.py └── test_orfit.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | docs/_build/ 3 | build/ 4 | dist/ 5 | 6 | *egg-info 7 | *.ipynb_checkpoints 8 | *.pyc 9 | *.pkl 10 | *-ubyte* 11 | *.log 12 | *.mp4 13 | 14 | # ignore figures unless manually added 15 | *.png 16 | *.jpg 17 | *.pdf 18 | *-dot 19 | *.DS_Store 20 | .vscode/ 21 | .venv 22 | */data/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReBayes = Recursive Bayesian inference for latent states 2 | 3 | 📝 Paper: [Low-rank extended Kalman filtering for online learning of neural networks from streaming data](https://arxiv.org/abs/2305.19535) 4 | 5 | We provide code for online (recursive) Bayesian inference in state space models; 6 | in contrast to the dynamax code, we do not assume the entire observation sequence is available in advance, 7 | so the ReBayes API can be used in an interactive loop (e.g., for Bayesian optimization). 8 | We assume the dynamics model is linear Gaussian (or constant), 9 | but the observation model can be non-linear and non-Gaussian. 10 | 11 | This is work in progress; a stable version will be released late Spring 2023. 12 | 13 | ![flipbook-lofi-fourier](https://user-images.githubusercontent.com/4108759/230786889-9fabdada-20d4-49fc-b9ee-c67d4db90d4b.png) 14 | -------------------------------------------------------------------------------- /demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/__init__.py -------------------------------------------------------------------------------- /demos/collas/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/collas/README.md -------------------------------------------------------------------------------- /demos/collas/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demos/collas/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/collas/classification/__init__.py -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/cifar10/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 4.539993096841499e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/cifar10/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 7.184404967119917e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/cifar10/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 4.539993096841499e-05, "dynamics_covariance": 5.346815214579692e-06, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights_or_function": 0.9999998274234372} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/cifar10/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 4.539993096841499e-05, "dynamics_covariance": 4.2663443309720606e-05, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 0.9999104929956957} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/cifar10/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.1045946329832077, "dynamics_covariance": 2.6182844781175163e-10, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 0.9999994108294459} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/cifar10/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0021004891023039818} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/cifar10/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 5.8919140428770334e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/cifar10/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 4.539993096841499e-05, "dynamics_covariance": 4.652425104723079e-06, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights_or_function": 0.9999994731969082} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0002056889934465289} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00021378239034675062} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 2.383402488703723e-06, "dynamics_covariance_inflation_factor": 8.208662904962694e-08, "dynamics_weights_or_function": 0.9999999999998356} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 2.383402488703723e-06, "dynamics_covariance_inflation_factor": 8.208662904962694e-08, "dynamics_weights": 0.9999999999998356} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 2.383402488703723e-06, "dynamics_covariance_inflation_factor": 8.208662904962694e-08, "dynamics_weights": 0.9999999999998356} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/lofi-20.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.010554869659245014, "dynamics_covariance": 6.675282747892197e-06, "dynamics_covariance_inflation_factor": 0.00019464772776700556, "dynamics_weights": 0.9999986769818179} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/lofi-5.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.020954398438334465, "dynamics_covariance": 1.587801989444415e-06, "dynamics_covariance_inflation_factor": 1.188868009194266e-05, "dynamics_weights": 0.9999981887524427} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/lofi-50.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801548719406, "dynamics_covariance": 1.6442786531217085e-13, "dynamics_covariance_inflation_factor": 2.3834011244616704e-06, "dynamics_weights": 0.999999917913442} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.010977965779602528} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.003972665872424841} -------------------------------------------------------------------------------- /demos/collas/classification/configs/permuted/fashion_mnist/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 2.383402488703723e-06, "dynamics_covariance_inflation_factor": 8.208662904962694e-08, "dynamics_weights_or_function": 0.9999999999998356} -------------------------------------------------------------------------------- /demos/collas/classification/configs/rotated/fashion_mnist/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00023749712272547185} -------------------------------------------------------------------------------- /demos/collas/classification/configs/rotated/fashion_mnist/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00016051571583375335} -------------------------------------------------------------------------------- /demos/collas/classification/configs/rotated/fashion_mnist/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 1.6442760781415489e-13, "dynamics_covariance_inflation_factor": 2.383402488703723e-06, "dynamics_weights_or_function": 0.999999917913371} -------------------------------------------------------------------------------- /demos/collas/classification/configs/rotated/fashion_mnist/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.035826001316308975, "dynamics_covariance": 8.595070727195064e-13, "dynamics_covariance_inflation_factor": 4.5447773544537995e-08, "dynamics_weights": 0.9999991010890881} -------------------------------------------------------------------------------- /demos/collas/classification/configs/rotated/fashion_mnist/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.0244983471930027, "dynamics_covariance": 2.157323798768207e-13, "dynamics_covariance_inflation_factor": 1.7873260276246583e-06, "dynamics_weights": 0.9999999235599333} -------------------------------------------------------------------------------- /demos/collas/classification/configs/rotated/fashion_mnist/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.010977965779602528} -------------------------------------------------------------------------------- /demos/collas/classification/configs/rotated/fashion_mnist/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0060435328632593155} -------------------------------------------------------------------------------- /demos/collas/classification/configs/rotated/fashion_mnist/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 1.6442760781415489e-13, "dynamics_covariance_inflation_factor": 2.383402488703723e-06, "dynamics_weights_or_function": 0.999999917913371} -------------------------------------------------------------------------------- /demos/collas/classification/configs/split/fashion_mnist/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00013753461826127023} -------------------------------------------------------------------------------- /demos/collas/classification/configs/split/fashion_mnist/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 5.557282565860078e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/split/fashion_mnist/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 1.6442760781415489e-13, "dynamics_covariance_inflation_factor": 2.383402488703723e-06, "dynamics_weights_or_function": 0.999999917913371} -------------------------------------------------------------------------------- /demos/collas/classification/configs/split/fashion_mnist/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.037291429936885834, "dynamics_covariance": 7.683067815378308e-05, "dynamics_covariance_inflation_factor": 9.597619872242155e-14, "dynamics_weights": 0.9999999998978196} -------------------------------------------------------------------------------- /demos/collas/classification/configs/split/fashion_mnist/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.13974709808826447, "dynamics_covariance": 8.463967374194858e-12, "dynamics_covariance_inflation_factor": 2.677482484614302e-07, "dynamics_weights": 0.9999998583443812} -------------------------------------------------------------------------------- /demos/collas/classification/configs/split/fashion_mnist/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.01882750913500786} -------------------------------------------------------------------------------- /demos/collas/classification/configs/split/fashion_mnist/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.01598750613629818} -------------------------------------------------------------------------------- /demos/collas/classification/configs/split/fashion_mnist/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 1.6442760781415489e-13, "dynamics_covariance_inflation_factor": 2.383402488703723e-06, "dynamics_weights_or_function": 0.999999917913371} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 4.539993096841499e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 4.539993096841499e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.002340520266443491, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.002816260326653719, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.004929000046104193, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/lofi-2.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.002388906432315707, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/lofi-20.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.004712103866040707, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/lofi-5.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.002952559618279338, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00148484215606004} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 9.882110316539183e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.0018995467107743025, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/cnn/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00015973432164173573} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/cnn/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.041510991752147675, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/cnn/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.05918506905436516, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/cnn/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.041510991752147675, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/cnn/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.010977965779602528} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/cnn/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 4.539993096841499e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/cnn/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.041510991752147675, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/dnn/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 4.539993096841499e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/dnn/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 4.539993096841499e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/dnn/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.010575706139206886, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/dnn/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.008839590474963188, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/dnn/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.008973372168838978, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/dnn/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.010554869659245014} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/dnn/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00019998291099909693} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/dnn/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.005629824474453926, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 8.194008692231508e-40, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 4.539986184681766e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 4.539986184681766e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.0018242523074150085, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.0018427588511258364, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.003770754672586918, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0016171499155461788} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00011783037916757166} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/cifar10_10k/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.0019540241919457912, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0012686057016253471} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0008061544504016638} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.10459447652101517, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 2.3644252777899055e-08, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.15452848374843597, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 3.945350712797335e-08, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.10459447652101517, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 2.3644252777899055e-08, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-20.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.1045946329832077, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 2.364432916124315e-08, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-5.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.1045946329832077, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 2.364432916124315e-08, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-50.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.04151087626814842, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 1.0211630979028996e-05, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-sph-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.06014932692050934, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 3.5177106383343926e-06, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-sph-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.029274344444274902, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 0.009429787285625935, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-sph-20.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.041510991752147675, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 1.0211657354375347e-05, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-sph-5.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.047011375427246094, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 1.6782491002231836e-05, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/lofi-sph-50.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.04151087626814842, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 1.0211630979028996e-05, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.040685247629880905} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.018828852102160454} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/cnn/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.10459447652101517, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 2.3644252777899055e-08, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0002499785041436553} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00013749791833106428} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/fdekf-it.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 4.539993096841499e-05, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0, "dynamics_weights_or_function": 1.0, "learning_rate": 4.539993096841499e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801548719406, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 2.3834011244616704e-06, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.020097902044653893, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 3.351170718701724e-08, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02162063680589199, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0010188906453549862, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.01330578327178955} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0056780665181577206} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/vdekf-it.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 4.539993096841499e-05, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0, "dynamics_weights_or_function": 1.0, "learning_rate": 4.539993096841499e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801548719406, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 2.3834011244616704e-06, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nlpd-linearized/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00029352190904319286, "initial_covariance": 0.0005571987712755799} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nlpd-linearized/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.00010850580292753875, "initial_covariance": 9.237635822501034e-05} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nlpd-linearized/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.008995318785309792, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 9.267486333897068e-09, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nlpd-linearized/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.011571381241083145, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 3.8317190046655014e-05, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nlpd-linearized/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.01802695356309414, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 3.101140177985684e-10, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nlpd-linearized/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.028982624411582947, "initial_covariance": 0.00314025254920125} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nlpd-linearized/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.004897463135421276, "initial_covariance": 0.0026078808587044477} -------------------------------------------------------------------------------- /demos/collas/classification/configs/stationary/fashion_mnist/mlp/nlpd-linearized/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.008995318785309792, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 9.267486333897068e-09, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/collas/datasets/__init__.py -------------------------------------------------------------------------------- /demos/collas/regression/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0007870193221606314} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.011167056858539581} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.041510991752147675, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 1.0211657354375347e-05, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.15742580592632294, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 9.172334766716084e-12, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.2725648880004883, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 0.006595255341380835, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.010554869659245014} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.028982624411582947} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.041510991752147675, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 1.0211657354375347e-05, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-linearized/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0021004891023039818, "initial_covariance": 0.6953274011611938} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-linearized/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.0021004891023039818, "initial_covariance": 0.6953274011611938} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-linearized/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.1045946329832077, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 2.364432916124315e-08, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-linearized/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.1045946329832077, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 2.364432916124315e-08, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-linearized/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.1457505226135254, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 9.521994002170686e-07, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-linearized/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.008995318785309792, "initial_covariance": 0.12458717823028564} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-linearized/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.03477505221962929, "initial_covariance": 0.05207415670156479} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-linearized/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.3388254940509796, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 2.434518319205381e-05, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.00015868346963543445, "learning_rate": 0.0013185807038098574} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.0026695493143051863, "learning_rate": 0.004370334558188915} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.004089664202183485, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 5.6327884522033855e-05, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/lofi-1-temp-0.01.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.06085390970110893, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 3.0561628250325157e-07, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/lofi-1-temp-0.05.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.028987614437937737, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 1.4494013989008181e-08, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/lofi-1-temp-0.1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.020122380927205086, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 5.7168941566487774e-05, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/lofi-1-temp-0.5.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.007819163613021374, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 6.24967151452438e-06, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.008293074555695057, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 1.357267024104658e-06, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.007017523515969515, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 3.900650540344941e-07, "dynamics_weights": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 4.610669566318393e-05, "learning_rate": 0.010105673223733902} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.00314025254920125, "learning_rate": 0.028982624411582947} -------------------------------------------------------------------------------- /demos/collas/regression/configs/iid/fashion_mnist/mlp/nlpd-mc/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.004089662339538336, "dynamics_covariance": 8.194008692231508e-40, "dynamics_covariance_inflation_factor": 5.632793909171596e-05, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/collas/regression/configs/permuted/fashion_mnist/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.006653312128037214, "initial_covariance": 0.0001603368582436815} -------------------------------------------------------------------------------- /demos/collas/regression/configs/permuted/fashion_mnist/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.006613842211663723, "initial_covariance": 0.0001595370122231543} -------------------------------------------------------------------------------- /demos/collas/regression/configs/permuted/fashion_mnist/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.036486055701971054, "dynamics_covariance": 9.379550901158357e-14, "dynamics_covariance_inflation_factor": 4.952033805238898e-07, "dynamics_weights_or_function": 0.9999999751830639} -------------------------------------------------------------------------------- /demos/collas/regression/configs/permuted/fashion_mnist/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.07452990859746933, "dynamics_covariance": 9.357622912219837e-14, "dynamics_covariance_inflation_factor": 5.306188973451531e-13, "dynamics_weights": 0.9999999999880127} -------------------------------------------------------------------------------- /demos/collas/regression/configs/permuted/fashion_mnist/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.2725648880004883, "dynamics_covariance": 0.001287185470573604, "dynamics_covariance_inflation_factor": 0.006595255341380835, "dynamics_weights": 0.9999999999998284} -------------------------------------------------------------------------------- /demos/collas/regression/configs/permuted/fashion_mnist/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.010554869659245014, "initial_covariance": 0.018828824162483215} -------------------------------------------------------------------------------- /demos/collas/regression/configs/permuted/fashion_mnist/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.028982624411582947, "initial_covariance": 0.00314025254920125} -------------------------------------------------------------------------------- /demos/collas/regression/configs/permuted/fashion_mnist/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801548719406, "dynamics_covariance": 1.6442786531217085e-13, "dynamics_covariance_inflation_factor": 2.3834011244616704e-06, "dynamics_weights_or_function": 0.999999917913442} -------------------------------------------------------------------------------- /demos/collas/regression/configs/random-walk/fashion_mnist/mlp/nll/adam-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.000859695952385664, "initial_covariance": 0.00019775258260779083} -------------------------------------------------------------------------------- /demos/collas/regression/configs/random-walk/fashion_mnist/mlp/nll/adam-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.008995318785309792, "initial_covariance": 0.12458717823028564} -------------------------------------------------------------------------------- /demos/collas/regression/configs/random-walk/fashion_mnist/mlp/nll/fdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.03655003011226654, "dynamics_covariance": 9.362049845215367e-14, "dynamics_covariance_inflation_factor": 4.954197834194929e-07, "dynamics_weights_or_function": 0.9999999751567845} -------------------------------------------------------------------------------- /demos/collas/regression/configs/random-walk/fashion_mnist/mlp/nll/lofi-1.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.1045946329832077, "dynamics_covariance": 2.6182844781175163e-10, "dynamics_covariance_inflation_factor": 2.364432916124315e-08, "dynamics_weights": 0.9999994108294459} -------------------------------------------------------------------------------- /demos/collas/regression/configs/random-walk/fashion_mnist/mlp/nll/lofi-10.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.21394824981689453, "dynamics_covariance": 1.4342944421128628e-13, "dynamics_covariance_inflation_factor": 9.9949432411095e-13, "dynamics_weights": 0.9999999999741871} -------------------------------------------------------------------------------- /demos/collas/regression/configs/random-walk/fashion_mnist/mlp/nll/sgd-rb-1.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.010554869659245014, "initial_covariance": 0.018828824162483215} -------------------------------------------------------------------------------- /demos/collas/regression/configs/random-walk/fashion_mnist/mlp/nll/sgd-rb-10.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.028982624411582947, "initial_covariance": 0.00314025254920125} -------------------------------------------------------------------------------- /demos/collas/regression/configs/random-walk/fashion_mnist/mlp/nll/vdekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.02184801734983921, "dynamics_covariance": 1.6442760781415489e-13, "dynamics_covariance_inflation_factor": 2.383402488703723e-06, "dynamics_weights_or_function": 0.999999917913371} -------------------------------------------------------------------------------- /demos/dekf_demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/dekf_demos/__init__.py -------------------------------------------------------------------------------- /demos/figures/regression_plot_1d_lofi.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/regression_plot_1d_lofi.pdf -------------------------------------------------------------------------------- /demos/figures/regression_plot_1d_lofi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/regression_plot_1d_lofi.png -------------------------------------------------------------------------------- /demos/figures/regression_plot_1d_vcl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/regression_plot_1d_vcl.pdf -------------------------------------------------------------------------------- /demos/figures/regression_plot_1d_vcl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/regression_plot_1d_vcl.png -------------------------------------------------------------------------------- /demos/figures/rmse_vs_time.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/rmse_vs_time.pdf -------------------------------------------------------------------------------- /demos/figures/rmse_vs_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/rmse_vs_time.png -------------------------------------------------------------------------------- /demos/figures/rmse_vs_time_log.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/rmse_vs_time_log.pdf -------------------------------------------------------------------------------- /demos/figures/rmse_vs_time_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/rmse_vs_time_log.png -------------------------------------------------------------------------------- /demos/figures/rmse_vs_time_vdekf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/rmse_vs_time_vdekf.pdf -------------------------------------------------------------------------------- /demos/figures/rmse_vs_time_vdekf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/figures/rmse_vs_time_vdekf.png -------------------------------------------------------------------------------- /demos/gradually-rotating/cfg_main.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | def get_config(): 4 | cfg = ml_collections.ConfigDict() 5 | 6 | cfg.memory = 10 7 | cfg.forecast = 10 8 | 9 | # LoFi parameters 10 | cfg.lofi = ml_collections.ConfigDict() 11 | cfg.lofi.dynamics_weight = 1.0 12 | cfg.lofi.dynamics_covariance = 1e-3 13 | cfg.lofi.initial_covariance = 1.0 14 | 15 | # Replay-buffer SGD parameters 16 | cfg.rsgd = ml_collections.ConfigDict() 17 | cfg.rsgd.learning_rate = 5e-4 18 | cfg.rsgd.n_inner = 1 19 | 20 | return cfg 21 | 22 | if __name__ == "__main__": 23 | cfg = get_config() 24 | print(cfg) 25 | -------------------------------------------------------------------------------- /demos/gradually-rotating/run_clf_lofi_rsgd.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import jax.numpy as jnp 4 | 5 | import run_main as ev 6 | from functools import partial 7 | from cfg_main import get_config 8 | from rebayes.utils.utils import tree_to_cpu 9 | from rebayes.utils.callbacks import cb_clf_sup 10 | from rebayes.sgd_filter import replay_sgd as rsgd 11 | from rebayes.datasets import rotating_mnist_data as data 12 | 13 | 14 | def damp_angle(n_configs, minangle, maxangle): 15 | t = np.linspace(-0.5, 1.5, n_configs) 16 | 17 | angles = np.exp(t) * np.sin(55 * t) 18 | angles = np.sin(55 * t) 19 | 20 | angles = (angles + 1) / 2 * (maxangle - minangle) + minangle + np.random.randn(n_configs) * 2 21 | 22 | return angles 23 | 24 | 25 | def categorise(labels): 26 | """ 27 | Labels is taken to be a list of ordinal numbers 28 | """ 29 | # One-hot-encoded 30 | n_classes = max(labels) + 1 31 | 32 | ohed = jax.nn.one_hot(labels, n_classes) 33 | filter_columns = ~(ohed == 0).all(axis=0) 34 | ohed = ohed[:, filter_columns] 35 | return ohed 36 | 37 | 38 | def emission_cov_function(w, x, fn_mean): 39 | """ 40 | Compute the covariance matrix of the emission distribution. 41 | fn_mean: emission mean function 42 | """ 43 | ps = fn_mean(w, x) 44 | n_classes = len(ps) 45 | I = jnp.eye(n_classes) 46 | return jnp.diag(ps) - jnp.outer(ps, ps) + 1e-3 * I 47 | 48 | 49 | if __name__ == "__main__": 50 | target_digits = [2, 3] 51 | n_classes = len(target_digits) 52 | num_train = 6_000 53 | 54 | data = data.load_and_transform(damp_angle, target_digits, num_train, sort_by_angle=False) 55 | X_train, signal_train, labels_train = data["dataset"]["train"] 56 | X_test, signal_test, labels_test = data["dataset"]["test"] 57 | Y_train = categorise(labels_train) 58 | Y_test = categorise(labels_test) 59 | 60 | cfg = get_config() 61 | cfg.lofi.dynamics_covariance = 0.0 62 | cfg.lofi.dynamics_weights = 1.0 63 | cfg.lofi.initial_covariance = 0.1 64 | 65 | _, dim_in = X_train.shape 66 | model, tree_params, flat_params, recfn = ev.make_bnn_flax(dim_in, n_classes) 67 | apply_fn_flat = partial(ev.apply_fn_flat, model=model, recfn=recfn) 68 | apply_fn_tree = partial(ev.apply_fn_unflat, model=model) 69 | def emission_mean_fn(w, x): return jax.nn.softmax(apply_fn_flat(w, x)) 70 | emission_cov_fn = partial(emission_cov_function, fn_mean=emission_mean_fn) 71 | 72 | 73 | _, dim_in = data["dataset"]["train"][0].shape 74 | 75 | callback_part = partial(cb_clf_sup, 76 | X_test=X_train, y_test=Y_train, 77 | ) 78 | 79 | ### Lofi---load and train 80 | agent = ev.load_lofi_agent(cfg, flat_params, emission_mean_fn, emission_cov_fn) 81 | callback_lofi = partial(callback_part, recfn=recfn, apply_fn=apply_fn_flat) 82 | bel, outputs_lofi = agent.scan(X_train, Y_train, progress_bar=True, callback=callback_lofi) 83 | bel = jax.block_until_ready(bel) 84 | outputs_lofi = tree_to_cpu(outputs_lofi) 85 | 86 | ### RSGD---load and train 87 | apply_fn = partial(apply_fn_tree, model=model) 88 | callback_rsgd = partial(callback_part, recfn=lambda x: x, apply_fn=apply_fn) 89 | agent = ev.load_rsgd_agent(cfg, tree_params, apply_fn, rsgd.lossfn_xentropy, dim_in, n_classes) 90 | bel, outputs_rsgd = agent.scan(X_train, Y_train, progress_bar=True, callback=callback_rsgd) 91 | bel = jax.block_until_ready(bel) 92 | outputs_rsgd = tree_to_cpu(outputs_rsgd) 93 | -------------------------------------------------------------------------------- /demos/gradually-rotating/run_gradually_increasing_angle.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this notebook, we consider the gradually-rotating MNIST 3 | problem for regression. The angle of rotation is a growing 4 | sinusoid. 5 | """ 6 | 7 | import jax 8 | import optax 9 | import numpy as np 10 | 11 | import run_main as ev 12 | from functools import partial 13 | from cfg_main import get_config 14 | from rebayes.utils.utils import tree_to_cpu 15 | from rebayes.utils.callbacks import cb_reg_sup 16 | from rebayes.sgd_filter import replay_sgd as rsgd 17 | from rebayes.datasets import rotating_mnist_data as data 18 | 19 | 20 | def damp_angle(n_configs, minangle, maxangle): 21 | t = np.linspace(0, 1.5, n_configs) 22 | angles = np.exp(t) * np.sin(35 * t) 23 | angles = (angles + 1) / 2 * (maxangle - minangle) + minangle + np.random.randn(n_configs) * 2 24 | return angles 25 | 26 | 27 | if __name__ == "__main__": 28 | cfg = get_config() 29 | 30 | num_train = None 31 | frac_train = 1.0 32 | target_digit = 2 33 | data = data.load_and_transform( 34 | damp_angle, target_digit, num_train, frac_train, sort_by_angle=False 35 | ) 36 | 37 | X_train, Y_train, labels_train = data["dataset"]["train"] 38 | 39 | 40 | # TODO: Refactor into LoFi regression 41 | dim_out = 1 42 | _, dim_in = X_train.shape 43 | model, tree_params, flat_params, recfn = ev.make_bnn_flax(dim_in, dim_out) 44 | apply_fn_flat = partial(ev.apply_fn_flat, model=model, recfn=recfn) 45 | apply_fn_tree = partial(ev.apply_fn_unflat, model=model) 46 | def emission_mean_fn(w, x): return apply_fn_flat(w, x) 47 | def emission_cov_fn(w, x): return 0.02 48 | 49 | ymean, ystd = data["ymean"], data["ystd"] 50 | callback = partial(cb_reg_sup, 51 | X_test=X_train, y_test=Y_train, 52 | ymean=ymean, ystd=ystd, 53 | ) 54 | 55 | callback_lofi = partial(callback, apply_fn=emission_mean_fn) 56 | callback_rsgd = partial(callback, apply_fn=apply_fn_tree) 57 | 58 | ### LoFi---load and train 59 | agent = ev.load_lofi_agent(cfg, flat_params, emission_mean_fn, emission_cov_fn) 60 | bel, output_lofi = agent.scan(X_train, Y_train, progress_bar=True, callback=callback_lofi) 61 | bel = jax.block_until_ready(bel) 62 | bel = tree_to_cpu(bel) 63 | output_lofi = tree_to_cpu(output_lofi) 64 | 65 | 66 | ### RSGD---load and train 67 | lr = 1e-2 68 | tx = optax.sgd(lr) 69 | agent = ev.load_rsgd_agent(cfg, tree_params, apply_fn_tree, rsgd.lossfn_rmse, dim_in, dim_out, tx=tx) 70 | bel, output_rsgd = agent.scan(X_train, Y_train, progress_bar=True, callback=callback_rsgd) 71 | bel = jax.block_until_ready(bel) 72 | output_rsgd = tree_to_cpu(output_rsgd) 73 | 74 | 75 | ### RSGD (ADAM)---load and train 76 | lr = 5e-3 77 | tx = optax.adam(lr) 78 | agent = ev.load_rsgd_agent(cfg, tree_params, apply_fn_tree, rsgd.lossfn_rmse, dim_in, dim_out, tx=tx) 79 | bel, output_rsgd_adam = agent.scan(X_train, Y_train, progress_bar=True, callback=callback_rsgd) 80 | bel = jax.block_until_ready(bel) 81 | output_rsgd_adam = tree_to_cpu(output_rsgd_adam) 82 | -------------------------------------------------------------------------------- /demos/gradually-rotating/run_lofi_hparam_clf.py: -------------------------------------------------------------------------------- 1 | """ 2 | We evaluate the performance of LoFi on the 3 | rotating MNIST dataset for classification 4 | with sinusoidal angle of rotation. 5 | """ 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import flax.linen as nn 11 | from copy import deepcopy 12 | from tqdm.auto import tqdm 13 | from itertools import product 14 | from functools import partial 15 | from cfg_main import get_config 16 | 17 | import run_main as ev 18 | from rebayes.utils.utils import tree_to_cpu 19 | 20 | 21 | def emission_cov_function(w, x, fn_mean): 22 | """ 23 | Compute the covariance matrix of the emission distribution. 24 | fn_mean: emission mean function 25 | """ 26 | ps = fn_mean(w, x) 27 | n_classes = len(ps) 28 | I = jnp.eye(n_classes) 29 | return jnp.diag(ps) - jnp.outer(ps, ps) + 1e-3 * I 30 | 31 | 32 | target_digits = [0, 1, 2, 3, 4, 5] 33 | n_classes = len(target_digits) 34 | num_train = 6_000 35 | 36 | data = ev.load_data(ev.damp_angle, target_digits, num_train, sort_by_angle=False) 37 | X_train, signal_train, labels_train = data["dataset"]["train"] 38 | X_test, signal_test, labels_test = data["dataset"]["test"] 39 | Y_train = ev.categorise(labels_train) 40 | Y_test = ev.categorise(labels_test) 41 | 42 | cfg = get_config() 43 | 44 | # TODO: Refactor into LoFi classification 45 | _, dim_in = X_train.shape 46 | model, tree_params, flat_params, recfn = ev.make_bnn_flax(dim_in, n_classes) 47 | apply_fn = partial(ev.apply_fn_flat, model=model, recfn=recfn) 48 | def emission_mean_fn(w, x): return nn.softmax(apply_fn(w, x)) 49 | emission_cov_fn = partial(emission_cov_function, fn_mean=emission_mean_fn) 50 | 51 | _, dim_in = data["dataset"]["train"][0].shape 52 | 53 | callback_lofi = partial(ev.callback, 54 | apply_fn=emission_mean_fn, 55 | X_test=X_train, y_test=Y_train, 56 | recfn=recfn, 57 | ) 58 | 59 | outputs_all = [] 60 | 61 | # eps = np.array([0, 1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4]) 62 | eps = np.array([0, 1e-6, 1e-5, 1e-4]) 63 | list_dynamics_weights = 1 - eps 64 | list_dynamics_covariance = eps.copy() 65 | 66 | elements = list(product(list_dynamics_weights, list_dynamics_covariance)) 67 | 68 | for dynamics_weight, dynamics_covariance in tqdm(elements): 69 | dynamics_weight = float(dynamics_weight) 70 | dynamics_covariance = float(dynamics_covariance) 71 | cfg_lofi = deepcopy(cfg) 72 | 73 | cfg_lofi.lofi.dynamics_weight = dynamics_weight 74 | cfg_lofi.lofi.dynamics_covariance = dynamics_covariance 75 | 76 | agent = ev.load_lofi_agent(cfg_lofi, flat_params, emission_mean_fn, emission_cov_fn) 77 | bel, output = agent.scan(X_train, Y_train, progress_bar=False, callback=callback_lofi) 78 | bel = jax.block_until_ready(bel) 79 | 80 | bel = tree_to_cpu(bel) 81 | output = tree_to_cpu(output) 82 | 83 | res = { 84 | "config": cfg_lofi, 85 | "outputs": output, 86 | "bel": bel, 87 | } 88 | 89 | outputs_all.append(res) 90 | -------------------------------------------------------------------------------- /demos/gradually-rotating/run_main.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this script, we consider the gradually-rotating 3 | MNIST problem for classification. We analyse the 4 | effect of the dynamics_weights (gamma) parameter 5 | and the dynamics_covariance (Q) parameter. 6 | We take an inflation factor of 0.0 7 | """ 8 | 9 | import jax 10 | import optax 11 | import numpy as np 12 | import jax.numpy as jnp 13 | import flax.linen as nn 14 | from functools import partial 15 | from typing import Callable 16 | from jax.flatten_util import ravel_pytree 17 | 18 | from rebayes import base 19 | from rebayes.low_rank_filter import lofi 20 | from rebayes.sgd_filter import replay_sgd as rsgd 21 | 22 | 23 | class MLP(nn.Module): 24 | n_out: int = 1 25 | n_hidden: int = 100 26 | activation: Callable = nn.elu 27 | 28 | @nn.compact 29 | def __call__(self, x): 30 | x = nn.Dense(self.n_hidden)(x) 31 | x = self.activation(x) 32 | x = nn.Dense(self.n_hidden)(x) 33 | x = self.activation(x) 34 | x = nn.Dense(self.n_hidden)(x) 35 | x = self.activation(x) 36 | x = nn.Dense(self.n_out, name="last-layer")(x) 37 | return x 38 | 39 | 40 | def make_bnn_flax(dim_in, dim_out, nhidden=50): 41 | key = jax.random.PRNGKey(314) 42 | model = MLP(dim_out, nhidden) 43 | params = model.init(key, jnp.ones((1, dim_in))) 44 | flat_params, recfn = ravel_pytree(params) 45 | return model, params, flat_params, recfn 46 | 47 | 48 | def apply_fn_unflat(params, x, model): 49 | return model.apply(params, x) 50 | 51 | 52 | def apply_fn_flat(flat_params, x, model, recfn): 53 | return model.apply(recfn(flat_params), x) 54 | 55 | 56 | def load_lofi_agent( 57 | cfg, 58 | mean_init, 59 | emission_mean_fn, 60 | emission_cov_fn, 61 | ): 62 | ssm_params = base.RebayesParams( 63 | initial_mean=mean_init, 64 | initial_covariance=cfg.lofi.initial_covariance, 65 | dynamics_weights=cfg.lofi.dynamics_weight, 66 | dynamics_covariance=cfg.lofi.dynamics_covariance, 67 | emission_mean_function=emission_mean_fn, 68 | emission_cov_function=emission_cov_fn, 69 | dynamics_covariance_inflation_factor=0.0 70 | ) 71 | 72 | lofi_params = lofi.LoFiParams(memory_size=cfg.memory, steady_state=False, inflation="hybrid") 73 | 74 | agent = lofi.RebayesLoFiDiagonal(ssm_params, lofi_params) 75 | return agent 76 | 77 | 78 | def load_rsgd_agent( 79 | cfg, 80 | mean_init, 81 | apply_fn, 82 | lossfn, 83 | dim_in, 84 | dim_out, 85 | tx=None 86 | ): 87 | if tx is None: 88 | tx = optax.adam(learning_rate=cfg.rsgd.learning_rate) 89 | 90 | agent = rsgd.FifoSGD(lossfn, 91 | apply_fn=apply_fn, 92 | init_params=mean_init, 93 | tx=tx, 94 | buffer_size=cfg.memory, 95 | dim_features=dim_in, 96 | dim_output=dim_out, 97 | n_inner=cfg.rsgd.n_inner 98 | ) 99 | 100 | return agent 101 | -------------------------------------------------------------------------------- /demos/showdown/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demos/showdown/build_uci_method_df.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import jax 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | from rebayes.datasets import uci_uncertainty_data 8 | 9 | def get_subtree(tree, key): 10 | return jax.tree_map(lambda x: x[key], tree, is_leaf=lambda x: key in x) 11 | 12 | 13 | def extract_data(files, base_path): 14 | regexp = re.compile("rank([0-9]+).pkl") 15 | data_all = {} 16 | for file in files: 17 | m = regexp.findall(file) 18 | if len(m) == 0: 19 | continue 20 | rank = 50 21 | else: 22 | rank = int(m[0]) 23 | 24 | file_path = os.path.join(base_path, file) 25 | with open(file_path, "rb") as f: 26 | data = pickle.load(f) 27 | data_all[rank] = data 28 | return data_all 29 | 30 | 31 | 32 | def extract_filenames(dataset, base_path): 33 | files = os.listdir(path) 34 | files_target = [file for file in files if (dataset in file) and ("pass" not in file)] 35 | return files_target 36 | 37 | 38 | def build_df_summary(data, dataset_name): 39 | """ 40 | Summary over the last-observed value 41 | """ 42 | agent_last = jax.tree_map(lambda x: x[:, -1], data) 43 | df_summary = [] 44 | for key in agent_last: 45 | piece = pd.DataFrame(agent_last[key]) 46 | 47 | if key != 1: 48 | drop_cols = ["fdekf", "vdekf"] 49 | piece = piece.drop(drop_cols, axis=1) 50 | if key != 2: 51 | drop_cols = ["fcekf"] 52 | piece = piece.drop(drop_cols, axis=1) 53 | 54 | 55 | piece = piece.melt() 56 | piece["rank"] = key 57 | df_summary.append(piece) 58 | df_summary = pd.concat(df_summary).dropna(axis=0) 59 | df_summary = df_summary.query("variable != 'lofi_orth'") 60 | 61 | df_summary.loc[df_summary["variable"] == "fcekf", "rank"] = "full" 62 | df_summary.loc[df_summary["variable"] == "fdekf", "rank"] = 0 63 | df_summary.loc[df_summary["variable"] == "vdekf", "rank"] = 0 64 | df_summary = df_summary.assign(dataset=dataset_name) 65 | return df_summary 66 | 67 | 68 | if __name__ == "__main__": 69 | path = "./output/cross-validation" 70 | dataset_path = "/home/gerardoduran/documents/external/DropoutUncertaintyExps/UCI_Datasets" 71 | all_files = os.listdir(path) 72 | datasets = list(set([f.split("_")[0].split(".")[0] for f in all_files])) 73 | 74 | methods_eval = ["lrvga", "sgd-rb", "lofi"] 75 | void_datasets = ["protein-tertiary-structure"] 76 | datasets = [d for d in datasets if d not in void_datasets] 77 | 78 | df_all = [] 79 | for dataset in datasets: 80 | files_target = extract_filenames(dataset, path) 81 | data_dataset = extract_data(files_target, path) 82 | 83 | 84 | data = jax.tree_map( 85 | lambda x: np.atleast_2d(x).mean(axis=0)[-1], data_dataset 86 | ) 87 | 88 | df = [] 89 | for mem, sub in data.items(): 90 | df_part = pd.DataFrame.from_dict(sub, orient="index") 91 | df_part["memory"] = mem 92 | df.append(df_part) 93 | df = pd.concat(df) 94 | df.index.name = "model" 95 | 96 | df = df.reset_index() 97 | df = df.query("model in @methods_eval") 98 | 99 | df = df.assign( 100 | metric=df["test"] / df["running_time"] 101 | ) 102 | 103 | rmin, rmax = df["test"].min(), df["test"].max() 104 | 105 | df["std_test"] = (df["test"] - rmin) / (rmax - rmin) 106 | 107 | df["dataset"] = dataset 108 | 109 | ix = 0 110 | data_path = os.path.join(dataset_path, dataset, "data") 111 | res = uci_uncertainty_data.load_data(data_path, ix) 112 | 113 | n_obs, *_ = res["dataset"]["train"][1].shape 114 | n_obs 115 | 116 | # ~Seconds per datapoint 117 | df["log_running_time_dp"] = np.log(df["running_time"] / n_obs) 118 | df_all.append(df) 119 | df_all = pd.concat(df_all) 120 | df_all.to_pickle("uci-models-results.pkl") 121 | print("Done!") 122 | -------------------------------------------------------------------------------- /demos/showdown/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/showdown/classification/__init__.py -------------------------------------------------------------------------------- /demos/showdown/classification/ablation_inflation_types.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from rebayes.low_rank_filter.lofi import LoFiParams 5 | from demos.showdown.classification import classification_train as benchmark 6 | from demos.collas.classification.stationary_mnist_clf import train_agent 7 | 8 | 9 | INFLATION_TYPES = ("bayesian", "hybrid", "simple") 10 | 11 | 12 | def train_lofi_ablation_type(infl_type="bayesian"): 13 | if infl_type not in INFLATION_TYPES: 14 | raise ValueError(f"Unknown inflation type: {infl_type}") 15 | 16 | lofi_ranks = (1, 5,) 17 | lofi_agents = { 18 | f'lofi-{rank}': { 19 | 'lofi_params': LoFiParams(memory_size=rank, diagonal_covariance=True), 20 | 'inflation': infl_type, 21 | } for rank in lofi_ranks 22 | } 23 | 24 | results = {} 25 | for agent, kwargs in lofi_agents.items(): 26 | miscl = train_agent(model_dict, dataset, agent_type=agent, **kwargs) 27 | results[f'{agent}_miscl'] = miscl 28 | 29 | return results 30 | 31 | 32 | if __name__ == "__main__": 33 | fashion = True 34 | 35 | output_path = os.environ.get("REBAYES_OUTPUT") 36 | if output_path is None: 37 | dataset_name = "mnist" if not fashion else "f-mnist" 38 | output_path = Path(Path.cwd(), "output", "ablation", "stationary") 39 | output_path.mkdir(parents=True, exist_ok=True) 40 | print(f"Output path: {output_path}") 41 | 42 | dataset = benchmark.load_mnist_dataset(fashion=fashion) # load data 43 | model_dict = benchmark.init_model(type='mlp', features=(100, 100, 10)) # initialize model 44 | 45 | for infl_type in INFLATION_TYPES: 46 | curr_output_path = Path(output_path, infl_type) 47 | curr_output_path.mkdir(parents=True, exist_ok=True) 48 | 49 | results = train_lofi_ablation_type(infl_type) 50 | for key, val in results.items(): 51 | benchmark.store_results(val, f'{key}_miscl', curr_output_path) -------------------------------------------------------------------------------- /demos/showdown/nonstationary/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/showdown/nonstationary/__init__.py -------------------------------------------------------------------------------- /demos/showdown/regression/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demos/showdown/regression/hparam_tune_clf.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import jax.numpy as jnp 4 | from rebayes import base 5 | from functools import partial 6 | from jax.flatten_util import ravel_pytree 7 | from bayes_opt import BayesianOptimization 8 | from rebayes.extended_kalman_filter import ekf 9 | from rebayes.low_rank_filter import lofi 10 | 11 | 12 | def apply(flat_params, x, model, unflatten_fn): 13 | return model.apply(unflatten_fn(flat_params), x) 14 | 15 | 16 | def bbf_lofi( 17 | log_init_cov, 18 | dynamics_weights, 19 | # Specify before running 20 | emission_cov_fn, 21 | train, 22 | test, 23 | flat_params, 24 | callback, 25 | apply_fn, 26 | params_lofi, 27 | method, 28 | ): 29 | """ 30 | Black-box function for Bayesian optimization. 31 | """ 32 | X_train, y_train = train 33 | X_test, y_test = test 34 | 35 | dynamics_covariance = None 36 | initial_covariance = jnp.exp(log_init_cov).item() 37 | 38 | test_callback_kwargs = {"X_test": X_test, "y_test": y_test, "apply_fn": apply_fn} 39 | params_rebayes = base.RebayesParams( 40 | initial_mean=flat_params, 41 | initial_covariance=initial_covariance, 42 | dynamics_weights=dynamics_weights, 43 | dynamics_covariance=dynamics_covariance, 44 | emission_mean_function=apply_fn, 45 | emission_cov_function=emission_cov_fn, 46 | ) 47 | 48 | estimator = lofi.RebayesLoFi(params_rebayes, params_lofi, method=method) 49 | 50 | bel, _ = estimator.scan(X_train, y_train, progress_bar=False) 51 | metric = callback(bel, **test_callback_kwargs)["test"].item() 52 | return metric 53 | 54 | 55 | 56 | def bbf_ekf( 57 | log_init_cov, 58 | dynamics_weights, 59 | # Specify before running 60 | emission_cov_fn, 61 | train, 62 | test, 63 | flat_params, 64 | callback, 65 | apply_fn, 66 | method="fdekf", 67 | ): 68 | """ 69 | Black-box function for Bayesian optimization. 70 | """ 71 | X_train, y_train = train 72 | X_test, y_test = test 73 | 74 | dynamics_covariance = None 75 | initial_covariance = jnp.exp(log_init_cov).item() 76 | 77 | test_callback_kwargs = {"X_test": X_test, "y_test": y_test, "apply_fn": apply_fn} 78 | params_rebayes = base.RebayesParams( 79 | initial_mean=flat_params, 80 | initial_covariance=initial_covariance, 81 | dynamics_weights=dynamics_weights, 82 | dynamics_covariance=dynamics_covariance, 83 | emission_mean_function=apply_fn, 84 | emission_cov_function=emission_cov_fn, 85 | ) 86 | 87 | estimator = ekf.RebayesEKF(params_rebayes, method=method) 88 | 89 | bel, _ = estimator.scan(X_train, y_train, progress_bar=False) 90 | metric = callback(bel, **test_callback_kwargs)["test"].item() 91 | return metric 92 | 93 | 94 | def create_optimizer( 95 | model, 96 | bounds, 97 | random_state, 98 | train, 99 | test, 100 | callback=None, 101 | method="fdekf", 102 | **kwargs 103 | ): 104 | key = jax.random.PRNGKey(random_state) 105 | X_train, _ = train 106 | _, n_features = X_train.shape 107 | 108 | batch_init = jnp.ones((1, n_features)) 109 | params_init = model.init(key, batch_init) 110 | flat_params, recfn = ravel_pytree(params_init) 111 | 112 | apply_fn = partial(apply, model=model, unflatten_fn=recfn) 113 | def emission_cov_fn(w, x): 114 | return apply_fn(w, x) * (1 - apply_fn(w, x)) 115 | 116 | if "ekf" in method: 117 | bbf = bbf_ekf 118 | elif "lofi" in method: 119 | bbf = bbf_lofi 120 | 121 | bbf_partial = partial( 122 | bbf, 123 | train=train, 124 | test=test, 125 | flat_params=flat_params, 126 | callback=callback, 127 | apply_fn=apply_fn, 128 | emission_cov_fn=emission_cov_fn, 129 | method=method, 130 | **kwargs # Must include params_lofi if method is lofi 131 | ) 132 | 133 | optimizer = BayesianOptimization( 134 | f=bbf_partial, 135 | pbounds=bounds, 136 | random_state=random_state, 137 | ) 138 | 139 | return optimizer, apply_fn, n_features 140 | 141 | 142 | def get_best_params(num_params, optimizer, method): 143 | max_params = optimizer.max["params"].copy() 144 | 145 | dynamics_covariance = None 146 | initial_covariance = np.exp(max_params["log_init_cov"]) 147 | dynamics_weights = max_params["dynamics_weights"] 148 | 149 | hparams = { 150 | "initial_covariance": initial_covariance, 151 | "dynamics_covariance": dynamics_covariance, 152 | "dynamics_weights": dynamics_weights, 153 | } 154 | 155 | return hparams 156 | 157 | 158 | def build_estimator(init_mean, hparams, apply_fn, method, **kwargs): 159 | """ 160 | _ is a dummy parameter for compatibility with lofi 161 | """ 162 | def emission_cov_fn(w, x): 163 | return apply_fn(w, x) * (1 - apply_fn(w, x)) 164 | 165 | params = base.RebayesParams( 166 | initial_mean=init_mean, 167 | emission_mean_function=apply_fn, 168 | emission_cov_function=emission_cov_fn, 169 | **hparams, 170 | ) 171 | 172 | if "ekf" in method: 173 | estimator = ekf.RebayesEKF(params, method=method) 174 | elif "lofi" in method: 175 | estimator = lofi.RebayesLoFi(params, method=method, **kwargs) 176 | return estimator 177 | -------------------------------------------------------------------------------- /demos/showdown/regression/hparam_tune_ekf.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import jax.numpy as jnp 4 | from rebayes import base 5 | from functools import partial 6 | from jax.flatten_util import ravel_pytree 7 | from bayes_opt import BayesianOptimization 8 | from rebayes.extended_kalman_filter import ekf 9 | 10 | 11 | def bbf( 12 | log_init_cov, 13 | dynamics_weights, 14 | log_emission_cov, 15 | dynamics_log_cov, 16 | # Specify before running 17 | train, 18 | test, 19 | flat_params, 20 | callback, 21 | apply_fn, 22 | method="fdekf", 23 | emission_mean_function=None, 24 | emission_cov_function=None, 25 | ): 26 | """ 27 | Black-box function for Bayesian optimization. 28 | """ 29 | X_train, y_train = train 30 | X_test, y_test = test 31 | 32 | dynamics_covariance = jnp.exp(dynamics_log_cov).item() 33 | initial_covariance = jnp.exp(log_init_cov).item() 34 | if emission_mean_function is None: 35 | emission_mean_function = apply_fn 36 | if emission_cov_function is None: 37 | def emission_cov_function(w, x): return jnp.exp(log_emission_cov) 38 | 39 | test_callback_kwargs = {"X_test": X_test, "y_test": y_test, "apply_fn": apply_fn} 40 | params_rebayes = base.RebayesParams( 41 | initial_mean=flat_params, 42 | initial_covariance=initial_covariance, 43 | dynamics_weights=dynamics_weights, 44 | dynamics_covariance=dynamics_covariance, 45 | emission_mean_function=emission_mean_function, 46 | emission_cov_function=emission_cov_function, 47 | ) 48 | 49 | estimator = ekf.RebayesEKF(params_rebayes, method=method) 50 | 51 | bel, _ = estimator.scan(X_train, y_train, progress_bar=False) 52 | metric = callback(bel, **test_callback_kwargs)["test"].item() 53 | return -metric 54 | 55 | 56 | def apply(flat_params, x, model, unflatten_fn): 57 | return model.apply(unflatten_fn(flat_params), x) 58 | 59 | 60 | def create_optimizer( 61 | model, 62 | bounds, 63 | random_state, 64 | train, 65 | test, 66 | callback=None, 67 | method="fdekf", 68 | emission_mean_function=None, 69 | emission_cov_function=None, 70 | ): 71 | key = jax.random.PRNGKey(random_state) 72 | X_train, _ = train 73 | _, *n_features = X_train.shape 74 | 75 | batch_init = jnp.ones((1, *n_features)) 76 | params_init = model.init(key, batch_init) 77 | flat_params, recfn = ravel_pytree(params_init) 78 | 79 | apply_fn = partial(apply, model=model, unflatten_fn=recfn) 80 | bbf_partial = partial( 81 | bbf, 82 | train=train, 83 | test=test, 84 | flat_params=flat_params, 85 | callback=callback, 86 | apply_fn=apply_fn, 87 | method=method, 88 | emission_mean_function=emission_mean_function, 89 | emission_cov_function=emission_cov_function, 90 | ) 91 | 92 | optimizer = BayesianOptimization( 93 | f=bbf_partial, 94 | pbounds=bounds, 95 | random_state=random_state, 96 | allow_duplicate_points=True, 97 | ) 98 | 99 | return optimizer, apply_fn, n_features 100 | 101 | 102 | def get_best_params(num_params, optimizer, method="fdekf"): 103 | if type(optimizer) is dict: 104 | max_params = optimizer.copy() 105 | else: 106 | max_params = optimizer.max["params"].copy() 107 | 108 | dynamics_covariance = 0.0 109 | initial_covariance = np.exp(max_params["log_init_cov"]) 110 | dynamics_weights = max_params["dynamics_weights"] 111 | emission_cov = np.exp(max_params.get("log_emission_cov", 0.0)) 112 | 113 | 114 | def emission_cov_function(w, x): return emission_cov 115 | hparams = { 116 | "initial_covariance": initial_covariance, 117 | "dynamics_covariance": dynamics_covariance, 118 | "dynamics_weights": dynamics_weights, 119 | "emission_cov_function": emission_cov_function, 120 | } 121 | 122 | return hparams 123 | 124 | def build_estimator(init_mean, hparams, _, apply_fn, method="fdekf", 125 | emission_mean_function=None, emission_cov_function=None): 126 | """ 127 | _ is a dummy parameter for compatibility with lofi 128 | """ 129 | if emission_mean_function is None: 130 | emission_mean_function = apply_fn 131 | if emission_cov_function is None: 132 | params = base.RebayesParams( 133 | initial_mean=init_mean, 134 | emission_mean_function=emission_mean_function, 135 | **hparams, 136 | ) 137 | else: 138 | params = base.RebayesParams( 139 | initial_mean=init_mean, 140 | emission_mean_function=emission_mean_function, 141 | emission_cov_function=emission_cov_function, 142 | **hparams, 143 | ) 144 | 145 | estimator = ekf.RebayesEKF(params, method=method) 146 | return estimator 147 | -------------------------------------------------------------------------------- /demos/showdown/regression/hparam_tune_lofi.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import jax.numpy as jnp 4 | from rebayes import base 5 | from functools import partial 6 | from jax.flatten_util import ravel_pytree 7 | from bayes_opt import BayesianOptimization 8 | from rebayes.low_rank_filter import lofi 9 | 10 | 11 | def bbf( 12 | log_init_cov, 13 | log_dynamics_weights, 14 | dynamics_log_cov, 15 | log_emission_cov, 16 | log_inflation, 17 | # Specify before running 18 | train, 19 | test, 20 | flat_params, 21 | callback, 22 | apply_fn, 23 | params_lofi, 24 | method="full_svd_lofi", # TODO: Deprecate this 25 | emission_mean_function=None, 26 | emission_cov_function=None, 27 | ): 28 | """ 29 | Black-box function for Bayesian optimization. 30 | """ 31 | X_train, y_train = train 32 | X_test, y_test = test 33 | 34 | initial_covariance = jnp.exp(log_init_cov).item() 35 | inflation = jnp.exp(log_inflation) 36 | dynamics_weights = jnp.exp(log_dynamics_weights) 37 | dynamics_covariance = jnp.exp(dynamics_log_cov) 38 | if emission_mean_function is None: 39 | emission_mean_function = apply_fn 40 | if emission_cov_function is None: 41 | def emission_cov_function(w, x): return jnp.exp(log_emission_cov) 42 | 43 | 44 | test_callback_kwargs = {"X_test": X_test, "y_test": y_test, "apply_fn": apply_fn} 45 | params_rebayes = base.RebayesParams( 46 | initial_mean=flat_params, 47 | initial_covariance=initial_covariance, 48 | dynamics_weights=dynamics_weights, 49 | dynamics_covariance=dynamics_covariance, 50 | emission_mean_function=emission_mean_function, 51 | emission_cov_function=emission_cov_function, 52 | dynamics_covariance_inflation_factor=inflation, 53 | ) 54 | 55 | estimator = lofi.RebayesLoFiDiagonal(params_rebayes, params_lofi) 56 | 57 | bel, _ = estimator.scan(X_train, y_train, progress_bar=False) 58 | metric = callback(bel, **test_callback_kwargs)["test"].item() 59 | isna = np.isnan(metric) 60 | metric = 10 if isna else metric 61 | return -metric 62 | 63 | 64 | def apply(flat_params, x, model, unflatten_fn): 65 | return model.apply(unflatten_fn(flat_params), x) 66 | 67 | 68 | def create_optimizer( 69 | model, 70 | bounds, 71 | random_state, 72 | train, 73 | test, 74 | params_lofi, 75 | callback=None, 76 | method="full_svd_lofi", 77 | emission_mean_function=None, 78 | emission_cov_function=None, 79 | ): 80 | bounds = bounds.copy() 81 | key = jax.random.PRNGKey(random_state) 82 | X_train, _ = train 83 | _, *n_params = X_train.shape 84 | 85 | batch_init = jnp.ones((1, *n_params)) 86 | params_init = model.init(key, batch_init) 87 | flat_params, recfn = ravel_pytree(params_init) 88 | 89 | kwargs = {} 90 | if bounds.get("log_inflation") is None: 91 | bounds.pop("log_inflation", None) 92 | kwargs["log_inflation"] = -np.inf 93 | 94 | 95 | apply_fn = partial(apply, model=model, unflatten_fn=recfn) 96 | bbf_partial = partial( 97 | bbf, 98 | train=train, 99 | test=test, 100 | flat_params=flat_params, 101 | callback=callback, 102 | apply_fn=apply_fn, 103 | params_lofi=params_lofi, 104 | method=method, 105 | emission_mean_function=emission_mean_function, 106 | emission_cov_function=emission_cov_function, 107 | **kwargs, 108 | ) 109 | 110 | # Fix log-emission-covariance to dummy if adaptive 111 | if emission_cov_function is not None: 112 | bbf_partial = partial( 113 | bbf_partial, 114 | log_emission_cov=0.0, 115 | ) 116 | 117 | 118 | 119 | optimizer = BayesianOptimization( 120 | f=bbf_partial, 121 | pbounds=bounds, 122 | random_state=random_state, 123 | allow_duplicate_points=True, 124 | ) 125 | 126 | return optimizer, apply_fn, n_params 127 | 128 | 129 | def get_best_params(n_params, optimizer): 130 | if type(optimizer) is dict: 131 | max_params = optimizer.copy() 132 | else: 133 | max_params = optimizer.max["params"].copy() 134 | 135 | init_cov = np.exp(max_params["log_init_cov"]).item() 136 | emission_cov = np.exp(max_params.get("log_emission_cov", 0.0)) 137 | dynamics_weights = np.exp(max_params["log_dynamics_weights"]) 138 | dynamics_cov = np.exp(max_params.get("dynamics_log_cov")) 139 | inflation = np.exp(max_params.get("log_inflation", -np.inf)) 140 | 141 | def emission_cov_function(w, x): return emission_cov 142 | hparams = { 143 | "initial_covariance": init_cov, 144 | "dynamics_covariance": dynamics_cov, 145 | "dynamics_weights": dynamics_weights, 146 | "emission_cov_function": emission_cov_function, 147 | "dynamics_covariance_inflation_factor": inflation, 148 | } 149 | 150 | return hparams 151 | 152 | 153 | def build_estimator(init_mean, hparams, params_lofi, apply_fn, method="full_svd_lofi", 154 | emission_mean_function=None, emission_cov_function=None): 155 | if emission_mean_function is None: 156 | emission_mean_function = apply_fn 157 | if emission_cov_function is None: 158 | params = base.RebayesParams( 159 | initial_mean=init_mean, 160 | emission_mean_function=emission_mean_function, 161 | **hparams, 162 | ) 163 | else: 164 | params = base.RebayesParams( 165 | initial_mean=init_mean, 166 | emission_mean_function=emission_mean_function, 167 | emission_cov_function=emission_cov_function, 168 | **hparams, 169 | ) 170 | 171 | estimator = lofi.RebayesLoFiDiagonal(params, params_lofi) 172 | return estimator 173 | -------------------------------------------------------------------------------- /demos/showdown/regression/hparam_tune_sgd.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import optax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from functools import partial 6 | from rebayes.sgd_filter import replay_sgd as rsgd 7 | from bayes_opt import BayesianOptimization 8 | 9 | def bbf( 10 | learning_rate, 11 | n_inner, 12 | # specify before training 13 | train, 14 | test, 15 | rank, 16 | params, 17 | callback, 18 | apply_fn, 19 | lossfn, 20 | dim_in, 21 | dim_out=1, 22 | ): 23 | n_inner = round(n_inner) 24 | X_train, y_train = train 25 | X_test, y_test = test 26 | 27 | test_callback_kwargs = {"X_test": X_test, "y_test": y_test, "apply_fn": apply_fn} 28 | agent = rsgd.FifoSGD( 29 | lossfn, 30 | apply_fn=apply_fn, 31 | init_params=params, 32 | tx=optax.adam(learning_rate), 33 | buffer_size=rank, 34 | dim_features=dim_in, 35 | dim_output=dim_out, 36 | n_inner=n_inner, 37 | ) 38 | 39 | bel, _ = agent.scan(X_train, y_train, progress_bar=False) 40 | metric = callback(bel, **test_callback_kwargs)["test"].item() 41 | 42 | isna = np.isnan(metric) 43 | metric = 10 if isna else metric 44 | return -metric 45 | 46 | 47 | def create_optimizer( 48 | model, 49 | bounds, 50 | random_state, 51 | train, 52 | test, 53 | rank, 54 | lossfn, 55 | callback=None, 56 | ): 57 | key = jax.random.PRNGKey(random_state) 58 | apply_fn = model.apply 59 | 60 | _, dim_in = train[0].shape 61 | dim_out = train[1].shape 62 | if len(dim_out) > 1: 63 | dim_out = dim_out[1] 64 | else: 65 | dim_out = 1 66 | 67 | batch_init = jnp.ones((1, dim_in)) 68 | params_init = model.init(key, batch_init) 69 | 70 | bbf_partial = partial(bbf, 71 | train=train, 72 | test=test, 73 | rank=rank, 74 | params=params_init, 75 | apply_fn=apply_fn, 76 | callback=callback, 77 | lossfn=lossfn, 78 | dim_in=dim_in, 79 | dim_out=dim_out, 80 | ) 81 | 82 | optimizer = BayesianOptimization( 83 | f=bbf_partial, 84 | pbounds=bounds, 85 | random_state=random_state, 86 | allow_duplicate_points=True, 87 | ) 88 | 89 | return optimizer 90 | -------------------------------------------------------------------------------- /demos/showdown/regression/peter/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demos/stitching/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/demos/stitching/__init__.py -------------------------------------------------------------------------------- /demos/stitching/configs/mlp_[50, 50, 50]_n_train_200_fcekf.json: -------------------------------------------------------------------------------- 1 | {"initial_covariance": 0.00036470420309342444, "dynamics_covariance": 0.0, "dynamics_covariance_inflation_factor": 0.0, "dynamics_weights_or_function": 1.0} -------------------------------------------------------------------------------- /demos/stitching/configs/mlp_[50, 50, 50]_n_train_200_sgd.json: -------------------------------------------------------------------------------- 1 | {"learning_rate": 0.014740299433469772} -------------------------------------------------------------------------------- /misc/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/misc/README.md -------------------------------------------------------------------------------- /misc/nando-efk-mlp/data.txt: -------------------------------------------------------------------------------- 1 | 0.1,0.047044 2 | 5.5175,1.8392 3 | 1.2816,-0.50255 4 | 5.5888,1.8584 5 | 7.8471,3.2154 6 | 14.583,9.9644 7 | 13.856,9.3423 8 | 4.7291,1.3493 9 | -0.44297,2.7057 10 | -11.03,8.3824 11 | -1.0856,-1.1632 12 | -6.5568,5.1177 13 | -9.0599,4.8644 14 | -15.167,11.431 15 | -12.947,9.2256 16 | -3.1342,0.2365 17 | -1.0905,-0.10209 18 | -13.062,10.551 19 | -15.606,14.163 20 | -14.887,13.147 21 | -5.622,2.5892 22 | 0.69406,-1.8398 23 | 14.39,11.487 24 | 2.647,2.9821 25 | 2.7109,1.171 26 | 10.54,7.2814 27 | 15.482,13.213 28 | 13.657,8.8076 29 | 4.0515,1.326 30 | 0.14126,-1.3592 31 | 2.542,1.8615 32 | 16.92,12.329 33 | 16.147,11.191 34 | 6.9578,1.0302 35 | -0.9684,-4.9711 36 | -16.272,15.665 37 | -4.1244,1.392 38 | -0.38244,-1.23 39 | -8.9074,6.1719 40 | -14.815,8.2786 41 | -14.117,9.8078 42 | -4.8886,0.83392 43 | 0.6099,0.48131 44 | 13.416,9.4378 45 | 2.0836,-0.94849 46 | 4.1834,0.83622 47 | 9.4381,4.2507 48 | 15.185,12.264 49 | 13.322,10.082 50 | 3.3323,1.7051 51 | 0.92687,-0.79128 52 | 12.244,7.5654 53 | 15.518,11.048 54 | 15.203,10.733 55 | 6.1432,1.8824 56 | -0.96817,0.93613 57 | -15.926,12.299 58 | -3.428,0.74144 59 | -1.5767,0.049897 60 | -13.239,9.0367 61 | -16.227,12.994 62 | -14.44,10.428 63 | -4.5674,1.035 64 | 0.39851,-0.10334 65 | 10.228,4.9414 66 | 0.60364,-0.0052594 67 | 5.0884,1.8291 68 | 9.5676,4.9055 69 | 15.425,12.454 70 | 12.876,6.8609 71 | 3.0174,0.8787 72 | 1.5227,-0.45745 73 | 12.267,7.6865 74 | 15.733,11.561 75 | 14.604,11.421 76 | 5.3707,2.9283 77 | -0.90763,1.6441 78 | -15.121,10.851 79 | -2.7624,0.60374 80 | -2.4892,0.56679 81 | -11.437,4.6022 82 | -15.792,12.056 83 | -13.71,9.1079 84 | -3.9214,-1.6279 85 | -0.20428,1.2491 86 | -4.1518,-0.24357 87 | -14.983,9.8727 88 | -15.146,11.871 89 | -6.6364,1.8413 90 | 1.1121,2.0358 91 | 15.927,14.646 92 | 3.5845,0.46194 93 | 1.0142,-1.1971 94 | 13.45,4.5675 95 | 16.22,14.309 96 | 14.546,10.256 97 | 5.1441,1.4646 98 | -0.5276,3.3033 99 | -12.497,8.5475 100 | -1.5027,3.0937 101 | -5.9083,0.37199 102 | -9.0622,5.5302 103 | -15.181,11.366 104 | -12.941,9.2064 105 | -3.2211,0.067224 106 | -0.98007,0.91342 107 | -12.76,7.4585 108 | -15.646,11.251 109 | -14.952,8.9837 110 | -5.8415,1.9434 111 | 0.99819,2.5285 112 | 15.53,12.978 113 | 3.1122,-0.90401 114 | 1.6807,1.1639 115 | 13.134,7.779 116 | 16.188,13.003 117 | 14.166,8.9889 118 | 4.3839,2.2308 119 | -0.16843,0.11341 120 | -5.405,2.0453 121 | -0.25234,-1.1607 122 | 0.117,-0.25968 123 | 0.60613,-0.2988 124 | 3.3108,0.70037 125 | 5.3503,1.4582 126 | 12.815,8.2032 127 | 15.68,12.421 128 | 9.0604,4.0485 129 | -0.31941,0.20686 130 | -12.622,7.2012 131 | -4.653,0.92093 132 | 0.49917,-0.43195 133 | 12.14,6.3893 134 | 1.6007,0.48431 135 | 5.3708,1.6621 136 | 8.7456,3.853 137 | 15.06,10.1 138 | 13.287,9.9567 139 | 3.3856,0.94825 140 | 0.93502,-0.2966 141 | 12.581,7.9412 142 | 15.578,11.804 143 | 15.187,9.2222 144 | 6.2845,1.5821 145 | -1.0809,-1.1268 146 | -15.621,10.759 147 | -3.3446,-1.1946 148 | -1.3253,-0.74272 149 | -13.569,6.0166 150 | -16.294,14.841 151 | -14.519,11.399 152 | -4.7656,1.1019 153 | 0.39509,-0.051172 154 | 10.096,3.7302 155 | 0.60569,1.7732 156 | 4.9721,1.0057 157 | 9.3229,3.1083 158 | 15.337,14.099 159 | 12.877,7.9031 160 | 3.0815,-0.53458 161 | 1.4285,-0.39702 162 | 12.445,6.3198 163 | 15.503,10.162 164 | 14.706,14.929 165 | 5.5174,4.1676 166 | -0.81795,0.51401 167 | -14.844,9.1025 168 | -2.8172,-0.88425 169 | -2.2939,0.01006 170 | -11.589,7.8102 171 | -15.653,10.481 172 | -13.751,6.4905 173 | -4.0644,-0.92988 174 | -0.23871,0.38582 175 | -4.8245,1.5873 176 | -14.37,10.782 177 | -14.995,11.12 178 | -6.5248,2.2881 179 | 0.95993,-0.33094 180 | 15.969,13.365 181 | 3.7506,-0.16004 182 | 0.8091,0.28422 183 | 13.108,8.1922 184 | 16.059,12.765 185 | 14.686,10.95 186 | 5.0891,1.519 187 | -0.45804,-0.13413 188 | -11.439,6.5967 189 | -1.0813,0.029629 190 | -6.4963,2.119 191 | -8.8139,3.9264 192 | -15.159,11.556 193 | -13.046,8.6278 194 | -3.2585,0.52003 195 | -1.2713,0.052423 196 | -12.528,7.3227 197 | -15.697,11.229 198 | -15.079,10.999 199 | -5.856,1.8968 200 | 0.88509,-0.54993 201 | 15.537,11.885 202 | 3.3205,1.6662 203 | 1.4272,-0.98869 204 | 13.512,8.7564 205 | 16.284,13.086 206 | 14.439,9.7387 207 | 4.5219,0.66071 208 | -0.12176,1.3507 209 | -4.3507,1.1868 210 | -0.66135,0.24441 211 | -5.5819,2.3894 212 | -9.7105,2.0262 213 | -15.482,13.412 214 | -12.694,8.8584 215 | -2.7972,1.5874 216 | -2.0341,2.2019 217 | -11.127,6.5197 218 | -15.276,12.141 219 | -14.66,8.0983 220 | -5.2569,1.8131 221 | 0.67282,1.8632 222 | 14.188,7.3031 223 | 2.3006,-1.8484 224 | 3.5517,1.0202 225 | 9.6866,7.2162 226 | 15.2,12.286 227 | 13.344,9.157 228 | 3.7196,0.25219 229 | 0.53685,0.42085 230 | 10.003,4.3308 231 | 14.607,10.091 232 | 14.997,13.028 233 | 6.3522,3.3202 234 | -0.9714,1.0253 235 | -15.835,13.179 236 | -3.5591,1.0184 237 | -1.2083,0.058818 238 | -13.494,12.347 239 | -16.333,13.795 240 | -14.63,9.5465 241 | -------------------------------------------------------------------------------- /misc/nando-efk-mlp/ekfdemo1.m: -------------------------------------------------------------------------------- 1 | % PURPOSE : To estimate the input-output mapping with inputs x 2 | % and outputs y generated by the following nonlinear, 3 | % nonstationary state space model: 4 | % 5 | % x(t+1) = 0.5x(t) + [25x(t)]/[(1+x(t))^(2)] 6 | % + 8cos(1.2t) + process noise 7 | % y(t) = x(t)^(2) / 20 + measurement noise with time 8 | % varying covariance. 9 | % 10 | % using a multi-layer perceptron (MLP) and both the EKF and 11 | % EKF with evidence maximisation algorithm. 12 | 13 | % AUTHOR : Nando de Freitas - Thanks for the acknowledgement :-) 14 | % DATE : 08-09-98 15 | 16 | clear; 17 | echo off; 18 | 19 | % INITIALISATION AND PARAMETERS: 20 | % ============================= 21 | 22 | N = 240; % Number of time steps. 23 | t = 1:1:N; % Time. 24 | x0 = 0.1; % Initial input. 25 | x = zeros(N,1); % Input observation. 26 | y = zeros(N,1); % Output observation. 27 | x(1,1) = x0; % Initia input. 28 | actualR = 3; % Measurement noise variance. 29 | actualQ = .01; % Process noise variance. 30 | s1=10; % Neurons in the hidden layer. 31 | s2=1; % Neurons in the output layer - only one in this implementation. 32 | initVar= 10; % Variance of prior for weights. 33 | KalmanR = 3; % Kalman filter measurement noise covariance hyperparameter; 34 | KalmanQ = 1e-5; % Kalman filter process noise covariance hyperparameter; 35 | KalmanP = 10; % Kalman filter initial weights covariance hyperparameter; 36 | window = 10; % Length of moving window to estimate the time covariance. 37 | 38 | % GENERATE PROCESS AND MEASUREMENT NOISE: 39 | % ====================================== 40 | 41 | v = sqrt(actualR)*sin(0.05*t)'.*randn(N,1); 42 | w = sqrt(actualQ)*randn(N,1); 43 | 44 | 45 | 46 | figure(1) 47 | clf; 48 | subplot(221) 49 | plot(v); 50 | ylabel('Measurement Noise','fontsize',15); 51 | xlabel('Time'); 52 | subplot(222) 53 | plot(w); 54 | ylabel('Process Noise','fontsize',15); 55 | xlabel('Time'); 56 | 57 | % GENERATE INPUT-OUTPUT DATA: 58 | % ========================== 59 | y(1,1) = (x(1,1)^(2))./20 + v(1,1); 60 | for t=2:N, 61 | x(t,1) = 0.5*x(t-1,1) + 25*x(t-1,1)/(1+x(t-1,1)^(2)) + 8*cos(1.2*(t-1)) + w(t,1); 62 | y(t,1) = (x(t,1).^(2))./20 + v(t,1) ; 63 | end; 64 | 65 | subplot(223) 66 | plot(x) 67 | ylabel('Input','fontsize',15); 68 | xlabel('Time','fontsize',15); 69 | subplot(224) 70 | plot(y) 71 | ylabel('Output','fontsize',15); 72 | xlabel('Time','fontsize',15); 73 | fprintf('Press a key to continue') 74 | pause; 75 | fprintf('\n') 76 | fprintf('Training the MLP with EKF') 77 | fprintf('\n') 78 | 79 | % PERFORM EXTENDED KALMAN FILTERING TO TRAIN MLP: 80 | % ============================================== 81 | 82 | tInit=clock; 83 | [ekfp,theta,thetaR,PR,innovations] = mlpekf(x',y',s1,s2,KalmanR,KalmanQ,KalmanP,initVar,N); 84 | durationekf=etime(clock,tInit); 85 | errorekf=norm(ekfp(N/3:N)'-y(N/3:N)); 86 | 87 | % PERFORM SEQUENTIAL EKF WITH Q UPDATE: 88 | % ==================================== 89 | 90 | fprintf('Training the MLP with EKF and Q updating') 91 | fprintf('\n') 92 | [ekfp2,theta2,thetaR2,PR2,innovations2,qplot] = mlpekfQ(x',y',s1,s2,KalmanR,KalmanQ,KalmanP,initVar,window,N); 93 | durationekfmax=etime(clock,tInit); 94 | errorekfmax=norm(ekfp2(N/3:N)'-y(N/3:N)); 95 | 96 | fprintf('EKF error = %d',errorekf) 97 | fprintf('\n') 98 | fprintf('EKFMAX error = %d',errorekfmax) 99 | fprintf('\n') 100 | fprintf('EKF duration = %d seconds.\n',durationekf) 101 | fprintf('EKFMAX duration = %d seconds.\n',durationekfmax) 102 | 103 | % PLOT RESULTS: 104 | % ============ 105 | 106 | figure(1) 107 | clf; 108 | subplot(221) 109 | plot(1:length(x),y,'g',1:length(x),ekfp2,'r',1:length(x),ekfp,'b') 110 | legend('True value','EKFMAX estimate','EKF estimate'); 111 | ylabel('One-step-ahead prediction','fontsize',15) 112 | xlabel('Time','fontsize',15) 113 | subplot(222) 114 | plot(1:length(x),qplot) 115 | ylabel('Q parameter','fontsize',15) 116 | xlabel('Time','fontsize',15) 117 | subplot(223) 118 | plot(1:length(x),innovations2) 119 | ylabel('Innovations variance for EKFMAX','fontsize',15) 120 | xlabel('Time','fontsize',15) 121 | subplot(224); 122 | plot(1:length(x),innovations) 123 | ylabel('Innovations variance for EKF','fontsize',15) 124 | xlabel('Time','fontsize',15) 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /misc/nando-efk-mlp/mlpekf.m: -------------------------------------------------------------------------------- 1 | function [y,theta,thetaRecord,PRecord,OutputVariance] = mlpekf(x,d,s1,s2,Rparameter,Qparameter,KalmanP,initVar,tsteps) 2 | % PURPOSE: To simulate a standard EKF-MLP training algorithm. 3 | % INPUTS : - x = The network input. 4 | % - d = The network target vector. 5 | % - s1 = Number of neurons in the hidden layer. 6 | % - s2 = Number of neurons in the output layer (1). 7 | % - Rparameter = EKF measurement noise hyperparameter. 8 | % - Qparameter = EKF process noise hyperparameter. 9 | % - KalmanP = initial EKF covariance. 10 | % - initVar = prior variance of the weights. 11 | % - tsteps = Number of time steps (input error checking). 12 | % OUTPUTS : - y = The network output. 13 | % - theta = The final weights. 14 | % - thetaRecord = The weights at each time step. 15 | % - PRecord = The EKF covariance at each time step. 16 | % - OutputVariance = The innovations covariance. 17 | 18 | % AUTHOR : Nando de Freitas - Thanks for the acknowledgement :-) 19 | % DATE : 08-09-98 20 | 21 | 22 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% CHECKING %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 23 | if nargin < 9, error('Not enough input arguments.'); end 24 | 25 | % Check that the size of input (x) is N by L, where N is the dimension 26 | % of the input and L is the length of the data (number of samples). 27 | [N,L] = size(x); 28 | [D,L] = size(d); 29 | if (L ~= tsteps), error('d must be of size 1x(time steps).'), end 30 | 31 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% INITIALISATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 32 | 33 | T = s2*(s1+1) + s1*(N+1); % The 1 is for the bias terms. 34 | theta = sqrt(initVar)*(randn(T,1)); % Parameter vector. 35 | H = zeros(T,D); % Jacobian Matrix. 36 | K = zeros(T,D); % Kalman Gain matrix. 37 | P = sqrt(KalmanP)*eye(T,T); % Weight covariance matrix. 38 | R = Rparameter*eye(D); % Measurement noise covariance. 39 | Q = Qparameter*eye(T,T); % Process noise covariance. 40 | o1 = zeros(s1,1); 41 | y = zeros(s2,L); 42 | w2 = zeros(s2,s1+1); 43 | w1 = zeros(s1,N+1); 44 | thetaRecord=zeros(T,L); 45 | PRecord=zeros(T,T,L); 46 | OutputVariance=zeros(1,L); 47 | 48 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%% MAIN SAMPLES LOOP %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 49 | for samples = 1:L, 50 | 51 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% FEED FORWARD %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 52 | % fill in weight matrices using the parameter vector: 53 | 54 | for i = 1:s2, 55 | for j = 1:(s1+1), 56 | w2(i,j)= theta(i*(s1+1)+j-(s1+1),1); 57 | end; 58 | end; 59 | for i = 1:s1, 60 | for j = 1:(N+1), 61 | w1(i,j)= theta(s2*(s1+1) +i*(N+1)+j-(N+1),1); 62 | end; 63 | end; 64 | 65 | % Compute the network outputs for each layer: 66 | u1 = w1*[1 ; x(:,samples)]; 67 | o1 = 1./(1+exp(-u1)); 68 | u2 = w2*[1 ; o1]; 69 | y(:,samples)=u2; 70 | 71 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% FILL H %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 72 | % output layer: 73 | for i = 1:s2, 74 | for j = 1:(s1+1), 75 | if j==1 76 | H(i*(s1+1) + j - (s1+1) ,1)= 1; 77 | else 78 | H(i*(s1+1) + j - (s1+1) ,1)= o1(j-1,1); 79 | end; 80 | end; 81 | end; 82 | 83 | % Second layer: 84 | for i = 1:s1, 85 | for j = 1:(N+1), 86 | rhs = w2(1,i+1)*o1(i,1)*(1-o1(i,1)); 87 | if j==1 88 | H(s2*(s1+1) + i*(N+1) + j - (N+1) ,1) = rhs; 89 | else 90 | H(s2*(s1+1) + i*(N+1) + j - (N+1) ,1)= rhs * x(j-1,samples); 91 | end; 92 | end; 93 | end; 94 | 95 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%% KALMAN EQUATIONS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 96 | 97 | 98 | K = (P+Q) *H * ((R + H'*(P+Q)*H)^(-1)); 99 | theta = theta + K * (d(:,samples) - y(:,samples)); 100 | P = P - K*H'*(P+Q) + Q; 101 | 102 | OutputVariance(1,samples) = R + H'*(P)*H; 103 | thetaRecord(:,samples)=theta; 104 | PRecord(:,:,samples)=P; 105 | end; 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /misc/nando-efk-mlp/mlpekfQ.m: -------------------------------------------------------------------------------- 1 | function [y,theta,thetaRecord,PRecord,OutputVariance,qplot] = mlpekfmax(x,d,s1,s2,Rparameter,Qparameter,KalmanP,initVar,window,tsteps) 2 | % PURPOSE: To simulate a standard EKF-MLP training algorithm. 3 | % INPUTS : - x = The network input. 4 | % - d = The network target vector. 5 | % - s1 = Number of neurons in the hidden layer. 6 | % - s2 = Number of neurons in the output layer (1). 7 | % - Rparameter = EKF measurement noise hyperparameter. 8 | % - Qparameter = EKF process noise hyperparameter. 9 | % - KalmanP = Initial EKF covariance. 10 | % - initVar = Prior variance of the weights. 11 | % - window = Window length to compute time covariance. 12 | % - tsteps = Number of time steps (input error checking). 13 | % OUTPUTS : - y = The network output. 14 | % - theta = The final weights. 15 | % - thetaRecord = The weights at each time step. 16 | % - PRecord = The EKF covariance at each time step. 17 | % - OutputVariance = The innovations covariance. 18 | % - qplot = Qparameter at each time step. 19 | 20 | % AUTHOR : Nando de Freitas - Thanks for the acknowledgement :-) 21 | % DATE : 08-09-98 22 | 23 | 24 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% CHECKING %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 25 | if nargin < 10, error('Not enough input arguments.'); end 26 | 27 | % Check that the size of input (x) is N by L, where N is the dimension 28 | % of the input and L is the length of the data (number of samples). 29 | [N,L] = size(x); 30 | [D,L] = size(d); 31 | if (L ~= tsteps), error('d must be of size 1x(time steps).'), end 32 | 33 | 34 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% INITIALISATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 35 | 36 | [N,L] = size(x); 37 | [D,L] = size(d); 38 | 39 | T = s2*(s1+1) + s1*(N+1); % No weights. The 1 is for the bias terms. 40 | theta = sqrt(initVar)*(randn(T,1)); % Parameter vector. 41 | H = zeros(T,D); % Jacobian Matrix. 42 | K = zeros(T,D); % Kalman Gain matrix. 43 | P = sqrt(KalmanP)*eye(T,T); % Weight covariance matrix. 44 | R = Rparameter*eye(D); % Measurement noise covariance. 45 | Q = Qparameter*eye(T,T); % Process noise covariance. 46 | o1 = zeros(s1,1); 47 | y = zeros(s2,L); 48 | w2 = zeros(s2,s1+1); 49 | w1 = zeros(s1,N+1); 50 | 51 | thetaRecord=zeros(T,L); 52 | PRecord=zeros(T,T,L); 53 | r = zeros(s2,L); 54 | qplot=zeros(1,L); 55 | Htime=zeros(L,T); 56 | OutputVariance=zeros(1,L); 57 | 58 | 59 | 60 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%% MAIN SAMPLES LOOP %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 61 | for samples = 1:L, 62 | 63 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% FEED FORWARD %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 64 | % fill in weight matrices using the parameter vector: 65 | 66 | for i = 1:s2, 67 | for j = 1:(s1+1), 68 | w2(i,j)= theta(i*(s1+1)+j-(s1+1),1); 69 | end; 70 | end; 71 | for i = 1:s1, 72 | for j = 1:(N+1), 73 | w1(i,j)= theta(s2*(s1+1) +i*(N+1)+j-(N+1),1); 74 | end; 75 | end; 76 | 77 | % Compute the network outputs for each layer: 78 | u1 = w1*[1 ; x(:,samples)]; 79 | o1 = 1./(1+exp(-u1)); 80 | u2 = w2*[1 ; o1]; 81 | y(:,samples)=u2; 82 | 83 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% FILL H %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 84 | % output layer: 85 | for i = 1:s2, 86 | for j = 1:(s1+1), 87 | if j==1 88 | H(i*(s1+1) + j - (s1+1) ,1)= 1; 89 | else 90 | H(i*(s1+1) + j - (s1+1) ,1)= o1(j-1,1); 91 | end; 92 | end; 93 | end; 94 | 95 | % Second layer: 96 | for i = 1:s1, 97 | for j = 1:(N+1), 98 | rhs = w2(1,i+1)*o1(i,1)*(1-o1(i,1)); 99 | if j==1 100 | H(s2*(s1+1) + i*(N+1) + j - (N+1) ,1) = rhs; 101 | else 102 | H(s2*(s1+1) + i*(N+1) + j - (N+1) ,1)= rhs * x(j-1,samples); 103 | end; 104 | end; 105 | end; 106 | 107 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%% E - KALMAN EQUATIONS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 108 | 109 | Pold=P; 110 | Htime(samples,:)=H(:,1)'; 111 | r = d-y; 112 | 113 | K = (P+Q) *H * ((R + H'*(P+Q)*H)^(-1)); 114 | theta = theta + K * (d(:,samples) - y(:,samples)); 115 | P = P - K*H'*(P+Q) + Q; 116 | 117 | thetaRecord(:,samples)=theta; 118 | PRecord(:,:,samples)=P; 119 | 120 | OutputVariance(1,samples) = R + H'*(P)*H; 121 | 122 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% UPDATE Q %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 123 | 124 | if samples > window 125 | S = zeros(1,T); 126 | sumS = 0; 127 | for l = 1:window, 128 | if l==1, 129 | S = (1/window) * (Htime(samples,:)) ./ R.^(1/2); 130 | else 131 | S = (1/window) * sum( (Htime(samples-(l-1):samples,:)) ./ R.^(1/2) ); 132 | end; 133 | sumS=sumS + S*S'; 134 | end; 135 | CovTime=((1/window)*sum(r(:,samples-window+1:samples) ./ R.^(1/2))).^(2); 136 | CovEnsemble=S*Pold*S' + (1/window); 137 | CovDifference = CovTime - CovEnsemble; 138 | 139 | if CovDifference > 0 140 | Qparameter = CovDifference / sumS; 141 | else 142 | Qparameter = 0; 143 | end; 144 | Q=Qparameter*eye(T,T); 145 | end; 146 | qplot(1,samples)=Qparameter; 147 | 148 | end; 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /misc/rotated_digits_load.py: -------------------------------------------------------------------------------- 1 | # data from https://www.mathworks.com/help/deeplearning/ug/train-a-convolutional-neural-network-for-regression.html 2 | # save('rotated_digits.mat', 'XTrain', 'YTrain', 'XValidation', 'YValidation') 3 | 4 | import scipy 5 | dat = scipy.io.loadmat('/Users/kpmurphy/github/rebayes/misc/rotated_digits_matlab.mat') 6 | XTrain, YTrain, XValidation, YValidation = dat['XTrain'], dat['YTrain'], dat['XValidation'], dat['YValidation'] 7 | print(XTrain.shape, YTrain.shape, XValidation.shape, YValidation.shape) 8 | print(YTrain[:10]) -------------------------------------------------------------------------------- /misc/rotated_digits_matlab.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/misc/rotated_digits_matlab.mat -------------------------------------------------------------------------------- /misc/torch_scan.py: -------------------------------------------------------------------------------- 1 | # https://pastebin.com/VzJTcsuv 2 | # https://github.com/pytorch/pytorch/issues/50688 3 | 4 | import torch 5 | 6 | def bench_triton(f, name=None, iters=1000, warmup=5, display=True, profile=False): 7 | import time 8 | from triton.testing import do_bench 9 | 10 | for _ in range(warmup): 11 | f() 12 | if profile: 13 | with torch.profiler.profile() as prof: 14 | f() 15 | prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json") 16 | 17 | 18 | us_per_iter = do_bench(lambda: f())[0]*iters 19 | 20 | if name is None: 21 | res = us_per_iter 22 | else: 23 | res= f"{name}: {us_per_iter:.3f}us" 24 | 25 | if display: 26 | print(res) 27 | return res 28 | 29 | def bench(f, name='', iters=100, warmup=5, display=True, profile=False): 30 | import time 31 | for i in range(warmup): 32 | f() 33 | start = time.time() 34 | for i in range(iters): 35 | f() 36 | end = time.time() 37 | res = end-start 38 | print('time for {name}={res:0.2f}'.format(name=name, res=res)) 39 | return res 40 | 41 | def scan(alphas, betas, h): 42 | """Loop over a simple RNN. 43 | 44 | Args: 45 | alphas (torch.tensor): shape [B, T, C] 46 | betas (torch.tensor): shape [B, T, C] 47 | h (torch.tensor): shape [B, C] 48 | """ 49 | T = betas.shape[-2] 50 | hs = torch.zeros_like(betas) 51 | exp_alphas = torch.exp(alphas) 52 | for t in range(T): 53 | h = exp_alphas[:, t] * h + betas[:, t] 54 | hs[:, t] = h 55 | return hs 56 | 57 | T = 128 58 | B = 64 59 | C = 256 60 | if torch.cuda.is_available(): 61 | device = torch.device('cuda') 62 | else: 63 | device = torch.device('cpu') 64 | 65 | 66 | h = torch.randn(B, C, device=device) 67 | alphas = torch.randn(B, T, C, device=device) 68 | betas = torch.randn(B, T, C, device=device) 69 | 70 | bench(lambda: scan(alphas, betas, h), name='vanilla') 71 | opt_scan = torch.compile(scan) 72 | bench(lambda: opt_scan(alphas, betas, h), name='torch.compile') 73 | scan_jit = torch.jit.trace(scan, (alphas, betas, h)) 74 | bench(lambda: scan_jit(alphas, betas, h), name='torch.jit.trace') 75 | -------------------------------------------------------------------------------- /rebayes/__init__.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | # Add matmul precision to avoid matmul precision error 4 | jax.config.update("jax_default_matmul_precision", "float32") 5 | -------------------------------------------------------------------------------- /rebayes/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/rebayes/datasets/__init__.py -------------------------------------------------------------------------------- /rebayes/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | from typing import Callable, Tuple, Union 3 | 4 | from augmax.geometric import GeometricTransformation, LazyCoordinates 5 | import numpy as np 6 | import jax.numpy as jnp 7 | import jax.random as jr 8 | 9 | 10 | class DataAugmentationFactory: 11 | """This is a base library to process/transform the elements of a numpy 12 | array according to a given function. 13 | """ 14 | def __init__( 15 | self, 16 | processor: Callable, 17 | ) -> None: 18 | self.processor = processor 19 | 20 | def __call__( 21 | self, 22 | img: np.ndarray, 23 | configs: Union[dict, list], 24 | n_processes: int=90, 25 | ) -> np.ndarray: 26 | img_processed = \ 27 | self.process_multiple_multiprocessing(img, configs, n_processes) 28 | 29 | return img_processed 30 | 31 | def process_single( 32 | self, 33 | img: np.ndarray, 34 | *args: list, 35 | **kwargs: dict, 36 | ) -> np.ndarray: 37 | img_processed = self.processor(img, *args, **kwargs) 38 | 39 | return img_processed 40 | 41 | def process_multiple( 42 | self, 43 | imgs: np.ndarray, 44 | configs: Union[dict, list], 45 | ) -> np.ndarray: 46 | imgs_processed = [] 47 | for X, config in zip(imgs, configs): 48 | X_processed = self.process_single(X, **config) 49 | imgs_processed.append(X_processed) 50 | imgs_processed = np.stack(imgs_processed, axis=0) 51 | 52 | return imgs_processed 53 | 54 | def process_multiple_multiprocessing( 55 | self, 56 | imgs: np.ndarray, 57 | configs: Union[dict, list], 58 | n_processes: int, 59 | ) -> np.ndarray: 60 | num_elements = len(imgs) 61 | if isinstance(configs, dict): 62 | configs = [configs] * num_elements 63 | 64 | if n_processes == 1: 65 | imgs_processed = self.process_multiple(imgs, configs) 66 | imgs_processed = imgs_processed.reshape(num_elements, -1) 67 | 68 | return imgs_processed 69 | 70 | imgs_processed = np.array_split(imgs, n_processes) 71 | config_split = np.array_split(configs, n_processes) 72 | elements = zip(imgs_processed, config_split) 73 | 74 | with Pool(processes=n_processes) as pool: 75 | imgs_processed = pool.starmap(self.process_multiple, elements) 76 | imgs_processed = np.concatenate(imgs_processed, axis=0) 77 | pool.join() 78 | imgs_processed = imgs_processed.reshape(num_elements, -1) 79 | 80 | return imgs_processed 81 | 82 | 83 | class Rotate(GeometricTransformation): 84 | """Rotates the image by a random arbitrary angle. 85 | Adapted from https://github.com/khdlr/augmax/. 86 | """ 87 | def __init__( 88 | self, 89 | angle_range: Union[Tuple[float, float], float]=(-30, 30), 90 | prob: float = 1.0 91 | ): 92 | super().__init__() 93 | if not hasattr(angle_range, '__iter__'): 94 | angle_range = (-angle_range, angle_range) 95 | self.theta_min, self.theta_max = map(jnp.radians, angle_range) 96 | self.probability = prob 97 | 98 | def transform_coordinates( 99 | self, 100 | rng: jnp.ndarray, 101 | coordinates: LazyCoordinates, 102 | invert=False 103 | ): 104 | do_apply = jr.bernoulli(rng, self.probability) 105 | theta = do_apply * jr.uniform(rng, minval=self.theta_min, maxval=self.theta_max) 106 | 107 | if invert: 108 | theta = -theta 109 | 110 | transform = jnp.array([ 111 | [ jnp.cos(theta), jnp.sin(theta), 0], 112 | [-jnp.sin(theta), jnp.cos(theta), 0], 113 | [0, 0, 1] 114 | ]) 115 | coordinates.push_transform(transform) 116 | -------------------------------------------------------------------------------- /rebayes/datasets/moons_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepcocessing and data augmentation for the datasets. 3 | """ 4 | import re 5 | import io 6 | import os 7 | import jax 8 | import chex 9 | import zipfile 10 | import requests 11 | import numpy as np 12 | import pandas as pd 13 | import jax.numpy as jnp 14 | import jax.random as jr 15 | from jax import vmap 16 | 17 | from typing import Union 18 | from jaxtyping import Float, Array 19 | 20 | from sklearn.datasets import make_moons 21 | 22 | 23 | def make_showdown_moons(n_train, n_test, n_train_warmup, n_test_warmup, noise, seed=314): 24 | np.random.seed(seed) 25 | train = make_moons(n_samples=n_train, noise=noise) 26 | test = make_moons(n_samples=n_test, noise=noise) 27 | warmup_train = make_moons(n_samples=n_train_warmup, noise=noise) 28 | warmup_test = make_moons(n_samples=n_test_warmup, noise=noise) 29 | 30 | train = jax.tree_map(jnp.array, train) 31 | test = jax.tree_map(jnp.array, test) 32 | warmup_train = jax.tree_map(jnp.array, warmup_train) 33 | warmup_test = jax.tree_map(jnp.array, warmup_test) 34 | 35 | return train, test, warmup_train, warmup_test 36 | 37 | def _rotation_matrix(angle): 38 | """ 39 | Create a rotation matrix that rotates the 40 | space 'angle'-radians. 41 | """ 42 | R = np.array([ 43 | [np.cos(angle), -np.sin(angle)], 44 | [np.sin(angle), np.cos(angle)] 45 | ]) 46 | return R 47 | 48 | 49 | def _make_rotating_moons(radians, n_samples=100, **kwargs): 50 | """ 51 | Make two interleaving half circles rotated by 'radians' radians 52 | 53 | Parameters 54 | ---------- 55 | radians: float 56 | Angle of rotation 57 | n_samples: int 58 | Number of samples 59 | **kwargs: 60 | Extra arguments passed to the `make_moons` function 61 | """ 62 | X, y = make_moons(n_samples=n_samples, **kwargs) 63 | X = jnp.einsum("nm,mk->nk", X, _rotation_matrix(radians)) 64 | return X, y 65 | 66 | 67 | def make_rotating_moons(n_train, n_test, n_rotations, min_angle=0, max_angle=360, seed=314, **kwargs): 68 | """ 69 | n_train: int 70 | Number of training samples per rotation 71 | n_test: int 72 | Number of test samples per rotation 73 | n_rotations: int 74 | Number of rotations 75 | """ 76 | np.random.seed(seed) 77 | n_samples = n_train + n_test 78 | min_rad = np.deg2rad(min_angle) 79 | max_rad = np.deg2rad(max_angle) 80 | 81 | radians = np.linspace(min_rad, max_rad, n_rotations) 82 | X_train_all, y_train_all, rads_train_all = [], [], [] 83 | X_test_all, y_test_all, rads_test_all = [], [], [] 84 | for rad in radians: 85 | X, y = _make_rotating_moons(rad, n_samples=n_samples, **kwargs) 86 | rads = jnp.ones(n_samples) * rad 87 | 88 | X_train = X[:n_train] 89 | y_train = y[:n_train] 90 | rad_train = rads[:n_train] 91 | 92 | X_test = X[n_train:] 93 | y_test = y[n_train:] 94 | rad_test = rads[n_train:] 95 | 96 | X_train_all.append(X_train) 97 | y_train_all.append(y_train) 98 | rads_train_all.append(rad_train) 99 | 100 | X_test_all.append(X_test) 101 | y_test_all.append(y_test) 102 | rads_test_all.append(rad_test) 103 | 104 | X_train_all = jnp.concatenate(X_train_all, axis=0) 105 | y_train_all = jnp.concatenate(y_train_all, axis=0) 106 | rads_train_all = jnp.concatenate(rads_train_all, axis=0) 107 | X_test_all = jnp.concatenate(X_test_all, axis=0) 108 | y_test_all = jnp.concatenate(y_test_all, axis=0) 109 | rads_test_all = jnp.concatenate(rads_test_all, axis=0) 110 | 111 | train = (X_train_all, y_train_all, rads_train_all) 112 | test = (X_test_all, y_test_all, rads_test_all) 113 | 114 | return train, test 115 | 116 | -------------------------------------------------------------------------------- /rebayes/datasets/nonstat_1d_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Nonstationary 1D data generation 3 | """ 4 | import jax 5 | import jax.numpy as jnp 6 | from copy import deepcopy 7 | 8 | def make_1d_regression( 9 | key, n_train=100, n_test=100, sort_data=False, coef0=2.0, coef1=3.0, coef2=1.0 10 | ): 11 | key_train, key_test = jax.random.split(key) 12 | keys_train = jax.random.split(key_train, n_train) 13 | keys_test = jax.random.split(key_test, n_test) 14 | minval, maxval = -0.5, 0.5 15 | 16 | def f(x): 17 | y = coef2 * x + 0.3 * jnp.sin(2.0 + coef1 * jnp.pi * x) 18 | return y 19 | 20 | @jax.vmap 21 | def gen(key): 22 | key_x, key_y = jax.random.split(key) 23 | x = jax.random.uniform(key_x, shape=(1,), minval=minval, maxval=maxval) 24 | if sort_data: 25 | x = jnp.sort(x) 26 | 27 | noise = jax.random.normal(key) * 0.02 28 | y = f(x) + noise 29 | return x, y 30 | 31 | X_train, y_train = gen(keys_train) 32 | X_test, y_test = gen(keys_test) 33 | 34 | X_eval = jnp.linspace(minval, maxval, 500) 35 | y_eval = f(X_eval) 36 | 37 | # Standardize dataset 38 | if True: 39 | X_train = (X_train - X_train.mean()) / X_train.std() 40 | y_train = (y_train - y_train.mean()) / y_train.std() 41 | X_test = (X_test - X_test.mean()) / X_test.std() 42 | y_test = (y_test - y_test.mean()) / y_test.std() 43 | X_eval = (X_eval - X_eval.mean()) / X_eval.std() 44 | y_eval = (y_eval - y_eval.mean()) / y_eval.std() 45 | 46 | train = (X_train, y_train) 47 | test = (X_test, y_test) 48 | eval = (X_eval, y_eval) 49 | return train, test, eval 50 | 51 | 52 | def make_coefs(key, n_dist): 53 | """ 54 | Make c0, c1 distributions 55 | """ 56 | key_slope, key_distort = jax.random.split(key) 57 | coefs = jax.random.uniform(key_distort, shape=(n_dist, 2), minval=-5, maxval=5) 58 | coef_slope = jax.random.uniform(key_slope, shape=(n_dist, 1), minval=-1.0, maxval=1.0) 59 | 60 | coefs = jnp.c_[coefs, coef_slope] 61 | return coefs 62 | 63 | 64 | def sample_1d_regression_sequence(key, n_dist, n_train=100, n_test=100): 65 | key_coef, key_dataset = jax.random.split(key) 66 | keys_dataset = jax.random.split(key_dataset, n_dist) 67 | coefs = make_coefs(key, n_dist) 68 | 69 | @jax.vmap 70 | def vsample_dataset(key, coefs): 71 | train, test, eval = make_1d_regression( 72 | key, n_train, n_test, coef0=coefs[0], coef1=coefs[1], coef2=coefs[2] 73 | ) 74 | return train, test, eval 75 | 76 | *collection, eval_dataset = vsample_dataset(keys_dataset, coefs) 77 | collection_flat = jax.tree_map(lambda x: x.reshape(-1, 1), collection) 78 | collection_train, collection_test = collection_flat 79 | 80 | train_id_seq = jnp.repeat(jnp.arange(n_dist), n_train) 81 | test_id_seq = jnp.repeat(jnp.arange(n_dist), n_test) 82 | 83 | collection_flat = { 84 | "train": { 85 | "X": collection_train[0], 86 | "y": collection_train[1], 87 | "id_seq": train_id_seq 88 | }, 89 | "test": { 90 | "X": collection_test[0], 91 | "y": collection_test[1], 92 | "id_seq": test_id_seq 93 | } 94 | } 95 | 96 | collection_train, collection_test = collection 97 | collection_task = { 98 | "train": { 99 | "X": collection_train[0], 100 | "y": collection_train[1] 101 | }, 102 | "test": { 103 | "X": collection_test[0], 104 | "y": collection_test[1], 105 | }, 106 | "eval": { 107 | "X": eval_dataset[0], 108 | "y": eval_dataset[1] 109 | } 110 | } 111 | 112 | return collection_flat, collection_task 113 | 114 | 115 | def slice_tasks(datasets, task): 116 | datasets = deepcopy(datasets) 117 | train_seq = datasets["train"].pop("id_seq") == task 118 | test_seq = datasets["test"].pop("id_seq") == task 119 | 120 | train = datasets["train"] 121 | test = datasets["test"] 122 | 123 | train = jax.tree_map(lambda x: x[train_seq], train) 124 | test = jax.tree_map(lambda x: x[test_seq], test) 125 | 126 | datasets = { 127 | "train": train, 128 | "test": test 129 | } 130 | 131 | return datasets 132 | -------------------------------------------------------------------------------- /rebayes/datasets/rotating_permuted_mnist_data.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import vmap 3 | import jax.numpy as jnp 4 | import jax.random as jr 5 | 6 | from rebayes.datasets import classification_data as clf_data 7 | from rebayes.datasets import rotating_mnist_data as rmnist_data 8 | 9 | 10 | def generate_random_angles(n_tasks, min_angle=0, max_angle=180, key=0): 11 | if isinstance(key, int): 12 | key = jax.random.PRNGKey(key) 13 | angles = jr.uniform(key, (n_tasks,), minval=min_angle, maxval=max_angle) 14 | return angles 15 | 16 | 17 | def rotate_mnist_dataset(X, angles): 18 | X_rotated = vmap(rmnist_data.rotate_mnist)(X, angles) 19 | 20 | return X_rotated 21 | 22 | 23 | def generate_rotating_mnist_dataset(X=None, min_angle=0, max_angle=180, key=0, target_digit=None): 24 | if isinstance(key, int): 25 | key = jr.PRNGKey(key) 26 | if X is None: 27 | (X, y), _ = rmnist_data.load_mnist() 28 | if target_digit is not None: 29 | X = X[y == target_digit] 30 | random_angles = generate_random_angles(len(X), min_angle, max_angle, key) 31 | X_rotated = rotate_mnist_dataset(X, random_angles) 32 | 33 | return X_rotated, random_angles 34 | 35 | 36 | def generate_rotating_permuted_mnist_regression_dataset( 37 | n_tasks, ntrain_per_task, nval_per_task, ntest_per_task, min_angle=0, 38 | max_angle=180, key=0, fashion=True, mnist_dataset = None 39 | ): 40 | if mnist_dataset is None: 41 | mnist_dataset = clf_data.load_permuted_mnist_dataset(n_tasks, ntrain_per_task, 42 | nval_per_task, ntest_per_task, 43 | key, fashion) 44 | dataset = { 45 | k: generate_rotating_mnist_dataset(mnist_dataset[k][0], min_angle, max_angle, i) 46 | for i, k in enumerate(('train', 'val', 'test')) 47 | } 48 | 49 | return dataset 50 | -------------------------------------------------------------------------------- /rebayes/datasets/uci_uncertainty_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data taken for the UCI Uncertainty Benchmark 3 | repo: https://github.com/yaringal/DropoutUncertaintyExps.git 4 | """ 5 | 6 | import os 7 | import jax 8 | import numpy as np 9 | import jax.numpy as jnp 10 | 11 | 12 | def load_raw_data(path): 13 | data_path = os.path.join(path, "data.txt") 14 | data = np.loadtxt(data_path) 15 | data = jax.tree_map(jnp.array, data) 16 | data = jax.tree_map(jnp.nan_to_num, data) 17 | return data 18 | 19 | 20 | def load_train_test_ixs(path, ix): 21 | ix_train_path = os.path.join(path, f"index_train_{ix}.txt") 22 | ix_test_path = os.path.join(path, f"index_test_{ix}.txt") 23 | ix_train = np.loadtxt(ix_train_path, dtype=int) 24 | ix_test = np.loadtxt(ix_test_path, dtype=int) 25 | 26 | return ix_train, ix_test 27 | 28 | 29 | def load_features_target_ixs(path): 30 | features_path = os.path.join(path, "index_features.txt") 31 | target_path = os.path.join(path, "index_target.txt") 32 | 33 | features_ixs = np.loadtxt(features_path, dtype=int) 34 | target_ixs = np.loadtxt(target_path, dtype=int) 35 | 36 | return features_ixs, target_ixs 37 | 38 | 39 | def normalise_features(X, ix_train, ix_test): 40 | X_train = X[ix_train] 41 | X_test = X[ix_test] 42 | 43 | xmean, xstd = np.nanmean(X_train, axis=0, keepdims=True), np.nanstd(X_train, axis=0, keepdims=True) 44 | X_train = (X_train - xmean) / xstd 45 | X_test = (X_test - xmean) / xstd 46 | 47 | return X_train, X_test, (xmean, xstd) 48 | 49 | 50 | def normalise_targets(y, ix_train, ix_test): 51 | y_train = y[ix_train] 52 | y_test = y[ix_test] 53 | 54 | ymean, ystd = np.nanmean(y_train), np.nanstd(y_train) 55 | y_train = (y_train - ymean) / ystd 56 | y_test = (y_test - ymean) / ystd 57 | 58 | return y_train, y_test, (ymean, ystd) 59 | 60 | 61 | def load_full_data(path): 62 | data = load_raw_data(path) 63 | features_ixs, target_ixs = load_features_target_ixs(path) 64 | X = data[:, features_ixs] 65 | y = data[:, target_ixs] 66 | return X, y 67 | 68 | 69 | def load_folds_data(path, n_partitions=20): 70 | """ 71 | Load data from all available folds 72 | """ 73 | X, y = load_full_data(path) 74 | 75 | X_train_all = [] 76 | y_train_all = [] 77 | X_test_all = [] 78 | y_test_all = [] 79 | coefs_all = [] 80 | 81 | for ix in range(n_partitions): 82 | ix_train, ix_test = load_train_test_ixs(path, ix) 83 | X_train, X_test, _ = normalise_features(X, ix_train, ix_test) 84 | y_train, y_test, (ymean, ystd) = normalise_features(y, ix_train, ix_test) 85 | 86 | X_test_all.append(X_test) 87 | y_test_all.append(y_test) 88 | 89 | X_train_all.append(X_train) 90 | y_train_all.append(y_train) 91 | coefs = {"ymean": ymean.item(), "ystd": ystd.item()} 92 | coefs_all.append(coefs) 93 | 94 | X_train_all = jnp.stack(X_train_all, axis=0) 95 | y_train_all = jnp.stack(y_train_all, axis=0) 96 | 97 | X_test_all = jnp.stack(X_test_all, axis=0) 98 | y_test_all = jnp.stack(y_test_all, axis=0) 99 | 100 | train_all = (X_train_all, y_train_all) 101 | test_all = (X_test_all, y_test_all) 102 | 103 | struct_out = jax.tree_util.tree_structure([0 for e in coefs_all]) 104 | struct_in = jax.tree_util.tree_structure(coefs_all[0]) 105 | coefs_all = jax.tree_util.tree_transpose(struct_out, struct_in, coefs_all) 106 | coefs_all = jax.tree_map(jnp.array, coefs_all, is_leaf=lambda x: type(x) == list) 107 | 108 | return train_all, test_all, coefs_all 109 | 110 | 111 | def load_data(path, index): 112 | data = load_raw_data(path) 113 | train_ixs, test_ixs = load_train_test_ixs(path, index) 114 | features_ixs, target_ixs = load_features_target_ixs(path) 115 | 116 | X, y = data[:, features_ixs], data[:, target_ixs] 117 | 118 | X_train, X_test, _ = normalise_features(X, train_ixs, test_ixs) 119 | y_train, y_test, (ymean, ystd) = normalise_targets(y, train_ixs, test_ixs) 120 | 121 | train = (X_train, y_train.ravel()) 122 | test = (X_test, y_test.ravel()) 123 | 124 | dataset = { 125 | "train": train, 126 | "test": test, 127 | } 128 | 129 | res = { 130 | "dataset": dataset, 131 | "ymean": ymean, 132 | "ystd": ystd, 133 | } 134 | 135 | return res 136 | 137 | 138 | if __name__ == "__main__": 139 | path = ( 140 | "/home/gerardoduran/documents/external" 141 | "/DropoutUncertaintyExps/UCI_Datasets" 142 | "/kin8nm" 143 | ) 144 | path = os.path.join(path, "data") 145 | data = load_data(path, 0) 146 | train, test = data 147 | print(train[0].shape) 148 | -------------------------------------------------------------------------------- /rebayes/deprecated/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rebayes/deprecated/old_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | from functools import partial 4 | 5 | import jax.numpy as jnp 6 | from jax import jacrev, jit 7 | from jax.lax import scan 8 | from jaxtyping import Float, Array 9 | from typing import Callable, NamedTuple, Union, Tuple, Any 10 | import chex 11 | from jax_tqdm import scan_tqdm 12 | 13 | _jacrev_2d = lambda f, x: jnp.atleast_2d(jacrev(f)(x)) 14 | 15 | import tensorflow_probability.substrates.jax as tfp 16 | tfd = tfp.distributions 17 | MVN = tfd.MultivariateNormalFullCovariance 18 | 19 | 20 | FnStateToState = Callable[ [Float[Array, "state_dim"]], Float[Array, "state_dim"]] 21 | FnStateAndInputToState = Callable[ [Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "state_dim"]] 22 | FnStateToEmission = Callable[ [Float[Array, "state_dim"]], Float[Array, "emission_dim"]] 23 | FnStateAndInputToEmission = Callable[ [Float[Array, "state_dim"], Float[Array, "input_dim"] ], Float[Array, "emission_dim"]] 24 | 25 | FnStateToEmission2 = Callable[[Float[Array, "state_dim"]], Float[Array, "emission_dim emission_dim"]] 26 | FnStateAndInputToEmission2 = Callable[[Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "emission_dim emission_dim"]] 27 | EmissionDistFn = Callable[ [Float[Array, "state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution] 28 | 29 | CovMat = Union[float, Float[Array, "dim"], Float[Array, "dim dim"]] 30 | 31 | class GaussianBroken(NamedTuple): 32 | mean: Float[Array, "state_dim"] 33 | cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "state_dim"]] 34 | 35 | @chex.dataclass 36 | class Gaussian: 37 | mean: chex.Array 38 | cov: chex.Array 39 | 40 | Belief = Gaussian # Can be over-ridden by other representations (e.g., MCMC samples or memory buffer) 41 | 42 | class RebayesParams(NamedTuple): 43 | initial_mean: Float[Array, "state_dim"] 44 | initial_covariance: CovMat 45 | dynamics_weights: CovMat 46 | dynamics_covariance: CovMat 47 | #emission_function: FnStateAndInputToEmission 48 | #emission_covariance: CovMat 49 | emission_mean_function: FnStateAndInputToEmission 50 | emission_cov_function: FnStateAndInputToEmission2 51 | emission_dist: EmissionDistFn = lambda mean, cov: MVN(loc=mean, covariance_matrix=cov) 52 | #emission_dist=lambda mu, Sigma: tfd.Poisson(log_rate = jnp.log(mu)) 53 | adaptive_emission_cov: bool=False 54 | dynamics_covariance_inflation_factor: float=0.0 55 | 56 | 57 | 58 | class Rebayes(ABC): 59 | def __init__( 60 | self, 61 | params: RebayesParams, 62 | ): 63 | self.params = params 64 | #self.emission_mean_function = lambda z, u: self.emission_function(z, u) 65 | #self.emission_cov_function = lambda z, u: self.params.emission_covariance 66 | 67 | def init_bel(self): 68 | return Gaussian(mean=self.params.initial_mean, cov=self.params.initial_covariance) 69 | 70 | @partial(jit, static_argnums=(0,)) 71 | def predict_state( 72 | self, 73 | bel: Gaussian 74 | ) -> Gaussian: 75 | """Given bel(t-1|t-1) = p(z(t-1) | D(1:t-1)), return bel(t|t-1) = p(z(t) | z(t-1), D(1:t-1)). 76 | This is cheap, since the dyanmics model is linear-Gaussian. 77 | """ 78 | m, P = bel.mean, bel.cov 79 | F = self.params.dynamics_weights 80 | Q = self.params.dynamics_covariance 81 | pred_mean = F @ m 82 | pred_cov = F @ P @ F.T + Q 83 | return Gaussian(mean=pred_mean, cov=pred_cov) 84 | 85 | @partial(jit, static_argnums=(0,)) 86 | def predict_obs( 87 | self, 88 | bel: Gaussian, 89 | u: Float[Array, "input_dim"] 90 | ) -> Gaussian: # TODO: replace output with emission_dist 91 | """Given bel(t|t-1) = p(z(t) | D(1:t-1)), return obs(t|t-1) = p(y(t) | u(t), D(1:t-1))""" 92 | prior_mean, prior_cov = bel.mean, bel.cov # p(z(t) | y(1:t-1)) 93 | # Partially apply fn to the input u so it just depends on hidden state z 94 | m_Y = lambda z: self.params.emission_mean_function(z, u) 95 | Cov_Y = lambda z: self.params.emission_cov_function(z, u) 96 | 97 | yhat = jnp.atleast_1d(m_Y(prior_mean)) 98 | R = jnp.atleast_2d(Cov_Y(prior_mean)) 99 | H = _jacrev_2d(m_Y, prior_mean) 100 | 101 | Sigma_obs = H @ prior_cov @ H.T + R 102 | return Gaussian(mean=yhat, cov=Sigma_obs) 103 | 104 | @abstractmethod 105 | def update_state( 106 | self, 107 | bel: Gaussian, 108 | u: Float[Array, "input_dim"], 109 | y: Float[Array, "obs_dim"] 110 | ) -> Gaussian: 111 | """Return bel(t|t) = p(z(t) | u(t), y(t), D(1:t-1)) using bel(t|t-1)""" 112 | raise NotImplementedError 113 | 114 | def scan( 115 | self, 116 | X: Float[Array, "ntime input_dim"], 117 | Y: Float[Array, "ntime emission_dim"], 118 | callback=None, 119 | bel=None, 120 | progress_bar=True, 121 | **kwargs 122 | ) -> Tuple[Gaussian, Any]: 123 | """Apply filtering to entire sequence of data. Return final belief state and outputs from callback.""" 124 | num_timesteps = X.shape[0] 125 | def step(bel, t): 126 | bel = self.predict_state(bel) 127 | pred_obs = self.predict_obs(bel, X[t]) 128 | bel = self.update_state(bel, X[t], Y[t]) 129 | out = None 130 | if callback is not None: 131 | out = callback(bel, pred_obs, t, X[t], Y[t], **kwargs) 132 | return bel, out 133 | carry = bel 134 | if bel is None: 135 | carry = self.init_bel() 136 | 137 | if progress_bar: 138 | step = scan_tqdm(num_timesteps)(step) 139 | 140 | bel, outputs = scan(step, carry, jnp.arange(num_timesteps)) 141 | return bel, outputs 142 | -------------------------------------------------------------------------------- /rebayes/deprecated/optax_optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | from jax.flatten_util import ravel_pytree 4 | from jax import jit 5 | import jax.numpy as jnp 6 | from optax._src import base 7 | 8 | from dynamax.rebayes.diagonal_inference import ( 9 | _full_covariance_condition_on, 10 | _fully_decoupled_ekf_condition_on, 11 | _variational_diagonal_ekf_condition_on, 12 | ) 13 | 14 | 15 | class EKFState(NamedTuple): 16 | """ 17 | Lightweight container for EKF parameters. 18 | """ 19 | mean: base.Updates 20 | cov: base.Updates 21 | 22 | 23 | def make_ekf_optimizer( 24 | pred_mean_fn, 25 | pred_cov_fn, 26 | init_var = 1.0, 27 | ekf_type = 'fcekf', 28 | num_iter = 1 29 | ) -> base.GradientTransformation: 30 | """Generate optax optimizer object for EKF. 31 | 32 | Args: 33 | pred_mean_fn (Callable): Emission mean function for EKF. 34 | pred_cov_fn (Callable): Emission covariance function for EKF. 35 | init_var (float, optional): Initial covariance factor. Defaults to 1.0. 36 | ekf_type (str, optional): One of ['fcekf', 'fdekf', 'vdekf']. Defaults to 'fcekf'. 37 | num_iter (int, optional): Number of posterior linearizations to perform. Defaults to 1. 38 | Returns: 39 | base.GradientTransformation: Optax optimizer object for EKF. 40 | """ 41 | if ekf_type not in ['fcekf', 'fdekf', 'vdekf']: 42 | raise ValueError(f"'ekf_type' must be one of ['fcekf, 'fdekf', 'vdekf']") 43 | 44 | def init_fn(params): 45 | flat_params, _ = ravel_pytree(params) 46 | if ekf_type == 'fcekf': 47 | cov = init_var * jnp.eye(flat_params.shape[0]) 48 | else: 49 | cov = init_var * jnp.ones_like(flat_params) 50 | return EKFState(mean=params, cov=cov) 51 | 52 | @jit 53 | def update_fn(updates, state, params=None): 54 | # Updates are new set of data points 55 | x, y = updates 56 | flat_mean, unflatten_fn = ravel_pytree(state.mean) 57 | if ekf_type == 'fcekf': 58 | mean, cov = _full_covariance_condition_on( 59 | flat_mean, state.cov, pred_mean_fn, pred_cov_fn, x, y, num_iter 60 | ) 61 | elif ekf_type == 'fdekf': 62 | mean, cov = _fully_decoupled_ekf_condition_on( 63 | flat_mean, state.cov, pred_mean_fn, pred_cov_fn, x, y, num_iter 64 | ) 65 | else: 66 | mean, cov = _variational_diagonal_ekf_condition_on( 67 | flat_mean, state.cov, pred_mean_fn, pred_cov_fn, x, y, num_iter 68 | ) 69 | updates = unflatten_fn(mean - flat_mean) 70 | return updates, EKFState(mean=unflatten_fn(mean), cov=cov) 71 | 72 | return base.GradientTransformation(init_fn, update_fn) -------------------------------------------------------------------------------- /rebayes/deprecated/simple_base.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import namedtuple 3 | from typing import Any, Tuple, Union 4 | 5 | import jax 6 | import chex 7 | from jax.lax import scan 8 | import jax.numpy as jnp 9 | from jaxtyping import Float, Array 10 | from tqdm import trange 11 | 12 | import tensorflow_probability.substrates.jax as tfp 13 | tfd = tfp.distributions 14 | tfb = tfp.bijectors 15 | MVN = tfd.MultivariateNormalFullCovariance 16 | MVND = tfd.MultivariateNormalDiag 17 | 18 | 19 | import optax 20 | 21 | 22 | def predict_update_batch( 23 | params, 24 | bel, 25 | emission_mean_fn, 26 | predict_state_fn, 27 | update_state_fn, 28 | Xs, 29 | Ys 30 | ): 31 | num_timesteps = Xs.shape[0] 32 | is_classification = params['is_classification'] 33 | def step(bel, t): 34 | bel = carry 35 | pred_bel = predict_state_fn(params['dyn_noise'], params['dyn_weights'], params['inflation'], bel) 36 | x = Xs[t] 37 | yhat = emission_mean_fn(pred_bel.mean, x) 38 | ytrue = Ys[t] 39 | if is_classification: 40 | #obs_cov = emission_cov_fn(pred_bel.mean) 41 | ps = yhat # probabilities 42 | obs_cov = jnp.diag(ps) - jnp.outer(ps, ps) + 1e-3 * jnp.eye(len(ps)) # Add diagonal to avoid singularity 43 | else: 44 | obs_cov = params['obs_noise'] 45 | bel = update_state_fn(pred_bel, x, ytrue, emission_mean_fn, obs_cov) 46 | if is_classification: 47 | logits = jnp.log(yhat) 48 | loss = optax.softmax_cross_entropy_loss_with_integer_labels(logits, ytrue) 49 | else: 50 | #loss = jnp.square(Ytr - yhat).mean() 51 | pdf = MVN(yhat, obs_cov) 52 | loss = -pdf.logpdf(ytrue) 53 | return bel, loss 54 | carry, losses = scan(step, bel, jnp.arange(num_timesteps)) 55 | loss = jnp.sum(losses) 56 | return carry, loss 57 | 58 | #apply_fn = lambda w, x: cnn.apply({'params': unflatten_fn(w)}, x).ravel() 59 | #emission_mean_function=lambda w, x: jax.nn.softmax(apply_fn(w, x)) 60 | 61 | def dual_rebayes_simple( 62 | init_fn, 63 | nstates, 64 | emission_mean_fn, 65 | is_classification, 66 | predict_state_fn, 67 | update_state_fn, 68 | constrain_params_fn, 69 | data_loader, 70 | optimizer 71 | ): 72 | X, Y = next(iter(data_loader)) 73 | emission_dim = 1 if len(Y.shape) == 1 else Y.shape[1] 74 | params_unc, bel = init_fn(nstates, emission_dim, is_classification) 75 | opt_state = optimizer.init(params_unc) 76 | def step(carry, b): 77 | params_unc, bel, opt_state = carry 78 | batch = data_loader[b] 79 | Xs, Ys = batch[0], batch[1] 80 | def lossfn(params_unc): 81 | params_con = constrain_params_fn(params_unc) 82 | bel_post, loss = predict_update_batch( 83 | params_con, bel, 84 | emission_mean_fn, is_classification, 85 | predict_state_fn, update_state_fn, 86 | Xs, Ys) 87 | return loss, bel_post 88 | (loss_value, bel), grads = jax.value_and_grad(lossfn, has_aux=True)(params_unc) 89 | param_updates, opt_state = optimizer.update(grads, opt_state, params_unc) 90 | params_unc = optax.apply_updates(params_unc, param_updates) 91 | return (params_unc, bel, opt_state), loss_value 92 | carry, losses = scan(step, (params_unc, bel, opt_state), jnp.arange(len(data_loader))) 93 | return carry, losses 94 | -------------------------------------------------------------------------------- /rebayes/extended_kalman_filter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/rebayes/extended_kalman_filter/__init__.py -------------------------------------------------------------------------------- /rebayes/extended_kalman_filter/replay_ekf.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import chex 4 | from functools import partial 5 | import jax 6 | from jax import jit 7 | import jax.numpy as jnp 8 | from jaxtyping import Array, Float 9 | import tensorflow_probability.substrates.jax as tfp 10 | 11 | from rebayes.base import ( 12 | CovMat, 13 | EmissionDistFn, 14 | FnStateAndInputToEmission, 15 | FnStateAndInputToEmission2, 16 | FnStateToEmission, 17 | FnStateToEmission2, 18 | FnStateToState, 19 | ) 20 | from rebayes.extended_kalman_filter.ekf import RebayesEKF 21 | 22 | 23 | tfd = tfp.distributions 24 | MVN = tfd.MultivariateNormalTriL 25 | MVD = tfd.MultivariateNormalDiag 26 | 27 | 28 | @chex.dataclass 29 | class ReplayEKFBel: 30 | mean: chex.Array 31 | cov: chex.Array 32 | 33 | 34 | class RebayesReplayEKF(RebayesEKF): 35 | def __init__( 36 | self, 37 | dynamics_weights_or_function: Union[float, FnStateToState], 38 | dynamics_covariance: CovMat, 39 | emission_mean_function: Union[FnStateToEmission, FnStateAndInputToEmission], 40 | emission_cov_function: Union[FnStateToEmission2, FnStateAndInputToEmission2], 41 | emission_dist: EmissionDistFn = \ 42 | lambda mean, cov: MVN(loc=mean, scale_tril=jnp.linalg.cholesky(cov)), 43 | dynamics_covariance_inflation_factor: float = 0.0, 44 | method: str="fcekf", 45 | n_replay: int=10, 46 | learning_rate: float=0.01, 47 | ): 48 | super().__init__( 49 | dynamics_weights_or_function, dynamics_covariance, 50 | emission_mean_function, emission_cov_function, emission_dist, False, 51 | dynamics_covariance_inflation_factor, method 52 | ) 53 | 54 | self.log_likelihood = lambda params, x, y: \ 55 | jnp.sum( 56 | emission_dist(self.emission_mean_function(params, x), 57 | self.emission_cov_function(params, x)).log_prob(y) 58 | ) 59 | self.n_replay = n_replay 60 | self.learning_rate = learning_rate 61 | 62 | def _update_mean( 63 | self, 64 | bel: ReplayEKFBel, 65 | m_prev: Float[Array, "state_dim"], 66 | x: Float[Array, "input_dim"], 67 | y: Float[Array, "output_dim"], 68 | ) -> ReplayEKFBel: 69 | m, P = bel.mean, bel.cov 70 | gll = -jax.grad(self.log_likelihood, argnums=0)(m, x, y) 71 | additive_term = P @ gll if self.method == "fcekf" else P * gll 72 | m_cond = m - self.learning_rate * (m - m_prev + additive_term) 73 | bel_cond = bel.replace(mean=m_cond) 74 | 75 | return bel_cond 76 | 77 | def _update_cov( 78 | self, 79 | bel: ReplayEKFBel, 80 | x: Float[Array, "input_dim"], 81 | y: Float[Array, "output_dim"], 82 | ) -> ReplayEKFBel: 83 | m_prev, P_prev = bel.mean, bel.cov 84 | _, P_cond = self.update_fn(m_prev, P_prev, self.emission_mean_function, 85 | self.emission_cov_function, x, y, 86 | 1, False, 0.0) 87 | bel_cond = bel.replace(cov=P_cond) 88 | 89 | return bel_cond 90 | 91 | @partial(jit, static_argnums=(0,)) 92 | def update_state( 93 | self, 94 | bel: ReplayEKFBel, 95 | x: Float[Array, "input_dim"], 96 | y: Float[Array, "output_dim"], 97 | ) -> ReplayEKFBel: 98 | m_prev = bel.mean 99 | def partial_step(_, bel): 100 | bel = self._update_mean(bel, m_prev, x, y) 101 | return bel 102 | bel = jax.lax.fori_loop(0, self.n_replay-1, partial_step, bel) 103 | bel = self._update_mean(bel, m_prev, x, y) 104 | bel_cond = self._update_cov(bel, x, y) 105 | 106 | return bel_cond 107 | -------------------------------------------------------------------------------- /rebayes/extended_kalman_filter/test_dual_ekf.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN 3 | 4 | from rebayes.dual_base import dual_rebayes_scan, DualRebayesParams, ObsModel 5 | from rebayes.extended_kalman_filter.ekf import RebayesEKF 6 | from rebayes.extended_kalman_filter.dual_ekf import make_dual_ekf_estimator, EKFParams 7 | from rebayes.extended_kalman_filter.test_ekf import make_linreg_rebayes_params, run_kalman, make_linreg_data, make_linreg_prior 8 | from rebayes.utils.utils import get_mlp_flattened_params 9 | 10 | 11 | def allclose(u, v): 12 | return jnp.allclose(u, v, atol=1e-3) 13 | 14 | def make_linreg_dual_params(nfeatures): 15 | (obs_var, mu0, Sigma0) = make_linreg_prior() 16 | 17 | # Define Linear Regression as MLP with no hidden layers 18 | input_dim, hidden_dims, output_dim = nfeatures, [], 1 19 | model_dims = [input_dim, *hidden_dims, output_dim] 20 | *_, apply_fn = get_mlp_flattened_params(model_dims) 21 | 22 | params = DualRebayesParams( 23 | mu0=mu0, 24 | eta0=1/Sigma0[0,0], 25 | dynamics_scale_factor = 1.0, 26 | dynamics_noise = 0.0, 27 | obs_noise = obs_var, 28 | cov_inflation_factor = 0, 29 | ) 30 | obs_model = ObsModel( 31 | emission_mean_function = lambda w, x: apply_fn(w, x), 32 | emission_cov_function = lambda w, x: obs_var 33 | ) 34 | 35 | return params, obs_model 36 | 37 | 38 | def test_linreg(): 39 | # check that dual estimator matches KF for linear regression 40 | (X, Y) = make_linreg_data() 41 | lgssm_posterior = run_kalman(X, Y) 42 | mu_kf = lgssm_posterior.filtered_means 43 | cov_kf = lgssm_posterior.filtered_covariances 44 | ll_kf = lgssm_posterior.marginal_loglik 45 | 46 | N,D = X.shape 47 | params, obs_model = make_linreg_dual_params(D) 48 | ekf_params = EKFParams(method="fcekf") 49 | estimator = make_dual_ekf_estimator(params, obs_model, ekf_params) 50 | 51 | def callback(params, bel, pred_obs, t, u, y, bel_pred): 52 | m = pred_obs # estimator.predict_obs(params, bel_pred, u) 53 | P = estimator.predict_obs_cov(params, bel_pred, u) 54 | ll = MVN(m, P).log_prob(jnp.atleast_1d(y)) 55 | return ll 56 | 57 | carry, lls = dual_rebayes_scan(estimator, X, Y, callback) 58 | params, final_bel = carry 59 | # print(carry) 60 | T = mu_kf.shape[0] 61 | assert allclose(final_bel.mean, mu_kf[T-1]) 62 | assert allclose(final_bel.cov, cov_kf[T-1]) 63 | ll = jnp.sum(lls) 64 | assert jnp.allclose(ll, ll_kf, atol=1e-1) 65 | 66 | 67 | def test_adaptive_backwards_compatibility(): 68 | # check that we estimate the same obs noise as Peter's EKF code (for certain settings) 69 | (X, Y) = make_linreg_data() 70 | 71 | # old estimator 72 | N, D = X.shape 73 | params = make_linreg_rebayes_params(D) 74 | params.adaptive_emission_cov = True 75 | estimator = RebayesEKF(params, method='fcekf') 76 | final_bel, lls = estimator.scan(X, Y) 77 | obs_noise_ekf = jnp.atleast_1d(final_bel.obs_noise_var).ravel() 78 | # print(obs_noise_ekf) 79 | 80 | params, obs_model = make_linreg_dual_params(D) 81 | # if we use the post-update estimator, initialized with q=0 and lr=1/N(t), we should match peter's code 82 | params.obs_noise = 0.0 * jnp.eye(1) 83 | ekf_params = EKFParams(method="fcekf", obs_noise_estimator = "post", obs_noise_lr_fn= lambda t: 1.0/(t+1)) 84 | 85 | estimator = make_dual_ekf_estimator(params, obs_model, ekf_params) 86 | carry, lls = dual_rebayes_scan(estimator, X, Y) 87 | params, final_bel = carry 88 | obs_noise_dual = jnp.atleast_1d(params.obs_noise).ravel() 89 | # print(obs_noise_dual) 90 | assert jnp.allclose(obs_noise_dual, obs_noise_ekf) 91 | 92 | 93 | def test_adaptive(): 94 | (X, Y) = make_linreg_data() 95 | N, D = X.shape 96 | params, obs_model = make_linreg_dual_params(D) 97 | init_R = 0.1*jnp.std(Y) * jnp.eye(1) 98 | lr = 0.01 99 | 100 | params.obs_noise = init_R 101 | ekf_params = EKFParams(method="fcekf", obs_noise_estimator = "post", obs_noise_lr_fn=lambda t: lr) 102 | estimator = make_dual_ekf_estimator(params, obs_model, ekf_params) 103 | (params, final_bel), lls = dual_rebayes_scan(estimator, X, Y) 104 | obs_noise_post = params.obs_noise 105 | 106 | params.obs_noise = init_R 107 | ekf_params = EKFParams(method="fcekf", obs_noise_estimator = "pre", obs_noise_lr_fn= lambda t: lr) 108 | estimator = make_dual_ekf_estimator(params, obs_model, ekf_params) 109 | (params,final_bel), lls = dual_rebayes_scan(estimator, X, Y) 110 | obs_noise_pre = params.obs_noise 111 | 112 | # print("post ", obs_noise_post, "pre ", obs_noise_pre) 113 | assert True -------------------------------------------------------------------------------- /rebayes/extended_kalman_filter/viking.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | import jax.numpy as jnp 4 | from jax.lax import scan 5 | 6 | 7 | @chex.dataclass 8 | class VikingBel: 9 | mean: chex.Array 10 | cov: chex.Array 11 | dynamics_noise_mean: chex.Array 12 | dynamics_noise_cov: chex.Array 13 | emission_noise_mean: chex.Array 14 | emission_noise_cov: chex.Array 15 | 16 | 17 | class RebayesViking: 18 | def __init__( 19 | self, 20 | dynamics_weights, # K 21 | initial_dynamics_noise_mean, # b_hat 22 | initial_dynamics_noise_cov, # Sigma 23 | initial_emission_noise_mean, # a_hat 24 | initial_emission_noise_cov, # s 25 | n_iter, 26 | dynamics_noise_transition_cov=0.0, # rho_b 27 | emission_noise_transition_cov=0.0, # rho_a 28 | learn_dynamics_noise_cov=True, 29 | learn_emission_noise_cov=True, 30 | ): 31 | self.K = dynamics_weights 32 | self.b_hat0 = initial_dynamics_noise_mean 33 | self.Sigma0 = initial_dynamics_noise_cov 34 | self.a_hat0 = initial_emission_noise_mean 35 | self.s0 = initial_emission_noise_cov 36 | self.n_iter = n_iter 37 | self.rho_b = dynamics_noise_transition_cov 38 | self.rho_a = emission_noise_transition_cov 39 | self.learn_dynamics_noise_cov = learn_dynamics_noise_cov 40 | self.learn_emission_noise_cov = learn_emission_noise_cov 41 | 42 | def init_bel( 43 | self, 44 | initial_mean, 45 | initial_covariance, 46 | ) -> VikingBel: 47 | return VikingBel( 48 | mean=initial_mean, 49 | cov=initial_covariance, 50 | dynamics_noise_mean=self.b_hat0, 51 | dynamics_noise_cov=self.Sigma0, 52 | emission_noise_mean=self.a_hat0, 53 | emission_noise_cov=self.s0, 54 | ) 55 | 56 | def update_bel( 57 | self, 58 | bel, 59 | x, 60 | y, 61 | ): 62 | mean, cov, b_hat, Sigma, a_hat, s = \ 63 | bel.mean, bel.cov, bel.dynamics_noise_mean, bel.dynamics_noise_cov, \ 64 | bel.emission_noise_mean, bel.emission_noise_cov 65 | a_hat_cond = a_hat 66 | s_cond = s + self.rho_a 67 | b_hat_cond = b_hat 68 | Sigma_cond = Sigma + self.rho_b 69 | y_pred = mean @ x 70 | 71 | f = jnp.exp(b_hat_cond) 72 | C_inv = jnp.linalg.pinv(self.K @ cov @ self.K.T + f) 73 | # Quadrature approximation to the expectation 74 | A = C_inv - f * Sigma_cond * C_inv @ C_inv / 2 + \ 75 | f**2 * Sigma_cond * C_inv @ C_inv @ C_inv 76 | A_inv = jnp.linalg.pinv((A+A.T)/2) 77 | 78 | # Update the mean and covariance 79 | v = jnp.exp(a_hat_cond - s_cond/2) 80 | cov_cond = A_inv - (A_inv @ x)[:, None] @ x[None, :] \ 81 | @ (A_inv / (x @ A_inv @ x + v)) 82 | mean_cond = self.K @ mean + A_inv @ x / (x @ A_inv @ x + v) \ 83 | * (y - (self.K @ mean) @ x) 84 | 85 | # Update a_hat and s 86 | if self.learn_emission_noise_cov: 87 | c = (y - mean_cond @ x)**2 + x @ cov_cond @ x 88 | s_cond = 1/(1/(s + self.rho_a) + 89 | 0.5*c*jnp.exp(-a_hat_cond + (s - self.rho_a)/2)) 90 | s_cond = jax.lax.max(s_cond, s - self.rho_a) 91 | M = 100*self.rho_a 92 | diff = 1/2*(1/(s + self.rho_a) + c/2 * 93 | jnp.exp(-a_hat_cond + s_cond/2 + M)) * \ 94 | (c * jnp.exp(-a_hat_cond + s_cond/2) - 1) 95 | a_hat_cond = a_hat_cond + jax.lax.max(jax.lax.min(diff, M), -M) 96 | 97 | # Update b_hat and Sigma 98 | if self.learn_dynamics_noise_cov: 99 | d = mean.shape[0] 100 | mean_term = mean_cond - self.K @ mean 101 | B = cov_cond + jnp.outer(mean_term, mean_term) 102 | C_inv = jnp.linalg.pinv(self.K @ cov @ self.K.T + f) 103 | g = jnp.sum(jnp.diag(C_inv @ (jnp.eye(d) - B @ C_inv)))*f 104 | 105 | # Approximation and no upper bound 106 | C_inv = jnp.linalg.pinv(self.K @ cov @ self.K.T + f) 107 | H = jnp.sum(jnp.diag(C_inv @ (jnp.eye(d) - B @ C_inv)))*f + \ 108 | 2*jnp.sum(jnp.diag(C_inv @ C_inv @ (B @ C_inv - jnp.eye(d)/2)))*f**2 109 | 110 | # Update b_hat and Sigma 111 | Sigma_cond = 1/(1/(Sigma + self.rho_b) + H/2) 112 | b_hat_cond = b_hat_cond - Sigma_cond*g/2 113 | 114 | bel_cond = VikingBel( 115 | mean=mean_cond, 116 | cov=cov_cond, 117 | dynamics_noise_mean=b_hat_cond, 118 | dynamics_noise_cov=Sigma_cond, 119 | emission_noise_mean=a_hat_cond, 120 | emission_noise_cov=s_cond, 121 | ) 122 | 123 | return bel_cond 124 | 125 | def scan( 126 | self, 127 | initial_mean, 128 | initial_covariance, 129 | X, 130 | Y, 131 | ): 132 | X, Y = jnp.array(X), jnp.array(Y) 133 | num_timesteps = X.shape[0] 134 | bel = self.init_bel(initial_mean, initial_covariance) 135 | def step(bel, t): 136 | x, y = X[t], Y[t] 137 | def _step(bel, _): 138 | bel = self.update_bel(bel, x, y) 139 | return bel, None 140 | bel, _ = scan(_step, bel, jnp.arange(self.n_iter)) 141 | return bel, None 142 | 143 | bel, _ = scan(step, bel, jnp.arange(num_timesteps)) 144 | 145 | return bel 146 | -------------------------------------------------------------------------------- /rebayes/ivon/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/rebayes/ivon/__init__.py -------------------------------------------------------------------------------- /rebayes/linear_filter/kf.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import chex 3 | import jax.numpy as jnp 4 | from rebayes.base import Rebayes 5 | from functools import partial 6 | 7 | @chex.dataclass 8 | class KFBel: 9 | mean: chex.Array 10 | cov: chex.Array 11 | t: int=0 12 | 13 | 14 | class KalmanFilter(Rebayes): 15 | def __init__( 16 | self, 17 | transition_matrix, 18 | system_noise, 19 | observation_noise, 20 | observation_matrix=None, 21 | ): 22 | self.transition_matrix = transition_matrix 23 | self.observation_matrix = observation_matrix 24 | self.system_noise = system_noise 25 | self.observation_noise = jnp.atleast_2d(observation_noise) 26 | 27 | def get_trans_mat_of(self, t: int): 28 | if callable(self.transition_matrix): 29 | return self.transition_matrix(t) 30 | else: 31 | return self.transition_matrix 32 | 33 | def get_obs_mat_of(self, t: int): 34 | if callable(self.observation_matrix): 35 | return self.observation_matrix(t) 36 | else: 37 | return self.observation_matrix 38 | 39 | def get_system_noise_of(self, t: int): 40 | if callable(self.system_noise): 41 | return self.system_noise(t) 42 | else: 43 | return self.system_noise 44 | 45 | def get_observation_noise_of(self, t: int): 46 | if callable(self.observation_noise): 47 | return self.observation_noise(t) 48 | else: 49 | return self.observation_noise 50 | 51 | 52 | def init_bel(self, Xinit=None, Yinit=None): 53 | # TODO: implement 54 | raise NotImplementedError 55 | 56 | @partial(jax.jit, static_argnums=(0,)) 57 | def predict_state(self, bel): 58 | A = self.get_trans_mat_of(bel.t) 59 | Q = self.get_system_noise_of(bel.t) 60 | 61 | Sigma_pred = A @ bel.cov @ A.T + Q 62 | mean_pred = A @ bel.mean 63 | 64 | bel = bel.replace( 65 | mean=mean_pred, 66 | cov=Sigma_pred 67 | ) 68 | return bel 69 | 70 | @partial(jax.jit, static_argnums=(0,)) 71 | def update_state(self, bel, C, y): 72 | # C = self.get_obs_mat_of(bel.t) 73 | C = jnp.atleast_2d(C) 74 | R = self.get_observation_noise_of(bel.t) 75 | S = C @ bel.cov @ C.T + R 76 | K = jnp.linalg.solve(S, C @ bel.cov).T 77 | 78 | pred_obs = C @ bel.mean 79 | innovation = y - pred_obs 80 | mean = bel.mean + K @ innovation 81 | 82 | I = jnp.eye(len(mean)) 83 | tmp = I - K @ C 84 | cov = tmp @ bel.cov @ tmp.T + K @ R @ K.T 85 | 86 | bel = bel.replace( 87 | mean=mean, 88 | cov=cov, 89 | t=bel.t + 1 90 | ) 91 | return bel 92 | 93 | @partial(jax.jit, static_argnums=(0,)) 94 | def predict_obs(self, bel, x): 95 | bel = self.predict_state(bel) 96 | y_pred = jnp.einsum("i,...i->...", bel.mean, x) 97 | return y_pred 98 | -------------------------------------------------------------------------------- /rebayes/low_rank_filter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/rebayes/low_rank_filter/__init__.py -------------------------------------------------------------------------------- /rebayes/low_rank_filter/ggt.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, Union 3 | 4 | import chex 5 | from jax import jit, grad 6 | import jax.numpy as jnp 7 | from jaxtyping import Array, Float, Int 8 | 9 | from rebayes.base import Rebayes 10 | 11 | 12 | _vec_pinv = lambda v: jnp.where(v != 0, 1/jnp.array(v), 0) # Vector pseudo-inverse 13 | 14 | 15 | @chex.dataclass 16 | class GGTBel: 17 | mean: chex.Array 18 | gradients: chex.Array 19 | prev_grad: chex.Array 20 | num_obs: int = 0 21 | 22 | def update_buffer(self, g, beta): 23 | _, r = self.gradients.shape 24 | ix_buffer = self.num_obs % r 25 | gradients = beta * self.gradients 26 | gradients = gradients.at[:, ix_buffer].set(g) 27 | 28 | return self.replace( 29 | gradients = gradients, 30 | num_obs = self.num_obs + 1, 31 | ) 32 | 33 | 34 | @chex.dataclass 35 | class GGTParams: 36 | initial_mean: Float[Array, "state_dim"] 37 | apply_fn: Callable 38 | loss_fn: Callable 39 | memory_size: int 40 | learning_rate: float 41 | beta1: float = 0.9 # momentum term 42 | beta2: float = 1.0 # forgetting term 43 | eps: float = 0 # regularization term 44 | 45 | 46 | class RebayesGGT(Rebayes): 47 | def __init__( 48 | self, 49 | params: GGTParams, 50 | ): 51 | super().__init__(params) 52 | 53 | def init_bel(self) -> GGTBel: 54 | m0 = self.params.initial_mean 55 | d, r = m0.shape[0], self.params.memory_size 56 | 57 | bel = GGTBel( 58 | mean = m0, 59 | gradients = jnp.zeros((d, r)), 60 | prev_grad = jnp.zeros(d), 61 | ) 62 | 63 | return bel 64 | 65 | @partial(jit, static_argnums=(0,)) 66 | def predict_obs( 67 | self, 68 | bel: GGTBel, 69 | X: Float[Array, "input_dim"] 70 | ) -> Union[Float[Array, "output_dim"], Any]: 71 | y_pred = jnp.atleast_1d(self.params.apply_fn(bel.mean, X)) 72 | 73 | return y_pred 74 | 75 | @partial(jit, static_argnums=(0,)) 76 | def update_state( 77 | self, 78 | bel: GGTBel, 79 | X: Float[Array, "input_dim"], 80 | y: Float[Array, "output_dim"], 81 | ) -> GGTBel: 82 | g_prev = bel.prev_grad 83 | beta1, beta2, eps, eta = \ 84 | self.params.beta1, self.params.beta2, self.params.eps, self.params.learning_rate 85 | 86 | g = beta1 * g_prev + (1-beta1) * grad(self.params.loss_fn)(bel.mean, X, y) 87 | bel = bel.update_buffer(g, beta2) 88 | G = bel.gradients 89 | 90 | V, S, _ = jnp.linalg.svd(G.T @ G, full_matrices=False, hermitian=True) 91 | Sig = jnp.sqrt(S) 92 | U = G @ (V * _vec_pinv(Sig)) 93 | update = _vec_pinv(eps) * g + (U * (_vec_pinv(Sig + eps) - _vec_pinv(eps))) @ (U.T @ g) 94 | 95 | bel_post = bel.replace( 96 | mean = bel.mean - eta * update, 97 | prev_grad = g, 98 | ) 99 | 100 | return bel_post -------------------------------------------------------------------------------- /rebayes/low_rank_filter/slang.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any 3 | 4 | import chex 5 | from jax import jit, grad 6 | from jax.lax import scan 7 | import jax.numpy as jnp 8 | import jax.random as jr 9 | from jaxtyping import Float, Array 10 | from tensorflow_probability.substrates.jax.distributions import ( 11 | MultivariateNormalDiag as MVND, 12 | MultivariateNormalFullCovariance as MVN, 13 | ) 14 | 15 | from rebayes.base import ( 16 | FnStateAndInputToEmission, 17 | FnStateAndInputToEmission2, 18 | Rebayes, 19 | ) 20 | 21 | 22 | @chex.dataclass 23 | class SLANGBel: 24 | mean: Float[Array, "state_dim"] 25 | cov_diag: Float[Array, "state_dim"] 26 | cov_lr: Float[Array, "state_dim memory_size"] 27 | step: int = 0 28 | 29 | 30 | @chex.dataclass 31 | class SLANGParams: 32 | """Lightweight container for SLANG parameters. 33 | """ 34 | initial_mean: Float[Array, "state_dim"] 35 | initial_cov_diag: Float[Array, "state_dim"] 36 | initial_cov_lr: Float[Array, "state_dim memory_size"] 37 | emission_mean_function: FnStateAndInputToEmission 38 | emission_cov_function: FnStateAndInputToEmission2 39 | lamb: float 40 | alpha: float 41 | beta: float 42 | batch_size: int 43 | n_train: int 44 | n_eig: int = 10 45 | likelihood_dist: Any = lambda m, sigma : MVN(loc=m, covariance_matrix=sigma) 46 | 47 | 48 | class SLANG(Rebayes): 49 | def __init__( 50 | self, 51 | params: SLANGParams, 52 | ): 53 | self.params = params 54 | self.log_lik = lambda mu, x, y: \ 55 | self.params.likelihood_dist( 56 | self.params.emission_mean_function(mu, x), 57 | self.params.emission_cov_function(mu, x) 58 | ).log_prob(y) 59 | self.grad_log_lik = jit(grad(self.log_lik, argnums=(0))) 60 | 61 | def init_bel( 62 | self 63 | ) -> SLANGBel: 64 | bel = SLANGBel( 65 | mean=self.params.initial_mean, 66 | cov_diag=self.params.initial_cov_diag, 67 | cov_lr=self.params.initial_cov_lr, 68 | ) 69 | 70 | return bel 71 | 72 | @partial(jit, static_argnums=(0,)) 73 | def predict_obs( 74 | self, 75 | bel: SLANGBel, 76 | x: Float[Array, "input_dim"] 77 | ) -> Float[Array, "output_dim"]: 78 | y_pred = self.params.emission_mean_function(bel.mean, x) 79 | 80 | return y_pred 81 | 82 | @partial(jit, static_argnums=(0,)) 83 | def update_state( 84 | self, 85 | bel: SLANGBel, 86 | x: Float[Array, "input_dim"], 87 | y: Float[Array, "output_dim"], 88 | ) -> SLANGBel: 89 | """Update the belief state given an input and output. 90 | """ 91 | m, U, d = bel.mean, bel.cov_lr, bel.cov_diag 92 | D, L = U.shape 93 | alpha, beta, lamb, n_eig = \ 94 | self.params.alpha, self.params.beta, self.params.lamb, self.params.n_eig 95 | 96 | theta = self._fast_sample(bel) 97 | g = self.grad_log_lik(theta, x, y).reshape((D, -1)) 98 | V = self._fast_eig(bel, g, beta, n_eig) 99 | diag_corr = (1-beta) * (U**2).sum(axis=1) + beta * (g**2).ravel() - (V**2).sum(axis=1) 100 | 101 | U_post = V 102 | d_post = (1-beta) * d + diag_corr + lamb * jnp.ones(D) 103 | 104 | ghat = g + lamb*m.reshape((D, -1)) 105 | m_post = m - alpha*self._fast_inverse(SLANGBel(mean=m, cov_lr=U_post, cov_diag=d_post), ghat) 106 | 107 | # Construct the new belief 108 | bel = SLANGBel( 109 | mean=m_post, 110 | cov_lr=U_post, 111 | cov_diag=d_post, 112 | ) 113 | 114 | return bel 115 | 116 | def _fast_sample( 117 | self, 118 | bel: SLANGBel, 119 | ): 120 | mu, U, d = bel.mean, bel.cov_lr, bel.cov_diag 121 | key = jr.PRNGKey(bel.step) 122 | 123 | D, L = U.shape 124 | eps = MVND(loc=jnp.zeros(D,), scale_diag=jnp.ones(D,)).sample(seed=key) 125 | dd = 1/jnp.sqrt(d).reshape((d.shape[0], 1)) 126 | V = U * dd 127 | A = jnp.linalg.cholesky(V.T @ V) 128 | B = jnp.linalg.cholesky(jnp.eye(L) + V.T @ V) 129 | C = jnp.linalg.pinv(A.T) @ (B - jnp.eye(L)) @ jnp.linalg.pinv(A) 130 | K = jnp.linalg.pinv(jnp.linalg.pinv(C) + V.T @ V) 131 | y = dd.ravel() * eps - (V * dd) @ K @ (V.T @ eps) 132 | 133 | return mu + y 134 | 135 | def _fast_eig( 136 | self, 137 | bel, 138 | g, 139 | beta, 140 | n_iter, 141 | ): 142 | key = jr.PRNGKey(bel.step+1) 143 | U = bel.cov_lr 144 | D, L = U.shape 145 | K = L + 2 146 | Q = jr.uniform(key, shape=(D, K), minval=-1.0, maxval=1.0) 147 | 148 | def _orth_step(carry, i): 149 | Q = carry 150 | AQ = (1-beta) * U @ (U.T@Q) + beta * g @ (g.T @ Q) 151 | Q_orth, _ = jnp.linalg.qr(AQ) 152 | 153 | return Q_orth, Q_orth 154 | 155 | Q_orth, _ = scan(_orth_step, Q, jnp.arange(n_iter)) 156 | V, *_ = jnp.linalg.svd(Q_orth, full_matrices=False) 157 | V = V[:, :L] 158 | 159 | return V 160 | 161 | def _fast_inverse( 162 | self, 163 | bel, 164 | g, 165 | ): 166 | U, d = bel.cov_lr, bel.cov_diag 167 | _, L = U.shape 168 | dinv = (1/d).reshape((d.shape[0], 1)) 169 | A = jnp.linalg.pinv(jnp.eye(L) + U.T @ (U*dinv)) 170 | y = dinv.ravel() * g.ravel() - ((U*dinv) @ A) @ ((U*dinv).T @ g).ravel() 171 | 172 | return y -------------------------------------------------------------------------------- /rebayes/low_rank_filter/subspace_filter.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | from typing import Callable 5 | from jax.flatten_util import ravel_pytree 6 | 7 | def subcify(cls): 8 | class SubspaceModule(nn.Module): 9 | dim_in: int 10 | dim_subspace: int 11 | init_normal: Callable = nn.initializers.normal() 12 | init_proj: Callable = nn.initializers.normal() 13 | 14 | def init(self, rngs, *args, **kwargs): 15 | r1, r2 = jax.random.split(rngs, 2) 16 | rngs_dict = {"params": r1, "fixed": r2} 17 | 18 | return nn.Module.init(self, rngs_dict, *args, **kwargs) 19 | 20 | def setup(self): 21 | key_dummy = jax.random.PRNGKey(0) 22 | params = cls().init(key_dummy, jnp.ones((1, self.dim_in))) 23 | params_all, reconstruct_fn = ravel_pytree(params) 24 | 25 | self.dim_full = len(params_all) 26 | self.reconstruct_fn = reconstruct_fn 27 | 28 | self.subspace = self.param( 29 | "subspace", 30 | self.init_proj, 31 | (self.dim_subspace,) 32 | ) 33 | 34 | shape = (self.dim_full, self.dim_subspace) 35 | init_fn = lambda shape: self.init_proj(self.make_rng("fixed"), shape) 36 | self.projection = self.variable("fixed", "P", init_fn, shape).value 37 | 38 | shape = (self.dim_full,) 39 | init_fn = lambda shape: self.init_proj(self.make_rng("fixed"), shape) 40 | self.bias = self.variable("fixed", "b", init_fn, shape).value 41 | 42 | @nn.compact 43 | def __call__(self, x): 44 | params = self.projection @ self.subspace + self.bias 45 | params = self.reconstruct_fn(params) 46 | return cls().apply(params, x) 47 | 48 | return SubspaceModule 49 | -------------------------------------------------------------------------------- /rebayes/low_rank_filter/test_dual_lofi.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import jax.numpy as jnp 4 | from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN 5 | 6 | from rebayes.dual_base import dual_rebayes_scan, DualRebayesParams, ObsModel 7 | from rebayes.low_rank_filter.lofi import ( 8 | LoFiParams, 9 | INFLATION_METHODS, 10 | RebayesLoFiOrthogonal, 11 | RebayesLoFiSpherical, 12 | RebayesLoFiDiagonal, 13 | ) 14 | from rebayes.low_rank_filter.dual_lofi import ( 15 | DualLoFiParams, 16 | make_dual_lofi_orthogonal_estimator, 17 | make_dual_lofi_spherical_estimator, 18 | make_dual_lofi_diagonal_estimator, 19 | ) 20 | from rebayes.extended_kalman_filter.test_ekf import make_linreg_rebayes_params, make_linreg_data, make_linreg_prior 21 | from rebayes.utils.utils import get_mlp_flattened_params 22 | 23 | 24 | RebayesLoFiEstimators = [RebayesLoFiOrthogonal, RebayesLoFiSpherical, RebayesLoFiDiagonal] 25 | make_dual_lofi_estimator_methods = [make_dual_lofi_orthogonal_estimator, make_dual_lofi_spherical_estimator, make_dual_lofi_diagonal_estimator] 26 | 27 | 28 | def allclose(u, v): 29 | return jnp.allclose(u, v, atol=1e-2) 30 | 31 | def make_linreg_dual_params(nfeatures): 32 | (obs_var, mu0, Sigma0) = make_linreg_prior() 33 | 34 | # Define Linear Regression as MLP with no hidden layers 35 | input_dim, hidden_dims, output_dim = nfeatures, [], 1 36 | model_dims = [input_dim, *hidden_dims, output_dim] 37 | *_, apply_fn = get_mlp_flattened_params(model_dims) 38 | 39 | params = DualRebayesParams( 40 | mu0=mu0, 41 | eta0=1/Sigma0[0,0], 42 | dynamics_scale_factor = 1.0, 43 | dynamics_noise = 0.0, 44 | obs_noise = obs_var, 45 | cov_inflation_factor = 0, 46 | ) 47 | obs_model = ObsModel( 48 | emission_mean_function = lambda w, x: apply_fn(w, x), 49 | emission_cov_function = lambda w, x: obs_var 50 | ) 51 | 52 | return params, obs_model 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "steady_state, inflation_type, estimator_class, make_dual_estimator_method", 57 | [(ss, it, RebayesLoFiEstimators[i], make_dual_lofi_estimator_methods[i]) 58 | for ss in [True, False] for it in INFLATION_METHODS for i in range(len(RebayesLoFiEstimators))] 59 | ) 60 | def test_lofi_adaptive_backwards_compatibility(steady_state, inflation_type, estimator_class, make_dual_estimator_method): 61 | (X, Y) = make_linreg_data() 62 | 63 | # old estimator 64 | _, D = X.shape 65 | params = make_linreg_rebayes_params(D) 66 | params.adaptive_emission_cov = True 67 | lofi_params = LoFiParams(memory_size=1, steady_state=steady_state, inflation=inflation_type) 68 | estimator = estimator_class(params, lofi_params) 69 | final_bel, _ = estimator.scan(X, Y) 70 | obs_noise_lofi = final_bel.obs_noise_var 71 | 72 | params, obs_model = make_linreg_dual_params(D) 73 | params.obs_noise = 1.0 * jnp.eye(1) 74 | dual_lofi_params = DualLoFiParams( 75 | memory_size=1, 76 | inflation=inflation_type, 77 | steady_state=steady_state, 78 | obs_noise_estimator = "post", 79 | obs_noise_lr_fn= lambda t: 1.0/(t+1) 80 | ) 81 | 82 | dual_estimator = make_dual_estimator_method(params, obs_model, dual_lofi_params) 83 | carry, _ = dual_rebayes_scan(dual_estimator, X, Y) 84 | params, final_bel = carry 85 | obs_noise_dual = params.obs_noise 86 | 87 | assert allclose(obs_noise_dual, obs_noise_lofi) -------------------------------------------------------------------------------- /rebayes/mcmc_filter/hamiltonian_monte_carlo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hamiltonian Monte Carlo for Bayesian Neural Network 3 | """ 4 | 5 | import jax 6 | import distrax 7 | import blackjax 8 | import jax.numpy as jnp 9 | import flax.linen as nn 10 | from chex import dataclass 11 | from tqdm.auto import tqdm 12 | from functools import partial 13 | from jax_tqdm import scan_tqdm 14 | from typing import Callable, Union, List 15 | from jaxtyping import Float, Array, PyTree 16 | from jax.flatten_util import ravel_pytree 17 | 18 | @dataclass 19 | class PriorParam: 20 | scale_obs: float 21 | scale_weight: float 22 | 23 | 24 | def get_leaves(params): 25 | flat_params, _ = ravel_pytree(params) 26 | return flat_params 27 | 28 | 29 | def log_joint( 30 | params: nn.FrozenDict, 31 | X: Float[Array, "num_obs dim_obs"], 32 | y: Float[Array, "num_obs"], 33 | apply_fn: Callable[[PyTree[float], Float[Array, "num_obs dim_obs"]], Float[Array, "num_obs"]], 34 | priors: PriorParam, 35 | ): 36 | """ 37 | We sample from a BNN posterior assuming 38 | p(w{i}) = N(0, scale_prior) ∀ i 39 | P(y | w, X) = N(apply_fn(w, X), scale_obs) 40 | 41 | TODO: 42 | * Add more general way to compute observation-model log-probability 43 | """ 44 | scale_obs = priors.scale_obs 45 | scale_prior = priors.scale_weight 46 | 47 | params_flat = get_leaves(params) 48 | 49 | # Prior log probability (use initialised vals for mean?) 50 | logp_prior = distrax.Normal(loc=0.0, scale=scale_prior).log_prob(params_flat).sum() 51 | 52 | # Observation log-probability 53 | mu_obs = apply_fn(params, X).ravel() 54 | logp_obs = distrax.Normal(loc=mu_obs, scale=scale_obs).log_prob(y).sum() 55 | 56 | logprob = logp_prior + logp_obs 57 | return logprob 58 | 59 | 60 | def inference_loop(rng_key, kernel, initial_state, num_samples, tqdm=True): 61 | def one_step(state, num_step): 62 | key = jax.random.fold_in(rng_key, num_step) 63 | state, _ = kernel(key, state) 64 | return state, state 65 | 66 | if tqdm: 67 | one_step = scan_tqdm(num_samples)(one_step) 68 | 69 | steps = jnp.arange(num_samples) 70 | _, states = jax.lax.scan(one_step, initial_state, steps) 71 | 72 | return states 73 | 74 | 75 | def inference( 76 | key: jax.random.PRNGKey, 77 | apply_fn: Callable, 78 | log_joint: Callable, 79 | params_init: nn.FrozenDict, 80 | priors: PriorParam, 81 | X: Float[Array, "num_obs ..."], 82 | y: Float[Array, "num_obs"], 83 | num_warmup: int, 84 | num_steps: int, 85 | tqdm: bool = True, 86 | ): 87 | key_warmup, key_train = jax.random.split(key) 88 | potential = partial( 89 | log_joint, 90 | priors=priors, X=X, y=y, apply_fn=apply_fn 91 | ) 92 | 93 | adapt = blackjax.window_adaptation(blackjax.nuts, potential, num_warmup) 94 | final_state, kernel, _ = adapt.run(key_warmup, params_init) 95 | states = inference_loop(key_train, kernel, final_state, num_steps, tqdm) 96 | 97 | return states 98 | 99 | 100 | class RebayesHMC: 101 | def __init__(self, apply_fn, priors, log_joint, num_samples, num_warmup): 102 | self.apply_fn = apply_fn 103 | self.priors = priors 104 | self.log_joint = log_joint 105 | self.num_samples = num_samples 106 | self.num_warmup = num_warmup 107 | 108 | 109 | @partial(jax.jit, static_argnums=(0,)) 110 | def eval(self, bel, X): 111 | """ 112 | Evaluate the model at the given parameters 113 | """ 114 | yhat_samples = jax.vmap(self.apply_fn, (0, None))(bel, X) 115 | return yhat_samples 116 | 117 | def predict_obs(self, bel, X): 118 | """ 119 | Estimate posterior predictive 120 | """ 121 | yhat_samples = self.eval(bel, X) 122 | yhat = yhat_samples.mean(axis=0) 123 | return yhat 124 | 125 | def predict_state(self, bel, X): 126 | return bel 127 | 128 | def update_state(self, bel, X, y, key, tqdm=False, return_state=False): 129 | state = inference( 130 | key, self.apply_fn, self.log_joint, bel, self.priors, 131 | X, y, self.num_warmup, self.num_samples, 132 | tqdm=tqdm 133 | ) 134 | if return_state: 135 | return state 136 | else: 137 | return state.position 138 | 139 | def scan( 140 | self, 141 | key: jax.random.PRNGKey, 142 | params_init: nn.FrozenDict, 143 | X: Float[Array, "ntime ..."], 144 | y: Float[Array, "ntime emission_dim"], 145 | eval_steps: Union[List, None] = None, 146 | callback: Callable = None, 147 | ): 148 | num_samples = len(y) 149 | if eval_steps is None: 150 | eval_steps = list(range(num_samples)) 151 | 152 | params_hist = {} 153 | for n_eval in tqdm(eval_steps): 154 | X_eval = X[:n_eval] 155 | y_eval = y[:n_eval] 156 | 157 | bel_update = self.update_state(params_init, X_eval, y_eval, key) 158 | params_hist[n_eval] = bel_update 159 | 160 | if callback is not None: 161 | callback(bel_update, n_eval) 162 | 163 | return params_hist 164 | -------------------------------------------------------------------------------- /rebayes/sgd_filter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/rebayes/sgd_filter/__init__.py -------------------------------------------------------------------------------- /rebayes/sgd_filter/sgd.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | from typing import Callable, Tuple, Union 5 | from jaxtyping import Float, Int, Array, PyTree 6 | from flax.training.train_state import TrainState 7 | 8 | @partial(jax.jit, static_argnames=("applyfn",)) 9 | def lossfn(params, X, y, applyfn): 10 | yhat = applyfn(params, X) 11 | mll = (y - yhat.ravel()) ** 2 12 | return mll.mean() 13 | 14 | 15 | @partial(jax.jit, static_argnames=("applyfn",)) 16 | def rmae(params, X, y, applyfn): 17 | yhat = applyfn(params, X) 18 | err = jnp.abs(y - yhat.ravel()) 19 | return err.mean() 20 | 21 | 22 | @partial(jax.jit, static_argnames=("loss_grad",)) 23 | def train_step( 24 | X: Float[Array, "num_obs dim_obs"], 25 | y: Float[Array, "num_obs"], 26 | ixs: Int[Array, "batch_len"], 27 | state: TrainState, 28 | loss_grad: Callable, 29 | ) -> Tuple[float, TrainState]: 30 | """ 31 | Perform a single training step. 32 | The `loss_grad` function 33 | """ 34 | X_batch = X[ixs] 35 | y_batch = y[ixs] 36 | loss, grads = loss_grad(state.params, X_batch, y_batch, state.apply_fn) 37 | state = state.apply_gradients(grads=grads) 38 | return loss, state 39 | 40 | 41 | @partial(jax.jit, static_argnums=(1,2)) 42 | def get_batch_train_ixs(key, num_samples, batch_size): 43 | """ 44 | Obtain the training indices to be used in an epoch of 45 | mini-batch optimisation. 46 | """ 47 | steps_per_epoch = num_samples // batch_size 48 | 49 | batch_ixs = jax.random.permutation(key, num_samples) 50 | batch_ixs = batch_ixs[:steps_per_epoch * batch_size] 51 | batch_ixs = batch_ixs.reshape(steps_per_epoch, batch_size) 52 | 53 | return batch_ixs 54 | 55 | 56 | def train_epoch( 57 | key: int, 58 | batch_size: int, 59 | state: TrainState, 60 | X: Float[Array, "num_obs dim_obs"], 61 | y: Float[Array, "num_obs"], 62 | loss_grad: Callable, 63 | ): 64 | num_train = X.shape[0] 65 | loss_epoch = 0.0 66 | train_ixs = get_batch_train_ixs(key, num_train, batch_size) 67 | for ixs in train_ixs: 68 | loss, state = train_step(X, y, ixs, state, loss_grad) 69 | loss_epoch += loss 70 | return loss_epoch, state 71 | 72 | 73 | def train_full( 74 | key: jax.random.PRNGKey, 75 | num_epochs: int, 76 | batch_size: int, 77 | state: TrainState, 78 | X: Float[Array, "num_obs dim_obs"], 79 | y: Float[Array, "num_obs"], 80 | loss: Callable[[PyTree, Float[Array, "num_obs dim_obs"], Float[Array, "num_obs"], Callable], float], 81 | X_test: Union[None, Float[Array, "num_obs_test dim_obs"]] = None, 82 | y_test: Union[None, Float[Array, "num_obs_test"]] = None, 83 | ): 84 | loss_grad = jax.value_and_grad(loss, 0) 85 | 86 | def epoch_step(state, t): 87 | keyt = jax.random.fold_in(key, t) 88 | loss_train, state = train_epoch(keyt, batch_size, state, X, y, loss_grad) 89 | 90 | if (X_test is not None) and (y_test is not None): 91 | loss_test = loss(state.params, X_test, y_test, state.apply_fn) 92 | else: 93 | loss_test = None 94 | 95 | losses = { 96 | "train": loss_train, 97 | "test": loss_test, 98 | } 99 | return state, losses 100 | steps = jnp.arange(num_epochs) 101 | state, losses = jax.lax.scan(epoch_step, state, steps) 102 | return state, losses 103 | -------------------------------------------------------------------------------- /rebayes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/rebayes/utils/__init__.py -------------------------------------------------------------------------------- /rebayes/utils/normalizing_flows.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from flax import linen as nn 4 | import jax 5 | from jax import jacrev, jacfwd, jit, lax, vmap 6 | from jax.flatten_util import ravel_pytree 7 | import jax.numpy as jnp 8 | import jax.random as jr 9 | from matplotlib import animation 10 | import numpy as np 11 | import tensorflow_probability.substrates.jax as tfp 12 | from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf 13 | 14 | tfd = tfp.distributions 15 | tfb = tfp.bijectors 16 | 17 | 18 | class NF_MLP(nn.Module): 19 | n_units: int=128 20 | n_layers: int=2 21 | # create a Flax Module dataclass 22 | @nn.compact 23 | def __call__(self, x): 24 | x = x.ravel() 25 | z = x 26 | for _ in range(self.n_layers): 27 | z = nn.Dense(self.n_units)(z) 28 | z = nn.relu(z) 29 | z = nn.Dense( 30 | x.shape[0]*2, 31 | kernel_init=lambda key, shape, dtype: jnp.full(shape, 1e-4, dtype=dtype), 32 | bias_init=lambda key, shape, dtype: jnp.full(shape, 1e-4, dtype=dtype) 33 | )(z) # shape inference 34 | 35 | return z 36 | 37 | def _batch_apply(unflatten_fn, apply_fn, params, x): 38 | # Convert x to a batch if it's not already one 39 | if len(x.shape) == 1: 40 | result = apply_fn(unflatten_fn(params), x) 41 | else: 42 | result = vmap(apply_fn, (None, 0))(unflatten_fn(params), x) 43 | 44 | return result 45 | 46 | 47 | def generate_shift_and_log_scale_fn(apply_fn, params): 48 | def shift_and_log_scale_fn(x, *args, **kwargs): 49 | result = apply_fn(params, x) 50 | shift, log_scale = jnp.split(result, 2, axis=-1) 51 | 52 | return shift, log_scale 53 | 54 | return shift_and_log_scale_fn 55 | 56 | 57 | def construct_bijector(apply_fn, params, power): 58 | sl_fn = generate_shift_and_log_scale_fn(apply_fn, params) 59 | bijector = tfb.RealNVP( 60 | fraction_masked=0.5*(-1)**power, 61 | shift_and_log_scale_fn=sl_fn 62 | ) 63 | 64 | return bijector 65 | 66 | 67 | def construct_flow(apply_fn, params_stack): 68 | n, *_ = params_stack.shape 69 | bijector = [] 70 | for i in range(n): 71 | bijector.append(construct_bijector(apply_fn, params_stack[i], i)) 72 | bijector = tfb.Chain(bijector) 73 | 74 | return bijector 75 | 76 | 77 | def init_normalizing_flow(model, input_dim, n_layers=4, key=0): 78 | input_dim = int(input_dim) 79 | assert input_dim % 2 == 0 # Even number of input dimensions 80 | if isinstance(key, int): 81 | key = jr.PRNGKey(key) 82 | keys = jr.split(key, n_layers) 83 | 84 | params_stack, bijectors = [], [] 85 | 86 | for i in range(n_layers): 87 | input = jnp.zeros(input_dim//2) 88 | params = model.init(keys[i], input) 89 | flat_params, unflatten_fn = ravel_pytree(params) 90 | params_stack.append(flat_params) 91 | apply_fn = lambda w, x: \ 92 | _batch_apply(unflatten_fn, model.apply, w, x) 93 | sl_fn = generate_shift_and_log_scale_fn(apply_fn, flat_params) 94 | bijector = tfb.RealNVP( 95 | fraction_masked=0.5*(-1)**i, 96 | shift_and_log_scale_fn=sl_fn 97 | ) 98 | bijectors.append(bijector) 99 | 100 | params_stack = jnp.stack(params_stack) 101 | bijector = tfb.Chain(bijectors) 102 | 103 | apply_fn = lambda w, x: \ 104 | _batch_apply(unflatten_fn, model.apply, w, x) 105 | 106 | result = { 107 | "params": params_stack, 108 | "input_dim": input_dim, 109 | "n_layers": n_layers, 110 | "apply_fn": apply_fn, 111 | "bijector": bijector 112 | } 113 | 114 | return result -------------------------------------------------------------------------------- /rebayes/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | def fourier_basis(domain, b_coef, include_bias=False): 4 | """ 5 | Example; 6 | num_basis, input_dim = 7, 1 7 | b = jax.random.normal(key_basis, (num_basis, input_dim)) * 0.8 8 | Xf = fourier_basis(X, b); # X.shape == (n_obs, input_dim) 9 | """ 10 | n_obs = len(domain) 11 | # We take aj=1 12 | elements = jnp.einsum("...m,bm->...b", domain, b_coef) 13 | elements = 2 * jnp.pi * elements 14 | cos_elements = jnp.cos(elements) 15 | sin_elements = jnp.sin(elements) 16 | 17 | elements = jnp.concatenate([cos_elements, sin_elements], axis=-1) 18 | if include_bias: 19 | ones_shape = elements.shape[:-1] 20 | ones = jnp.ones(ones_shape)[..., None] 21 | elements = jnp.append(elements, ones, axis=-1) 22 | return elements 23 | -------------------------------------------------------------------------------- /rebayes/utils/sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for sampling from a distribution. 3 | """ 4 | import jax 5 | import numpy as np 6 | import jax.numpy as jnp 7 | import jax.random as jr 8 | from functools import partial 9 | 10 | 11 | def sample_dlr_single(key, W, diag, temperature=1.0): 12 | """ 13 | Sample from an MVG with diagonal + low-rank 14 | covariance matrix. See §4.2.2, Proposition 1 of 15 | L-RVGA paper 16 | """ 17 | key_x, key_eps = jax.random.split(key) 18 | diag_inv = (1 / diag).ravel() 19 | diag_inv_mod = diag_inv * temperature 20 | D, d = W.shape 21 | 22 | ID = jnp.eye(D) 23 | Id = jnp.eye(d) 24 | 25 | M = Id + jnp.einsum("ji,j,jk->ik", W, diag_inv, W) 26 | L = jnp.sqrt(temperature) * jnp.linalg.solve(M.T, jnp.einsum("ji,j->ij", W, diag_inv)).T 27 | 28 | x = jax.random.normal(key_x, (D,)) * jnp.sqrt(diag_inv_mod) 29 | eps = jax.random.normal(key_eps, (d,)) 30 | 31 | x_plus = jnp.einsum("ij,kj,k->i", L, W, x) 32 | x_plus = x - x_plus + jnp.einsum("ij,j->i", L, eps) 33 | 34 | return x_plus 35 | 36 | 37 | @partial(jax.jit, static_argnums=(4,)) 38 | def sample_dlr(key, W, diag, temperature=1.0, shape=None): 39 | shape = (1,) if shape is None else shape 40 | n_elements = np.prod(shape) 41 | keys = jax.random.split(key, n_elements) 42 | samples = jax.vmap(sample_dlr_single, in_axes=(0, None, None, None))(keys, W, diag, temperature) 43 | samples = samples.reshape(*shape, -1) 44 | 45 | return samples 46 | -------------------------------------------------------------------------------- /rebayes/utils/split_mnist_data_test.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | #pytest split_mnist_data_test.py -rP 4 | 5 | import pytest 6 | import haiku as hk 7 | import jax 8 | import jax.numpy as jnp 9 | import jax.random as jr 10 | import numpy as np 11 | 12 | from jaxtyping import Float, Array 13 | from typing import Callable, NamedTuple, Union, Tuple, Any 14 | 15 | import jax_dataloader.core as jdl 16 | 17 | from rebayes.utils.split_mnist_data import SplitMNIST 18 | from rebayes.base import Rebayes, RebayesParams, make_rebayes_params, Belief 19 | 20 | def test_split_mnist(): 21 | ntrain_per_task = 10 22 | ntest_per_task = 4 23 | split_mnist = SplitMNIST(ntrain_per_task, ntest_per_task) 24 | X_01, y_01, L_01 = split_mnist.test_set_by_digit_pair['01'] 25 | X_23, y_23, L_23 = split_mnist.test_set_by_digit_pair['23'] 26 | print(X_01.shape, X_23.shape) 27 | assert all(np.unique(L_23)==np.array([2, 3])) 28 | assert all(np.unique(y_23)==np.array([0, 1])) 29 | 30 | X, Y, L, I = split_mnist.get_test_data_for_task(1) 31 | X_23, y_23, L_23 = split_mnist.test_set_by_digit_pair['23'] 32 | print(X.shape, X_23.shape) 33 | jnp.allclose(X, X_23[:ntest_per_task]) 34 | print(Y.shape, y_23.shape) 35 | jnp.allclose(Y, y_23[:ntest_per_task]) 36 | 37 | class RebayesSum(Rebayes): 38 | """The belief state is the sum of all the input X_t values.""" 39 | def __init__( 40 | self, 41 | params: RebayesParams, 42 | shape_in, 43 | shape_out 44 | ): 45 | self.params = params 46 | self.shape_in = shape_in 47 | self.shape_out = shape_out 48 | 49 | def init_bel(self) -> Belief: 50 | bel = Belief(dummy = jnp.zeros(self.shape_in)) 51 | return bel 52 | 53 | def update_state( 54 | self, 55 | bel: Belief, 56 | X: Float[Array, "input_dim"], 57 | Y: Float[Array, "obs_dim"] 58 | ) -> Belief: 59 | return Belief(dummy = bel.dummy + X) 60 | 61 | def test_split_mnist_rebayes(): 62 | ntrain_per_task = 10 63 | ntest_per_task = 4 64 | split_mnist = SplitMNIST(ntrain_per_task, ntest_per_task) 65 | 66 | def callback_dl(b, bel_pre, bel_post, batch): 67 | jax.debug.print("callback on batch {b}", b=b) 68 | Xtr, Ytr, Ltr, Itr = batch 69 | jax.debug.print("Xtr shape {x1}, Ytr shape {y1}", x1=Xtr.shape, y1=Ytr.shape) 70 | task = int(Itr[0]) 71 | Xte, Yte, Lte, Ite = split_mnist.get_test_data_for_seen_tasks(task) 72 | jax.debug.print("Xte shape {x1}, Yte shape {y1}", x1=Xte.shape, y1=Yte.shape) 73 | print('train labels ', Ltr) 74 | print('test labels ', Lte) 75 | #plot_batch(Xte, Yte, ttl='batch {:d}'.format(b)) 76 | return (Ltr, Itr, Lte, Ite) 77 | 78 | Xtr, Ytr, Ltr, Itr = split_mnist.get_training_data_all_tasks() 79 | print(Xtr.shape, Ytr.shape, Ltr.shape, Itr.shape) 80 | 81 | train_ds = jdl.Dataset(Xtr, Ytr, Ltr, Itr) 82 | ntrain_per_batch = 5 83 | train_loader = jdl.DataLoaderJax(train_ds, batch_size=ntrain_per_batch, shuffle=False, drop_last=False) 84 | 85 | shape_in = Xtr.shape[1:] 86 | shape_out = 1 87 | estimator = RebayesSum(make_rebayes_params(), shape_in, shape_out) 88 | 89 | bel, outputs = estimator.scan_dataloader(train_loader, callback=callback_dl) 90 | Xsum = jnp.sum(Xtr, axis=0) 91 | assert(jnp.allclose(Xsum, bel.dummy, atol=1e-2)) 92 | 93 | for b in range(len(train_loader)): 94 | print('batch ', b) 95 | out = outputs[b] 96 | print(out) 97 | (Ltr, Itr, Lte, Ite) = out 98 | task = Itr[0] 99 | Xte, Yte, Lte_expected, Ite_expected = split_mnist.get_test_data_for_seen_tasks(task) 100 | assert jnp.allclose(Lte_expected, Lte) 101 | -------------------------------------------------------------------------------- /rebayes/vi_filter/bayes_by_backprop.py: -------------------------------------------------------------------------------- 1 | """ 2 | [2] Farquhar, S., Osborne, M., & Gal, Y. (2019). 3 | Radial Bayesian Neural Networks: Beyond Discrete Support 4 | In Large-Scale Bayesian Deep Learning. doi:10.48550/ARXIV.1907.00865 5 | 6 | """ 7 | 8 | import jax 9 | import distrax 10 | import jax.numpy as jnp 11 | import flax.linen as nn 12 | from chex import dataclass 13 | from functools import partial 14 | from typing import Callable 15 | from jax.flatten_util import ravel_pytree 16 | from jaxtyping import Float, PyTree, Array 17 | 18 | 19 | @dataclass 20 | class BBBParams: 21 | mean: PyTree[Float] 22 | rho: PyTree[Float] 23 | 24 | 25 | def init_bbb_params(key, model, batch_init): 26 | key_mean, key_rho = jax.random.split(key) 27 | 28 | params_mean = model.init(key_mean, batch_init) 29 | flat_params, reconstruct_fn = ravel_pytree(params_mean) 30 | num_params = len(flat_params) 31 | 32 | params_rho = jax.random.normal(key_rho, (num_params,)) 33 | params_rho = reconstruct_fn(params_rho) 34 | 35 | bbb_params = BBBParams( 36 | mean=params_mean, 37 | rho=params_rho, 38 | ) 39 | 40 | return bbb_params, (reconstruct_fn, num_params) 41 | 42 | 43 | def transform(eps, mean, rho): 44 | std = jnp.log(1 + jnp.exp(rho)) 45 | weight = mean + std * eps 46 | return weight 47 | 48 | 49 | def sample_gauss_params(key, state:BBBParams, reconstruct_fn:Callable): 50 | """ 51 | Sample from a Gaussian distribution 52 | """ 53 | num_params = len(get_leaves(state.mean)) 54 | eps = jax.random.normal(key, (num_params,)) 55 | eps = reconstruct_fn(eps) 56 | 57 | params = jax.tree_map(transform, eps, state.mean, state.rho) 58 | return params 59 | 60 | 61 | def sample_rbnn_params(key, state:BBBParams, reconstruct_fn:Callable, scale:float=1.0): 62 | """ 63 | Sample from a radial Bayesian neural network 64 | radial BNN of [2]. We modify the definition of the 65 | RBNN to include a scale parameter, which allows us 66 | to control the prior uncertainty over the posterior predictive. 67 | """ 68 | key_eps, key_rho = jax.random.split(key) 69 | num_params = len(get_leaves(state.mean)) 70 | 71 | # The radial dimension. 72 | r = jax.random.normal(key_rho) * scale 73 | 74 | eps = jax.random.normal(key_eps, (num_params,)) 75 | eps = eps / jnp.linalg.norm(eps) * r 76 | eps = reconstruct_fn(eps) 77 | 78 | 79 | params = jax.tree_map(transform, eps, state.mean, state.rho) 80 | return params 81 | 82 | 83 | def get_leaves(params): 84 | flat_params, _ = ravel_pytree(params) 85 | return flat_params 86 | 87 | 88 | @partial(jax.jit, static_argnames=("num_samples", "batch_size")) 89 | def get_batch_train_ixs(key, num_samples, batch_size): 90 | """ 91 | Obtain the training indices to be used in an epoch of 92 | mini-batch optimisation. 93 | """ 94 | steps_per_epoch = num_samples // batch_size 95 | 96 | batch_ixs = jax.random.permutation(key, num_samples) 97 | batch_ixs = batch_ixs[:steps_per_epoch * batch_size] 98 | batch_ixs = batch_ixs.reshape(steps_per_epoch, batch_size) 99 | 100 | return batch_ixs 101 | 102 | 103 | def index_values_batch(X, y, ixs): 104 | """ 105 | Index values of a batch of observations 106 | """ 107 | X_batch = X[ixs] 108 | y_batch = y[ixs] 109 | return X_batch, y_batch 110 | 111 | 112 | def train_step(key, opt_state, X, y, lossfn, model, reconstruct_fn): 113 | params = opt_state.params 114 | apply_fn = opt_state.apply_fn 115 | 116 | loss, grads = jax.value_and_grad(lossfn, 1)(key, params, X, y, model, reconstruct_fn) 117 | opt_state_new = opt_state.apply_gradients(grads=grads) 118 | return opt_state_new, loss 119 | 120 | 121 | @partial(jax.jit, static_argnames=("lossfn", "model", "reconstruct_fn")) 122 | def split_and_train_step(key, opt_state, X, y, ixs, lossfn, model, reconstruct_fn): 123 | X_batch, y_batch = index_values_batch(X, y, ixs) 124 | opt_state, loss = train_step(key, opt_state, X_batch, y_batch, lossfn, model, reconstruct_fn) 125 | return opt_state, loss 126 | 127 | 128 | def train_epoch(key, state, X, y, batch_size, lossfn, model, reconstruct_fn): 129 | num_samples = len(X) 130 | key_batch, keys_train = jax.random.split(key) 131 | batch_ixs = get_batch_train_ixs(key_batch, num_samples, batch_size) 132 | 133 | num_batches = len(batch_ixs) 134 | keys_train = jax.random.split(keys_train, num_batches) 135 | 136 | total_loss = 0 137 | for key_step, batch_ix in zip(keys_train, batch_ixs): 138 | state, loss = split_and_train_step(key_step, state, X, y, batch_ix, lossfn, model, reconstruct_fn) 139 | total_loss += loss 140 | 141 | return total_loss.item(), state 142 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name="rebayes", 6 | packages=find_packages(), 7 | install_requires=[] 8 | ) 9 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/rebayes/3b880724541e913d7a7b2d06ee2e0407e66307ab/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | 2 | #pytest test_base.py -rP 3 | 4 | import pytest 5 | import numpy as np 6 | import jax 7 | import jax.numpy as jnp 8 | import jax.random as jr 9 | import time 10 | from functools import partial 11 | 12 | from jax import jit 13 | from jax.lax import scan 14 | from jaxtyping import Float, Array 15 | from typing import Callable, NamedTuple, Union, Tuple, Any 16 | import chex 17 | 18 | import haiku as hk 19 | 20 | from rebayes.base import RebayesParams, Rebayes, Belief, make_rebayes_params 21 | 22 | 23 | class RebayesSum(Rebayes): 24 | """The belief state is the sum of all the input X_t values.""" 25 | def __init__( 26 | self, 27 | params: RebayesParams, 28 | ndim_in: int, 29 | ndim_out: int 30 | ): 31 | self.params = params 32 | self.ndim_in = ndim_in 33 | self.ndim_out = ndim_out 34 | 35 | def init_bel(self) -> Belief: 36 | bel = Belief(dummy = jnp.zeros((self.ndim_in,))) 37 | return bel 38 | 39 | def update_state( 40 | self, 41 | bel: Belief, 42 | X: Float[Array, "input_dim"], 43 | Y: Float[Array, "obs_dim"] 44 | ) -> Belief: 45 | return Belief(dummy = bel.dummy + X) 46 | 47 | 48 | def make_data(): 49 | keys = hk.PRNGSequence(42) 50 | ndim_in = 5 51 | nclasses = 10 52 | ntime = 12 53 | #X = jnp.arange(ntime).reshape((ntime, 1)) # 1d 54 | X = jr.normal(next(keys), (ntime, ndim_in)) 55 | labels = jr.randint(next(keys), (ntime,), 0, nclasses-1) 56 | Y = jax.nn.one_hot(labels, nclasses) 57 | return X, Y 58 | 59 | 60 | def callback_scan(bel, pred_obs, t, X, Y, bel_pred, **kwargs): 61 | jax.debug.print("callback with t={t}", t=t) 62 | return t 63 | 64 | def test_scan(): 65 | print('test scan') 66 | X, Y = make_data() 67 | ndim_in = X.shape[1] 68 | ndim_out = Y.shape[1] 69 | estimator = RebayesSum(make_rebayes_params(), ndim_in, ndim_out) 70 | bel, outputs = estimator.scan(X, Y, callback=callback_scan, progress_bar=False) 71 | print('final belief ', bel) 72 | print('outputs ', outputs) 73 | Xsum = jnp.sum(X, axis=0) 74 | assert jnp.allclose(bel.dummy, Xsum) 75 | 76 | def test_update_batch(): 77 | print('test update batch') 78 | X, Y = make_data() 79 | ndim_in = X.shape[1] 80 | ndim_out = Y.shape[1] 81 | estimator = RebayesSum(make_rebayes_params(), ndim_in, ndim_out) 82 | Xsum = jnp.sum(X, axis=0) 83 | 84 | bel = estimator.init_bel() 85 | bel = estimator.update_state_batch(bel, X, Y) 86 | assert(jnp.allclose(bel.dummy, Xsum)) 87 | 88 | bel = estimator.init_bel() 89 | N = X.shape[0] 90 | for n in range(N): 91 | bel = estimator.predict_state(bel) 92 | bel = estimator.update_state(bel, X[n], Y[n]) 93 | assert(jnp.allclose(bel.dummy, Xsum)) 94 | 95 | 96 | -------------------------------------------------------------------------------- /tests/test_base_dl.py: -------------------------------------------------------------------------------- 1 | 2 | #pytest test_base_dl.py -rP 3 | # Test the dataloader version 4 | 5 | import pytest 6 | import numpy as np 7 | import jax 8 | import jax.numpy as jnp 9 | import jax.random as jr 10 | import time 11 | from functools import partial 12 | 13 | from jax import jit 14 | from jax.lax import scan 15 | from jaxtyping import Float, Array 16 | from typing import Callable, NamedTuple, Union, Tuple, Any 17 | import chex 18 | 19 | import haiku as hk 20 | 21 | 22 | import jax_dataloader.core as jdl 23 | 24 | import torch 25 | from torch.utils.data import DataLoader, TensorDataset 26 | import torchvision 27 | import torchvision.datasets as datasets 28 | import torchvision.transforms as T 29 | 30 | from rebayes.base import RebayesParams, Rebayes, Belief, make_rebayes_params 31 | 32 | 33 | class RebayesSum(Rebayes): 34 | """The belief state is the sum of all the input X_t values.""" 35 | def __init__( 36 | self, 37 | params: RebayesParams, 38 | shape_in, 39 | shape_out 40 | ): 41 | self.params = params 42 | self.shape_in = shape_in 43 | self.shape_out = shape_out 44 | 45 | def init_bel(self) -> Belief: 46 | bel = Belief(dummy = jnp.zeros(self.shape_in)) 47 | return bel 48 | 49 | def update_state( 50 | self, 51 | bel: Belief, 52 | X: Float[Array, "input_dim"], 53 | Y: Float[Array, "obs_dim"] 54 | ) -> Belief: 55 | return Belief(dummy = bel.dummy + X) 56 | 57 | 58 | 59 | def callback_dl(b, bel_pre, bel, batch): 60 | jax.debug.print("callback on batch {b}", b=b) 61 | Xtr, Ytr = batch 62 | jax.debug.print("Xtr shape {x1}, Ytr shape {y1}", x1=Xtr.shape, y1=Ytr.shape) 63 | return b 64 | 65 | 66 | def make_data(): 67 | keys = hk.PRNGSequence(42) 68 | ndim_in = 5 69 | nclasses = 10 70 | ntime = 12 71 | #X = jnp.arange(ntime).reshape((ntime, 1)) # 1d 72 | X = jr.normal(next(keys), (ntime, ndim_in)) 73 | labels = jr.randint(next(keys), (ntime,), 0, nclasses-1) 74 | Y = jax.nn.one_hot(labels, nclasses) 75 | return X, Y 76 | 77 | def test_scan_dataloader(): 78 | print('test scan dataloaders') 79 | Xtr, Ytr = make_data() 80 | train_ds = jdl.Dataset(Xtr, Ytr) 81 | train_loader = jdl.DataLoaderJax(train_ds, batch_size=5, shuffle=False, drop_last=False) 82 | shape_in = Xtr.shape[1:] 83 | shape_out = Ytr.shape[1:] 84 | estimator = RebayesSum(make_rebayes_params(), shape_in, shape_out) 85 | bel, outputs = estimator.scan_dataloader(train_loader, callback=callback_dl) 86 | Xsum = jnp.sum(Xtr, axis=0) 87 | assert(jnp.allclose(Xsum, bel.dummy, atol=1e-2)) 88 | 89 | def test_scan_dataloader_batch1(): 90 | print('test scan dataloader batch1') 91 | # when batchsize=1, scan_dataloader == scan 92 | Xtr, Ytr = make_data() 93 | train_ds = jdl.Dataset(Xtr, Ytr) 94 | train_loader = jdl.DataLoaderJax(train_ds, batch_size=1, shuffle=False, drop_last=False) 95 | shape_in = Xtr.shape[1:] 96 | shape_out = Ytr.shape[1:] 97 | estimator = RebayesSum(make_rebayes_params(), shape_in, shape_out) 98 | bel, outputs = estimator.scan_dataloader(train_loader) 99 | bel2, outputs2 = estimator.scan(Xtr, Ytr) 100 | assert(jnp.allclose(bel.dummy, bel2.dummy)) 101 | 102 | def make_mnist_data(): 103 | # convert PIL to pytorch tensor, flatten (1,28,28) to (784), standardize values 104 | # using mean and std deviation of the MNIST dataset. 105 | transform=T.Compose([T.ToTensor(), 106 | T.Normalize((0.1307,), (0.3081,)), 107 | T.Lambda(torch.flatten)] 108 | ) 109 | 110 | train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform) 111 | test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform) 112 | 113 | # convert full dataset to numpy 114 | Xtr, Ytr = train_set.data.numpy(), train_set.targets.numpy() 115 | Xte, Yte = test_set.data.numpy(), test_set.targets.numpy() 116 | 117 | # extract small subset 118 | ntrain, ntest = 100, 500 119 | train_ndx, test_ndx = jnp.arange(ntrain), jnp.arange(ntest) 120 | 121 | return Xtr[train_ndx], Ytr[train_ndx], Xte[test_ndx], Yte[test_ndx] 122 | 123 | 124 | def test_mnist(): 125 | print('test mnist') 126 | Xtr, Ytr, Xte, Yte = make_mnist_data() 127 | train_ds = jdl.Dataset(Xtr, Ytr) 128 | train_loader = jdl.DataLoaderJax(train_ds, batch_size=50, shuffle=False, drop_last=False) 129 | shape_in = Xtr.shape[1:] 130 | shape_out = Ytr.shape[1:] 131 | estimator = RebayesSum(make_rebayes_params(), shape_in, shape_out) 132 | bel, outputs = estimator.scan_dataloader(train_loader, callback=callback_dl) 133 | Xsum = jnp.sum(Xtr, axis=0) 134 | assert(jnp.allclose(Xsum, bel.dummy, atol=1e-2)) 135 | 136 | -------------------------------------------------------------------------------- /tests/test_dual_base.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | 7 | from rebayes.base import Belief 8 | from rebayes.dual_base import ( 9 | dual_rebayes_scan, 10 | RebayesEstimator, 11 | DualRebayesParams, 12 | ObsModel, 13 | make_dual_rebayes_params, 14 | ) 15 | 16 | 17 | def make_my_estimator(params: DualRebayesParams, obs: ObsModel, est_params: Any): 18 | """The belief state is the sum of all the scaled input X_t values. 19 | The model parameters sets the dynamics covariance at time t to t.""" 20 | 21 | del obs # ignored 22 | ndim_in, ndim_out, scale_factor = est_params 23 | 24 | def init(): 25 | bel = Belief(dummy = jnp.zeros((ndim_in,))) 26 | return params, bel 27 | 28 | def predict_state(params, bel): 29 | return bel 30 | 31 | def update_state(params, bel, X, Y): 32 | return Belief(dummy = bel.dummy + scale_factor * X) 33 | 34 | def predict_obs(params, bel, X): 35 | return None 36 | 37 | def predict_obs_cov(params, bel, X): 38 | return None 39 | 40 | def update_params(params, t, X, Y, Yhat, bel): 41 | #jax.debug.print("t={t}", t=t) 42 | params.dynamics_noise = t*1.0 # abritrary update 43 | return params 44 | 45 | return RebayesEstimator(init, predict_state, update_state, predict_obs, predict_obs_cov, update_params) 46 | 47 | 48 | def make_data(): 49 | keys = jr.split(jr.PRNGKey(0), 2) 50 | ndim_in = 5 51 | nclasses = 10 52 | ntime = 12 53 | X = jr.normal(keys[0], (ntime, ndim_in)) 54 | labels = jr.randint(keys[1], (ntime,), 0, nclasses-1) 55 | Y = jax.nn.one_hot(labels, nclasses) 56 | return X, Y 57 | 58 | 59 | def test_scan(): 60 | X, Y = make_data() 61 | ntime = X.shape[0] 62 | ndim_in = X.shape[1] 63 | ndim_out = Y.shape[1] 64 | 65 | scale_factor = 2 66 | est_params = (ndim_in, ndim_out, scale_factor) 67 | params, obs = make_dual_rebayes_params() 68 | params.dynamics_noise = 0 69 | estimator = make_my_estimator(params, obs, est_params) 70 | 71 | carry, outputs = dual_rebayes_scan(estimator, X, Y,) 72 | params, bel = carry 73 | # print('final belief ', bel) 74 | # print('final params ', params) 75 | # print('outputs ', outputs) 76 | Xsum = jnp.sum(X, axis=0) 77 | assert jnp.allclose(bel.dummy, Xsum*scale_factor) 78 | assert jnp.allclose(params.dynamics_noise, ntime-1) -------------------------------------------------------------------------------- /tests/test_lofi.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import jax.numpy as jnp 4 | import jax.random as jr 5 | 6 | from rebayes.low_rank_filter.lofi import ( 7 | INFLATION_METHODS, 8 | RebayesLoFiOrthogonal, 9 | RebayesLoFiSpherical, 10 | RebayesLoFiDiagonal, 11 | ) 12 | from rebayes.low_rank_filter.lofi_core import _fast_svd 13 | from rebayes.datasets import rotating_permuted_mnist_data 14 | from rebayes.utils.utils import get_mlp_flattened_params 15 | 16 | 17 | RebayesLoFiEstimators = [RebayesLoFiOrthogonal, RebayesLoFiSpherical, RebayesLoFiDiagonal] 18 | 19 | 20 | def setup_lofi(memory_size, steady_state, inflation_type, estimator_class): 21 | input_dim, hidden_dims, output_dim = 784, [2, 2], 1 22 | model_dims = [input_dim, *hidden_dims, output_dim] 23 | _, flat_params, _, apply_fn = get_mlp_flattened_params(model_dims) 24 | initial_mean, initial_covariance = flat_params, 1e-1 25 | estimator = estimator_class( 26 | dynamics_weights=1.0, 27 | dynamics_covariance=1e-1, 28 | emission_mean_function=apply_fn, 29 | emission_cov_function=None, 30 | adaptive_emission_cov=True, 31 | dynamics_covariance_inflation_factor=1e-5, 32 | memory_size=memory_size, 33 | steady_state=steady_state, 34 | inflation=inflation_type, 35 | ) 36 | 37 | return initial_mean, initial_covariance, estimator 38 | 39 | 40 | def test_fast_svd(): 41 | for i in [10, 100, 1_000, 1_000_000]: 42 | print(i) 43 | A = jr.normal(jr.PRNGKey(i), (i, 10)) 44 | u_n, s_n, _ = jnp.linalg.svd(A, full_matrices=False) 45 | u_s, s_s = _fast_svd(A) 46 | 47 | assert jnp.allclose(jnp.abs(u_n), jnp.abs(u_s), atol=1e-2) and jnp.allclose(s_n, s_s, atol=1e-2) 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "memory_size, steady_state, inflation_type, estimator_class", 52 | [(10, ss, it, ec) for ss in [True, False] for it in INFLATION_METHODS for ec in RebayesLoFiEstimators] 53 | ) 54 | def test_lofi(memory_size, steady_state, inflation_type, estimator_class): 55 | # Load rotated MNIST dataset 56 | n_train = 200 57 | X_train, y_train = rotating_permuted_mnist_data.generate_rotating_mnist_dataset() 58 | X_train, y_train = X_train[:n_train], y_train[:n_train] 59 | 60 | # Define mean callback function 61 | def callback(bel, *args, **kwargs): 62 | return bel.mean 63 | 64 | # Test if run without error 65 | initial_mean, initial_cov, lofi_estimator = \ 66 | setup_lofi(memory_size, steady_state, inflation_type, estimator_class) 67 | 68 | _ = lofi_estimator.scan(initial_mean, initial_cov, X_train, y_train, callback) 69 | 70 | assert True -------------------------------------------------------------------------------- /tests/test_orfit.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | 6 | from rebayes.extended_kalman_filter.ekf import RebayesEKF 7 | from rebayes.low_rank_filter.orfit import RebayesORFit 8 | from rebayes.datasets import rotating_permuted_mnist_data 9 | from rebayes.utils.utils import get_mlp_flattened_params 10 | 11 | 12 | def allclose(u, v): 13 | return jnp.allclose(u, v, atol=1e-3) 14 | 15 | 16 | def uniform_angles(n_configs, minangle, maxangle): 17 | angles = np.random.uniform(minangle, maxangle, size=n_configs) 18 | return angles 19 | 20 | 21 | def load_rmnist_data(num_train=100): 22 | # Load rotated MNIST dataset 23 | np.random.seed(314) 24 | 25 | X_train, y_train = rotating_permuted_mnist_data.generate_rotating_mnist_dataset(target_digit=2) 26 | 27 | X_train = jnp.array(X_train)[:num_train] 28 | y_train = jnp.array(y_train)[:num_train] 29 | 30 | X_train = (X_train - X_train.mean()) / X_train.std() 31 | 32 | return X_train, y_train 33 | 34 | 35 | def setup_orfit(memory_size): 36 | # Define Linear Regression as single layer perceptron 37 | input_dim, hidden_dims, output_dim = 784, [], 1 38 | model_dims = [input_dim, *hidden_dims, output_dim] 39 | _, flat_params, _, apply_fn = get_mlp_flattened_params(model_dims) 40 | initial_mean, initial_covariance = flat_params, None 41 | 42 | # Define ORFit parameters 43 | estimator = RebayesORFit( 44 | dynamics_weights = None, 45 | dynamics_covariance = None, 46 | emission_mean_function = apply_fn, 47 | emission_cov_function = None, 48 | memory_size = memory_size, 49 | ) 50 | 51 | return initial_mean, initial_covariance, estimator 52 | 53 | 54 | def setup_ekf(): 55 | # Define Linear Regression as single layer perceptron 56 | input_dim, hidden_dims, output_dim = 784, [], 1 57 | model_dims = [input_dim, *hidden_dims, output_dim] 58 | _, flat_params, _, apply_fn = get_mlp_flattened_params(model_dims) 59 | initial_mean, initial_covariance = flat_params, jnp.eye(len(flat_params)), 60 | 61 | # Define Kalman Filter parameters 62 | estimator = RebayesEKF( 63 | dynamics_weights_or_function = 1.0, 64 | dynamics_covariance = 0.0, 65 | emission_mean_function = apply_fn, 66 | emission_cov_function = lambda w, x: jnp.array([0.]), 67 | method = "fcekf" 68 | ) 69 | 70 | return initial_mean, initial_covariance, estimator 71 | 72 | 73 | def test_rebayes_orfit_loop(): 74 | # Load rotated MNIST dataset 75 | n_train = 200 76 | X_train, y_train = load_rmnist_data(n_train) 77 | 78 | # Run Infinite-memory ORFit 79 | orfit_init_mean, orfit_init_cov, orfit_estimator = setup_orfit(n_train) 80 | orfit_before_time = time.time() 81 | orfit_bel = orfit_estimator.init_bel(orfit_init_mean, orfit_init_cov) 82 | for i in range(n_train): 83 | orfit_bel = orfit_estimator.update_state(orfit_bel, X_train[i], y_train[i]) 84 | orfit_after_time = time.time() 85 | print(f"Looped ORFit took {orfit_after_time - orfit_before_time} seconds.") 86 | 87 | # Run Kalman Filter 88 | ekf_init_mean, ekf_init_cov, ekf_estimator = setup_ekf() 89 | ekf_before_time = time.time() 90 | ekf_bel = ekf_estimator.init_bel(ekf_init_mean, ekf_init_cov) 91 | for i in range(n_train): 92 | ekf_bel = ekf_estimator.update_state(ekf_bel, X_train[i], y_train[i]) 93 | ekf_after_time = time.time() 94 | print(f"Kalman Filter took {ekf_after_time - ekf_before_time} seconds.") 95 | 96 | assert allclose(orfit_bel.mean, ekf_bel.mean) 97 | 98 | 99 | def test_rebayes_orfit_scan(): 100 | # Load rotated MNIST dataset 101 | n_train = 200 102 | X_train, y_train = load_rmnist_data(n_train) 103 | 104 | # Define mean callback function 105 | def callback(bel, *args, **kwargs): 106 | return bel.mean 107 | 108 | # Run Infinite-memory ORFit 109 | orfit_init_mean, orfit_init_cov, orfit_estimator = setup_orfit(n_train) 110 | orfit_before_time = time.time() 111 | _, orfit_outputs = orfit_estimator.scan(orfit_init_mean, orfit_init_cov, X_train, y_train, callback) 112 | orfit_after_time = time.time() 113 | print(f"Scanned ORFit took {orfit_after_time - orfit_before_time} seconds.") 114 | 115 | # Run Kalman Filter 116 | ekf_init_mean, ekf_init_cov, ekf_estimator = setup_ekf() 117 | ekf_before_time = time.time() 118 | _, ekf_outputs = ekf_estimator.scan(ekf_init_mean, ekf_init_cov, X_train, y_train, callback) 119 | ekf_after_time = time.time() 120 | print(f"Kalman Filter took {ekf_after_time - ekf_before_time} seconds.") 121 | 122 | assert allclose(orfit_outputs, ekf_outputs) 123 | --------------------------------------------------------------------------------