├── .github ├── ISSUE_TEMPLATES │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── docs.yml │ ├── lint.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── LICENSE ├── docs ├── README.md ├── api │ ├── classifier.md │ ├── constrained_module.md │ ├── datasets.md │ ├── enums.md │ ├── feature_config.md │ ├── layers.md │ ├── model_configs.md │ ├── models.md │ ├── plots.md │ └── utils.md ├── concepts │ ├── calibrators.md │ ├── classifier.md │ ├── model_types.md │ ├── plotting.md │ └── shape_constraints.md ├── contributing.md ├── help.md ├── img │ ├── dnn_diagram.png │ ├── hours_per_week_calibrator.png │ ├── linear_coefficients.png │ ├── occupation_calibrator.png │ └── thal_calibrator.png ├── walkthroughs │ └── uci_adult_income.md └── why.md ├── examples ├── basic_classifier.py ├── calibrated_linear_classification.py ├── classifier_plotting.py └── classifier_with_monotonicity.py ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml ├── pytorch_lattice ├── __init__.py ├── classifier.py ├── constrained_module.py ├── datasets.py ├── enums.py ├── feature_config.py ├── layers │ ├── __init__.py │ ├── categorical_calibrator.py │ ├── lattice.py │ ├── linear.py │ ├── numerical_calibrator.py │ └── rtl.py ├── model_configs.py ├── models │ ├── __init__.py │ ├── calibrated_lattice.py │ ├── calibrated_linear.py │ └── features.py ├── plots.py └── utils │ ├── __init__.py │ ├── data.py │ └── models.py ├── requirements-dev.txt ├── requirements-docs.txt ├── requirements.txt └── tests ├── __init__.py ├── layers ├── __init__.py ├── test_categorical_calibrator.py ├── test_lattice.py ├── test_linear.py ├── test_numerical_calibrator.py └── test_rtl.py ├── models ├── __init__.py ├── test_calibrated_lattice.py ├── test_calibrated_linear.py └── test_features.py ├── test_classifier.py ├── test_feature_config.py ├── test_model_configs.py ├── testing_utils.py └── utils ├── test_data.py └── test_models.py /.github/ISSUE_TEMPLATES/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | ## Describe the bug: 4 | A clear and concise description of what the bug is. 5 | 6 | --- 7 | 8 | **To Reproduce** 9 | Steps to reproduce the behavior: 10 | 11 | 1. Go to '...' 12 | 2. Click on '....' 13 | 3. Scroll down to '....' 14 | 4. See error 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Desktop (please complete the following information):** 23 | 24 | - OS: [e.g. iOS] 25 | - Browser [e.g. chrome, safari] 26 | - Version [e.g. 22] 27 | 28 | **Smartphone (please complete the following information):** 29 | 30 | - Device: [e.g. iPhone6] 31 | - OS: [e.g. iOS8.1] 32 | - Browser [e.g. stock browser, safari] 33 | - Version [e.g. 22] 34 | 35 | **Additional context** 36 | Add any other context about the problem here. 37 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATES/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | ## About: 4 | Suggest an idea for this project 5 | 6 | --- 7 | 8 | **Is your feature request related to a problem? Please describe.** 9 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 10 | 11 | **Describe the solution you'd like** 12 | A clear and concise description of what you want to happen. 13 | 14 | **Describe alternatives you've considered** 15 | A clear and concise description of any alternative solutions or features you've considered. 16 | 17 | **Additional context** 18 | Add any other context or screenshots about the feature request here. 19 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - uses: actions/setup-python@v4 17 | with: 18 | python-version: 3.x 19 | - uses: actions/cache@v2 20 | with: 21 | key: ${{ github.ref }} 22 | path: .cache 23 | 24 | - run: pip install -r requirements-docs.txt 25 | - run: mkdocs gh-deploy --force 26 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | run-name: ${{ github.actor }} is linting the package 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | jobs: 13 | lint: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | 18 | - name: Run Ruff 19 | uses: chartboost/ruff-action@v1 20 | 21 | - name: Install Poetry 22 | uses: snok/install-poetry@v1.3.1 23 | 24 | - name: Install dependencies 25 | run: poetry install --with dev 26 | 27 | - name: Run mypy 28 | run: poetry run mypy . 29 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | run-name: ${{ github.actor }} is uploading a new release to PyPI 3 | 4 | on: 5 | release: 6 | types: [published] 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | release: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | 18 | - name: Set Up Python 3.10 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | 23 | - name: Install Poetry 24 | uses: snok/install-poetry@v1.3.1 25 | env: 26 | ACTIONS_ALLOW_UNSECURE_COMMANDS: "true" 27 | 28 | - name: Get Release Version 29 | run: echo "RELEASE_VERSION=$(poetry version | awk '{print $2}')" >> $GITHUB_ENV 30 | 31 | - name: Build And Publish Python Package 32 | run: poetry publish --build 33 | env: 34 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_PYTORCH_LATTICE_RELEASE_TOKEN }} 35 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | run-name: ${{ github.actor }} is testing the package 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | 16 | strategy: 17 | matrix: 18 | python-version: ["3.9", "3.10", "3.11"] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | 23 | - name: Set Up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install Poetry 29 | uses: snok/install-poetry@v1.3.1 30 | 31 | - name: Install dependencies 32 | run: poetry install --with dev 33 | 34 | - name: Run Tests 35 | run: poetry run pytest tests/ 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | **__pycache__ 3 | .ipynb_checkpoints 4 | .idea/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # mypy 32 | .mypy_cache/ 33 | .dmypy.json 34 | dmypy.json 35 | 36 | # Environments 37 | .env 38 | .venv 39 | env/ 40 | venv/ 41 | ENV/ 42 | env.bak/ 43 | venv.bak/ 44 | .envrc 45 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "charliermarsh.ruff", 4 | "eamodio.gitlens", 5 | "github.vscode-github-actions", 6 | "ms-toolsai.jupyter", 7 | "ms-toolsai.jupyter-keymap", 8 | "ms-toolsai.jupyter-renderers", 9 | "ms-python.mypy", 10 | "ms-python.python", 11 | "ms-python.vscode-pylance", 12 | "ms-toolsai.vscode-jupyter-cell-tags", 13 | "ms-toolsai.vscode-jupyter-slideshow" 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.autoSave": "onFocusChange", 3 | "editor.rulers": [ 4 | 88 5 | ], 6 | "editor.formatOnSave": true, 7 | "editor.formatOnSaveMode": "file", 8 | "files.insertFinalNewline": true, 9 | "python.testing.unittestEnabled": false, 10 | "python.testing.pytestEnabled": true, 11 | "[python]": { 12 | "editor.tabSize": 4, 13 | "editor.codeActionsOnSave": { 14 | "source.organizeImports": "explicit" 15 | }, 16 | "editor.defaultFormatter": "charliermarsh.ruff" 17 | }, 18 | "[markdown]": { 19 | "editor.formatOnSave": false 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 William Bakst. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started with PyTorch Lattice 2 | 3 | A PyTorch implementation of constrained optimization and modeling techniques 4 | 5 | - **Transparent Models**: Glassbox models to provide increased interpretability and insights into your ML models. 6 | - **Shape Constraints**: Embed domain knowledge directly into the model through feature constraints. 7 | - **Rate Constraints (Coming soon...)**: Optimize any PyTorch model under a set of constraints on rates (e.g. FPR < 1%). Rates can be calculated both for the entire dataset as well as specific slices. 8 | 9 | --- 10 | 11 | [![GitHub stars](https://img.shields.io/github/stars/ControlAI/pytorch-lattice.svg)](https://github.com/ControlAI/pytorch-lattice/stargazers) 12 | [![Documentation](https://img.shields.io/badge/docs-available-brightgreen)](https://controlai.github.io/pytorch-lattice/) 13 | [![](https://github.com/ControlAI/pytorch-lattice/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/ControlAI/pytorch-lattice/actions/workflows/test.yml) 14 | [![GitHub issues](https://img.shields.io/github/issues/ControlAI/pytorch-lattice.svg)](https://github.com/ControlAI/pytorch-lattice/issues) 15 | [![Github discussions](https://img.shields.io/github/discussions/ControlAI/pytorch-lattice)](https:github.com/ControlAI/pytorch-lattice/discussions) 16 | [![GitHub license](https://img.shields.io/github/license/ControlAI/pytorch-lattice.svg)](https://github.com/ControlAI/pytorch-lattice/blob/main/LICENSE) 17 | [![PyPI version](https://img.shields.io/pypi/v/pytorch-lattice.svg)](https://pypi.python.org/pypi/pytorch-lattice) 18 | [![PyPI pyversions](https://img.shields.io/pypi/pyversions/pytorch-lattice.svg)](https://pypi.python.org/pypi/pytorch-lattice) 19 | 20 | --- 21 | 22 | ## Installation 23 | 24 | Install PyTorch Lattice and start training and analyzing calibrated models in minutes. 25 | 26 | ```sh 27 | $ pip install pytorch-lattice 28 | ``` 29 | 30 | ## Quickstart 31 | 32 | ### Step 1. Import the package 33 | 34 | First, import the PyTorch Lattice library: 35 | 36 | ```py 37 | import pytorch_lattice as pyl 38 | ``` 39 | 40 | ### Step 2. Load data and fit a classifier 41 | 42 | Load the UCI Statlog (Heart) dataset. Then create a base classifier and fit it to the data. Creating the base classifier requires only the feature names. 43 | 44 | ```py 45 | X, y = pyl.datasets.heart() 46 | clf = pyl.Classifier(X.columns).fit(X, y) 47 | ``` 48 | 49 | ### Step 3. Plot a feature calibrator 50 | 51 | Now that you've trained a classifier, you can plot the feature calibrators to better understand how the model is understanding each feature. 52 | 53 | ```py 54 | pyl.plots.calibrator(clf.model, "thal") 55 | ``` 56 | 57 | ![Thal Calibrator](img/thal_calibrator.png) 58 | 59 | ### Step 4. What's Next? 60 | 61 | - Check out the [Concepts](concepts/classifier.md) section to dive deeper into the library and the core features that make it powerful, such as [calibrators](concepts/calibrators.md) and [shape constraints](concepts/shape_constraints.md). 62 | - You can follow along with more detailed [walkthroughs](walkthroughs/uci_adult_income.md) to get a better understanding of how to utilize the library to effectively model your data. You can also take a look at [code examples](https://github.com/ControlAI/pytorch-lattice/tree/main/examples) in the repo. 63 | - The [API Reference](api/layers.md) contains full details on all classes, methods, functions, etc. 64 | 65 | ## Related Research 66 | 67 | - [Monotonic Kronecker-Factored Lattice](https://openreview.net/forum?id=0pxiMpCyBtr), William Taylor Bakst, Nobuyuki Morioka, Erez Louidor, International Conference on Learning Representations (ICLR), 2021 68 | - [Multidimensional Shape Constraints](https://proceedings.mlr.press/v119/gupta20b.html), Maya Gupta, Erez Louidor, Oleksandr Mangylov, Nobu Morioka, Taman Narayan, Sen Zhao, Proceedings of the 37th International Conference on Machine Learning (PMLR), 2020 69 | - [Deontological Ethics By Monotonicity Shape Constraints](https://arxiv.org/abs/2001.11990), Serena Wang, Maya Gupta, International Conference on Artificial Intelligence and Statistics (AISTATS), 2020 70 | - [Shape Constraints for Set Functions](http://proceedings.mlr.press/v97/cotter19a.html), Andrew Cotter, Maya Gupta, H. Jiang, Erez Louidor, Jim Muller, Taman Narayan, Serena Wang, Tao Zhu. International Conference on Machine Learning (ICML), 2019 71 | - [Diminishing Returns Shape Constraints for Interpretability and Regularization](https://papers.nips.cc/paper/7916-diminishing-returns-shape-constraints-for-interpretability-and-regularization), Maya Gupta, Dara Bahri, Andrew Cotter, Kevin Canini, Advances in Neural Information Processing Systems (NeurIPS), 2018 72 | - [Deep Lattice Networks and Partial Monotonic Functions](https://research.google.com/pubs/pub46327.html), Seungil You, Kevin Canini, David Ding, Jan Pfeifer, Maya R. Gupta, Advances in Neural Information Processing Systems (NeurIPS), 2017 73 | - [Fast and Flexible Monotonic Functions with Ensembles of Lattices](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices), Mahdi Milani Fard, Kevin Canini, Andrew Cotter, Jan Pfeifer, Maya Gupta, Advances in Neural Information Processing Systems (NeurIPS), 2016 74 | - [Monotonic Calibrated Interpolated Look-Up Tables](http://jmlr.org/papers/v17/15-243.html), Maya Gupta, Andrew Cotter, Jan Pfeifer, Konstantin Voevodski, Kevin Canini, Alexander Mangylov, Wojciech Moczydlowski, Alexander van Esbroeck, Journal of Machine Learning Research (JMLR), 2016 75 | - [Optimized Regression for Efficient Function Evaluation](http://ieeexplore.ieee.org/document/6203580/), Eric Garcia, Raman Arora, Maya R. Gupta, IEEE Transactions on Image Processing, 2012 76 | - [Lattice Regression](https://papers.nips.cc/paper/3694-lattice-regression), Eric Garcia, Maya Gupta, Advances in Neural Information Processing Systems (NeurIPS), 2009 77 | 78 | ## Contributing 79 | 80 | PyTorch Lattice welcomes contributions from the community! See the [contribution guide](contributing.md) for more information on the development workflow. For bugs and feature requests, visit our [GitHub Issues](https://github.com/ControlAI/pytorch-lattice/issues) and check out our [templates](https://github.com/ControlAI/pytorch-lattice/tree/main/.github/ISSUE_TEMPLATES). 81 | 82 | ## How To Help 83 | 84 | Any and all help is greatly appreciated! Check out our page on [how you can help](help.md). 85 | 86 | ## Roadmap 87 | 88 | Check out the our [roadmap](https://github.com/orgs/ControlAI/projects/1/views/1) to see what's planned. If there's an item that you really want that isn't assigned or in progress, take a stab at it! 89 | 90 | ## Versioning 91 | 92 | PyTorch Lattice uses [Semantic Versioning](https://semver.org/). 93 | 94 | ## License 95 | 96 | This project is licensed under the terms of the [MIT License](https://github.com/ControlAI/pytorch-lattice/blob/main/LICENSE). 97 | -------------------------------------------------------------------------------- /docs/api/classifier.md: -------------------------------------------------------------------------------- 1 | # classifier 2 | 3 | ::: pytorch_lattice.classifier.Classifier 4 | -------------------------------------------------------------------------------- /docs/api/constrained_module.md: -------------------------------------------------------------------------------- 1 | # constrained_module 2 | 3 | ::: pytorch_lattice.constrained_module.ConstrainedModule 4 | -------------------------------------------------------------------------------- /docs/api/datasets.md: -------------------------------------------------------------------------------- 1 | # datasets 2 | 3 | ::: pytorch_lattice.datasets 4 | -------------------------------------------------------------------------------- /docs/api/enums.md: -------------------------------------------------------------------------------- 1 | # enums 2 | 3 | ::: pytorch_lattice.enums 4 | -------------------------------------------------------------------------------- /docs/api/feature_config.md: -------------------------------------------------------------------------------- 1 | # feature_config 2 | 3 | ::: pytorch_lattice.feature_config 4 | -------------------------------------------------------------------------------- /docs/api/layers.md: -------------------------------------------------------------------------------- 1 | # layers 2 | 3 | ::: pytorch_lattice.layers.CategoricalCalibrator 4 | 5 | ::: pytorch_lattice.layers.Lattice 6 | 7 | ::: pytorch_lattice.layers.Linear 8 | 9 | ::: pytorch_lattice.layers.NumericalCalibrator 10 | -------------------------------------------------------------------------------- /docs/api/model_configs.md: -------------------------------------------------------------------------------- 1 | # model_configs 2 | 3 | ::: pytorch_lattice.model_configs 4 | -------------------------------------------------------------------------------- /docs/api/models.md: -------------------------------------------------------------------------------- 1 | # models 2 | 3 | ::: pytorch_lattice.models.CalibratedLattice 4 | 5 | ::: pytorch_lattice.models.CalibratedLinear 6 | 7 | ::: pytorch_lattice.models.features.CategoricalFeature 8 | 9 | ::: pytorch_lattice.models.features.NumericalFeature 10 | -------------------------------------------------------------------------------- /docs/api/plots.md: -------------------------------------------------------------------------------- 1 | # plots 2 | 3 | ::: pytorch_lattice.plots 4 | -------------------------------------------------------------------------------- /docs/api/utils.md: -------------------------------------------------------------------------------- 1 | # utils 2 | 3 | ::: pytorch_lattice.utils.data 4 | 5 | ::: pytorch_lattice.utils.models 6 | -------------------------------------------------------------------------------- /docs/concepts/calibrators.md: -------------------------------------------------------------------------------- 1 | # Calibrators 2 | 3 | Calibrators are one of the core concepts of the PyTorch Lattice library. The library currently implements two types of calibrators: 4 | 5 | - [`CategoricalCalibrator`](../api/layers.md#pytorch_lattice.layers.CategoricalCalibrator): calibrates a categorical value through a mapping from a category to a learned value. 6 | - [`NumericalCalibrator`](../api/layers.md#pytorch_lattice.layers.NumericalCalibrator): calibrates a numerical value through a learned piece-wise linear function. 7 | 8 | Categorical Calibrator | Numerical Calibrator 9 | :------------------------------:|:----------------------------------------: 10 | ![](../img/thal_calibrator.png) | ![](../img/hours_per_week_calibrator.png) 11 | 12 | ## Feature Calibrators 13 | 14 | In a [calibrated model](model_types.md), the first layer is the calibration layer that calibrates each feature using a calibrator that's learned per feature. 15 | 16 | There are three primary benefits to using feature calibrators: 17 | 18 | - Automated Feature Pre-Processing. Rather than relying on the practitioner to determine how to best transform each feature, feature calibrators learn the best transformations from the data. 19 | - Additional Interpretability. Plotting calibrators as bar/line charts helps visualize how the model is understanding each feature. For example, if two input values for a feature have the same calibrated value, then the model considers those two input values equivalent with respect to the prediction. 20 | - [Shape Constraints](shape_constraints.md). Calibrators can be constrained to guarantee certain expected input/output behavior. For example, you might a monotonicity constraint on a feature for square footage to ensure that increasing square footage always increases predicted price. Or perhaps you want a concavity constraint such that increasing a feature for price first increases and then decreases predicted sales. 21 | 22 | ## Output Calibration 23 | 24 | You can also use a `NumericalCalibrator` as the final layer for a model, which is called output calibration. This can provide additional flexibility to the overall model function. 25 | 26 | Furthermore, you can use an output calibrator for post-training distribution matching to calibrate your model to a new distribution without retraining the rest of the model. 27 | 28 | -------------------------------------------------------------------------------- /docs/concepts/classifier.md: -------------------------------------------------------------------------------- 1 | # Classifier 2 | 3 | The [`Classifier`](../api/classifier.md) class is a high-level wrapper around the calibrated modeling functionality to make it extremely easy to fit a calibrated model to a classification task. The class uses declarative configuration and automatically handles the data preparation, feature configuration, model creation, and model training necessary for properly training a calibrated model. 4 | 5 | ## Initialization 6 | 7 | The only required parameter for creating a classifier is the list of features to use: 8 | 9 | ```py 10 | clf = pyl.Classifier(["list", "of", "features"]) 11 | ``` 12 | 13 | You do not need to include all of the feature present in your dataset. When you specify only a subset of the features, the classifier will automatically handle selecting only those features for training. 14 | 15 | ## Fitting 16 | 17 | Fitting the classifier to your data is as simple as calling `fit(...)`: 18 | 19 | ```py 20 | clf.fit(X, y) 21 | ``` 22 | 23 | You can additionally further specify hyperparameters used for fitting such as `epochs`, `batch_size`, and `learning_rate`. Just pass the values in as parameters: 24 | 25 | ```py 26 | clf.fit(X, y, epochs=100, batch_size=512, learning_rate=1e-4) 27 | ``` 28 | 29 | When you call fit, the classifier will train a new model, overwriting any previously trained model. If you want to run a hyperparameter optimization job to find the best setting of hyperparameters, you can first extract the trained model before calling fit again: 30 | 31 | ```py 32 | models = [] 33 | for epochs, batch_size, learning_rate in hyperparameters: 34 | clf.fit(X, y, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate) 35 | models.append(clf.model) 36 | ``` 37 | 38 | The benefit of extracting the model is that you can reuse the same classifier configuration; however, you can also always create a new classifier for each setting instead: 39 | 40 | ```py 41 | clfs = [] 42 | for epochs, batch_size, learning_rate in hyperparameters: 43 | clf = pyl.Classifier(X.columns).fit( 44 | X, y, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate 45 | ) 46 | clfs.append(clf) 47 | ``` 48 | 49 | ## Generate Predictions 50 | 51 | You can generate predictions using the `predict(...)` function: 52 | 53 | ```py 54 | probabilities = clf.predict(X) 55 | logits = clf.predict(X, logits=True) 56 | ``` 57 | 58 | Just make sure that the input `pd.DataFrame` contains all of the features the classifier is expecting. 59 | 60 | ## Model Configuration 61 | 62 | To configure the type of calibrated model to use for the classifier, you can additionally provide a model configuration during initialization: 63 | 64 | ```py 65 | model_config = pyl.model_configs.LinearConfig(use_bias=False) 66 | clf = pyl.Classifier(["list", "of", "features"], model_config) 67 | ``` 68 | 69 | See [Model Types](model_types.md) for more information on the supported model types and [model_configs](../api/model_configs.md) for more information on configuring these models in a classifier. 70 | 71 | ## Feature Configuration 72 | 73 | When you first initialize a calibrator, all features will be initialized using default values. You can further specify configuration options for features by retrieve the feature's configuration from the classifier and calling the corresponding function to set that option: 74 | 75 | ```py 76 | clf.configure("feature").monotonicity("increasing").num_keypoints(10) 77 | ``` 78 | 79 | See [feature_configs](../api/feature_config.md) for all of the available configuration options. 80 | 81 | ## Categorical Features 82 | 83 | If the value type for a feature in the dataset is not numerical (e.g. string), the classifier will automatically handle the feature as categorical, using all unique categories present in the dataset as the categories for the calibrator. 84 | 85 | If you want the classifier to handle a discrete numerical value as a categorical feature, simply convert the values to strings: 86 | 87 | ```py 88 | X["categorical_feature"] = X["categorical_feature"].astype(str) 89 | ``` 90 | 91 | Additionally you can specify a list of categories to use as a configuration option: 92 | 93 | ```py 94 | clf.configure("categorical_feature").categories(["list", "of", "categories"]) 95 | ``` 96 | 97 | Any category in the dataset that is not present in the configured category list will be lumped together into a missing category bucket, which will also have a learned calibration. This can be particularly useful if there are categories in your dataset that appear in very few examples. 98 | 99 | ## Saving & Loading 100 | 101 | The `Classifier` class also provides easy save/load functionality so that you can save your classifiers and load them as necessary to generate predictions: 102 | 103 | ```py 104 | clf.save("path/to/dir") 105 | loaded_clf = pyl.Classifier.load("path/to/dir") 106 | ``` 107 | -------------------------------------------------------------------------------- /docs/concepts/model_types.md: -------------------------------------------------------------------------------- 1 | # Model Types 2 | 3 | The PyTorch Lattice library currently supports two types of calibrated modeling: 4 | 5 | - [`CalibratedLinear`](../api/models.md#pytorch_lattice.models.CalibratedLinear): a calibrated linear model combines calibrated features using a standard [linear](../api/layers.md#pytorch_lattice.layers.Linear) layer, optionally followed by an output calibrator. 6 | 7 | - [`CalibratedLattice`](../api/models.md#pytorch_lattice.models.CalibratedLattice): a calibrated lattice model combines calibrated features using a [lattice](../api/layers.md#pytorch_lattice.layers.Lattice) layer, optionally followed by an output calibrator. The lattice layer can learn higher-order feature interactions, which can help increase model flexibility and thereby performance on more complex prediction tasks. 8 | -------------------------------------------------------------------------------- /docs/concepts/plotting.md: -------------------------------------------------------------------------------- 1 | # Plotting 2 | 3 | The `plots` module provides useful plotting utility functions for visualizing calibrated models. 4 | 5 | ## Feature Calibrators 6 | 7 | For any calibrated model, you can plot feature calibrators. The plotting utility will automatically determine the feature type and generate the corresponding calibrator visualization: 8 | 9 | ```py 10 | pyl.plots.calibrator(clf.model, "feature") 11 | ``` 12 | 13 | Categorical Calibrator | Numerical Calibrator 14 | :------------------------------:|:----------------------------------------: 15 | ![](../img/thal_calibrator.png) | ![](../img/hours_per_week_calibrator.png) 16 | 17 | 18 | The `calibrator(...)` function expects a calibrated model as the first argument so that you can use these functions even if you train a calibrated model manually without the `Classifier` class. 19 | 20 | ## Linear Coefficients 21 | 22 | For calibrated linear models, you can also plot the linear coefficients as a bar chart to better understand how the model is combining calibrated feature values: 23 | 24 | ```py 25 | pyl.plots.linear_coefficients(clf.model) 26 | ``` 27 | 28 | ![](../img/linear_coefficients.png) 29 | -------------------------------------------------------------------------------- /docs/concepts/shape_constraints.md: -------------------------------------------------------------------------------- 1 | # Shape Constraints 2 | 3 | Shape constraints play a crucial role in making calibrated models interpretable by allowing users to impose specific behavioral rules on their machine learning models. These constraints help to reduce – or even eliminate – the impact of noise and inherent biases contained in the data. 4 | 5 | [`Monotonicity`](../api/enums.md#pytorch_lattice.enums.Monotonicity) constraints ensure that the relationship between an input feature and the output prediction consistently increases or decreases. Let's consider our house price prediction task once more. A monotonic constraint on the square footage feature would guarantee that increasing the size of the property increases the predicted price. This makes sense. 6 | 7 | Unimodality constraints (coming soon) create a single peak in the model's output, ensuring that there is only one optimal value for a given input feature. For example, a feature for price used when predicting sales volume may be unimodal since lower prices generally lead to higher sales, but prices that are too low may indicate low quality with one single optimal price. 8 | 9 | Convexity/Concavity constraints (coming soon) ensure that the given feature's value has a convex/concave relationship with the model's output. Looking again at the feature for price for predicting sales volume, it may be that there is a range of optimal prices and not one single optimal price, which would instead be a concavity constraint. 10 | 11 | Trust constraints (coming soon) define the relative importance of input features depending on other features. For instance, a trust constraint can ensure that a model predicting product sales relies more on the star rating (1-5) when the number of reviews is higher, which forces the model's predictions to better align with real-world expectations and rules. 12 | 13 | Dominance constraints (coming soon) are intended to embed that a dominant feature is more important than a weak feature. For example, you might want to constrain a model predicting click-through-rate (CTR) for a specific web link to be more sensitive to past CTR for that web link than the average CTR for the whole website. 14 | 15 | Together, these shape constraints help create machine learning models that are both interpretable and trustworthy. 16 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Setting Up Development Environment 4 | 5 | First, [install pyenv](https://github.com/pyenv/pyenv#installation) so you can run the code under all of the supported environments. Also make sure to [install pyenv-virtualenv](https://github.com/pyenv/pyenv-virtualenv#installation) so you can create python environments with the correct versions. 6 | 7 | To install a specific version of python, you can run e.g. `pyenv install 3.10.9`. You can then create a virtual environment to run and test code locally during development by running the following code from the base directory: 8 | 9 | ```sh 10 | pyenv virtualenv {python_version} env-name 11 | pyenv activate env-name 12 | pip install poetry 13 | poetry install 14 | ``` 15 | 16 | If you'd prefer, you can also use conda to manage your python versions and environments. For installing conda, see their [installation guide](https://conda.io/projects/conda/en/latest/user-guide/install/index.html). 17 | 18 | The following code is an example of how to set up such an environment: 19 | 20 | ```sh 21 | conda create -n env-name pip poetry python={python_version} 22 | conda activate env-name 23 | poetry install 24 | ``` 25 | 26 | Make sure to replace `{python_version}` in the above snippets with the version you want the environment to use (e.g. 3.10.9) and name the environment accordingly (e.g. env-name-3.10). 27 | 28 | ## Development Workflow 29 | 30 | 1. Search through existing [GitHub Issues](https://github.com/ControlAI/pytorch-lattice/issues) to see if what you want to work on has already been added. 31 | 32 | - If not, please create a new issue. This will help to reduce duplicated work. 33 | 34 | 2. For first-time contributors, visit [https://github.com/ControlAI/pytorch-lattice](https://github.com/ControlAI/pytorch-lattice) and "Fork" the repository (see the button in the top right corner). 35 | 36 | - You'll need to set up [SSH authentication](https://docs.github.com/en/authentication/connecting-to-github-with-ssh). 37 | 38 | - Clone the forked project and point it to the main project: 39 | 40 | ```shell 41 | git clone https://github.com//pytorch-lattice.git 42 | git remote add upstream https://github.com/ControlAI/pytorch-lattice.git 43 | ``` 44 | 45 | 3. Development. 46 | 47 | - Make sure you are in sync with the main repo: 48 | 49 | ```shell 50 | git checkout dev 51 | git pull upstream dev 52 | ``` 53 | 54 | - Create a `git` feature branch with a meaningful name where you will add your contributions. 55 | 56 | ```shell 57 | git checkout -b meaningful-branch-name 58 | ``` 59 | 60 | - Start coding! commit your changes locally as you work: 61 | 62 | ```shell 63 | git add pytorch-lattice/modified_file.py tests/test_modified_file.py 64 | git commit -m "feat: specific description of changes contained in commit" 65 | ``` 66 | 67 | - Format your code! 68 | 69 | ```shell 70 | poetry run ruff format . 71 | ``` 72 | 73 | - Lint and test your code! From the base directory, run: 74 | 75 | ```shell 76 | poetry run ruff check . 77 | poetry run mypy . 78 | ``` 79 | 80 | 4. Contributions are submitted through [GitHub Pull Requests](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests) 81 | 82 | - When you are ready to submit your contribution for review, push your branch: 83 | 84 | ```shell 85 | git push origin meaningful-branch-name 86 | ``` 87 | 88 | - Open the printed URL to open a PR. Make sure to fill in a detailed title and description. Submit your PR for review. 89 | 90 | - Link the issue you selected or created under "Development" 91 | 92 | - We will review your contribution and add any comments to the PR. Commit any updates you make in response to comments and push them to the branch (they will be automatically included in the PR) 93 | 94 | ### Pull Requests 95 | 96 | Please conform to the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification for all PR titles and commits. 97 | 98 | ## Formatting & Linting 99 | 100 | In an effort to keep the codebase clean and easy to work with, we use `ruff` for formatting and both `ruff` and `mypy` for linting. Before sending any PR for review, make sure to run both `ruff` and `mypy`. 101 | 102 | If you are using VS Code, then install the extensions in `.vscode/extensions.json` and the workspace settings should automatically run `ruff` formatting on save and show `ruff` and `mypy` errors. 103 | -------------------------------------------------------------------------------- /docs/help.md: -------------------------------------------------------------------------------- 1 | # How to help PyTorch Lattice 2 | 3 | ## Star PyTorch Lattice on GitHub 4 | 5 | ⭐️ You can "star" PyTorch Lattice on [GitHub](https://github.com/ControlAI/pytorch-lattice) ⭐️ 6 | 7 | ## Connect with the author 8 | 9 | - Follow me on GitHub 10 | 11 | - See other related Open Source projects that might help you with machine learning 12 | 13 | - Follow me on [Twitter/X](https://twitter.com/WilliamBakst) 14 | 15 | - Tell me how you use lattice models 16 | - Hear about new announcements or releases 17 | 18 | - Connect with me on [LinkedIn](https://www.linkedin.com/in/wbakst/) 19 | 20 | - Give me any feedback about packages or suggestions 21 | 22 | ## Post about PyTorch Lattice 23 | 24 | - Twitter, Reddit, Hackernews, LinkedIn, and others. 25 | 26 | - We love to hear about how PyTorch Lattice has helped you and in which project/company you are using it. 27 | 28 | ## Help Others 29 | 30 | We are a kind and welcoming community that encourages you to help others with their questions on GitHub Issues / Discussions. 31 | 32 | - Guide for asking questions 33 | - First, search through issues and discussions to see if others have faced similar issues 34 | - Be as specific as possible, add minimal reproducible example 35 | - List out things you have tried, errors, etc 36 | - Close the issue if your question has been successfully answered 37 | - Guide for answering questions 38 | - Understand the question, ask clarifying questions 39 | - If there is sample code, reproduce the issue with code given by original poster 40 | - Give them solution or possibly an alternative that might be better than what original poster is trying to do 41 | - Ask original poster to close the issue 42 | 43 | ## Review Pull Requests 44 | 45 | You are encouraged to review any pull requests. Here is a guideline on how to review a pull request: 46 | 47 | - Understand the problem the pull request is trying to solve 48 | - Ask clarification questions to determine whether the pull request belongs in the package 49 | - Check the code, run it locally, see if it solves the problem described by the pull request 50 | - Add a comment with screenshots or accompanying code to verify that you have tested it 51 | - Check for tests 52 | - Request the original poster to add tests if they do not exist 53 | - Check that tests fail before the PR and succeed after 54 | - This will greatly speed up the review process for a PR and will ultimately make SOTAI a better package 55 | -------------------------------------------------------------------------------- /docs/img/dnn_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/docs/img/dnn_diagram.png -------------------------------------------------------------------------------- /docs/img/hours_per_week_calibrator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/docs/img/hours_per_week_calibrator.png -------------------------------------------------------------------------------- /docs/img/linear_coefficients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/docs/img/linear_coefficients.png -------------------------------------------------------------------------------- /docs/img/occupation_calibrator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/docs/img/occupation_calibrator.png -------------------------------------------------------------------------------- /docs/img/thal_calibrator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/docs/img/thal_calibrator.png -------------------------------------------------------------------------------- /docs/walkthroughs/uci_adult_income.md: -------------------------------------------------------------------------------- 1 | # UCI Adult Income 2 | 3 | For this walkthrough you are going to fit a `Classifier` to the UCI Adult Income dataset to predict whether or not a given input makes more or less than $50K. 4 | 5 | ## Install and import packages 6 | 7 | On top of PyTorch Lattice, you will be using `scikit-learn` for calculating metrics to evaluate your classifiers. First, make sure you have these packages installed: 8 | 9 | ```shell 10 | $ pip install pytorch-lattice scikit-learn 11 | ``` 12 | 13 | Next, import these packages in your script: 14 | 15 | ```py 16 | from sklearn import metrics 17 | 18 | import pytorch_lattice as pyl 19 | ``` 20 | 21 | ## Load the UCI Adult Income dataset 22 | 23 | The `Classifier` expects a `pandas.DataFrame` containing the training data and a `numpy.ndarray` containing the labels for each example. You can use the PyTorch Lattice datasets module to load the data in this form: 24 | 25 | ```py 26 | X, y = pyl.datasets.adult() 27 | ``` 28 | 29 | ## Create and configure a `Classifier` 30 | 31 | Next, you'll want to create a `Classifier` that you can use to fit to a calibrated model to the data. When creating a `Classifier`, the only required field is the list of features to use. For this guide, you are only going to use four features: `age`, `education_num`, `occupation`, and `hours_per_week`. 32 | 33 | By default the calibrated modeling type used will be a calibrated linear model. If you'd like to further configure the model, you can provide a `model_config` with the attributes of your choice. For this guide, you will train a `CalibratedLattice` model. 34 | 35 | ```py 36 | model_config = pyl.model_configs.LatticeConfig() 37 | clf = pyl.Classifier( 38 | ["age", "education_num", "occupation", "hours_per_week"], model_config 39 | ) 40 | ``` 41 | 42 | ## Configure features 43 | 44 | One of the primary benefits of using PyTorch Lattice models is the ability to easily add shape constraints that guarantee certain real-world expectations. For example, it would make sense that someone with a higher level of education would be more likely to make more than $50K compared to someone with a lower level of education, all else being equal. Similarly, you might expect someone who works more hours per week to be more likely to make more than $50K compared to someone who works less hours per week, all else being equal. 45 | 46 | ```py 47 | clf.configure("education_num").monotonicity("increasing") 48 | clf.configure("hours_per_week").monotonicity("increasing") 49 | ``` 50 | 51 | PyTorch Lattice makes it very easy to ensure that the model behaves as expected even after training on data. By setting the `monotonicity` field to `increasing` for both `education_num` and `hours_per_week`, you will guarantee that increasing either of those features will increase the prediction (so long as no other feature values change). 52 | 53 | Of course, the classifier is still trained on data, so the only thing guaranteed is the relationship. How much increasing `education_num` or `hours_per_week` will increase the prediction will still be learned from training. Ultimately you are reducing the risk of unknown outcomes while still learning from data. 54 | 55 | ## Fit the classifier to the data 56 | 57 | Now that you've configured the classifier, fitting it to the data is easy: 58 | 59 | ```py 60 | clf.fit(X, y, batch_size=1024) # using a larger batch size for faster training 61 | ``` 62 | 63 | There are additional training configuration options that you can set such as the number of epochs for which to fit the classifier, the learning rate, and the batch size, which you can set as parameters of the `fit` function. 64 | 65 | ## Generate predictions and evaluate AUC 66 | 67 | Once you've fit your classifier it's easy to generate predictions: 68 | 69 | ```py 70 | preds = clf.predict(X) 71 | ``` 72 | 73 | You can then use `scikit-learn` to calculate `AUC` to evaluate the predictive quality of your classifier: 74 | 75 | ```py 76 | fpr, tpr, _ = metrics.roc_curve(y, preds) 77 | print(f"Train AUC: {metrics.auc(fpr, tpr)}") 78 | # Train AUC: 0.8165439029459205 79 | ``` 80 | 81 | ## Plot calibrators for analysis 82 | 83 | Plotting the calibrators for each feature can help visualize how the model is understanding the features. First, try plotting the calibrator for `occupation`: 84 | 85 | ```py 86 | pyl.plots.calibrator(clf.model, "occupation") 87 | ``` 88 | 89 | ![Occupation Calibrator](../img/occupation_calibrator.png) 90 | 91 | You can see here how each category for `occupation` gets calibrated before going into the lattice layer of the model, which shows us relatively how the model understands each category. For example, you can see that the model things that `Sales` and `Armed-Forces` have a similar likelihood of making more than $50K. 92 | 93 | Interestingly, plotting the calibrator for `hours_per_week` shows that there's a flat region starting around ~52 hours. This indicates that there is a chance that the `hours_per_week` feature is not actually monotonically increasing, in which case you might consider training a new classifier where you do not constrain this feature. 94 | 95 | ![Hours Per Week Calibrator](../img/hours_per_week_calibrator.png) 96 | 97 | When setting constraints, there are two things to keep in mind: 98 | 99 | 1. Do you want to guarantee the constrained behavior regardless of performance? In this case, setting the constraint can make sure that model behavior matches your expectations on unseen examples, which is especially useful when using a model to make decisions. 100 | 2. Does the model have better performance on a validation dataset if you remove the constraint? It is important to remember that adding constraints to a feature may result in worse performance on the training set but actually result in better performance on the validation set. This is because the constraint helps the model to better handle unseen examples if the constraint should in fact be set. 101 | -------------------------------------------------------------------------------- /docs/why.md: -------------------------------------------------------------------------------- 1 | # Why use PyTorch Lattice? 2 | 3 | Many current state-of-the-art machine learning models are built using a black-box modeling approach, which means training an opaque but flexible model like a deep neural net (DNN) on a dataset of training examples. While we know the structure of DNNs, it is precisely this structure that makes them black-box models. 4 | 5 | ![DNN Diagram](img/dnn_diagram.png) 6 | 7 | Every feature goes through a series of fully-connected layers, meaning every node is a function of every feature. Each node becomes a function through training, but the purpose of any individual node is hidden from the user -- only the model knows. How are we supposed to understand or trust a model's predictions if we don't know what any function within the larger system is doing? 8 | 9 | Furthermore, black-box models are 100% reliant on the training data. This means that if a model is producing funky predictions, the solution is to either (1) find more training data and re-train the model, or (2) discover a new model structure tailored to the given task. Neither option is a great choice for the majority of data scientists and machine learning practitioners -- unless they work at a large tech company with the resources dedicated to making such solutions possible -- since gathering and cleaning data and discovering new model structures are not only inherently difficult tasks but also time and cost intensive. 10 | 11 | But every data scientist and machine learning practitioner, even those at large tech companies, has run into issues where their model behaves unexpectedly in the wild because the training data is too different from live examples, especially since real-world data distributions change frequently. 12 | 13 | So, what can we do to reduce the risk of unknown outcomes? 14 | 15 | ## Understanding The _Why_ Of A Model's Predictions 16 | 17 | Without the why, a model's prediction is opaque and difficult to trust, even if it's correct. That's why understanding the why is such an active area of research. It's worth noting that there is a distinction between the two approaches in this field that have seen success: Explainability vs. Interpretability. 18 | 19 | Explainability focuses on explaining a black-box model's predictions, which is a top-down approach. The benefit of this approach is that resulting methods apply to black-box models, meaning that they apply to any machine learning model. The current state-of-the-art explainability technology is Shapley values, which we can use to determine the importance of each feature for any machine learning model. Perhaps we train a model to predict the price of a house and learn that zip code is the most important feature. The downside of this approach is the limitations inherent to a black-box structure. While this knowledge of importance provides general insight into how the model is making predictions, does it really explain anything? How a particular zip code impacts a model's predictions is still a mystery. 20 | 21 | The sad truth is that Explainability often only points to common sense results -- not illuminating insights. 22 | 23 | Interpretability is instead a bottom-up approach focused on providing transparency through calibrated models structured specifically with illuminating insights and control in mind. The downside to this approach is that it requires more input from the user; however, this input is invaluable for the model to understand the system in the way we expect. The benefit of this approach is that the resulting models are much easier to understand. For example, we can analyze the way a calibrated model handles individual features by charting the corresponding feature calibration layer -- the layer specific to calibrated models that calibrates the initial input for later layers. For a categorical feature like zip code, the result will be a bar chart that shows us the calibrated values for each zip code. So now we know not only that zip code is the most important feature, but also the relative impact each zip code has on the predicted price. This is a far more granular understanding. 24 | 25 | ## Consistently Predicting How A Model Will Behave On Unseen Examples 26 | 27 | Okay, so we have a way to dig deeper and understand the why. That's great. But we have to remember that why is an afterthought -- for example, something went wrong and we want to know why. Of course, the why is incredibly useful and plays a big part in understanding how a model will behave, but it does not provide any guarantees on future behavior. Trust comes from the ability to predict behavior, so the more consistently one can predict a model's behavior, the more one can trust that model. 28 | 29 | Consider using a machine learning model to predict credit score where one of the input features is how late someone is on their payments. The behavior we want and expect is for the model to produce a better credit score for someone who pays their bills sooner, all else being equal. We can imagine that it would be unfair to penalize someone for paying their bills sooner. Even if we can understand the why, with black-box modeling we have no such guarantee. 30 | 31 | With calibrated modeling, we can constrain the shape of the model's function to provide certain guarantees. We call these shape constraints, and they come in many different flavors. The feature for payment lateness is a perfect fit for a decreasing monotonicity shape constraint. A decreasing monotonic functions's output always increases if the input decreases, and vice-versa. We want the model (function) to produce a higher credit score (output) if payment lateness (input) decreases, all else being equal. With PyTorch Lattice, just configure this behavior before training and it will be guaranteed. Pretty cool, right? 32 | 33 | Now, if you're not here to predict credit scores, you might be wondering how shape constraints can help you. What about the age of a house when predicting its price? Time since last repair for predictive maintenance? Number of similar items purchased when trying to predict a sale? 34 | 35 | Hopefully it's clear that many real-world features operate under these or similar constraints because they are part of real-world systems with certain fundamental rules. While we can hope that black-box models learn what we would expect from data, the ability to guarantee the behaviors we expect enables a higher level of trust in a model and eliminates toil. 36 | -------------------------------------------------------------------------------- /examples/basic_classifier.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | 3 | import pytorch_lattice as pyl 4 | 5 | # Load Data 6 | X, y = pyl.datasets.heart() 7 | 8 | # Fit Classifier 9 | clf = pyl.Classifier(X.columns).fit(X, y) 10 | 11 | # Generate Predictions 12 | preds = clf.predict(X) 13 | 14 | # Calculate AUC 15 | fpr, tpr, _ = metrics.roc_curve(y, preds) 16 | print(f"Train AUC: {metrics.auc(fpr, tpr)}") 17 | -------------------------------------------------------------------------------- /examples/calibrated_linear_classification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn import metrics 4 | 5 | import pytorch_lattice as pyl 6 | from pytorch_lattice.models.features import CategoricalFeature, NumericalFeature 7 | 8 | # Load Data 9 | X, y = pyl.datasets.heart() 10 | 11 | # Configure Features 12 | features = [ 13 | NumericalFeature("age", data=np.array(X["age"].values), monotonicity="increasing"), 14 | NumericalFeature( 15 | "trestbps", data=np.array(X["trestbps"].values), monotonicity="increasing" 16 | ), 17 | NumericalFeature( 18 | "chol", data=np.array(X["chol"].values), monotonicity="increasing" 19 | ), 20 | CategoricalFeature("ca", categories=X["ca"].unique().tolist()), 21 | CategoricalFeature( 22 | "thal", 23 | categories=["fixed", "normal", "reversible"], 24 | monotonicity_pairs=[("normal", "fixed"), ("normal", "reversible")], 25 | ), 26 | ] 27 | 28 | # Create Model (you can replace this with CalibratedLattice to train a lattice model) 29 | model = pyl.models.CalibratedLinear(features) 30 | 31 | # Fit Model 32 | optimizer = torch.optim.Adam(model.parameters(recurse=True), lr=1e-3) 33 | loss_fn = torch.nn.BCEWithLogitsLoss() 34 | dataset = pyl.utils.data.Dataset(X, y, features) 35 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) 36 | for epoch in range(100): 37 | for inputs, labels in dataloader: 38 | optimizer.zero_grad() 39 | loss_fn(model(inputs), labels).backward() 40 | optimizer.step() 41 | model.apply_constraints() 42 | 43 | # Generate Predictions 44 | model.eval() 45 | X_copy = X[["age", "trestbps", "chol", "ca", "thal"]].copy() 46 | pyl.utils.data.prepare_features(X_copy, features) 47 | X_tensor = torch.tensor(X_copy.values).double() 48 | with torch.no_grad(): 49 | preds = model(X_tensor).numpy() 50 | 51 | # Calculate AUC 52 | fpr, tpr, _ = metrics.roc_curve(y, preds) 53 | print(f"Train AUC: {metrics.auc(fpr, tpr)}") 54 | -------------------------------------------------------------------------------- /examples/classifier_plotting.py: -------------------------------------------------------------------------------- 1 | import pytorch_lattice as pyl 2 | 3 | # Load Data 4 | X, y = pyl.datasets.heart() 5 | 6 | # Fit Classifier 7 | clf = pyl.Classifier(X.columns).fit(X, y) 8 | 9 | # Plot Calibrator For Feature "thal" 10 | pyl.plots.calibrator(clf.model, "thal") 11 | 12 | # Plot Linear Coefficients For Calibrated Linear Model 13 | pyl.plots.linear_coefficients(clf.model) 14 | -------------------------------------------------------------------------------- /examples/classifier_with_monotonicity.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | 3 | import pytorch_lattice as pyl 4 | 5 | # Load Data 6 | X, y = pyl.datasets.adult() 7 | 8 | # Create Classifier On Subset Of Features 9 | clf = pyl.Classifier( 10 | [ 11 | "age", 12 | "workclass", 13 | "education_num", 14 | "capital_gain", 15 | "capital_loss", 16 | "hours_per_week", 17 | ] 18 | ) 19 | 20 | # Configure Feature Monotonicity 21 | clf.configure("education_num").monotonicity("increasing") 22 | clf.configure("capital_gain").monotonicity("increasing") 23 | 24 | # Fit Classifier 25 | clf.fit(X, y) 26 | 27 | # Generate Predictions 28 | preds = clf.predict(X) 29 | 30 | # Calculate AUC 31 | fpr, tpr, _ = metrics.roc_curve(y, preds) 32 | print(f"Train AUC: {metrics.auc(fpr, tpr)}") 33 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | copyright: Copyright © 2023 William Bakst. 2 | site_name: PyTorch Lattice 3 | site_url: https://willbakst.github.io/pytorch-lattice 4 | repo_name: pytorch-lattice 5 | repo_url: https://github.com/willbakst/pytorch-lattice/ 6 | theme: 7 | name: material 8 | icon: 9 | repo: fontawesome/brands/github 10 | features: 11 | - content.code.annotation 12 | - content.code.copy 13 | - content.code.link 14 | - navigation.footer 15 | - navigation.sections 16 | - navigation.tabs 17 | - navigation.top 18 | - search.highlight 19 | - search.suggest 20 | - toc.follow 21 | language: en 22 | palette: 23 | - scheme: default 24 | toggle: 25 | icon: material/brightness-7 26 | name: Switch to dark mode 27 | primary: indigo 28 | accent: indigo 29 | - scheme: slate 30 | toggle: 31 | icon: material/brightness-4 32 | name: Switch to light mode 33 | primary: indigo 34 | accent: indigo 35 | font: 36 | text: Roboto 37 | code: Roboto Mono 38 | 39 | extra: 40 | social: 41 | - icon: fontawesome/brands/github-alt 42 | link: https://github.com/willbakst 43 | - icon: fontawesome/brands/twitter 44 | link: https://twitter.com/WilliamBakst 45 | - icon: fontawesome/brands/linkedin 46 | link: https://www.linkedin.com/in/wbakst/ 47 | analytics: 48 | provider: google 49 | property: G-Q8WNH5KD11 50 | 51 | markdown_extensions: 52 | - pymdownx.highlight: 53 | anchor_linenums: true 54 | - pymdownx.inlinehilite 55 | - pymdownx.snippets 56 | - admonition 57 | - pymdownx.arithmatex: 58 | generic: true 59 | - footnotes 60 | - pymdownx.details 61 | - pymdownx.superfences 62 | - pymdownx.mark 63 | - attr_list 64 | - pymdownx.emoji: 65 | emoji_index: !!python/name:material.extensions.emoji.twemoji 66 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 67 | 68 | plugins: 69 | - search 70 | - mkdocstrings: 71 | handlers: 72 | python: 73 | options: 74 | show_root_heading: true 75 | docstring_style: google 76 | 77 | nav: 78 | - Get Started: 79 | - Welcome to PyTorch Lattice: "README.md" 80 | - Why use PyTorch Lattice: "why.md" 81 | - Contributing: "contributing.md" 82 | - How to help: "help.md" 83 | - Concepts: 84 | - Classifier: "concepts/classifier.md" 85 | - Calibrators: "concepts/calibrators.md" 86 | - Shape Constraints: "concepts/shape_constraints.md" 87 | - Model Types: "concepts/model_types.md" 88 | - Plotting: "concepts/plotting.md" 89 | - Walkthroughs: 90 | - UCI Adult Income: "walkthroughs/uci_adult_income.md" 91 | - API Reference: 92 | - layers: "api/layers.md" 93 | - models: "api/models.md" 94 | - utils: "api/utils.md" 95 | - classifier: "api/classifier.md" 96 | - constrained_module: "api/constrained_module.md" 97 | - datasets: "api/datasets.md" 98 | - enums: "api/enums.md" 99 | - feature_config: "api/feature_config.md" 100 | - model_configs: "api/model_configs.md" 101 | - plots: "api/plots.md" 102 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pytorch-lattice" 3 | version = "0.2.0" 4 | description = "A PyTorch Implementation Of Lattice Modeling Techniques" 5 | license = "MIT" 6 | authors = ["William Bakst "] 7 | readme = "docs/README.md" 8 | packages = [{ include = "pytorch_lattice" }] 9 | repository = "https://github.com/willbakst/pytorch-lattice" 10 | 11 | [tool.poetry.dependencies] 12 | python = ">=3.9, <=3.12" 13 | matplotlib = "^3.7.1" 14 | numpy = "^1.23.5" 15 | pandas = "^2.0.2" 16 | pydantic = "^2.0.2" 17 | torch = ">=2.0.0, !=2.0.1, !=2.1.0" 18 | tqdm = "^4.65.0" 19 | 20 | [tool.poetry.group.dev.dependencies] 21 | mkdocs = "^1.4.3" 22 | mkdocs-material = "^9.1.18" 23 | mkdocstrings = "^0.22.0" 24 | mkdocstrings-python = "^1.1.2" 25 | mypy = "^1.6.1" 26 | pytest = "^7.4.0" 27 | ruff = "^0.1.5" 28 | pandas-stubs = "^2.1.1.230928" 29 | types-tqdm = "^4.66.0.4" 30 | 31 | [tool.ruff] 32 | exclude = [ 33 | ".bzr", 34 | ".direnv", 35 | ".eggs", 36 | ".git", 37 | ".git-rewrite", 38 | ".hg", 39 | ".mypy_cache", 40 | ".nox", 41 | ".pants.d", 42 | ".pytype", 43 | ".ruff_cache", 44 | ".svn", 45 | ".tox", 46 | ".venv", 47 | "__pypackages__", 48 | "_build", 49 | "buck-out", 50 | "build", 51 | "dist", 52 | "node_modules", 53 | "venv", 54 | ] 55 | line-length = 88 56 | target-version = "py38" 57 | 58 | [tool.ruff.lint.per-file-ignores] 59 | "__init__.py" = ["F401"] 60 | 61 | [tool.ruff.lint] 62 | select = ["E4", "E7", "E9", "F"] 63 | ignore = [] 64 | fixable = ["ALL"] 65 | unfixable = [] 66 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 67 | 68 | [tool.ruff.format] 69 | quote-style = "double" 70 | indent-style = "space" 71 | skip-magic-trailing-comma = false 72 | line-ending = "auto" 73 | 74 | [tool.mypy] 75 | exclude = ["examples", "venv"] 76 | 77 | [build-system] 78 | requires = ["poetry-core"] 79 | build-backend = "poetry.core.masonry.api" 80 | -------------------------------------------------------------------------------- /pytorch_lattice/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch Lattice""" 2 | 3 | # This version must always be one version ahead of the current release, so it 4 | # matches the current state of development, which will always be ahead of the 5 | # current release. Use Semantic Versioning. 6 | __version__ = "0.2.0" 7 | 8 | from . import datasets, plots, utils 9 | from .classifier import Classifier 10 | from .enums import ( 11 | CategoricalCalibratorInit, 12 | InputKeypointsInit, 13 | InputKeypointsType, 14 | Interpolation, 15 | LatticeInit, 16 | Monotonicity, 17 | NumericalCalibratorInit, 18 | ) 19 | from .feature_config import FeatureConfig 20 | from .model_configs import LatticeConfig, LinearConfig 21 | -------------------------------------------------------------------------------- /pytorch_lattice/classifier.py: -------------------------------------------------------------------------------- 1 | """A class for training classifiers on tabular data using calibrated models.""" 2 | from __future__ import annotations 3 | 4 | import os 5 | import pickle 6 | from typing import Optional, Union 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from tqdm import trange 12 | 13 | from .feature_config import FeatureConfig 14 | from .model_configs import LatticeConfig, LinearConfig 15 | from .models import ( 16 | CalibratedLattice, 17 | CalibratedLinear, 18 | ) 19 | from .models.features import CategoricalFeature, NumericalFeature 20 | from .utils.data import Dataset, prepare_features 21 | 22 | MISSING_INPUT_VALUE = -123456789 23 | 24 | 25 | class Classifier: 26 | """A classifier for tabular data using calibrated models. 27 | 28 | Note: currently only handles binary classification targets. 29 | 30 | Example: 31 | ```python 32 | X, y = pyl.datasets.heart() 33 | clf = pyl.Classifier(X.columns) 34 | clf.configure("age").num_keypoints(10).monotonicity("increasing") 35 | clf.fit(X, y) 36 | ``` 37 | 38 | Attributes: 39 | features: A dict mapping feature names to their corresponding `FeatureConfig` 40 | instances. 41 | model_config: The model configuration to use for fitting the classifier. 42 | self.model: The fitted model. This will be `None` until `fit` is called. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | feature_names: list[str], 48 | model_config: Optional[Union[LinearConfig, LatticeConfig]] = None, 49 | ): 50 | """Initializes an instance of `Classifier`.""" 51 | self.features = { 52 | feature_name: FeatureConfig(name=feature_name) 53 | for feature_name in feature_names 54 | } 55 | self.model_config = model_config if model_config is not None else LinearConfig() 56 | self.model: Optional[Union[CalibratedLinear, CalibratedLattice]] = None 57 | 58 | def configure(self, feature_name: str): 59 | """Returns a `FeatureConfig` object for the given feature name.""" 60 | return self.features[feature_name] 61 | 62 | def fit( 63 | self, 64 | X: pd.DataFrame, 65 | y: np.ndarray, 66 | epochs: int = 50, 67 | batch_size: int = 64, 68 | learning_rate: float = 1e-3, 69 | shuffle: bool = False, 70 | ) -> Classifier: 71 | """Returns this classifier after fitting a model to the given data. 72 | 73 | Note that calling this function will overwrite any existing model and train a 74 | new model from scratch. 75 | 76 | Args: 77 | X: A `pd.DataFrame` containing the features for the training data. 78 | y: A `np.ndarray` containing the labels for the training data. 79 | epochs: The number of epochs for which to fit the classifier. 80 | batch_size: The batch size to use for fitting. 81 | learning_rate: The learning rate to use for fitting the model. 82 | shuffle: Whether to shuffle the data before fitting. 83 | """ 84 | model = self._create_model(X) 85 | optimizer = torch.optim.Adam(model.parameters(recurse=True), lr=learning_rate) 86 | loss_fn = torch.nn.BCEWithLogitsLoss() 87 | 88 | dataset = Dataset(X, y, model.features) 89 | dataloader = torch.utils.data.DataLoader( 90 | dataset, batch_size=batch_size, shuffle=shuffle 91 | ) 92 | for _ in trange(epochs, desc="Training Progress"): 93 | for inputs, labels in dataloader: 94 | optimizer.zero_grad() 95 | outputs = model(inputs) 96 | loss = loss_fn(outputs, labels) 97 | loss.backward() 98 | optimizer.step() 99 | model.apply_constraints() 100 | 101 | self.model = model 102 | return self 103 | 104 | def predict(self, X: pd.DataFrame, logits: bool = False) -> np.ndarray: 105 | """Returns predictions for the given data. 106 | 107 | Args: 108 | X: a `pd.DataFrame` containing to data for which to generate predictions. 109 | logits: If `True`, returns the logits of the predictions. Otherwise, returns 110 | probabilities. 111 | """ 112 | if self.model is None: 113 | raise RuntimeError("Cannot predict before fitting the model.") 114 | 115 | self.model.eval() 116 | X_copy = X[[feature.feature_name for feature in self.model.features]].copy() 117 | prepare_features(X_copy, self.model.features) 118 | X_tensor = torch.tensor(X_copy.values).double() 119 | with torch.no_grad(): 120 | preds = self.model(X_tensor).numpy() 121 | 122 | if logits: 123 | return preds 124 | else: 125 | return 1.0 / (1.0 + np.exp(-preds)) 126 | 127 | def save(self, filepath: str): 128 | """Saves the classifier to the specified path. 129 | 130 | Args: 131 | filepath: The directory where the classifier will be saved. If the directory 132 | does not exist, this function will attempt to create it. If the 133 | directory already exists, this function will overwrite any existing 134 | content with conflicting filenames. 135 | """ 136 | if not os.path.exists(filepath): 137 | os.makedirs(filepath) 138 | with open(os.path.join(filepath, "clf_attrs.pkl"), "wb") as f: 139 | attrs = {key: self.__dict__[key] for key in ["features", "model_config"]} 140 | pickle.dump(attrs, f) 141 | if self.model is not None: 142 | model_path = os.path.join(filepath, "model.pt") 143 | torch.save(self.model, model_path) 144 | 145 | @classmethod 146 | def load(cls, filepath: str) -> Classifier: 147 | """Loads a `Classifier` from the specified path. 148 | 149 | Args: 150 | filepath: The filepath from which to load the classifier. The filepath 151 | should point to the filepath used in the `save` method when saving the 152 | classifier. 153 | 154 | Returns: 155 | A `Classifier` instance. 156 | """ 157 | with open(os.path.join(filepath, "clf_attrs.pkl"), "rb") as f: 158 | attrs = pickle.load(f) 159 | 160 | clf = cls([]) 161 | clf.__dict__.update(attrs) 162 | 163 | model_path = os.path.join(filepath, "model.pt") 164 | if os.path.exists(model_path): 165 | clf.model = torch.load(model_path) 166 | 167 | return clf 168 | 169 | ################################################################################ 170 | ############################## PRIVATE METHODS ################################# 171 | ################################################################################ 172 | 173 | def _create_model( 174 | self, X: pd.DataFrame 175 | ) -> Union[CalibratedLinear, CalibratedLattice]: 176 | """Returns a model based on `self.features` and `self.model_config`.""" 177 | features: list[Union[CategoricalFeature, NumericalFeature]] = [] 178 | 179 | for feature_name, feature in self.features.items(): 180 | if X[feature_name].dtype.kind in ["S", "O", "b"]: # string, object, bool 181 | if feature._categories is None: 182 | categories = X[feature_name].unique().tolist() 183 | feature.categories(categories) 184 | else: 185 | categories = feature._categories 186 | if feature._monotonicity is not None and isinstance( 187 | feature._monotonicity, list 188 | ): 189 | monotonicity_pairs = feature._monotonicity 190 | else: 191 | monotonicity_pairs = None 192 | features.append( 193 | CategoricalFeature( 194 | feature_name=feature_name, 195 | categories=categories, 196 | missing_input_value=MISSING_INPUT_VALUE, 197 | monotonicity_pairs=monotonicity_pairs, 198 | lattice_size=feature._lattice_size, 199 | ) 200 | ) 201 | else: # numerical feature 202 | if feature._monotonicity is not None and isinstance( 203 | feature._monotonicity, str 204 | ): 205 | monotonicity = feature._monotonicity 206 | else: 207 | monotonicity = None 208 | features.append( 209 | NumericalFeature( 210 | feature_name=feature_name, 211 | data=np.array(X[feature_name].values), 212 | num_keypoints=feature._num_keypoints, 213 | input_keypoints_init=feature._input_keypoints_init, 214 | missing_input_value=MISSING_INPUT_VALUE, 215 | monotonicity=monotonicity, 216 | projection_iterations=feature._projection_iterations, 217 | lattice_size=feature._lattice_size, 218 | ) 219 | ) 220 | 221 | if isinstance(self.model_config, LinearConfig): 222 | return CalibratedLinear( 223 | features, 224 | self.model_config.output_min, 225 | self.model_config.output_max, 226 | self.model_config.use_bias, 227 | self.model_config.output_calibration_num_keypoints, 228 | ) 229 | else: 230 | return CalibratedLattice( 231 | features, 232 | True, 233 | self.model_config.output_min, 234 | self.model_config.output_max, 235 | self.model_config.kernel_init, 236 | self.model_config.interpolation, 237 | self.model_config.output_calibration_num_keypoints, 238 | ) 239 | -------------------------------------------------------------------------------- /pytorch_lattice/constrained_module.py: -------------------------------------------------------------------------------- 1 | """A virtual base class for constrained modules.""" 2 | from abc import abstractmethod 3 | from typing import Union 4 | 5 | import torch 6 | 7 | 8 | class ConstrainedModule(torch.nn.Module): 9 | """A base class for constrained implementations of a `torch.nn.Module`.""" 10 | 11 | @torch.no_grad() 12 | @abstractmethod 13 | def apply_constraints(self) -> None: 14 | """Applies defined constraints to the module.""" 15 | raise NotImplementedError() 16 | 17 | @torch.no_grad() 18 | @abstractmethod 19 | def assert_constraints( 20 | self, eps: float = 1e-6 21 | ) -> Union[list[str], dict[str, list[str]]]: 22 | """Asserts that the module satisfied specified constraints.""" 23 | raise NotImplementedError() 24 | -------------------------------------------------------------------------------- /pytorch_lattice/datasets.py: -------------------------------------------------------------------------------- 1 | """Functions for loading datasets to use with the PyTorch Lattice package.""" 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | def heart() -> tuple[pd.DataFrame, np.ndarray]: 7 | """Loads the UCI Statlog (Heart) dataset. 8 | 9 | The UCI Statlog (Heart) dataset is a classification dataset with 303 rows and 14 10 | columns. The target is binary, with 0 indicating no heart disease and 1 indicating 11 | heart disease. The features are a mix of categorical and numerical features. For 12 | more information, see https://archive.ics.uci.edu/ml/datasets/heart+Disease. 13 | 14 | Returns: 15 | A tuple `(X, y)` of the features and target. 16 | """ 17 | X = pd.read_csv( 18 | "https://raw.githubusercontent.com/ControlAI/datasets/main/heart.csv" 19 | ) 20 | y = np.array(X.pop("target").values) 21 | return X, y 22 | 23 | 24 | def adult() -> tuple[pd.DataFrame, np.ndarray]: 25 | """Loads the UCI Adult Income dataset. 26 | 27 | The UCI Adult Income dataset is a classification dataset with 48,842 rows and 14 28 | columns. The target is binary, with 0 indicating an income of less than $50k and 1 29 | indicating an income of at least $50k. The features are a mix of categorical and 30 | numerical features. For more information, see 31 | https://archive.ics.uci.edu/dataset/2/adult 32 | 33 | Returns: 34 | A tuple `(X, y)` of the features and target. 35 | """ 36 | X = pd.read_csv( 37 | "https://raw.githubusercontent.com/ControlAI/datasets/main/adult.csv" 38 | ) 39 | y = np.array(X.pop("label").values) 40 | return X, y 41 | -------------------------------------------------------------------------------- /pytorch_lattice/enums.py: -------------------------------------------------------------------------------- 1 | """Enum Classes for PyTorch Lattice.""" 2 | from enum import Enum, EnumMeta 3 | from typing import Any 4 | 5 | 6 | class _Metaclass(EnumMeta): 7 | """Base `EnumMeta` subclass for accessing enum members directly.""" 8 | 9 | def __getattribute__(cls, __name: str) -> Any: 10 | value = super().__getattribute__(__name) 11 | if isinstance(value, Enum): 12 | value = value.value 13 | return value 14 | 15 | 16 | class _Enum(str, Enum, metaclass=_Metaclass): 17 | """Base Enum Class.""" 18 | 19 | 20 | class InputKeypointsInit(_Enum): 21 | """Type of initialization to use for NumericalCalibrator input keypoints. 22 | 23 | - QUANTILES: initialize the input keypoints such that each segment will see the same 24 | number of examples. 25 | - UNIFORM: initialize the input keypoints uniformly spaced in the feature range. 26 | """ 27 | 28 | QUANTILES = "quantiles" 29 | UNIFORM = "uniform" 30 | 31 | 32 | class InputKeypointsType(_Enum): 33 | """The type of input keypoints to use. 34 | 35 | - FIXED: the input keypoints will be fixed during initialization. 36 | - LEARNED: the interior keypoints will learn through training to best fit the 37 | piecewise linear function. 38 | """ 39 | 40 | FIXED = "fixed" 41 | LEARNED = "learned" 42 | 43 | 44 | class NumericalCalibratorInit(_Enum): 45 | """Type of kernel initialization to use for NumericalCalibrator. 46 | 47 | - EQUAL_HEIGHTS: initialize the kernel such that all segments have the same height. 48 | - EQUAL_SLOPES: initialize the kernel such that all segments have the same slope. 49 | """ 50 | 51 | EQUAL_HEIGHTS = "equal_heights" 52 | EQUAL_SLOPES = "equal_slopes" 53 | 54 | 55 | class CategoricalCalibratorInit(_Enum): 56 | """Type of kernel initialization to use for CategoricalCalibrator. 57 | 58 | - UNIFORM: initialize the kernel with uniformly distributed values. The sample range 59 | will be [`output_min`, `output_max`] if both are provided. 60 | - CONSTANT: initialize the kernel with a constant value for all categories. This 61 | value will be `(output_min + output_max) / 2` if both are provided. 62 | """ 63 | 64 | UNIFORM = "uniform" 65 | CONSTANT = "constant" 66 | 67 | 68 | class Monotonicity(_Enum): 69 | """Type of monotonicity constraint. 70 | 71 | - INCREASING: increasing monotonicity i.e. increasing input increases output. 72 | - DECREASING: decreasing monotonicity i.e. increasing input decreases output. 73 | """ 74 | 75 | INCREASING = "increasing" 76 | DECREASING = "decreasing" 77 | 78 | 79 | class Interpolation(_Enum): 80 | """Enum for interpolation method of lattice. 81 | 82 | - HYPERCUBE: n-dimensional hypercube surrounding input point(s). 83 | - SIMPLEX: uses only one of the n! simplices in the n-dim hypercube. 84 | """ 85 | 86 | HYPERCUBE = "hypercube" 87 | SIMPLEX = "simplex" 88 | 89 | 90 | class LatticeInit(_Enum): 91 | """Type of kernel initialization to use for CategoricalCalibrator. 92 | 93 | - LINEAR: initialize the kernel with weights represented by a linear function, 94 | conforming to monotonicity and unimodality constraints. 95 | - RANDOM_MONOTONIC: initialize the kernel with a uniformly random sampled 96 | lattice layer weight tensor, conforming to monotonicity and unimodality 97 | constraints. 98 | """ 99 | 100 | LINEAR = "linear" 101 | RANDOM_MONOTONIC = "random_monotonic" 102 | -------------------------------------------------------------------------------- /pytorch_lattice/feature_config.py: -------------------------------------------------------------------------------- 1 | """Configuration objects for the PyTorch Lattice library.""" 2 | from __future__ import annotations 3 | 4 | from typing import Optional, Union 5 | 6 | from .enums import InputKeypointsInit, InputKeypointsType, Monotonicity 7 | 8 | 9 | class FeatureConfig: 10 | """A configuration object for a feature in a calibrated model. 11 | 12 | This configuration object handles both numerical and categorical features. If the 13 | `categeories` attribute is `None`, then this feature will be handled as numerical. 14 | Otherwise, it will be handled as categorical. 15 | 16 | Example: 17 | ```python 18 | fc = FeatureConfig(name="feature_name").num_keypoints(10).monotonicity("increasing") 19 | ``` 20 | 21 | Attributes: 22 | name: The name of the feature. 23 | """ 24 | 25 | def __init__(self, name: str): 26 | """Initializes an instance of `FeatureConfig` with default values.""" 27 | self.name = name 28 | self._categories: Optional[list[str]] = None 29 | self._num_keypoints: int = 5 30 | self._input_keypoints_init: InputKeypointsInit = InputKeypointsInit.QUANTILES 31 | self._input_keypoints_type: InputKeypointsType = InputKeypointsType.FIXED 32 | self._monotonicity: Optional[Union[Monotonicity, list[tuple[str, str]]]] = None 33 | self._projection_iterations: int = 8 34 | self._lattice_size: int = 2 # only used in lattice models 35 | 36 | def categories(self, categories: list[str]) -> FeatureConfig: 37 | """Sets the categories for a categorical feature.""" 38 | self._categories = categories 39 | return self 40 | 41 | def num_keypoints(self, num_keypoints: int) -> FeatureConfig: 42 | """Sets the categories for a categorical feature.""" 43 | self._num_keypoints = num_keypoints 44 | return self 45 | 46 | def input_keypoints_init( 47 | self, input_keypoints_init: InputKeypointsInit 48 | ) -> FeatureConfig: 49 | """Sets the input keypoints initialization method for a numerical calibrator.""" 50 | self._input_keypoints_init = input_keypoints_init 51 | return self 52 | 53 | def input_keypoints_type( 54 | self, input_keypoints_type: InputKeypointsType 55 | ) -> FeatureConfig: 56 | """Sets the input keypoints type for a numerical calibrator.""" 57 | self._input_keypoints_type = input_keypoints_type 58 | return self 59 | 60 | def monotonicity( 61 | self, monotonicity: Optional[Union[Monotonicity, list[tuple[str, str]]]] 62 | ) -> FeatureConfig: 63 | """Sets the monotonicity constraint for a feature.""" 64 | self._monotonicity = monotonicity 65 | return self 66 | 67 | def projection_iterations(self, projection_iterations: int) -> FeatureConfig: 68 | """Sets the number of projection iterations for a numerical calibrator.""" 69 | self._projection_iterations = projection_iterations 70 | return self 71 | 72 | def lattice_size(self, lattice_size: int) -> FeatureConfig: 73 | """Sets the lattice size for a feature.""" 74 | self._lattice_size = lattice_size 75 | return self 76 | -------------------------------------------------------------------------------- /pytorch_lattice/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """Layers used in calibrated modeling implemented as `torch.nn.Module`.""" 2 | from .categorical_calibrator import CategoricalCalibrator 3 | from .lattice import Lattice 4 | from .linear import Linear 5 | from .numerical_calibrator import NumericalCalibrator 6 | from .rtl import RTL 7 | -------------------------------------------------------------------------------- /pytorch_lattice/layers/categorical_calibrator.py: -------------------------------------------------------------------------------- 1 | """Categorical calibration module. 2 | 3 | PyTorch implementation of the categorical calibration module. This module takes in a 4 | single-dimensional input of categories represented as indices and transforms it by 5 | mapping a given category to its learned output value. 6 | """ 7 | from collections import defaultdict 8 | from typing import Optional 9 | 10 | import torch 11 | from graphlib import CycleError, TopologicalSorter 12 | 13 | from ..constrained_module import ConstrainedModule 14 | from ..enums import CategoricalCalibratorInit 15 | 16 | 17 | class CategoricalCalibrator(ConstrainedModule): 18 | """A categorical calibrator. 19 | 20 | This module takes an input of shape `(batch_size, 1)` and calibrates it by mapping a 21 | given category to its learned output value. The output will have the same shape as 22 | the input. 23 | 24 | Attributes: 25 | All: `__init__` arguments. 26 | kernel: `torch.nn.Parameter` that stores the categorical mapping weights. 27 | 28 | Example: 29 | ```python 30 | inputs = torch.tensor(...) # shape: (batch_size, 1) 31 | calibrator = CategoricalCalibrator( 32 | num_categories=5, 33 | missing_input_value=-1, 34 | output_min=0.0 35 | output_max=1.0, 36 | monotonicity_pairs=[(0, 1), (1, 2)], 37 | kernel_init=CateegoricalCalibratorInit.UNIFORM, 38 | ) 39 | outputs = calibrator(inputs) 40 | ``` 41 | """ 42 | 43 | def __init__( 44 | self, 45 | num_categories: int, 46 | missing_input_value: Optional[float] = None, 47 | output_min: Optional[float] = None, 48 | output_max: Optional[float] = None, 49 | monotonicity_pairs: Optional[list[tuple[int, int]]] = None, 50 | kernel_init: CategoricalCalibratorInit = CategoricalCalibratorInit.UNIFORM, 51 | ) -> None: 52 | """Initializes an instance of `CategoricalCalibrator`. 53 | 54 | Args: 55 | num_categories: The number of known categories. 56 | missing_input_value: If provided, the calibrator will learn to map all 57 | instances of this missing input value to a learned output value just 58 | the same as it does for known categories. Note that `num_categories` 59 | will be one greater to include this missing category. 60 | output_min: Minimum output value. If `None`, the minimum output value will 61 | be unbounded. 62 | output_max: Maximum output value. If `None`, the maximum output value will 63 | be unbounded. 64 | monotonicity_pairs: List of pairs of indices `(i,j)` indicating that the 65 | calibrator output for index `j` should be greater than or equal to that 66 | of index `i`. 67 | kernel_init: Initialization scheme to use for the kernel. 68 | 69 | Raises: 70 | ValueError: If `monotonicity_pairs` is cyclic. 71 | ValueError: If `kernel_init` is invalid. 72 | """ 73 | super().__init__() 74 | 75 | self.num_categories = ( 76 | num_categories + 1 if missing_input_value is not None else num_categories 77 | ) 78 | self.missing_input_value = missing_input_value 79 | self.output_min = output_min 80 | self.output_max = output_max 81 | self.monotonicity_pairs = monotonicity_pairs 82 | if monotonicity_pairs: 83 | self._monotonicity_graph = defaultdict(list) 84 | self._reverse_monotonicity_graph = defaultdict(list) 85 | for i, j in monotonicity_pairs: 86 | self._monotonicity_graph[i].append(j) 87 | self._reverse_monotonicity_graph[j].append(i) 88 | try: 89 | self._monotonically_sorted_indices = [ 90 | *TopologicalSorter(self._reverse_monotonicity_graph).static_order() 91 | ] 92 | except CycleError as exc: 93 | raise ValueError("monotonicity_pairs is cyclic") from exc 94 | self.kernel_init = kernel_init 95 | 96 | self.kernel = torch.nn.Parameter(torch.Tensor(self.num_categories, 1).double()) 97 | if kernel_init == CategoricalCalibratorInit.CONSTANT: 98 | if output_min is not None and output_max is not None: 99 | init_value = (output_min + output_max) / 2 100 | elif output_min is not None: 101 | init_value = output_min 102 | elif output_max is not None: 103 | init_value = output_max 104 | else: 105 | init_value = 0.0 106 | torch.nn.init.constant_(self.kernel, init_value) 107 | elif kernel_init == CategoricalCalibratorInit.UNIFORM: 108 | if output_min is not None and output_max is not None: 109 | low, high = output_min, output_max 110 | elif output_min is None and output_max is not None: 111 | low, high = output_max - 0.05, output_max 112 | elif output_min is not None and output_max is None: 113 | low, high = output_min, output_min + 0.05 114 | else: 115 | low, high = -0.05, 0.05 116 | torch.nn.init.uniform_(self.kernel, low, high) 117 | else: 118 | raise ValueError(f"Unknown kernel init: {kernel_init}") 119 | 120 | def forward(self, x: torch.Tensor) -> torch.Tensor: 121 | """Calibrates categorical inputs through a learned mapping. 122 | 123 | Args: 124 | x: The input tensor of category indices of shape `(batch_size, 1)`. 125 | 126 | Returns: 127 | torch.Tensor of shape `(batch_size, 1)` containing calibrated input values. 128 | """ 129 | if self.missing_input_value is not None: 130 | missing_category_tensor = torch.zeros_like(x) + (self.num_categories - 1) 131 | x = torch.where(x == self.missing_input_value, missing_category_tensor, x) 132 | # TODO: test if using torch.gather is faster than one-hot matmul. 133 | one_hot = torch.nn.functional.one_hot( 134 | torch.squeeze(x, -1).long(), num_classes=self.num_categories 135 | ).double() 136 | return torch.mm(one_hot, self.kernel) 137 | 138 | @torch.no_grad() 139 | def apply_constraints(self) -> None: 140 | """Projects kernel into desired constraints.""" 141 | projected_kernel_data = self.kernel.data 142 | if self.monotonicity_pairs: 143 | projected_kernel_data = self._approximately_project_monotonicity_pairs( 144 | projected_kernel_data 145 | ) 146 | if self.output_min is not None: 147 | projected_kernel_data = torch.maximum( 148 | projected_kernel_data, torch.tensor(self.output_min) 149 | ) 150 | if self.output_max is not None: 151 | projected_kernel_data = torch.minimum( 152 | projected_kernel_data, torch.tensor(self.output_max) 153 | ) 154 | self.kernel.data = projected_kernel_data 155 | 156 | @torch.no_grad() 157 | def assert_constraints(self, eps: float = 1e-6) -> list[str]: 158 | """Asserts that layer satisfies specified constraints. 159 | 160 | This checks that weights at the indexes of monotonicity pairs are in the correct 161 | order and that the output is within bounds. 162 | 163 | Args: 164 | eps: the margin of error allowed 165 | 166 | Returns: 167 | A list of messages describing violated constraints including violated 168 | monotonicity pairs. If no constraints violated, the list will be empty. 169 | """ 170 | weights = torch.squeeze(self.kernel.data) 171 | messages = [] 172 | 173 | if self.output_max is not None and torch.max(weights) > self.output_max + eps: 174 | messages.append("Max weight greater than output_max.") 175 | if self.output_min is not None and torch.min(weights) < self.output_min - eps: 176 | messages.append("Min weight less than output_min.") 177 | 178 | if self.monotonicity_pairs: 179 | violation_indices = [ 180 | (i, j) 181 | for (i, j) in self.monotonicity_pairs 182 | if weights[i] - weights[j] > eps 183 | ] 184 | if violation_indices: 185 | messages.append(f"Monotonicity violated at: {str(violation_indices)}.") 186 | 187 | return messages 188 | 189 | @torch.no_grad() 190 | def keypoints_inputs(self) -> torch.Tensor: 191 | """Returns a tensor of keypoint inputs (category indices).""" 192 | if self.missing_input_value is not None: 193 | return torch.cat( 194 | ( 195 | torch.arange(self.num_categories - 1), 196 | torch.tensor([self.missing_input_value]), 197 | ), 198 | 0, 199 | ) 200 | return torch.arange(self.num_categories) 201 | 202 | @torch.no_grad() 203 | def keypoints_outputs(self) -> torch.Tensor: 204 | """Returns a tensor of keypoint outputs.""" 205 | return torch.squeeze(self.kernel.data, -1) 206 | 207 | ################################################################################ 208 | ############################## PRIVATE METHODS ################################# 209 | ################################################################################ 210 | 211 | def _approximately_project_monotonicity_pairs(self, kernel_data) -> torch.Tensor: 212 | """Projects kernel such that the monotonicity pairs are satisfied. 213 | 214 | The kernel will be projected such that `kernel_data[i] <= kernel_data[j]`. This 215 | results in calibrated outputs that adhere to the desired constraints. 216 | 217 | Args: 218 | kernel_data: The tensor of shape `(self.num_categories, 1)` to be projected 219 | into the constraints specified by `self.monotonicity pairs`. 220 | 221 | Returns: 222 | Projected kernel data. To prevent the kernel from drifting in one direction, 223 | the data returned is the average of the min/max and max/min projections. 224 | """ 225 | projected_kernel_data = torch.unbind(kernel_data, 0) 226 | 227 | def project(data, monotonicity_graph, step, minimum): 228 | projected_data = list(data) 229 | sorted_indices = self._monotonically_sorted_indices 230 | if minimum: 231 | sorted_indices = sorted_indices[::-1] 232 | for i in sorted_indices: 233 | if i in monotonicity_graph: 234 | projection = projected_data[i] 235 | for j in monotonicity_graph[i]: 236 | if minimum: 237 | projection = torch.minimum(projection, projected_data[j]) 238 | else: 239 | projection = torch.maximum(projection, projected_data[j]) 240 | if step == 1.0: 241 | projected_data[i] = projection 242 | else: 243 | projected_data[i] = ( 244 | step * projection + (1 - step) * projected_data[i] 245 | ) 246 | return projected_data 247 | 248 | projected_kernel_min_max = project( 249 | projected_kernel_data, self._monotonicity_graph, 0.5, minimum=True 250 | ) 251 | projected_kernel_min_max = project( 252 | projected_kernel_min_max, 253 | self._reverse_monotonicity_graph, 254 | 1.0, 255 | minimum=False, 256 | ) 257 | projected_kernel_min_max = torch.stack(projected_kernel_min_max) 258 | 259 | projected_kernel_max_min = project( 260 | projected_kernel_data, self._reverse_monotonicity_graph, 0.5, minimum=False 261 | ) 262 | projected_kernel_max_min = project( 263 | projected_kernel_max_min, self._monotonicity_graph, 1.0, minimum=True 264 | ) 265 | projected_kernel_max_min = torch.stack(projected_kernel_max_min) 266 | 267 | return (projected_kernel_min_max + projected_kernel_max_min) / 2 268 | -------------------------------------------------------------------------------- /pytorch_lattice/layers/linear.py: -------------------------------------------------------------------------------- 1 | """Linear module for use in calibrated modeling. 2 | 3 | PyTorch implementation of the calibrated linear module. This module takes in a 4 | single-dimensional input and transforms it using a linear transformation and optionally 5 | a bias term. This module supports monotonicity constraints. 6 | """ 7 | from typing import Optional 8 | 9 | import torch 10 | 11 | from ..constrained_module import ConstrainedModule 12 | from ..enums import Monotonicity 13 | 14 | 15 | class Linear(ConstrainedModule): 16 | """A constrained linear module. 17 | 18 | This module takes an input of shape `(batch_size, input_dim)` and applied a linear 19 | transformation. The output will have the same shape as the input. 20 | 21 | Attributes: 22 | All: `__init__` arguments. 23 | kernel: `torch.nn.Parameter` that stores the linear combination weighting. 24 | bias: `torch.nn.Parameter` that stores the bias term. Only available is 25 | `use_bias` is true. 26 | 27 | Example: 28 | ```python 29 | input_dim = 3 30 | inputs = torch.tensor(...) # shape: (batch_size, input_dim) 31 | linear = Linear( 32 | input_dim, 33 | monotonicities=[ 34 | None, 35 | Monotonicity.INCREASING, 36 | Monotonicity.DECREASING 37 | ], 38 | use_bias=False, 39 | weighted_average=True, 40 | ) 41 | outputs = linear(inputs) 42 | ``` 43 | """ 44 | 45 | def __init__( 46 | self, 47 | input_dim: int, 48 | monotonicities: Optional[list[Optional[Monotonicity]]] = None, 49 | use_bias: bool = True, 50 | weighted_average: bool = False, 51 | ) -> None: 52 | """Initializes an instance of `Linear`. 53 | 54 | Args: 55 | input_dim: The number of inputs that will be combined. 56 | monotonicities: If provided, specifies the monotonicity of each input 57 | dimension. 58 | use_bias: Whether to use a bias term for the linear combination. 59 | weighted_average: Whether to make the output a weighted average i.e. all 60 | coefficients are positive and add up to a total of 1.0. No bias term 61 | will be used, and `use_bias` will be set to false regardless of the 62 | original value. `monotonicities` will also be set to increasing for all 63 | input dimensions to ensure that all coefficients are positive. 64 | 65 | Raises: 66 | ValueError: If monotonicities does not have length input_dim (if provided). 67 | """ 68 | super().__init__() 69 | 70 | self.input_dim = input_dim 71 | if monotonicities and len(monotonicities) != input_dim: 72 | raise ValueError("Monotonicities, if provided, must have length input_dim.") 73 | self.monotonicities = ( 74 | monotonicities 75 | if not weighted_average 76 | else [Monotonicity.INCREASING] * input_dim 77 | ) 78 | self.use_bias = use_bias if not weighted_average else False 79 | self.weighted_average = weighted_average 80 | 81 | self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, 1).double()) 82 | torch.nn.init.constant_(self.kernel, 1.0 / input_dim) 83 | if use_bias: 84 | self.bias = torch.nn.Parameter(torch.Tensor(1).double()) 85 | torch.nn.init.constant_(self.bias, 0.0) 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | """Transforms inputs using a linear combination. 89 | 90 | Args: 91 | x: The input tensor of shape `(batch_size, input_dim)`. 92 | 93 | Returns: 94 | torch.Tensor of shape `(batch_size, 1)` containing transformed input values. 95 | """ 96 | result = torch.mm(x, self.kernel) 97 | if self.use_bias: 98 | result += self.bias 99 | return result 100 | 101 | @torch.no_grad() 102 | def apply_constraints(self) -> None: 103 | """Projects kernel into desired constraints.""" 104 | projected_kernel_data = self.kernel.data 105 | 106 | if self.monotonicities: 107 | if Monotonicity.INCREASING in self.monotonicities: 108 | increasing_mask = torch.tensor( 109 | [ 110 | [0.0] if m == Monotonicity.INCREASING else [1.0] 111 | for m in self.monotonicities 112 | ] 113 | ) 114 | projected_kernel_data = torch.maximum( 115 | projected_kernel_data, projected_kernel_data * increasing_mask 116 | ) 117 | if Monotonicity.DECREASING in self.monotonicities: 118 | decreasing_mask = torch.tensor( 119 | [ 120 | [0.0] if m == Monotonicity.DECREASING else [1.0] 121 | for m in self.monotonicities 122 | ] 123 | ) 124 | projected_kernel_data = torch.minimum( 125 | projected_kernel_data, projected_kernel_data * decreasing_mask 126 | ) 127 | 128 | if self.weighted_average: 129 | norm = torch.norm(projected_kernel_data, 1) 130 | norm = torch.where(norm < 1e-8, 1.0, norm) 131 | projected_kernel_data /= norm 132 | 133 | self.kernel.data = projected_kernel_data 134 | 135 | @torch.no_grad() 136 | def assert_constraints(self, eps: float = 1e-6) -> list[str]: 137 | """Asserts that layer satisfies specified constraints. 138 | 139 | This checks that decreasing monotonicity corresponds to negative weights, 140 | increasing monotonicity corresponds to positive weights, and weights sum to 1 141 | for weighted_average=True. 142 | 143 | Args: 144 | eps: the margin of error allowed 145 | 146 | Returns: 147 | A list of messages describing violated constraints. If no constraints 148 | violated, the list will be empty. 149 | """ 150 | messages = [] 151 | 152 | if self.weighted_average: 153 | total_weight = torch.sum(self.kernel.data) 154 | if torch.abs(total_weight - 1.0) > eps: 155 | messages.append("Weights do not sum to 1.") 156 | 157 | if self.monotonicities: 158 | monotonicities_constant = torch.tensor( 159 | [ 160 | 1 161 | if m == Monotonicity.INCREASING 162 | else -1 163 | if m == Monotonicity.DECREASING 164 | else 0 165 | for m in self.monotonicities 166 | ], 167 | device=self.kernel.device, 168 | dtype=self.kernel.dtype, 169 | ).view(-1, 1) 170 | 171 | violated_monotonicities = (self.kernel * monotonicities_constant) < -eps 172 | violation_indices = torch.where(violated_monotonicities) 173 | if violation_indices[0].numel() > 0: 174 | messages.append( 175 | f"Monotonicity violated at: {violation_indices[0].tolist()}" 176 | ) 177 | 178 | return messages 179 | -------------------------------------------------------------------------------- /pytorch_lattice/layers/rtl.py: -------------------------------------------------------------------------------- 1 | """A PyTorch module implementing a calibrated modeling layer for Random Tiny Lattices. 2 | 3 | This module implements an ensemble of tiny lattices that each operate on a subset of the 4 | inputs. It utilizes the multi-unit functionality of the Lattice module to better 5 | optimize speed performance by putting feature subsets that have the same constraint 6 | structure into the same Lattice module as multiple units. 7 | """ 8 | import logging 9 | from typing import Optional, Union 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from ..enums import Interpolation, LatticeInit, Monotonicity 15 | from .lattice import Lattice 16 | 17 | 18 | class RTL(torch.nn.Module): 19 | """A module that efficiently implements Random Tiny Lattices. 20 | 21 | This module creates an ensemble of lattices where each lattice in the ensemble takes 22 | as input a subset of the input features. For further efficiency, input subsets with 23 | the same constraint structure all go through the same lattice as multiple units in 24 | parallel. When creating the ensemble structure, features are shuffled and uniformly 25 | repeated if there are more available slots in the ensemble structure than there are 26 | features. 27 | 28 | Attributes: 29 | - All `__init__` arguments. 30 | 31 | Example: 32 | ```python 33 | inputs=torch.tensor(...) # shape: (batch_size, D) 34 | monotonicities = List[Monotonicity...] # len: D 35 | random_tiny_lattices = RTL( 36 | monotonicities, 37 | num_lattices=5 38 | lattice_rank=3, # num_lattices * lattice_rank must be greater than D 39 | ) 40 | output1 = random_tiny_lattices(inputs) 41 | 42 | # You can stack RTL modules based on the previous RTL's output monotonicities. 43 | rtl2 = RTL(random_tiny_lattices.output_monotonicities(), ...) 44 | outputs2 = rtl2(outputs) 45 | ``` 46 | """ 47 | 48 | def __init__( 49 | self, 50 | monotonicities: list[Monotonicity], 51 | num_lattices: int, 52 | lattice_rank: int, 53 | lattice_size: int = 2, 54 | output_min: Optional[float] = None, 55 | output_max: Optional[float] = None, 56 | kernel_init: LatticeInit = LatticeInit.LINEAR, 57 | clip_inputs: bool = True, 58 | interpolation: Interpolation = Interpolation.HYPERCUBE, 59 | average_outputs: bool = False, 60 | random_seed: int = 42, 61 | ) -> None: 62 | """Initializes an instance of 'RTL'. 63 | 64 | Args: 65 | monotonicities: List of `Monotonicity.INCREASING` or `None` 66 | indicating monotonicities of input features, ordered respectively. 67 | num_lattices: number of lattices in RTL structure. 68 | lattice_rank: number of inputs for each lattice in RTL structure. 69 | output_min: Minimum output of each lattice in RTL. 70 | output_max: Maximum output of each lattice in RTL. 71 | kernel_init: Initialization scheme to use for lattices. 72 | clip_inputs: Whether input should be clipped to the range of each lattice. 73 | interpolation: Interpolation scheme for each lattice in RTL. 74 | average_outputs: Whether to average the outputs of every lattice RTL. 75 | random_seed: seed used for shuffling. 76 | 77 | Raises: 78 | ValueError: If size of RTL, determined by `num_lattices * lattice_rank`, is 79 | too small to support the number of input features. 80 | """ 81 | super().__init__() 82 | 83 | if len(monotonicities) > num_lattices * lattice_rank: 84 | raise ValueError( 85 | f"RTL with {num_lattices}x{lattice_rank}D structure cannot support " 86 | + f"{len(monotonicities)} input features." 87 | ) 88 | self.monotonicities = monotonicities 89 | self.num_lattices = num_lattices 90 | self.lattice_rank = lattice_rank 91 | self.lattice_size = lattice_size 92 | self.output_min = output_min 93 | self.output_max = output_max 94 | self.kernel_init = kernel_init 95 | self.clip_inputs = clip_inputs 96 | self.interpolation = interpolation 97 | self.average_outputs = average_outputs 98 | self.random_seed = random_seed 99 | 100 | rtl_indices = np.array( 101 | [i % len(self.monotonicities) for i in range(num_lattices * lattice_rank)] 102 | ) 103 | np.random.seed(self.random_seed) 104 | np.random.shuffle(rtl_indices) 105 | split_rtl_indices = [list(arr) for arr in np.split(rtl_indices, num_lattices)] 106 | swapped_rtl_indices = self._ensure_unique_sublattices(split_rtl_indices) 107 | monotonicity_groupings = {} 108 | for lattice_indices in swapped_rtl_indices: 109 | monotonic_count = sum( 110 | 1 111 | for idx in lattice_indices 112 | if self.monotonicities[idx] == Monotonicity.INCREASING 113 | ) 114 | if monotonic_count not in monotonicity_groupings: 115 | monotonicity_groupings[monotonic_count] = [lattice_indices] 116 | else: 117 | monotonicity_groupings[monotonic_count].append(lattice_indices) 118 | for monotonic_count, groups in monotonicity_groupings.items(): 119 | for i, lattice_indices in enumerate(groups): 120 | sorted_indices = sorted( 121 | lattice_indices, 122 | key=lambda x: (self.monotonicities[x] is None), 123 | reverse=False, 124 | ) 125 | groups[i] = sorted_indices 126 | 127 | self._lattice_layers = {} 128 | for monotonic_count, groups in monotonicity_groupings.items(): 129 | self._lattice_layers[monotonic_count] = ( 130 | Lattice( 131 | lattice_sizes=[self.lattice_size] * self.lattice_rank, 132 | output_min=self.output_min, 133 | output_max=self.output_max, 134 | kernel_init=self.kernel_init, 135 | monotonicities=[Monotonicity.INCREASING] * monotonic_count 136 | + [None] * (lattice_rank - monotonic_count), 137 | clip_inputs=self.clip_inputs, 138 | interpolation=self.interpolation, 139 | units=len(groups), 140 | ), 141 | groups, 142 | ) 143 | 144 | def forward(self, x: torch.Tensor) -> torch.Tensor: 145 | """Forward method computed by using forward methods of each lattice in ensemble. 146 | 147 | Args: 148 | x: input tensor of feature values with shape `(batch_size, num_features)`. 149 | 150 | Returns: 151 | `torch.Tensor` containing the outputs of each lattice within RTL structure. 152 | If `average_outputs == True`, then all outputs are averaged into a tensor of 153 | shape `(batch_size, 1)`. If `average_outputs == False`, shape of tensor is 154 | `(batch_size, num_lattices)`. 155 | """ 156 | forward_results = [] 157 | for _, (lattice, group) in sorted(self._lattice_layers.items()): 158 | if len(group) > 1: 159 | lattice_input = torch.stack([x[:, idx] for idx in group], dim=-2) 160 | else: 161 | lattice_input = x[:, group[0]] 162 | forward_results.append(lattice.forward(lattice_input)) 163 | result = torch.cat(forward_results, dim=-1) 164 | if not self.average_outputs: 165 | return result 166 | result = torch.mean(result, dim=-1, keepdim=True) 167 | 168 | return result 169 | 170 | @torch.no_grad() 171 | def output_monotonicities(self) -> list[Union[Monotonicity, None]]: 172 | """Gives the monotonicities of the outputs of RTL. 173 | 174 | Returns: 175 | List of `Monotonicity` corresponding to each output of the RTL layer, in the 176 | same order as outputs. 177 | """ 178 | monotonicities = [] 179 | for monotonic_count, (lattice, _) in sorted(self._lattice_layers.items()): 180 | if monotonic_count: 181 | monotonicity = Monotonicity.INCREASING 182 | else: 183 | monotonicity = None 184 | for _ in range(lattice.units): 185 | monotonicities.append(monotonicity) 186 | 187 | return monotonicities 188 | 189 | @torch.no_grad() 190 | def apply_constraints(self) -> None: 191 | """Enforces constraints for each lattice in RTL.""" 192 | for lattice, _ in self._lattice_layers.values(): 193 | lattice.apply_constraints() 194 | 195 | @torch.no_grad() 196 | def assert_constraints(self, eps: float = 1e-6) -> list[list[str]]: 197 | """Asserts that each Lattice in RTL satisfies all constraints. 198 | 199 | Args: 200 | eps: allowed constraints violations. 201 | 202 | Returns: 203 | List of lists, each with constraints violations for an individual Lattice. 204 | """ 205 | return list( 206 | lattice.assert_constraints(eps=eps) 207 | for lattice, _ in self._lattice_layers.values() 208 | ) 209 | 210 | @staticmethod 211 | def _ensure_unique_sublattices( 212 | rtl_indices: list[list[int]], 213 | max_swaps: int = 10000, 214 | ) -> list[list[int]]: 215 | """Attempts to ensure every lattice in RTL structure contains unique features. 216 | 217 | Args: 218 | rtl_indices: list of lists where inner lists are groupings of 219 | indices of input features to RTL layer. 220 | max_swaps: maximum number of swaps to perform before giving up. 221 | 222 | Returns: 223 | List of lists where elements between inner lists have been swapped in 224 | an attempt to remove any duplicates from every grouping. 225 | """ 226 | swaps = 0 227 | num_sublattices = len(rtl_indices) 228 | 229 | def find_swap_candidate(current_index, element): 230 | """Helper function to find the next sublattice not containing element.""" 231 | for offset in range(1, num_sublattices): 232 | candidate_index = (current_index + offset) % num_sublattices 233 | if element not in rtl_indices[candidate_index]: 234 | return candidate_index 235 | return None 236 | 237 | for i, sublattice in enumerate(rtl_indices): 238 | unique_elements = set() 239 | for element in sublattice: 240 | if element in unique_elements: 241 | swap_with = find_swap_candidate(i, element) 242 | if swap_with is not None: 243 | for swap_element in rtl_indices[swap_with]: 244 | if swap_element not in sublattice: 245 | # Perform the swap 246 | idx_to_swap = rtl_indices[swap_with].index(swap_element) 247 | idx_duplicate = sublattice.index(element) 248 | ( 249 | rtl_indices[swap_with][idx_to_swap], 250 | sublattice[idx_duplicate], 251 | ) = element, swap_element 252 | swaps += 1 253 | break 254 | else: 255 | logging.info( 256 | "Some lattices in RTL may use the same feature multiple " 257 | "times." 258 | ) 259 | return rtl_indices 260 | else: 261 | unique_elements.add(element) 262 | if swaps >= max_swaps: 263 | logging.info( 264 | "Some lattices in RTL may use the same feature multiple times." 265 | ) 266 | return rtl_indices 267 | return rtl_indices 268 | -------------------------------------------------------------------------------- /pytorch_lattice/model_configs.py: -------------------------------------------------------------------------------- 1 | """Model configurations classes for PyTorch Calibrated Models.""" 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | from .enums import Interpolation, LatticeInit 6 | 7 | 8 | @dataclass 9 | class _BaseModelConfig: 10 | """Configuration for a calibrated model. 11 | 12 | Attributes: 13 | output_min: The minimum output value for the model. If None, then it will be 14 | assumed that there is no minimum output value. 15 | output_max: The maximum output value for the model. If None, then it will be 16 | assumed that there is no maximum output value. 17 | output_calibration_num_keypoints: The number of keypoints to use for the output 18 | calibrator. If `None`, no output calibration will be used. 19 | """ 20 | 21 | output_min: Optional[float] = None 22 | output_max: Optional[float] = None 23 | output_calibration_num_keypoints: Optional[int] = None 24 | 25 | 26 | @dataclass 27 | class LinearConfig(_BaseModelConfig): 28 | """Configuration for a calibrated linear model. 29 | 30 | Attributes: 31 | All: `_BaseModelConfig` attributes. 32 | use_bias: Whether to use a bias term for the linear combination. 33 | """ 34 | 35 | use_bias: bool = True 36 | 37 | 38 | @dataclass 39 | class LatticeConfig(_BaseModelConfig): 40 | """Configuration for a calibrated lattice model. 41 | 42 | Attributes: 43 | All: `_BaseModelConfig` attributes. 44 | kernel_init: The `LatticeInit` scheme to use to initialize the lattice kernel. 45 | interpolation: The `Interpolation` scheme to use in the lattice. Note that 46 | `HYPERCUBE` has exponential time complexity while `SIMPLEX` has 47 | log-linear time complexity. 48 | """ 49 | 50 | kernel_init: LatticeInit = LatticeInit.LINEAR 51 | interpolation: Interpolation = Interpolation.SIMPLEX 52 | -------------------------------------------------------------------------------- /pytorch_lattice/models/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch Calibrated Models to easily implement common calibrated model architectures. 2 | 3 | PyTorch Calibrated Models make it easy to construct common calibrated model 4 | architectures. To construct a PyTorch Calibrated Model, pass a calibrated modeling 5 | config to the corresponding calibrated model. 6 | """ 7 | from .calibrated_lattice import CalibratedLattice 8 | from .calibrated_linear import CalibratedLinear 9 | -------------------------------------------------------------------------------- /pytorch_lattice/models/calibrated_lattice.py: -------------------------------------------------------------------------------- 1 | """Class for easily constructing a calibrated lattice model.""" 2 | from typing import Optional, Union 3 | 4 | import torch 5 | 6 | from ..constrained_module import ConstrainedModule 7 | from ..enums import ( 8 | Interpolation, 9 | LatticeInit, 10 | ) 11 | from ..layers import Lattice 12 | from ..utils.models import ( 13 | calibrate_and_stack, 14 | initialize_feature_calibrators, 15 | initialize_monotonicities, 16 | initialize_output_calibrator, 17 | ) 18 | from .features import CategoricalFeature, NumericalFeature 19 | 20 | 21 | class CalibratedLattice(ConstrainedModule): 22 | """PyTorch Calibrated Lattice Model. 23 | 24 | Creates a `torch.nn.Module` representing a calibrated lattice model, which will be 25 | constructed using the provided model configuration. Note that the model inputs 26 | should match the order in which they are defined in the `feature_configs`. 27 | 28 | Attributes: 29 | All: `__init__` arguments. 30 | calibrators: A dictionary that maps feature names to their calibrators. 31 | lattice: The `Lattice` layer of the model. 32 | output_calibrator: The output `NumericalCalibrator` calibration layer. This 33 | will be `None` if no output calibration is desired. 34 | 35 | Example: 36 | 37 | ```python 38 | feature_configs = [...] 39 | calibrated_model = CalibratedLattice(feature_configs, ...) 40 | 41 | loss_fn = torch.nn.MSELoss() 42 | optimizer = torch.optim.Adam(calibrated_model.parameters(recurse=True), lr=1e-1) 43 | 44 | dataset = pyl.utils.data.Dataset(...) 45 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) 46 | for epoch in range(100): 47 | for inputs, labels in dataloader: 48 | optimizer.zero_grad() 49 | outputs = calibrated_model(inputs) 50 | loss = loss_fn(outputs, labels) 51 | loss.backward() 52 | optimizer.step() 53 | calibrated_model.apply_constraints() 54 | ``` 55 | """ 56 | 57 | def __init__( 58 | self, 59 | features: list[Union[NumericalFeature, CategoricalFeature]], 60 | clip_inputs: bool = True, 61 | output_min: Optional[float] = None, 62 | output_max: Optional[float] = None, 63 | kernel_init: LatticeInit = LatticeInit.LINEAR, 64 | interpolation: Interpolation = Interpolation.HYPERCUBE, 65 | output_calibration_num_keypoints: Optional[int] = None, 66 | ) -> None: 67 | """Initializes an instance of `CalibratedLattice`. 68 | 69 | Args: 70 | features: A list of numerical and/or categorical feature configs. 71 | clip_inputs: Whether to restrict inputs to the bounds of lattice. 72 | output_min: The minimum output value for the model. If `None`, the minimum 73 | output value will be unbounded. 74 | output_max: The maximum output value for the model. If `None`, the maximum 75 | output value will be unbounded. 76 | kernel_init: the method of initializing kernel weights. If otherwise 77 | unspecified, will default to `LatticeInit.LINEAR`. 78 | interpolation: the method of interpolation in the lattice's forward pass. 79 | If otherwise unspecified, will default to `Interpolation.HYPERCUBE`. 80 | output_calibration_num_keypoints: The number of keypoints to use for the 81 | output calibrator. If `None`, no output calibration will be used. 82 | 83 | Raises: 84 | ValueError: If any feature configs are not `NUMERICAL` or `CATEGORICAL`. 85 | """ 86 | super().__init__() 87 | 88 | self.features = features 89 | self.clip_inputs = clip_inputs 90 | self.output_min = output_min 91 | self.output_max = output_max 92 | self.kernel_init = kernel_init 93 | self.interpolation = interpolation 94 | self.output_calibration_num_keypoints = output_calibration_num_keypoints 95 | self.monotonicities = initialize_monotonicities(features) 96 | self.calibrators = initialize_feature_calibrators( 97 | features=features, 98 | output_min=0, 99 | output_max=[feature.lattice_size - 1 for feature in features], 100 | ) 101 | 102 | self.lattice = Lattice( 103 | lattice_sizes=[feature.lattice_size for feature in features], 104 | monotonicities=self.monotonicities, 105 | clip_inputs=self.clip_inputs, 106 | output_min=self.output_min, 107 | output_max=self.output_max, 108 | interpolation=interpolation, 109 | kernel_init=kernel_init, 110 | ) 111 | 112 | self.output_calibrator = initialize_output_calibrator( 113 | output_calibration_num_keypoints=output_calibration_num_keypoints, 114 | monotonic=not all(m is None for m in self.monotonicities), 115 | output_min=output_min, 116 | output_max=output_max, 117 | ) 118 | 119 | def forward(self, x: torch.Tensor) -> torch.Tensor: 120 | """Runs an input through the network to produce a calibrated lattice output. 121 | 122 | Args: 123 | x: The input tensor of feature values of shape `(batch_size, num_features)`. 124 | 125 | Returns: 126 | torch.Tensor of shape `(batch_size, 1)` containing the model output result. 127 | """ 128 | result = calibrate_and_stack(x, self.calibrators) 129 | result = self.lattice(result) 130 | if self.output_calibrator is not None: 131 | result = self.output_calibrator(result) 132 | 133 | return result 134 | 135 | @torch.no_grad() 136 | def apply_constraints(self) -> None: 137 | """Constrains the model into desired constraints specified by the config.""" 138 | for calibrator in self.calibrators.values(): 139 | calibrator.apply_constraints() 140 | self.lattice.apply_constraints() 141 | if self.output_calibrator: 142 | self.output_calibrator.apply_constraints() 143 | 144 | @torch.no_grad() 145 | def assert_constraints(self, eps: float = 1e-6) -> dict[str, list[str]]: 146 | """Asserts all layers within model satisfied specified constraints. 147 | 148 | Asserts monotonicity pairs and output bounds for categorical calibrators, 149 | monotonicity and output bounds for numerical calibrators, and monotonicity and 150 | weights summing to 1 if weighted_average for linear layer. 151 | 152 | Args: 153 | eps: the margin of error allowed 154 | 155 | Returns: 156 | A dict where key is feature_name for calibrators and 'linear' for the linear 157 | layer, and value is the error messages for each layer. Layers with no error 158 | messages are not present in the dictionary. 159 | """ 160 | messages = {} 161 | 162 | for name, calibrator in self.calibrators.items(): 163 | calibrator_messages = calibrator.assert_constraints(eps) 164 | if calibrator_messages: 165 | messages[f"{name}_calibrator"] = calibrator_messages 166 | lattice_messages = self.lattice.assert_constraints(eps) 167 | if lattice_messages: 168 | messages["lattice"] = lattice_messages 169 | if self.output_calibrator: 170 | output_calibrator_messages = self.output_calibrator.assert_constraints(eps) 171 | if output_calibrator_messages: 172 | messages["output_calibrator"] = output_calibrator_messages 173 | 174 | return messages 175 | -------------------------------------------------------------------------------- /pytorch_lattice/models/calibrated_linear.py: -------------------------------------------------------------------------------- 1 | """Class for easily constructing a calibrated linear model.""" 2 | from typing import Optional, Union 3 | 4 | import torch 5 | 6 | from ..constrained_module import ConstrainedModule 7 | from ..layers import Linear 8 | from ..utils.models import ( 9 | calibrate_and_stack, 10 | initialize_feature_calibrators, 11 | initialize_monotonicities, 12 | initialize_output_calibrator, 13 | ) 14 | from .features import CategoricalFeature, NumericalFeature 15 | 16 | 17 | class CalibratedLinear(ConstrainedModule): 18 | """PyTorch Calibrated Linear Model. 19 | 20 | Creates a `torch.nn.Module` representing a calibrated linear model, which will be 21 | constructed using the provided model configuration. Note that the model inputs 22 | should match the order in which they are defined in the `feature_configs`. 23 | 24 | Attributes: 25 | All: `__init__` arguments. 26 | calibrators: A dictionary that maps feature names to their calibrators. 27 | linear: The `Linear` layer of the model. 28 | output_calibrator: The output `NumericalCalibrator` calibration layer. This 29 | will be `None` if no output calibration is desired. 30 | 31 | Example: 32 | 33 | ```python 34 | feature_configs = [...] 35 | calibrated_model = pyl.models.CalibratedLinear(feature_configs, ...) 36 | 37 | loss_fn = torch.nn.MSELoss() 38 | optimizer = torch.optim.Adam(calibrated_model.parameters(recurse=True), lr=1e-1) 39 | 40 | dataset = pyl.utils.data.Dataset(...) 41 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) 42 | for epoch in range(100): 43 | for inputs, labels in dataloader: 44 | optimizer.zero_grad() 45 | outputs = calibrated_model(inputs) 46 | loss = loss_fn(outputs, labels) 47 | loss.backward() 48 | optimizer.step() 49 | calibrated_model.apply_constraints() 50 | ``` 51 | """ 52 | 53 | def __init__( 54 | self, 55 | features: list[Union[NumericalFeature, CategoricalFeature]], 56 | output_min: Optional[float] = None, 57 | output_max: Optional[float] = None, 58 | use_bias: bool = True, 59 | output_calibration_num_keypoints: Optional[int] = None, 60 | ) -> None: 61 | """Initializes an instance of `CalibratedLinear`. 62 | 63 | Args: 64 | features: A list of numerical and/or categorical feature configs. 65 | output_min: The minimum output value for the model. If `None`, the minimum 66 | output value will be unbounded. 67 | output_max: The maximum output value for the model. If `None`, the maximum 68 | output value will be unbounded. 69 | use_bias: Whether to use a bias term for the linear combination. If any of 70 | `output_min`, `output_max`, or `output_calibration_num_keypoints` are 71 | set, a bias term will not be used regardless of the setting here. 72 | output_calibration_num_keypoints: The number of keypoints to use for the 73 | output calibrator. If `None`, no output calibration will be used. 74 | 75 | Raises: 76 | ValueError: If any feature configs are not `NUMERICAL` or `CATEGORICAL`. 77 | """ 78 | super().__init__() 79 | 80 | self.features = features 81 | self.output_min = output_min 82 | self.output_max = output_max 83 | self.use_bias = use_bias 84 | self.output_calibration_num_keypoints = output_calibration_num_keypoints 85 | self.monotonicities = initialize_monotonicities(features) 86 | self.calibrators = initialize_feature_calibrators( 87 | features=features, output_min=output_min, output_max=output_max 88 | ) 89 | 90 | self.linear = Linear( 91 | input_dim=len(features), 92 | monotonicities=self.monotonicities, 93 | use_bias=use_bias, 94 | weighted_average=bool( 95 | output_min is not None 96 | or output_max is not None 97 | or output_calibration_num_keypoints 98 | ), 99 | ) 100 | 101 | self.output_calibrator = initialize_output_calibrator( 102 | output_calibration_num_keypoints=output_calibration_num_keypoints, 103 | monotonic=not all(m is None for m in self.monotonicities), 104 | output_min=output_min, 105 | output_max=output_max, 106 | ) 107 | 108 | def forward(self, x: torch.Tensor) -> torch.Tensor: 109 | """Runs an input through the network to produce a calibrated linear output. 110 | 111 | Args: 112 | x: The input tensor of feature values of shape `(batch_size, num_features)`. 113 | 114 | Returns: 115 | torch.Tensor of shape `(batch_size, 1)` containing the model output result. 116 | """ 117 | result = calibrate_and_stack(x, self.calibrators) 118 | result = self.linear(result) 119 | if self.output_calibrator is not None: 120 | result = self.output_calibrator(result) 121 | 122 | return result 123 | 124 | @torch.no_grad() 125 | def apply_constraints(self) -> None: 126 | """Constrains the model into desired constraints specified by the config.""" 127 | for calibrator in self.calibrators.values(): 128 | calibrator.apply_constraints() 129 | self.linear.apply_constraints() 130 | if self.output_calibrator: 131 | self.output_calibrator.apply_constraints() 132 | 133 | @torch.no_grad() 134 | def assert_constraints( 135 | self, eps: float = 1e-6 136 | ) -> Union[list[str], dict[str, list[str]]]: 137 | """Asserts all layers within model satisfied specified constraints. 138 | 139 | Asserts monotonicity pairs and output bounds for categorical calibrators, 140 | monotonicity and output bounds for numerical calibrators, and monotonicity and 141 | weights summing to 1 if weighted_average for linear layer. 142 | 143 | Args: 144 | eps: the margin of error allowed 145 | 146 | Returns: 147 | A dict where key is feature_name for calibrators and 'linear' for the linear 148 | layer, and value is the error messages for each layer. Layers with no error 149 | messages are not present in the dictionary. 150 | """ 151 | messages: dict[str, list[str]] = {} 152 | 153 | for name, calibrator in self.calibrators.items(): 154 | calibrator_messages = calibrator.assert_constraints(eps) 155 | if calibrator_messages: 156 | messages[f"{name}_calibrator"] = calibrator_messages 157 | linear_messages = self.linear.assert_constraints(eps) 158 | if linear_messages: 159 | messages["linear"] = linear_messages 160 | if self.output_calibrator: 161 | output_calibrator_messages = self.output_calibrator.assert_constraints(eps) 162 | if output_calibrator_messages: 163 | messages["output_calibrator"] = output_calibrator_messages 164 | 165 | return messages 166 | -------------------------------------------------------------------------------- /pytorch_lattice/models/features.py: -------------------------------------------------------------------------------- 1 | """Feature objects for use in models. 2 | 3 | To construct a calibrated model, create the calibrated model configuration and pass it 4 | in to the corresponding calibrated model constructor. 5 | 6 | Example: 7 | ```python 8 | feature_configs = [...] 9 | linear_config = CalibratedLinearConfig(feature_configs, ...) 10 | linear_model = CalibratedLinear(linear_config) 11 | ``` 12 | """ 13 | import logging 14 | from typing import Optional, Union 15 | 16 | import numpy as np 17 | 18 | from ..enums import InputKeypointsInit, Monotonicity 19 | 20 | 21 | class NumericalFeature: 22 | """Feature configuration for numerical features. 23 | 24 | Attributes: 25 | All: `__init__` arguments. 26 | input_keypoints: The input keypoints used for this feature's calibrator. These 27 | keypoints will be initialized using the given `data` under the desired 28 | `input_keypoints_init` scheme. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | feature_name: str, 34 | data: np.ndarray, 35 | num_keypoints: int = 5, 36 | input_keypoints_init: InputKeypointsInit = InputKeypointsInit.QUANTILES, 37 | missing_input_value: Optional[float] = None, 38 | monotonicity: Optional[Monotonicity] = None, 39 | projection_iterations: int = 8, 40 | lattice_size: int = 2, 41 | ) -> None: 42 | """Initializes a `NumericalFeatureConfig` instance. 43 | 44 | Args: 45 | feature_name: The name of the feature. This should match the header for the 46 | column in the dataset representing this feature. 47 | data: Numpy array of float-valued data used for calculating keypoint inputs 48 | and initializing keypoint outputs. 49 | num_keypoints: The number of keypoints used by the underlying piece-wise 50 | linear function of a NumericalCalibrator. There will be 51 | `num_keypoints - 1` total segments. 52 | input_keypoints_init: The scheme to use for initializing the input 53 | keypoints. See `InputKeypointsInit` for more details. 54 | missing_input_value: If provided, this feature's calibrator will learn to 55 | map all instances of this missing input value to a learned output value. 56 | monotonicity: Monotonicity constraint for this feature, if any. 57 | projection_iterations: Number of times to run Dykstra's projection 58 | algorithm when applying constraints. 59 | lattice_size: The default number of keypoints outputted by the 60 | calibrator. Only used within `Lattice` models. 61 | 62 | Raises: 63 | ValueError: If `data` contains NaN values. 64 | ValueError: If `input_keypoints_init` is invalid. 65 | """ 66 | self.feature_name = feature_name 67 | 68 | if np.isnan(data).any(): 69 | raise ValueError("Data contains NaN values.") 70 | 71 | self.data = data 72 | self.num_keypoints = num_keypoints 73 | self.input_keypoints_init = input_keypoints_init 74 | self.missing_input_value = missing_input_value 75 | self.monotonicity = monotonicity 76 | self.projection_iterations = projection_iterations 77 | self.lattice_size = lattice_size 78 | 79 | sorted_unique_values = np.unique(data) 80 | 81 | if input_keypoints_init == InputKeypointsInit.QUANTILES: 82 | if sorted_unique_values.size < num_keypoints: 83 | logging.info( 84 | "Observed fewer unique values for feature %s than %d desired " 85 | "keypoints. Using the observed %d unique values as keypoints.", 86 | feature_name, 87 | num_keypoints, 88 | sorted_unique_values.size, 89 | ) 90 | self.input_keypoints = sorted_unique_values 91 | else: 92 | quantiles = np.linspace(0.0, 1.0, num=num_keypoints) 93 | self.input_keypoints = np.quantile( 94 | sorted_unique_values, quantiles, method="nearest" 95 | ) 96 | elif input_keypoints_init == InputKeypointsInit.UNIFORM: 97 | self.input_keypoints = np.linspace( 98 | sorted_unique_values[0], sorted_unique_values[-1], num=num_keypoints 99 | ) 100 | else: 101 | raise ValueError(f"Unknown input keypoints init: {input_keypoints_init}") 102 | 103 | 104 | class CategoricalFeature: 105 | """Feature configuration for categorical features. 106 | 107 | Attributes: 108 | All: `__init__` arguments. 109 | category_indices: A dictionary mapping string categories to their index. 110 | monotonicity_index_pairs: A conversion of `monotonicity_pairs` from string 111 | categories to category indices. Only available if `monotonicity_pairs` are 112 | provided. 113 | """ 114 | 115 | def __init__( 116 | self, 117 | feature_name: str, 118 | categories: Union[list[int], list[str]], 119 | missing_input_value: Optional[float] = None, 120 | monotonicity_pairs: Optional[list[tuple[str, str]]] = None, 121 | lattice_size: int = 2, 122 | ) -> None: 123 | """Initializes a `CategoricalFeatureConfig` instance. 124 | 125 | Args: 126 | feature_name: The name of the feature. This should match the header for the 127 | column in the dataset representing this feature. 128 | categories: The categories that should be used for this feature. Any 129 | categories not contained will be considered missing or unknown. If you 130 | expect to have such missing categories, make sure to 131 | missing_input_value: If provided, this feature's calibrator will learn to 132 | map all instances of this missing input value to a learned output value. 133 | monotonicity_pairs: List of pairs of categories `(category_a, category_b)` 134 | indicating that the calibrator output for `category_b` should be greater 135 | than or equal to that of `category_a`. 136 | lattice_size: The default number of keypoints outputted by the calibrator. 137 | Only used within `Lattice` models. 138 | """ 139 | self.feature_name = feature_name 140 | self.categories = categories 141 | self.missing_input_value = missing_input_value 142 | self.monotonicity_pairs = monotonicity_pairs 143 | self.lattice_size = lattice_size 144 | 145 | self.category_indices = {category: i for i, category in enumerate(categories)} 146 | self.monotonicity_index_pairs = [ 147 | (self.category_indices[a], self.category_indices[b]) 148 | for a, b in monotonicity_pairs or [] 149 | ] 150 | -------------------------------------------------------------------------------- /pytorch_lattice/plots.py: -------------------------------------------------------------------------------- 1 | """Plotting functions for PyTorch Lattice calibrated models using matplotlib.""" 2 | from typing import Union 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | from .layers import CategoricalCalibrator 8 | from .models import CalibratedLattice, CalibratedLinear 9 | from .models.features import CategoricalFeature 10 | 11 | 12 | def calibrator( 13 | model: Union[CalibratedLinear, CalibratedLattice], 14 | feature_name: str, 15 | ) -> None: 16 | """Plots the calibrator for the given feature and calibrated model. 17 | 18 | Args: 19 | model: The calibrated model for which to plot calibrators. 20 | feature_name: The name of the feature for which to plot the calibrator. 21 | """ 22 | if feature_name not in model.calibrators: 23 | raise ValueError(f"Feature {feature_name} not found in model.") 24 | 25 | calibrator = model.calibrators[feature_name] 26 | input_keypoints = calibrator.keypoints_inputs().numpy() 27 | output_keypoints = calibrator.keypoints_outputs().numpy() 28 | 29 | if isinstance(calibrator, CategoricalCalibrator): 30 | model_feature = next( 31 | (x for x in model.features if x.feature_name == feature_name), None 32 | ) 33 | if isinstance(model_feature, CategoricalFeature): 34 | input_keypoints = np.array( 35 | [ 36 | model_feature.categories[i] 37 | if i < len(input_keypoints) - 1 38 | else "" 39 | for i, ik in enumerate(input_keypoints) 40 | ] 41 | ) 42 | plt.xticks(rotation=45) 43 | plt.bar(input_keypoints, output_keypoints) 44 | else: 45 | plt.plot(input_keypoints, output_keypoints) 46 | 47 | plt.title(f"Calibrator: {feature_name}") 48 | plt.xlabel("Input Keypoints") 49 | plt.ylabel("Output Keypoints") 50 | plt.show() 51 | 52 | 53 | def linear_coefficients(model: CalibratedLinear) -> None: 54 | """Plots the coefficients for the linear layer of a calibrated linear model.""" 55 | if not isinstance(model, CalibratedLinear): 56 | raise ValueError( 57 | "Model must be a `CalibratedLinear` model to plot linear coefficients." 58 | ) 59 | linear_coefficients = dict( 60 | zip( 61 | [feature.feature_name for feature in model.features], 62 | model.linear.kernel.detach().numpy().flatten(), 63 | ) 64 | ) 65 | if model.use_bias: 66 | linear_coefficients["bias"] = model.linear.bias.detach().numpy()[0] 67 | 68 | plt.bar(list(linear_coefficients.keys()), list(linear_coefficients.values())) 69 | plt.title("Linear Coefficients") 70 | plt.xlabel("Feature Name") 71 | plt.xticks(rotation=45) 72 | plt.ylabel("Coefficient Value") 73 | plt.show() 74 | -------------------------------------------------------------------------------- /pytorch_lattice/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/pytorch_lattice/utils/__init__.py -------------------------------------------------------------------------------- /pytorch_lattice/utils/data.py: -------------------------------------------------------------------------------- 1 | """Utility functions and classes for handling data.""" 2 | from typing import Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | 8 | from ..models.features import CategoricalFeature, NumericalFeature 9 | 10 | 11 | def prepare_features( 12 | X: pd.DataFrame, features: list[Union[NumericalFeature, CategoricalFeature]] 13 | ): 14 | """Maps categorical features to their integer indices in place.""" 15 | for feature in features: 16 | feature_data = X[feature.feature_name] 17 | 18 | if isinstance(feature, CategoricalFeature): 19 | feature_data = feature_data.map(feature.category_indices) 20 | 21 | if feature.missing_input_value is not None: 22 | feature_data = feature_data.fillna(feature.missing_input_value) 23 | 24 | X[feature.feature_name] = feature_data 25 | 26 | 27 | class Dataset(torch.utils.data.Dataset): 28 | """A class for loading a dataset for a calibrated model.""" 29 | 30 | def __init__( 31 | self, 32 | X: pd.DataFrame, 33 | y: np.ndarray, 34 | features: list[Union[NumericalFeature, CategoricalFeature]], 35 | ): 36 | """Initializes an instance of `Dataset`.""" 37 | self.X = X.copy() 38 | self.y = y.copy() 39 | 40 | selected_features = [feature.feature_name for feature in features] 41 | unavailable_features = set(selected_features) - set(self.X.columns) 42 | if len(unavailable_features) > 0: 43 | raise ValueError(f"Features {unavailable_features} not found in dataset.") 44 | 45 | drop_features = list(set(self.X.columns) - set(selected_features)) 46 | self.X.drop(drop_features, axis=1, inplace=True) 47 | prepare_features(self.X, features) 48 | 49 | self.data = torch.from_numpy(self.X.values).double() 50 | self.labels = torch.from_numpy(self.y).double()[:, None] 51 | 52 | def __len__(self): 53 | return len(self.X) 54 | 55 | def __getitem__(self, idx): 56 | if isinstance(idx, torch.Tensor): 57 | idx = idx.tolist() 58 | 59 | return [self.data[idx], self.labels[idx]] 60 | -------------------------------------------------------------------------------- /pytorch_lattice/utils/models.py: -------------------------------------------------------------------------------- 1 | """Utility functions for use in model classes.""" 2 | from typing import Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ..enums import ( 8 | CategoricalCalibratorInit, 9 | Monotonicity, 10 | NumericalCalibratorInit, 11 | ) 12 | from ..layers.categorical_calibrator import CategoricalCalibrator 13 | from ..layers.numerical_calibrator import NumericalCalibrator 14 | from ..models.features import CategoricalFeature, NumericalFeature 15 | 16 | 17 | def initialize_feature_calibrators( 18 | features: list[Union[NumericalFeature, CategoricalFeature]], 19 | output_min: Optional[float] = None, 20 | output_max: Union[Optional[float], list[Optional[float]]] = None, 21 | ) -> torch.nn.ModuleDict: 22 | """Helper function to initialize calibrators for calibrated model. 23 | 24 | Args: 25 | features: A list of numerical and/or categorical feature configs. 26 | output_min: The minimum output value for the model. If `None`, the minimum 27 | output value will be unbounded. 28 | output_max: A list of maximum output value for each feature of the model. If 29 | `None`, the maximum output value will be unbounded. If a singular value, it 30 | will be taken as the maximum of all features. 31 | 32 | Returns: 33 | A `torch.nn.ModuleDict` of calibrators accessible by each feature's name. 34 | 35 | Raises: 36 | ValueError: If any feature configs are not `NUMERICAL` or `CATEGORICAL`. 37 | """ 38 | calibrators = torch.nn.ModuleDict() 39 | if not isinstance(output_max, list): 40 | output_max = [output_max] * len(features) 41 | for feature, feature_max in zip(features, output_max): 42 | if isinstance(feature, NumericalFeature): 43 | calibrators[feature.feature_name] = NumericalCalibrator( 44 | input_keypoints=feature.input_keypoints, 45 | missing_input_value=feature.missing_input_value, 46 | output_min=output_min, 47 | output_max=feature_max, 48 | monotonicity=feature.monotonicity, 49 | kernel_init=NumericalCalibratorInit.EQUAL_SLOPES, 50 | projection_iterations=feature.projection_iterations, 51 | ) 52 | elif isinstance(feature, CategoricalFeature): 53 | calibrators[feature.feature_name] = CategoricalCalibrator( 54 | num_categories=len(feature.categories), 55 | missing_input_value=feature.missing_input_value, 56 | output_min=output_min, 57 | output_max=feature_max, 58 | monotonicity_pairs=feature.monotonicity_index_pairs, 59 | kernel_init=CategoricalCalibratorInit.UNIFORM, 60 | ) 61 | else: 62 | raise ValueError(f"Unknown type {type(feature)} for feature {feature}") 63 | return calibrators 64 | 65 | 66 | def initialize_monotonicities( 67 | features: list[Union[NumericalFeature, CategoricalFeature]] 68 | ) -> list[Optional[Monotonicity]]: 69 | """Helper function to initialize monotonicities for calibrated model. 70 | 71 | Args: 72 | features: A list of numerical and/or categorical feature configs. 73 | 74 | Returns: 75 | A list of `None` or `Monotonicity.INCREASING` based on whether 76 | each feature has a monotonicity or not. 77 | """ 78 | monotonicities = [ 79 | None 80 | if (isinstance(feature, CategoricalFeature) and not feature.monotonicity_pairs) 81 | or (isinstance(feature, NumericalFeature) and feature.monotonicity is None) 82 | else Monotonicity.INCREASING 83 | for feature in features 84 | ] 85 | return monotonicities 86 | 87 | 88 | def initialize_output_calibrator( 89 | monotonic: bool, 90 | output_calibration_num_keypoints: Optional[int], 91 | output_min: Optional[float] = None, 92 | output_max: Optional[float] = None, 93 | ) -> Optional[NumericalCalibrator]: 94 | """Helper function to initialize output calibrator for calibrated model. 95 | 96 | Args: 97 | monotonic: Whether output calibrator should have monotonicity constraint. 98 | output_calibration_num_keypoints: The number of keypoints in output 99 | calibrator. If `0` or `None`, no output calibrator will be returned. 100 | output_min: The minimum output value for the model. If `None`, the minimum 101 | output value will be unbounded. 102 | output_max: The maximum output value for the model. If `None`, the maximum 103 | output value will be unbounded. 104 | 105 | Returns: 106 | A `torch.nn.ModuleDict` of calibrators accessible by each feature's name. 107 | 108 | Raises: 109 | ValueError: If any feature configs are not `NUMERICAL` or `CATEGORICAL`. 110 | """ 111 | if output_calibration_num_keypoints: 112 | output_calibrator = NumericalCalibrator( 113 | input_keypoints=np.linspace(0.0, 1.0, num=output_calibration_num_keypoints), 114 | missing_input_value=None, 115 | output_min=output_min, 116 | output_max=output_max, 117 | monotonicity=Monotonicity.INCREASING if monotonic else None, 118 | kernel_init=NumericalCalibratorInit.EQUAL_HEIGHTS, 119 | ) 120 | return output_calibrator 121 | return None 122 | 123 | 124 | def calibrate_and_stack( 125 | x: torch.Tensor, 126 | calibrators: torch.nn.ModuleDict, 127 | ) -> torch.Tensor: 128 | """Helper function to run calibrators along columns of given data. 129 | 130 | Args: 131 | x: The input tensor of feature values of shape `(batch_size, num_features)`. 132 | calibrators: A dictionary of calibrator functions. 133 | 134 | Returns: 135 | A torch.Tensor resulting from applying the calibrators and stacking the results. 136 | """ 137 | return torch.column_stack( 138 | tuple( 139 | calibrator(x[:, i, None]) 140 | for i, calibrator in enumerate(calibrators.values()) 141 | ) 142 | ) 143 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | mypy>=^1.6.1 2 | pytest 3 | ruff>=0.1.5 4 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | mkdocs-material 2 | mkdocstrings 3 | mkdocstrings-python 4 | mike 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.7.1 2 | numpy>=1.23.5 3 | pandas>=1.5.3 4 | pydantic>=2.0.2 5 | torch>=2.0.0,!=2.0.1,!=2.1.0 6 | tqdm>=4.65.0 7 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/tests/__init__.py -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/tests/layers/__init__.py -------------------------------------------------------------------------------- /tests/layers/test_linear.py: -------------------------------------------------------------------------------- 1 | """Tests for Linear module.""" 2 | import numpy as np 3 | import pytest 4 | import torch 5 | 6 | from pytorch_lattice import Monotonicity 7 | from pytorch_lattice.layers import Linear 8 | 9 | from ..testing_utils import train_calibrated_module 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "input_dim,monotonicities,use_bias,weighted_average", 14 | [(5, None, True, False), (5, None, True, True)], 15 | ) 16 | def test_initialization(input_dim, monotonicities, use_bias, weighted_average) -> None: 17 | """Tests that Linear initialization works properly""" 18 | linear = Linear(input_dim, monotonicities, use_bias, weighted_average) 19 | assert linear.input_dim == input_dim 20 | assert linear.monotonicities == ( 21 | monotonicities 22 | if not weighted_average 23 | else [Monotonicity.INCREASING] * input_dim 24 | ) 25 | assert linear.use_bias == (use_bias if not weighted_average else False) 26 | assert linear.weighted_average == weighted_average 27 | assert torch.allclose( 28 | linear.kernel.data, 29 | torch.tensor([[1.0 / input_dim] * input_dim]).double(), 30 | ) 31 | if use_bias: 32 | assert linear.bias.data.size() == torch.Size([1]) 33 | assert torch.all(linear.bias.data == 0.0) 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "kernel_data,bias_data,inputs,expected_outputs", 38 | [ 39 | ( 40 | torch.tensor([[1.0], [2.0], [3.0]]).double(), 41 | None, 42 | torch.tensor([[1.0, 1.0, 1.0], [3.0, 2.0, 1.0], [1.0, -2.0, 3.0]]).double(), 43 | torch.tensor([[6.0], [10.0], [6.0]]).double(), 44 | ), 45 | ( 46 | torch.tensor([[1.0], [2.0], [1.0]]).double(), 47 | torch.tensor([-1.0]).double(), 48 | torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 1.0], [4.0, -1.0, 2.0]]).double(), 49 | torch.tensor([[7.0], [8.0], [3.0]]).double(), 50 | ), 51 | ], 52 | ) 53 | def test_forward(kernel_data, bias_data, inputs, expected_outputs) -> None: 54 | """Tests that forward properly combined inputs.""" 55 | linear = Linear(kernel_data.size()[0], use_bias=bias_data is not None) 56 | linear.kernel.data = kernel_data 57 | if bias_data is not None: 58 | linear.bias.data = bias_data 59 | outputs = linear(inputs) 60 | assert torch.allclose(outputs, expected_outputs) 61 | 62 | 63 | @pytest.mark.parametrize( 64 | "monotonicities,kernel_data,expected_out", 65 | [ 66 | ( 67 | [Monotonicity.INCREASING, Monotonicity.INCREASING, Monotonicity.INCREASING], 68 | torch.tensor([[0.2], [0.1], [0.2]]).double(), 69 | [], 70 | ), 71 | ( 72 | [Monotonicity.INCREASING, Monotonicity.INCREASING, Monotonicity.INCREASING], 73 | torch.tensor([[0.2], [-0.01], [0.2]]).double(), 74 | [], 75 | ), 76 | ( 77 | [Monotonicity.DECREASING, Monotonicity.DECREASING, Monotonicity.DECREASING], 78 | torch.tensor([[-0.2], [-0.1], [-0.2]]).double(), 79 | [], 80 | ), 81 | ( 82 | [Monotonicity.DECREASING, Monotonicity.DECREASING, Monotonicity.DECREASING], 83 | torch.tensor([[-0.2], [0.01], [-0.2]]).double(), 84 | [], 85 | ), 86 | ( 87 | [Monotonicity.INCREASING, Monotonicity.DECREASING, Monotonicity.INCREASING], 88 | torch.tensor([[-0.2], [0.2], [-0.2]]).double(), 89 | ["Monotonicity violated at: [0, 1, 2]"], 90 | ), 91 | ( 92 | [ 93 | None, 94 | None, 95 | Monotonicity.DECREASING, 96 | Monotonicity.INCREASING, 97 | ], 98 | torch.tensor([[1.5], [-1.5], [0.01], [-0.01]]).double(), 99 | [], 100 | ), 101 | ], 102 | ) 103 | def test_assert_constraints_monotonicty( 104 | monotonicities, kernel_data, expected_out 105 | ) -> None: 106 | """Tests that assert_constraints properly checks monotonicity.""" 107 | linear = Linear( 108 | kernel_data.size()[0], monotonicities=monotonicities, weighted_average=False 109 | ) 110 | linear.kernel.data = kernel_data 111 | assert linear.assert_constraints(eps=0.05) == expected_out 112 | 113 | 114 | @pytest.mark.parametrize( 115 | "monotonicities,kernel_data,expected_out", 116 | [ 117 | ( 118 | [Monotonicity.INCREASING, Monotonicity.INCREASING], 119 | torch.tensor([[0.4], [0.6]]).double(), 120 | [], 121 | ), 122 | ( 123 | [Monotonicity.INCREASING, Monotonicity.INCREASING], 124 | torch.tensor([[0.4], [0.61]]).double(), 125 | [], 126 | ), 127 | ( 128 | [Monotonicity.INCREASING, Monotonicity.DECREASING], 129 | torch.tensor([[1.5], [-0.5]]).double(), 130 | [], 131 | ), 132 | ( 133 | [Monotonicity.INCREASING, Monotonicity.DECREASING], 134 | torch.tensor([[1.5], [-0.51]]).double(), 135 | [], 136 | ), 137 | ( 138 | [Monotonicity.INCREASING, Monotonicity.DECREASING, None], 139 | torch.tensor([[1.5], [-2.2], [1.7]]).double(), 140 | [], 141 | ), 142 | ( 143 | [Monotonicity.INCREASING, Monotonicity.DECREASING, None], 144 | torch.tensor([[1.5], [-2.2], [2.7]]).double(), 145 | ["Weights do not sum to 1."], 146 | ), 147 | ], 148 | ) 149 | def test_assert_constraints_weighted_average( 150 | monotonicities, kernel_data, expected_out 151 | ) -> None: 152 | """Tests assert_constraints checks weights sum to 1 when weighted_average=True.""" 153 | linear = Linear( 154 | kernel_data.size()[0], monotonicities=monotonicities, weighted_average=True 155 | ) 156 | linear.kernel.data = kernel_data 157 | linear.monotonicities = monotonicities 158 | assert linear.assert_constraints(eps=0.05) == expected_out 159 | 160 | 161 | @pytest.mark.parametrize( 162 | "monotonicities,kernel_data,expected_out", 163 | [ 164 | ( 165 | [Monotonicity.INCREASING, Monotonicity.INCREASING], 166 | torch.tensor([[0.4], [0.6]]).double(), 167 | [], 168 | ), 169 | ( 170 | [Monotonicity.DECREASING, None, Monotonicity.DECREASING], 171 | torch.tensor([[0.4], [0.01], [0.6]]).double(), 172 | ["Monotonicity violated at: [0, 2]"], 173 | ), 174 | ( 175 | [Monotonicity.INCREASING, None, Monotonicity.DECREASING], 176 | torch.tensor([[-0.5], [2.0], [0.5]]).double(), 177 | ["Weights do not sum to 1.", "Monotonicity violated at: [0, 2]"], 178 | ), 179 | ], 180 | ) 181 | def test_assert_constraints_combo(monotonicities, kernel_data, expected_out) -> None: 182 | """Tests asserts_constraints for both monotonicity and weighed_average.""" 183 | linear = Linear( 184 | kernel_data.size()[0], monotonicities=monotonicities, weighted_average=True 185 | ) 186 | linear.kernel.data = kernel_data 187 | linear.monotonicities = monotonicities 188 | assert linear.assert_constraints(eps=0.05) == expected_out 189 | 190 | 191 | @pytest.mark.parametrize( 192 | "monotonicities,kernel_data,bias_data", 193 | [ 194 | (None, torch.tensor([[1.2], [2.5], [3.1]]).double(), None), 195 | ( 196 | None, 197 | torch.tensor([[1.2], [2.5], [3.1]]).double(), 198 | torch.tensor([1.0]).double(), 199 | ), 200 | ( 201 | [None, None, None], 202 | torch.tensor([[1.2], [2.5], [3.1]]).double(), 203 | torch.tensor([1.0]).double(), 204 | ), 205 | ( 206 | [None, None, None], 207 | torch.tensor([[1.2], [2.5], [3.1]]).double(), 208 | torch.tensor([1.0]).double(), 209 | ), 210 | ], 211 | ) 212 | def test_constrain_no_constraints(monotonicities, kernel_data, bias_data) -> None: 213 | """Tests that constrain does nothing when there are no constraints.""" 214 | linear = Linear(kernel_data.size()[0], monotonicities=monotonicities) 215 | linear.kernel.data = kernel_data 216 | if bias_data is not None: 217 | linear.bias.data = bias_data 218 | linear.apply_constraints() 219 | assert torch.allclose(linear.kernel.data, kernel_data) 220 | if bias_data is not None: 221 | assert torch.allclose(linear.bias.data, bias_data) 222 | 223 | 224 | @pytest.mark.parametrize( 225 | "monotonicities,kernel_data,expected_projected_kernel_data", 226 | [ 227 | ( 228 | [ 229 | None, 230 | Monotonicity.INCREASING, 231 | Monotonicity.DECREASING, 232 | ], 233 | torch.tensor([[1.0], [-0.2], [0.2]]).double(), 234 | torch.tensor([[1.0], [0.0], [0.0]]).double(), 235 | ), 236 | ( 237 | [ 238 | None, 239 | Monotonicity.INCREASING, 240 | None, 241 | ], 242 | torch.tensor([[1.0], [0.2], [-2.0]]).double(), 243 | torch.tensor([[1.0], [0.2], [-2.0]]).double(), 244 | ), 245 | ( 246 | [ 247 | Monotonicity.DECREASING, 248 | Monotonicity.DECREASING, 249 | ], 250 | torch.tensor([[-1.0], [0.2]]).double(), 251 | torch.tensor([[-1.0], [0.0]]).double(), 252 | ), 253 | ( 254 | [ 255 | Monotonicity.INCREASING, 256 | Monotonicity.INCREASING, 257 | ], 258 | torch.tensor([[-1.0], [1.0]]).double(), 259 | torch.tensor([[0.0], [1.0]]).double(), 260 | ), 261 | ], 262 | ) 263 | def test_constrain_monotonicities( 264 | monotonicities, kernel_data, expected_projected_kernel_data 265 | ) -> None: 266 | """Tests that constrain properly projects kernel according to monotonicies.""" 267 | linear = Linear(kernel_data.size()[0], monotonicities=monotonicities) 268 | linear.kernel.data = kernel_data 269 | linear.apply_constraints() 270 | assert torch.allclose(linear.kernel.data, expected_projected_kernel_data) 271 | 272 | 273 | @pytest.mark.parametrize( 274 | "kernel_data,expected_projected_kernel_data", 275 | [ 276 | ( 277 | torch.tensor([[1.0], [2.0], [3.0]]).double(), 278 | torch.tensor([[1 / 6], [2 / 6], [0.5]]).double(), 279 | ), 280 | ( 281 | torch.tensor([[2.0], [-1.0], [1.0], [3.0]]).double(), 282 | torch.tensor([[2 / 6], [0.0], [1 / 6], [0.5]]).double(), 283 | ), 284 | ], 285 | ) 286 | def test_constrain_weighted_average( 287 | kernel_data, expected_projected_kernel_data 288 | ) -> None: 289 | """Tests that constrain properly projects kernel to be a weighted average.""" 290 | linear = Linear(kernel_data.size()[0], weighted_average=True) 291 | linear.kernel.data = kernel_data 292 | linear.apply_constraints() 293 | assert torch.allclose(linear.kernel.data, expected_projected_kernel_data) 294 | 295 | 296 | def test_training() -> None: 297 | """Tests that the `Linear` module can learn f(x_1,x_2) = 2x_1 + 3x_2""" 298 | num_examples = 1000 299 | input_min, input_max = 0.0, 10.0 300 | training_examples = torch.from_numpy( 301 | np.random.uniform(input_min, input_max, size=(1000, 2)) 302 | ) 303 | linear_coefficients = torch.tensor([2.0, 3.0]).double() 304 | training_labels = torch.sum( 305 | linear_coefficients * training_examples, dim=1, keepdim=True 306 | ) 307 | 308 | linear = Linear(2, use_bias=False) 309 | loss_fn = torch.nn.MSELoss() 310 | optimizer = torch.optim.Adam(linear.parameters(), lr=1e-2) 311 | 312 | train_calibrated_module( 313 | linear, 314 | training_examples, 315 | training_labels, 316 | loss_fn, 317 | optimizer, 318 | 300, 319 | num_examples // 10, 320 | ) 321 | 322 | assert torch.allclose(torch.squeeze(linear.kernel.data), linear_coefficients) 323 | -------------------------------------------------------------------------------- /tests/layers/test_rtl.py: -------------------------------------------------------------------------------- 1 | """Tests for RTL layer.""" 2 | from itertools import cycle 3 | from unittest.mock import Mock, patch 4 | 5 | import pytest 6 | import torch 7 | 8 | from pytorch_lattice import Interpolation, LatticeInit, Monotonicity 9 | from pytorch_lattice.layers import RTL, Lattice 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "monotonicities, num_lattices, lattice_rank, output_min, output_max, kernel_init," 14 | "clip_inputs, interpolation, average_outputs", 15 | [ 16 | ( 17 | [ 18 | None, 19 | None, 20 | None, 21 | None, 22 | ], 23 | 3, 24 | 3, 25 | None, 26 | 2.0, 27 | LatticeInit.LINEAR, 28 | True, 29 | Interpolation.HYPERCUBE, 30 | True, 31 | ), 32 | ( 33 | [ 34 | Monotonicity.INCREASING, 35 | Monotonicity.INCREASING, 36 | None, 37 | None, 38 | ], 39 | 3, 40 | 3, 41 | -1.0, 42 | 4.0, 43 | LatticeInit.LINEAR, 44 | False, 45 | Interpolation.SIMPLEX, 46 | False, 47 | ), 48 | ( 49 | [Monotonicity.INCREASING, None] * 25, 50 | 20, 51 | 5, 52 | None, 53 | None, 54 | LatticeInit.LINEAR, 55 | True, 56 | Interpolation.HYPERCUBE, 57 | True, 58 | ), 59 | ], 60 | ) 61 | def test_initialization( 62 | monotonicities, 63 | num_lattices, 64 | lattice_rank, 65 | output_min, 66 | output_max, 67 | kernel_init, 68 | clip_inputs, 69 | interpolation, 70 | average_outputs, 71 | ): 72 | """Tests that RTL Initialization works properly.""" 73 | rtl = RTL( 74 | monotonicities=monotonicities, 75 | num_lattices=num_lattices, 76 | lattice_rank=lattice_rank, 77 | output_min=output_min, 78 | output_max=output_max, 79 | kernel_init=kernel_init, 80 | clip_inputs=clip_inputs, 81 | interpolation=interpolation, 82 | average_outputs=average_outputs, 83 | ) 84 | assert rtl.monotonicities == monotonicities 85 | assert rtl.num_lattices == num_lattices 86 | assert rtl.lattice_rank == lattice_rank 87 | assert rtl.output_min == output_min 88 | assert rtl.output_max == output_max 89 | assert rtl.kernel_init == kernel_init 90 | assert rtl.interpolation == interpolation 91 | assert rtl.average_outputs == average_outputs 92 | 93 | total_lattices = 0 94 | for monotonic_count, (lattice, group) in rtl._lattice_layers.items(): 95 | # test monotonic features have been sorted to front of list for lattice indices 96 | for single_lattice_indices in group: 97 | for i in range(lattice_rank): 98 | if i < monotonic_count: 99 | assert ( 100 | rtl.monotonicities[single_lattice_indices[i]] 101 | == Monotonicity.INCREASING 102 | ) 103 | else: 104 | assert rtl.monotonicities[single_lattice_indices[i]] is None 105 | 106 | assert len(lattice.monotonicities) == len(lattice.lattice_sizes) 107 | assert ( 108 | sum(1 for _ in lattice.monotonicities if _ == Monotonicity.INCREASING) 109 | == monotonic_count 110 | ) 111 | assert lattice.output_min == rtl.output_min 112 | assert lattice.output_max == rtl.output_max 113 | assert lattice.kernel_init == rtl.kernel_init 114 | assert lattice.clip_inputs == rtl.clip_inputs 115 | assert lattice.interpolation == rtl.interpolation 116 | 117 | # test number of lattices created is correct 118 | total_lattices += lattice.units 119 | 120 | assert total_lattices == num_lattices 121 | 122 | 123 | @pytest.mark.parametrize( 124 | "monotonicities, num_lattices, lattice_rank", 125 | [ 126 | ([None] * 9, 2, 2), 127 | ([Monotonicity.INCREASING] * 10, 3, 3), 128 | ], 129 | ) 130 | def test_initialization_invalid( 131 | monotonicities, 132 | num_lattices, 133 | lattice_rank, 134 | ): 135 | """Tests that RTL Initialization raises error when RTL is too small.""" 136 | with pytest.raises(ValueError) as exc_info: 137 | RTL( 138 | monotonicities=monotonicities, 139 | num_lattices=num_lattices, 140 | lattice_rank=lattice_rank, 141 | ) 142 | assert ( 143 | str(exc_info.value) 144 | == f"RTL with {num_lattices}x{lattice_rank}D structure cannot support " 145 | + f"{len(monotonicities)} input features." 146 | ) 147 | 148 | 149 | @pytest.mark.parametrize( 150 | "num_features, num_lattices, lattice_rank, units, expected_lattice_args," 151 | "expected_result, expected_avg", 152 | [ 153 | ( 154 | 6, 155 | 6, 156 | 3, 157 | [3, 2, 1], 158 | [ 159 | torch.tensor([[[0.0, 0.1, 0.2], [0.3, 0.4, 0.5], [0.0, 0.1, 0.2]]]), 160 | torch.tensor([[[0.3, 0.4, 0.5], [0.0, 0.1, 0.2]]]), 161 | torch.tensor([[0.3, 0.4, 0.5]]), 162 | ], 163 | torch.tensor([[0.0, 0.0, 0.0, 1.0, 1.0, 2.0]]), 164 | torch.tensor([[2 / 3]]), 165 | ), 166 | ( 167 | 3, 168 | 3, 169 | 2, 170 | [1, 1, 1], 171 | [ 172 | torch.tensor([[0.0, 0.1]]), 173 | torch.tensor([[0.2, 0.0]]), 174 | torch.tensor([[0.1, 0.2]]), 175 | ], 176 | torch.tensor([[0.0, 1.0, 2.0]]), 177 | torch.tensor([[1.0]]), 178 | ), 179 | ( 180 | 6, 181 | 7, 182 | 5, 183 | [2, 3, 2], 184 | [ 185 | torch.tensor([[[0.0, 0.1, 0.2, 0.3, 0.4], [0.5, 0.0, 0.1, 0.2, 0.3]]]), 186 | torch.tensor( 187 | [ 188 | [ 189 | [0.4, 0.5, 0.0, 0.1, 0.2], 190 | [0.3, 0.4, 0.5, 0.0, 0.1], 191 | [0.2, 0.3, 0.4, 0.5, 0.0], 192 | ] 193 | ] 194 | ), 195 | torch.tensor([[[0.1, 0.2, 0.3, 0.4, 0.5], [0.0, 0.1, 0.2, 0.3, 0.4]]]), 196 | ], 197 | torch.tensor([[0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]]), 198 | torch.tensor([[1.0]]), 199 | ), 200 | ], 201 | ) 202 | def test_forward( 203 | num_features, 204 | num_lattices, 205 | lattice_rank, 206 | units, 207 | expected_lattice_args, 208 | expected_result, 209 | expected_avg, 210 | ): 211 | """Tests forward function of RTL Lattice.""" 212 | rtl = RTL( 213 | monotonicities=[None, Monotonicity.INCREASING] * (num_features // 2), 214 | num_lattices=num_lattices, 215 | lattice_rank=lattice_rank, 216 | ) 217 | # generate indices for each lattice in cyclic fashion based off units 218 | groups = [] 219 | feature_indices = cycle(range(num_features)) 220 | for lattice_units in units: 221 | group = [ 222 | [next(feature_indices) for _ in range(lattice_rank)] 223 | for _ in range(lattice_units) 224 | ] 225 | groups.append(group) 226 | lattice_indices = {i: groups[i % len(groups)] for i in range(len(units))} 227 | rtl._lattice_layers = { 228 | i: (Lattice(lattice_sizes=[2] * lattice_rank, units=unit), lattice_indices[i]) 229 | for i, unit in enumerate(units) 230 | } 231 | 232 | mock_forwards = [] 233 | for monotonic_count, (lattice, _) in rtl._lattice_layers.items(): 234 | mock_forward = Mock() 235 | lattice.forward = mock_forward 236 | mock_forward.return_value = torch.full( 237 | (1, units[monotonic_count]), 238 | float(monotonic_count), 239 | dtype=torch.float32, 240 | ) 241 | mock_forwards.append(mock_forward) 242 | 243 | x = torch.arange(0, num_features * 0.1, 0.1).unsqueeze(0) 244 | result = rtl.forward(x) 245 | 246 | # Assert the calls and results for each mock_forward based on expected_outs 247 | for i, mock_forward in enumerate(mock_forwards): 248 | mock_forward.assert_called_once() 249 | assert torch.allclose( 250 | mock_forward.call_args[0][0], expected_lattice_args[i], atol=1e-6 251 | ) 252 | assert torch.allclose(result, expected_result) 253 | rtl.average_outputs = True 254 | result = rtl.forward(x) 255 | assert torch.allclose(result, expected_avg) 256 | 257 | 258 | @pytest.mark.parametrize( 259 | "monotonic_counts, units, expected_out", 260 | [ 261 | ( 262 | [0, 1, 2, 3], 263 | [2, 1, 1, 1], 264 | [None] * 2 + [Monotonicity.INCREASING] * 3, 265 | ), 266 | ( 267 | [0, 4, 5, 7], 268 | [1, 2, 3, 4], 269 | [None] + [Monotonicity.INCREASING] * 9, 270 | ), 271 | ([0], [3], [None] * 3), 272 | ([1, 2, 3], [1, 1, 1], [Monotonicity.INCREASING] * 3), 273 | ], 274 | ) 275 | def test_output_monotonicities( 276 | monotonic_counts, 277 | units, 278 | expected_out, 279 | ): 280 | """Tests output_monotonicities function.""" 281 | rtl = RTL( 282 | monotonicities=[None, Monotonicity.INCREASING], 283 | num_lattices=3, 284 | lattice_rank=3, 285 | ) 286 | rtl._lattice_layers = { 287 | monotonic_count: (Lattice(lattice_sizes=[2, 2], units=units[i]), []) 288 | for i, monotonic_count in enumerate(monotonic_counts) 289 | } 290 | assert rtl.output_monotonicities() == expected_out 291 | 292 | 293 | def test_apply_constraints(): 294 | """Tests RTL apply_constraints function.""" 295 | rtl = RTL( 296 | monotonicities=[None, Monotonicity.INCREASING], 297 | num_lattices=3, 298 | lattice_rank=3, 299 | ) 300 | mock_constrains = [] 301 | for lattice, _ in rtl._lattice_layers.values(): 302 | mock_constrain = Mock() 303 | lattice.apply_constraints = mock_constrain 304 | mock_constrains.append(mock_constrain) 305 | 306 | rtl.apply_constraints() 307 | for mock_constrain in mock_constrains: 308 | mock_constrain.assert_called_once() 309 | 310 | 311 | def test_assert_constraints(): 312 | """Tests RTL assert_constraints function.""" 313 | rtl = RTL( 314 | monotonicities=[None, Monotonicity.INCREASING], 315 | num_lattices=3, 316 | lattice_rank=3, 317 | ) 318 | mock_asserts = [] 319 | for lattice, _ in rtl._lattice_layers.values(): 320 | mock_assert = Mock() 321 | lattice.assert_constraints = mock_assert 322 | mock_assert.return_value = "violation" 323 | mock_asserts.append(mock_assert) 324 | 325 | violations = rtl.assert_constraints() 326 | for mock_assert in mock_asserts: 327 | mock_assert.assert_called_once() 328 | 329 | assert violations == ["violation"] * len(rtl._lattice_layers) 330 | 331 | 332 | @pytest.mark.parametrize( 333 | "rtl_indices", 334 | [ 335 | [[1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [6, 6]], 336 | [[1, 1, 1], [2, 2, 2], [3, 3, 3]], 337 | [[1, 1, 1], [2, 2, 2], [1, 2, 3], [3, 3, 3]], 338 | [ 339 | [1, 1, 2], 340 | [2, 3, 4], 341 | [1, 5, 5], 342 | [4, 6, 7], 343 | [1, 3, 4], 344 | [2, 3, 3], 345 | [4, 5, 6], 346 | [6, 6, 6], 347 | ], 348 | ], 349 | ) 350 | def test_ensure_unique_sublattices_possible(rtl_indices): 351 | """Tests _ensure_unique_sublattices removes duplicates from groups when possible.""" 352 | swapped_indices = RTL._ensure_unique_sublattices(rtl_indices) 353 | for group in swapped_indices: 354 | assert len(set(group)) == len(group) 355 | 356 | 357 | @pytest.mark.parametrize( 358 | "rtl_indices, max_swaps", 359 | [ 360 | ([[1, 1], [1, 2], [1, 3]], 100), 361 | ([[1, 1], [2, 2], [3, 3], [4, 4]], 2), 362 | ], 363 | ) 364 | def test_ensure_unique_sublattices_impossible(rtl_indices, max_swaps): 365 | """Tests _ensure_unique_sublattices logs when it can't remove duplicates.""" 366 | with patch("logging.info") as mock_logging_info: 367 | RTL._ensure_unique_sublattices(rtl_indices, max_swaps) 368 | mock_logging_info.assert_called_with( 369 | "Some lattices in RTL may use the same feature multiple times." 370 | ) 371 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willbakst/pytorch-lattice/d444dba62d1c74708e0dc23f6d60c97785df46b6/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_calibrated_lattice.py: -------------------------------------------------------------------------------- 1 | """Tests for calibrated lattice model.""" 2 | from unittest.mock import Mock, patch 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | 8 | from pytorch_lattice import Interpolation, LatticeInit, Monotonicity 9 | from pytorch_lattice.layers import Lattice, NumericalCalibrator 10 | from pytorch_lattice.models import CalibratedLattice 11 | from pytorch_lattice.models.features import CategoricalFeature, NumericalFeature 12 | 13 | from ..testing_utils import train_calibrated_module 14 | 15 | 16 | def test_init_required_args(): 17 | """Tests `CalibratedLattice` initialization with only required arguments.""" 18 | calibrated_lattice = CalibratedLattice( 19 | features=[ 20 | NumericalFeature( 21 | feature_name="numerical_feature", 22 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 23 | num_keypoints=5, 24 | monotonicity=None, 25 | ), 26 | CategoricalFeature( 27 | feature_name="categorical_feature", 28 | categories=["a", "b", "c"], 29 | monotonicity_pairs=[("a", "b")], 30 | ), 31 | ] 32 | ) 33 | assert calibrated_lattice.clip_inputs 34 | assert calibrated_lattice.output_min is None 35 | assert calibrated_lattice.output_max is None 36 | assert calibrated_lattice.kernel_init == LatticeInit.LINEAR 37 | assert calibrated_lattice.interpolation == Interpolation.HYPERCUBE 38 | assert calibrated_lattice.lattice.lattice_sizes == [2, 2] 39 | assert calibrated_lattice.output_calibration_num_keypoints is None 40 | assert calibrated_lattice.output_calibrator is None 41 | for calibrator in calibrated_lattice.calibrators.values(): 42 | assert calibrator.output_min == 0.0 43 | assert calibrator.output_max == 1.0 44 | 45 | 46 | @pytest.mark.parametrize( 47 | "features, output_min, output_max, interpolation, output_num_keypoints," 48 | "expected_monotonicity, expected_lattice_sizes, expected_output_monotonicity", 49 | [ 50 | ( 51 | [ 52 | NumericalFeature( 53 | feature_name="numerical_feature", 54 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 55 | num_keypoints=5, 56 | monotonicity=Monotonicity.DECREASING, 57 | lattice_size=3, 58 | ), 59 | CategoricalFeature( 60 | feature_name="categorical_feature", 61 | categories=["a", "b", "c"], 62 | monotonicity_pairs=[("a", "b")], 63 | lattice_size=2, 64 | ), 65 | ], 66 | 0.5, 67 | 2.0, 68 | Interpolation.SIMPLEX, 69 | 4, 70 | [Monotonicity.INCREASING, Monotonicity.INCREASING], 71 | [3, 2], 72 | Monotonicity.INCREASING, 73 | ), 74 | ( 75 | [ 76 | NumericalFeature( 77 | feature_name="numerical_feature", 78 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 79 | num_keypoints=5, 80 | lattice_size=4, 81 | ), 82 | CategoricalFeature( 83 | feature_name="categorical_feature", 84 | categories=["a", "b", "c"], 85 | lattice_size=4, 86 | ), 87 | ], 88 | -0.5, 89 | 8.0, 90 | Interpolation.HYPERCUBE, 91 | 5, 92 | [None, None], 93 | [4, 4], 94 | None, 95 | ), 96 | ], 97 | ) 98 | def test_init_full_args( 99 | features, 100 | output_min, 101 | output_max, 102 | interpolation, 103 | output_num_keypoints, 104 | expected_monotonicity, 105 | expected_lattice_sizes, 106 | expected_output_monotonicity, 107 | ): 108 | """Tests `CalibratedLattice` initialization with all arguments.""" 109 | calibrated_lattice = CalibratedLattice( 110 | features=features, 111 | output_min=output_min, 112 | output_max=output_max, 113 | interpolation=interpolation, 114 | output_calibration_num_keypoints=output_num_keypoints, 115 | ) 116 | assert calibrated_lattice.clip_inputs 117 | assert calibrated_lattice.output_min == output_min 118 | assert calibrated_lattice.output_max == output_max 119 | assert calibrated_lattice.interpolation == interpolation 120 | assert calibrated_lattice.output_calibration_num_keypoints == output_num_keypoints 121 | assert calibrated_lattice.output_calibrator.output_min == output_min 122 | assert calibrated_lattice.output_calibrator.output_max == output_max 123 | assert ( 124 | calibrated_lattice.output_calibrator.monotonicity 125 | == expected_output_monotonicity 126 | ) 127 | assert calibrated_lattice.monotonicities == expected_monotonicity 128 | assert calibrated_lattice.lattice.lattice_sizes == expected_lattice_sizes 129 | for calibrator, lattice_dim in zip( 130 | calibrated_lattice.calibrators.values(), expected_lattice_sizes 131 | ): 132 | assert calibrator.output_min == 0.0 133 | assert calibrator.output_max == lattice_dim - 1 134 | 135 | 136 | def test_forward(): 137 | """Tests all parts of calibrated lattice forward pass are called.""" 138 | calibrated_lattice = CalibratedLattice( 139 | features=[ 140 | NumericalFeature( 141 | feature_name="n", 142 | data=np.array([1.0, 2.0]), 143 | ), 144 | CategoricalFeature( 145 | feature_name="c", 146 | categories=["a", "b", "c"], 147 | ), 148 | ], 149 | output_calibration_num_keypoints=10, 150 | ) 151 | with patch( 152 | "pytorch_lattice.models.calibrated_lattice.calibrate_and_stack", 153 | ) as mock_calibrate_and_stack, patch.object( 154 | calibrated_lattice.lattice, 155 | "forward", 156 | ) as mock_lattice_forward, patch.object( 157 | calibrated_lattice.output_calibrator, 158 | "forward", 159 | ) as mock_output_calibrator: 160 | mock_calibrate_and_stack.return_value = torch.rand((1, 1)) 161 | mock_lattice_forward.return_value = torch.rand((1, 1)) 162 | mock_output_calibrator.return_value = torch.rand((1, 1)) 163 | input_tensor = torch.rand((1, 2)) 164 | 165 | result = calibrated_lattice.forward(input_tensor) 166 | 167 | mock_calibrate_and_stack.assert_called_once() 168 | assert torch.allclose(mock_calibrate_and_stack.call_args[0][0], input_tensor) 169 | assert ( 170 | mock_calibrate_and_stack.call_args[0][1] == calibrated_lattice.calibrators 171 | ) 172 | mock_lattice_forward.assert_called_once() 173 | assert torch.allclose( 174 | mock_lattice_forward.call_args[0][0], mock_calibrate_and_stack.return_value 175 | ) 176 | mock_output_calibrator.assert_called_once() 177 | assert torch.allclose( 178 | mock_output_calibrator.call_args[0][0], mock_lattice_forward.return_value 179 | ) 180 | assert torch.allclose(result, mock_output_calibrator.return_value) 181 | 182 | 183 | @pytest.mark.parametrize( 184 | "interpolation", 185 | [ 186 | Interpolation.HYPERCUBE, 187 | Interpolation.SIMPLEX, 188 | ], 189 | ) 190 | @pytest.mark.parametrize( 191 | "lattice_dim", 192 | [ 193 | 2, 194 | 3, 195 | 4, 196 | ], 197 | ) 198 | def test_training(interpolation, lattice_dim): 199 | """Tests `CalibratedLattice` training on data from f(x) = 0.7|x_1| + 0.3x_2.""" 200 | num_examples, num_categories = 3000, 3 201 | output_min, output_max = 0.0, num_categories - 1 202 | x_1_numpy = np.random.uniform(-output_max, output_max, size=num_examples) 203 | x_1 = torch.from_numpy(x_1_numpy)[:, None] 204 | num_examples_per_category = num_examples // num_categories 205 | x2_numpy = np.concatenate( 206 | [[c] * num_examples_per_category for c in range(num_categories)] 207 | ) 208 | x_2 = torch.from_numpy(x2_numpy)[:, None] 209 | training_examples = torch.column_stack((x_1, x_2)) 210 | linear_coefficients = torch.tensor([0.7, 0.3]).double() 211 | training_labels = torch.sum( 212 | torch.column_stack((torch.absolute(x_1), x_2)) * linear_coefficients, 213 | dim=1, 214 | keepdim=True, 215 | ) 216 | randperm = torch.randperm(training_examples.size()[0]) 217 | training_examples = training_examples[randperm] 218 | training_labels = training_labels[randperm] 219 | 220 | calibrated_lattice = CalibratedLattice( 221 | features=[ 222 | NumericalFeature( 223 | "x1", x_1_numpy, num_keypoints=4, lattice_size=lattice_dim 224 | ), 225 | CategoricalFeature( 226 | "x2", 227 | [0, 1, 2], 228 | monotonicity_pairs=[(0, 1), (1, 2)], 229 | lattice_size=lattice_dim, 230 | ), 231 | ], 232 | output_min=output_min, 233 | output_max=output_max, 234 | interpolation=interpolation, 235 | ) 236 | 237 | loss_fn = torch.nn.MSELoss() 238 | optimizer = torch.optim.Adam(calibrated_lattice.parameters(recurse=True), lr=1e-1) 239 | 240 | with torch.no_grad(): 241 | initial_predictions = calibrated_lattice(training_examples) 242 | initial_loss = loss_fn(initial_predictions, training_labels) 243 | 244 | train_calibrated_module( 245 | calibrated_lattice, 246 | training_examples, 247 | training_labels, 248 | loss_fn, 249 | optimizer, 250 | 500, 251 | num_examples // 10, 252 | ) 253 | 254 | with torch.no_grad(): 255 | trained_predictions = calibrated_lattice(training_examples) 256 | trained_loss = loss_fn(trained_predictions, training_labels) 257 | 258 | # calibrated_lattice.apply_constraints() 259 | assert not calibrated_lattice.assert_constraints() 260 | assert trained_loss < initial_loss 261 | assert trained_loss < 0.08 262 | 263 | 264 | @patch.object(Lattice, "assert_constraints") 265 | @patch.object(NumericalCalibrator, "assert_constraints") 266 | def test_assert_constraints( 267 | mock_lattice_assert_constraints, mock_output_assert_constraints 268 | ): 269 | """Tests `assert_constraints()` method calls internal assert_constraints.""" 270 | calibrated_lattice = CalibratedLattice( 271 | features=[ 272 | NumericalFeature( 273 | feature_name="numerical_feature", 274 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 275 | num_keypoints=5, 276 | monotonicity=None, 277 | ), 278 | CategoricalFeature( 279 | feature_name="categorical_feature", 280 | categories=["a", "b", "c"], 281 | monotonicity_pairs=[("a", "b")], 282 | ), 283 | ], 284 | output_calibration_num_keypoints=5, 285 | ) 286 | 287 | mock_asserts = [] 288 | for calibrator in calibrated_lattice.calibrators.values(): 289 | mock_assert = Mock() 290 | calibrator.assert_constraints = mock_assert 291 | mock_asserts.append(mock_assert) 292 | 293 | calibrated_lattice.assert_constraints() 294 | 295 | mock_lattice_assert_constraints.assert_called_once() 296 | for mock_assert in mock_asserts: 297 | mock_assert.assert_called_once() 298 | mock_output_assert_constraints.assert_called_once() 299 | 300 | 301 | @patch.object(Lattice, "apply_constraints") 302 | @patch.object(NumericalCalibrator, "apply_constraints") 303 | def test_constrain( 304 | mock_lattice_apply_constraints, mock_output_calibrator_apply_constraints 305 | ): 306 | """Tests `apply_constraints()` method calls internal constraint functions.""" 307 | calibrated_lattice = CalibratedLattice( 308 | features=[ 309 | NumericalFeature( 310 | feature_name="numerical_feature", 311 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 312 | num_keypoints=5, 313 | monotonicity=None, 314 | ), 315 | CategoricalFeature( 316 | feature_name="categorical_feature", 317 | categories=["a", "b", "c"], 318 | monotonicity_pairs=[("a", "b")], 319 | ), 320 | ], 321 | output_calibration_num_keypoints=2, 322 | ) 323 | mock_apply_constraints_fns = [] 324 | for calibrator in calibrated_lattice.calibrators.values(): 325 | mock_calibrator_apply_constraints = Mock() 326 | calibrator.apply_constraints = mock_calibrator_apply_constraints 327 | mock_apply_constraints_fns.append(mock_calibrator_apply_constraints) 328 | 329 | calibrated_lattice.apply_constraints() 330 | 331 | mock_lattice_apply_constraints.assert_called_once() 332 | mock_output_calibrator_apply_constraints.assert_called_once() 333 | for mock_constrain in mock_apply_constraints_fns: 334 | mock_constrain.assert_called_once() 335 | -------------------------------------------------------------------------------- /tests/models/test_calibrated_linear.py: -------------------------------------------------------------------------------- 1 | """Tests for calibrated linear model.""" 2 | from unittest.mock import Mock, patch 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | 8 | from pytorch_lattice import Monotonicity 9 | from pytorch_lattice.layers import Linear, NumericalCalibrator 10 | from pytorch_lattice.models import CalibratedLinear 11 | from pytorch_lattice.models.features import CategoricalFeature, NumericalFeature 12 | 13 | from ..testing_utils import train_calibrated_module 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "features,output_min,output_max,output_calibration_num_keypoints," 18 | "expected_linear_monotonicities,expected_output_calibrator_monotonicity", 19 | [ 20 | ( 21 | [ 22 | NumericalFeature( 23 | feature_name="numerical_feature", 24 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 25 | num_keypoints=5, 26 | monotonicity=None, 27 | ), 28 | CategoricalFeature( 29 | feature_name="categorical_feature", 30 | categories=["a", "b", "c"], 31 | monotonicity_pairs=[("a", "b")], 32 | ), 33 | ], 34 | None, 35 | None, 36 | None, 37 | [ 38 | None, 39 | Monotonicity.INCREASING, 40 | ], 41 | Monotonicity.INCREASING, 42 | ), 43 | ( 44 | [ 45 | NumericalFeature( 46 | feature_name="numerical_feature", 47 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 48 | num_keypoints=5, 49 | monotonicity=None, 50 | ), 51 | CategoricalFeature( 52 | feature_name="categorical_feature", 53 | categories=["a", "b", "c"], 54 | monotonicity_pairs=None, 55 | ), 56 | ], 57 | -1.0, 58 | 1.0, 59 | 10, 60 | [ 61 | Monotonicity.INCREASING, 62 | Monotonicity.INCREASING, 63 | ], 64 | None, 65 | ), 66 | ( 67 | [ 68 | NumericalFeature( 69 | feature_name="numerical_feature", 70 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 71 | num_keypoints=5, 72 | monotonicity=Monotonicity.DECREASING, 73 | ), 74 | CategoricalFeature( 75 | feature_name="categorical_feature", 76 | categories=["a", "b", "c"], 77 | monotonicity_pairs=None, 78 | ), 79 | ], 80 | 0.0, 81 | None, 82 | None, 83 | [ 84 | Monotonicity.INCREASING, 85 | Monotonicity.INCREASING, 86 | ], 87 | Monotonicity.INCREASING, 88 | ), 89 | ], 90 | ) 91 | def test_initialization( 92 | features, 93 | output_min, 94 | output_max, 95 | output_calibration_num_keypoints, 96 | expected_linear_monotonicities, 97 | expected_output_calibrator_monotonicity, 98 | ): 99 | """Tests that `CalibratedLinear` initialization works.""" 100 | calibrated_linear = CalibratedLinear( 101 | features=features, 102 | output_min=output_min, 103 | output_max=output_max, 104 | output_calibration_num_keypoints=output_calibration_num_keypoints, 105 | ) 106 | assert calibrated_linear.features == features 107 | assert calibrated_linear.output_min == output_min 108 | assert calibrated_linear.output_max == output_max 109 | assert ( 110 | calibrated_linear.output_calibration_num_keypoints 111 | == output_calibration_num_keypoints 112 | ) 113 | assert len(calibrated_linear.calibrators) == len(features) 114 | for calibrator in calibrated_linear.calibrators.values(): 115 | assert calibrator.output_min == output_min 116 | assert calibrator.output_max == output_max 117 | assert calibrated_linear.linear.monotonicities == expected_linear_monotonicities 118 | if ( 119 | output_min is not None 120 | or output_max is not None 121 | or output_calibration_num_keypoints 122 | ): 123 | assert not calibrated_linear.linear.use_bias 124 | assert calibrated_linear.linear.weighted_average 125 | else: 126 | assert calibrated_linear.linear.use_bias 127 | assert not calibrated_linear.linear.weighted_average 128 | if not output_calibration_num_keypoints: 129 | assert calibrated_linear.output_calibrator is None 130 | else: 131 | assert calibrated_linear.output_calibrator.output_min == output_min 132 | assert calibrated_linear.output_calibrator.output_max == output_max 133 | assert ( 134 | calibrated_linear.output_calibrator.monotonicity 135 | == expected_output_calibrator_monotonicity 136 | ) 137 | 138 | 139 | def test_forward(): 140 | """Tests all parts of calibrated lattice forward pass are called.""" 141 | calibrated_linear = CalibratedLinear( 142 | features=[ 143 | NumericalFeature( 144 | feature_name="n", 145 | data=np.array([1.0, 2.0]), 146 | ), 147 | CategoricalFeature( 148 | feature_name="c", 149 | categories=["a", "b", "c"], 150 | ), 151 | ], 152 | output_calibration_num_keypoints=10, 153 | ) 154 | 155 | with patch( 156 | "pytorch_lattice.models.calibrated_linear.calibrate_and_stack", 157 | ) as mock_calibrate_and_stack, patch.object( 158 | calibrated_linear.linear, 159 | "forward", 160 | ) as mock_linear_forward, patch.object( 161 | calibrated_linear.output_calibrator, 162 | "forward", 163 | ) as mock_output_calibrator: 164 | mock_calibrate_and_stack.return_value = torch.rand((1, 1)) 165 | mock_linear_forward.return_value = torch.rand((1, 1)) 166 | mock_output_calibrator.return_value = torch.rand((1, 1)) 167 | input_tensor = torch.rand((1, 2)) 168 | 169 | result = calibrated_linear.forward(input_tensor) 170 | 171 | mock_calibrate_and_stack.assert_called_once() 172 | assert torch.allclose(mock_calibrate_and_stack.call_args[0][0], input_tensor) 173 | assert mock_calibrate_and_stack.call_args[0][1] == calibrated_linear.calibrators 174 | mock_linear_forward.assert_called_once() 175 | assert torch.allclose( 176 | mock_linear_forward.call_args[0][0], mock_calibrate_and_stack.return_value 177 | ) 178 | mock_output_calibrator.assert_called_once() 179 | assert torch.allclose( 180 | mock_output_calibrator.call_args[0][0], mock_linear_forward.return_value 181 | ) 182 | assert torch.allclose(result, mock_output_calibrator.return_value) 183 | 184 | 185 | @patch.object(Linear, "assert_constraints") 186 | @patch.object(NumericalCalibrator, "assert_constraints") 187 | def test_assert_constraints( 188 | mock_linear_assert_constraints, mock_output_assert_constraints 189 | ): 190 | """Tests `assert_constraints()` method calls internal assert_constraints.""" 191 | calibrated_linear = CalibratedLinear( 192 | features=[ 193 | NumericalFeature( 194 | feature_name="numerical_feature", 195 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 196 | num_keypoints=5, 197 | monotonicity=None, 198 | ), 199 | CategoricalFeature( 200 | feature_name="categorical_feature", 201 | categories=["a", "b", "c"], 202 | monotonicity_pairs=[("a", "b")], 203 | ), 204 | ], 205 | output_calibration_num_keypoints=5, 206 | ) 207 | mock_asserts = [] 208 | for calibrator in calibrated_linear.calibrators.values(): 209 | mock_assert = Mock() 210 | calibrator.assert_constraints = mock_assert 211 | mock_asserts.append(mock_assert) 212 | 213 | calibrated_linear.assert_constraints() 214 | 215 | mock_linear_assert_constraints.assert_called_once() 216 | mock_output_assert_constraints.assert_called_once() 217 | for mock_assert in mock_asserts: 218 | mock_assert.assert_called_once() 219 | 220 | 221 | @patch.object(Linear, "apply_constraints") 222 | @patch.object(NumericalCalibrator, "apply_constraints") 223 | def test_constrain( 224 | mock_linear_apply_constraints, mock_output_calibrator_apply_constraints 225 | ): 226 | """Tests `apply_constraints()` method calls internal constraint functions.""" 227 | calibrated_linear = CalibratedLinear( 228 | features=[ 229 | NumericalFeature( 230 | feature_name="numerical_feature", 231 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 232 | num_keypoints=5, 233 | monotonicity=None, 234 | ), 235 | CategoricalFeature( 236 | feature_name="categorical_feature", 237 | categories=["a", "b", "c"], 238 | monotonicity_pairs=[("a", "b")], 239 | ), 240 | ], 241 | output_calibration_num_keypoints=2, 242 | ) 243 | mock_apply_constraints_fns = [] 244 | for calibrator in calibrated_linear.calibrators.values(): 245 | mock_calibrator_apply_constraints = Mock() 246 | calibrator.apply_constraints = mock_calibrator_apply_constraints 247 | mock_apply_constraints_fns.append(mock_calibrator_apply_constraints) 248 | 249 | calibrated_linear.apply_constraints() 250 | 251 | mock_linear_apply_constraints.assert_called_once() 252 | mock_output_calibrator_apply_constraints.assert_called_once() 253 | for mock_constrain in mock_apply_constraints_fns: 254 | mock_constrain.assert_called_once() 255 | 256 | 257 | def test_training(): 258 | """Tests `CalibratedLinear` training on data from f(x) = 0.7|x_1| + 0.3x_2.""" 259 | num_examples, num_categories = 3000, 3 260 | output_min, output_max = 0.0, num_categories - 1 261 | x_1_numpy = np.random.uniform(-output_max, output_max, size=num_examples) 262 | x_1 = torch.from_numpy(x_1_numpy)[:, None] 263 | num_examples_per_category = num_examples // num_categories 264 | x2_numpy = np.concatenate( 265 | [[c] * num_examples_per_category for c in range(num_categories)] 266 | ) 267 | x_2 = torch.from_numpy(x2_numpy)[:, None] 268 | training_examples = torch.column_stack((x_1, x_2)) 269 | linear_coefficients = torch.tensor([0.7, 0.3]).double() 270 | training_labels = torch.sum( 271 | torch.column_stack((torch.absolute(x_1), x_2)) * linear_coefficients, 272 | dim=1, 273 | keepdim=True, 274 | ) 275 | randperm = torch.randperm(training_examples.size()[0]) 276 | training_examples = training_examples[randperm] 277 | training_labels = training_labels[randperm] 278 | 279 | calibrated_linear = CalibratedLinear( 280 | features=[ 281 | NumericalFeature( 282 | "x1", 283 | x_1_numpy, 284 | num_keypoints=4, 285 | ), 286 | CategoricalFeature("x2", [0, 1, 2], monotonicity_pairs=[(0, 1), (1, 2)]), 287 | ], 288 | output_min=output_min, 289 | output_max=output_max, 290 | ) 291 | 292 | loss_fn = torch.nn.MSELoss() 293 | optimizer = torch.optim.Adam(calibrated_linear.parameters(recurse=True), lr=1e-1) 294 | 295 | with torch.no_grad(): 296 | initial_predictions = calibrated_linear(training_examples) 297 | initial_loss = loss_fn(initial_predictions, training_labels) 298 | 299 | train_calibrated_module( 300 | calibrated_linear, 301 | training_examples, 302 | training_labels, 303 | loss_fn, 304 | optimizer, 305 | 500, 306 | num_examples // 10, 307 | ) 308 | 309 | with torch.no_grad(): 310 | trained_predictions = calibrated_linear(training_examples) 311 | trained_loss = loss_fn(trained_predictions, training_labels) 312 | 313 | assert not calibrated_linear.assert_constraints() 314 | assert trained_loss < initial_loss 315 | assert trained_loss < 0.02 316 | -------------------------------------------------------------------------------- /tests/models/test_features.py: -------------------------------------------------------------------------------- 1 | """Tests for configuration objects.""" 2 | import numpy as np 3 | import pytest 4 | 5 | from pytorch_lattice import InputKeypointsInit, Monotonicity 6 | from pytorch_lattice.models.features import CategoricalFeature, NumericalFeature 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "data,num_keypoints,input_keypoints_init,missing_input_value,monotonicity," 11 | "expected_input_keypoints", 12 | [ 13 | ( 14 | np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 15 | 5, 16 | InputKeypointsInit.QUANTILES, 17 | None, 18 | None, 19 | np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 20 | ), 21 | ( 22 | np.array([1.0] * 10 + [2.0] * 10 + [3.0] * 10 + [4.0] * 10 + [5.0] * 10), 23 | 5, 24 | InputKeypointsInit.QUANTILES, 25 | None, 26 | Monotonicity.INCREASING, 27 | np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 28 | ), 29 | ( 30 | np.array([1.0] * 10 + [2.0] * 10 + [3.0] * 10 + [4.0] * 10 + [5.0] * 10), 31 | 5, 32 | InputKeypointsInit.UNIFORM, 33 | None, 34 | Monotonicity.INCREASING, 35 | np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 36 | ), 37 | ( 38 | np.array( 39 | [1.0] * 10 + [2.0] * 8 + [3.0] * 6 + [4.0] * 4 + [5.0] * 2 + [6.0] 40 | ), 41 | 4, 42 | InputKeypointsInit.QUANTILES, 43 | None, 44 | Monotonicity.INCREASING, 45 | np.array([1.0, 3.0, 4.0, 6.0]), 46 | ), 47 | ( 48 | np.array( 49 | [1.0] * 10 + [2.0] * 8 + [3.0] * 6 + [4.0] * 4 + [5.0] * 2 + [6.0] 50 | ), 51 | 10, 52 | InputKeypointsInit.QUANTILES, 53 | None, 54 | None, 55 | np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), 56 | ), 57 | ], 58 | ) 59 | def test_numerical_feature_config_initialization( 60 | data, 61 | num_keypoints, 62 | input_keypoints_init, 63 | missing_input_value, 64 | monotonicity, 65 | expected_input_keypoints, 66 | ) -> None: 67 | """Tests that numerical feature configs initialize properly.""" 68 | feature_name = "test_feature" 69 | feature = NumericalFeature( 70 | feature_name, 71 | data, 72 | num_keypoints, 73 | input_keypoints_init, 74 | missing_input_value, 75 | monotonicity, 76 | ) 77 | assert feature.feature_name == feature_name 78 | assert (feature.data == data).all() 79 | assert feature.num_keypoints == num_keypoints 80 | assert feature.input_keypoints_init == input_keypoints_init 81 | assert feature.missing_input_value == missing_input_value 82 | assert feature.monotonicity == monotonicity 83 | assert np.allclose(feature.input_keypoints, expected_input_keypoints) 84 | 85 | 86 | @pytest.mark.parametrize( 87 | "categories,missing_input_value,monotonicity_pairs,expected_category_indices," 88 | "expected_monotonicity_index_pairs", 89 | [ 90 | (["a", "b", "c"], None, None, {"a": 0, "b": 1, "c": 2}, []), 91 | ( 92 | ["a", "b", "c", "d"], 93 | -1.0, 94 | [("a", "b"), ("c", "d")], 95 | {"a": 0, "b": 1, "c": 2, "d": 3}, 96 | [(0, 1), (2, 3)], 97 | ), 98 | ], 99 | ) 100 | def test_categorical_feature_config_initialization( 101 | categories, 102 | missing_input_value, 103 | monotonicity_pairs, 104 | expected_category_indices, 105 | expected_monotonicity_index_pairs, 106 | ): 107 | """Tests that categorical feature configs initialize properly.""" 108 | feature_name = "test_feature" 109 | feature = CategoricalFeature( 110 | feature_name, categories, missing_input_value, monotonicity_pairs 111 | ) 112 | assert feature.feature_name == feature_name 113 | assert feature.categories == categories 114 | assert feature.missing_input_value == missing_input_value 115 | assert feature.monotonicity_pairs == monotonicity_pairs 116 | assert feature.category_indices == expected_category_indices 117 | assert feature.monotonicity_index_pairs == expected_monotonicity_index_pairs 118 | -------------------------------------------------------------------------------- /tests/test_classifier.py: -------------------------------------------------------------------------------- 1 | """Tests for the `Classifier` class.""" 2 | import numpy as np 3 | import pandas as pd 4 | import pytest 5 | 6 | from pytorch_lattice.classifier import Classifier 7 | from pytorch_lattice.feature_config import FeatureConfig 8 | from pytorch_lattice.model_configs import LatticeConfig, LinearConfig 9 | from pytorch_lattice.models import CalibratedLattice, CalibratedLinear 10 | 11 | 12 | def test_initialization(): 13 | """Tests that the classifier initializes properly.""" 14 | expected_features = {"a": FeatureConfig(name="a"), "b": FeatureConfig(name="b")} 15 | clf = Classifier(list(expected_features.keys())) 16 | assert list(clf.features.keys()) == list(expected_features.keys()) 17 | for name, config in clf.features.items(): 18 | assert config.name == expected_features[name].name 19 | assert isinstance(clf.model_config, LinearConfig) 20 | assert clf.model is None 21 | 22 | 23 | def test_configure(): 24 | """Tests that configure returns the correct feature configs.""" 25 | feature_names = ["a", "b"] 26 | clf = Classifier(feature_names) 27 | for name in feature_names: 28 | config = clf.configure(name) 29 | assert isinstance(config, FeatureConfig) 30 | assert config.name == name 31 | 32 | 33 | @pytest.fixture(name="X") 34 | def fixture_data(): 35 | """Randomized training data for fitting a classifier.""" 36 | return pd.DataFrame( 37 | { 38 | "numerical": np.random.rand(100), 39 | "categorical": np.random.choice(["a", "b", "c"], 100), 40 | } 41 | ) 42 | 43 | 44 | @pytest.fixture(name="y") 45 | def fixture_labels(X): 46 | """Randomized training labels for fitting a classifier.""" 47 | return np.random.randint(0, 2, len(X)) 48 | 49 | 50 | @pytest.mark.parametrize("model_config", [LinearConfig(), LatticeConfig()]) 51 | def test_fit_and_predict(model_config, X, y): 52 | """Tests that the classifier can be fit and generate predictions.""" 53 | clf = Classifier(X.columns, model_config).fit(X, y, epochs=1) 54 | assert clf.model is not None 55 | preds = clf.predict(X) 56 | assert isinstance(preds, np.ndarray) 57 | assert preds.shape == (100, 1) 58 | 59 | 60 | def test_fit_linear(X, y): 61 | """Tests that a linear config will fit a calibrated linear model.""" 62 | clf = Classifier(X.columns, LinearConfig()).fit(X, y, epochs=1) 63 | assert clf.model is not None 64 | assert isinstance(clf.model, CalibratedLinear) 65 | 66 | 67 | def test_fit_lattice(X, y): 68 | """Tests that a lattice config will fit a calibrated lattice model.""" 69 | clf = Classifier(X.columns, LatticeConfig()).fit(X, y, epochs=1) 70 | assert clf.model is not None 71 | assert isinstance(clf.model, CalibratedLattice) 72 | 73 | 74 | @pytest.mark.parametrize("model_config", [LinearConfig(), LatticeConfig()]) 75 | def test_save_and_load(model_config, X, y, tmp_path): 76 | """Tests that the classifier can be saved and loaded.""" 77 | clf = Classifier(X.columns, model_config).fit(X, y, epochs=1) 78 | clf.save(tmp_path) 79 | loaded_clf = Classifier.load(tmp_path) 80 | assert list(loaded_clf.__dict__.keys()) == list(clf.__dict__.keys()) 81 | for name, config in clf.features.items(): 82 | assert loaded_clf.features[name].__dict__ == config.__dict__ 83 | assert loaded_clf.model_config.__dict__ == clf.model_config.__dict__ 84 | assert isinstance(loaded_clf.model, type(clf.model)) 85 | -------------------------------------------------------------------------------- /tests/test_feature_config.py: -------------------------------------------------------------------------------- 1 | """Tests for the `FeatureConfig` class.""" 2 | 3 | from pytorch_lattice.feature_config import FeatureConfig 4 | 5 | 6 | def test_initialization(): 7 | """Tests that the feature config initializes properly with default values.""" 8 | name = "name" 9 | fc = FeatureConfig(name) 10 | assert fc.name == name 11 | assert fc._categories is None 12 | assert fc._num_keypoints == 5 13 | assert fc._input_keypoints_init == "quantiles" 14 | assert fc._input_keypoints_type == "fixed" 15 | assert fc._monotonicity is None 16 | assert fc._projection_iterations == 8 17 | assert fc._lattice_size == 2 18 | 19 | 20 | def test_setters(): 21 | """Tests that setting configuration values through methods works.""" 22 | fc = FeatureConfig("name") 23 | 24 | # Categories 25 | categories = ["a", "b"] 26 | fc.categories(["a", "b"]) 27 | assert fc._categories == categories 28 | 29 | # Num Keypoints 30 | num_keypoints = 10 31 | fc.num_keypoints(num_keypoints) 32 | assert fc._num_keypoints == num_keypoints 33 | 34 | # Input Keypoints Init 35 | input_keypoints_init = "uniform" 36 | fc.input_keypoints_init(input_keypoints_init) 37 | assert fc._input_keypoints_init == input_keypoints_init 38 | 39 | # Input Keypoints Type (LEARNED not yet implemented) 40 | # input_keypoints_type = "learned_interior" 41 | # fc.input_keypoints_type(input_keypoints_type) 42 | # assert fc._input_keypoints_type == input_keypoints_type 43 | 44 | # Monotonicity 45 | monotonicity = "increasing" 46 | fc.monotonicity(monotonicity) 47 | assert fc._monotonicity == monotonicity 48 | 49 | # Projection Iterations 50 | projection_iterations = 10 51 | fc.projection_iterations(projection_iterations) 52 | assert fc._projection_iterations == projection_iterations 53 | 54 | # Lattice Size 55 | lattice_size = 10 56 | fc.lattice_size(lattice_size) 57 | assert fc._lattice_size == lattice_size 58 | -------------------------------------------------------------------------------- /tests/test_model_configs.py: -------------------------------------------------------------------------------- 1 | """Tests for the model configuration classes.""" 2 | 3 | from pytorch_lattice.model_configs import LatticeConfig, LinearConfig, _BaseModelConfig 4 | 5 | 6 | def test_base_model_config_initialization(): 7 | """Tests that the base model config initializes with proper defaults.""" 8 | config = _BaseModelConfig() 9 | assert config.output_min is None 10 | assert config.output_max is None 11 | assert config.output_calibration_num_keypoints is None 12 | 13 | 14 | def test_linear_config_initialization(): 15 | """Tests that the linear config initializes with proper defaults.""" 16 | base_config = _BaseModelConfig() 17 | config = LinearConfig() 18 | for key, value in base_config.__dict__.items(): 19 | assert config.__dict__[key] == value 20 | assert config.use_bias 21 | 22 | 23 | def test_lattice_config_initialization(): 24 | """Tests that the lattice config initializes with proper defaults.""" 25 | base_config = _BaseModelConfig() 26 | config = LatticeConfig() 27 | for key, value in base_config.__dict__.items(): 28 | assert config.__dict__[key] == value 29 | assert config.kernel_init == "linear" 30 | assert config.interpolation == "simplex" 31 | -------------------------------------------------------------------------------- /tests/testing_utils.py: -------------------------------------------------------------------------------- 1 | """Testing Utilities.""" 2 | import torch 3 | 4 | from pytorch_lattice.constrained_module import ConstrainedModule 5 | 6 | 7 | def _batch_data(examples: torch.Tensor, labels: torch.Tensor, batch_size: int): 8 | """A generator that yields batches of data.""" 9 | num_examples = examples.size()[0] 10 | for i in range(0, num_examples, batch_size): 11 | yield ( 12 | examples[i : i + batch_size], 13 | labels[i : i + batch_size], 14 | ) 15 | 16 | 17 | def train_calibrated_module( 18 | calibrated_module: ConstrainedModule, 19 | examples: torch.Tensor, 20 | labels: torch.Tensor, 21 | loss_fn: torch.nn.Module, 22 | optimizer: torch.optim.Optimizer, 23 | epochs: int, 24 | batch_size: int, 25 | ): 26 | """Trains a calibrated module for testing purposes.""" 27 | for _ in range(epochs): 28 | for batched_inputs, batched_labels in _batch_data(examples, labels, batch_size): 29 | optimizer.zero_grad() 30 | outputs = calibrated_module(batched_inputs) 31 | loss = loss_fn(outputs, batched_labels) 32 | loss.backward() 33 | optimizer.step() 34 | calibrated_module.apply_constraints() 35 | 36 | 37 | class MockResponse: 38 | """Mock response class for testing.""" 39 | 40 | def __init__(self, json_data, status_code=200): 41 | """Mock response for testing.""" 42 | self.json_data = json_data 43 | self.status_code = status_code 44 | 45 | def json(self): 46 | """Return json data.""" 47 | return self.json_data 48 | -------------------------------------------------------------------------------- /tests/utils/test_data.py: -------------------------------------------------------------------------------- 1 | """Tests for data utilities.""" 2 | import numpy as np 3 | import pandas as pd 4 | import pytest 5 | 6 | from pytorch_lattice.models.features import CategoricalFeature, NumericalFeature 7 | from pytorch_lattice.utils.data import Dataset, prepare_features 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "X,features,expected_prepared_X", 12 | [ 13 | ( 14 | pd.DataFrame({"a": [1.0, 2.0, 3.0]}), 15 | [NumericalFeature(feature_name="a", data=[1.0, 2.0, 3.0])], 16 | pd.DataFrame({"a": [1.0, 2.0, 3.0]}), 17 | ), 18 | ( 19 | pd.DataFrame({"b": ["a", "b", "c"]}), 20 | [CategoricalFeature(feature_name="b", categories=["a", "b", "c"])], 21 | pd.DataFrame({"b": [0, 1, 2]}), 22 | ), 23 | ( 24 | pd.DataFrame({"a": [1.0, 2.0, 3.0, np.nan], "b": ["a", "b", "c", np.nan]}), 25 | [ 26 | NumericalFeature( 27 | feature_name="a", data=[1.0, 2.0, 3.0], missing_input_value=-1.0 28 | ), 29 | CategoricalFeature( 30 | feature_name="b", categories=["a", "b"], missing_input_value=-1.0 31 | ), 32 | ], 33 | pd.DataFrame({"a": [1.0, 2.0, 3.0, -1.0], "b": [0.0, 1.0, -1.0, -1.0]}), 34 | ), 35 | ], 36 | ) 37 | def test_prepare_features(X, features, expected_prepared_X): 38 | """Tests that the `prepare_features` function works as expected.""" 39 | prepare_features(X, features) 40 | assert X.equals(expected_prepared_X) 41 | 42 | 43 | @pytest.fixture(name="X") 44 | def fixture_X(): 45 | """Returns a `pd.DataFrame` fixture for testing.""" 46 | return pd.DataFrame( 47 | {"a": [1.0, 2.0, 3.0], "b": ["a", "b", "c"], "c": [4.0, 5.0, 6.0]} 48 | ) 49 | 50 | 51 | @pytest.fixture(name="features") 52 | def fixture_features(X): 53 | """Returns a list of model features for testing.""" 54 | return [ 55 | NumericalFeature(feature_name="a", data=X["a"].values), 56 | CategoricalFeature(feature_name="b", categories=list(X["b"].values)), 57 | ] 58 | 59 | 60 | @pytest.fixture(name="dataset") 61 | def fixture_dataset(X, features): 62 | """Returns a `Dataset` fixture for testing.""" 63 | y = np.array([1.0, 2.0, 3.0]) 64 | return Dataset(X, y, features) 65 | 66 | 67 | def test_initialization(dataset): 68 | """Tests that `Dataset` initialization work as expected.""" 69 | expected_X = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [0, 1, 2]}) 70 | assert dataset.X.equals(expected_X) 71 | 72 | 73 | def test_len(dataset): 74 | """Tests that `Dataset` __len__ is correct.""" 75 | assert len(dataset) == 3 76 | 77 | 78 | def test_get_item(dataset): 79 | """Tests that `Dataset` __getitem__ is correct.""" 80 | inputs, labels = dataset[:2] 81 | assert np.array_equal(inputs, np.array([[1.0, 0], [2.0, 1]])) 82 | assert np.array_equal(labels, np.array([[1.0], [2.0]])) 83 | -------------------------------------------------------------------------------- /tests/utils/test_models.py: -------------------------------------------------------------------------------- 1 | """Tests for models utilities.""" 2 | from unittest.mock import Mock 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | 8 | from pytorch_lattice import ( 9 | CategoricalCalibratorInit, 10 | Monotonicity, 11 | NumericalCalibratorInit, 12 | ) 13 | from pytorch_lattice.models.features import CategoricalFeature, NumericalFeature 14 | from pytorch_lattice.utils.models import ( 15 | calibrate_and_stack, 16 | initialize_feature_calibrators, 17 | initialize_monotonicities, 18 | initialize_output_calibrator, 19 | ) 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "num_feat, cat_feat", 24 | [ 25 | ( 26 | NumericalFeature( 27 | feature_name="numerical_feature", 28 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 29 | monotonicity=None, 30 | ), 31 | CategoricalFeature( 32 | feature_name="categorical_feature", 33 | categories=["a", "b"], 34 | ), 35 | ), 36 | ( 37 | NumericalFeature( 38 | feature_name="numerical_feature", 39 | data=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 40 | num_keypoints=5, 41 | monotonicity=Monotonicity.DECREASING, 42 | ), 43 | CategoricalFeature( 44 | feature_name="categorical_feature", 45 | categories=["a", "b", "c"], 46 | monotonicity_pairs=[("a", "b"), ("c", "b")], 47 | ), 48 | ), 49 | ], 50 | ) 51 | @pytest.mark.parametrize( 52 | "output_min, output_max, expected_output_max", 53 | [ 54 | (None, None, [None, None]), 55 | (None, 1, [1, 1]), 56 | (0, None, [None, None]), 57 | (None, [2, 3], [2, 3]), 58 | (0, [2, 2], [2, 2]), 59 | ], 60 | ) 61 | def test_initialize_feature_calibrators( 62 | num_feat, cat_feat, output_min, output_max, expected_output_max 63 | ) -> None: 64 | """Test for calibrator initialization helper function.""" 65 | features = [num_feat, cat_feat] 66 | calibrators_dict = initialize_feature_calibrators( 67 | features=features, 68 | output_min=output_min, 69 | output_max=output_max, 70 | ) 71 | 72 | np.testing.assert_allclose( 73 | calibrators_dict["numerical_feature"].input_keypoints, 74 | num_feat.input_keypoints, 75 | rtol=1e-5, 76 | ) 77 | assert ( 78 | calibrators_dict["numerical_feature"].missing_input_value 79 | == num_feat.missing_input_value 80 | ) 81 | assert calibrators_dict["numerical_feature"].output_min == output_min 82 | assert calibrators_dict["numerical_feature"].output_max == expected_output_max[0] 83 | assert calibrators_dict["numerical_feature"].monotonicity == num_feat.monotonicity 84 | assert ( 85 | calibrators_dict["numerical_feature"].kernel_init 86 | == NumericalCalibratorInit.EQUAL_SLOPES 87 | ) 88 | assert ( 89 | calibrators_dict["numerical_feature"].projection_iterations 90 | == num_feat.projection_iterations 91 | ) 92 | 93 | assert calibrators_dict["categorical_feature"].num_categories == len( 94 | cat_feat.categories 95 | ) 96 | assert ( 97 | calibrators_dict["categorical_feature"].missing_input_value 98 | == cat_feat.missing_input_value 99 | ) 100 | assert calibrators_dict["categorical_feature"].output_min == output_min 101 | assert calibrators_dict["categorical_feature"].output_max == expected_output_max[1] 102 | assert ( 103 | calibrators_dict["categorical_feature"].monotonicity_pairs 104 | == cat_feat.monotonicity_index_pairs 105 | ) 106 | assert ( 107 | calibrators_dict["categorical_feature"].kernel_init 108 | == CategoricalCalibratorInit.UNIFORM 109 | ) 110 | 111 | 112 | def test_initialize_feature_calibrators_invalid() -> None: 113 | """Test for calibrator initialization helper function for invalid feature type.""" 114 | with pytest.raises( 115 | ValueError, 116 | match=r"Unknown type for feature NOT A FEATURE", 117 | ): 118 | features = ["NOT A FEATURE"] 119 | initialize_feature_calibrators(features) # type: ignore[arg-type] 120 | 121 | 122 | @pytest.mark.parametrize( 123 | "features, expected_monotonicities", 124 | [ 125 | ( 126 | [ 127 | NumericalFeature( 128 | feature_name="n", 129 | data=np.array([1.0]), 130 | monotonicity=None, 131 | ), 132 | CategoricalFeature( 133 | feature_name="c", 134 | categories=["a", "b"], 135 | ), 136 | ], 137 | [None, None], 138 | ), 139 | ( 140 | [ 141 | NumericalFeature( 142 | feature_name="n", 143 | data=np.array([1.0]), 144 | monotonicity=None, 145 | ), 146 | CategoricalFeature( 147 | feature_name="c", 148 | categories=["a", "b"], 149 | monotonicity_pairs=[("a", "b")], 150 | ), 151 | ], 152 | [None, Monotonicity.INCREASING], 153 | ), 154 | ( 155 | [ 156 | NumericalFeature( 157 | feature_name="n", 158 | data=np.array([1.0]), 159 | monotonicity=Monotonicity.INCREASING, 160 | ), 161 | CategoricalFeature( 162 | feature_name="c", 163 | categories=["a", "b"], 164 | ), 165 | ], 166 | [Monotonicity.INCREASING, None], 167 | ), 168 | ( 169 | [ 170 | NumericalFeature( 171 | feature_name="n", 172 | data=np.array([1.0]), 173 | monotonicity=Monotonicity.DECREASING, 174 | ), 175 | CategoricalFeature( 176 | feature_name="c", 177 | categories=["a", "b"], 178 | monotonicity_pairs=[("a", "b")], 179 | ), 180 | ], 181 | [Monotonicity.INCREASING, Monotonicity.INCREASING], 182 | ), 183 | ], 184 | ) 185 | def test_initialize_monotonicities(features, expected_monotonicities) -> None: 186 | """Tests for monotonicity initialization logic in helper function.""" 187 | monotonicities = initialize_monotonicities(features) 188 | for mono, expected_mono in zip(monotonicities, expected_monotonicities): 189 | assert mono == expected_mono 190 | 191 | 192 | @pytest.mark.parametrize( 193 | "output_calibration_num_keypoints, monotonic, output_min, output_max", 194 | [ 195 | (None, None, None, None), 196 | (0, None, None, None), 197 | ], 198 | ) 199 | def test_initialize_output_calibrator_none( 200 | output_calibration_num_keypoints, monotonic, output_min, output_max 201 | ) -> None: 202 | """Tests helper function for initializing output calibrator when not initialized.""" 203 | output_cal = initialize_output_calibrator( 204 | output_calibration_num_keypoints=output_calibration_num_keypoints, 205 | monotonic=monotonic, 206 | output_min=output_min, 207 | output_max=output_max, 208 | ) 209 | assert output_cal is None 210 | 211 | 212 | @pytest.mark.parametrize( 213 | "output_calibration_num_keypoints, monotonic, output_min, output_max", 214 | [ 215 | (4, True, None, None), 216 | (5, False, 0.0, 1.0), 217 | ], 218 | ) 219 | def test_initialize_output_calibrator( 220 | output_calibration_num_keypoints, monotonic, output_min, output_max 221 | ) -> None: 222 | """Tests helper function for initializing output calibrator when initialized.""" 223 | output_cal = initialize_output_calibrator( 224 | output_calibration_num_keypoints=output_calibration_num_keypoints, 225 | monotonic=monotonic, 226 | output_min=output_min, 227 | output_max=output_max, 228 | ) 229 | assert output_cal is not None 230 | assert len(output_cal.input_keypoints) == output_calibration_num_keypoints 231 | assert output_cal.missing_input_value is None 232 | assert output_cal.output_max == output_max 233 | assert output_cal.output_min == output_min 234 | if monotonic: 235 | assert output_cal.monotonicity == Monotonicity.INCREASING 236 | else: 237 | assert output_cal.monotonicity is None 238 | 239 | 240 | @pytest.mark.parametrize( 241 | "data,expected_args", 242 | [ 243 | ( 244 | torch.tensor([[1.0, 2.0, 3.0]]), 245 | [torch.tensor([[1.0]]), torch.tensor([[2.0]]), torch.tensor([[3.0]])], 246 | ), 247 | ( 248 | torch.tensor([[4.0, 5.0], [6.0, 7.0]]), 249 | [torch.tensor([[4.0], [6.0]]), torch.tensor([[5.0], [7.0]])], 250 | ), 251 | ( 252 | torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), 253 | [ 254 | torch.tensor([[1.0], [4.0], [7.0]]), 255 | torch.tensor([[2.0], [5.0], [8.0]]), 256 | torch.tensor([[3.0], [6.0], [9.0]]), 257 | ], 258 | ), 259 | ], 260 | ) 261 | def test_calibrate_and_stack(data, expected_args): 262 | """Tests slicing logic of calibrate_and_stack function used in forward passes.""" 263 | mock_calibrators = { 264 | f"calibrator_{i}": Mock( 265 | spec=torch.nn.Module, return_value=torch.tensor([[0.0]]) 266 | ) 267 | for i in range(data.shape[1]) 268 | } 269 | calibrators = torch.nn.ModuleDict(mock_calibrators) 270 | 271 | result = calibrate_and_stack(data, calibrators) 272 | 273 | for mock_calibrator, expected_arg in zip(calibrators.values(), expected_args): 274 | mock_calibrator.assert_called_once() 275 | assert torch.allclose(mock_calibrator.call_args[0][0], expected_arg) 276 | expected_result = torch.zeros(data.shape[0], data.shape[1]) 277 | assert torch.allclose(result, expected_result) 278 | --------------------------------------------------------------------------------