├── .github ├── ISSUE_TEMPLATE │ ├── 01_BUG_REPORT.md │ ├── 02_FEATURE_REQUEST.md │ ├── 03_DOCUMENTATION.md │ ├── 04_CODEBASE_IMPROVEMENT.md │ └── config.yml ├── pull_request_template.md └── workflows │ ├── documentation.yml │ ├── lint.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.rst ├── SECURITY.md ├── benchmarks ├── calibration │ └── temp_scaling │ │ └── breast_cancer_temp_scaling.py ├── confbayes │ ├── README.md │ ├── example_2moons_CB.py │ └── example_MNIST_CB.py ├── focal_vs_cross_entropy.ipynb ├── hallucination │ └── mmlu │ │ └── run.py ├── multivalid │ ├── breast_cancer_multicalibrate.py │ ├── mnist_top_label_multicalibrate.py │ └── two_moons_multicalibrate.py ├── tabular │ ├── analysis.py │ ├── dataset.py │ └── run.py ├── tabular_regressions.ipynb └── transformers │ ├── masked_language_modeling.py │ ├── prob_model_text_classification.py │ └── sagemaker_entrypoints │ └── prob_model_text_classification_config │ ├── default.yaml │ ├── hyperparams │ ├── cyclical_sgld_ll.yaml │ └── sghmc_ll.yaml │ ├── method │ ├── cyclical_sgld_ll.yaml │ └── sgmcmc_ll.yaml │ ├── model │ ├── bert.yaml │ └── roberta.yaml │ └── task │ ├── mnli.yaml │ └── sentiment.yaml ├── docs ├── Makefile ├── README.md ├── make.bat └── source │ ├── _static │ ├── fortuna_symbol.png │ ├── fortuna_symbol2.png │ ├── fortuna_symbol_white.png │ └── pipeline.png │ ├── conf.py │ ├── index.rst │ ├── installation.rst │ ├── landingpage.rst │ ├── license.rst │ ├── methods.rst │ ├── quickstart.rst │ ├── references │ ├── conformal.rst │ ├── data_loader.rst │ ├── likelihood.rst │ ├── metric.rst │ ├── model │ │ ├── builtin_models.rst │ │ ├── cnn.rst │ │ ├── constant.rst │ │ ├── hyper.rst │ │ ├── linen_module.rst │ │ ├── mlp.rst │ │ ├── model.rst │ │ ├── model_manager.rst │ │ ├── resnet.rst │ │ ├── scalar_constant.rst │ │ ├── scalar_hyper.rst │ │ ├── utils.rst │ │ ├── utils │ │ │ ├── random_features.rst │ │ │ └── spectral_norm.rst │ │ └── wideresnet.rst │ ├── ood_classifier.rst │ ├── output_calib_model │ │ ├── config.rst │ │ ├── output_calib_model.rst │ │ └── predictive.rst │ ├── output_calibrator.rst │ ├── plot.rst │ ├── prob_model │ │ ├── callbacks.rst │ │ ├── fit_config.rst │ │ ├── joint.rst │ │ ├── model_editor.rst │ │ ├── posterior │ │ │ ├── advi.rst │ │ │ ├── deep_ensemble.rst │ │ │ ├── laplace.rst │ │ │ ├── map.rst │ │ │ ├── posterior.rst │ │ │ ├── sgmcmc.rst │ │ │ ├── sngp.rst │ │ │ └── swag.rst │ │ ├── predictive.rst │ │ ├── prior.rst │ │ ├── prob_calib_config.rst │ │ └── prob_model.rst │ ├── prob_output_layer.rst │ ├── references.rst │ ├── sagemaker.rst │ ├── typing.rst │ └── utils.rst │ └── usage_modes │ ├── flax_models.rst │ ├── model_outputs.rst │ ├── uncertainty_estimates.rst │ └── usage_modes.rst ├── examples ├── adaptive_conformal_inference.pct.py ├── bring_in_your_own.pct.py ├── enbpi_ts_regression.pct.py ├── index.rst ├── jackknifeplus_regression.pct.py ├── mnist_classification.pct.py ├── mnist_classification_sghmc.pct.py ├── multivalid_coverage.pct.py ├── scaling_up_bayesian_inference.pct.py ├── sentiment_analysis.pct.py ├── sgmcmc_diagnostics.pct.py ├── sinusoidal_regression.pct.py ├── subnet_calibration.pct.py ├── two_moons_classification.pct.py └── two_moons_classification_ood.pct.py ├── fortuna ├── __init__.py ├── calib_model │ ├── __init__.py │ ├── base.py │ ├── calib_mixin.py │ ├── calib_model_calibrator.py │ ├── calib_state_repository.py │ ├── classification.py │ ├── config │ │ ├── __init__.py │ │ ├── base.py │ │ ├── callback.py │ │ ├── checkpointer.py │ │ ├── hyperparameters.py │ │ ├── monitor.py │ │ ├── optimizer.py │ │ └── processor.py │ ├── loss.py │ ├── predictive │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classification.py │ │ └── regression.py │ ├── regression.py │ └── state.py ├── calibration │ ├── __init__.py │ ├── binary_classification │ │ ├── __init__.py │ │ └── temp_scaling │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── bias_binary_temp_scaling.py │ │ │ ├── brier_binary_temp_scaling.py │ │ │ ├── crossentropy_binary_temp_scaling.py │ │ │ └── f1_temp_scaling.py │ └── classification │ │ ├── __init__.py │ │ └── temp_scaling │ │ ├── __init__.py │ │ └── base.py ├── conformal │ ├── __init__.py │ ├── classification │ │ ├── __init__.py │ │ ├── adaptive_conformal_classifier.py │ │ ├── adaptive_prediction.py │ │ ├── base.py │ │ ├── maxcovfixprec_binary_classfication.py │ │ └── simple_prediction.py │ ├── multivalid │ │ ├── __init__.py │ │ ├── base.py │ │ ├── iterative │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── batch_mvp.py │ │ │ ├── classification │ │ │ │ ├── __init__.py │ │ │ │ ├── binary_multicalibrator.py │ │ │ │ └── top_label_multicalibrator.py │ │ │ ├── multicalibrator.py │ │ │ └── regression │ │ │ │ ├── __init__.py │ │ │ │ └── batch_mvp.py │ │ ├── mixins │ │ │ ├── __init__.py │ │ │ ├── batchmvp.py │ │ │ ├── classification │ │ │ │ ├── __init__.py │ │ │ │ ├── binary_multicalibrator.py │ │ │ │ └── top_label_multicalibrator.py │ │ │ ├── multicalibrator.py │ │ │ └── regression │ │ │ │ └── __init__.py │ │ └── one_shot │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── classification │ │ │ ├── __init__.py │ │ │ ├── binary_multicalibrator.py │ │ │ └── top_label_multicalibrator.py │ │ │ └── multicalibrator.py │ └── regression │ │ ├── __init__.py │ │ ├── adaptive_conformal_regressor.py │ │ ├── base.py │ │ ├── batch_mvp.py │ │ ├── cvplus.py │ │ ├── enbpi.py │ │ ├── jackknife_minmax.py │ │ ├── jackknifeplus.py │ │ ├── onedim_uncertainty.py │ │ └── quantile.py ├── data │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── data_collator.py │ │ └── huggingface_datasets.py │ └── loader │ │ ├── __init__.py │ │ ├── array_loaders.py │ │ ├── base.py │ │ ├── huggingface_loaders.py │ │ └── utils.py ├── distribution │ ├── __init__.py │ ├── base.py │ └── gaussian.py ├── docker │ ├── Dockerfile │ └── build_and_push.sh ├── hallucination │ ├── __init__.py │ ├── base.py │ ├── grouping │ │ ├── __init__.py │ │ └── clustering │ │ │ ├── __init__.py │ │ │ └── base.py │ ├── scoring │ │ └── inv_perplexity.py │ └── utils │ │ ├── __init__.py │ │ └── string.py ├── kernel_regression │ ├── __init__.py │ ├── kernels │ │ └── gaussian.py │ └── nadaraya_watson.py ├── likelihood │ ├── __init__.py │ ├── base.py │ ├── classification.py │ └── regression.py ├── loss │ ├── __init__.py │ ├── classification │ │ ├── __init__.py │ │ ├── cross_entropy.py │ │ └── focal_loss.py │ └── regression │ │ ├── __init__.py │ │ └── scaled_mse.py ├── metric │ ├── __init__.py │ ├── classification.py │ └── regression.py ├── model │ ├── __init__.py │ ├── cnn.py │ ├── constant.py │ ├── hyper.py │ ├── lenet.py │ ├── linear.py │ ├── mlp.py │ ├── model_manager │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classification.py │ │ ├── name_to_model_manager.py │ │ ├── regression.py │ │ ├── state.py │ │ └── transformers │ │ │ ├── __init__.py │ │ │ └── classification.py │ ├── resnet.py │ ├── scalar_constant.py │ ├── scalar_hyper.py │ ├── utils │ │ ├── __init__.py │ │ ├── random_features.py │ │ └── spectral_norm.py │ └── wideresnet.py ├── model_editor │ ├── __init__.py │ ├── base.py │ └── probit.py ├── ood_detection │ ├── __init__.py │ ├── base.py │ ├── ddu.py │ └── mahalanobis.py ├── output_calib_model │ ├── __init__.py │ ├── base.py │ ├── classification.py │ ├── config │ │ ├── __init__.py │ │ ├── base.py │ │ ├── checkpointer.py │ │ ├── monitor.py │ │ ├── optimizer.py │ │ └── processor.py │ ├── loss.py │ ├── output_calib_mixin.py │ ├── output_calib_model_calibrator.py │ ├── output_calib_state_repository.py │ ├── predictive │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classification.py │ │ └── regression.py │ ├── regression.py │ └── state.py ├── output_calibrator │ ├── __init__.py │ ├── classification.py │ ├── output_calib_manager │ │ ├── __init__.py │ │ ├── base.py │ │ └── state.py │ └── regression.py ├── plot.py ├── prob_model │ ├── __init__.py │ ├── base.py │ ├── calib_config │ │ ├── __init__.py │ │ ├── base.py │ │ ├── checkpointer.py │ │ ├── monitor.py │ │ ├── optimizer.py │ │ └── processor.py │ ├── classification.py │ ├── fit_config │ │ ├── __init__.py │ │ ├── base.py │ │ ├── callback.py │ │ ├── checkpointer.py │ │ ├── hyperparameters.py │ │ ├── monitor.py │ │ ├── optimizer.py │ │ └── processor.py │ ├── joint │ │ ├── __init__.py │ │ ├── base.py │ │ └── state.py │ ├── posterior │ │ ├── __init__.py │ │ ├── base.py │ │ ├── deep_ensemble │ │ │ ├── __init__.py │ │ │ ├── deep_ensemble_approximator.py │ │ │ ├── deep_ensemble_posterior.py │ │ │ └── deep_ensemble_state.py │ │ ├── laplace │ │ │ ├── __init__.py │ │ │ ├── laplace_approximator.py │ │ │ ├── laplace_posterior.py │ │ │ └── laplace_state.py │ │ ├── map │ │ │ ├── __init__.py │ │ │ ├── map_approximator.py │ │ │ ├── map_posterior.py │ │ │ ├── map_state.py │ │ │ └── map_trainer.py │ │ ├── name_to_posterior_state.py │ │ ├── normalizing_flow │ │ │ ├── __init__.py │ │ │ ├── advi │ │ │ │ ├── __init__.py │ │ │ │ ├── advi_approximator.py │ │ │ │ ├── advi_architecture.py │ │ │ │ ├── advi_posterior.py │ │ │ │ ├── advi_state.py │ │ │ │ └── advi_trainer.py │ │ │ ├── normalizing_flow_state.py │ │ │ └── normalizing_flow_trainer.py │ │ ├── posterior_approximations.py │ │ ├── posterior_mixin.py │ │ ├── posterior_multi_state_repository.py │ │ ├── posterior_state_repository.py │ │ ├── posterior_trainer.py │ │ ├── run_preliminary_map.py │ │ ├── sgmcmc │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── cyclical_sgld │ │ │ │ ├── __init__.py │ │ │ │ ├── cyclical_sgld_approximator.py │ │ │ │ ├── cyclical_sgld_callback.py │ │ │ │ ├── cyclical_sgld_integrator.py │ │ │ │ ├── cyclical_sgld_posterior.py │ │ │ │ └── cyclical_sgld_state.py │ │ │ ├── sghmc │ │ │ │ ├── __init__.py │ │ │ │ ├── sghmc_approximator.py │ │ │ │ ├── sghmc_callback.py │ │ │ │ ├── sghmc_integrator.py │ │ │ │ ├── sghmc_posterior.py │ │ │ │ └── sghmc_state.py │ │ │ ├── sgmcmc_diagnostic.py │ │ │ ├── sgmcmc_posterior.py │ │ │ ├── sgmcmc_posterior_state_repository.py │ │ │ ├── sgmcmc_preconditioner.py │ │ │ ├── sgmcmc_sampling_callback.py │ │ │ └── sgmcmc_step_schedule.py │ │ ├── sngp │ │ │ ├── __init__.py │ │ │ ├── sngp_approximator.py │ │ │ ├── sngp_callback.py │ │ │ ├── sngp_posterior.py │ │ │ └── transformers │ │ │ │ ├── __init__.py │ │ │ │ ├── auto_factory.py │ │ │ │ ├── modeling_flax_auto.py │ │ │ │ └── models │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_flax_bert.py │ │ │ │ ├── modeling_flax_distilbert.py │ │ │ │ └── modeling_flax_roberta.py │ │ ├── state.py │ │ └── swag │ │ │ ├── __init__.py │ │ │ ├── swag_approximator.py │ │ │ ├── swag_posterior.py │ │ │ ├── swag_state.py │ │ │ └── swag_trainer.py │ ├── predictive │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classification.py │ │ └── regression.py │ ├── prior │ │ ├── __init__.py │ │ ├── base.py │ │ └── gaussian.py │ ├── prob_model_calibrator.py │ ├── regression.py │ └── state.py ├── prob_output_layer │ ├── __init__.py │ ├── base.py │ ├── classification.py │ └── regression.py ├── sagemaker │ ├── __init__.py │ ├── base.py │ └── utils.py ├── training │ ├── __init__.py │ ├── callback.py │ ├── mixin.py │ ├── name_to_train_state.py │ ├── output_calibrator.py │ ├── train_state.py │ ├── train_state_repository.py │ └── trainer.py ├── typing.py └── utils │ ├── __init__.py │ ├── builtins.py │ ├── data.py │ ├── device.py │ ├── freeze.py │ ├── grad.py │ ├── nested_dicts.py │ ├── optimizer.py │ ├── probit.py │ ├── random.py │ ├── strings.py │ └── training.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── fortuna ├── __init__.py ├── calib_model │ ├── test_calib_model.py │ └── test_output_calib_model.py ├── hallucination │ ├── grouping.py │ └── scoring.py ├── prob_model │ ├── __init__.py │ ├── test_diagnostic.py │ ├── test_joint.py │ ├── test_likelihood.py │ ├── test_preconditioner.py │ ├── test_prob_model.py │ ├── test_step_schedule.py │ └── test_train.py ├── test_builtins.py ├── test_conformal_methods.py ├── test_data.py ├── test_kernel_regression.py ├── test_metric.py ├── test_mixin.py ├── test_model.py ├── test_output_maker.py ├── test_plot.py ├── test_predictive.py ├── test_prior.py ├── test_prob_output_layer.py ├── test_state.py ├── test_temp_scaling.py ├── test_trainer.py └── utils │ └── freeze.py ├── make_data.py └── make_model.py /.github/ISSUE_TEMPLATE/01_BUG_REPORT.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help Fortuna to improve 4 | title: "bug: " 5 | labels: "bug" 6 | assignees: "" 7 | --- 8 | 9 | # Bug Report 10 | 11 | **Fortuna version:** 12 | 13 | 14 | 15 | **Current behavior:** 16 | 17 | 18 | 19 | **Expected behavior:** 20 | 21 | 22 | 23 | **Steps to reproduce:** 24 | 25 | 26 | 27 | **Related code:** 28 | 29 | 30 | 31 | ``` 32 | insert short code snippets here 33 | ``` 34 | 35 | **Other information:** 36 | 37 | 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea for this project 4 | title: "feat: " 5 | labels: "enhancement" 6 | assignees: "" 7 | --- 8 | 9 | # Feature Request 10 | 11 | **Describe the Feature Request** 12 | 13 | 14 | 15 | **Describe Preferred Solution** 16 | 17 | 18 | 19 | 20 | **Related Code** 21 | 22 | 23 | 24 | **Additional Context** 25 | 26 | 27 | 28 | **If the feature request is approved, would you be willing to submit a PR?** 29 | Yes / No _(Help can be provided if you need assistance submitting a PR)_ 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/03_DOCUMENTATION.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation 3 | about: Suggest an improvement for our documentation 4 | title: "docs: " 5 | labels: "documentation" 6 | assignees: "" 7 | --- 8 | 9 | # Documentation Request 10 | 11 | **Link to relevant documentation:** 12 | 13 | **Suggested improvement** 14 | 15 | 20 | 21 | **Additional Context** 22 | 23 | 24 | 25 | **If the request is approved, would you be willing to submit a PR?** 26 | Yes / No _(Help can be provided if you need assistance submitting a PR)_ 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/04_CODEBASE_IMPROVEMENT.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Codebase improvement 3 | about: Provide your feedback for the existing codebase. Suggest a better solution for algorithms, development tools, etc. 4 | title: "dev: " 5 | labels: "enhancement" 6 | assignees: "" 7 | --- 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | --- 2 | blank_issues_enabled: false 3 | contact_links: 4 | - name: Fortuna Community Support 5 | url: https://github.com/awslabs/fortuna/discussions 6 | about: Please ask and answer questions here. 7 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Pull request type 4 | 5 | 6 | 7 | Please check the type of change your PR introduces: 8 | 9 | - [ ] Bugfix 10 | - [ ] Feature 11 | - [ ] Code style update (formatting, renaming) 12 | - [ ] Refactoring (no functional changes, no api changes) 13 | - [ ] Build related changes 14 | - [ ] Documentation content changes 15 | - [ ] Other (please describe): 16 | 17 | ## What is the current behavior? 18 | 19 | 20 | 21 | Issue Number: N/A 22 | 23 | ## What is the new behavior? 24 | 25 | 26 | 27 | - 28 | - 29 | - 30 | 31 | ## Other information 32 | 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build the documentation 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build: 11 | # Functionality for testing documentation builds on multiple OSes and Python versions 12 | name: Build docs (${{ matrix.python-version }}, ${{ matrix.os }}) 13 | runs-on: 14 | group: fortuna 15 | labels: fortuna_ubuntu-latest_32-core 16 | defaults: 17 | run: 18 | shell: bash -l {0} 19 | strategy: 20 | matrix: 21 | os: ["ubuntu-latest"] 22 | python-version: ["3.11"] 23 | 24 | steps: 25 | # Grap the latest commit from the branch 26 | - name: Checkout the branch 27 | uses: actions/checkout@v2.3.1 28 | with: 29 | persist-credentials: false 30 | 31 | # Add this step to set up Python version from matrix 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v2 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | # Install Poetry and build the documentation 38 | - name: Install and configure Poetry 39 | uses: snok/install-poetry@v1 40 | with: 41 | version: 1.8.3 42 | virtualenvs-create: true 43 | virtualenvs-in-project: false 44 | installer-parallel: true 45 | env: 46 | POETRY_VIRTUALENVS_PREFER_ACTIVE_PYTHON: "true" 47 | 48 | - name: Build the documentation with Sphinx 49 | run: | 50 | poetry env use ${{ matrix.python-version }} 51 | poetry install --all-extras 52 | sudo apt install pandoc 53 | pip install pandoc 54 | cd docs 55 | poetry run sphinx-build -b html source build/html 56 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Python linting 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | push: 7 | branches: [main] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - name: Set up Python 3.9 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.9 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install black 25 | 26 | - name: Run Black 27 | run: black --check --diff --verbose fortuna 28 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | jobs: 8 | test: 9 | name: Run Tests 10 | runs-on: 11 | group: fortuna 12 | labels: fortuna_ubuntu-latest_32-core 13 | strategy: 14 | matrix: 15 | # Select the Python versions to test against 16 | python-version: ["3.9", "3.10", "3.11"] 17 | steps: 18 | - name: Check out the code 19 | uses: actions/checkout@v3 20 | with: 21 | fetch-depth: 1 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | # Install Poetry 28 | - name: Install Poetry 29 | uses: snok/install-poetry@v1.3.3 30 | with: 31 | version: 1.8.3 32 | 33 | # Configure Poetry to use the virtual environment in the project 34 | - name: Setup Poetry 35 | run: | 36 | poetry config virtualenvs.in-project true 37 | 38 | # Install the dependencies 39 | - name: Install Package 40 | run: | 41 | poetry install --with test 42 | 43 | # Run the unit tests and build the coverage report 44 | - name: Run Tests 45 | run: poetry run pytest --cov=fortuna --cov-report=term-missing --cov-report=xml 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .AppleDouble 3 | .LSOverride 4 | *.ipynb 5 | **/__pycache__/ 6 | 7 | 8 | docs/build/ 9 | docs/source/examples/ 10 | **/*.tar.gz 11 | **/*.txt 12 | **/*.txt~ 13 | **/*.csv 14 | **/*.zip 15 | **/*.xlsx 16 | **/*.xls 17 | **/*.data 18 | **/*.ods 19 | **/*._.DS_Store 20 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | minimum_pre_commit_version: 2.15.0 2 | ci: 3 | autofix_prs: false 4 | repos: 5 | - repo: https://github.com/psf/black 6 | rev: 23.3.0 7 | hooks: 8 | - id: black 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.4.0 11 | hooks: 12 | - id: trailing-whitespace 13 | - id: end-of-file-fixer 14 | exclude_types: [json, binary] 15 | - repo: https://github.com/codespell-project/codespell 16 | rev: v2.2.4 17 | hooks: 18 | - id: codespell 19 | types_or: [python, markdown, rst] 20 | additional_dependencies: [tomli] 21 | - repo: https://github.com/PyCQA/isort 22 | rev: 5.12.0 23 | hooks: 24 | - id: isort 25 | exclude: examples/ 26 | - repo: local 27 | hooks: 28 | - id: commitizen 29 | name: commitizen 30 | entry: cz check 31 | args: [--commit-msg-file] 32 | require_serial: true 33 | language: system 34 | stages: [commit-msg] 35 | - id: absolufy-imports 36 | name: absolufy-imports 37 | entry: absolufy-imports 38 | require_serial: true 39 | language: system 40 | types: [python] 41 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | # Set the version of Python and other tools you might need 4 | build: 5 | os: ubuntu-20.04 6 | tools: 7 | python: '3.11' 8 | jobs: 9 | post_create_environment: 10 | # Install poetry 11 | - pip install poetry==1.3.2 12 | # Tell poetry to not use a virtual environment 13 | - poetry config virtualenvs.create false 14 | post_install: 15 | - poetry install -E docs 16 | 17 | # Build documentation in the docs/ directory with Sphinx 18 | sphinx: 19 | configuration: docs/source/conf.py 20 | fail_on_warning: false 21 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Reporting a Vulnerability 2 | 3 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security 4 | via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/) or directly via email to aws-security@amazon.com. 5 | Please do **not** create a public GitHub issue. 6 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task/sentiment 3 | - model/roberta 4 | - method/sgmcmc_ll 5 | - hyperparams/sghmc_ll 6 | 7 | dataset: 8 | base_data_path: ~ 9 | train_relative_path: "" 10 | test_relative_path: "" 11 | validation_relative_path: "" 12 | 13 | 14 | model: 15 | hparams: 16 | tokenizer_max_length: 512 17 | max_grad_norm: 1 18 | adam_eps: 0.00000001 19 | adam_b2: 0.999 20 | gradient_checkpointing: "true" 21 | save_every_n_steps: 20000 22 | keep_top_n_checkpoints: 1 23 | seed: 42 24 | disable_jit: False 25 | devices: -1 26 | 27 | sagemaker: 28 | account_id: ~ 29 | iam_role: ~ 30 | entrypoint: "benchmarks/transformers//prob_model_text_classification.py" 31 | instance_type: "ml.g5.2xlarge" 32 | profile: "default" 33 | region: "us-east-1" 34 | job_name_suffix: ~ 35 | metrics: 36 | - {Name: "train_loss_step", Regex: 'loss: ([-+]?(\d+(\.\d*)?|\.\d+))'} 37 | - {Name: "train_accuracy_step", Regex: 'accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} 38 | - {Name: "val_loss", Regex: 'val_loss: ([-+]?(\d+(\.\d*)?|\.\d+))'} 39 | - {Name: "val_accuracy", Regex: 'val_accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} 40 | - {Name: "ind_accuracy", Regex: 'IND Test accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} 41 | - {Name: "ind_ece", Regex: 'IND ECE: ([-+]?(\d+(\.\d*)?|\.\d+))'} 42 | - {Name: "ood_accuracy", Regex: 'OOD Test accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} 43 | - {Name: "ood_ece", Regex: 'OOD ECE: ([-+]?(\d+(\.\d*)?|\.\d+))'} 44 | 45 | output_data_path: ~ 46 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/hyperparams/cyclical_sgld_ll.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | hparams: 3 | sghmc_momentum_decay: 0 4 | 5 | sagemaker: 6 | tuner: 7 | hyperparameter_ranges: 8 | sgmcmc_preconditioner: 9 | _target_: sagemaker.tuner.CategoricalParameter 10 | values: 11 | - true 12 | - false 13 | sgmcmc_step_schedule: 14 | _target_: sagemaker.tuner.CategoricalParameter 15 | values: 16 | - constant 17 | - cosine 18 | sgmcmc_init_step_size: 19 | _target_: sagemaker.tuner.ContinuousParameter 20 | min_value: 5e-6 21 | max_value: 0.1 22 | scaling_type: Logarithmic 23 | sgmcmc_n_thinning: 24 | _target_: sagemaker.tuner.IntegerParameter 25 | min_value: 100 26 | max_value: 500 27 | scaling_type: Auto 28 | objective_metric_name: ind_ece 29 | objective_type: Minimize 30 | max_parallel_jobs: 2 31 | max_jobs: 50 32 | early_stopping_type: Auto 33 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/hyperparams/sghmc_ll.yaml: -------------------------------------------------------------------------------- 1 | sagemaker: 2 | tuner: 3 | hyperparameter_ranges: 4 | sghmc_momentum_decay: 5 | _target_: sagemaker.tuner.ContinuousParameter 6 | min_value: 0.0001 7 | max_value: 0.005 8 | scaling_type: Logarithmic 9 | sgmcmc_preconditioner: 10 | _target_: sagemaker.tuner.CategoricalParameter 11 | values: 12 | - true 13 | - false 14 | sgmcmc_step_schedule: 15 | _target_: sagemaker.tuner.CategoricalParameter 16 | values: 17 | - constant 18 | - cosine 19 | sgmcmc_init_step_size: 20 | _target_: sagemaker.tuner.ContinuousParameter 21 | min_value: 5e-6 22 | max_value: 0.1 23 | scaling_type: Logarithmic 24 | sgmcmc_n_thinning: 25 | _target_: sagemaker.tuner.IntegerParameter 26 | min_value: 100 27 | max_value: 500 28 | scaling_type: Auto 29 | objective_metric_name: ind_ece 30 | objective_type: Minimize 31 | max_parallel_jobs: 2 32 | max_jobs: 50 33 | early_stopping_type: Auto 34 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/method/cyclical_sgld_ll.yaml: -------------------------------------------------------------------------------- 1 | name: "cyclical_sgld" 2 | hparams: 3 | posterior_approximator_name: "cyclical_sgld" 4 | last_layer_only: "true" 5 | num_train_epochs: 6 6 | n_posterior_samples: 15 7 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/method/sgmcmc_ll.yaml: -------------------------------------------------------------------------------- 1 | name: "sghmc" 2 | hparams: 3 | posterior_approximator_name: "sghmc" 4 | last_layer_only: "true" 5 | num_train_epochs: 6 6 | n_posterior_samples: 15 7 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/model/bert.yaml: -------------------------------------------------------------------------------- 1 | name: "bert" 2 | hparams: 3 | model_name_or_path: "bert-base-cased" 4 | per_device_eval_batch_size: 32 5 | per_device_train_batch_size: 32 6 | learning_rate: 2e-05 7 | num_warmup_steps: 500 8 | prior_log_var: 100.0 9 | weight_decay: 0.01 10 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/model/roberta.yaml: -------------------------------------------------------------------------------- 1 | name: "roberta" 2 | hparams: 3 | model_name_or_path: "roberta-base" 4 | per_device_train_batch_size: 16 5 | per_device_eval_batch_size: 16 6 | learning_rate: 5e-5 7 | num_warmup_steps: 500 8 | prior_log_var: 20. 9 | weight_decay: 0.0 10 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/task/mnli.yaml: -------------------------------------------------------------------------------- 1 | name: "mnli" 2 | hparams: 3 | dataset_name: "glue" 4 | text_columns: "premise,hypothesis" 5 | num_labels: 3 6 | train_split: "train" 7 | validation_split: "validation_matched[:50%]" 8 | test_split: "validation_matched[-50%:]" 9 | ood_dataset_name: "snli" 10 | ood_text_columns: "premise,hypothesis" 11 | ood_test_split: "test[:50%]" 12 | task_name: "mnli" 13 | -------------------------------------------------------------------------------- /benchmarks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/task/sentiment.yaml: -------------------------------------------------------------------------------- 1 | name: "sentiment" 2 | hparams: 3 | dataset_name: "imdb" 4 | text_columns: "text" 5 | num_labels: 2 6 | train_split: "train" 7 | validation_split: "test[:25%]+test[-25%:]" 8 | test_split: "test[25%:75%]" 9 | ood_dataset_name: "yelp_polarity" 10 | ood_text_columns: "text" 11 | ood_test_split: "test[25%:75%]" 12 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Fortuna Documentation 2 | 3 | The documentation for Fortuna can be found [here](https://aws-fortuna.readthedocs.io/en/latest/). 4 | 5 | ## Build the documentation 6 | 7 | The documentation for Fortuna consists of a series of notebooks. To serve these locally, there are two steps that need to be taken that we outline below. 8 | 9 | ### Prerequisites 10 | 11 | To build the documentation, first install Fortuna and it's dependencies by following the [installation instructions](https://github.com/awslabs/fortuna#installation). Next, install the documentation requirements through 12 | ```bash 13 | poetry install -E docs 14 | ``` 15 | 16 | ### Notebooks 17 | 18 | For easier version control, the notebooks are stored as `.pct.py` files. To convert these to `.ipynb` files, run the following command from the root of the repository: 19 | 20 | ```bash 21 | jupytext --to notebook examples/*pct.py 22 | ``` 23 | 24 | This will create a corresponding notebook file for each `.pct.py` file that can be opened in Jupyter. 25 | 26 | ### Building the documentation 27 | 28 | From the root directory, documentation can be built by running the following commands: 29 | 30 | ```bash 31 | cd docs 32 | poetry run make html 33 | ``` 34 | 35 | Documentation will then be available in the `docs/build/html` directory. 36 | 37 | The above process can be slow as it executes each notebook one-by-one. To build the notebooks in parallel, run the following command: 38 | 39 | ```bash 40 | cd docs 41 | sphinx-build -b html -j auto source build/html 42 | ``` 43 | 44 | 45 | ### Additional Information 46 | 47 | For [VSCode](https://code.visualstudio.com/) users, we recommend installing the [Jupytext extension](https://marketplace.visualstudio.com/items?itemName=congyiwu.vscode-jupytext) to automatically render `.pct.py` as Jupyter notebooks when opened in VSCode. 48 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/fortuna_symbol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/docs/source/_static/fortuna_symbol.png -------------------------------------------------------------------------------- /docs/source/_static/fortuna_symbol2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/docs/source/_static/fortuna_symbol2.png -------------------------------------------------------------------------------- /docs/source/_static/fortuna_symbol_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/docs/source/_static/fortuna_symbol_white.png -------------------------------------------------------------------------------- /docs/source/_static/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/docs/source/_static/pipeline.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: landingpage.rst 2 | 3 | .. toctree:: 4 | :maxdepth: 1 5 | :hidden: 6 | :caption: CONTENTS: 7 | 8 | quickstart 9 | installation 10 | examples/index 11 | usage_modes/usage_modes 12 | references/references 13 | methods 14 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | **NOTE:** Before installing Fortuna, you are required to `install JAX `_ in your virtual environment. 4 | 5 | You can install Fortuna by typing 6 | 7 | .. code-block:: 8 | 9 | pip install aws-fortuna 10 | 11 | Alternatively, you can build the package using `Poetry `_. 12 | If you choose to pursue this way, first install Poetry and add it to your PATH 13 | (see `here `_). Then type 14 | 15 | .. code-block:: 16 | 17 | poetry install 18 | 19 | All the dependencies will be installed at their required versions. 20 | If you also want to install the optional Sphinx dependencies to build the documentation, 21 | add the flag :code:`-E docs` to the command above. 22 | Finally, you can either access the virtualenv that Poetry created by typing :code:`poetry shell`, 23 | or execute commands within the virtualenv using the :code:`run` command, e.g. :code:`poetry run python`. 24 | -------------------------------------------------------------------------------- /docs/source/landingpage.rst: -------------------------------------------------------------------------------- 1 | Fortuna 2 | ####### 3 | 4 | .. image:: https://img.shields.io/pypi/status/Fortuna 5 | :target: https://img.shields.io/pypi/status/Fortuna 6 | :alt: PyPI - Status 7 | .. image:: https://img.shields.io/pypi/dm/aws-fortuna 8 | :target: https://pypistats.org/packages/aws-fortuna 9 | :alt: PyPI - Downloads 10 | .. image:: https://img.shields.io/pypi/v/aws-fortuna 11 | :target: https://img.shields.io/pypi/v/aws-fortuna 12 | :alt: PyPI - Version 13 | .. image:: https://img.shields.io/github/license/awslabs/Fortuna 14 | :target: https://github.com/awslabs/Fortuna/blob/main/LICENSE 15 | :alt: License 16 | .. image:: https://readthedocs.org/projects/aws-fortuna/badge/?version=latest 17 | :target: https://aws-fortuna.readthedocs.io 18 | :alt: Documentation Status 19 | 20 | A Library for Uncertainty Quantification 21 | ======================================== 22 | Proper estimation of predictive uncertainty is fundamental in applications that involve critical decisions. 23 | Uncertainty can be used to assess reliability of model predictions, trigger human intervention, 24 | or decide whether a model can be safely deployed in the wild. 25 | 26 | Fortuna provides calibrated uncertainty estimates of model predictions, in classification and regression. 27 | It is designed to be easy-to-use, 28 | and to promote effortless estimation of uncertainty in production systems. 29 | 30 | .. include:: quickstart.rst 31 | 32 | .. include:: installation.rst 33 | 34 | .. include:: license.rst 35 | -------------------------------------------------------------------------------- /docs/source/license.rst: -------------------------------------------------------------------------------- 1 | License 2 | ======= 3 | This project is licensed under the Apache-2.0 License. 4 | -------------------------------------------------------------------------------- /docs/source/references/conformal.rst: -------------------------------------------------------------------------------- 1 | Conformal prediction 2 | ==================== 3 | Conformal prediction methods are a type of calibration methods that, starting from uncertainty estimates, 4 | provide *conformal sets*, i.e. rigorous sets of predictions with a user-chosen level of probability. 5 | We support conformal methods for both 6 | :ref:`classification ` 7 | and :ref:`regression `. 8 | 9 | .. _conformal_classification: 10 | 11 | .. automodule:: fortuna.conformal.classification.adaptive_prediction 12 | 13 | .. automodule:: fortuna.conformal.classification.simple_prediction 14 | 15 | .. automodule:: fortuna.conformal.classification.adaptive_conformal_classifier 16 | 17 | .. automodule:: fortuna.conformal.classification.batch_mvp 18 | 19 | .. automodule:: fortuna.conformal.classification.multicalibrator 20 | 21 | .. _conformal_regression: 22 | 23 | .. automodule:: fortuna.conformal.regression.quantile 24 | 25 | .. automodule:: fortuna.conformal.regression.onedim_uncertainty 26 | 27 | .. automodule:: fortuna.conformal.regression.cvplus 28 | 29 | .. automodule:: fortuna.conformal.regression.jackknifeplus 30 | 31 | .. automodule:: fortuna.conformal.regression.jackknife_minmax 32 | 33 | .. automodule:: fortuna.conformal.regression.enbpi 34 | 35 | .. automodule:: fortuna.conformal.regression.adaptive_conformal_regressor 36 | 37 | .. automodule:: fortuna.conformal.regression.batch_mvp 38 | 39 | .. automodule:: fortuna.conformal.regression.multicalibrator 40 | -------------------------------------------------------------------------------- /docs/source/references/data_loader.rst: -------------------------------------------------------------------------------- 1 | Data loader 2 | =========== 3 | This section describes Fortuna's data loader functionalities. A :class:`~fortuna.data.loader.DataLoader` object 4 | is an iterable of two-dimensional tuples of arrays (either `NumPy `__-arrays or `JAX-NumPy `__-arrays), 5 | where the first components are input variables and the second components are target variables. If your dispose of a data loader 6 | of `TensorFlow `__ or `PyTorch `__ tensors, or others, you can convert them into something digestible by Fortuna using 7 | the appropriate :class:`~fortuna.data.loader.DataLoader` functionality 8 | (check :meth:`~fortuna.data.loader.DataLoader.from_tensorflow_data_loader`, :meth:`~fortuna.data.loader.DataLoader.from_torch_data_loader`). 9 | 10 | The data :class:`~fortuna.data.loader.DataLoader` also allows you to generate an :class:`~fortuna.data.loader.InputsLoader` or a 11 | :class:`~fortuna.data.loader.TargetsLoader`, i.e. data loaders of only inputs and only targets variables, respectively 12 | (check :meth:`~fortuna.data.loader.DataLoader.to_inputs_loader` and :meth:`~fortuna.data.loader.DataLoader.to_targets_loader`). 13 | Additionally, you can convert a data loader into an array of inputs, an array of targets, or a tuple of input and target 14 | arrays (check :meth:`~fortuna.data.loader.DataLoader.to_array_inputs`, :meth:`~fortuna.data.loader.DataLoader.to_array_targets` and :meth:`~fortuna.data.loader.DataLoader.to_array_data`). 15 | 16 | .. _data_loader: 17 | 18 | .. autoclass:: fortuna.data.loader.DataLoader 19 | 20 | .. _inputs_loader: 21 | 22 | .. autoclass:: fortuna.data.loader.InputsLoader 23 | 24 | .. _targets_loader: 25 | 26 | .. autoclass:: fortuna.data.loader.TargetsLoader 27 | -------------------------------------------------------------------------------- /docs/source/references/likelihood.rst: -------------------------------------------------------------------------------- 1 | Likelihood function 2 | =================== 3 | The likelihood function models the target data given the model and the input data. 4 | We support a :ref:`classification likelihood ` for classification and a 5 | :ref:`regression likelihood ` for regression. Please find their references below. 6 | 7 | .. _likelihood_classification: 8 | 9 | .. automodule:: fortuna.likelihood.classification 10 | 11 | .. _likelihood_regression: 12 | 13 | .. automodule:: fortuna.likelihood.regression 14 | 15 | .. _likelihood: 16 | 17 | .. automodule:: fortuna.likelihood.base 18 | -------------------------------------------------------------------------------- /docs/source/references/metric.rst: -------------------------------------------------------------------------------- 1 | Metric 2 | =========== 3 | We support some metrics for both 4 | :ref:`classification ` 5 | and :ref:`regression `. 6 | Metrics are `NumPy `__-compatible, therefore feel free to bring your own and 7 | apply them on Fortuna's predictions. 8 | 9 | .. _metric_classification: 10 | 11 | .. automodule:: fortuna.metric.classification 12 | :exclude-members: compute_counts_confs_accs 13 | 14 | .. _metric_regression: 15 | 16 | .. automodule:: fortuna.metric.regression 17 | -------------------------------------------------------------------------------- /docs/source/references/model/builtin_models.rst: -------------------------------------------------------------------------------- 1 | .. _builtin_models: 2 | 3 | Built-in models 4 | =============== 5 | We support several built-in models: 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | mlp 11 | cnn 12 | resnet 13 | wideresnet 14 | constant 15 | scalar_constant 16 | hyper 17 | scalar_hyper 18 | -------------------------------------------------------------------------------- /docs/source/references/model/cnn.rst: -------------------------------------------------------------------------------- 1 | CNN 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.lenet.LeNet5 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/source/references/model/constant.rst: -------------------------------------------------------------------------------- 1 | Constant model 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.constant.ConstantModel 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/source/references/model/hyper.rst: -------------------------------------------------------------------------------- 1 | Hyperparameter model 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.hyper.HyperparameterModel 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/source/references/model/linen_module.rst: -------------------------------------------------------------------------------- 1 | Flax Linen Module 2 | ================= 3 | .. automodule:: flax.linen.Module 4 | -------------------------------------------------------------------------------- /docs/source/references/model/mlp.rst: -------------------------------------------------------------------------------- 1 | Multi-Layer Perceptron (MLP) 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.mlp.MLP 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/source/references/model/model.rst: -------------------------------------------------------------------------------- 1 | Model 2 | ============================ 3 | A deterministic model from inputs to outputs. 4 | You can choose among several :ref:`built-in models `, 5 | or bring in your own model by overwriting :mod:`~flax.linen.Module`. 6 | The model forward pass is orchestrated by a :ref:`model manager `. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | :caption: Contents: 11 | 12 | builtin_models 13 | linen_module 14 | model_manager 15 | utils 16 | -------------------------------------------------------------------------------- /docs/source/references/model/model_manager.rst: -------------------------------------------------------------------------------- 1 | .. _model_manager: 2 | 3 | Model manager 4 | ============= 5 | The model manager is responsible for the orchestration of the forward pass. 6 | We support a :ref:`classification model manager ` for classification 7 | and a :ref:`regression model manager ` for regression. 8 | 9 | .. _model_manager_classification: 10 | 11 | .. autoclass:: fortuna.model.model_manager.classification.ClassificationModelManager 12 | :no-inherited-members: 13 | 14 | .. _model_manager_regression: 15 | 16 | .. automodule:: fortuna.model.model_manager.regression 17 | 18 | .. automodule:: fortuna.model.model_manager.base 19 | 20 | .. autoclass:: fortuna.model.model_manager.state.ModelManagerState 21 | :no-inherited-members: 22 | :exclude-members: params, mutable 23 | -------------------------------------------------------------------------------- /docs/source/references/model/resnet.rst: -------------------------------------------------------------------------------- 1 | ResNet 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.resnet.ResNet 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | 10 | .. autoclass:: fortuna.model.resnet.ResNet18 11 | 12 | .. autoclass:: fortuna.model.resnet.ResNet34 13 | 14 | .. autoclass:: fortuna.model.resnet.ResNet50 15 | 16 | .. autoclass:: fortuna.model.resnet.ResNet101 17 | 18 | .. autoclass:: fortuna.model.resnet.ResNet152 19 | 20 | .. autoclass:: fortuna.model.resnet.ResNet200 21 | -------------------------------------------------------------------------------- /docs/source/references/model/scalar_constant.rst: -------------------------------------------------------------------------------- 1 | Scalar constant model 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.scalar_constant.ScalarConstantModel 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/source/references/model/scalar_hyper.rst: -------------------------------------------------------------------------------- 1 | Scalar hyperparameter model 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.scalar_hyper.ScalarHyperparameterModel 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/source/references/model/utils.rst: -------------------------------------------------------------------------------- 1 | .. _builtin_models: 2 | 3 | Model Utils 4 | =============== 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | utils/spectral_norm 10 | utils/random_features 11 | -------------------------------------------------------------------------------- /docs/source/references/model/utils/random_features.rst: -------------------------------------------------------------------------------- 1 | Random Features 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.utils.random_features.RandomFeatureGaussianProcess 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | 10 | .. autoclass:: fortuna.model.utils.random_features.RandomFourierFeatures 11 | :no-undoc-members: 12 | :no-inherited-members: 13 | :no-members: 14 | :show-inheritance: 15 | 16 | .. autoclass:: fortuna.model.utils.random_features.LaplaceRandomFeatureCovariance 17 | :no-undoc-members: 18 | :no-inherited-members: 19 | :no-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/source/references/model/utils/spectral_norm.rst: -------------------------------------------------------------------------------- 1 | Spectral Normalization 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.utils.spectral_norm.SpectralNormalization 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | 8 | .. autoclass:: fortuna.model.utils.spectral_norm.SpectralNormalizationConv2D 9 | :no-undoc-members: 10 | :no-inherited-members: 11 | -------------------------------------------------------------------------------- /docs/source/references/model/wideresnet.rst: -------------------------------------------------------------------------------- 1 | WideResNet 2 | ============================ 3 | 4 | .. autoclass:: fortuna.model.wideresnet.WideResNet 5 | :no-undoc-members: 6 | :no-inherited-members: 7 | :no-members: 8 | :show-inheritance: 9 | 10 | .. autoclass:: fortuna.model.wideresnet.WideResNet28_10 11 | -------------------------------------------------------------------------------- /docs/source/references/ood_classifier.rst: -------------------------------------------------------------------------------- 1 | .. _ood_detection: 2 | 3 | Out-Of-Distribution (OOD) detection 4 | ================== 5 | Starting from a trained a neural classifier, it's possible to fit one of the models below 6 | to help distinguish between in-distribution and out of distribution inputs. 7 | 8 | .. autoclass:: fortuna.ood_detection.mahalanobis.MalahanobisOODClassifier 9 | 10 | .. autoclass:: fortuna.ood_detection.ddu.DeepDeterministicUncertaintyOODClassifier 11 | -------------------------------------------------------------------------------- /docs/source/references/output_calib_model/config.rst: -------------------------------------------------------------------------------- 1 | Calibration configuration 2 | ========================= 3 | This section describes :class:`~fortuna.output_calib_model.config.base.Config`, 4 | an object that configures the calibration process of the probabilistic model. It is made of several objects: 5 | 6 | - :class:`~fortuna.output_calib_model.config.optimizer.Optimizer`: to configure the optimization process; 7 | 8 | - :class:`~fortuna.output_calib_model.config.checkpointer.Checkpointer`: to save and restore checkpoints; 9 | 10 | - :class:`~fortuna.output_calib_model.config.monitor.Monitor`: to monitor the process and trigger early stopping; 11 | 12 | - :class:`~fortuna.output_calib_model.config.processor.Processor`: to decide how and where the computation is processed. 13 | 14 | .. _output_calib_model_config: 15 | 16 | .. autoclass:: fortuna.output_calib_model.config.base.Config 17 | 18 | .. _output_calib_model_calib_optimizer: 19 | 20 | .. autoclass:: fortuna.output_calib_model.config.optimizer.Optimizer 21 | 22 | .. _output_calib_model_calib_checkpointer: 23 | 24 | .. autoclass:: fortuna.output_calib_model.config.checkpointer.Checkpointer 25 | 26 | .. _output_calib_model_calib_monitor: 27 | 28 | .. autoclass:: fortuna.output_calib_model.config.monitor.Monitor 29 | 30 | .. _output_calib_model_calib_processor: 31 | 32 | .. autoclass:: fortuna.output_calib_model.config.processor.Processor 33 | -------------------------------------------------------------------------------- /docs/source/references/output_calib_model/output_calib_model.rst: -------------------------------------------------------------------------------- 1 | Output calibration model 2 | ======================== 3 | We support a :ref:`calibration classifier ` for classification 4 | and a :ref:`calibration regressor ` for regression. 5 | Please find their references below. 6 | 7 | .. _output_calib_classifier: 8 | 9 | .. automodule:: fortuna.output_calib_model.classification 10 | :members: 11 | :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint 12 | 13 | .. _output_calib_regressor: 14 | 15 | .. automodule:: fortuna.output_calib_model.regression 16 | :members: 17 | :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint 18 | 19 | .. _output_calib_base: 20 | 21 | .. automodule:: fortuna.output_calib_model.base 22 | :members: 23 | :exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint 24 | 25 | .. toctree:: 26 | :maxdepth: 1 27 | :hidden: 28 | :caption: Contents: 29 | 30 | predictive 31 | config 32 | -------------------------------------------------------------------------------- /docs/source/references/output_calib_model/predictive.rst: -------------------------------------------------------------------------------- 1 | Predictive distribution 2 | ============================ 3 | The predictive distribution is the component of the probabilistic model responsible for the computation of predictive 4 | statistics. We support a :ref:`classification predictive ` for classification and a 5 | :ref:`regression predictive ` for regression. Please find their references below. 6 | 7 | .. _output_calib_predictive: 8 | 9 | .. _output_calib_predictive_classification: 10 | 11 | .. automodule:: fortuna.output_calib_model.predictive.classification 12 | 13 | .. _output_calib_predictive_regression: 14 | 15 | .. automodule:: fortuna.output_calib_model.predictive.regression 16 | -------------------------------------------------------------------------------- /docs/source/references/output_calibrator.rst: -------------------------------------------------------------------------------- 1 | .. _output_calibrator: 2 | 3 | Output calibrator 4 | ================== 5 | The output calibration calibrates the model outputs. We explicitly support a 6 | :ref:`temperature scaling output calibrator for classification `, 7 | and a :ref:`temperature scaling output calibrator for regression `. 8 | 9 | Alternatively, you can bring in your own output calibrator by overwriting :mod:`~flax.linen.Module`. 10 | 11 | .. _output_calibrator_classification: 12 | 13 | .. automodule:: fortuna.output_calibrator.classification 14 | :no-inherited-members: 15 | 16 | .. _output_calibrator_regression: 17 | 18 | .. automodule:: fortuna.output_calibrator.regression 19 | :no-inherited-members: 20 | 21 | .. autoclass:: fortuna.output_calib_model.state.OutputCalibState 22 | :no-inherited-members: 23 | :exclude-members: params, mutable, encoded_name, replace 24 | -------------------------------------------------------------------------------- /docs/source/references/plot.rst: -------------------------------------------------------------------------------- 1 | Plot 2 | =========== 3 | This section includes the plotting functionalities currently supported by Fortuna. 4 | 5 | .. autofunction:: fortuna.plot.plot_reliability_diagram 6 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/callbacks.rst: -------------------------------------------------------------------------------- 1 | Training Callbacks 2 | =============================== 3 | This section describes :class:`~fortuna.prob_model.fit_config.callback.FitCallback`, 4 | which allows users to add custom actions at different stages of the training loop. 5 | Callbacks can be used while training a :class:`~fortuna.prob_model.base.ProbModel`. 6 | 7 | To use callbacks the user has to: 8 | 9 | - Define their own callbacks by subclassing :class:`~fortuna.prob_model.fit_config.callback.FitCallback` and override the methods of interest. 10 | - When calling the train method of a :class:`~fortuna.calib_model.base.ProbModel` instance, 11 | add a list of callbacks to the configuration object :class:`~fortuna.prob_model.fit_config.base.FitConfig`. 12 | 13 | The following example outlines the usage of :class:`~fortuna.prob_model.fit_config.callback.FitCallback`. 14 | It assumes that the user already obtained an instance of :class:`~fortuna.prob_model.base.ProbModel`: 15 | 16 | .. code-block:: python 17 | 18 | from jax.flatten_util import ravel_pytree 19 | import optax 20 | 21 | from fortuna.training.train_state import TrainState 22 | from fortuna.prob_model.fit_config import FitConfig, FitMonitor, FitOptimizer, FitCallback 23 | from fortuna.metric.classification import accuracy 24 | 25 | # Define custom callback 26 | class CountParamsCallback(FitCallback): 27 | def training_epoch_start(self, state: TrainState) -> TrainState: 28 | params, unravel = ravel_pytree(state.params) 29 | logger.info(f"num params: {len(params)}") 30 | return state 31 | 32 | # Add a list of callbacks containing CountParamsCallback to FitConfig 33 | status = prob_model.train( 34 | train_data_loader=train_data_loader, 35 | val_data_loader=val_data_loader, 36 | calib_data_loader=val_data_loader, 37 | fit_config=FitConfig( 38 | optimizer=FitOptimizer(method=optax.adam(1e-4), n_epochs=100), 39 | callbacks=[ 40 | CountParamsCallback() 41 | ] 42 | ) 43 | ) 44 | 45 | 46 | .. _callbacks: 47 | 48 | .. autoclass:: fortuna.training.callback.Callback 49 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/fit_config.rst: -------------------------------------------------------------------------------- 1 | Posterior fitting configuration 2 | =============================== 3 | This section describes :class:`~fortuna.prob_model.fit_config.base.FitConfig`, 4 | an object that configures the posterior fitting process. It is made of several objects: 5 | 6 | - :class:`~fortuna.prob_model.fit_config.optimizer.FitOptimizer`: to configure the optimization process; 7 | - :class:`~fortuna.prob_model.fit_config.checkpointer.FitCheckpointer`: to save and restore checkpoints; 8 | - :class:`~fortuna.prob_model.fit_config.monitor.FitMonitor`: to monitor the process and trigger early stopping; 9 | - :class:`~fortuna.prob_model.fit_config.processor.FitProcessor`: to decide how and where the computation is processed. 10 | - List[:class:`~fortuna.prob_model.fit_config.callback.Callback`]: to allow the user to perform custom actions at different stages of the training process. 11 | 12 | .. _fit_config: 13 | 14 | .. autoclass:: fortuna.prob_model.fit_config.base.FitConfig 15 | 16 | .. autoclass:: fortuna.prob_model.fit_config.optimizer.FitOptimizer 17 | 18 | .. autoclass:: fortuna.prob_model.fit_config.monitor.FitMonitor 19 | 20 | .. autoclass:: fortuna.prob_model.fit_config.checkpointer.FitCheckpointer 21 | 22 | .. autoclass:: fortuna.prob_model.fit_config.processor.FitProcessor 23 | 24 | .. autoclass:: fortuna.prob_model.fit_config.callback.Callback 25 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/joint.rst: -------------------------------------------------------------------------------- 1 | Joint distribution 2 | =================== 3 | The :ref:`joint ` distribution of the target data and the model parameters given the input data. 4 | Please find its reference below. 5 | 6 | .. _joint: 7 | 8 | .. automodule:: fortuna.prob_model.joint.base 9 | 10 | .. autoclass:: fortuna.prob_model.joint.state.JointState 11 | :show-inheritance: 12 | :no-inherited-members: 13 | :exclude-members: params, mutable, calib_params, calib_mutable 14 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/model_editor.rst: -------------------------------------------------------------------------------- 1 | Model editor 2 | =================== 3 | A model editor is an object that takes the forward pass as a function of model parameters and a batch of input data 4 | points, and can modify its outputs in a custom way. 5 | 6 | .. automodule:: fortuna.model_editor.base 7 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/posterior/advi.rst: -------------------------------------------------------------------------------- 1 | Automatic Differentiation Variational Inference (ADVI) 2 | ------------------------------------------------------ 3 | 4 | .. autoclass:: fortuna.prob_model.posterior.normalizing_flow.advi.advi_approximator.ADVIPosteriorApproximator 5 | 6 | .. autoclass:: fortuna.prob_model.posterior.normalizing_flow.advi.advi_posterior.ADVIPosterior 7 | :show-inheritance: 8 | :no-inherited-members: 9 | :exclude-members: state 10 | :members: fit, sample, load_state, save_state 11 | 12 | .. autoclass:: fortuna.prob_model.posterior.normalizing_flow.advi.advi_state.ADVIState 13 | :show-inheritance: 14 | :no-inherited-members: 15 | :inherited-members: init, init_from_dict 16 | :members: convert_from_map_state 17 | :exclude-members: params, mutable, calib_params, calib_mutable, replace, apply_gradients, encoded_name, create 18 | :no-undoc-members: 19 | :no-special-members: 20 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/posterior/deep_ensemble.rst: -------------------------------------------------------------------------------- 1 | Deep Ensemble 2 | -------------- 3 | 4 | .. autoclass:: fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_approximator.DeepEnsemblePosteriorApproximator 5 | 6 | .. autoclass:: fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_posterior.DeepEnsemblePosterior 7 | :show-inheritance: 8 | :no-inherited-members: 9 | :exclude-members: state 10 | :members: fit, sample, load_state, save_state 11 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/posterior/laplace.rst: -------------------------------------------------------------------------------- 1 | Laplace approximation 2 | --------------------- 3 | 4 | .. autoclass:: fortuna.prob_model.posterior.laplace.laplace_approximator.LaplacePosteriorApproximator 5 | 6 | .. autoclass:: fortuna.prob_model.posterior.laplace.laplace_posterior.LaplacePosterior 7 | :show-inheritance: 8 | :no-inherited-members: 9 | :exclude-members: state 10 | :members: fit, sample, load_state, save_state 11 | 12 | .. autoclass:: fortuna.prob_model.posterior.laplace.laplace_state.LaplaceState 13 | :show-inheritance: 14 | :no-inherited-members: 15 | :inherited-members: init, init_from_dict 16 | :members: convert_from_map_state 17 | :exclude-members: params, mutable, calib_params, calib_mutable, replace, apply_gradients, encoded_name, create 18 | :no-undoc-members: 19 | :no-special-members: 20 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/posterior/map.rst: -------------------------------------------------------------------------------- 1 | Maximum-A-Posteriori (MAP) 2 | -------------------------- 3 | 4 | .. autoclass:: fortuna.prob_model.posterior.map.map_approximator.MAPPosteriorApproximator 5 | 6 | .. autoclass:: fortuna.prob_model.posterior.map.map_posterior.MAPPosterior 7 | :show-inheritance: 8 | :no-inherited-members: 9 | :exclude-members: state 10 | :members: fit, sample, load_state, save_state 11 | 12 | .. autoclass:: fortuna.prob_model.posterior.map.map_state.MAPState 13 | :show-inheritance: 14 | :no-inherited-members: 15 | :inherited-members: init, init_from_dict 16 | :exclude-members: params, mutable, calib_params, calib_mutable, replace, create, apply_gradients, encoded_name 17 | :no-undoc-members: 18 | :no-special-members: 19 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/posterior/posterior.rst: -------------------------------------------------------------------------------- 1 | Posterior 2 | =================== 3 | The :ref:`posterior ` distribution of the model parameters given the training data and the 4 | calibration parameters. We support several posterior approximations: 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | map 10 | advi 11 | deep_ensemble 12 | laplace 13 | swag 14 | sngp 15 | sgmcmc 16 | 17 | .. _posterior: 18 | 19 | .. autoclass:: fortuna.prob_model.posterior.base.Posterior 20 | :no-inherited-members: 21 | :exclude-members: state 22 | :members: fit, sample, load_state, save_state 23 | 24 | .. autoclass:: fortuna.prob_model.posterior.base.PosteriorApproximator 25 | 26 | .. autoclass:: fortuna.prob_model.posterior.state.PosteriorState 27 | :no-inherited-members: 28 | :exclude-members: params, mutable, calib_params, calib_mutable, replace, encoded_name 29 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/posterior/sngp.rst: -------------------------------------------------------------------------------- 1 | Spectral-normalized Neural Gaussian Process (SNGP) 2 | -------------------------------------------------- 3 | 4 | .. autoclass:: fortuna.prob_model.posterior.sngp.sngp_approximator.SNGPPosteriorApproximator 5 | 6 | .. autoclass:: fortuna.prob_model.posterior.sngp.sngp_posterior.SNGPPosterior 7 | :show-inheritance: 8 | :no-inherited-members: 9 | :exclude-members: state 10 | :members: fit, sample, load_state, save_state 11 | 12 | 13 | .. autoclass:: fortuna.model.model_manager.classification.SNGPClassificationModelManager 14 | 15 | .. autoclass:: fortuna.prob_model.posterior.sngp.sngp_callback.ResetCovarianceCallback 16 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/posterior/swag.rst: -------------------------------------------------------------------------------- 1 | SWAG 2 | ----- 3 | 4 | .. autoclass:: fortuna.prob_model.posterior.swag.swag_approximator.SWAGPosteriorApproximator 5 | 6 | .. autoclass:: fortuna.prob_model.posterior.swag.swag_posterior.SWAGPosterior 7 | :show-inheritance: 8 | :no-inherited-members: 9 | :exclude-members: state 10 | :members: fit, sample, load_state, save_state 11 | 12 | .. autoclass:: fortuna.prob_model.posterior.swag.swag_state.SWAGState 13 | :show-inheritance: 14 | :no-inherited-members: 15 | :inherited-members: init, init_from_dict 16 | :members: convert_from_map_state, update 17 | :exclude-members: params, mutable, calib_params, calib_mutable, replace, mean, std, dev, create, apply_gradients, 18 | encoded_name 19 | :no-undoc-members: 20 | :no-special-members: 21 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/predictive.rst: -------------------------------------------------------------------------------- 1 | Predictive distribution 2 | ============================ 3 | The predictive distribution is the component of the probabilistic model responsible for the computation of predictive 4 | statistics. We support a :ref:`classification predictive ` for classification and a 5 | :ref:`regression predictive ` for regression. Please find their references below. 6 | 7 | .. _predictive: 8 | 9 | .. _predictive_classification: 10 | 11 | .. automodule:: fortuna.prob_model.predictive.classification 12 | 13 | .. _predictive_regression: 14 | 15 | .. automodule:: fortuna.prob_model.predictive.regression 16 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/prior.rst: -------------------------------------------------------------------------------- 1 | Prior distribution 2 | ================== 3 | We support Gaussian prior distributions, 4 | specifically :class:`~fortuna.prob_model.prior.gaussian.IsotropicGaussianPrior` 5 | and :class:`~fortuna.prob_model.prior.gaussian.DiagonalGaussianPrior`. 6 | 7 | Alternatively, you can bring your own prior distribution by overwriting the abstract class 8 | :class:`~fortuna.prob_model.prior.base.Prior`. Please find the references below. 9 | 10 | .. automodule:: fortuna.prob_model.prior.base 11 | 12 | .. automodule:: fortuna.prob_model.prior.gaussian 13 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/prob_calib_config.rst: -------------------------------------------------------------------------------- 1 | Calibration configuration 2 | ========================= 3 | This section describes :class:`~fortuna.prob_model.calib_config.base.CalibConfig`, 4 | an object that configures the calibration process of the probabilistic model. It is made of several objects: 5 | 6 | - :class:`~fortuna.prob_model.calib_config.optimizer.CalibOptimizer`: to configure the optimization process; 7 | 8 | - :class:`~fortuna.prob_model.calib_config.checkpointer.CalibCheckpointer`: to save and restore checkpoints; 9 | 10 | - :class:`~fortuna.prob_model.calib_config.monitor.CalibMonitor`: to monitor the process and trigger early stopping; 11 | 12 | - :class:`~fortuna.prob_model.calib_config.processor.CalibProcessor`: to decide how and where the computation is processed. 13 | 14 | .. _prob_calib_config: 15 | 16 | .. autoclass:: fortuna.prob_model.calib_config.base.CalibConfig 17 | 18 | .. autoclass:: fortuna.prob_model.calib_config.optimizer.CalibOptimizer 19 | 20 | .. autoclass:: fortuna.prob_model.calib_config.checkpointer.CalibCheckpointer 21 | 22 | .. autoclass:: fortuna.prob_model.calib_config.monitor.CalibMonitor 23 | 24 | .. autoclass:: fortuna.prob_model.calib_config.processor.CalibProcessor 25 | -------------------------------------------------------------------------------- /docs/source/references/prob_model/prob_model.rst: -------------------------------------------------------------------------------- 1 | Probabilistic model 2 | =================== 3 | We support a :ref:`probabilistic classifier ` for classification 4 | and a :ref:`probabilistic regressor ` for regression. 5 | Please find their references below. 6 | 7 | .. _prob_classifier: 8 | 9 | .. automodule:: fortuna.prob_model.classification 10 | :members: 11 | 12 | .. _prob_regressor: 13 | 14 | .. automodule:: fortuna.prob_model.regression 15 | :members: 16 | 17 | .. _prob_base: 18 | 19 | .. automodule:: fortuna.prob_model.base 20 | :members: 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | :hidden: 25 | :caption: Contents: 26 | 27 | predictive 28 | posterior/posterior 29 | joint 30 | prior 31 | model_editor 32 | fit_config 33 | prob_calib_config 34 | callbacks 35 | -------------------------------------------------------------------------------- /docs/source/references/prob_output_layer.rst: -------------------------------------------------------------------------------- 1 | .. _prob_output_layer: 2 | 3 | Probabilistic output layer 4 | ========================== 5 | 6 | We support a :ref:`classification probabilistic output layer ` for classification 7 | and a :ref:`regression probabilistic output layer ` for regression. 8 | Please find their references below. 9 | 10 | .. automodule:: fortuna.prob_output_layer.base 11 | 12 | .. _prob_output_layer_classification: 13 | 14 | .. automodule:: fortuna.prob_output_layer.classification 15 | 16 | .. _prob_output_layer_regression: 17 | 18 | .. automodule:: fortuna.prob_output_layer.regression 19 | -------------------------------------------------------------------------------- /docs/source/references/references.rst: -------------------------------------------------------------------------------- 1 | API References 2 | =================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Contents: 7 | 8 | prob_model/prob_model 9 | output_calib_model/output_calib_model 10 | likelihood 11 | model/model 12 | output_calibrator 13 | prob_output_layer 14 | conformal 15 | ood_detection 16 | data_loader 17 | metric 18 | utils 19 | plot 20 | sagemaker 21 | typing 22 | -------------------------------------------------------------------------------- /docs/source/references/typing.rst: -------------------------------------------------------------------------------- 1 | Typing 2 | =========== 3 | These section contains Fortuna's custom typings. 4 | 5 | .. autoclass:: fortuna.typing.Array 6 | .. autoclass:: fortuna.typing.Batch 7 | .. autoclass:: fortuna.typing.Params 8 | .. autoclass:: fortuna.typing.Mutable 9 | .. autoclass:: fortuna.typing.Path 10 | .. autoclass:: fortuna.typing.OptaxOptimizer 11 | .. autoclass:: fortuna.typing.CalibParams 12 | .. autoclass:: fortuna.typing.CalibMutable 13 | .. autoclass:: fortuna.typing.Status 14 | -------------------------------------------------------------------------------- /docs/source/references/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ========== 3 | The probabilistic model automatically handles random number generator updates via a 4 | :ref:`random number generator ` object. 5 | Please find its reference below. 6 | 7 | .. _random_number_generator: 8 | 9 | .. autoclass:: fortuna.utils.random.RandomNumberGenerator 10 | -------------------------------------------------------------------------------- /docs/source/usage_modes/usage_modes.rst: -------------------------------------------------------------------------------- 1 | .. _usage_modes: 2 | 3 | Usage modes 4 | ###################### 5 | This section explains some of Fortuna's main usage modes. 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | :caption: CONTENTS: 10 | 11 | uncertainty_estimates 12 | model_outputs 13 | flax_models 14 | -------------------------------------------------------------------------------- /examples/index.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | In this section we show some examples of how to use Fortuna in classification and regression tasks. 4 | 5 | 6 | .. toctree:: 7 | :glob: 8 | :maxdepth: 2 9 | :caption: Examples: 10 | 11 | adaptive_conformal_inference 12 | bring_in_your_own 13 | enbpi_ts_regression 14 | jackknifeplus_regression 15 | mnist_classification 16 | multivalid_coverage 17 | sinusoidal_regression 18 | two_moons_classification 19 | two_moons_classification_ood 20 | subnet_calibration 21 | scaling_up_bayesian_inference 22 | mnist_classification_sghmc 23 | sgmcmc_diagnostics 24 | sentiment_analysis 25 | -------------------------------------------------------------------------------- /fortuna/__init__.py: -------------------------------------------------------------------------------- 1 | import flax 2 | 3 | flax.config.update("flax_use_orbax_checkpointing", False) 4 | -------------------------------------------------------------------------------- /fortuna/calib_model/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.calib_model.classification import CalibClassifier 2 | from fortuna.calib_model.config.base import Config 3 | from fortuna.calib_model.config.checkpointer import Checkpointer 4 | from fortuna.calib_model.config.hyperparameters import Hyperparameters 5 | from fortuna.calib_model.config.monitor import Monitor 6 | from fortuna.calib_model.config.optimizer import Optimizer 7 | from fortuna.calib_model.config.processor import Processor 8 | from fortuna.calib_model.regression import CalibRegressor 9 | -------------------------------------------------------------------------------- /fortuna/calib_model/calib_mixin.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from flax.training import checkpoints 5 | 6 | from fortuna.calib_model.state import CalibState 7 | from fortuna.training.mixin import WithCheckpointingMixin 8 | from fortuna.typing import ( 9 | OptaxOptimizer, 10 | Path, 11 | ) 12 | 13 | 14 | class WithCalibCheckpointingMixin(WithCheckpointingMixin): 15 | def restore_checkpoint( 16 | self, 17 | restore_checkpoint_path: Path, 18 | optimizer: Optional[OptaxOptimizer] = None, 19 | prefix: str = "checkpoint_", 20 | **kwargs, 21 | ) -> CalibState: 22 | if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile( 23 | restore_checkpoint_path 24 | ): 25 | raise ValueError( 26 | f"`restore_checkpoint_path={restore_checkpoint_path}` was not found." 27 | ) 28 | d = checkpoints.restore_checkpoint( 29 | ckpt_dir=str(restore_checkpoint_path), 30 | target=None, 31 | step=None, 32 | prefix=prefix, 33 | parallel=True, 34 | ) 35 | if d is None: 36 | raise ValueError( 37 | f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." 38 | ) 39 | 40 | return CalibState.init_from_dict(d, optimizer, **kwargs) 41 | -------------------------------------------------------------------------------- /fortuna/calib_model/calib_state_repository.py: -------------------------------------------------------------------------------- 1 | from fortuna.calib_model.calib_mixin import WithCalibCheckpointingMixin 2 | from fortuna.training.train_state_repository import TrainStateRepository 3 | 4 | 5 | class CalibStateRepository(WithCalibCheckpointingMixin, TrainStateRepository): 6 | pass 7 | -------------------------------------------------------------------------------- /fortuna/calib_model/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/calib_model/config/__init__.py -------------------------------------------------------------------------------- /fortuna/calib_model/config/base.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | List, 3 | Optional, 4 | ) 5 | 6 | from fortuna.calib_model.config.callback import Callback 7 | from fortuna.calib_model.config.checkpointer import Checkpointer 8 | from fortuna.calib_model.config.hyperparameters import Hyperparameters 9 | from fortuna.calib_model.config.monitor import Monitor 10 | from fortuna.calib_model.config.optimizer import Optimizer 11 | from fortuna.calib_model.config.processor import Processor 12 | 13 | 14 | class Config: 15 | def __init__( 16 | self, 17 | optimizer: Optimizer = Optimizer(), 18 | checkpointer: Checkpointer = Checkpointer(), 19 | monitor: Monitor = Monitor(), 20 | processor: Processor = Processor(), 21 | hyperparameters: Hyperparameters = Hyperparameters(), 22 | callbacks: Optional[List[Callback]] = None, 23 | ): 24 | """ 25 | Configure the posterior distribution fitting. 26 | 27 | Parameters 28 | ---------- 29 | optimizer: Optimizer 30 | It defines the optimization specifics. 31 | checkpointer: Checkpointer 32 | It handles saving and restoring checkpoints. 33 | monitor: Monitor 34 | It monitors training progress and might induce early stopping. 35 | processor: Processor 36 | It processes where computation takes place. 37 | hyperparameters: Hyperparameters 38 | It defines other hyperparameters that may be needed during model's training. 39 | callbacks: Optional[List[FitCallback]] 40 | A list of user-defined callbacks to be called during training. 41 | Callbacks run sequentially in the order defined by the user. 42 | """ 43 | self.optimizer = optimizer 44 | self.checkpointer = checkpointer 45 | self.monitor = monitor 46 | self.processor = processor 47 | self.hyperparameters = hyperparameters 48 | self.callbacks = callbacks 49 | -------------------------------------------------------------------------------- /fortuna/calib_model/config/callback.py: -------------------------------------------------------------------------------- 1 | from fortuna.training.callback import Callback 2 | -------------------------------------------------------------------------------- /fortuna/calib_model/config/checkpointer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fortuna.typing import Path 4 | 5 | 6 | class Checkpointer: 7 | def __init__( 8 | self, 9 | save_checkpoint_dir: Optional[Path] = None, 10 | restore_checkpoint_path: Optional[Path] = None, 11 | start_from_current_state: bool = False, 12 | save_every_n_steps: Optional[int] = None, 13 | keep_top_n_checkpoints: Optional[int] = 2, 14 | dump_state: bool = False, 15 | ): 16 | """ 17 | An object to configure saving and restoring of checkpoints during the calibration process. 18 | 19 | Parameters 20 | ---------- 21 | save_checkpoint_dir: Optional[Path] = None 22 | Save directory location. 23 | restore_checkpoint_path: Optional[Path] 24 | Path to checkpoint file or directory to restore. 25 | start_from_current_state: bool = False 26 | If True, the optimization will start from the current state. If `restore_checkpoint_path` is given, then 27 | `start_from_current_state` is ignored. 28 | save_every_n_steps: int 29 | Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no 30 | checkpoint will be saved during training). 31 | keep_top_n_checkpoints: int 32 | Number of past checkpoint files to keep. 33 | dump_state: bool 34 | Dump the fitted calibration state as a checkpoint in `save_checkpoint_dir`. 35 | Any future call to the state will internally involve restoring it from memory. 36 | """ 37 | self.save_checkpoint_dir = save_checkpoint_dir 38 | self.save_every_n_steps = save_every_n_steps 39 | self.restore_checkpoint_path = restore_checkpoint_path 40 | self.start_from_current_state = start_from_current_state 41 | self.keep_top_n_checkpoints = keep_top_n_checkpoints 42 | self.dump_state = dump_state 43 | -------------------------------------------------------------------------------- /fortuna/calib_model/config/hyperparameters.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class Hyperparameters: 5 | def __init__( 6 | self, 7 | max_grad_norm: Optional[float] = None, 8 | gradient_accumulation_steps: Optional[int] = None, 9 | ): 10 | """ 11 | An object to configure additional arguments that may be needed during the posterior fitting. 12 | 13 | Parameters 14 | ---------- 15 | max_grad_norm: Optional[Path] 16 | Maximum gradient norm. If `max_grad_norm > 0`, gradient clipping is performed. 17 | gradient_accumulation_steps: Optional[Path] 18 | Number of forward passes to perform before doing a backward pass. 19 | """ 20 | self.max_grad_norm = max_grad_norm 21 | self.gradient_accumulation_steps = gradient_accumulation_steps 22 | -------------------------------------------------------------------------------- /fortuna/calib_model/config/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Callable, 3 | Optional, 4 | Tuple, 5 | ) 6 | 7 | import optax 8 | 9 | from fortuna.typing import ( 10 | AnyKey, 11 | Array, 12 | OptaxOptimizer, 13 | ) 14 | 15 | 16 | class Optimizer: 17 | def __init__( 18 | self, 19 | method: Optional[OptaxOptimizer] = optax.adam(1e-2), 20 | n_epochs: int = 100, 21 | freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None, 22 | ): 23 | """ 24 | An object to configure the optimization in the calibration process. 25 | 26 | Parameters 27 | ---------- 28 | method: OptaxOptimizer 29 | An Optax optimizer. 30 | n_epochs: int 31 | Maximum number of epochs to run the calibration for. 32 | freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] 33 | A callable taking in input a path in the nested dictionary of parameters, as well as the corresponding 34 | array of parameters, and returns "trainable" or "freeze", according to whether the corresponding parameter 35 | should be optimized or not. 36 | 37 | Examples 38 | -------- 39 | .. code-block:: python 40 | def freeze_fun(path: Tuple[str], v: Array) -> str: 41 | path = [p[:6] for p in path] # take only the first 6 characters of each key" 42 | return 'trainable' if "Dense" in path else 'frozen'` 43 | """ 44 | self.method = method 45 | self.n_epochs = n_epochs 46 | self.freeze_fun = freeze_fun 47 | -------------------------------------------------------------------------------- /fortuna/calib_model/config/processor.py: -------------------------------------------------------------------------------- 1 | class Processor: 2 | def __init__( 3 | self, 4 | devices: int = -1, 5 | disable_jit: bool = False, 6 | ): 7 | """ 8 | An object to configure computational aspects of the calibration process. 9 | 10 | Parameters 11 | ---------- 12 | devices: int 13 | A list of devices to be used during training. 14 | At the moment two options are supported: use all devices (`devices=-1`) or use no device (`devices=0`). 15 | disable_jit: bool 16 | if True, no function within the calibration loop is jitted. 17 | """ 18 | self.devices = devices 19 | self.disable_jit = disable_jit 20 | -------------------------------------------------------------------------------- /fortuna/calib_model/predictive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/calib_model/predictive/__init__.py -------------------------------------------------------------------------------- /fortuna/calib_model/predictive/classification.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from fortuna.calib_model.predictive.base import Predictive 4 | from fortuna.data.loader import InputsLoader 5 | from fortuna.likelihood.classification import ClassificationLikelihood 6 | 7 | 8 | class ClassificationPredictive(Predictive): 9 | def __init__(self, likelihood: ClassificationLikelihood): 10 | super().__init__(likelihood=likelihood) 11 | 12 | def entropy( 13 | self, 14 | inputs_loader: InputsLoader, 15 | distribute: bool = True, 16 | ) -> jnp.ndarray: 17 | r""" 18 | Estimate the predictive entropy, that is 19 | 20 | .. math:: 21 | -\mathbb{E}_{Y|x, \mathcal{D}}[\log p(Y|x, \mathcal{D})], 22 | 23 | where: 24 | - :math:`x` is an observed input variable; 25 | - :math:`Y` is a random target variable; 26 | - :math:`\mathcal{D}` is the observed training data set; 27 | - :math:`W` denotes the random model parameters. 28 | 29 | Parameters 30 | ---------- 31 | inputs_loader : InputsLoader 32 | A loader of input data points. 33 | distribute: bool 34 | Whether to distribute computation over multiple devices, if available. 35 | 36 | Returns 37 | ------- 38 | jnp.ndarray 39 | An estimate of the predictive entropy for each input. 40 | """ 41 | state = self.state.get() 42 | return self.likelihood.entropy( 43 | params=state.params, 44 | inputs_loader=inputs_loader, 45 | mutable=state.mutable, 46 | distribute=distribute, 47 | ) 48 | -------------------------------------------------------------------------------- /fortuna/calibration/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.calibration.binary_classification.temp_scaling.bias_binary_temp_scaling import ( 2 | BiasBinaryClassificationTemperatureScaling, 3 | ) 4 | from fortuna.calibration.binary_classification.temp_scaling.brier_binary_temp_scaling import ( 5 | BrierBinaryClassificationTemperatureScaling, 6 | ) 7 | from fortuna.calibration.binary_classification.temp_scaling.crossentropy_binary_temp_scaling import ( 8 | CrossEntropyBinaryClassificationTemperatureScaling, 9 | ) 10 | from fortuna.calibration.binary_classification.temp_scaling.f1_temp_scaling import ( 11 | F1BinaryClassificationTemperatureScaling, 12 | ) 13 | from fortuna.calibration.classification.temp_scaling.base import ( 14 | ClassificationTemperatureScaling, 15 | ) 16 | -------------------------------------------------------------------------------- /fortuna/calibration/binary_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/calibration/binary_classification/__init__.py -------------------------------------------------------------------------------- /fortuna/calibration/binary_classification/temp_scaling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/calibration/binary_classification/temp_scaling/__init__.py -------------------------------------------------------------------------------- /fortuna/calibration/binary_classification/temp_scaling/bias_binary_temp_scaling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from fortuna.calibration.binary_classification.temp_scaling.base import ( 4 | BaseBinaryClassificationTemperatureScaling, 5 | ) 6 | 7 | 8 | class BiasBinaryClassificationTemperatureScaling( 9 | BaseBinaryClassificationTemperatureScaling 10 | ): 11 | """ 12 | A temperature scaling class for binary classification. 13 | It scales the probability that the target variables is positive with a single learnable parameters. 14 | The method minimizes the expected bias. 15 | """ 16 | 17 | def fit(self, probs: np.ndarray, targets: np.ndarray): 18 | self._check_probs(probs) 19 | self._check_targets(targets) 20 | self._temperature = np.mean(probs) / np.mean(targets) 21 | -------------------------------------------------------------------------------- /fortuna/calibration/binary_classification/temp_scaling/brier_binary_temp_scaling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from fortuna.calibration.binary_classification.temp_scaling.base import ( 4 | BaseBinaryClassificationTemperatureScaling, 5 | ) 6 | 7 | 8 | class BrierBinaryClassificationTemperatureScaling( 9 | BaseBinaryClassificationTemperatureScaling 10 | ): 11 | """ 12 | A temperature scaling class for binary classification. 13 | It scales the probability that the target variables is positive with a single learnable parameters. 14 | The method attempts to minimize the MSE, or Brier score. 15 | """ 16 | 17 | def fit(self, probs: np.ndarray, targets: np.ndarray): 18 | self._check_probs(probs) 19 | self._check_targets(targets) 20 | self._temperature = np.mean(probs**2) / np.mean(probs * targets) 21 | -------------------------------------------------------------------------------- /fortuna/calibration/binary_classification/temp_scaling/crossentropy_binary_temp_scaling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import brute 3 | 4 | from fortuna.calibration.binary_classification.temp_scaling.base import ( 5 | BaseBinaryClassificationTemperatureScaling, 6 | ) 7 | 8 | 9 | class CrossEntropyBinaryClassificationTemperatureScaling( 10 | BaseBinaryClassificationTemperatureScaling 11 | ): 12 | """ 13 | A temperature scaling class for binary classification. 14 | It scales the probability that the target variables is positive with a single learnable parameters. 15 | The method minimizes the binary cross-entropy loss. 16 | """ 17 | 18 | def fit(self, probs: np.ndarray, targets: np.ndarray): 19 | self._check_probs(probs) 20 | self._check_targets(targets) 21 | 22 | def temp_scaling_fn(tau): 23 | temp_probs = np.clip(probs / tau, 1e-9, 1 - 1e-9) 24 | return -np.mean( 25 | targets * np.log(temp_probs) + (1 - targets) * np.log(1 - temp_probs) 26 | ) 27 | 28 | self._temperature = brute( 29 | temp_scaling_fn, ranges=[(np.min(probs), 10)], Ns=1000 30 | )[0] 31 | -------------------------------------------------------------------------------- /fortuna/calibration/binary_classification/temp_scaling/f1_temp_scaling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import brute 3 | 4 | from fortuna.calibration.binary_classification.temp_scaling.base import ( 5 | BaseBinaryClassificationTemperatureScaling, 6 | ) 7 | 8 | 9 | class F1BinaryClassificationTemperatureScaling( 10 | BaseBinaryClassificationTemperatureScaling 11 | ): 12 | """ 13 | A temperature scaling class for binary classification. 14 | It scales the probability that the target variables is positive with a single learnable parameters. 15 | The method attempts to maximize the F1 score. 16 | """ 17 | 18 | def __init__(self): 19 | super().__init__() 20 | self._threshold = None 21 | self._temperature = None 22 | 23 | def fit(self, probs: np.ndarray, targets: np.ndarray, threshold: float): 24 | self._check_probs(probs) 25 | self._check_targets(targets) 26 | 27 | self._threshold = threshold 28 | n_pos_targets = np.sum(targets) 29 | 30 | def loss_fn(tau): 31 | temp_preds = probs >= threshold * tau 32 | n_pos_preds = np.sum(temp_preds) 33 | n_joint = np.sum(targets * temp_preds) 34 | prec = n_joint / n_pos_preds if n_pos_preds > 0 else 0.0 35 | rec = n_joint / n_pos_targets 36 | if prec + rec == 0.0: 37 | return 0.0 38 | return -2 * prec * rec / (prec + rec) 39 | 40 | self._temperature = brute( 41 | loss_fn, ranges=[(np.min(probs), 1 / threshold)], Ns=1000 42 | )[0] 43 | 44 | def predict(self, probs: np.ndarray): 45 | self._check_probs(probs) 46 | return (self.predict_proba(probs) >= self._threshold).astype(int) 47 | 48 | @property 49 | def threshold(self): 50 | return self._threshold 51 | -------------------------------------------------------------------------------- /fortuna/calibration/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/calibration/classification/__init__.py -------------------------------------------------------------------------------- /fortuna/calibration/classification/temp_scaling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/calibration/classification/temp_scaling/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.conformal.classification.adaptive_conformal_classifier import ( 2 | AdaptiveConformalClassifier, 3 | ) 4 | from fortuna.conformal.classification.adaptive_prediction import ( 5 | AdaptivePredictionConformalClassifier, 6 | CVPlusAdaptivePredictionConformalClassifier, 7 | ) 8 | from fortuna.conformal.classification.maxcovfixprec_binary_classfication import ( 9 | MaxCoverageFixedPrecisionBinaryClassificationCalibrator, 10 | ) 11 | from fortuna.conformal.classification.simple_prediction import ( 12 | CVPlusSimplePredictionConformalClassifier, 13 | SimplePredictionConformalClassifier, 14 | ) 15 | from fortuna.conformal.multivalid.iterative.classification.binary_multicalibrator import ( 16 | BinaryClassificationMulticalibrator, 17 | ) 18 | from fortuna.conformal.multivalid.iterative.classification.top_label_multicalibrator import ( 19 | TopLabelMulticalibrator, 20 | ) 21 | from fortuna.conformal.multivalid.iterative.multicalibrator import Multicalibrator 22 | from fortuna.conformal.multivalid.iterative.regression.batch_mvp import ( 23 | BatchMVPConformalClassifier, 24 | ) 25 | from fortuna.conformal.multivalid.one_shot.classification.binary_multicalibrator import ( 26 | OneShotBinaryClassificationMulticalibrator, 27 | ) 28 | from fortuna.conformal.multivalid.one_shot.classification.top_label_multicalibrator import ( 29 | OneShotTopLabelMulticalibrator, 30 | ) 31 | from fortuna.conformal.multivalid.one_shot.multicalibrator import OneShotMulticalibrator 32 | from fortuna.conformal.regression.adaptive_conformal_regressor import ( 33 | AdaptiveConformalRegressor, 34 | ) 35 | from fortuna.conformal.regression.batch_mvp import BatchMVPConformalRegressor 36 | from fortuna.conformal.regression.cvplus import CVPlusConformalRegressor 37 | from fortuna.conformal.regression.enbpi import EnbPI 38 | from fortuna.conformal.regression.jackknife_minmax import ( 39 | JackknifeMinmaxConformalRegressor, 40 | ) 41 | from fortuna.conformal.regression.jackknifeplus import JackknifePlusConformalRegressor 42 | from fortuna.conformal.regression.onedim_uncertainty import ( 43 | OneDimensionalUncertaintyConformalRegressor, 44 | ) 45 | from fortuna.conformal.regression.quantile import QuantileConformalRegressor 46 | -------------------------------------------------------------------------------- /fortuna/conformal/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/classification/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/classification/adaptive_prediction.py: -------------------------------------------------------------------------------- 1 | from jax import vmap 2 | import jax.numpy as jnp 3 | 4 | from fortuna.conformal.classification.base import ( 5 | CVPlusConformalClassifier, 6 | SplitConformalClassifier, 7 | ) 8 | from fortuna.typing import Array 9 | 10 | 11 | @vmap 12 | def _score_fn(probs: Array, perm: Array, inv_perm: Array, targets: Array): 13 | return jnp.cumsum(probs[perm])[inv_perm[targets]] 14 | 15 | 16 | def score_fn( 17 | probs: Array, 18 | targets: Array, 19 | ): 20 | perms = jnp.argsort(probs, axis=1)[:, ::-1] 21 | inv_perms = jnp.argsort(perms, axis=1) 22 | return _score_fn(probs, perms, inv_perms, targets) 23 | 24 | 25 | class AdaptivePredictionConformalClassifier(SplitConformalClassifier): 26 | def score_fn( 27 | self, 28 | probs: Array, 29 | targets: Array, 30 | ): 31 | return score_fn(probs=probs, targets=targets) 32 | 33 | 34 | class CVPlusAdaptivePredictionConformalClassifier(CVPlusConformalClassifier): 35 | def score_fn( 36 | self, 37 | probs: Array, 38 | targets: Array, 39 | ): 40 | return score_fn(probs=probs, targets=targets) 41 | -------------------------------------------------------------------------------- /fortuna/conformal/classification/simple_prediction.py: -------------------------------------------------------------------------------- 1 | from jax import vmap 2 | 3 | from fortuna.conformal.classification.base import ( 4 | CVPlusConformalClassifier, 5 | SplitConformalClassifier, 6 | ) 7 | from fortuna.typing import Array 8 | 9 | 10 | @vmap 11 | def _score_fn(probs: Array, target: Array): 12 | return 1 - probs[target] 13 | 14 | 15 | def score_fn( 16 | probs: Array, 17 | targets: Array, 18 | ): 19 | return _score_fn(probs, targets) 20 | 21 | 22 | class SimplePredictionConformalClassifier(SplitConformalClassifier): 23 | def score_fn( 24 | self, 25 | probs: Array, 26 | targets: Array, 27 | ): 28 | return score_fn(probs=probs, targets=targets) 29 | 30 | 31 | class CVPlusSimplePredictionConformalClassifier(CVPlusConformalClassifier): 32 | def score_fn( 33 | self, 34 | probs: Array, 35 | targets: Array, 36 | ): 37 | return score_fn(probs=probs, targets=targets) 38 | -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/iterative/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/iterative/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/iterative/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/iterative/classification/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/iterative/regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/iterative/regression/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/mixins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/mixins/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/mixins/batchmvp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jax.numpy as jnp 4 | 5 | from fortuna.typing import Array 6 | 7 | 8 | class BatchMVPMixin: 9 | def pinball_loss(self, values: Array, scores: Array, coverage: float) -> Array: 10 | """ 11 | The pinball loss between the model evaluations and the scores. 12 | 13 | Parameters 14 | ---------- 15 | values: Array 16 | The model evaluations. 17 | scores: Array 18 | The scores. 19 | coverage: float 20 | The target coverage. 21 | 22 | Returns 23 | ------- 24 | Array 25 | The pinball loss evaluation. 26 | """ 27 | return self._loss_fn(values, scores, coverage=coverage) 28 | 29 | def _loss_fn( 30 | self, values: Array, scores: Array, coverage: Optional[float] = None 31 | ) -> Array: 32 | if scores.ndim == 2 and values.ndim == 1: 33 | scores = scores[:, 0] 34 | if coverage is None: 35 | coverage = self._coverage 36 | diff = scores - values 37 | return jnp.mean( 38 | diff * coverage * (scores > values) 39 | - diff * (1 - coverage) * (scores <= values) 40 | ) 41 | -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/mixins/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/mixins/classification/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/mixins/classification/binary_multicalibrator.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jax.numpy as jnp 4 | 5 | from fortuna.conformal.multivalid.mixins.multicalibrator import MulticalibratorMixin 6 | from fortuna.typing import Array 7 | 8 | 9 | class BinaryClassificationMulticalibratorMixin(MulticalibratorMixin): 10 | def mean_squared_error(self, probs: Array, targets: Array) -> Array: 11 | return super().mean_squared_error(values=probs, scores=targets) 12 | 13 | @staticmethod 14 | def _maybe_check_values( 15 | values: Optional[Array], test_values: Optional[Array] = None 16 | ): 17 | if values is not None: 18 | if values.ndim != 1: 19 | raise ValueError( 20 | "`probs` must be a 1-dimensional array representing the probability that the " 21 | "target variable is 1." 22 | ) 23 | if jnp.any(values < 0) or jnp.any(values > 1): 24 | raise ValueError("All elements in `values` must be within [0, 1].") 25 | if test_values is not None: 26 | if test_values.ndim != 1: 27 | raise ValueError( 28 | "`test_probs` must be a 1-dimensional array representing the probability that the " 29 | "target variable is 1." 30 | ) 31 | if jnp.any(test_values < 0) or jnp.any(test_values > 1): 32 | raise ValueError("All elements in `test_values` must be within [0, 1].") 33 | 34 | @staticmethod 35 | def _check_scores(scores: Array): 36 | if scores.ndim != 1: 37 | raise ValueError("`targets` must be a 1-dimensional array of integers.") 38 | if set(jnp.unique(scores).tolist()) != {0, 1}: 39 | raise ValueError("All values in `targets` must be 0 or 1.") 40 | if scores.dtype not in ["int32", "int64"]: 41 | raise ValueError("All elements in `targets` must be integers") 42 | -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/mixins/multicalibrator.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from fortuna.typing import Array 4 | 5 | 6 | class MulticalibratorMixin: 7 | def mean_squared_error(self, values: Array, scores: Array) -> Array: 8 | """ 9 | The mean squared error between the model evaluations and the scores. 10 | This is supposed to decrease at every round of the algorithm. 11 | 12 | Parameters 13 | ---------- 14 | values: Array 15 | The model evaluations. 16 | scores: Array 17 | The scores. 18 | 19 | Returns 20 | ------- 21 | Array 22 | The mean-squared error. 23 | """ 24 | return self._loss_fn(values, scores) 25 | 26 | @staticmethod 27 | def _loss_fn(values: Array, scores: Array) -> Array: 28 | if scores.ndim == 2 and values.ndim == 1: 29 | scores = scores[:, 0] 30 | return jnp.mean(jnp.sum((values - scores) ** 2, axis=-1)) 31 | return jnp.mean((values - scores) ** 2) 32 | -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/mixins/regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/mixins/regression/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/one_shot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/one_shot/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/one_shot/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/multivalid/one_shot/classification/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/one_shot/classification/binary_multicalibrator.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Optional, 3 | Union, 4 | ) 5 | 6 | from fortuna.conformal.multivalid.mixins.classification.binary_multicalibrator import ( 7 | BinaryClassificationMulticalibratorMixin, 8 | ) 9 | from fortuna.conformal.multivalid.one_shot.multicalibrator import OneShotMulticalibrator 10 | from fortuna.typing import Array 11 | 12 | 13 | class OneShotBinaryClassificationMulticalibrator( 14 | BinaryClassificationMulticalibratorMixin, OneShotMulticalibrator 15 | ): 16 | def __init__(self, seed: int = 0): 17 | super().__init__(seed=seed) 18 | 19 | def calibrate( 20 | self, 21 | targets: Array, 22 | probs: Optional[Array] = None, 23 | test_probs: Optional[Array] = None, 24 | n_buckets: int = 100, 25 | min_prob_b: Union[float, str] = "auto", 26 | ): 27 | return super().calibrate( 28 | scores=targets, 29 | values=probs, 30 | test_values=test_probs, 31 | n_buckets=n_buckets, 32 | min_prob_b=min_prob_b, 33 | ) 34 | 35 | def apply_patches( 36 | self, 37 | probs: Array, 38 | ) -> Array: 39 | return super().apply_patches(values=probs) 40 | -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/one_shot/classification/top_label_multicalibrator.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Optional, 3 | Union, 4 | ) 5 | 6 | import jax.numpy as jnp 7 | 8 | from fortuna.conformal.multivalid.mixins.classification.top_label_multicalibrator import ( 9 | TopLabelMulticalibratorMixin, 10 | ) 11 | from fortuna.conformal.multivalid.one_shot.multicalibrator import OneShotMulticalibrator 12 | from fortuna.typing import Array 13 | 14 | 15 | class OneShotTopLabelMulticalibrator( 16 | TopLabelMulticalibratorMixin, OneShotMulticalibrator 17 | ): 18 | def __init__(self, n_classes: int, seed: int = 0): 19 | super().__init__(n_classes=n_classes, seed=seed) 20 | 21 | def calibrate( 22 | self, 23 | targets: Array, 24 | probs: Optional[Array] = None, 25 | test_probs: Optional[Array] = None, 26 | n_buckets: int = 100, 27 | min_prob_b: Union[float, str] = "auto", 28 | ): 29 | return super().calibrate( 30 | scores=self._get_scores(targets), 31 | values=probs, 32 | test_values=test_probs, 33 | n_buckets=n_buckets, 34 | min_prob_b=min_prob_b, 35 | ) 36 | 37 | def apply_patches( 38 | self, 39 | probs: Array, 40 | ) -> Array: 41 | return super().apply_patches(values=probs) 42 | 43 | @staticmethod 44 | def _get_b( 45 | values: Array, 46 | v: Array, 47 | c: Optional[Array], 48 | n_buckets: int, 49 | ) -> Array: 50 | return (jnp.abs(values[:, c] - v) < 0.5 / n_buckets) * (values.argmax(1) == c) 51 | -------------------------------------------------------------------------------- /fortuna/conformal/multivalid/one_shot/multicalibrator.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from fortuna.conformal.multivalid.mixins.multicalibrator import MulticalibratorMixin 4 | from fortuna.conformal.multivalid.one_shot.base import OneShotMultivalidMethod 5 | from fortuna.typing import Array 6 | 7 | 8 | class OneShotMulticalibrator(MulticalibratorMixin, OneShotMultivalidMethod): 9 | def __init__(self, seed: int = 0): 10 | super().__init__(seed=seed) 11 | 12 | def _get_patch( 13 | self, v: Array, c: Array, scores: Array, values: Array, min_prob_b: float 14 | ) -> Array: 15 | return self._compute_expectation( 16 | v=v, c=c, scores=scores, values=values, min_prob_b=min_prob_b 17 | ) 18 | 19 | def _compute_expectation( 20 | self, v: Array, c: Array, scores: Array, values: Array, min_prob_b: float 21 | ): 22 | b = self._get_b(values=values, v=v, c=c, n_buckets=self.n_buckets) 23 | filtered_scores = scores * b 24 | prob_b = jnp.mean(b) 25 | mean = jnp.where(prob_b > min_prob_b, jnp.mean(filtered_scores) / prob_b, v) 26 | return mean 27 | -------------------------------------------------------------------------------- /fortuna/conformal/regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/conformal/regression/__init__.py -------------------------------------------------------------------------------- /fortuna/conformal/regression/base.py: -------------------------------------------------------------------------------- 1 | from fortuna.typing import Array 2 | 3 | 4 | class ConformalRegressor: 5 | """ 6 | A base conformal regressor class. 7 | """ 8 | 9 | def is_in(self, values: Array, conformal_intervals: Array) -> Array: 10 | """ 11 | Check whether the values lie within their respective conformal intervals. 12 | 13 | Parameters 14 | ---------- 15 | values: Array 16 | Values to check if they lie in the respective conformal intervals. 17 | conformal_intervals: Array 18 | A conformal interval for each input data point. 19 | 20 | Returns 21 | ------- 22 | Array 23 | An array of ones or zero, indicating whether the values lie within their respective conformal intervals. 24 | """ 25 | return (values <= conformal_intervals[:, 1]) * ( 26 | values >= conformal_intervals[:, 0] 27 | ) 28 | -------------------------------------------------------------------------------- /fortuna/conformal/regression/batch_mvp.py: -------------------------------------------------------------------------------- 1 | from fortuna.conformal.multivalid.iterative.batch_mvp import BatchMVPConformalMethod 2 | from fortuna.conformal.regression.base import ConformalRegressor 3 | 4 | 5 | class BatchMVPConformalRegressor(BatchMVPConformalMethod, ConformalRegressor): 6 | def __init__(self, seed: int = 0): 7 | """ 8 | This class implements a classification version of BatchMVP 9 | `[Jung et al., 2022] `_, 10 | a multivalid conformal prediction method that satisfies coverage guarantees conditioned on group membership 11 | and non-conformity threshold. 12 | 13 | Parameters 14 | ---------- 15 | seed: int 16 | Random seed. 17 | """ 18 | super().__init__(seed=seed) 19 | -------------------------------------------------------------------------------- /fortuna/data/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.data.loader import ( 2 | BaseInputsLoader, 3 | BaseTargetsLoader, 4 | ConcatenatedLoader, 5 | DataLoader, 6 | DeviceDimensionAugmentedLoader, 7 | InputsLoader, 8 | TargetsLoader, 9 | ) 10 | -------------------------------------------------------------------------------- /fortuna/data/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/data/dataset/__init__.py -------------------------------------------------------------------------------- /fortuna/data/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.data.loader.array_loaders import ( 2 | DataLoader, 3 | InputsLoader, 4 | TargetsLoader, 5 | ) 6 | from fortuna.data.loader.base import ( 7 | BaseInputsLoader, 8 | BaseTargetsLoader, 9 | ConcatenatedLoader, 10 | DeviceDimensionAugmentedLoader, 11 | ) 12 | -------------------------------------------------------------------------------- /fortuna/distribution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/distribution/__init__.py -------------------------------------------------------------------------------- /fortuna/distribution/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from fortuna.utils.builtins import HashableMixin 4 | 5 | 6 | class Distribution(HashableMixin, abc.ABC): 7 | @abc.abstractmethod 8 | def sample(self, *args, **kwargs): 9 | pass 10 | 11 | def log_joint_prob(self, *args, **kwargs): 12 | pass 13 | -------------------------------------------------------------------------------- /fortuna/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG ACCOUNT_ID 2 | 3 | # Dockerfile for training models using JAX 4 | FROM $ACCOUNT_ID.dkr.ecr.us-east-1.amazonaws.com/fortuna:cuda-11.8.0-cudnn8-devel-ubuntu22.04 5 | 6 | # Install python3 7 | RUN apt update && apt install -y python3-pip 8 | 9 | RUN ln -sf /usr/bin/python3 /usr/bin/python && \ 10 | ln -sf /usr/bin/pip3 /usr/bin/pip 11 | 12 | RUN pip --no-cache-dir install --upgrade pip setuptools_rust 13 | 14 | # Install ML Packages built with CUDA11 support 15 | RUN ln -s /usr/lib/cuda /usr/local/cuda-11.8 16 | RUN pip --no-cache-dir install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 17 | RUN pip --no-cache-dir install "aws-fortuna[transformers]" 18 | RUN pip --no-cache-dir install sagemaker-training 19 | RUN pip --no-cache-dir install smdebug 20 | RUN pip --no-cache-dir install Jinja2 21 | 22 | # Setting some environment variables related to logging 23 | ENV PYTHONDONTWRITEBYTECODE=1 24 | ENV PYTHONUNBUFFERED=1 25 | -------------------------------------------------------------------------------- /fortuna/hallucination/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.hallucination.base import HallucinationMulticalibrator 2 | from fortuna.hallucination.grouping.clustering.base import GroupingModel 3 | from fortuna.hallucination.scoring.inv_perplexity import inv_perplexity 4 | -------------------------------------------------------------------------------- /fortuna/hallucination/grouping/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/hallucination/grouping/__init__.py -------------------------------------------------------------------------------- /fortuna/hallucination/grouping/clustering/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/hallucination/grouping/clustering/__init__.py -------------------------------------------------------------------------------- /fortuna/hallucination/scoring/inv_perplexity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import CrossEntropyLoss 3 | 4 | 5 | @torch.no_grad() 6 | def perplexity(logits: torch.Tensor, labels: torch.Tensor, n_final_tokens: int): 7 | loss_fct = CrossEntropyLoss(reduction="none") 8 | shift_logits = logits[..., :-1, :].contiguous() 9 | shift_labels = labels[..., 1:].contiguous() 10 | 11 | return torch.exp( 12 | loss_fct(shift_logits.transpose(1, 2), shift_labels)[:, -n_final_tokens:].mean( 13 | 1 14 | ) 15 | ) 16 | 17 | 18 | @torch.no_grad() 19 | def inv_perplexity(logits: torch.Tensor, labels: torch.Tensor, n_final_tokens: int): 20 | return 1 / perplexity(logits=logits, labels=labels, n_final_tokens=n_final_tokens) 21 | -------------------------------------------------------------------------------- /fortuna/hallucination/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.hallucination.utils.string import string_cleaner 2 | -------------------------------------------------------------------------------- /fortuna/hallucination/utils/string.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def string_cleaner(text: str) -> str: 5 | """ 6 | Clean a string of text. Remove possible spaces before punctuation and format for proper capitalization. 7 | 8 | Parameters 9 | ---------- 10 | text: str 11 | A string of text 12 | Returns 13 | ------- 14 | str 15 | Formatted string. 16 | """ 17 | text = re.sub(r'\s([?.,%!"](?:\s|$))', r"\1", text) 18 | 19 | text = ". ".join(map(lambda s: s.strip().capitalize(), text.split("."))) 20 | text = "? ".join(map(lambda s: s.strip().capitalize(), text.split("?"))) 21 | text = "! ".join(map(lambda s: s.strip().capitalize(), text.split("!"))) 22 | text = " ' ".join(map(lambda s: s.strip(), text.split("'"))) 23 | text = text.replace(" ' ", "'") 24 | return text 25 | -------------------------------------------------------------------------------- /fortuna/kernel_regression/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.kernel_regression.nadaraya_watson import NadarayaWatsonKernelRegressor 2 | -------------------------------------------------------------------------------- /fortuna/kernel_regression/kernels/gaussian.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from fortuna.typing import Array 4 | 5 | 6 | def gaussian_kernel(x: Array) -> Array: 7 | return jnp.exp(-0.5 * x**2) / jnp.sqrt(2 * jnp.pi) 8 | -------------------------------------------------------------------------------- /fortuna/likelihood/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/likelihood/__init__.py -------------------------------------------------------------------------------- /fortuna/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/loss/__init__.py -------------------------------------------------------------------------------- /fortuna/loss/classification/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.loss.classification.cross_entropy import cross_entropy_loss_fn 2 | from fortuna.loss.classification.focal_loss import focal_loss_fn 3 | -------------------------------------------------------------------------------- /fortuna/loss/classification/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.scipy as jsp 4 | 5 | from fortuna.typing import Array 6 | 7 | 8 | def cross_entropy_loss_fn( 9 | outputs: jnp.ndarray, 10 | targets: Array, 11 | ) -> jnp.ndarray: 12 | """ 13 | A cross-entropy loss function. Check `here `_ for reference. 14 | 15 | Parameters 16 | ---------- 17 | outputs: Array 18 | Model outputs to be passed to the loss. 19 | targets: Array 20 | Target data points. 21 | 22 | Returns 23 | ------- 24 | Tuple[jnp.ndarray, Any] 25 | The cross-entropy loss evaluation and auxiliary objects. 26 | """ 27 | targets = jax.nn.one_hot(targets, outputs.shape[-1]) 28 | return jnp.mean(jsp.special.logsumexp(outputs, -1) - jnp.sum(targets * outputs, -1)) 29 | -------------------------------------------------------------------------------- /fortuna/loss/classification/focal_loss.py: -------------------------------------------------------------------------------- 1 | from jax.nn import ( 2 | one_hot, 3 | softmax, 4 | ) 5 | import jax.numpy as jnp 6 | 7 | from fortuna.typing import Array 8 | 9 | 10 | def focal_loss_fn( 11 | outputs: Array, 12 | targets: Array, 13 | gamma: float = 2.0, 14 | ) -> jnp.ndarray: 15 | """ 16 | A focal loss function. See `[Mukhoti J. et a., 2020] `_ 17 | for reference. 18 | 19 | Parameters 20 | ---------- 21 | outputs: Array 22 | Model outputs to be passed to the loss. 23 | targets: Array 24 | Target data points. 25 | gamma: float 26 | Hyper-parameter of the focal loss. 27 | 28 | Returns 29 | ------- 30 | jnp.ndarray 31 | The focal loss evaluation. 32 | """ 33 | probs = softmax(outputs, -1) 34 | targets = one_hot(targets, outputs.shape[-1]) 35 | probs = jnp.sum(probs * targets, -1) 36 | return -jnp.mean((1 - probs) ** gamma * jnp.log(probs)) 37 | -------------------------------------------------------------------------------- /fortuna/loss/regression/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.loss.regression.scaled_mse import scaled_mse_fn 2 | -------------------------------------------------------------------------------- /fortuna/loss/regression/scaled_mse.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from fortuna.typing import Array 4 | 5 | 6 | def scaled_mse_fn( 7 | outputs: Array, 8 | targets: Array, 9 | ) -> jnp.ndarray: 10 | """ 11 | Compute a variance-scaled mean-squared-error (MSE). 12 | 13 | Parameters 14 | ---------- 15 | outputs: Array 16 | Model outputs to be passed to the loss. 17 | targets: Array 18 | Target data points. 19 | 20 | Returns 21 | ------- 22 | Tuple[jnp.ndarray, Any] 23 | Scaled MSE evalution and auxiliary objects. 24 | """ 25 | means, log_vars = jnp.split(outputs, 2, axis=-1) 26 | return jnp.mean(jnp.sum(jnp.exp(-log_vars) * (targets - means) ** 2 + log_vars, -1)) 27 | -------------------------------------------------------------------------------- /fortuna/metric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/metric/__init__.py -------------------------------------------------------------------------------- /fortuna/model/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.model.constant import ConstantModel 2 | from fortuna.model.hyper import HyperparameterModel 3 | from fortuna.model.lenet import LeNet5 4 | from fortuna.model.linear import LinearModel 5 | from fortuna.model.mlp import MLP 6 | from fortuna.model.resnet import ( 7 | ResNet18, 8 | ResNet34, 9 | ResNet50, 10 | ResNet101, 11 | ResNet152, 12 | ResNet200, 13 | ) 14 | from fortuna.model.scalar_constant import ScalarConstantModel 15 | from fortuna.model.scalar_hyper import ScalarHyperparameterModel 16 | from fortuna.model.wideresnet import WideResNet28_10 17 | -------------------------------------------------------------------------------- /fortuna/model/constant.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import flax.linen as nn 4 | from flax.linen.initializers import Initializer 5 | import jax.numpy as jnp 6 | 7 | from fortuna.typing import Array 8 | 9 | 10 | class ConstantModel(nn.Module): 11 | r""" 12 | A constant model, that is :math:`f(\theta, x) = \theta`. 13 | 14 | Parameters 15 | ---------- 16 | output_dim: int 17 | The output model dimension. 18 | initializer_fun: Optional[Initializer] 19 | Function to initialize the model parameters. 20 | This must be one of the available options in :code:`flax.linen.initializers`. 21 | """ 22 | 23 | output_dim: int 24 | initializer_fun: Optional[Initializer] = nn.initializers.zeros 25 | 26 | @nn.compact 27 | def __call__(self, x: Array, **kwargs) -> jnp.ndarray: 28 | constant = self.param("constant", self.initializer_fun, (self.output_dim,)) 29 | return jnp.broadcast_to(constant, shape=(x.shape[0], self.output_dim)) 30 | -------------------------------------------------------------------------------- /fortuna/model/hyper.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | from fortuna.typing import Array 5 | 6 | 7 | class HyperparameterModel(nn.Module): 8 | r""" 9 | A hyperparameter model. The value of the hyperparameter will not change during training. 10 | 11 | Parameters 12 | ---------- 13 | value: Union[float, Array] 14 | Value of the hyperparameter. 15 | """ 16 | 17 | value: Array 18 | 19 | def setup(self) -> None: 20 | if self.value.ndim != 1: 21 | raise ValueError( 22 | "`value` must be a one-dimensional array, with length equal to the output dimension of " 23 | "the model." 24 | ) 25 | dummy = self.param("none", nn.initializers.zeros, (0,)) 26 | 27 | def __call__(self, x: Array, **kwargs) -> jnp.ndarray: 28 | return jnp.broadcast_to(self.value, shape=(x.shape[0], len(self.value))) 29 | -------------------------------------------------------------------------------- /fortuna/model/linear.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | from fortuna.typing import Array 5 | 6 | 7 | class LinearModel(nn.Module): 8 | """ 9 | A linear model. 10 | 11 | Parameters 12 | ---------- 13 | output_dim: int 14 | The output model dimension. 15 | """ 16 | 17 | output_dim: int 18 | 19 | @nn.compact 20 | def __call__(self, x: Array, **kwargs) -> jnp.ndarray: 21 | x = nn.Dense(self.output_dim, name="last")(x) 22 | return x 23 | -------------------------------------------------------------------------------- /fortuna/model/model_manager/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/model/model_manager/__init__.py -------------------------------------------------------------------------------- /fortuna/model/model_manager/name_to_model_manager.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from fortuna.model.model_manager.classification import ( 4 | ClassificationModelManager, 5 | SNGPClassificationModelManager, 6 | ) 7 | from fortuna.prob_model.posterior.deep_ensemble import DEEP_ENSEMBLE_NAME 8 | from fortuna.prob_model.posterior.laplace import LAPLACE_NAME 9 | from fortuna.prob_model.posterior.map import MAP_NAME 10 | from fortuna.prob_model.posterior.normalizing_flow.advi import ADVI_NAME 11 | from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import CYCLICAL_SGLD_NAME 12 | from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME 13 | from fortuna.prob_model.posterior.sngp import SNGP_NAME 14 | from fortuna.prob_model.posterior.swag import SWAG_NAME 15 | 16 | 17 | class ClassificationModelManagers(enum.Enum): 18 | """Map approximator name to model manager classes""" 19 | 20 | vars()[MAP_NAME] = ClassificationModelManager 21 | vars()[ADVI_NAME] = ClassificationModelManager 22 | vars()[DEEP_ENSEMBLE_NAME] = ClassificationModelManager 23 | vars()[LAPLACE_NAME] = ClassificationModelManager 24 | vars()[SWAG_NAME] = ClassificationModelManager 25 | vars()[SNGP_NAME] = SNGPClassificationModelManager 26 | vars()[SGHMC_NAME] = ClassificationModelManager 27 | vars()[CYCLICAL_SGLD_NAME] = ClassificationModelManager 28 | -------------------------------------------------------------------------------- /fortuna/model/model_manager/state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Dict, 5 | Optional, 6 | Union, 7 | ) 8 | 9 | from flax.core import FrozenDict 10 | 11 | from fortuna.typing import ( 12 | Mutable, 13 | Params, 14 | ) 15 | 16 | 17 | class ModelManagerState: 18 | params: Params 19 | mutable: Optional[Mutable] = None 20 | 21 | def __init__(self, params: Params, mutable: Optional[Mutable] = None): 22 | """ 23 | A model manager state class. 24 | 25 | Parameters 26 | ---------- 27 | params : Params 28 | The random parameters of the probabilistic model. 29 | mutable : Optional[Mutable] 30 | The mutable objects used to evaluate the models. 31 | """ 32 | self.params = params 33 | self.mutable = mutable 34 | 35 | @classmethod 36 | def init_from_dict(cls, d: Union[Dict, FrozenDict]) -> ModelManagerState: 37 | """ 38 | Initialize the model manager state from a dictionary. This dictionary should be like the output of 39 | :func:`~fortuna.model.model_manager.base.ModelManager.init`. 40 | 41 | Parameters 42 | ---------- 43 | d : Union[Dict, FrozenDict] 44 | A dictionary like the output of :func:`~fortuna.model.model_manager.base.ModelManager.init`. 45 | 46 | Returns 47 | ------- 48 | ModelManagerState 49 | An model manager state. 50 | """ 51 | params = FrozenDict( 52 | {k: FrozenDict({"params": v["params"]}) for k, v in d.items()} 53 | ) 54 | mutable = FrozenDict( 55 | { 56 | k: FrozenDict({_k: _v for _k, _v in v.items() if _k != "params"}) 57 | for k, v in d.items() 58 | } 59 | ) 60 | flag = 0 61 | for k, v in mutable.items(): 62 | if len(v) > 0: 63 | flag += 1 64 | if flag == 0: 65 | mutable = None 66 | return cls(params=params, mutable=mutable) 67 | -------------------------------------------------------------------------------- /fortuna/model/model_manager/transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/model/model_manager/transformers/__init__.py -------------------------------------------------------------------------------- /fortuna/model/scalar_constant.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import flax.linen as nn 4 | from flax.linen.initializers import Initializer 5 | import jax.numpy as jnp 6 | 7 | from fortuna.typing import Array 8 | 9 | 10 | class ScalarConstantModel(nn.Module): 11 | r""" 12 | A scalar constant model, that is :math:`f(\theta, x) = \theta`, with :math:`\theta\in\mathbb{R}`. The scalar value 13 | will be broadcasted to the output dimension. 14 | 15 | Parameters 16 | ---------- 17 | output_dim: int 18 | The output model dimension. 19 | initializer_fun: Optional[Initializer] 20 | Function to initialize the model parameters. 21 | This must be one of the available options in :code:`flax.linen.initializers`. 22 | """ 23 | 24 | output_dim: int 25 | initializer_fun: Optional[Initializer] = nn.initializers.zeros 26 | 27 | @nn.compact 28 | def __call__(self, x: Array, **kwargs) -> jnp.ndarray: 29 | scalar = self.param("scalar", self.initializer_fun, (1,)) 30 | return jnp.broadcast_to(scalar, shape=(x.shape[0], self.output_dim)) 31 | -------------------------------------------------------------------------------- /fortuna/model/scalar_hyper.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | from fortuna.typing import Array 5 | 6 | 7 | class ScalarHyperparameterModel(nn.Module): 8 | r""" 9 | A scalar hyperparameter model. The scalar value of the hyperparameter will not change during training, and it will 10 | be broadcasted to the output dimension. 11 | 12 | Parameters 13 | ---------- 14 | output_dim: int 15 | The output model dimension. 16 | value: float 17 | Scalar value of the hyperparameter. 18 | """ 19 | 20 | output_dim: int 21 | value: float 22 | 23 | def setup(self) -> None: 24 | if type(self.value) != float: 25 | raise ValueError( 26 | f"`value` must be a float, but a {type(self.value)} was found instead." 27 | ) 28 | dummy = self.param("none", nn.initializers.zeros, (0,)) 29 | 30 | def __call__(self, x: Array, **kwargs) -> jnp.ndarray: 31 | return jnp.broadcast_to(self.value, shape=(x.shape[0], self.output_dim)) 32 | -------------------------------------------------------------------------------- /fortuna/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/model/utils/__init__.py -------------------------------------------------------------------------------- /fortuna/model_editor/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.model_editor.probit import ProbitModelEditor 2 | -------------------------------------------------------------------------------- /fortuna/model_editor/base.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | Callable, 4 | Dict, 5 | Tuple, 6 | Union, 7 | ) 8 | 9 | import flax.linen as nn 10 | import jax.numpy as jnp 11 | 12 | from fortuna.typing import InputData 13 | 14 | 15 | class ModelEditor(nn.Module): 16 | @nn.compact 17 | def __call__( 18 | self, 19 | apply_fn: Callable[ 20 | [Dict, InputData], Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]] 21 | ], 22 | model_params: Dict, 23 | x: Any, 24 | has_aux: bool, 25 | ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]: 26 | """ 27 | Apply a transformation to the forward pass. 28 | 29 | Parameters 30 | ---------- 31 | apply_fn: Callable[[Dict, InputData], Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]] 32 | The model forward pass. 33 | model_params: Dict 34 | The model parameters. 35 | x: Array 36 | Batch of inputs. 37 | has_aux: bool 38 | Whether the forward pass returns auxiliary objects. 39 | 40 | Returns 41 | ------- 42 | Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]] 43 | Return the transformed outputs, and auxiliary objects if available. 44 | """ 45 | pass 46 | -------------------------------------------------------------------------------- /fortuna/model_editor/probit.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | Callable, 4 | Dict, 5 | Optional, 6 | Tuple, 7 | Union, 8 | ) 9 | 10 | import flax.linen as nn 11 | import jax.numpy as jnp 12 | 13 | from fortuna.model_editor.base import ModelEditor 14 | from fortuna.typing import ( 15 | AnyKey, 16 | Array, 17 | InputData, 18 | Params, 19 | ) 20 | from fortuna.utils.probit import sequential_probit_scaling 21 | 22 | 23 | class ProbitModelEditor(ModelEditor): 24 | freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None 25 | top_k: Optional[int] = None 26 | memory: Optional[int] = None 27 | n_final_tokens: Optional[int] = None 28 | init_log_var: float = -5.0 29 | stop_gradient: bool = False 30 | 31 | @nn.compact 32 | def __call__( 33 | self, 34 | apply_fn: Callable[ 35 | [Params, InputData], Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]] 36 | ], 37 | model_params: Params, 38 | x: Any, 39 | has_aux: bool, 40 | ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]: 41 | log_var = self.param( 42 | "log_var", nn.initializers.constant(self.init_log_var), (1,) 43 | ) 44 | outputs = sequential_probit_scaling( 45 | apply_fn, 46 | model_params, 47 | x, 48 | log_var=log_var, 49 | has_aux=has_aux, 50 | freeze_fun=self.freeze_fun, 51 | top_k=self.top_k, 52 | memory=self.memory, 53 | n_final_tokens=self.n_final_tokens, 54 | stop_gradient=self.stop_gradient, 55 | ) 56 | return outputs 57 | -------------------------------------------------------------------------------- /fortuna/ood_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.ood_detection.ddu import DeepDeterministicUncertaintyOODClassifier 2 | from fortuna.ood_detection.mahalanobis import MalahanobisOODClassifier 3 | -------------------------------------------------------------------------------- /fortuna/ood_detection/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from fortuna.typing import Array 4 | 5 | 6 | class NotFittedError(ValueError, AttributeError): 7 | """Exception class to raise if estimator is used before fitting.""" 8 | 9 | 10 | class OutOfDistributionClassifierABC: 11 | """ 12 | Post-training classifier that uses the training sample embeddings coming from the model 13 | to score a (new) test sample w.r.t. its chance of belonging to the original training distribution 14 | (i.e, it is in-distribution) or not (i.e., it is out of distribution). 15 | """ 16 | 17 | def __init__(self, num_classes: int): 18 | """ 19 | Parameters 20 | ---------- 21 | num_classes: int 22 | The number of classes for the in-distribution classification task. 23 | """ 24 | self.num_classes = num_classes 25 | 26 | @abc.abstractmethod 27 | def fit(self, embeddings: Array, targets: Array) -> None: 28 | pass 29 | 30 | @abc.abstractmethod 31 | def score(self, embeddings: Array) -> Array: 32 | pass 33 | -------------------------------------------------------------------------------- /fortuna/output_calib_model/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.output_calib_model.classification import OutputCalibClassifier 2 | from fortuna.output_calib_model.config.base import Config 3 | from fortuna.output_calib_model.config.checkpointer import Checkpointer 4 | from fortuna.output_calib_model.config.monitor import Monitor 5 | from fortuna.output_calib_model.config.optimizer import Optimizer 6 | from fortuna.output_calib_model.config.processor import Processor 7 | from fortuna.output_calib_model.regression import OutputCalibRegressor 8 | -------------------------------------------------------------------------------- /fortuna/output_calib_model/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/output_calib_model/config/__init__.py -------------------------------------------------------------------------------- /fortuna/output_calib_model/config/base.py: -------------------------------------------------------------------------------- 1 | from fortuna.output_calib_model.config.checkpointer import Checkpointer 2 | from fortuna.output_calib_model.config.monitor import Monitor 3 | from fortuna.output_calib_model.config.optimizer import Optimizer 4 | from fortuna.output_calib_model.config.processor import Processor 5 | 6 | 7 | class Config: 8 | def __init__( 9 | self, 10 | optimizer: Optimizer = Optimizer(), 11 | checkpointer: Checkpointer = Checkpointer(), 12 | monitor: Monitor = Monitor(), 13 | processor: Processor = Processor(), 14 | ): 15 | """ 16 | Configure the calibration of the output calibration model. 17 | 18 | Parameters 19 | ---------- 20 | optimizer: Optimizer 21 | It defines the optimization specifics. 22 | checkpointer: Checkpointer 23 | It handles saving and restoring checkpoints. 24 | monitor: Monitor 25 | It monitors training progress and might induce early stopping. 26 | processor: Processor 27 | It processes where computation takes place. 28 | """ 29 | self.optimizer = optimizer 30 | self.checkpointer = checkpointer 31 | self.monitor = monitor 32 | self.processor = processor 33 | -------------------------------------------------------------------------------- /fortuna/output_calib_model/config/checkpointer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fortuna.typing import Path 4 | 5 | 6 | class Checkpointer: 7 | def __init__( 8 | self, 9 | save_checkpoint_dir: Optional[Path] = None, 10 | restore_checkpoint_path: Optional[Path] = None, 11 | save_every_n_steps: Optional[int] = None, 12 | keep_top_n_checkpoints: Optional[int] = 2, 13 | dump_state: bool = False, 14 | ): 15 | """ 16 | An object to configure saving and restoring of checkpoints during the calibration process. 17 | 18 | Parameters 19 | ---------- 20 | save_checkpoint_dir: Optional[Path] = None 21 | Save directory location. 22 | restore_checkpoint_path: Optional[Path] 23 | Path to checkpoint file or directory to restore. 24 | save_every_n_steps: int 25 | Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no 26 | checkpoint will be saved during training). 27 | keep_top_n_checkpoints: int 28 | Number of past checkpoint files to keep. 29 | dump_state: bool 30 | Dump the fitted calibration state as a checkpoint in `save_checkpoint_dir`. 31 | Any future call to the state will internally involve restoring it from memory. 32 | """ 33 | self.save_checkpoint_dir = save_checkpoint_dir 34 | self.save_every_n_steps = save_every_n_steps 35 | self.restore_checkpoint_path = restore_checkpoint_path 36 | self.keep_top_n_checkpoints = keep_top_n_checkpoints 37 | self.dump_state = dump_state 38 | -------------------------------------------------------------------------------- /fortuna/output_calib_model/config/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import optax 4 | 5 | from fortuna.typing import OptaxOptimizer 6 | 7 | 8 | class Optimizer: 9 | def __init__( 10 | self, 11 | method: Optional[OptaxOptimizer] = optax.adam(1e-2), 12 | n_epochs: int = 100, 13 | ): 14 | """ 15 | An object to configure the optimization in the calibration process. 16 | 17 | Parameters 18 | ---------- 19 | method: OptaxOptimizer 20 | An Optax optimizer. 21 | n_epochs: int 22 | Maximum number of epochs to run the calibration for. 23 | """ 24 | self.method = method 25 | self.n_epochs = n_epochs 26 | -------------------------------------------------------------------------------- /fortuna/output_calib_model/config/processor.py: -------------------------------------------------------------------------------- 1 | class Processor: 2 | def __init__( 3 | self, 4 | devices: int = -1, 5 | disable_jit: bool = False, 6 | ): 7 | """ 8 | An object to configure computational aspects of the calibration process. 9 | 10 | Parameters 11 | ---------- 12 | devices: int 13 | A list of devices to be used during training. 14 | At the moment two options are supported: use all devices (`devices=-1`) or use no device (`devices=0`). 15 | disable_jit: bool 16 | if True, no function within the calibration loop is jitted. 17 | """ 18 | self.devices = devices 19 | self.disable_jit = disable_jit 20 | -------------------------------------------------------------------------------- /fortuna/output_calib_model/output_calib_mixin.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from flax.training import checkpoints 5 | 6 | from fortuna.output_calib_model.state import OutputCalibState 7 | from fortuna.training.mixin import WithCheckpointingMixin 8 | from fortuna.typing import ( 9 | OptaxOptimizer, 10 | Path, 11 | ) 12 | 13 | 14 | class WithOutputCalibCheckpointingMixin(WithCheckpointingMixin): 15 | def restore_checkpoint( 16 | self, 17 | restore_checkpoint_path: Path, 18 | optimizer: Optional[OptaxOptimizer] = None, 19 | prefix: str = "checkpoint_", 20 | **kwargs, 21 | ) -> OutputCalibState: 22 | if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile( 23 | restore_checkpoint_path 24 | ): 25 | raise ValueError( 26 | f"`restore_checkpoint_path={restore_checkpoint_path}` was not found." 27 | ) 28 | d = checkpoints.restore_checkpoint( 29 | ckpt_dir=str(restore_checkpoint_path), 30 | target=None, 31 | step=None, 32 | prefix=prefix, 33 | parallel=True, 34 | ) 35 | if d is None: 36 | raise ValueError( 37 | f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`." 38 | ) 39 | 40 | return OutputCalibState.init_from_dict(d, optimizer, **kwargs) 41 | -------------------------------------------------------------------------------- /fortuna/output_calib_model/output_calib_state_repository.py: -------------------------------------------------------------------------------- 1 | from fortuna.output_calib_model.output_calib_mixin import ( 2 | WithOutputCalibCheckpointingMixin, 3 | ) 4 | from fortuna.training.train_state_repository import TrainStateRepository 5 | 6 | 7 | class OutputCalibStateRepository( 8 | WithOutputCalibCheckpointingMixin, TrainStateRepository 9 | ): 10 | pass 11 | -------------------------------------------------------------------------------- /fortuna/output_calib_model/predictive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/output_calib_model/predictive/__init__.py -------------------------------------------------------------------------------- /fortuna/output_calibrator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/output_calibrator/__init__.py -------------------------------------------------------------------------------- /fortuna/output_calibrator/classification.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | from fortuna.typing import Array 5 | 6 | 7 | class ClassificationTemperatureScaler(nn.Module): 8 | r""" 9 | Classification temperature scaling. It scales the logits with a scalar temperature parameters. Let :math:`o` be 10 | output logits and :math:`\phi` be a scalar parameter. Then the scaling can be seen as 11 | :math:`g(\phi, o) = \exp(-\phi) o`. 12 | """ 13 | 14 | @nn.compact 15 | def __call__(self, x: Array, **kwargs) -> jnp.ndarray: 16 | log_temp = self.param("log_temp", nn.initializers.zeros, (1,)) 17 | return x * jnp.exp(-log_temp) 18 | -------------------------------------------------------------------------------- /fortuna/output_calibrator/output_calib_manager/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/output_calibrator/output_calib_manager/__init__.py -------------------------------------------------------------------------------- /fortuna/output_calibrator/output_calib_manager/state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Dict, 5 | Optional, 6 | Union, 7 | ) 8 | 9 | from flax.core import FrozenDict 10 | 11 | from fortuna.typing import ( 12 | CalibMutable, 13 | CalibParams, 14 | ) 15 | 16 | 17 | class OutputCalibManagerState: 18 | params: CalibParams 19 | mutable: Optional[CalibMutable] = None 20 | 21 | def __init__(self, params: CalibParams, mutable: Optional[CalibMutable] = None): 22 | """ 23 | An model manager state class. 24 | 25 | Parameters 26 | ---------- 27 | params : Params 28 | The random parameters of the probabilistic model. 29 | mutable : Optional[Mutable] 30 | The mutable objects used to evaluate the models. 31 | """ 32 | self.params = params 33 | self.mutable = mutable 34 | 35 | @classmethod 36 | def init_from_dict(cls, d: Union[Dict, FrozenDict]) -> OutputCalibManagerState: 37 | """ 38 | Initialize an output calibration manager state from a dictionary. 39 | 40 | Parameters 41 | ---------- 42 | d : Union[Dict, FrozenDict] 43 | A dictionary with as keys the calibrators and as values their initializations. 44 | 45 | Returns 46 | ------- 47 | OutputCalibManagerState 48 | An output calibration manager state. 49 | """ 50 | params = FrozenDict( 51 | { 52 | k: FrozenDict({"params": v["params"] if v else None}) 53 | for k, v in d.items() 54 | } 55 | ) 56 | mutable = {calib_name: {} for calib_name in d} 57 | for name, variables in d.items(): 58 | if variables: 59 | for var_name, var_obj in variables.items(): 60 | if var_name != "params": 61 | mutable[name][var_name] = var_obj 62 | mutable = FrozenDict( 63 | {name: v if len(v) > 0 else None for name, v in mutable.items()} 64 | ) 65 | return cls(params=params, mutable=mutable) 66 | -------------------------------------------------------------------------------- /fortuna/output_calibrator/regression.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | from fortuna.typing import Array 5 | 6 | 7 | class RegressionTemperatureScaler(nn.Module): 8 | r""" 9 | Regression temperature scaling. It multiplies the variance with a scalar temperature parameters. Let :math:`v` be 10 | the variance outputs and :math:`\phi` be a scalar parameter. Then the scaling can be seen as 11 | :math:`g(\phi, o) = \exp(\phi) v`. 12 | """ 13 | 14 | @nn.compact 15 | def __call__(self, x: Array, **kwargs) -> jnp.ndarray: 16 | log_temp = self.param("log_temp", nn.initializers.zeros, (1,)) 17 | mean, log_var = jnp.split(x, 2, axis=-1) 18 | log_var += log_temp 19 | return jnp.concatenate((mean, log_var), axis=-1) 20 | -------------------------------------------------------------------------------- /fortuna/prob_model/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.calib_config.base import CalibConfig 2 | from fortuna.prob_model.calib_config.checkpointer import CalibCheckpointer 3 | from fortuna.prob_model.calib_config.monitor import CalibMonitor 4 | from fortuna.prob_model.calib_config.optimizer import CalibOptimizer 5 | from fortuna.prob_model.calib_config.processor import CalibProcessor 6 | from fortuna.prob_model.classification import ProbClassifier 7 | from fortuna.prob_model.fit_config.base import FitConfig 8 | from fortuna.prob_model.fit_config.checkpointer import FitCheckpointer 9 | from fortuna.prob_model.fit_config.monitor import FitMonitor 10 | from fortuna.prob_model.fit_config.optimizer import FitOptimizer 11 | from fortuna.prob_model.fit_config.processor import FitProcessor 12 | from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_approximator import ( 13 | DeepEnsemblePosteriorApproximator, 14 | ) 15 | from fortuna.prob_model.posterior.laplace.laplace_approximator import ( 16 | LaplacePosteriorApproximator, 17 | ) 18 | from fortuna.prob_model.posterior.map.map_posterior import MAPPosteriorApproximator 19 | from fortuna.prob_model.posterior.normalizing_flow.advi.advi_approximator import ( 20 | ADVIPosteriorApproximator, 21 | ) 22 | from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_approximator import ( 23 | CyclicalSGLDPosteriorApproximator, 24 | ) 25 | from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_approximator import ( 26 | SGHMCPosteriorApproximator, 27 | ) 28 | from fortuna.prob_model.posterior.sngp.sngp_approximator import ( 29 | SNGPPosteriorApproximator, 30 | ) 31 | from fortuna.prob_model.posterior.swag.swag_approximator import ( 32 | SWAGPosteriorApproximator, 33 | ) 34 | from fortuna.prob_model.regression import ProbRegressor 35 | -------------------------------------------------------------------------------- /fortuna/prob_model/calib_config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_model/calib_config/__init__.py -------------------------------------------------------------------------------- /fortuna/prob_model/calib_config/base.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.calib_config.checkpointer import CalibCheckpointer 2 | from fortuna.prob_model.calib_config.monitor import CalibMonitor 3 | from fortuna.prob_model.calib_config.optimizer import CalibOptimizer 4 | from fortuna.prob_model.calib_config.processor import CalibProcessor 5 | 6 | 7 | class CalibConfig: 8 | def __init__( 9 | self, 10 | optimizer: CalibOptimizer = CalibOptimizer(), 11 | checkpointer: CalibCheckpointer = CalibCheckpointer(), 12 | monitor: CalibMonitor = CalibMonitor(), 13 | processor: CalibProcessor = CalibProcessor(), 14 | ): 15 | """ 16 | Configure the probabilistic model calibration. 17 | 18 | Parameters 19 | ---------- 20 | optimizer: CalibOptimizer 21 | It defines the optimization specifics. 22 | checkpointer: CalibCheckpointer 23 | It handles saving and restoring checkpoints. 24 | monitor: CalibMonitor 25 | It monitors training progress and might induce early stopping. 26 | processor: CalibProcessor 27 | It processes where computation takes place. 28 | """ 29 | self.optimizer = optimizer 30 | self.checkpointer = checkpointer 31 | self.monitor = monitor 32 | self.processor = processor 33 | -------------------------------------------------------------------------------- /fortuna/prob_model/calib_config/checkpointer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fortuna.typing import Path 4 | 5 | 6 | class CalibCheckpointer: 7 | def __init__( 8 | self, 9 | save_checkpoint_dir: Optional[Path] = None, 10 | restore_checkpoint_path: Optional[Path] = None, 11 | save_every_n_steps: Optional[int] = None, 12 | keep_top_n_checkpoints: Optional[int] = 2, 13 | dump_state: bool = False, 14 | ): 15 | """ 16 | An object to configure saving and restoring of checkpoints during the calibration process. 17 | 18 | Parameters 19 | ---------- 20 | save_checkpoint_dir: Optional[Path] = None 21 | Save directory location. 22 | restore_checkpoint_path: Optional[Path] 23 | Path to checkpoint file or directory to restore. 24 | save_every_n_steps: int 25 | Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no 26 | checkpoint will be saved during training). 27 | keep_top_n_checkpoints: int 28 | Number of past checkpoint files to keep. 29 | dump_state: bool 30 | Dump the fitted calibration state as a checkpoint in `save_checkpoint_dir`. 31 | Any future call to the state will internally involve restoring it from memory. 32 | """ 33 | self.save_checkpoint_dir = save_checkpoint_dir 34 | self.save_every_n_steps = save_every_n_steps 35 | self.restore_checkpoint_path = restore_checkpoint_path 36 | self.keep_top_n_checkpoints = keep_top_n_checkpoints 37 | self.dump_state = dump_state 38 | -------------------------------------------------------------------------------- /fortuna/prob_model/calib_config/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import optax 4 | 5 | from fortuna.typing import OptaxOptimizer 6 | 7 | 8 | class CalibOptimizer: 9 | def __init__( 10 | self, 11 | method: Optional[OptaxOptimizer] = optax.adam(1e-2), 12 | n_epochs: int = 100, 13 | ): 14 | """ 15 | An object to configure the optimization in the calibration process. 16 | 17 | Parameters 18 | ---------- 19 | method: OptaxOptimizer 20 | An Optax optimizer. 21 | n_epochs: int 22 | Maximum number of epochs to run the calibration for. 23 | """ 24 | self.method = method 25 | self.n_epochs = n_epochs 26 | -------------------------------------------------------------------------------- /fortuna/prob_model/calib_config/processor.py: -------------------------------------------------------------------------------- 1 | class CalibProcessor: 2 | def __init__( 3 | self, 4 | devices: int = -1, 5 | disable_jit: bool = False, 6 | n_posterior_samples: int = 30, 7 | ): 8 | """ 9 | An object to configure computational aspects of the calibration process. 10 | 11 | Parameters 12 | ---------- 13 | devices: int 14 | A list of devices to be used during training. 15 | At the moment two options are supported: use all devices (`devices=-1`) or use no device (`devices=0`). 16 | disable_jit: bool 17 | if True, no function within the calibration loop is jitted. 18 | n_posterior_samples: int 19 | Number of posterior samples to draw from the posterior distribution for the calibration process. 20 | """ 21 | self.devices = devices 22 | self.disable_jit = disable_jit 23 | self.n_posterior_samples = n_posterior_samples 24 | -------------------------------------------------------------------------------- /fortuna/prob_model/fit_config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_model/fit_config/__init__.py -------------------------------------------------------------------------------- /fortuna/prob_model/fit_config/base.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | List, 3 | Optional, 4 | ) 5 | 6 | from fortuna.prob_model.fit_config.callback import FitCallback 7 | from fortuna.prob_model.fit_config.checkpointer import FitCheckpointer 8 | from fortuna.prob_model.fit_config.hyperparameters import FitHyperparameters 9 | from fortuna.prob_model.fit_config.monitor import FitMonitor 10 | from fortuna.prob_model.fit_config.optimizer import FitOptimizer 11 | from fortuna.prob_model.fit_config.processor import FitProcessor 12 | 13 | 14 | class FitConfig: 15 | def __init__( 16 | self, 17 | optimizer: FitOptimizer = FitOptimizer(), 18 | checkpointer: FitCheckpointer = FitCheckpointer(), 19 | monitor: FitMonitor = FitMonitor(), 20 | processor: FitProcessor = FitProcessor(), 21 | hyperparameters: FitHyperparameters = FitHyperparameters(), 22 | callbacks: Optional[List[FitCallback]] = None, 23 | ): 24 | """ 25 | Configure the posterior distribution fitting. 26 | 27 | Parameters 28 | ---------- 29 | optimizer: FitOptimizer 30 | It defines the optimization specifics. 31 | checkpointer: FitCheckpointer 32 | It handles saving and restoring checkpoints. 33 | monitor: FitMonitor 34 | It monitors training progress and might induce early stopping. 35 | processor: FitProcessor 36 | It processes where computation takes place. 37 | hyperparameters: FitHyperparameters 38 | It defines other hyperparameters that may be needed during model's training. 39 | callbacks: Optional[List[FitCallback]] 40 | A list of user-defined callbacks to be called during training. 41 | Callbacks run sequentially in the order defined by the user. 42 | """ 43 | self.optimizer = optimizer 44 | self.checkpointer = checkpointer 45 | self.monitor = monitor 46 | self.processor = processor 47 | self.hyperparameters = hyperparameters 48 | self.callbacks = callbacks 49 | -------------------------------------------------------------------------------- /fortuna/prob_model/fit_config/callback.py: -------------------------------------------------------------------------------- 1 | from fortuna.training.callback import Callback 2 | 3 | 4 | class FitCallback(Callback): 5 | pass 6 | -------------------------------------------------------------------------------- /fortuna/prob_model/fit_config/checkpointer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fortuna.typing import Path 4 | 5 | 6 | class FitCheckpointer: 7 | def __init__( 8 | self, 9 | save_checkpoint_dir: Optional[Path] = None, 10 | restore_checkpoint_path: Optional[Path] = None, 11 | start_from_current_state: bool = False, 12 | save_every_n_steps: Optional[int] = None, 13 | keep_top_n_checkpoints: Optional[int] = 2, 14 | dump_state: bool = False, 15 | ): 16 | """ 17 | An object to configure saving and restoring of checkpoints during the posterior fitting. 18 | 19 | Parameters 20 | ---------- 21 | save_checkpoint_dir: Optional[Path] 22 | Save directory location. 23 | restore_checkpoint_path: Optional[Path] 24 | Path to checkpoint file or directory to restore. 25 | start_from_current_state: bool = False 26 | If True, the optimization will start from the current state. If `restore_checkpoint_path` is given, then 27 | `start_from_current_state` is ignored. 28 | save_every_n_steps: int 29 | Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no 30 | checkpoint will be saved during training). 31 | keep_top_n_checkpoints: int 32 | Number of past checkpoint files to keep. 33 | dump_state: bool 34 | Dump the fitted posterior state as a checkpoint in `save_checkpoint_dir`. Any future call to the state will 35 | internally involve restoring it from memory. 36 | """ 37 | self.save_checkpoint_dir = save_checkpoint_dir 38 | self.save_every_n_steps = save_every_n_steps 39 | self.restore_checkpoint_path = restore_checkpoint_path 40 | self.start_from_current_state = start_from_current_state 41 | self.keep_top_n_checkpoints = keep_top_n_checkpoints 42 | self.dump_state = dump_state 43 | -------------------------------------------------------------------------------- /fortuna/prob_model/fit_config/hyperparameters.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class FitHyperparameters: 5 | def __init__( 6 | self, 7 | max_grad_norm: Optional[float] = None, 8 | gradient_accumulation_steps: Optional[int] = None, 9 | ): 10 | """ 11 | An object to configure additional arguments that may be needed during the posterior fitting. 12 | 13 | Parameters 14 | ---------- 15 | max_grad_norm: Optional[Path] 16 | Maximum gradient norm. If `max_grad_norm > 0`, gradient clipping is performed. 17 | gradient_accumulation_steps: Optional[Path] 18 | Number of forward passes to perform before doing a backward pass. 19 | """ 20 | self.max_grad_norm = max_grad_norm 21 | self.gradient_accumulation_steps = gradient_accumulation_steps 22 | -------------------------------------------------------------------------------- /fortuna/prob_model/fit_config/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Callable, 3 | Optional, 4 | Tuple, 5 | ) 6 | 7 | import optax 8 | 9 | from fortuna.typing import ( 10 | AnyKey, 11 | Array, 12 | OptaxOptimizer, 13 | ) 14 | 15 | 16 | class FitOptimizer: 17 | def __init__( 18 | self, 19 | method: Optional[OptaxOptimizer] = optax.adam(1e-3), 20 | n_epochs: int = 100, 21 | freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None, 22 | ): 23 | """ 24 | An object to configure the optimization in the posterior fitting. 25 | 26 | Parameters 27 | ---------- 28 | method: OptaxOptimizer 29 | An Optax optimizer. 30 | n_epochs: int 31 | Maximum number of epochs to run the training for. 32 | freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] 33 | A callable taking in input a path in the nested dictionary of parameters, as well as the corresponding 34 | array of parameters, and returns "trainable" or "freeze", according to whether the corresponding parameter 35 | should be optimized or not. 36 | """ 37 | self.method = method 38 | self.n_epochs = n_epochs 39 | self.freeze_fun = freeze_fun 40 | -------------------------------------------------------------------------------- /fortuna/prob_model/fit_config/processor.py: -------------------------------------------------------------------------------- 1 | class FitProcessor: 2 | def __init__( 3 | self, 4 | devices: int = -1, 5 | disable_jit: bool = False, 6 | ): 7 | """ 8 | An object to configure computational aspects of the posterior fitting. 9 | 10 | Parameters 11 | ---------- 12 | devices: int 13 | A list of devices to be used during training. 14 | At the moment two options are supported: use all devices (`devices=-1`) or use no device (`devices=0`). 15 | disable_jit: bool 16 | if True, no function within the training loop is jitted. 17 | """ 18 | self.devices = devices 19 | self.disable_jit = disable_jit 20 | -------------------------------------------------------------------------------- /fortuna/prob_model/joint/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_model/joint/__init__.py -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_model/posterior/__init__.py -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/deep_ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | DEEP_ENSEMBLE_NAME = "deep_ensemble" 2 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_approximator.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.posterior.base import PosteriorApproximator 2 | from fortuna.prob_model.posterior.deep_ensemble import DEEP_ENSEMBLE_NAME 3 | 4 | 5 | class DeepEnsemblePosteriorApproximator(PosteriorApproximator): 6 | def __init__(self, ensemble_size: int = 5): 7 | """ 8 | Deep ensemble posterior approximator. It is responsible to define how the posterior distribution is 9 | approximated. 10 | 11 | Parameters 12 | ---------- 13 | ensemble_size : int 14 | The size of the ensemble. 15 | """ 16 | self.ensemble_size = ensemble_size 17 | 18 | def __str__(self): 19 | return DEEP_ENSEMBLE_NAME 20 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_state.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fortuna.prob_model.posterior.map.map_state import MAPState 4 | 5 | DeepEnsembleState = List[MAPState] 6 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/laplace/__init__.py: -------------------------------------------------------------------------------- 1 | LAPLACE_NAME = "laplace" 2 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/laplace/laplace_approximator.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.posterior.base import PosteriorApproximator 2 | from fortuna.prob_model.posterior.laplace import LAPLACE_NAME 3 | 4 | 5 | class LaplacePosteriorApproximator(PosteriorApproximator): 6 | def __init__(self, tune_prior_log_variance: bool = True): 7 | """ 8 | Laplace posterior approximator. 9 | """ 10 | self.tune_prior_log_variance = tune_prior_log_variance 11 | 12 | def __str__(self): 13 | return LAPLACE_NAME 14 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/map/__init__.py: -------------------------------------------------------------------------------- 1 | MAP_NAME = "map" 2 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/map/map_approximator.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.posterior.base import PosteriorApproximator 2 | from fortuna.prob_model.posterior.map import MAP_NAME 3 | 4 | 5 | class MAPPosteriorApproximator(PosteriorApproximator): 6 | """Maximum-A-Posteriori posterior approximator. It is responsible to define how the posterior distribution is 7 | approximated.""" 8 | 9 | def __str__(self): 10 | return MAP_NAME 11 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/map/map_state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from fortuna.prob_model.posterior.state import PosteriorState 4 | from fortuna.utils.strings import convert_string_to_tuple 5 | 6 | 7 | class MAPState(PosteriorState): 8 | """ 9 | Attributes 10 | ---------- 11 | encoded_name: jnp.ndarray 12 | MAP state name encoded as an array. 13 | """ 14 | 15 | encoded_name: tuple = convert_string_to_tuple("MAPState") 16 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/name_to_posterior_state.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from fortuna.output_calib_model.state import OutputCalibState 4 | from fortuna.prob_model.posterior.laplace.laplace_state import LaplaceState 5 | from fortuna.prob_model.posterior.map.map_state import MAPState 6 | from fortuna.prob_model.posterior.normalizing_flow.advi.advi_state import ADVIState 7 | from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_state import ( 8 | CyclicalSGLDState, 9 | ) 10 | from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_state import SGHMCState 11 | from fortuna.prob_model.posterior.state import PosteriorState 12 | from fortuna.prob_model.posterior.swag.swag_state import SWAGState 13 | 14 | 15 | class NameToPosteriorState(enum.Enum): 16 | vars()[OutputCalibState.__name__] = OutputCalibState 17 | vars()[PosteriorState.__name__] = PosteriorState 18 | vars()[MAPState.__name__] = MAPState 19 | vars()[ADVIState.__name__] = ADVIState 20 | vars()[LaplaceState.__name__] = LaplaceState 21 | vars()[SWAGState.__name__] = SWAGState 22 | vars()[SGHMCState.__name__] = SGHMCState 23 | vars()[CyclicalSGLDState.__name__] = CyclicalSGLDState 24 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/normalizing_flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_model/posterior/normalizing_flow/__init__.py -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/normalizing_flow/advi/__init__.py: -------------------------------------------------------------------------------- 1 | ADVI_NAME = "advi" 2 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/normalizing_flow/advi/advi_approximator.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.posterior.base import PosteriorApproximator 2 | from fortuna.prob_model.posterior.normalizing_flow.advi import ADVI_NAME 3 | 4 | 5 | class ADVIPosteriorApproximator(PosteriorApproximator): 6 | def __init__( 7 | self, 8 | std_init_params: float = 0.1, 9 | log_std_base: float = -2.3, 10 | n_loss_samples: int = 3, 11 | ): 12 | """ 13 | Automatic Differentiation Variational Inference (ADVI) approximator. It is responsible to define how the 14 | posterior distribution is approximated. 15 | 16 | Parameters 17 | ---------- 18 | std_init_params : float 19 | The standard deviation of the Gaussian distribution used to initialize the parameters of the flow. 20 | log_std_base : float 21 | The normalizing flow transforms a base distribution into an approximation of the posterior. The base 22 | distribution is assumed to be an isotropic Gaussian, with this argument as the log-standard deviation. 23 | n_loss_samples : int 24 | Number of samples to approximate the loss, that is the KL divergence (or the ELBO, equivalently). 25 | """ 26 | self.std_init_params = std_init_params 27 | self.log_std_base = log_std_base 28 | self.n_loss_samples = n_loss_samples 29 | 30 | def __str__(self): 31 | return ADVI_NAME 32 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/normalizing_flow/advi/advi_state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Dict, 5 | List, 6 | Optional, 7 | ) 8 | 9 | from fortuna.prob_model.posterior.normalizing_flow.normalizing_flow_state import ( 10 | NormalizingFlowState, 11 | ) 12 | from fortuna.typing import Array 13 | from fortuna.utils.strings import convert_string_to_tuple 14 | 15 | 16 | class ADVIState(NormalizingFlowState): 17 | """ 18 | Attributes 19 | ---------- 20 | encoded_name: jnp.ndarray 21 | ADVI state name encoded as an array. 22 | """ 23 | 24 | encoded_name: tuple = convert_string_to_tuple("ADVIState") 25 | _encoded_which_params: Optional[Dict[str, List[Array]]] = None 26 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/normalizing_flow/normalizing_flow_state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from fortuna.prob_model.posterior.state import PosteriorState 4 | 5 | 6 | class NormalizingFlowState(PosteriorState): 7 | pass 8 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/posterior_approximations.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from fortuna.prob_model.posterior.deep_ensemble import DEEP_ENSEMBLE_NAME 4 | from fortuna.prob_model.posterior.deep_ensemble.deep_ensemble_posterior import ( 5 | DeepEnsemblePosterior, 6 | ) 7 | from fortuna.prob_model.posterior.laplace import LAPLACE_NAME 8 | from fortuna.prob_model.posterior.laplace.laplace_posterior import LaplacePosterior 9 | from fortuna.prob_model.posterior.map import MAP_NAME 10 | from fortuna.prob_model.posterior.map.map_posterior import MAPPosterior 11 | from fortuna.prob_model.posterior.normalizing_flow.advi import ADVI_NAME 12 | from fortuna.prob_model.posterior.normalizing_flow.advi.advi_posterior import ( 13 | ADVIPosterior, 14 | ) 15 | from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld import CYCLICAL_SGLD_NAME 16 | from fortuna.prob_model.posterior.sgmcmc.cyclical_sgld.cyclical_sgld_posterior import ( 17 | CyclicalSGLDPosterior, 18 | ) 19 | from fortuna.prob_model.posterior.sgmcmc.sghmc import SGHMC_NAME 20 | from fortuna.prob_model.posterior.sgmcmc.sghmc.sghmc_posterior import SGHMCPosterior 21 | from fortuna.prob_model.posterior.sngp import SNGP_NAME 22 | from fortuna.prob_model.posterior.sngp.sngp_posterior import SNGPPosterior 23 | from fortuna.prob_model.posterior.swag import SWAG_NAME 24 | from fortuna.prob_model.posterior.swag.swag_posterior import SWAGPosterior 25 | 26 | 27 | class PosteriorApproximations(enum.Enum): 28 | """Map approximator name to posterior posterior_approximation.""" 29 | 30 | vars()[MAP_NAME] = MAPPosterior 31 | vars()[ADVI_NAME] = ADVIPosterior 32 | vars()[DEEP_ENSEMBLE_NAME] = DeepEnsemblePosterior 33 | vars()[LAPLACE_NAME] = LaplacePosterior 34 | vars()[SWAG_NAME] = SWAGPosterior 35 | vars()[SNGP_NAME] = SNGPPosterior 36 | vars()[SGHMC_NAME] = SGHMCPosterior 37 | vars()[CYCLICAL_SGLD_NAME] = CyclicalSGLDPosterior 38 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/posterior_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from fortuna.prob_model.posterior.name_to_posterior_state import NameToPosteriorState 4 | from fortuna.prob_model.posterior.state import PosteriorState 5 | from fortuna.training.mixin import WithCheckpointingMixin 6 | from fortuna.typing import ( 7 | OptaxOptimizer, 8 | Path, 9 | ) 10 | 11 | 12 | class WithPosteriorCheckpointingMixin(WithCheckpointingMixin): 13 | def restore_checkpoint( 14 | self, 15 | restore_checkpoint_path: Path, 16 | optimizer: Optional[OptaxOptimizer] = None, 17 | prefix: str = "checkpoint_", 18 | name_to_train_state: NameToPosteriorState = NameToPosteriorState, 19 | **kwargs, 20 | ) -> PosteriorState: 21 | return super().restore_checkpoint( 22 | restore_checkpoint_path, 23 | optimizer, 24 | prefix, 25 | name_to_train_state=name_to_train_state, 26 | ) 27 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/posterior_state_repository.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Dict, 3 | Optional, 4 | ) 5 | 6 | from fortuna.prob_model.posterior.posterior_mixin import WithPosteriorCheckpointingMixin 7 | from fortuna.training.train_state_repository import TrainStateRepository 8 | from fortuna.typing import Path 9 | 10 | 11 | class PosteriorStateRepository(WithPosteriorCheckpointingMixin, TrainStateRepository): 12 | def extract_calib_keys( 13 | self, 14 | checkpoint_path: Optional[Path] = None, 15 | prefix: str = "checkpoint_", 16 | **kwargs, 17 | ) -> Dict: 18 | return super().extract( 19 | ["calib_params", "calib_mutable"], checkpoint_path, prefix, **kwargs 20 | ) 21 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/posterior_trainer.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.posterior.posterior_mixin import WithPosteriorCheckpointingMixin 2 | from fortuna.training.trainer import TrainerABC 3 | 4 | 5 | class PosteriorTrainerABC(WithPosteriorCheckpointingMixin, TrainerABC): 6 | pass 7 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/run_preliminary_map.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import ( 3 | Optional, 4 | Tuple, 5 | ) 6 | 7 | from fortuna.data.loader import DataLoader 8 | from fortuna.prob_model.fit_config.base import FitConfig 9 | from fortuna.prob_model.joint.base import Joint 10 | from fortuna.prob_model.posterior.map.map_approximator import MAPPosteriorApproximator 11 | from fortuna.prob_model.posterior.map.map_posterior import MAPPosterior 12 | from fortuna.prob_model.posterior.map.map_state import MAPState 13 | from fortuna.typing import Status 14 | from fortuna.utils.random import RandomNumberGenerator 15 | 16 | 17 | def run_preliminary_map( 18 | joint: Joint, 19 | train_data_loader: DataLoader, 20 | val_data_loader: DataLoader, 21 | map_fit_config: Optional[FitConfig], 22 | rng: RandomNumberGenerator, 23 | **kwargs, 24 | ) -> Tuple[MAPState, Status]: 25 | logging.info("Do a preliminary run of MAP.") 26 | map_posterior = MAPPosterior( 27 | joint, posterior_approximator=MAPPosteriorApproximator() 28 | ) 29 | map_posterior.rng = rng 30 | status = map_posterior.fit( 31 | rng=rng.get(), 32 | train_data_loader=train_data_loader, 33 | val_data_loader=val_data_loader, 34 | fit_config=map_fit_config, 35 | **kwargs, 36 | ) 37 | logging.info("Preliminary run with MAP completed.") 38 | return map_posterior.state.get(), status 39 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sgmcmc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_model/posterior/sgmcmc/__init__.py -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sgmcmc/base.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.posterior.base import PosteriorApproximator 2 | from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( 3 | Preconditioner, 4 | identity_preconditioner, 5 | ) 6 | 7 | 8 | class SGMCMCPosteriorApproximator(PosteriorApproximator): 9 | def __init__( 10 | self, 11 | n_samples: int = 10, 12 | n_thinning: int = 1, 13 | preconditioner: Preconditioner = identity_preconditioner(), 14 | ) -> None: 15 | """ 16 | SGMCMC posterior approximator. It is responsible to define how the posterior distribution is approximated. 17 | 18 | Parameters 19 | ---------- 20 | n_samples: int 21 | The desired number of the posterior samples. 22 | n_thinning: int 23 | If `n_thinning` > 1, keep only each `n_thinning` sample during the sampling phase. 24 | preconditioner: Preconditioner 25 | A `Preconditioner` instance that preconditions the approximator with information about the posterior distribution, if available. 26 | 27 | """ 28 | self.n_samples = n_samples 29 | self.n_thinning = n_thinning 30 | self.preconditioner = preconditioner 31 | 32 | def __str__(self) -> str: 33 | raise NotImplementedError 34 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/__init__.py: -------------------------------------------------------------------------------- 1 | CYCLICAL_SGLD_NAME = "cyclical_sgld" 2 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Dict, 5 | List, 6 | Optional, 7 | Tuple, 8 | ) 9 | 10 | from fortuna.prob_model.posterior.map.map_state import MAPState 11 | from fortuna.prob_model.posterior.state import PosteriorState 12 | from fortuna.typing import ( 13 | AnyKey, 14 | Array, 15 | OptaxOptimizer, 16 | ) 17 | from fortuna.utils.strings import ( 18 | convert_string_to_tuple, 19 | encode_tuple_of_lists_of_strings_to_numpy, 20 | ) 21 | 22 | 23 | class CyclicalSGLDState(PosteriorState): 24 | """ 25 | Attributes 26 | ---------- 27 | encoded_name: jnp.ndarray 28 | CyclicalSGLDState state name encoded as an array. 29 | """ 30 | 31 | encoded_name: tuple = convert_string_to_tuple("CyclicalSGLDState") 32 | _encoded_which_params: Optional[Dict[str, List[Array]]] = None 33 | 34 | @classmethod 35 | def convert_from_map_state( 36 | cls, 37 | map_state: MAPState, 38 | optimizer: OptaxOptimizer, 39 | which_params: Tuple[List[AnyKey], ...], 40 | ) -> CyclicalSGLDState: 41 | """ 42 | Convert a MAP state into an CyclicalSGLDState state. 43 | 44 | Parameters 45 | ---------- 46 | map_state: MAPState 47 | A MAP posterior state. 48 | optimizer: OptaxOptimizer 49 | An Optax optimizer. 50 | which_params: Tuple[List[AnyKey], ...] 51 | Sequences of keys pointing to the stochastic parameters. 52 | 53 | Returns 54 | ------- 55 | CyclicalSGLDState 56 | An Cyclical SGLD state. 57 | """ 58 | _encoded_which_params = encode_tuple_of_lists_of_strings_to_numpy(which_params) 59 | return cls.init( 60 | params=map_state.params, 61 | mutable=map_state.mutable, 62 | optimizer=optimizer, 63 | calib_params=map_state.calib_params, 64 | calib_mutable=map_state.calib_mutable, 65 | _encoded_which_params=_encoded_which_params, 66 | ) 67 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sgmcmc/sghmc/__init__.py: -------------------------------------------------------------------------------- 1 | SGHMC_NAME = "sghmc" 2 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Dict, 5 | List, 6 | Optional, 7 | Tuple, 8 | ) 9 | 10 | from fortuna.prob_model.posterior.map.map_state import MAPState 11 | from fortuna.prob_model.posterior.state import PosteriorState 12 | from fortuna.typing import ( 13 | AnyKey, 14 | Array, 15 | OptaxOptimizer, 16 | ) 17 | from fortuna.utils.strings import ( 18 | convert_string_to_tuple, 19 | encode_tuple_of_lists_of_strings_to_numpy, 20 | ) 21 | 22 | 23 | class SGHMCState(PosteriorState): 24 | """ 25 | Attributes 26 | ---------- 27 | encoded_name: jnp.ndarray 28 | SGHMC state name encoded as an array. 29 | """ 30 | 31 | encoded_name: tuple = convert_string_to_tuple("SGHMCState") 32 | _encoded_which_params: Optional[Dict[str, List[Array]]] = None 33 | 34 | @classmethod 35 | def convert_from_map_state( 36 | cls, 37 | map_state: MAPState, 38 | optimizer: OptaxOptimizer, 39 | which_params: Tuple[List[AnyKey], ...], 40 | ) -> SGHMCState: 41 | """ 42 | Convert a MAP state into an SGHMC state. 43 | 44 | Parameters 45 | ---------- 46 | map_state: MAPState 47 | A MAP posterior state. 48 | optimizer: OptaxOptimizer 49 | An Optax optimizer. 50 | which_params: Tuple[List[AnyKey], ...] 51 | Sequences of keys pointing to the stochastic parameters. 52 | 53 | Returns 54 | ------- 55 | SGHMCState 56 | An SGHMC state. 57 | """ 58 | _encoded_which_params = encode_tuple_of_lists_of_strings_to_numpy(which_params) 59 | return cls.init( 60 | params=map_state.params, 61 | mutable=map_state.mutable, 62 | optimizer=optimizer, 63 | calib_params=map_state.calib_params, 64 | calib_mutable=map_state.calib_mutable, 65 | _encoded_which_params=_encoded_which_params, 66 | ) 67 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sgmcmc/sgmcmc_sampling_callback.py: -------------------------------------------------------------------------------- 1 | from fortuna.training.callback import Callback 2 | from fortuna.training.train_state import TrainState 3 | from fortuna.training.train_state_repository import TrainStateRepository 4 | from fortuna.training.trainer import TrainerABC 5 | 6 | 7 | class SGMCMCSamplingCallback(Callback): 8 | def __init__( 9 | self, 10 | trainer: TrainerABC, 11 | state_repository: TrainStateRepository, 12 | keep_top_n_checkpoints: int, 13 | ): 14 | """ 15 | Sampling callback that collects samples from the MCMC chain. 16 | 17 | Parameters 18 | ---------- 19 | trainer: TrainerABC 20 | An instance of the trainer class. 21 | state_repository: TrainStateRepository 22 | An instance of the state repository. 23 | keep_top_n_checkpoints: int 24 | Number of past checkpoint files to keep. 25 | """ 26 | self._trainer = trainer 27 | self._state_repository = state_repository 28 | self._keep_top_n_checkpoints = keep_top_n_checkpoints 29 | 30 | self._current_step = 0 31 | self._samples_count = 0 32 | 33 | def _do_sample(self, current_step, samples_count): 34 | raise NotImplementedError 35 | 36 | def training_step_end(self, state: TrainState) -> TrainState: 37 | self._current_step += 1 38 | 39 | if self._do_sample(self._current_step, self._samples_count): 40 | self._state_repository.put( 41 | state=self._trainer._sync_state(state), 42 | i=self._samples_count, 43 | keep=self._keep_top_n_checkpoints, 44 | ) 45 | self._samples_count += 1 46 | 47 | return state 48 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sngp/__init__.py: -------------------------------------------------------------------------------- 1 | SNGP_NAME = "sngp" 2 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sngp/sngp_callback.py: -------------------------------------------------------------------------------- 1 | from flax.core import FrozenDict 2 | import jax.numpy as jnp 3 | 4 | from fortuna.training.callback import Callback 5 | from fortuna.training.train_state import TrainState 6 | from fortuna.utils.nested_dicts import ( 7 | find_one_path_to_key, 8 | nested_get, 9 | nested_update, 10 | ) 11 | 12 | 13 | class ResetCovarianceCallback(Callback): 14 | """ 15 | Reset, at the beginning of each epoch, the covariance matrix estimated while training an SNGP model. 16 | """ 17 | 18 | def __init__(self, precision_matrix_key_name: str, ridge_penalty: float): 19 | self.precision_matrix_key_name = precision_matrix_key_name 20 | self.ridge_penalty = ridge_penalty 21 | 22 | def training_epoch_start(self, state: TrainState) -> TrainState: 23 | key_paths = find_one_path_to_key(state.mutable, self.precision_matrix_key_name) 24 | precision_matrix = nested_get(state.mutable, key_paths) 25 | if precision_matrix.ndim == 2: 26 | n, _ = precision_matrix.shape # rows, cols 27 | init_precision_matrix = ( 28 | jnp.eye(n, dtype=precision_matrix.dtype) * self.ridge_penalty 29 | ) 30 | elif precision_matrix.ndim == 3: 31 | d, n, _ = precision_matrix.shape # num_devices, rows, cols 32 | init_precision_matrix = ( 33 | jnp.eye(n, dtype=precision_matrix.dtype) * self.ridge_penalty 34 | ) 35 | init_precision_matrix = jnp.broadcast_to(init_precision_matrix, (d, n, n)) 36 | 37 | partially_updated_mutables = init_precision_matrix 38 | for key in reversed(key_paths): 39 | partially_updated_mutables = {key: partially_updated_mutables} 40 | mutables = nested_update(state.mutable.unfreeze(), partially_updated_mutables) 41 | mutables = FrozenDict(mutables) 42 | return state.replace(mutable=mutables) 43 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sngp/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.posterior.sngp.transformers.modeling_flax_auto import ( 2 | FlaxAutoSNGPModelForSequenceClassification, 3 | ) 4 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sngp/transformers/modeling_flax_auto.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES 4 | 5 | from fortuna.prob_model.posterior.sngp.transformers.auto_factory import ( 6 | _BaseAutoSNGPModelClass, 7 | _SNGPLazyAutoMapping, 8 | ) 9 | 10 | FLAX_SNGP_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( 11 | [ 12 | # Model for Sequence Classification with SNGP mapping 13 | ("bert", "FlaxSNGPBertExtractorForSequenceClassification"), 14 | ("distilbert", "FlaxSNGPDistilBertExtractorForSequenceClassification"), 15 | ("roberta", "FlaxSNGPRobertaExtractorForSequenceClassification"), 16 | ] 17 | ) 18 | 19 | FLAX_SNGP_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _SNGPLazyAutoMapping( 20 | CONFIG_MAPPING_NAMES, FLAX_SNGP_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES 21 | ) 22 | 23 | 24 | class FlaxAutoSNGPModelForSequenceClassification(_BaseAutoSNGPModelClass): 25 | _model_mapping = FLAX_SNGP_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING 26 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/sngp/transformers/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_model/posterior/sngp/transformers/models/__init__.py -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/swag/__init__.py: -------------------------------------------------------------------------------- 1 | SWAG_NAME = "swag" 2 | -------------------------------------------------------------------------------- /fortuna/prob_model/posterior/swag/swag_approximator.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.posterior.base import PosteriorApproximator 2 | from fortuna.prob_model.posterior.swag import SWAG_NAME 3 | 4 | 5 | class SWAGPosteriorApproximator(PosteriorApproximator): 6 | def __init__(self, rank: int = 5): 7 | """ 8 | SWAG posterior approximator. It is responsible to define how the posterior distribution is approximated. 9 | 10 | Parameters 11 | ---------- 12 | rank: int 13 | SWAG approximates the posterior with a Gaussian distribution. The Gaussian's covariance matrix is formed by 14 | a diagonal matrix, and a low-rank empirical approximation. This argument defines the rank of the low-rank 15 | empirical covariance approximation. It must be at least 2. 16 | """ 17 | self.rank = rank 18 | 19 | def __str__(self): 20 | return SWAG_NAME 21 | -------------------------------------------------------------------------------- /fortuna/prob_model/predictive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_model/predictive/__init__.py -------------------------------------------------------------------------------- /fortuna/prob_model/prior/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.prob_model.prior.gaussian import ( 2 | DiagonalGaussianPrior, 3 | IsotropicGaussianPrior, 4 | Prior, 5 | ) 6 | -------------------------------------------------------------------------------- /fortuna/prob_model/prior/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Optional 3 | 4 | import jax 5 | 6 | from fortuna.typing import Params 7 | from fortuna.utils.random import WithRNG 8 | 9 | 10 | class Prior(WithRNG, abc.ABC): 11 | """ 12 | Abstract prior distribution class. 13 | """ 14 | 15 | @abc.abstractmethod 16 | def log_joint_prob(self, params: Params) -> float: 17 | """ 18 | Evaluate the prior log-probability density function (a.k.a. log-pdf). 19 | 20 | Parameters 21 | ---------- 22 | params : PyTree 23 | The parameters where to evaluate the log-pdf. 24 | 25 | Returns 26 | ------- 27 | float 28 | Evaluation of the prior log-pdf. 29 | """ 30 | pass 31 | 32 | @abc.abstractmethod 33 | def sample(self, params_like: Params, rng: Optional[jax.Array] = None) -> Params: 34 | """ 35 | Sample parameters from the prior distribution. 36 | 37 | Parameters 38 | ---------- 39 | params_like : PyTree 40 | An PyTree object with the same structure as the parameters to sample. 41 | rng: Optional[jax.Array] 42 | A random number generator. If not passed, this will be taken from the attributes of this class. 43 | 44 | Returns 45 | ------- 46 | PyTree 47 | A sample from the prior distribution. 48 | """ 49 | pass 50 | -------------------------------------------------------------------------------- /fortuna/prob_model/state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Dict, 5 | Optional, 6 | Union, 7 | ) 8 | 9 | from flax.core import FrozenDict 10 | 11 | from fortuna.typing import ( 12 | Mutable, 13 | Params, 14 | ) 15 | 16 | 17 | class ModelManagerState: 18 | params: Params 19 | mutable: Optional[Mutable] = None 20 | 21 | def __init__(self, params: Params, mutable: Optional[Mutable] = None): 22 | """ 23 | An model manager state class. 24 | 25 | Parameters 26 | ---------- 27 | params : Params 28 | The random parameters of the probabilistic model. 29 | mutable : Optional[Mutable] 30 | The mutable objects used to evaluate the models. 31 | """ 32 | self.params = params 33 | self.mutable = mutable 34 | 35 | @classmethod 36 | def init_from_dict(cls, d: Union[Dict, FrozenDict]) -> ModelManagerState: 37 | """ 38 | Initialize the model manager state from a dictionary. This dictionary should be like the output of 39 | :func:`~fortuna.model.model_manager.base.ModelManager.init`. 40 | 41 | Parameters 42 | ---------- 43 | d : Union[Dict, FrozenDict] 44 | A dictionary like the output of :func:`~fortuna.model.model_manager.base.ModelManager.init`. 45 | 46 | Returns 47 | ------- 48 | ModelManagerState 49 | An model manager state. 50 | """ 51 | params = FrozenDict( 52 | {k: FrozenDict({"params": v["params"]}) for k, v in d.items()} 53 | ) 54 | mutable = FrozenDict( 55 | { 56 | k: FrozenDict({_k: _v for _k, _v in v.items() if _k != "params"}) 57 | for k, v in d.items() 58 | } 59 | ) 60 | flag = 0 61 | for k, v in mutable.items(): 62 | if len(v) > 0: 63 | flag += 1 64 | if flag == 0: 65 | mutable = None 66 | return cls(params=params, mutable=mutable) 67 | -------------------------------------------------------------------------------- /fortuna/prob_output_layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/prob_output_layer/__init__.py -------------------------------------------------------------------------------- /fortuna/sagemaker/__init__.py: -------------------------------------------------------------------------------- 1 | from fortuna.sagemaker.base import run_training_job 2 | -------------------------------------------------------------------------------- /fortuna/sagemaker/utils.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | Dict, 4 | ) 5 | 6 | from omegaconf import DictConfig 7 | 8 | 9 | def create_input_channels(cfg: DictConfig) -> Dict[str, str]: 10 | base_path = cfg.dataset.base_data_path 11 | base_path = base_path if base_path.endswith("/") else base_path + "/" 12 | 13 | if ( 14 | hasattr(cfg.dataset, "train_relative_path") 15 | and cfg.dataset.train_relative_path != "" 16 | ): 17 | channels = {"train": base_path + cfg.dataset.train_relative_path} 18 | else: 19 | channels = {"train": base_path} 20 | if ( 21 | hasattr(cfg.dataset, "test_relative_path") 22 | and cfg.dataset.test_relative_path != "" 23 | ): 24 | channels.update({"test": base_path + cfg.dataset.test_relative_path}) 25 | if ( 26 | hasattr(cfg.dataset, "validation_relative_path") 27 | and cfg.dataset.validation_relative_path != "" 28 | ): 29 | channels.update( 30 | {"validation": base_path + cfg.dataset.validation_relative_path} 31 | ) 32 | return channels 33 | 34 | 35 | def get_base_job_name(cfg: DictConfig) -> str: 36 | base_job_name = ( 37 | f"fortuna-{cfg.task.name}-{cfg.method.name}-{cfg.model.name}".replace("_", "-") 38 | ) 39 | if cfg.sagemaker.job_name_suffix is not None: 40 | base_job_name += f"-{cfg.sagemaker.job_name_suffix}".replace("_", "-") 41 | return base_job_name 42 | 43 | 44 | def get_hparams(cfg: DictConfig) -> Dict[str, Any]: 45 | task_hparams = {k: v for k, v in cfg.task.hparams.items()} 46 | model_hparams = {k: v for k, v in cfg.model.hparams.items()} 47 | method_hparams = {k: v for k, v in cfg.method.hparams.items()} 48 | 49 | hparams = dict(**task_hparams, **model_hparams, **method_hparams) 50 | 51 | return hparams 52 | -------------------------------------------------------------------------------- /fortuna/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/training/__init__.py -------------------------------------------------------------------------------- /fortuna/training/callback.py: -------------------------------------------------------------------------------- 1 | from fortuna.training.train_state import TrainState 2 | 3 | 4 | class Callback: 5 | """ 6 | Base class to define new callback functions. To define a new callback, create a child of this class and 7 | override the relevant methods. 8 | 9 | Example 10 | ------- 11 | The following is a custom callback that prints the number of model's parameters at the start of each epoch. 12 | 13 | .. code-block:: python 14 | 15 | class CountParamsCallback(Callback): 16 | def training_epoch_start(self, state: TrainState) -> TrainState: 17 | params, unravel = ravel_pytree(state.params) 18 | logger.info(f"num params: {len(params)}") 19 | return state 20 | """ 21 | 22 | def training_epoch_start(self, state: TrainState) -> TrainState: 23 | """ 24 | Called at the beginning of every training epoch 25 | 26 | Parameters 27 | ---------- 28 | state: TrainState 29 | The training state 30 | 31 | Returns 32 | ------- 33 | TrainState 34 | The (possibly updated) training state 35 | """ 36 | return state 37 | 38 | def training_epoch_end(self, state: TrainState) -> TrainState: 39 | """ 40 | Called at the end of every training epoch 41 | 42 | Parameters 43 | ---------- 44 | state: TrainState 45 | The training state 46 | 47 | Returns 48 | ------- 49 | TrainState 50 | The (possibly updated) training state 51 | """ 52 | return state 53 | 54 | def training_step_end(self, state: TrainState) -> TrainState: 55 | """ 56 | Called after every minibatch update 57 | 58 | Parameters 59 | ---------- 60 | state: TrainState 61 | The training state 62 | 63 | Returns 64 | ------- 65 | TrainState 66 | The (possibly updated) training state 67 | """ 68 | return state 69 | -------------------------------------------------------------------------------- /fortuna/training/name_to_train_state.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from fortuna.training.train_state import TrainState 4 | 5 | 6 | class NameToTrainState(enum.Enum): 7 | vars()[TrainState.__name__] = TrainState 8 | -------------------------------------------------------------------------------- /fortuna/training/train_state.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import ( 4 | Any, 5 | Optional, 6 | ) 7 | 8 | from flax.training import ( 9 | dynamic_scale, 10 | train_state, 11 | ) 12 | 13 | from fortuna.typing import Params 14 | from fortuna.utils.strings import convert_string_to_tuple 15 | 16 | 17 | class TrainState(train_state.TrainState): 18 | encoded_name: tuple = convert_string_to_tuple("TrainState") 19 | frozen_params: Optional[Params] = None 20 | dynamic_scale: Optional[dynamic_scale.DynamicScale] = None 21 | 22 | @classmethod 23 | def init(cls, *args, **kwargs) -> Any: 24 | pass 25 | 26 | @classmethod 27 | def init_from_dict(cls, *args, **kwargs) -> TrainState: 28 | pass 29 | -------------------------------------------------------------------------------- /fortuna/typing.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import ( 3 | Dict, 4 | Iterable, 5 | Tuple, 6 | Union, 7 | ) 8 | 9 | from flax.core import FrozenDict 10 | import jax.numpy as jnp 11 | import numpy as np 12 | from optax._src.base import ( 13 | GradientTransformation, 14 | PyTree, 15 | ) 16 | 17 | Params = FrozenDict[str, FrozenDict[str, PyTree]] 18 | Mutable = FrozenDict[str, FrozenDict[str, PyTree]] 19 | CalibParams = FrozenDict[str, PyTree] 20 | CalibMutable = FrozenDict[str, PyTree] 21 | OptaxOptimizer = GradientTransformation 22 | Array = Union[jnp.ndarray, np.ndarray] 23 | Status = Dict[str, Array] 24 | Path = Union[str, pathlib.Path] 25 | InputData = Union[Array, Dict[str, Array]] 26 | Targets = Array 27 | Batch = Tuple[InputData, Targets] 28 | Outputs = jnp.ndarray 29 | Uncertainties = jnp.ndarray 30 | Predictions = jnp.ndarray 31 | AnyKey = Union[str, int] 32 | Shape = Union[Iterable[int], Dict[str, Iterable[int]]] 33 | -------------------------------------------------------------------------------- /fortuna/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/fortuna/utils/__init__.py -------------------------------------------------------------------------------- /fortuna/utils/builtins.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Optional, 3 | Type, 4 | ) 5 | 6 | from flax.training.dynamic_scale import DynamicScale 7 | from jax import numpy as jnp 8 | 9 | 10 | class HashableMixin: 11 | def __hash__(self) -> int: 12 | return hash( 13 | tuple( 14 | [ 15 | getattr(self, k) 16 | for k in sorted(vars(self).keys()) 17 | if not k.startswith("_") 18 | ] 19 | ) 20 | ) 21 | 22 | def __eq__(self, other) -> bool: 23 | self_keys = [k for k in vars(self).keys() if not k.startswith("_")] 24 | other_keys = [k for k in vars(other).keys() if not k.startswith("_")] 25 | 26 | same_keys = self_keys == other_keys 27 | if same_keys and isinstance(other, self.__class__): 28 | same_vals = all( 29 | map(lambda k: getattr(self, k) == getattr(other, k), self_keys) 30 | ) 31 | return same_vals 32 | return False 33 | 34 | 35 | def get_dynamic_scale_instance_from_model_dtype(dtype: Type) -> Optional[DynamicScale]: 36 | if dtype in [jnp.float16, jnp.bfloat16]: 37 | return DynamicScale() 38 | -------------------------------------------------------------------------------- /fortuna/utils/data.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax.tree_util import tree_map 3 | import numpy as np 4 | 5 | from fortuna.data.loader import DataLoader 6 | from fortuna.typing import ( 7 | InputData, 8 | Shape, 9 | ) 10 | 11 | 12 | def check_data_loader_is_not_random(data_loader: DataLoader, max_iter: int = 3) -> None: 13 | flag = False 14 | for i, ((x1, y1), (x2, y2)) in enumerate(zip(data_loader, data_loader)): 15 | if i > max_iter: 16 | break 17 | if isinstance(x1, dict): 18 | if not all( 19 | [np.alltrue(x1[key] == x2[key]) for key in x1.keys()] 20 | ) or not np.alltrue(y1 == y2): 21 | flag = True 22 | break 23 | else: 24 | if not np.alltrue(x1 == x2) or not np.alltrue(y1 == y2): 25 | flag = True 26 | break 27 | 28 | if flag: 29 | raise ValueError( 30 | """The data loader randomizes at every iteration. To perform this method, please provide a data loader that 31 | generates the same sequence of data when called multiple times.""" 32 | ) 33 | 34 | 35 | def get_input_shape(inputs: InputData) -> Shape: 36 | return tree_map(lambda x: x.shape[1:], inputs) 37 | 38 | 39 | def get_inputs_from_shape(input_shape: Shape) -> InputData: 40 | if isinstance(input_shape, tuple): 41 | inputs = jnp.zeros((1,) + input_shape) 42 | elif isinstance(input_shape, dict): 43 | inputs = {k: jnp.zeros((1,) + v) for k, v in input_shape.items()} 44 | else: 45 | raise ValueError("Data batches shape have to be of type fortuna.typing.Shape") 46 | return inputs 47 | -------------------------------------------------------------------------------- /fortuna/utils/device.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Type 3 | 4 | import jax 5 | 6 | 7 | def select_trainer_given_devices( 8 | devices: int, 9 | base_trainer_cls: Type, 10 | jitted_trainer_cls: Type, 11 | multi_device_trainer_cls: Type, 12 | disable_jit: bool, 13 | ) -> Type: 14 | if devices not in [0, -1]: 15 | raise NotImplementedError( 16 | "Currently, only two options are supported: use all available (`devices=-1`) or use only CPU (`devices=0`)." 17 | ) 18 | elif devices == -1 and disable_jit: 19 | logging.warning("Jit must be enabled when not training on a single CPU device.") 20 | 21 | if devices == -1: 22 | logging.info("Training on all available devices.") 23 | trainer_cls = ( 24 | multi_device_trainer_cls 25 | if len([d for d in jax.devices() if d.platform == "gpu"]) > 0 26 | else jitted_trainer_cls 27 | ) 28 | 29 | elif devices == 0 and disable_jit: 30 | logging.info("Training on CPU without jit.") 31 | trainer_cls = base_trainer_cls 32 | else: 33 | logging.info("Training on CPU.") 34 | trainer_cls = jitted_trainer_cls 35 | return trainer_cls 36 | -------------------------------------------------------------------------------- /fortuna/utils/random.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import random 3 | from jax.tree_util import ( 4 | tree_map, 5 | tree_structure, 6 | tree_unflatten, 7 | ) 8 | from optax._src.base import PyTree 9 | 10 | 11 | def generate_rng_like_tree(rng, target: PyTree): 12 | treedef = tree_structure(target) 13 | keys = random.split(rng, treedef.num_leaves) 14 | return tree_unflatten(treedef, keys) 15 | 16 | 17 | def generate_random_normal_like_tree(rng, target: PyTree): 18 | keys = generate_rng_like_tree(rng, target) 19 | return tree_map( 20 | lambda l, k: random.normal(k, l.shape, l.dtype), 21 | target, 22 | keys, 23 | ) 24 | 25 | 26 | class RandomNumberGenerator: 27 | def __init__(self, seed: int): 28 | """ 29 | A random number generator object. 30 | 31 | Parameters 32 | ---------- 33 | seed : int 34 | A random seed. 35 | """ 36 | self._rng = random.PRNGKey(seed) 37 | 38 | def get(self) -> jax.Array: 39 | """ 40 | Get the internal random number generator key. Whenever this function is called, the random number generator 41 | key is updated. 42 | 43 | Returns 44 | ------- 45 | jax.Array 46 | A random number generator key. 47 | """ 48 | self._rng = random.split(self._rng)[0] 49 | return self._rng 50 | 51 | 52 | class WithRNG: 53 | @property 54 | def rng(self) -> RandomNumberGenerator: 55 | """ 56 | Invoke the random number generator object. 57 | 58 | Returns 59 | ------- 60 | The random number generator object. 61 | """ 62 | return self._rng 63 | 64 | @rng.setter 65 | def rng(self, rng: RandomNumberGenerator): 66 | """ 67 | Set a random number generator object. 68 | 69 | Parameters 70 | ---------- 71 | rng : RandomNumberGenerator 72 | A random number generator object. 73 | """ 74 | self._rng = rng 75 | -------------------------------------------------------------------------------- /fortuna/utils/strings.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Dict, 3 | List, 4 | Optional, 5 | Tuple, 6 | ) 7 | 8 | from jax.tree_util import tree_map 9 | import numpy as np 10 | 11 | from fortuna.typing import Array 12 | 13 | 14 | def convert_string_to_tuple(s: str) -> Tuple: 15 | return tuple([ord(c) for c in s]) 16 | 17 | 18 | def convert_string_to_np_array(s: str) -> np.ndarray: 19 | return np.array([ord(c) for c in s]) 20 | 21 | 22 | def encode_tuple_of_lists_of_strings_to_numpy( 23 | a: Optional[Tuple[List[str]]], 24 | ) -> Optional[Tuple[List[Array]]]: 25 | return ( 26 | tuple([[convert_string_to_np_array(s) for s in key_path] for key_path in a]) 27 | if a is not None 28 | else None 29 | ) 30 | 31 | 32 | def decode_encoded_tuple_of_lists_of_strings_to_array( 33 | encoded: Optional[Dict[str, List[Array]]], 34 | ) -> Optional[Tuple[List[str], ...]]: 35 | if encoded is None: 36 | return None 37 | encoded = tree_map(lambda v: "".join([chr(o) for o in v]), encoded) 38 | if isinstance(encoded, dict): 39 | return tuple([list(v.values()) for k, v in encoded.items()]) 40 | else: 41 | return encoded 42 | -------------------------------------------------------------------------------- /fortuna/utils/training.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax.tree_util import ( 3 | tree_map, 4 | tree_reduce, 5 | ) 6 | 7 | 8 | def clip_grandients_by_value(grad: jnp.ndarray, max_grad_val: float) -> jnp.ndarray: 9 | clip_fn = lambda z: jnp.clip(z, -max_grad_val, max_grad_val) 10 | grad = tree_map(clip_fn, grad) 11 | return grad 12 | 13 | 14 | def clip_grandients_by_norm(grad: jnp.ndarray, max_grad_norm: float) -> jnp.ndarray: 15 | grad_norm = jnp.sqrt( 16 | tree_reduce(lambda x, y: x + jnp.sum(y**2), grad, initializer=0) 17 | ) 18 | mult = jnp.minimum(1, max_grad_norm / (1e-7 + grad_norm)) 19 | grad = tree_map(lambda z: mult * z, grad) 20 | return grad 21 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/tests/__init__.py -------------------------------------------------------------------------------- /tests/fortuna/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/tests/fortuna/__init__.py -------------------------------------------------------------------------------- /tests/fortuna/hallucination/scoring.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from jax import random 4 | import numpy as np 5 | import torch 6 | 7 | from fortuna.data import InputsLoader 8 | from fortuna.hallucination.scoring.inv_perplexity import inv_perplexity 9 | 10 | 11 | class TestScoringModel(unittest.TestCase): 12 | def test_score(self): 13 | logits = torch.ones((5, 10, 3)) 14 | labels = torch.ones( 15 | ( 16 | 4, 17 | 10, 18 | ) 19 | ) 20 | 21 | assert inv_perplexity(logits=logits, labels=labels).shape == () 22 | assert inv_perplexity(logits=logits, labels=labels, init_pos=2).shape == () 23 | -------------------------------------------------------------------------------- /tests/fortuna/prob_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/fortuna/19c7d9c128440f6eabee387db60a85032270f33c/tests/fortuna/prob_model/__init__.py -------------------------------------------------------------------------------- /tests/fortuna/prob_model/test_preconditioner.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | 5 | from fortuna.prob_model.posterior.sgmcmc.sgmcmc_preconditioner import ( 6 | identity_preconditioner, 7 | rmsprop_preconditioner, 8 | ) 9 | 10 | 11 | class TestPreconditioner(unittest.TestCase): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.params = { 15 | "p1": jnp.zeros([1, 2], jnp.float32), 16 | "p2": jnp.zeros([2, 1], jnp.float32), 17 | } 18 | self.grad = { 19 | "p1": jnp.ones([1, 2], jnp.float32), 20 | "p2": jnp.ones([2, 1], jnp.float32), 21 | } 22 | 23 | def test_rmsprop(self): 24 | preconditioner = rmsprop_preconditioner() 25 | state = preconditioner.init(self.params) 26 | state = preconditioner.update_preconditioner(self.grad, state) 27 | result = preconditioner.multiply_by_m_inv(self.params, state) 28 | assert "p1" in result and "p2" in result 29 | result = preconditioner.multiply_by_m_sqrt(self.params, state) 30 | assert "p1" in result and "p2" in result 31 | result = preconditioner.multiply_by_m_sqrt_inv(self.params, state) 32 | assert "p1" in result and "p2" in result 33 | 34 | def test_identity(self): 35 | preconditioner = identity_preconditioner() 36 | state = preconditioner.init(self.params) 37 | state = preconditioner.update_preconditioner(self.grad, state) 38 | result = preconditioner.multiply_by_m_inv(self.params, state) 39 | assert "p1" in result and "p2" in result 40 | result = preconditioner.multiply_by_m_sqrt(self.params, state) 41 | assert "p1" in result and "p2" in result 42 | result = preconditioner.multiply_by_m_sqrt_inv(self.params, state) 43 | assert "p1" in result and "p2" in result 44 | -------------------------------------------------------------------------------- /tests/fortuna/prob_model/test_step_schedule.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | 5 | from fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule import ( 6 | constant_schedule, 7 | constant_schedule_with_cosine_burnin, 8 | cosine_schedule, 9 | cyclical_cosine_schedule_with_const_burnin, 10 | polynomial_schedule, 11 | ) 12 | 13 | 14 | class TestStepSchedule(unittest.TestCase): 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.count = jnp.zeros([], jnp.int32) 18 | 19 | def test_constant(self): 20 | schedule_fn = constant_schedule(init_step_size=1e-1) 21 | assert jnp.allclose(schedule_fn(self.count), 1e-1) 22 | assert jnp.allclose(schedule_fn(self.count + 1), 1e-1) 23 | 24 | def test_cosine(self): 25 | schedule_fn = cosine_schedule(init_step_size=1e-1, total_steps=10) 26 | assert jnp.allclose(schedule_fn(self.count), 1e-1) 27 | assert not jnp.allclose(schedule_fn(self.count + 1), schedule_fn(self.count)) 28 | assert jnp.allclose(schedule_fn(self.count + 10), 0) 29 | assert jnp.allclose(schedule_fn(self.count + 20), 1e-1) 30 | 31 | def test_polynomial(self): 32 | schedule_fn = polynomial_schedule() 33 | assert schedule_fn(self.count + 1) < schedule_fn(self.count) 34 | 35 | def test_cosine_burnin(self): 36 | schedule_fn = constant_schedule_with_cosine_burnin( 37 | init_step_size=1e-1, final_step_size=1e-2, burnin_steps=10 38 | ) 39 | assert jnp.allclose(schedule_fn(self.count), 1e-1) 40 | assert not jnp.allclose(schedule_fn(self.count + 1), schedule_fn(self.count)) 41 | assert jnp.allclose(schedule_fn(self.count + 10), 1e-2) 42 | assert jnp.allclose(schedule_fn(self.count + 11), 1e-2) 43 | 44 | def test_const_burnin(self): 45 | schedule_fn = cyclical_cosine_schedule_with_const_burnin( 46 | init_step_size=1e-1, burnin_steps=10, cycle_length=10 47 | ) 48 | assert jnp.allclose(schedule_fn(self.count), 1e-1) 49 | assert jnp.allclose(schedule_fn(self.count + 1), 1e-1) 50 | assert not jnp.allclose(schedule_fn(self.count + 12), 1e-1) 51 | -------------------------------------------------------------------------------- /tests/fortuna/test_kernel_regression.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from jax import random 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from fortuna.kernel_regression.nadaraya_watson import NadarayaWatsonKernelRegressor 8 | 9 | 10 | class TestKernelRegression(unittest.TestCase): 11 | def test_nadaraya_watson(self): 12 | train_x = random.normal(random.PRNGKey(0), shape=(3,)) 13 | train_y = random.normal(random.PRNGKey(1), shape=(3,)) 14 | eval_x = random.normal(random.PRNGKey(2), shape=(4,)) 15 | 16 | kr = NadarayaWatsonKernelRegressor(train_inputs=train_x, train_targets=train_y) 17 | preds = kr.predict(inputs=eval_x) 18 | assert preds.shape == (4,) 19 | 20 | kr = NadarayaWatsonKernelRegressor( 21 | train_inputs=np.array(train_x), train_targets=np.array(train_y) 22 | ) 23 | preds = kr.predict(inputs=np.array(eval_x)) 24 | assert preds.shape == (4,) 25 | 26 | with self.assertRaises(ValueError): 27 | NadarayaWatsonKernelRegressor( 28 | train_inputs=train_x, train_targets=train_y[None] 29 | ) 30 | with self.assertRaises(ValueError): 31 | NadarayaWatsonKernelRegressor( 32 | train_inputs=train_x[None], train_targets=train_y 33 | ) 34 | with self.assertRaises(ValueError): 35 | NadarayaWatsonKernelRegressor( 36 | train_inputs=jnp.concatenate((train_x, train_y)), train_targets=train_y 37 | ) 38 | -------------------------------------------------------------------------------- /tests/fortuna/test_metric.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from fortuna.metric.classification import brier_score 6 | 7 | 8 | class TestBrierScore(unittest.TestCase): 9 | def test_brier_score(self): 10 | probs = np.random.normal(size=(10, 3)) 11 | targets = np.random.normal(size=10) 12 | assert np.atleast_1d(brier_score(probs, targets)).shape == (1,) 13 | -------------------------------------------------------------------------------- /tests/fortuna/test_output_maker.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from flax.core import FrozenDict 4 | from jax import random 5 | import jax.numpy as jnp 6 | 7 | from fortuna.model.mlp import MLP 8 | from fortuna.model.model_manager.classification import ClassificationModelManager 9 | from fortuna.model.model_manager.regression import RegressionModelManager 10 | from tests.make_data import make_array_random_inputs 11 | 12 | 13 | class TestModelManagers(unittest.TestCase): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.shape_inputs = (4,) 17 | self.output_dim = 2 18 | self.n_inputs = 10 19 | self.rng = random.PRNGKey(0) 20 | self.model = MLP(output_dim=self.output_dim) 21 | self.lik_log_var = MLP(output_dim=self.output_dim) 22 | 23 | def test_classifier_model_manager_apply(self): 24 | classifier_model_manager = ClassificationModelManager(self.model) 25 | params = FrozenDict( 26 | dict(model=self.model.init(self.rng, jnp.zeros((2,) + self.shape_inputs))) 27 | ) 28 | 29 | inputs = make_array_random_inputs( 30 | n_inputs=self.n_inputs, shape_inputs=self.shape_inputs 31 | ) 32 | assert classifier_model_manager.apply(params, inputs).shape == ( 33 | self.n_inputs, 34 | self.output_dim, 35 | ) 36 | 37 | def test_regressor_model_manager_apply(self): 38 | regressor_model_manager = RegressionModelManager(self.model, self.lik_log_var) 39 | params = FrozenDict( 40 | dict( 41 | model=self.model.init(self.rng, jnp.zeros((2,) + self.shape_inputs)), 42 | lik_log_var=self.model.init( 43 | self.rng, jnp.zeros((2,) + self.shape_inputs) 44 | ), 45 | ) 46 | ) 47 | 48 | inputs = make_array_random_inputs( 49 | n_inputs=self.n_inputs, shape_inputs=self.shape_inputs 50 | ) 51 | assert regressor_model_manager.apply(params, inputs).shape == ( 52 | self.n_inputs, 53 | 2 * self.output_dim, 54 | ) 55 | -------------------------------------------------------------------------------- /tests/fortuna/test_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | from fortuna.plot import plot_reliability_diagram 8 | 9 | 10 | class TestStates(unittest.TestCase): 11 | def test_reliability_diagram(self): 12 | with tempfile.TemporaryDirectory() as tmp_dir: 13 | accs = [np.random.normal(size=20), np.random.normal(size=20)] 14 | confs = [np.random.normal(size=20), np.random.normal(size=20)] 15 | labels = ["a", "b"] 16 | plot_reliability_diagram(accs, confs) 17 | plot_reliability_diagram(accs[0], confs[0]) 18 | plot_reliability_diagram(accs, confs, labels=labels) 19 | plot_reliability_diagram( 20 | accs, confs, fname=os.path.join(tmp_dir, "tmp.png") 21 | ) 22 | plot_reliability_diagram( 23 | accs, confs, fname=os.path.join(tmp_dir, "tmp.png") 24 | ) 25 | plot_reliability_diagram(accs, confs, title="bla") 26 | -------------------------------------------------------------------------------- /tests/fortuna/test_prior.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from jax import numpy as jnp 4 | from jax.flatten_util import ravel_pytree 5 | 6 | from fortuna.prob_model.prior import ( 7 | DiagonalGaussianPrior, 8 | IsotropicGaussianPrior, 9 | ) 10 | from fortuna.utils.random import RandomNumberGenerator 11 | 12 | 13 | class TestIsotropicDiagGaussianPrior(unittest.TestCase): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.log_var = 0.0 17 | self.prior = IsotropicGaussianPrior(log_var=self.log_var) 18 | self.prior.rng = RandomNumberGenerator(seed=0) 19 | self.params = dict(model=jnp.arange(3), lik_log_var=jnp.arange(4, 7)) 20 | 21 | def test_log_joint_prob(self): 22 | assert jnp.array([self.prior.log_joint_prob(self.params)]).shape == (1,) 23 | assert jnp.allclose( 24 | self.prior.log_joint_prob(jnp.zeros(2)), 25 | -(jnp.log(2 * jnp.pi) + self.log_var), 26 | ) 27 | 28 | def test_sample(self): 29 | n_params = len(ravel_pytree(self.params)[0]) 30 | rav_samples = ravel_pytree(self.prior.sample(self.params))[0] 31 | assert rav_samples.size == n_params 32 | 33 | 34 | class TestDiagGaussianPrior(unittest.TestCase): 35 | def __init__(self, *args, **kwargs): 36 | super().__init__(*args, **kwargs) 37 | self.log_var = 0.1 + jnp.arange(-2, 4) 38 | self.prior = DiagonalGaussianPrior(log_var=self.log_var) 39 | self.prior.rng = RandomNumberGenerator(seed=0) 40 | self.params = dict(model=jnp.arange(3), lik_log_var=jnp.arange(4, 7)) 41 | self.n_samples = 3 42 | 43 | def test_log_joint_prob(self): 44 | assert jnp.array([self.prior.log_joint_prob(self.params)]).shape == (1,) 45 | assert jnp.allclose( 46 | self.prior.log_joint_prob(jnp.zeros(len(self.log_var))), 47 | -0.5 * jnp.sum(jnp.log(2 * jnp.pi) + self.log_var), 48 | ) 49 | 50 | def test_sample(self): 51 | n_params = len(ravel_pytree(self.params)[0]) 52 | rav_samples = ravel_pytree(self.prior.sample(self.params))[0] 53 | assert rav_samples.size == n_params 54 | -------------------------------------------------------------------------------- /tests/fortuna/test_state.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | 5 | from fortuna.output_calib_model.state import OutputCalibState 6 | from fortuna.output_calibrator.output_calib_manager.state import OutputCalibManagerState 7 | from fortuna.prob_model.joint.state import JointState 8 | 9 | 10 | class TestStates(unittest.TestCase): 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | def test_joint_state(self): 15 | d = dict( 16 | model=dict(params=jnp.array([0.0]), batch_stats=jnp.array([0.0])), 17 | lik_log_var=dict(params=jnp.array([1.0]), batch_stats=jnp.array([1.0])), 18 | ) 19 | js = JointState.init_from_dict(d) 20 | assert js.params == dict( 21 | model=dict(params=jnp.array([0.0])), 22 | lik_log_var=dict(params=jnp.array([1.0])), 23 | ) 24 | assert js.mutable == dict( 25 | model=dict(batch_stats=jnp.array([0.0])), 26 | lik_log_var=dict(batch_stats=jnp.array([1.0])), 27 | ) 28 | 29 | def test_output_calib_manager_state(self): 30 | cs = OutputCalibManagerState.init_from_dict( 31 | dict( 32 | output_calibrator=dict( 33 | params=jnp.array([0.0]), batch_stats=jnp.array([0.0]) 34 | ) 35 | ) 36 | ) 37 | assert cs.params == dict(output_calibrator=dict(params=jnp.array([0.0]))) 38 | assert cs.mutable == dict(output_calibrator=dict(batch_stats=jnp.array([0.0]))) 39 | 40 | def test_calib_state(self): 41 | cs = OutputCalibState.init_from_dict(dict(params=dict(a=1), mutable=dict(b=2))) 42 | assert hasattr(cs.params, "unfreeze") 43 | assert "a" in cs.params 44 | assert hasattr(cs.mutable, "unfreeze") 45 | assert "b" in cs.mutable 46 | -------------------------------------------------------------------------------- /tests/fortuna/test_temp_scaling.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from jax import random 4 | import jax.numpy as jnp 5 | import optax 6 | 7 | from fortuna.data.loader import DataLoader 8 | from fortuna.model.mlp import MLP 9 | from fortuna.output_calibrator.classification import ClassificationTemperatureScaler 10 | from fortuna.prob_model.classification import ProbClassifier 11 | from fortuna.prob_model.posterior.map.map_approximator import MAPPosteriorApproximator 12 | from tests.make_data import make_array_random_data 13 | 14 | 15 | class TestCalibrators(unittest.TestCase): 16 | def __init__(self, *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | self.rng = random.PRNGKey(0) 19 | self.n_inputs = 6 20 | self.shape_inputs = (4,) 21 | self.output_dim = 2 22 | self.checkpoint_dir = "logs" 23 | 24 | self.class_data_loader = DataLoader.from_array_data( 25 | make_array_random_data( 26 | n_data=self.n_inputs, 27 | shape_inputs=self.shape_inputs, 28 | output_dim=self.output_dim, 29 | output_type="discrete", 30 | ) 31 | ) 32 | 33 | def test_calibrate_prob_model(self): 34 | prob_model = ProbClassifier( 35 | model=MLP(self.output_dim), 36 | output_calibrator=ClassificationTemperatureScaler(), 37 | posterior_approximator=MAPPosteriorApproximator(), 38 | ) 39 | status = prob_model.posterior.fit( 40 | train_data_loader=self.class_data_loader, 41 | optimizer=optax.adam(1e-2), 42 | n_epochs=2, 43 | ) 44 | status = prob_model.calibrate(self.class_data_loader) 45 | s = prob_model.posterior.state.get() 46 | assert s.calib_params["output_calibrator"]["params"]["log_temp"].shape == (1,) 47 | assert s.calib_params["output_calibrator"]["params"]["log_temp"] != jnp.array( 48 | [0.0] 49 | ) 50 | -------------------------------------------------------------------------------- /tests/make_model.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | from fortuna.model.utils.spectral_norm import WithSpectralNorm 5 | 6 | 7 | class MyModel(nn.Module): 8 | output_dim: int 9 | dense: nn.Module = nn.Dense 10 | 11 | @nn.compact 12 | def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: 13 | if hasattr(self, "spectral_norm"): 14 | dense = self.spectral_norm(self.dense, train=train) 15 | else: 16 | dense = self.dense 17 | x = x.reshape(x.shape[0], -1) 18 | x = dense(2, name="l1")(x) 19 | x = nn.Dropout(rate=0.9)(x, deterministic=not train) 20 | x = dense(self.output_dim, name="l2")(x) 21 | return x 22 | 23 | 24 | class MyModelWithSpectralNorm(WithSpectralNorm, MyModel): 25 | pass 26 | --------------------------------------------------------------------------------