├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── citation.bib ├── docs ├── .nojekyll ├── api │ ├── _static │ │ ├── alabaster.css │ │ ├── basic.css │ │ ├── custom.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── file.png │ │ ├── github-banner.svg │ │ ├── language_data.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── pygments.css │ │ ├── searchtools.js │ │ └── sphinx_highlight.js │ ├── genindex.html │ ├── index.html │ ├── objects.inv │ ├── search.html │ └── searchindex.js ├── fig1.png ├── images │ └── video_screenshot.png ├── index.html └── style.css ├── examples ├── Data_Cortex_Nuclear.csv ├── accuracy.png ├── boston_housing.py ├── boston_housing_group.py ├── cox_experiments.py ├── cox_regression.py ├── data │ ├── hnscc_x.csv │ └── hnscc_y.csv ├── diabetes.py ├── friedman.py ├── friedman │ ├── download.sh │ └── main.py ├── generated.py ├── miceprotein.py ├── mnist_ae.py ├── mnist_classif.py └── mnist_reconstruction.py ├── experiments ├── README.MD ├── data_utils.py └── run.py ├── lassonet ├── __init__.py ├── cox.py ├── interfaces.py ├── model.py ├── plot.py ├── prox.py ├── r.py └── utils.py ├── setup.py ├── sphinx_docs ├── Makefile ├── conf.py └── index.rst └── tests └── test_interface.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | Icon 3 | __pycache__ 4 | .pytest_cache 5 | *.egg-info/ 6 | .DS_Store 7 | sphinx_docs/_*/ 8 | docs/api/_sources 9 | dist/ 10 | build/ 11 | *.png 12 | *.jpg 13 | *.csv 14 | .vscode 15 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "examples/spinet"] 2 | path = examples/spinet 3 | url = https://github.com/meixide/spinet 4 | ignore = dirty -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: black 5 | name: Black Formatter 6 | entry: black . 7 | language: system 8 | types: [python] 9 | 10 | - id: isort 11 | name: isort 12 | entry: isort --profile=black . 13 | language: system 14 | types: [python] 15 | 16 | - id: make-docs 17 | name: Generate Docs 18 | entry: make docs 19 | language: system 20 | pass_filenames: false 21 | 22 | - id: add-docs 23 | name: Stage Docs 24 | entry: git add docs/api/ 25 | language: system 26 | pass_filenames: false 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Louis Abraham, Ismael Lemhadri 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | pypi: dist 2 | twine upload dist/* 3 | 4 | dist: 5 | - rm -rf dist 6 | python3 setup.py sdist bdist_wheel 7 | 8 | docs: 9 | cd sphinx_docs && $(MAKE) html 10 | - rm -rf docs/api 11 | mkdir docs/api 12 | cp -r sphinx_docs/_build/html/* docs/api 13 | 14 | .PHONY: docs dist -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI version](https://badge.fury.io/py/lassonet.svg)](https://badge.fury.io/py/lassonet) 2 | [![Downloads](https://static.pepy.tech/badge/lassonet)](https://pepy.tech/project/lassonet) 3 | 4 | # LassoNet 5 | 6 | LassoNet is a new family of models to incorporate feature selection and neural networks. 7 | 8 | LassoNet works by adding a linear skip connection from the input features to the output. A L1 penalty (LASSO-inspired) is added to that skip connection along with a constraint on the network so that whenever a feature is ignored by the skip connection, it is ignored by the whole network. 9 | 10 | Promo Video 11 | 12 | ## Installation 13 | 14 | ``` 15 | pip install lassonet 16 | ``` 17 | 18 | ## Usage 19 | 20 | We have designed the code to follow scikit-learn's standards to the extent possible (e.g. [linear_model.Lasso](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html)). 21 | 22 | ```python 23 | from lassonet import LassoNetClassifierCV 24 | model = LassoNetClassifierCV() # LassoNetRegressorCV 25 | path = model.fit(X_train, y_train) 26 | print("Best model scored", model.score(X_test, y_test)) 27 | print("Lambda =", model.best_lambda_) 28 | ``` 29 | 30 | You should always try to give normalized data to LassoNet as it uses neural networks under the hood. 31 | 32 | You can read the full [documentation](https://lasso-net.github.io//lassonet/api/) or read the [examples](https://github.com/lasso-net/lassonet/tree/master/examples) that cover most features. We also provide a Quickstart section below. 33 | 34 | 35 | 36 | ## Quickstart 37 | 38 | Here we guide you through the features of LassoNet and how you typically use them. 39 | 40 | ### Task 41 | 42 | LassoNet is based on neural networks and can be used for any kind of data. Currently, we have implemented losses for the following tasks: 43 | 44 | - regression: `LassoNetRegressor` 45 | - classification: `LassoNetClassifier` 46 | - Cox regression: `LassoNetCoxRegressor` 47 | - interval-censored Cox regression: `LassoNetIntervalRegressor` 48 | 49 | If features naturally belong to groups, you can use the `groups` parameter to specify them. This will allow the model to put a penalty on groups of features instead of each feature individually. 50 | 51 | ### Data preparation 52 | 53 | You should always normalize your data before passing it to the model to avoid too large (or too small) values in the data. 54 | 55 | ### What do you want to do? 56 | 57 | The LassoNet family of models do a lot of things. 58 | 59 | Here are some examples of what you can do with LassoNet. Note that you can switch `LassoNetRegressor` with any of the other models to perform the same operations. 60 | 61 | #### Using the base interface 62 | 63 | The original paper describes how to train LassoNet along a regularization path. This requires the user to manually select a model from the path and made the `.fit()` method useless since the resulting model is always empty. This feature is still available with the `.path(return_state_dicts=True)` method for any base model and returns a list of checkpoints that can be loaded with `.load()`. 64 | 65 | 66 | ```python 67 | from lassonet import LassoNetRegressor, plot_path 68 | 69 | model = LassoNetRegressor() 70 | path = model.path(X_train, y_train, return_state_dicts=True) 71 | plot_path(model, X_test, y_test) 72 | 73 | # choose `best_id` based on the plot 74 | model.load(path[best_id].state_dict) 75 | print(model.score(X_test, y_test)) 76 | ``` 77 | 78 | You can also retrieve the mask of the selected features and train a dense model on the selected features. 79 | 80 | ```python 81 | selected = path[best_id].selected 82 | model.fit(X_train[:, selected], y_train, dense_only=True) 83 | print(model.score(X_test[:, selected], y_test)) 84 | ``` 85 | 86 | You get a `model.feature_importances_` attribute that is the value of the L1 regularization parameter at which each feature is removed. This can give you an idea of the most important features but is very unstable across different runs. You should use stability selection to select the most stable features. 87 | 88 | #### Using the cross-validation interface 89 | 90 | 91 | We integrated support for cross-validation (5-fold by default) in the estimators whose name ends with `CV`. For each fold, a path is trained. The best regularization value is then chosen to maximize the average score over all folds. The model is then retrained on the whole training dataset to reach that regularization. 92 | 93 | ```python 94 | model = LassoNetRegressorCV() 95 | model.fit(X_train, y_train) 96 | model.score(X_test, y_test) 97 | ``` 98 | 99 | You can also use the `plot_cv` method to get more information. 100 | 101 | Some attributes give you more information about the best model, like `best_lambda_`, `best_selected_` or `best_cv_score_`. 102 | 103 | This information is useful to pass to a base model to train it from scratch with the best regularization parameter or the best subset of features. 104 | 105 | #### Using the stability selection interface 106 | 107 | 108 | [Stability selection](https://arxiv.org/abs/0809.2932) is a method to select the most stable features when running the model multiple times on different random subsamples of the data. It is probably the best way to select the most important features. 109 | 110 | ```python 111 | model = LassoNetRegressor() 112 | oracle, order, wrong, paths, prob = model.stability_selection(X_train, y_train) 113 | ``` 114 | 115 | - `oracle` is a heuristic that can detect the most stable features when introducing noise. 116 | - `order` sorts the features by their decreasing importance. 117 | - `wrong[k]` is a measure of error when selecting the k+1 first features (read [this paper](https://arxiv.org/pdf/2206.06885) for more details). You can `plt.plot(wrong)` to see the error as a function of the number of selected features. 118 | - `paths` stores all the computed paths. 119 | - `prob` is the probability that a feature is selected at each value of the regularization parameter. 120 | 121 | In practice, you might want to train multiple dense models on different subsets of features to get a better understanding of the importance of each feature. 122 | 123 | For example: 124 | 125 | ```python 126 | for i in range(10): 127 | selected = order[:i] 128 | model.fit(X_train[:, selected], y_train, dense_only=True) 129 | print(model.score(X_test[:, selected], y_test)) 130 | ``` 131 | 132 | ### Important parameters 133 | 134 | Here are the most important parameters you should be aware of: 135 | 136 | - `hidden_dims`: the number of neurons in each hidden layer. The default value is `(100,)` but you might want to try smaller and deeper networks like `(10, 10)`. 137 | - `path_multiplier`: the number of lambda values to compute on the path. The lower it is, the more precise the model is but the more time it takes. The default value is a trade-off to get a fast training but you might want to try smaller values like `1.01` or `1.005` to get a better model. 138 | - `lambda_start`: the starting value of the regularization parameter. The default value is `"auto"` and the model will try to select a good starting value according to an unpublised heuristic (read the code to know more). You can identify a bad `lambda_start` by plotting the path. If `lambda_start` is too small, the model will stay dense for a long time, which does not affect performance but takes longer. If `lambda_start` is too large, the number of features with decrease very fast and the path will not be accurate. In that case you might also want to decrease `lambda_start`. 139 | - `gamma`: puts some L2 penalty on the network. The default is `0.0` which means no penalty but some small value can improve the performance, especially on small datasets. 140 | - more standard MLP training parameters are accessible: `dropout`, `batch_size`, `optim`, `n_iters`, `patience`, `tol`, `backtrack`, `val_size`. In particular, `batch_size` can be useful to do stochastic gradient descent instead of full batch gradient descent and to avoid memory issues on large datasets. 141 | - `M`: this parameter has almost no effect on the model. 142 | 143 | ## Features 144 | 145 | - regression, classification, [Cox regression](https://en.wikipedia.org/wiki/Proportional_hazards_model) and [interval-censored Cox regression](https://arxiv.org/abs/2206.06885) with `LassoNetRegressor`, `LassoNetClassifier`, `LassoNetCoxRegressor` and `LassoNetIntervalRegressor`. 146 | - cross-validation with `LassoNetRegressorCV`, `LassoNetClassifierCV`, `LassoNetCoxRegressorCV` and `LassoNetIntervalRegressorCV` 147 | - [stability selection](https://arxiv.org/abs/0809.2932) with `model.stability_selection()` 148 | - group feature selection with the `groups` argument 149 | - `lambda_start="auto"` heuristic (default) 150 | 151 | Note that cross-validation, group feature selection and automatic `lambda_start` selection have not been published in papers, you can read the code or [post as issue](https://github.com/lasso-net/lassonet/issues/new) to request more details. 152 | 153 | We are currently working (among others) on adding support for convolution layers, auto-encoders and online logging of experiments. 154 | 155 | ## Website 156 | 157 | LassoNet's website is [https:lasso-net.github.io/](https://lasso-net.github.io/). It contains many useful references including the paper, live talks and additional documentation. 158 | 159 | ## References 160 | 161 | - Lemhadri, Ismael, Feng Ruan, Louis Abraham, and Robert Tibshirani. "LassoNet: A Neural Network with Feature Sparsity." Journal of Machine Learning Research 22, no. 127 (2021). [pdf](https://arxiv.org/pdf/1907.12207.pdf) [bibtex](https://github.com/lasso-net/lassonet/blob/master/citation.bib) 162 | - Yang, Xuelin, Louis Abraham, Sejin Kim, Petr Smirnov, Feng Ruan, Benjamin Haibe-Kains, and Robert Tibshirani. "FastCPH: Efficient Survival Analysis for Neural Networks." In NeurIPS 2022 Workshop on Learning from Time Series for Health. [pdf](https://arxiv.org/pdf/2208.09793.pdf) 163 | - Meixide, Carlos García, Marcos Matabuena, Louis Abraham, and Michael R. Kosorok. "Neural interval‐censored survival regression with feature selection." Statistical Analysis and Data Mining: The ASA Data Science Journal 17, no. 4 (2024): e11704. [pdf](https://arxiv.org/pdf/2206.06885) -------------------------------------------------------------------------------- /citation.bib: -------------------------------------------------------------------------------- 1 | @article{JMLR:v22:20-848, 2 | author = {Ismael Lemhadri and Feng Ruan and Louis Abraham and Robert Tibshirani}, 3 | title = {LassoNet: A Neural Network with Feature Sparsity}, 4 | journal = {Journal of Machine Learning Research}, 5 | year = {2021}, 6 | volume = {22}, 7 | number = {127}, 8 | pages = {1-29}, 9 | url = {http://jmlr.org/papers/v22/20-848.html} 10 | } 11 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/.nojekyll -------------------------------------------------------------------------------- /docs/api/_static/alabaster.css: -------------------------------------------------------------------------------- 1 | /* -- page layout ----------------------------------------------------------- */ 2 | 3 | body { 4 | font-family: Georgia, serif; 5 | font-size: 17px; 6 | background-color: #fff; 7 | color: #000; 8 | margin: 0; 9 | padding: 0; 10 | } 11 | 12 | 13 | div.document { 14 | width: 940px; 15 | margin: 30px auto 0 auto; 16 | } 17 | 18 | div.documentwrapper { 19 | float: left; 20 | width: 100%; 21 | } 22 | 23 | div.bodywrapper { 24 | margin: 0 0 0 220px; 25 | } 26 | 27 | div.sphinxsidebar { 28 | width: 220px; 29 | font-size: 14px; 30 | line-height: 1.5; 31 | } 32 | 33 | hr { 34 | border: 1px solid #B1B4B6; 35 | } 36 | 37 | div.body { 38 | background-color: #fff; 39 | color: #3E4349; 40 | padding: 0 30px 0 30px; 41 | } 42 | 43 | div.body > .section { 44 | text-align: left; 45 | } 46 | 47 | div.footer { 48 | width: 940px; 49 | margin: 20px auto 30px auto; 50 | font-size: 14px; 51 | color: #888; 52 | text-align: right; 53 | } 54 | 55 | div.footer a { 56 | color: #888; 57 | } 58 | 59 | p.caption { 60 | font-family: inherit; 61 | font-size: inherit; 62 | } 63 | 64 | 65 | div.relations { 66 | display: none; 67 | } 68 | 69 | 70 | div.sphinxsidebar { 71 | max-height: 100%; 72 | overflow-y: auto; 73 | } 74 | 75 | div.sphinxsidebar a { 76 | color: #444; 77 | text-decoration: none; 78 | border-bottom: 1px dotted #999; 79 | } 80 | 81 | div.sphinxsidebar a:hover { 82 | border-bottom: 1px solid #999; 83 | } 84 | 85 | div.sphinxsidebarwrapper { 86 | padding: 18px 10px; 87 | } 88 | 89 | div.sphinxsidebarwrapper p.logo { 90 | padding: 0; 91 | margin: -10px 0 0 0px; 92 | text-align: center; 93 | } 94 | 95 | div.sphinxsidebarwrapper h1.logo { 96 | margin-top: -10px; 97 | text-align: center; 98 | margin-bottom: 5px; 99 | text-align: left; 100 | } 101 | 102 | div.sphinxsidebarwrapper h1.logo-name { 103 | margin-top: 0px; 104 | } 105 | 106 | div.sphinxsidebarwrapper p.blurb { 107 | margin-top: 0; 108 | font-style: normal; 109 | } 110 | 111 | div.sphinxsidebar h3, 112 | div.sphinxsidebar h4 { 113 | font-family: Georgia, serif; 114 | color: #444; 115 | font-size: 24px; 116 | font-weight: normal; 117 | margin: 0 0 5px 0; 118 | padding: 0; 119 | } 120 | 121 | div.sphinxsidebar h4 { 122 | font-size: 20px; 123 | } 124 | 125 | div.sphinxsidebar h3 a { 126 | color: #444; 127 | } 128 | 129 | div.sphinxsidebar p.logo a, 130 | div.sphinxsidebar h3 a, 131 | div.sphinxsidebar p.logo a:hover, 132 | div.sphinxsidebar h3 a:hover { 133 | border: none; 134 | } 135 | 136 | div.sphinxsidebar p { 137 | color: #555; 138 | margin: 10px 0; 139 | } 140 | 141 | div.sphinxsidebar ul { 142 | margin: 10px 0; 143 | padding: 0; 144 | color: #000; 145 | } 146 | 147 | div.sphinxsidebar ul li.toctree-l1 > a { 148 | font-size: 120%; 149 | } 150 | 151 | div.sphinxsidebar ul li.toctree-l2 > a { 152 | font-size: 110%; 153 | } 154 | 155 | div.sphinxsidebar input { 156 | border: 1px solid #CCC; 157 | font-family: Georgia, serif; 158 | font-size: 1em; 159 | } 160 | 161 | div.sphinxsidebar #searchbox { 162 | margin: 1em 0; 163 | } 164 | 165 | div.sphinxsidebar .search > div { 166 | display: table-cell; 167 | } 168 | 169 | div.sphinxsidebar hr { 170 | border: none; 171 | height: 1px; 172 | color: #AAA; 173 | background: #AAA; 174 | 175 | text-align: left; 176 | margin-left: 0; 177 | width: 50%; 178 | } 179 | 180 | div.sphinxsidebar .badge { 181 | border-bottom: none; 182 | } 183 | 184 | div.sphinxsidebar .badge:hover { 185 | border-bottom: none; 186 | } 187 | 188 | /* To address an issue with donation coming after search */ 189 | div.sphinxsidebar h3.donation { 190 | margin-top: 10px; 191 | } 192 | 193 | /* -- body styles ----------------------------------------------------------- */ 194 | 195 | a { 196 | color: #004B6B; 197 | text-decoration: underline; 198 | } 199 | 200 | a:hover { 201 | color: #6D4100; 202 | text-decoration: underline; 203 | } 204 | 205 | div.body h1, 206 | div.body h2, 207 | div.body h3, 208 | div.body h4, 209 | div.body h5, 210 | div.body h6 { 211 | font-family: Georgia, serif; 212 | font-weight: normal; 213 | margin: 30px 0px 10px 0px; 214 | padding: 0; 215 | } 216 | 217 | div.body h1 { margin-top: 0; padding-top: 0; font-size: 240%; } 218 | div.body h2 { font-size: 180%; } 219 | div.body h3 { font-size: 150%; } 220 | div.body h4 { font-size: 130%; } 221 | div.body h5 { font-size: 100%; } 222 | div.body h6 { font-size: 100%; } 223 | 224 | a.headerlink { 225 | color: #DDD; 226 | padding: 0 4px; 227 | text-decoration: none; 228 | } 229 | 230 | a.headerlink:hover { 231 | color: #444; 232 | background: #EAEAEA; 233 | } 234 | 235 | div.body p, div.body dd, div.body li { 236 | line-height: 1.4em; 237 | } 238 | 239 | div.admonition { 240 | margin: 20px 0px; 241 | padding: 10px 30px; 242 | background-color: #EEE; 243 | border: 1px solid #CCC; 244 | } 245 | 246 | div.admonition tt.xref, div.admonition code.xref, div.admonition a tt { 247 | background-color: #FBFBFB; 248 | border-bottom: 1px solid #fafafa; 249 | } 250 | 251 | div.admonition p.admonition-title { 252 | font-family: Georgia, serif; 253 | font-weight: normal; 254 | font-size: 24px; 255 | margin: 0 0 10px 0; 256 | padding: 0; 257 | line-height: 1; 258 | } 259 | 260 | div.admonition p.last { 261 | margin-bottom: 0; 262 | } 263 | 264 | dt:target, .highlight { 265 | background: #FAF3E8; 266 | } 267 | 268 | div.warning { 269 | background-color: #FCC; 270 | border: 1px solid #FAA; 271 | } 272 | 273 | div.danger { 274 | background-color: #FCC; 275 | border: 1px solid #FAA; 276 | -moz-box-shadow: 2px 2px 4px #D52C2C; 277 | -webkit-box-shadow: 2px 2px 4px #D52C2C; 278 | box-shadow: 2px 2px 4px #D52C2C; 279 | } 280 | 281 | div.error { 282 | background-color: #FCC; 283 | border: 1px solid #FAA; 284 | -moz-box-shadow: 2px 2px 4px #D52C2C; 285 | -webkit-box-shadow: 2px 2px 4px #D52C2C; 286 | box-shadow: 2px 2px 4px #D52C2C; 287 | } 288 | 289 | div.caution { 290 | background-color: #FCC; 291 | border: 1px solid #FAA; 292 | } 293 | 294 | div.attention { 295 | background-color: #FCC; 296 | border: 1px solid #FAA; 297 | } 298 | 299 | div.important { 300 | background-color: #EEE; 301 | border: 1px solid #CCC; 302 | } 303 | 304 | div.note { 305 | background-color: #EEE; 306 | border: 1px solid #CCC; 307 | } 308 | 309 | div.tip { 310 | background-color: #EEE; 311 | border: 1px solid #CCC; 312 | } 313 | 314 | div.hint { 315 | background-color: #EEE; 316 | border: 1px solid #CCC; 317 | } 318 | 319 | div.seealso { 320 | background-color: #EEE; 321 | border: 1px solid #CCC; 322 | } 323 | 324 | div.topic { 325 | background-color: #EEE; 326 | } 327 | 328 | p.admonition-title { 329 | display: inline; 330 | } 331 | 332 | p.admonition-title:after { 333 | content: ":"; 334 | } 335 | 336 | pre, tt, code { 337 | font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace; 338 | font-size: 0.9em; 339 | } 340 | 341 | .hll { 342 | background-color: #FFC; 343 | margin: 0 -12px; 344 | padding: 0 12px; 345 | display: block; 346 | } 347 | 348 | img.screenshot { 349 | } 350 | 351 | tt.descname, tt.descclassname, code.descname, code.descclassname { 352 | font-size: 0.95em; 353 | } 354 | 355 | tt.descname, code.descname { 356 | padding-right: 0.08em; 357 | } 358 | 359 | img.screenshot { 360 | -moz-box-shadow: 2px 2px 4px #EEE; 361 | -webkit-box-shadow: 2px 2px 4px #EEE; 362 | box-shadow: 2px 2px 4px #EEE; 363 | } 364 | 365 | table.docutils { 366 | border: 1px solid #888; 367 | -moz-box-shadow: 2px 2px 4px #EEE; 368 | -webkit-box-shadow: 2px 2px 4px #EEE; 369 | box-shadow: 2px 2px 4px #EEE; 370 | } 371 | 372 | table.docutils td, table.docutils th { 373 | border: 1px solid #888; 374 | padding: 0.25em 0.7em; 375 | } 376 | 377 | table.field-list, table.footnote { 378 | border: none; 379 | -moz-box-shadow: none; 380 | -webkit-box-shadow: none; 381 | box-shadow: none; 382 | } 383 | 384 | table.footnote { 385 | margin: 15px 0; 386 | width: 100%; 387 | border: 1px solid #EEE; 388 | background: #FDFDFD; 389 | font-size: 0.9em; 390 | } 391 | 392 | table.footnote + table.footnote { 393 | margin-top: -15px; 394 | border-top: none; 395 | } 396 | 397 | table.field-list th { 398 | padding: 0 0.8em 0 0; 399 | } 400 | 401 | table.field-list td { 402 | padding: 0; 403 | } 404 | 405 | table.field-list p { 406 | margin-bottom: 0.8em; 407 | } 408 | 409 | /* Cloned from 410 | * https://github.com/sphinx-doc/sphinx/commit/ef60dbfce09286b20b7385333d63a60321784e68 411 | */ 412 | .field-name { 413 | -moz-hyphens: manual; 414 | -ms-hyphens: manual; 415 | -webkit-hyphens: manual; 416 | hyphens: manual; 417 | } 418 | 419 | table.footnote td.label { 420 | width: .1px; 421 | padding: 0.3em 0 0.3em 0.5em; 422 | } 423 | 424 | table.footnote td { 425 | padding: 0.3em 0.5em; 426 | } 427 | 428 | dl { 429 | margin-left: 0; 430 | margin-right: 0; 431 | margin-top: 0; 432 | padding: 0; 433 | } 434 | 435 | dl dd { 436 | margin-left: 30px; 437 | } 438 | 439 | blockquote { 440 | margin: 0 0 0 30px; 441 | padding: 0; 442 | } 443 | 444 | ul, ol { 445 | /* Matches the 30px from the narrow-screen "li > ul" selector below */ 446 | margin: 10px 0 10px 30px; 447 | padding: 0; 448 | } 449 | 450 | pre { 451 | background: unset; 452 | padding: 7px 30px; 453 | margin: 15px 0px; 454 | line-height: 1.3em; 455 | } 456 | 457 | div.viewcode-block:target { 458 | background: #ffd; 459 | } 460 | 461 | dl pre, blockquote pre, li pre { 462 | margin-left: 0; 463 | padding-left: 30px; 464 | } 465 | 466 | tt, code { 467 | background-color: #ecf0f3; 468 | color: #222; 469 | /* padding: 1px 2px; */ 470 | } 471 | 472 | tt.xref, code.xref, a tt { 473 | background-color: #FBFBFB; 474 | border-bottom: 1px solid #fff; 475 | } 476 | 477 | a.reference { 478 | text-decoration: none; 479 | border-bottom: 1px dotted #004B6B; 480 | } 481 | 482 | a.reference:hover { 483 | border-bottom: 1px solid #6D4100; 484 | } 485 | 486 | /* Don't put an underline on images */ 487 | a.image-reference, a.image-reference:hover { 488 | border-bottom: none; 489 | } 490 | 491 | a.footnote-reference { 492 | text-decoration: none; 493 | font-size: 0.7em; 494 | vertical-align: top; 495 | border-bottom: 1px dotted #004B6B; 496 | } 497 | 498 | a.footnote-reference:hover { 499 | border-bottom: 1px solid #6D4100; 500 | } 501 | 502 | a:hover tt, a:hover code { 503 | background: #EEE; 504 | } 505 | 506 | @media screen and (max-width: 940px) { 507 | 508 | body { 509 | margin: 0; 510 | padding: 20px 30px; 511 | } 512 | 513 | div.documentwrapper { 514 | float: none; 515 | background: #fff; 516 | margin-left: 0; 517 | margin-top: 0; 518 | margin-right: 0; 519 | margin-bottom: 0; 520 | } 521 | 522 | div.sphinxsidebar { 523 | display: block; 524 | float: none; 525 | width: unset; 526 | margin: 50px -30px -20px -30px; 527 | padding: 10px 20px; 528 | background: #333; 529 | color: #FFF; 530 | } 531 | 532 | div.sphinxsidebar h3, div.sphinxsidebar h4, div.sphinxsidebar p, 533 | div.sphinxsidebar h3 a { 534 | color: #fff; 535 | } 536 | 537 | div.sphinxsidebar a { 538 | color: #AAA; 539 | } 540 | 541 | div.sphinxsidebar p.logo { 542 | display: none; 543 | } 544 | 545 | div.document { 546 | width: 100%; 547 | margin: 0; 548 | } 549 | 550 | div.footer { 551 | display: none; 552 | } 553 | 554 | div.bodywrapper { 555 | margin: 0; 556 | } 557 | 558 | div.body { 559 | min-height: 0; 560 | min-width: auto; /* fixes width on small screens, breaks .hll */ 561 | padding: 0; 562 | } 563 | 564 | .hll { 565 | /* "fixes" the breakage */ 566 | width: max-content; 567 | } 568 | 569 | .rtd_doc_footer { 570 | display: none; 571 | } 572 | 573 | .document { 574 | width: auto; 575 | } 576 | 577 | .footer { 578 | width: auto; 579 | } 580 | 581 | .github { 582 | display: none; 583 | } 584 | 585 | ul { 586 | margin-left: 0; 587 | } 588 | 589 | li > ul { 590 | /* Matches the 30px from the "ul, ol" selector above */ 591 | margin-left: 30px; 592 | } 593 | } 594 | 595 | 596 | /* misc. */ 597 | 598 | .revsys-inline { 599 | display: none!important; 600 | } 601 | 602 | /* Hide ugly table cell borders in ..bibliography:: directive output */ 603 | table.docutils.citation, table.docutils.citation td, table.docutils.citation th { 604 | border: none; 605 | /* Below needed in some edge cases; if not applied, bottom shadows appear */ 606 | -moz-box-shadow: none; 607 | -webkit-box-shadow: none; 608 | box-shadow: none; 609 | } 610 | 611 | 612 | /* relbar */ 613 | 614 | .related { 615 | line-height: 30px; 616 | width: 100%; 617 | font-size: 0.9rem; 618 | } 619 | 620 | .related.top { 621 | border-bottom: 1px solid #EEE; 622 | margin-bottom: 20px; 623 | } 624 | 625 | .related.bottom { 626 | border-top: 1px solid #EEE; 627 | } 628 | 629 | .related ul { 630 | padding: 0; 631 | margin: 0; 632 | list-style: none; 633 | } 634 | 635 | .related li { 636 | display: inline; 637 | } 638 | 639 | nav#rellinks { 640 | float: right; 641 | } 642 | 643 | nav#rellinks li+li:before { 644 | content: "|"; 645 | } 646 | 647 | nav#breadcrumbs li+li:before { 648 | content: "\00BB"; 649 | } 650 | 651 | /* Hide certain items when printing */ 652 | @media print { 653 | div.related { 654 | display: none; 655 | } 656 | } 657 | 658 | img.github { 659 | position: absolute; 660 | top: 0; 661 | border: 0; 662 | right: 0; 663 | } -------------------------------------------------------------------------------- /docs/api/_static/basic.css: -------------------------------------------------------------------------------- 1 | /* 2 | * basic.css 3 | * ~~~~~~~~~ 4 | * 5 | * Sphinx stylesheet -- basic theme. 6 | * 7 | * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | /* -- main layout ----------------------------------------------------------- */ 13 | 14 | div.clearer { 15 | clear: both; 16 | } 17 | 18 | div.section::after { 19 | display: block; 20 | content: ''; 21 | clear: left; 22 | } 23 | 24 | /* -- relbar ---------------------------------------------------------------- */ 25 | 26 | div.related { 27 | width: 100%; 28 | font-size: 90%; 29 | } 30 | 31 | div.related h3 { 32 | display: none; 33 | } 34 | 35 | div.related ul { 36 | margin: 0; 37 | padding: 0 0 0 10px; 38 | list-style: none; 39 | } 40 | 41 | div.related li { 42 | display: inline; 43 | } 44 | 45 | div.related li.right { 46 | float: right; 47 | margin-right: 5px; 48 | } 49 | 50 | /* -- sidebar --------------------------------------------------------------- */ 51 | 52 | div.sphinxsidebarwrapper { 53 | padding: 10px 5px 0 10px; 54 | } 55 | 56 | div.sphinxsidebar { 57 | float: left; 58 | width: 230px; 59 | margin-left: -100%; 60 | font-size: 90%; 61 | word-wrap: break-word; 62 | overflow-wrap : break-word; 63 | } 64 | 65 | div.sphinxsidebar ul { 66 | list-style: none; 67 | } 68 | 69 | div.sphinxsidebar ul ul, 70 | div.sphinxsidebar ul.want-points { 71 | margin-left: 20px; 72 | list-style: square; 73 | } 74 | 75 | div.sphinxsidebar ul ul { 76 | margin-top: 0; 77 | margin-bottom: 0; 78 | } 79 | 80 | div.sphinxsidebar form { 81 | margin-top: 10px; 82 | } 83 | 84 | div.sphinxsidebar input { 85 | border: 1px solid #98dbcc; 86 | font-family: sans-serif; 87 | font-size: 1em; 88 | } 89 | 90 | div.sphinxsidebar #searchbox form.search { 91 | overflow: hidden; 92 | } 93 | 94 | div.sphinxsidebar #searchbox input[type="text"] { 95 | float: left; 96 | width: 80%; 97 | padding: 0.25em; 98 | box-sizing: border-box; 99 | } 100 | 101 | div.sphinxsidebar #searchbox input[type="submit"] { 102 | float: left; 103 | width: 20%; 104 | border-left: none; 105 | padding: 0.25em; 106 | box-sizing: border-box; 107 | } 108 | 109 | 110 | img { 111 | border: 0; 112 | max-width: 100%; 113 | } 114 | 115 | /* -- search page ----------------------------------------------------------- */ 116 | 117 | ul.search { 118 | margin: 10px 0 0 20px; 119 | padding: 0; 120 | } 121 | 122 | ul.search li { 123 | padding: 5px 0 5px 20px; 124 | background-image: url(file.png); 125 | background-repeat: no-repeat; 126 | background-position: 0 7px; 127 | } 128 | 129 | ul.search li a { 130 | font-weight: bold; 131 | } 132 | 133 | ul.search li p.context { 134 | color: #888; 135 | margin: 2px 0 0 30px; 136 | text-align: left; 137 | } 138 | 139 | ul.keywordmatches li.goodmatch a { 140 | font-weight: bold; 141 | } 142 | 143 | /* -- index page ------------------------------------------------------------ */ 144 | 145 | table.contentstable { 146 | width: 90%; 147 | margin-left: auto; 148 | margin-right: auto; 149 | } 150 | 151 | table.contentstable p.biglink { 152 | line-height: 150%; 153 | } 154 | 155 | a.biglink { 156 | font-size: 1.3em; 157 | } 158 | 159 | span.linkdescr { 160 | font-style: italic; 161 | padding-top: 5px; 162 | font-size: 90%; 163 | } 164 | 165 | /* -- general index --------------------------------------------------------- */ 166 | 167 | table.indextable { 168 | width: 100%; 169 | } 170 | 171 | table.indextable td { 172 | text-align: left; 173 | vertical-align: top; 174 | } 175 | 176 | table.indextable ul { 177 | margin-top: 0; 178 | margin-bottom: 0; 179 | list-style-type: none; 180 | } 181 | 182 | table.indextable > tbody > tr > td > ul { 183 | padding-left: 0em; 184 | } 185 | 186 | table.indextable tr.pcap { 187 | height: 10px; 188 | } 189 | 190 | table.indextable tr.cap { 191 | margin-top: 10px; 192 | background-color: #f2f2f2; 193 | } 194 | 195 | img.toggler { 196 | margin-right: 3px; 197 | margin-top: 3px; 198 | cursor: pointer; 199 | } 200 | 201 | div.modindex-jumpbox { 202 | border-top: 1px solid #ddd; 203 | border-bottom: 1px solid #ddd; 204 | margin: 1em 0 1em 0; 205 | padding: 0.4em; 206 | } 207 | 208 | div.genindex-jumpbox { 209 | border-top: 1px solid #ddd; 210 | border-bottom: 1px solid #ddd; 211 | margin: 1em 0 1em 0; 212 | padding: 0.4em; 213 | } 214 | 215 | /* -- domain module index --------------------------------------------------- */ 216 | 217 | table.modindextable td { 218 | padding: 2px; 219 | border-collapse: collapse; 220 | } 221 | 222 | /* -- general body styles --------------------------------------------------- */ 223 | 224 | div.body { 225 | min-width: inherit; 226 | max-width: 800px; 227 | } 228 | 229 | div.body p, div.body dd, div.body li, div.body blockquote { 230 | -moz-hyphens: auto; 231 | -ms-hyphens: auto; 232 | -webkit-hyphens: auto; 233 | hyphens: auto; 234 | } 235 | 236 | a.headerlink { 237 | visibility: hidden; 238 | } 239 | 240 | a:visited { 241 | color: #551A8B; 242 | } 243 | 244 | h1:hover > a.headerlink, 245 | h2:hover > a.headerlink, 246 | h3:hover > a.headerlink, 247 | h4:hover > a.headerlink, 248 | h5:hover > a.headerlink, 249 | h6:hover > a.headerlink, 250 | dt:hover > a.headerlink, 251 | caption:hover > a.headerlink, 252 | p.caption:hover > a.headerlink, 253 | div.code-block-caption:hover > a.headerlink { 254 | visibility: visible; 255 | } 256 | 257 | div.body p.caption { 258 | text-align: inherit; 259 | } 260 | 261 | div.body td { 262 | text-align: left; 263 | } 264 | 265 | .first { 266 | margin-top: 0 !important; 267 | } 268 | 269 | p.rubric { 270 | margin-top: 30px; 271 | font-weight: bold; 272 | } 273 | 274 | img.align-left, figure.align-left, .figure.align-left, object.align-left { 275 | clear: left; 276 | float: left; 277 | margin-right: 1em; 278 | } 279 | 280 | img.align-right, figure.align-right, .figure.align-right, object.align-right { 281 | clear: right; 282 | float: right; 283 | margin-left: 1em; 284 | } 285 | 286 | img.align-center, figure.align-center, .figure.align-center, object.align-center { 287 | display: block; 288 | margin-left: auto; 289 | margin-right: auto; 290 | } 291 | 292 | img.align-default, figure.align-default, .figure.align-default { 293 | display: block; 294 | margin-left: auto; 295 | margin-right: auto; 296 | } 297 | 298 | .align-left { 299 | text-align: left; 300 | } 301 | 302 | .align-center { 303 | text-align: center; 304 | } 305 | 306 | .align-default { 307 | text-align: center; 308 | } 309 | 310 | .align-right { 311 | text-align: right; 312 | } 313 | 314 | /* -- sidebars -------------------------------------------------------------- */ 315 | 316 | div.sidebar, 317 | aside.sidebar { 318 | margin: 0 0 0.5em 1em; 319 | border: 1px solid #ddb; 320 | padding: 7px; 321 | background-color: #ffe; 322 | width: 40%; 323 | float: right; 324 | clear: right; 325 | overflow-x: auto; 326 | } 327 | 328 | p.sidebar-title { 329 | font-weight: bold; 330 | } 331 | 332 | nav.contents, 333 | aside.topic, 334 | div.admonition, div.topic, blockquote { 335 | clear: left; 336 | } 337 | 338 | /* -- topics ---------------------------------------------------------------- */ 339 | 340 | nav.contents, 341 | aside.topic, 342 | div.topic { 343 | border: 1px solid #ccc; 344 | padding: 7px; 345 | margin: 10px 0 10px 0; 346 | } 347 | 348 | p.topic-title { 349 | font-size: 1.1em; 350 | font-weight: bold; 351 | margin-top: 10px; 352 | } 353 | 354 | /* -- admonitions ----------------------------------------------------------- */ 355 | 356 | div.admonition { 357 | margin-top: 10px; 358 | margin-bottom: 10px; 359 | padding: 7px; 360 | } 361 | 362 | div.admonition dt { 363 | font-weight: bold; 364 | } 365 | 366 | p.admonition-title { 367 | margin: 0px 10px 5px 0px; 368 | font-weight: bold; 369 | } 370 | 371 | div.body p.centered { 372 | text-align: center; 373 | margin-top: 25px; 374 | } 375 | 376 | /* -- content of sidebars/topics/admonitions -------------------------------- */ 377 | 378 | div.sidebar > :last-child, 379 | aside.sidebar > :last-child, 380 | nav.contents > :last-child, 381 | aside.topic > :last-child, 382 | div.topic > :last-child, 383 | div.admonition > :last-child { 384 | margin-bottom: 0; 385 | } 386 | 387 | div.sidebar::after, 388 | aside.sidebar::after, 389 | nav.contents::after, 390 | aside.topic::after, 391 | div.topic::after, 392 | div.admonition::after, 393 | blockquote::after { 394 | display: block; 395 | content: ''; 396 | clear: both; 397 | } 398 | 399 | /* -- tables ---------------------------------------------------------------- */ 400 | 401 | table.docutils { 402 | margin-top: 10px; 403 | margin-bottom: 10px; 404 | border: 0; 405 | border-collapse: collapse; 406 | } 407 | 408 | table.align-center { 409 | margin-left: auto; 410 | margin-right: auto; 411 | } 412 | 413 | table.align-default { 414 | margin-left: auto; 415 | margin-right: auto; 416 | } 417 | 418 | table caption span.caption-number { 419 | font-style: italic; 420 | } 421 | 422 | table caption span.caption-text { 423 | } 424 | 425 | table.docutils td, table.docutils th { 426 | padding: 1px 8px 1px 5px; 427 | border-top: 0; 428 | border-left: 0; 429 | border-right: 0; 430 | border-bottom: 1px solid #aaa; 431 | } 432 | 433 | th { 434 | text-align: left; 435 | padding-right: 5px; 436 | } 437 | 438 | table.citation { 439 | border-left: solid 1px gray; 440 | margin-left: 1px; 441 | } 442 | 443 | table.citation td { 444 | border-bottom: none; 445 | } 446 | 447 | th > :first-child, 448 | td > :first-child { 449 | margin-top: 0px; 450 | } 451 | 452 | th > :last-child, 453 | td > :last-child { 454 | margin-bottom: 0px; 455 | } 456 | 457 | /* -- figures --------------------------------------------------------------- */ 458 | 459 | div.figure, figure { 460 | margin: 0.5em; 461 | padding: 0.5em; 462 | } 463 | 464 | div.figure p.caption, figcaption { 465 | padding: 0.3em; 466 | } 467 | 468 | div.figure p.caption span.caption-number, 469 | figcaption span.caption-number { 470 | font-style: italic; 471 | } 472 | 473 | div.figure p.caption span.caption-text, 474 | figcaption span.caption-text { 475 | } 476 | 477 | /* -- field list styles ----------------------------------------------------- */ 478 | 479 | table.field-list td, table.field-list th { 480 | border: 0 !important; 481 | } 482 | 483 | .field-list ul { 484 | margin: 0; 485 | padding-left: 1em; 486 | } 487 | 488 | .field-list p { 489 | margin: 0; 490 | } 491 | 492 | .field-name { 493 | -moz-hyphens: manual; 494 | -ms-hyphens: manual; 495 | -webkit-hyphens: manual; 496 | hyphens: manual; 497 | } 498 | 499 | /* -- hlist styles ---------------------------------------------------------- */ 500 | 501 | table.hlist { 502 | margin: 1em 0; 503 | } 504 | 505 | table.hlist td { 506 | vertical-align: top; 507 | } 508 | 509 | /* -- object description styles --------------------------------------------- */ 510 | 511 | .sig { 512 | font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace; 513 | } 514 | 515 | .sig-name, code.descname { 516 | background-color: transparent; 517 | font-weight: bold; 518 | } 519 | 520 | .sig-name { 521 | font-size: 1.1em; 522 | } 523 | 524 | code.descname { 525 | font-size: 1.2em; 526 | } 527 | 528 | .sig-prename, code.descclassname { 529 | background-color: transparent; 530 | } 531 | 532 | .optional { 533 | font-size: 1.3em; 534 | } 535 | 536 | .sig-paren { 537 | font-size: larger; 538 | } 539 | 540 | .sig-param.n { 541 | font-style: italic; 542 | } 543 | 544 | /* C++ specific styling */ 545 | 546 | .sig-inline.c-texpr, 547 | .sig-inline.cpp-texpr { 548 | font-family: unset; 549 | } 550 | 551 | .sig.c .k, .sig.c .kt, 552 | .sig.cpp .k, .sig.cpp .kt { 553 | color: #0033B3; 554 | } 555 | 556 | .sig.c .m, 557 | .sig.cpp .m { 558 | color: #1750EB; 559 | } 560 | 561 | .sig.c .s, .sig.c .sc, 562 | .sig.cpp .s, .sig.cpp .sc { 563 | color: #067D17; 564 | } 565 | 566 | 567 | /* -- other body styles ----------------------------------------------------- */ 568 | 569 | ol.arabic { 570 | list-style: decimal; 571 | } 572 | 573 | ol.loweralpha { 574 | list-style: lower-alpha; 575 | } 576 | 577 | ol.upperalpha { 578 | list-style: upper-alpha; 579 | } 580 | 581 | ol.lowerroman { 582 | list-style: lower-roman; 583 | } 584 | 585 | ol.upperroman { 586 | list-style: upper-roman; 587 | } 588 | 589 | :not(li) > ol > li:first-child > :first-child, 590 | :not(li) > ul > li:first-child > :first-child { 591 | margin-top: 0px; 592 | } 593 | 594 | :not(li) > ol > li:last-child > :last-child, 595 | :not(li) > ul > li:last-child > :last-child { 596 | margin-bottom: 0px; 597 | } 598 | 599 | ol.simple ol p, 600 | ol.simple ul p, 601 | ul.simple ol p, 602 | ul.simple ul p { 603 | margin-top: 0; 604 | } 605 | 606 | ol.simple > li:not(:first-child) > p, 607 | ul.simple > li:not(:first-child) > p { 608 | margin-top: 0; 609 | } 610 | 611 | ol.simple p, 612 | ul.simple p { 613 | margin-bottom: 0; 614 | } 615 | 616 | aside.footnote > span, 617 | div.citation > span { 618 | float: left; 619 | } 620 | aside.footnote > span:last-of-type, 621 | div.citation > span:last-of-type { 622 | padding-right: 0.5em; 623 | } 624 | aside.footnote > p { 625 | margin-left: 2em; 626 | } 627 | div.citation > p { 628 | margin-left: 4em; 629 | } 630 | aside.footnote > p:last-of-type, 631 | div.citation > p:last-of-type { 632 | margin-bottom: 0em; 633 | } 634 | aside.footnote > p:last-of-type:after, 635 | div.citation > p:last-of-type:after { 636 | content: ""; 637 | clear: both; 638 | } 639 | 640 | dl.field-list { 641 | display: grid; 642 | grid-template-columns: fit-content(30%) auto; 643 | } 644 | 645 | dl.field-list > dt { 646 | font-weight: bold; 647 | word-break: break-word; 648 | padding-left: 0.5em; 649 | padding-right: 5px; 650 | } 651 | 652 | dl.field-list > dd { 653 | padding-left: 0.5em; 654 | margin-top: 0em; 655 | margin-left: 0em; 656 | margin-bottom: 0em; 657 | } 658 | 659 | dl { 660 | margin-bottom: 15px; 661 | } 662 | 663 | dd > :first-child { 664 | margin-top: 0px; 665 | } 666 | 667 | dd ul, dd table { 668 | margin-bottom: 10px; 669 | } 670 | 671 | dd { 672 | margin-top: 3px; 673 | margin-bottom: 10px; 674 | margin-left: 30px; 675 | } 676 | 677 | .sig dd { 678 | margin-top: 0px; 679 | margin-bottom: 0px; 680 | } 681 | 682 | .sig dl { 683 | margin-top: 0px; 684 | margin-bottom: 0px; 685 | } 686 | 687 | dl > dd:last-child, 688 | dl > dd:last-child > :last-child { 689 | margin-bottom: 0; 690 | } 691 | 692 | dt:target, span.highlighted { 693 | background-color: #fbe54e; 694 | } 695 | 696 | rect.highlighted { 697 | fill: #fbe54e; 698 | } 699 | 700 | dl.glossary dt { 701 | font-weight: bold; 702 | font-size: 1.1em; 703 | } 704 | 705 | .versionmodified { 706 | font-style: italic; 707 | } 708 | 709 | .system-message { 710 | background-color: #fda; 711 | padding: 5px; 712 | border: 3px solid red; 713 | } 714 | 715 | .footnote:target { 716 | background-color: #ffa; 717 | } 718 | 719 | .line-block { 720 | display: block; 721 | margin-top: 1em; 722 | margin-bottom: 1em; 723 | } 724 | 725 | .line-block .line-block { 726 | margin-top: 0; 727 | margin-bottom: 0; 728 | margin-left: 1.5em; 729 | } 730 | 731 | .guilabel, .menuselection { 732 | font-family: sans-serif; 733 | } 734 | 735 | .accelerator { 736 | text-decoration: underline; 737 | } 738 | 739 | .classifier { 740 | font-style: oblique; 741 | } 742 | 743 | .classifier:before { 744 | font-style: normal; 745 | margin: 0 0.5em; 746 | content: ":"; 747 | display: inline-block; 748 | } 749 | 750 | abbr, acronym { 751 | border-bottom: dotted 1px; 752 | cursor: help; 753 | } 754 | 755 | .translated { 756 | background-color: rgba(207, 255, 207, 0.2) 757 | } 758 | 759 | .untranslated { 760 | background-color: rgba(255, 207, 207, 0.2) 761 | } 762 | 763 | /* -- code displays --------------------------------------------------------- */ 764 | 765 | pre { 766 | overflow: auto; 767 | overflow-y: hidden; /* fixes display issues on Chrome browsers */ 768 | } 769 | 770 | pre, div[class*="highlight-"] { 771 | clear: both; 772 | } 773 | 774 | span.pre { 775 | -moz-hyphens: none; 776 | -ms-hyphens: none; 777 | -webkit-hyphens: none; 778 | hyphens: none; 779 | white-space: nowrap; 780 | } 781 | 782 | div[class*="highlight-"] { 783 | margin: 1em 0; 784 | } 785 | 786 | td.linenos pre { 787 | border: 0; 788 | background-color: transparent; 789 | color: #aaa; 790 | } 791 | 792 | table.highlighttable { 793 | display: block; 794 | } 795 | 796 | table.highlighttable tbody { 797 | display: block; 798 | } 799 | 800 | table.highlighttable tr { 801 | display: flex; 802 | } 803 | 804 | table.highlighttable td { 805 | margin: 0; 806 | padding: 0; 807 | } 808 | 809 | table.highlighttable td.linenos { 810 | padding-right: 0.5em; 811 | } 812 | 813 | table.highlighttable td.code { 814 | flex: 1; 815 | overflow: hidden; 816 | } 817 | 818 | .highlight .hll { 819 | display: block; 820 | } 821 | 822 | div.highlight pre, 823 | table.highlighttable pre { 824 | margin: 0; 825 | } 826 | 827 | div.code-block-caption + div { 828 | margin-top: 0; 829 | } 830 | 831 | div.code-block-caption { 832 | margin-top: 1em; 833 | padding: 2px 5px; 834 | font-size: small; 835 | } 836 | 837 | div.code-block-caption code { 838 | background-color: transparent; 839 | } 840 | 841 | table.highlighttable td.linenos, 842 | span.linenos, 843 | div.highlight span.gp { /* gp: Generic.Prompt */ 844 | user-select: none; 845 | -webkit-user-select: text; /* Safari fallback only */ 846 | -webkit-user-select: none; /* Chrome/Safari */ 847 | -moz-user-select: none; /* Firefox */ 848 | -ms-user-select: none; /* IE10+ */ 849 | } 850 | 851 | div.code-block-caption span.caption-number { 852 | padding: 0.1em 0.3em; 853 | font-style: italic; 854 | } 855 | 856 | div.code-block-caption span.caption-text { 857 | } 858 | 859 | div.literal-block-wrapper { 860 | margin: 1em 0; 861 | } 862 | 863 | code.xref, a code { 864 | background-color: transparent; 865 | font-weight: bold; 866 | } 867 | 868 | h1 code, h2 code, h3 code, h4 code, h5 code, h6 code { 869 | background-color: transparent; 870 | } 871 | 872 | .viewcode-link { 873 | float: right; 874 | } 875 | 876 | .viewcode-back { 877 | float: right; 878 | font-family: sans-serif; 879 | } 880 | 881 | div.viewcode-block:target { 882 | margin: -1px -10px; 883 | padding: 0 10px; 884 | } 885 | 886 | /* -- math display ---------------------------------------------------------- */ 887 | 888 | img.math { 889 | vertical-align: middle; 890 | } 891 | 892 | div.body div.math p { 893 | text-align: center; 894 | } 895 | 896 | span.eqno { 897 | float: right; 898 | } 899 | 900 | span.eqno a.headerlink { 901 | position: absolute; 902 | z-index: 1; 903 | } 904 | 905 | div.math:hover a.headerlink { 906 | visibility: visible; 907 | } 908 | 909 | /* -- printout stylesheet --------------------------------------------------- */ 910 | 911 | @media print { 912 | div.document, 913 | div.documentwrapper, 914 | div.bodywrapper { 915 | margin: 0 !important; 916 | width: 100%; 917 | } 918 | 919 | div.sphinxsidebar, 920 | div.related, 921 | div.footer, 922 | #top-link { 923 | display: none; 924 | } 925 | } -------------------------------------------------------------------------------- /docs/api/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* This file intentionally left blank. */ 2 | -------------------------------------------------------------------------------- /docs/api/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Base JavaScript utilities for all Sphinx HTML documentation. 6 | * 7 | * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | "use strict"; 12 | 13 | const BLACKLISTED_KEY_CONTROL_ELEMENTS = new Set([ 14 | "TEXTAREA", 15 | "INPUT", 16 | "SELECT", 17 | "BUTTON", 18 | ]); 19 | 20 | const _ready = (callback) => { 21 | if (document.readyState !== "loading") { 22 | callback(); 23 | } else { 24 | document.addEventListener("DOMContentLoaded", callback); 25 | } 26 | }; 27 | 28 | /** 29 | * Small JavaScript module for the documentation. 30 | */ 31 | const Documentation = { 32 | init: () => { 33 | Documentation.initDomainIndexTable(); 34 | Documentation.initOnKeyListeners(); 35 | }, 36 | 37 | /** 38 | * i18n support 39 | */ 40 | TRANSLATIONS: {}, 41 | PLURAL_EXPR: (n) => (n === 1 ? 0 : 1), 42 | LOCALE: "unknown", 43 | 44 | // gettext and ngettext don't access this so that the functions 45 | // can safely bound to a different name (_ = Documentation.gettext) 46 | gettext: (string) => { 47 | const translated = Documentation.TRANSLATIONS[string]; 48 | switch (typeof translated) { 49 | case "undefined": 50 | return string; // no translation 51 | case "string": 52 | return translated; // translation exists 53 | default: 54 | return translated[0]; // (singular, plural) translation tuple exists 55 | } 56 | }, 57 | 58 | ngettext: (singular, plural, n) => { 59 | const translated = Documentation.TRANSLATIONS[singular]; 60 | if (typeof translated !== "undefined") 61 | return translated[Documentation.PLURAL_EXPR(n)]; 62 | return n === 1 ? singular : plural; 63 | }, 64 | 65 | addTranslations: (catalog) => { 66 | Object.assign(Documentation.TRANSLATIONS, catalog.messages); 67 | Documentation.PLURAL_EXPR = new Function( 68 | "n", 69 | `return (${catalog.plural_expr})` 70 | ); 71 | Documentation.LOCALE = catalog.locale; 72 | }, 73 | 74 | /** 75 | * helper function to focus on search bar 76 | */ 77 | focusSearchBar: () => { 78 | document.querySelectorAll("input[name=q]")[0]?.focus(); 79 | }, 80 | 81 | /** 82 | * Initialise the domain index toggle buttons 83 | */ 84 | initDomainIndexTable: () => { 85 | const toggler = (el) => { 86 | const idNumber = el.id.substr(7); 87 | const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`); 88 | if (el.src.substr(-9) === "minus.png") { 89 | el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`; 90 | toggledRows.forEach((el) => (el.style.display = "none")); 91 | } else { 92 | el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`; 93 | toggledRows.forEach((el) => (el.style.display = "")); 94 | } 95 | }; 96 | 97 | const togglerElements = document.querySelectorAll("img.toggler"); 98 | togglerElements.forEach((el) => 99 | el.addEventListener("click", (event) => toggler(event.currentTarget)) 100 | ); 101 | togglerElements.forEach((el) => (el.style.display = "")); 102 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler); 103 | }, 104 | 105 | initOnKeyListeners: () => { 106 | // only install a listener if it is really needed 107 | if ( 108 | !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && 109 | !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS 110 | ) 111 | return; 112 | 113 | document.addEventListener("keydown", (event) => { 114 | // bail for input elements 115 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 116 | // bail with special keys 117 | if (event.altKey || event.ctrlKey || event.metaKey) return; 118 | 119 | if (!event.shiftKey) { 120 | switch (event.key) { 121 | case "ArrowLeft": 122 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 123 | 124 | const prevLink = document.querySelector('link[rel="prev"]'); 125 | if (prevLink && prevLink.href) { 126 | window.location.href = prevLink.href; 127 | event.preventDefault(); 128 | } 129 | break; 130 | case "ArrowRight": 131 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 132 | 133 | const nextLink = document.querySelector('link[rel="next"]'); 134 | if (nextLink && nextLink.href) { 135 | window.location.href = nextLink.href; 136 | event.preventDefault(); 137 | } 138 | break; 139 | } 140 | } 141 | 142 | // some keyboard layouts may need Shift to get / 143 | switch (event.key) { 144 | case "/": 145 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; 146 | Documentation.focusSearchBar(); 147 | event.preventDefault(); 148 | } 149 | }); 150 | }, 151 | }; 152 | 153 | // quick alias for translations 154 | const _ = Documentation.gettext; 155 | 156 | _ready(Documentation.init); 157 | -------------------------------------------------------------------------------- /docs/api/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | const DOCUMENTATION_OPTIONS = { 2 | VERSION: '', 3 | LANGUAGE: 'en', 4 | COLLAPSE_INDEX: false, 5 | BUILDER: 'html', 6 | FILE_SUFFIX: '.html', 7 | LINK_SUFFIX: '.html', 8 | HAS_SOURCE: true, 9 | SOURCELINK_SUFFIX: '.txt', 10 | NAVIGATION_WITH_KEYS: false, 11 | SHOW_SEARCH_SUMMARY: true, 12 | ENABLE_SEARCH_SHORTCUTS: true, 13 | }; -------------------------------------------------------------------------------- /docs/api/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/api/_static/file.png -------------------------------------------------------------------------------- /docs/api/_static/github-banner.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /docs/api/_static/language_data.js: -------------------------------------------------------------------------------- 1 | /* 2 | * language_data.js 3 | * ~~~~~~~~~~~~~~~~ 4 | * 5 | * This script contains the language-specific data used by searchtools.js, 6 | * namely the list of stopwords, stemmer, scorer and splitter. 7 | * 8 | * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. 9 | * :license: BSD, see LICENSE for details. 10 | * 11 | */ 12 | 13 | var stopwords = ["a", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it", "near", "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"]; 14 | 15 | 16 | /* Non-minified version is copied as a separate JS file, if available */ 17 | 18 | /** 19 | * Porter Stemmer 20 | */ 21 | var Stemmer = function() { 22 | 23 | var step2list = { 24 | ational: 'ate', 25 | tional: 'tion', 26 | enci: 'ence', 27 | anci: 'ance', 28 | izer: 'ize', 29 | bli: 'ble', 30 | alli: 'al', 31 | entli: 'ent', 32 | eli: 'e', 33 | ousli: 'ous', 34 | ization: 'ize', 35 | ation: 'ate', 36 | ator: 'ate', 37 | alism: 'al', 38 | iveness: 'ive', 39 | fulness: 'ful', 40 | ousness: 'ous', 41 | aliti: 'al', 42 | iviti: 'ive', 43 | biliti: 'ble', 44 | logi: 'log' 45 | }; 46 | 47 | var step3list = { 48 | icate: 'ic', 49 | ative: '', 50 | alize: 'al', 51 | iciti: 'ic', 52 | ical: 'ic', 53 | ful: '', 54 | ness: '' 55 | }; 56 | 57 | var c = "[^aeiou]"; // consonant 58 | var v = "[aeiouy]"; // vowel 59 | var C = c + "[^aeiouy]*"; // consonant sequence 60 | var V = v + "[aeiou]*"; // vowel sequence 61 | 62 | var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0 63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 65 | var s_v = "^(" + C + ")?" + v; // vowel in stem 66 | 67 | this.stemWord = function (w) { 68 | var stem; 69 | var suffix; 70 | var firstch; 71 | var origword = w; 72 | 73 | if (w.length < 3) 74 | return w; 75 | 76 | var re; 77 | var re2; 78 | var re3; 79 | var re4; 80 | 81 | firstch = w.substr(0,1); 82 | if (firstch == "y") 83 | w = firstch.toUpperCase() + w.substr(1); 84 | 85 | // Step 1a 86 | re = /^(.+?)(ss|i)es$/; 87 | re2 = /^(.+?)([^s])s$/; 88 | 89 | if (re.test(w)) 90 | w = w.replace(re,"$1$2"); 91 | else if (re2.test(w)) 92 | w = w.replace(re2,"$1$2"); 93 | 94 | // Step 1b 95 | re = /^(.+?)eed$/; 96 | re2 = /^(.+?)(ed|ing)$/; 97 | if (re.test(w)) { 98 | var fp = re.exec(w); 99 | re = new RegExp(mgr0); 100 | if (re.test(fp[1])) { 101 | re = /.$/; 102 | w = w.replace(re,""); 103 | } 104 | } 105 | else if (re2.test(w)) { 106 | var fp = re2.exec(w); 107 | stem = fp[1]; 108 | re2 = new RegExp(s_v); 109 | if (re2.test(stem)) { 110 | w = stem; 111 | re2 = /(at|bl|iz)$/; 112 | re3 = new RegExp("([^aeiouylsz])\\1$"); 113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 114 | if (re2.test(w)) 115 | w = w + "e"; 116 | else if (re3.test(w)) { 117 | re = /.$/; 118 | w = w.replace(re,""); 119 | } 120 | else if (re4.test(w)) 121 | w = w + "e"; 122 | } 123 | } 124 | 125 | // Step 1c 126 | re = /^(.+?)y$/; 127 | if (re.test(w)) { 128 | var fp = re.exec(w); 129 | stem = fp[1]; 130 | re = new RegExp(s_v); 131 | if (re.test(stem)) 132 | w = stem + "i"; 133 | } 134 | 135 | // Step 2 136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 137 | if (re.test(w)) { 138 | var fp = re.exec(w); 139 | stem = fp[1]; 140 | suffix = fp[2]; 141 | re = new RegExp(mgr0); 142 | if (re.test(stem)) 143 | w = stem + step2list[suffix]; 144 | } 145 | 146 | // Step 3 147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 148 | if (re.test(w)) { 149 | var fp = re.exec(w); 150 | stem = fp[1]; 151 | suffix = fp[2]; 152 | re = new RegExp(mgr0); 153 | if (re.test(stem)) 154 | w = stem + step3list[suffix]; 155 | } 156 | 157 | // Step 4 158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 159 | re2 = /^(.+?)(s|t)(ion)$/; 160 | if (re.test(w)) { 161 | var fp = re.exec(w); 162 | stem = fp[1]; 163 | re = new RegExp(mgr1); 164 | if (re.test(stem)) 165 | w = stem; 166 | } 167 | else if (re2.test(w)) { 168 | var fp = re2.exec(w); 169 | stem = fp[1] + fp[2]; 170 | re2 = new RegExp(mgr1); 171 | if (re2.test(stem)) 172 | w = stem; 173 | } 174 | 175 | // Step 5 176 | re = /^(.+?)e$/; 177 | if (re.test(w)) { 178 | var fp = re.exec(w); 179 | stem = fp[1]; 180 | re = new RegExp(mgr1); 181 | re2 = new RegExp(meq1); 182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 184 | w = stem; 185 | } 186 | re = /ll$/; 187 | re2 = new RegExp(mgr1); 188 | if (re.test(w) && re2.test(w)) { 189 | re = /.$/; 190 | w = w.replace(re,""); 191 | } 192 | 193 | // and turn initial Y back to y 194 | if (firstch == "y") 195 | w = firstch.toLowerCase() + w.substr(1); 196 | return w; 197 | } 198 | } 199 | 200 | -------------------------------------------------------------------------------- /docs/api/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/api/_static/minus.png -------------------------------------------------------------------------------- /docs/api/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/api/_static/plus.png -------------------------------------------------------------------------------- /docs/api/_static/pygments.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .highlight .hll { background-color: #ffffcc } 7 | .highlight { background: #f8f8f8; } 8 | .highlight .c { color: #8f5902; font-style: italic } /* Comment */ 9 | .highlight .err { color: #a40000; border: 1px solid #ef2929 } /* Error */ 10 | .highlight .g { color: #000000 } /* Generic */ 11 | .highlight .k { color: #004461; font-weight: bold } /* Keyword */ 12 | .highlight .l { color: #000000 } /* Literal */ 13 | .highlight .n { color: #000000 } /* Name */ 14 | .highlight .o { color: #582800 } /* Operator */ 15 | .highlight .x { color: #000000 } /* Other */ 16 | .highlight .p { color: #000000; font-weight: bold } /* Punctuation */ 17 | .highlight .ch { color: #8f5902; font-style: italic } /* Comment.Hashbang */ 18 | .highlight .cm { color: #8f5902; font-style: italic } /* Comment.Multiline */ 19 | .highlight .cp { color: #8f5902 } /* Comment.Preproc */ 20 | .highlight .cpf { color: #8f5902; font-style: italic } /* Comment.PreprocFile */ 21 | .highlight .c1 { color: #8f5902; font-style: italic } /* Comment.Single */ 22 | .highlight .cs { color: #8f5902; font-style: italic } /* Comment.Special */ 23 | .highlight .gd { color: #a40000 } /* Generic.Deleted */ 24 | .highlight .ge { color: #000000; font-style: italic } /* Generic.Emph */ 25 | .highlight .ges { color: #000000 } /* Generic.EmphStrong */ 26 | .highlight .gr { color: #ef2929 } /* Generic.Error */ 27 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 28 | .highlight .gi { color: #00A000 } /* Generic.Inserted */ 29 | .highlight .go { color: #888888 } /* Generic.Output */ 30 | .highlight .gp { color: #745334 } /* Generic.Prompt */ 31 | .highlight .gs { color: #000000; font-weight: bold } /* Generic.Strong */ 32 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 33 | .highlight .gt { color: #a40000; font-weight: bold } /* Generic.Traceback */ 34 | .highlight .kc { color: #004461; font-weight: bold } /* Keyword.Constant */ 35 | .highlight .kd { color: #004461; font-weight: bold } /* Keyword.Declaration */ 36 | .highlight .kn { color: #004461; font-weight: bold } /* Keyword.Namespace */ 37 | .highlight .kp { color: #004461; font-weight: bold } /* Keyword.Pseudo */ 38 | .highlight .kr { color: #004461; font-weight: bold } /* Keyword.Reserved */ 39 | .highlight .kt { color: #004461; font-weight: bold } /* Keyword.Type */ 40 | .highlight .ld { color: #000000 } /* Literal.Date */ 41 | .highlight .m { color: #990000 } /* Literal.Number */ 42 | .highlight .s { color: #4e9a06 } /* Literal.String */ 43 | .highlight .na { color: #c4a000 } /* Name.Attribute */ 44 | .highlight .nb { color: #004461 } /* Name.Builtin */ 45 | .highlight .nc { color: #000000 } /* Name.Class */ 46 | .highlight .no { color: #000000 } /* Name.Constant */ 47 | .highlight .nd { color: #888888 } /* Name.Decorator */ 48 | .highlight .ni { color: #ce5c00 } /* Name.Entity */ 49 | .highlight .ne { color: #cc0000; font-weight: bold } /* Name.Exception */ 50 | .highlight .nf { color: #000000 } /* Name.Function */ 51 | .highlight .nl { color: #f57900 } /* Name.Label */ 52 | .highlight .nn { color: #000000 } /* Name.Namespace */ 53 | .highlight .nx { color: #000000 } /* Name.Other */ 54 | .highlight .py { color: #000000 } /* Name.Property */ 55 | .highlight .nt { color: #004461; font-weight: bold } /* Name.Tag */ 56 | .highlight .nv { color: #000000 } /* Name.Variable */ 57 | .highlight .ow { color: #004461; font-weight: bold } /* Operator.Word */ 58 | .highlight .pm { color: #000000; font-weight: bold } /* Punctuation.Marker */ 59 | .highlight .w { color: #f8f8f8 } /* Text.Whitespace */ 60 | .highlight .mb { color: #990000 } /* Literal.Number.Bin */ 61 | .highlight .mf { color: #990000 } /* Literal.Number.Float */ 62 | .highlight .mh { color: #990000 } /* Literal.Number.Hex */ 63 | .highlight .mi { color: #990000 } /* Literal.Number.Integer */ 64 | .highlight .mo { color: #990000 } /* Literal.Number.Oct */ 65 | .highlight .sa { color: #4e9a06 } /* Literal.String.Affix */ 66 | .highlight .sb { color: #4e9a06 } /* Literal.String.Backtick */ 67 | .highlight .sc { color: #4e9a06 } /* Literal.String.Char */ 68 | .highlight .dl { color: #4e9a06 } /* Literal.String.Delimiter */ 69 | .highlight .sd { color: #8f5902; font-style: italic } /* Literal.String.Doc */ 70 | .highlight .s2 { color: #4e9a06 } /* Literal.String.Double */ 71 | .highlight .se { color: #4e9a06 } /* Literal.String.Escape */ 72 | .highlight .sh { color: #4e9a06 } /* Literal.String.Heredoc */ 73 | .highlight .si { color: #4e9a06 } /* Literal.String.Interpol */ 74 | .highlight .sx { color: #4e9a06 } /* Literal.String.Other */ 75 | .highlight .sr { color: #4e9a06 } /* Literal.String.Regex */ 76 | .highlight .s1 { color: #4e9a06 } /* Literal.String.Single */ 77 | .highlight .ss { color: #4e9a06 } /* Literal.String.Symbol */ 78 | .highlight .bp { color: #3465a4 } /* Name.Builtin.Pseudo */ 79 | .highlight .fm { color: #000000 } /* Name.Function.Magic */ 80 | .highlight .vc { color: #000000 } /* Name.Variable.Class */ 81 | .highlight .vg { color: #000000 } /* Name.Variable.Global */ 82 | .highlight .vi { color: #000000 } /* Name.Variable.Instance */ 83 | .highlight .vm { color: #000000 } /* Name.Variable.Magic */ 84 | .highlight .il { color: #990000 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs/api/_static/searchtools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * searchtools.js 3 | * ~~~~~~~~~~~~~~~~ 4 | * 5 | * Sphinx JavaScript utilities for the full-text search. 6 | * 7 | * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | "use strict"; 12 | 13 | /** 14 | * Simple result scoring code. 15 | */ 16 | if (typeof Scorer === "undefined") { 17 | var Scorer = { 18 | // Implement the following function to further tweak the score for each result 19 | // The function takes a result array [docname, title, anchor, descr, score, filename] 20 | // and returns the new score. 21 | /* 22 | score: result => { 23 | const [docname, title, anchor, descr, score, filename] = result 24 | return score 25 | }, 26 | */ 27 | 28 | // query matches the full name of an object 29 | objNameMatch: 11, 30 | // or matches in the last dotted part of the object name 31 | objPartialMatch: 6, 32 | // Additive scores depending on the priority of the object 33 | objPrio: { 34 | 0: 15, // used to be importantResults 35 | 1: 5, // used to be objectResults 36 | 2: -5, // used to be unimportantResults 37 | }, 38 | // Used when the priority is not in the mapping. 39 | objPrioDefault: 0, 40 | 41 | // query found in title 42 | title: 15, 43 | partialTitle: 7, 44 | // query found in terms 45 | term: 5, 46 | partialTerm: 2, 47 | }; 48 | } 49 | 50 | const _removeChildren = (element) => { 51 | while (element && element.lastChild) element.removeChild(element.lastChild); 52 | }; 53 | 54 | /** 55 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Regular_Expressions#escaping 56 | */ 57 | const _escapeRegExp = (string) => 58 | string.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string 59 | 60 | const _displayItem = (item, searchTerms, highlightTerms) => { 61 | const docBuilder = DOCUMENTATION_OPTIONS.BUILDER; 62 | const docFileSuffix = DOCUMENTATION_OPTIONS.FILE_SUFFIX; 63 | const docLinkSuffix = DOCUMENTATION_OPTIONS.LINK_SUFFIX; 64 | const showSearchSummary = DOCUMENTATION_OPTIONS.SHOW_SEARCH_SUMMARY; 65 | const contentRoot = document.documentElement.dataset.content_root; 66 | 67 | const [docName, title, anchor, descr, score, _filename] = item; 68 | 69 | let listItem = document.createElement("li"); 70 | let requestUrl; 71 | let linkUrl; 72 | if (docBuilder === "dirhtml") { 73 | // dirhtml builder 74 | let dirname = docName + "/"; 75 | if (dirname.match(/\/index\/$/)) 76 | dirname = dirname.substring(0, dirname.length - 6); 77 | else if (dirname === "index/") dirname = ""; 78 | requestUrl = contentRoot + dirname; 79 | linkUrl = requestUrl; 80 | } else { 81 | // normal html builders 82 | requestUrl = contentRoot + docName + docFileSuffix; 83 | linkUrl = docName + docLinkSuffix; 84 | } 85 | let linkEl = listItem.appendChild(document.createElement("a")); 86 | linkEl.href = linkUrl + anchor; 87 | linkEl.dataset.score = score; 88 | linkEl.innerHTML = title; 89 | if (descr) { 90 | listItem.appendChild(document.createElement("span")).innerHTML = 91 | " (" + descr + ")"; 92 | // highlight search terms in the description 93 | if (SPHINX_HIGHLIGHT_ENABLED) // set in sphinx_highlight.js 94 | highlightTerms.forEach((term) => _highlightText(listItem, term, "highlighted")); 95 | } 96 | else if (showSearchSummary) 97 | fetch(requestUrl) 98 | .then((responseData) => responseData.text()) 99 | .then((data) => { 100 | if (data) 101 | listItem.appendChild( 102 | Search.makeSearchSummary(data, searchTerms, anchor) 103 | ); 104 | // highlight search terms in the summary 105 | if (SPHINX_HIGHLIGHT_ENABLED) // set in sphinx_highlight.js 106 | highlightTerms.forEach((term) => _highlightText(listItem, term, "highlighted")); 107 | }); 108 | Search.output.appendChild(listItem); 109 | }; 110 | const _finishSearch = (resultCount) => { 111 | Search.stopPulse(); 112 | Search.title.innerText = _("Search Results"); 113 | if (!resultCount) 114 | Search.status.innerText = Documentation.gettext( 115 | "Your search did not match any documents. Please make sure that all words are spelled correctly and that you've selected enough categories." 116 | ); 117 | else 118 | Search.status.innerText = _( 119 | "Search finished, found ${resultCount} page(s) matching the search query." 120 | ).replace('${resultCount}', resultCount); 121 | }; 122 | const _displayNextItem = ( 123 | results, 124 | resultCount, 125 | searchTerms, 126 | highlightTerms, 127 | ) => { 128 | // results left, load the summary and display it 129 | // this is intended to be dynamic (don't sub resultsCount) 130 | if (results.length) { 131 | _displayItem(results.pop(), searchTerms, highlightTerms); 132 | setTimeout( 133 | () => _displayNextItem(results, resultCount, searchTerms, highlightTerms), 134 | 5 135 | ); 136 | } 137 | // search finished, update title and status message 138 | else _finishSearch(resultCount); 139 | }; 140 | // Helper function used by query() to order search results. 141 | // Each input is an array of [docname, title, anchor, descr, score, filename]. 142 | // Order the results by score (in opposite order of appearance, since the 143 | // `_displayNextItem` function uses pop() to retrieve items) and then alphabetically. 144 | const _orderResultsByScoreThenName = (a, b) => { 145 | const leftScore = a[4]; 146 | const rightScore = b[4]; 147 | if (leftScore === rightScore) { 148 | // same score: sort alphabetically 149 | const leftTitle = a[1].toLowerCase(); 150 | const rightTitle = b[1].toLowerCase(); 151 | if (leftTitle === rightTitle) return 0; 152 | return leftTitle > rightTitle ? -1 : 1; // inverted is intentional 153 | } 154 | return leftScore > rightScore ? 1 : -1; 155 | }; 156 | 157 | /** 158 | * Default splitQuery function. Can be overridden in ``sphinx.search`` with a 159 | * custom function per language. 160 | * 161 | * The regular expression works by splitting the string on consecutive characters 162 | * that are not Unicode letters, numbers, underscores, or emoji characters. 163 | * This is the same as ``\W+`` in Python, preserving the surrogate pair area. 164 | */ 165 | if (typeof splitQuery === "undefined") { 166 | var splitQuery = (query) => query 167 | .split(/[^\p{Letter}\p{Number}_\p{Emoji_Presentation}]+/gu) 168 | .filter(term => term) // remove remaining empty strings 169 | } 170 | 171 | /** 172 | * Search Module 173 | */ 174 | const Search = { 175 | _index: null, 176 | _queued_query: null, 177 | _pulse_status: -1, 178 | 179 | htmlToText: (htmlString, anchor) => { 180 | const htmlElement = new DOMParser().parseFromString(htmlString, 'text/html'); 181 | for (const removalQuery of [".headerlink", "script", "style"]) { 182 | htmlElement.querySelectorAll(removalQuery).forEach((el) => { el.remove() }); 183 | } 184 | if (anchor) { 185 | const anchorContent = htmlElement.querySelector(`[role="main"] ${anchor}`); 186 | if (anchorContent) return anchorContent.textContent; 187 | 188 | console.warn( 189 | `Anchored content block not found. Sphinx search tries to obtain it via DOM query '[role=main] ${anchor}'. Check your theme or template.` 190 | ); 191 | } 192 | 193 | // if anchor not specified or not found, fall back to main content 194 | const docContent = htmlElement.querySelector('[role="main"]'); 195 | if (docContent) return docContent.textContent; 196 | 197 | console.warn( 198 | "Content block not found. Sphinx search tries to obtain it via DOM query '[role=main]'. Check your theme or template." 199 | ); 200 | return ""; 201 | }, 202 | 203 | init: () => { 204 | const query = new URLSearchParams(window.location.search).get("q"); 205 | document 206 | .querySelectorAll('input[name="q"]') 207 | .forEach((el) => (el.value = query)); 208 | if (query) Search.performSearch(query); 209 | }, 210 | 211 | loadIndex: (url) => 212 | (document.body.appendChild(document.createElement("script")).src = url), 213 | 214 | setIndex: (index) => { 215 | Search._index = index; 216 | if (Search._queued_query !== null) { 217 | const query = Search._queued_query; 218 | Search._queued_query = null; 219 | Search.query(query); 220 | } 221 | }, 222 | 223 | hasIndex: () => Search._index !== null, 224 | 225 | deferQuery: (query) => (Search._queued_query = query), 226 | 227 | stopPulse: () => (Search._pulse_status = -1), 228 | 229 | startPulse: () => { 230 | if (Search._pulse_status >= 0) return; 231 | 232 | const pulse = () => { 233 | Search._pulse_status = (Search._pulse_status + 1) % 4; 234 | Search.dots.innerText = ".".repeat(Search._pulse_status); 235 | if (Search._pulse_status >= 0) window.setTimeout(pulse, 500); 236 | }; 237 | pulse(); 238 | }, 239 | 240 | /** 241 | * perform a search for something (or wait until index is loaded) 242 | */ 243 | performSearch: (query) => { 244 | // create the required interface elements 245 | const searchText = document.createElement("h2"); 246 | searchText.textContent = _("Searching"); 247 | const searchSummary = document.createElement("p"); 248 | searchSummary.classList.add("search-summary"); 249 | searchSummary.innerText = ""; 250 | const searchList = document.createElement("ul"); 251 | searchList.classList.add("search"); 252 | 253 | const out = document.getElementById("search-results"); 254 | Search.title = out.appendChild(searchText); 255 | Search.dots = Search.title.appendChild(document.createElement("span")); 256 | Search.status = out.appendChild(searchSummary); 257 | Search.output = out.appendChild(searchList); 258 | 259 | const searchProgress = document.getElementById("search-progress"); 260 | // Some themes don't use the search progress node 261 | if (searchProgress) { 262 | searchProgress.innerText = _("Preparing search..."); 263 | } 264 | Search.startPulse(); 265 | 266 | // index already loaded, the browser was quick! 267 | if (Search.hasIndex()) Search.query(query); 268 | else Search.deferQuery(query); 269 | }, 270 | 271 | _parseQuery: (query) => { 272 | // stem the search terms and add them to the correct list 273 | const stemmer = new Stemmer(); 274 | const searchTerms = new Set(); 275 | const excludedTerms = new Set(); 276 | const highlightTerms = new Set(); 277 | const objectTerms = new Set(splitQuery(query.toLowerCase().trim())); 278 | splitQuery(query.trim()).forEach((queryTerm) => { 279 | const queryTermLower = queryTerm.toLowerCase(); 280 | 281 | // maybe skip this "word" 282 | // stopwords array is from language_data.js 283 | if ( 284 | stopwords.indexOf(queryTermLower) !== -1 || 285 | queryTerm.match(/^\d+$/) 286 | ) 287 | return; 288 | 289 | // stem the word 290 | let word = stemmer.stemWord(queryTermLower); 291 | // select the correct list 292 | if (word[0] === "-") excludedTerms.add(word.substr(1)); 293 | else { 294 | searchTerms.add(word); 295 | highlightTerms.add(queryTermLower); 296 | } 297 | }); 298 | 299 | if (SPHINX_HIGHLIGHT_ENABLED) { // set in sphinx_highlight.js 300 | localStorage.setItem("sphinx_highlight_terms", [...highlightTerms].join(" ")) 301 | } 302 | 303 | // console.debug("SEARCH: searching for:"); 304 | // console.info("required: ", [...searchTerms]); 305 | // console.info("excluded: ", [...excludedTerms]); 306 | 307 | return [query, searchTerms, excludedTerms, highlightTerms, objectTerms]; 308 | }, 309 | 310 | /** 311 | * execute search (requires search index to be loaded) 312 | */ 313 | _performSearch: (query, searchTerms, excludedTerms, highlightTerms, objectTerms) => { 314 | const filenames = Search._index.filenames; 315 | const docNames = Search._index.docnames; 316 | const titles = Search._index.titles; 317 | const allTitles = Search._index.alltitles; 318 | const indexEntries = Search._index.indexentries; 319 | 320 | // Collect multiple result groups to be sorted separately and then ordered. 321 | // Each is an array of [docname, title, anchor, descr, score, filename]. 322 | const normalResults = []; 323 | const nonMainIndexResults = []; 324 | 325 | _removeChildren(document.getElementById("search-progress")); 326 | 327 | const queryLower = query.toLowerCase().trim(); 328 | for (const [title, foundTitles] of Object.entries(allTitles)) { 329 | if (title.toLowerCase().trim().includes(queryLower) && (queryLower.length >= title.length/2)) { 330 | for (const [file, id] of foundTitles) { 331 | const score = Math.round(Scorer.title * queryLower.length / title.length); 332 | const boost = titles[file] === title ? 1 : 0; // add a boost for document titles 333 | normalResults.push([ 334 | docNames[file], 335 | titles[file] !== title ? `${titles[file]} > ${title}` : title, 336 | id !== null ? "#" + id : "", 337 | null, 338 | score + boost, 339 | filenames[file], 340 | ]); 341 | } 342 | } 343 | } 344 | 345 | // search for explicit entries in index directives 346 | for (const [entry, foundEntries] of Object.entries(indexEntries)) { 347 | if (entry.includes(queryLower) && (queryLower.length >= entry.length/2)) { 348 | for (const [file, id, isMain] of foundEntries) { 349 | const score = Math.round(100 * queryLower.length / entry.length); 350 | const result = [ 351 | docNames[file], 352 | titles[file], 353 | id ? "#" + id : "", 354 | null, 355 | score, 356 | filenames[file], 357 | ]; 358 | if (isMain) { 359 | normalResults.push(result); 360 | } else { 361 | nonMainIndexResults.push(result); 362 | } 363 | } 364 | } 365 | } 366 | 367 | // lookup as object 368 | objectTerms.forEach((term) => 369 | normalResults.push(...Search.performObjectSearch(term, objectTerms)) 370 | ); 371 | 372 | // lookup as search terms in fulltext 373 | normalResults.push(...Search.performTermsSearch(searchTerms, excludedTerms)); 374 | 375 | // let the scorer override scores with a custom scoring function 376 | if (Scorer.score) { 377 | normalResults.forEach((item) => (item[4] = Scorer.score(item))); 378 | nonMainIndexResults.forEach((item) => (item[4] = Scorer.score(item))); 379 | } 380 | 381 | // Sort each group of results by score and then alphabetically by name. 382 | normalResults.sort(_orderResultsByScoreThenName); 383 | nonMainIndexResults.sort(_orderResultsByScoreThenName); 384 | 385 | // Combine the result groups in (reverse) order. 386 | // Non-main index entries are typically arbitrary cross-references, 387 | // so display them after other results. 388 | let results = [...nonMainIndexResults, ...normalResults]; 389 | 390 | // remove duplicate search results 391 | // note the reversing of results, so that in the case of duplicates, the highest-scoring entry is kept 392 | let seen = new Set(); 393 | results = results.reverse().reduce((acc, result) => { 394 | let resultStr = result.slice(0, 4).concat([result[5]]).map(v => String(v)).join(','); 395 | if (!seen.has(resultStr)) { 396 | acc.push(result); 397 | seen.add(resultStr); 398 | } 399 | return acc; 400 | }, []); 401 | 402 | return results.reverse(); 403 | }, 404 | 405 | query: (query) => { 406 | const [searchQuery, searchTerms, excludedTerms, highlightTerms, objectTerms] = Search._parseQuery(query); 407 | const results = Search._performSearch(searchQuery, searchTerms, excludedTerms, highlightTerms, objectTerms); 408 | 409 | // for debugging 410 | //Search.lastresults = results.slice(); // a copy 411 | // console.info("search results:", Search.lastresults); 412 | 413 | // print the results 414 | _displayNextItem(results, results.length, searchTerms, highlightTerms); 415 | }, 416 | 417 | /** 418 | * search for object names 419 | */ 420 | performObjectSearch: (object, objectTerms) => { 421 | const filenames = Search._index.filenames; 422 | const docNames = Search._index.docnames; 423 | const objects = Search._index.objects; 424 | const objNames = Search._index.objnames; 425 | const titles = Search._index.titles; 426 | 427 | const results = []; 428 | 429 | const objectSearchCallback = (prefix, match) => { 430 | const name = match[4] 431 | const fullname = (prefix ? prefix + "." : "") + name; 432 | const fullnameLower = fullname.toLowerCase(); 433 | if (fullnameLower.indexOf(object) < 0) return; 434 | 435 | let score = 0; 436 | const parts = fullnameLower.split("."); 437 | 438 | // check for different match types: exact matches of full name or 439 | // "last name" (i.e. last dotted part) 440 | if (fullnameLower === object || parts.slice(-1)[0] === object) 441 | score += Scorer.objNameMatch; 442 | else if (parts.slice(-1)[0].indexOf(object) > -1) 443 | score += Scorer.objPartialMatch; // matches in last name 444 | 445 | const objName = objNames[match[1]][2]; 446 | const title = titles[match[0]]; 447 | 448 | // If more than one term searched for, we require other words to be 449 | // found in the name/title/description 450 | const otherTerms = new Set(objectTerms); 451 | otherTerms.delete(object); 452 | if (otherTerms.size > 0) { 453 | const haystack = `${prefix} ${name} ${objName} ${title}`.toLowerCase(); 454 | if ( 455 | [...otherTerms].some((otherTerm) => haystack.indexOf(otherTerm) < 0) 456 | ) 457 | return; 458 | } 459 | 460 | let anchor = match[3]; 461 | if (anchor === "") anchor = fullname; 462 | else if (anchor === "-") anchor = objNames[match[1]][1] + "-" + fullname; 463 | 464 | const descr = objName + _(", in ") + title; 465 | 466 | // add custom score for some objects according to scorer 467 | if (Scorer.objPrio.hasOwnProperty(match[2])) 468 | score += Scorer.objPrio[match[2]]; 469 | else score += Scorer.objPrioDefault; 470 | 471 | results.push([ 472 | docNames[match[0]], 473 | fullname, 474 | "#" + anchor, 475 | descr, 476 | score, 477 | filenames[match[0]], 478 | ]); 479 | }; 480 | Object.keys(objects).forEach((prefix) => 481 | objects[prefix].forEach((array) => 482 | objectSearchCallback(prefix, array) 483 | ) 484 | ); 485 | return results; 486 | }, 487 | 488 | /** 489 | * search for full-text terms in the index 490 | */ 491 | performTermsSearch: (searchTerms, excludedTerms) => { 492 | // prepare search 493 | const terms = Search._index.terms; 494 | const titleTerms = Search._index.titleterms; 495 | const filenames = Search._index.filenames; 496 | const docNames = Search._index.docnames; 497 | const titles = Search._index.titles; 498 | 499 | const scoreMap = new Map(); 500 | const fileMap = new Map(); 501 | 502 | // perform the search on the required terms 503 | searchTerms.forEach((word) => { 504 | const files = []; 505 | const arr = [ 506 | { files: terms[word], score: Scorer.term }, 507 | { files: titleTerms[word], score: Scorer.title }, 508 | ]; 509 | // add support for partial matches 510 | if (word.length > 2) { 511 | const escapedWord = _escapeRegExp(word); 512 | if (!terms.hasOwnProperty(word)) { 513 | Object.keys(terms).forEach((term) => { 514 | if (term.match(escapedWord)) 515 | arr.push({ files: terms[term], score: Scorer.partialTerm }); 516 | }); 517 | } 518 | if (!titleTerms.hasOwnProperty(word)) { 519 | Object.keys(titleTerms).forEach((term) => { 520 | if (term.match(escapedWord)) 521 | arr.push({ files: titleTerms[term], score: Scorer.partialTitle }); 522 | }); 523 | } 524 | } 525 | 526 | // no match but word was a required one 527 | if (arr.every((record) => record.files === undefined)) return; 528 | 529 | // found search word in contents 530 | arr.forEach((record) => { 531 | if (record.files === undefined) return; 532 | 533 | let recordFiles = record.files; 534 | if (recordFiles.length === undefined) recordFiles = [recordFiles]; 535 | files.push(...recordFiles); 536 | 537 | // set score for the word in each file 538 | recordFiles.forEach((file) => { 539 | if (!scoreMap.has(file)) scoreMap.set(file, {}); 540 | scoreMap.get(file)[word] = record.score; 541 | }); 542 | }); 543 | 544 | // create the mapping 545 | files.forEach((file) => { 546 | if (!fileMap.has(file)) fileMap.set(file, [word]); 547 | else if (fileMap.get(file).indexOf(word) === -1) fileMap.get(file).push(word); 548 | }); 549 | }); 550 | 551 | // now check if the files don't contain excluded terms 552 | const results = []; 553 | for (const [file, wordList] of fileMap) { 554 | // check if all requirements are matched 555 | 556 | // as search terms with length < 3 are discarded 557 | const filteredTermCount = [...searchTerms].filter( 558 | (term) => term.length > 2 559 | ).length; 560 | if ( 561 | wordList.length !== searchTerms.size && 562 | wordList.length !== filteredTermCount 563 | ) 564 | continue; 565 | 566 | // ensure that none of the excluded terms is in the search result 567 | if ( 568 | [...excludedTerms].some( 569 | (term) => 570 | terms[term] === file || 571 | titleTerms[term] === file || 572 | (terms[term] || []).includes(file) || 573 | (titleTerms[term] || []).includes(file) 574 | ) 575 | ) 576 | break; 577 | 578 | // select one (max) score for the file. 579 | const score = Math.max(...wordList.map((w) => scoreMap.get(file)[w])); 580 | // add result to the result list 581 | results.push([ 582 | docNames[file], 583 | titles[file], 584 | "", 585 | null, 586 | score, 587 | filenames[file], 588 | ]); 589 | } 590 | return results; 591 | }, 592 | 593 | /** 594 | * helper function to return a node containing the 595 | * search summary for a given text. keywords is a list 596 | * of stemmed words. 597 | */ 598 | makeSearchSummary: (htmlText, keywords, anchor) => { 599 | const text = Search.htmlToText(htmlText, anchor); 600 | if (text === "") return null; 601 | 602 | const textLower = text.toLowerCase(); 603 | const actualStartPosition = [...keywords] 604 | .map((k) => textLower.indexOf(k.toLowerCase())) 605 | .filter((i) => i > -1) 606 | .slice(-1)[0]; 607 | const startWithContext = Math.max(actualStartPosition - 120, 0); 608 | 609 | const top = startWithContext === 0 ? "" : "..."; 610 | const tail = startWithContext + 240 < text.length ? "..." : ""; 611 | 612 | let summary = document.createElement("p"); 613 | summary.classList.add("context"); 614 | summary.textContent = top + text.substr(startWithContext, 240).trim() + tail; 615 | 616 | return summary; 617 | }, 618 | }; 619 | 620 | _ready(Search.init); 621 | -------------------------------------------------------------------------------- /docs/api/_static/sphinx_highlight.js: -------------------------------------------------------------------------------- 1 | /* Highlighting utilities for Sphinx HTML documentation. */ 2 | "use strict"; 3 | 4 | const SPHINX_HIGHLIGHT_ENABLED = true 5 | 6 | /** 7 | * highlight a given string on a node by wrapping it in 8 | * span elements with the given class name. 9 | */ 10 | const _highlight = (node, addItems, text, className) => { 11 | if (node.nodeType === Node.TEXT_NODE) { 12 | const val = node.nodeValue; 13 | const parent = node.parentNode; 14 | const pos = val.toLowerCase().indexOf(text); 15 | if ( 16 | pos >= 0 && 17 | !parent.classList.contains(className) && 18 | !parent.classList.contains("nohighlight") 19 | ) { 20 | let span; 21 | 22 | const closestNode = parent.closest("body, svg, foreignObject"); 23 | const isInSVG = closestNode && closestNode.matches("svg"); 24 | if (isInSVG) { 25 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 26 | } else { 27 | span = document.createElement("span"); 28 | span.classList.add(className); 29 | } 30 | 31 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 32 | const rest = document.createTextNode(val.substr(pos + text.length)); 33 | parent.insertBefore( 34 | span, 35 | parent.insertBefore( 36 | rest, 37 | node.nextSibling 38 | ) 39 | ); 40 | node.nodeValue = val.substr(0, pos); 41 | /* There may be more occurrences of search term in this node. So call this 42 | * function recursively on the remaining fragment. 43 | */ 44 | _highlight(rest, addItems, text, className); 45 | 46 | if (isInSVG) { 47 | const rect = document.createElementNS( 48 | "http://www.w3.org/2000/svg", 49 | "rect" 50 | ); 51 | const bbox = parent.getBBox(); 52 | rect.x.baseVal.value = bbox.x; 53 | rect.y.baseVal.value = bbox.y; 54 | rect.width.baseVal.value = bbox.width; 55 | rect.height.baseVal.value = bbox.height; 56 | rect.setAttribute("class", className); 57 | addItems.push({ parent: parent, target: rect }); 58 | } 59 | } 60 | } else if (node.matches && !node.matches("button, select, textarea")) { 61 | node.childNodes.forEach((el) => _highlight(el, addItems, text, className)); 62 | } 63 | }; 64 | const _highlightText = (thisNode, text, className) => { 65 | let addItems = []; 66 | _highlight(thisNode, addItems, text, className); 67 | addItems.forEach((obj) => 68 | obj.parent.insertAdjacentElement("beforebegin", obj.target) 69 | ); 70 | }; 71 | 72 | /** 73 | * Small JavaScript module for the documentation. 74 | */ 75 | const SphinxHighlight = { 76 | 77 | /** 78 | * highlight the search words provided in localstorage in the text 79 | */ 80 | highlightSearchWords: () => { 81 | if (!SPHINX_HIGHLIGHT_ENABLED) return; // bail if no highlight 82 | 83 | // get and clear terms from localstorage 84 | const url = new URL(window.location); 85 | const highlight = 86 | localStorage.getItem("sphinx_highlight_terms") 87 | || url.searchParams.get("highlight") 88 | || ""; 89 | localStorage.removeItem("sphinx_highlight_terms") 90 | url.searchParams.delete("highlight"); 91 | window.history.replaceState({}, "", url); 92 | 93 | // get individual terms from highlight string 94 | const terms = highlight.toLowerCase().split(/\s+/).filter(x => x); 95 | if (terms.length === 0) return; // nothing to do 96 | 97 | // There should never be more than one element matching "div.body" 98 | const divBody = document.querySelectorAll("div.body"); 99 | const body = divBody.length ? divBody[0] : document.querySelector("body"); 100 | window.setTimeout(() => { 101 | terms.forEach((term) => _highlightText(body, term, "highlighted")); 102 | }, 10); 103 | 104 | const searchBox = document.getElementById("searchbox"); 105 | if (searchBox === null) return; 106 | searchBox.appendChild( 107 | document 108 | .createRange() 109 | .createContextualFragment( 110 | '" 114 | ) 115 | ); 116 | }, 117 | 118 | /** 119 | * helper function to hide the search marks again 120 | */ 121 | hideSearchWords: () => { 122 | document 123 | .querySelectorAll("#searchbox .highlight-link") 124 | .forEach((el) => el.remove()); 125 | document 126 | .querySelectorAll("span.highlighted") 127 | .forEach((el) => el.classList.remove("highlighted")); 128 | localStorage.removeItem("sphinx_highlight_terms") 129 | }, 130 | 131 | initEscapeListener: () => { 132 | // only install a listener if it is really needed 133 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; 134 | 135 | document.addEventListener("keydown", (event) => { 136 | // bail for input elements 137 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 138 | // bail with special keys 139 | if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; 140 | if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { 141 | SphinxHighlight.hideSearchWords(); 142 | event.preventDefault(); 143 | } 144 | }); 145 | }, 146 | }; 147 | 148 | _ready(() => { 149 | /* Do not call highlightSearchWords() when we are on the search page. 150 | * It will highlight words from the *previous* search query. 151 | */ 152 | if (typeof Search === "undefined") SphinxHighlight.highlightSearchWords(); 153 | SphinxHighlight.initEscapeListener(); 154 | }); 155 | -------------------------------------------------------------------------------- /docs/api/genindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Index — LassoNet documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 |
28 |
29 | 30 | 31 |
32 | 33 | 34 |

Index

35 | 36 |
37 | F 38 | | G 39 | | L 40 | | P 41 | | S 42 | 43 |
44 |

F

45 | 46 | 62 |
63 | 64 |

G

65 | 66 | 82 | 98 |
99 | 100 |

L

101 | 102 | 110 | 120 |
121 | 122 |

P

123 | 124 | 140 | 144 |
145 | 146 |

S

147 | 148 | 192 | 222 |
223 | 224 | 225 | 226 |
227 | 228 |
229 |
230 | 269 |
270 |
271 | 279 | 280 | 281 | 282 | 283 | 284 | -------------------------------------------------------------------------------- /docs/api/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/api/objects.inv -------------------------------------------------------------------------------- /docs/api/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Search — LassoNet documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 |
34 |
35 |
36 | 37 | 38 |
39 | 40 |

Search

41 | 42 | 50 | 51 | 52 |

53 | Searching for multiple words only shows matches that contain 54 | all words. 55 |

56 | 57 | 58 |
59 | 60 | 61 | 62 |
63 | 64 | 65 |
66 | 67 | 68 |
69 | 70 |
71 |
72 | 102 |
103 |
104 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /docs/api/searchindex.js: -------------------------------------------------------------------------------- 1 | Search.setIndex({"alltitles": {"API": [[0, "api"]], "Installation": [[0, "installation"]], "Welcome to LassoNet\u2019s documentation!": [[0, null]]}, "docnames": ["index"], "envversion": {"sphinx": 63, "sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2}, "filenames": ["index.rst"], "indexentries": {"fit() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.fit", false]], "fit() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.fit", false]], "fit() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.fit", false]], "fit() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.fit", false]], "fit() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.fit", false]], "fit() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.fit", false]], "get_metadata_routing() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.get_metadata_routing", false]], "get_params() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.get_params", false]], "get_params() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.get_params", false]], "get_params() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.get_params", false]], "get_params() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.get_params", false]], "get_params() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.get_params", false]], "get_params() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.get_params", false]], "lassonet_path() (in module lassonet)": [[0, "lassonet.lassonet_path", false]], "lassonetclassifier (class in lassonet)": [[0, "lassonet.LassoNetClassifier", false]], "lassonetclassifiercv (class in lassonet)": [[0, "lassonet.LassoNetClassifierCV", false]], "lassonetcoxregressor (class in lassonet)": [[0, "lassonet.LassoNetCoxRegressor", false]], "lassonetcoxregressorcv (class in lassonet)": [[0, "lassonet.LassoNetCoxRegressorCV", false]], "lassonetregressor (class in lassonet)": [[0, "lassonet.LassoNetRegressor", false]], "lassonetregressorcv (class in lassonet)": [[0, "lassonet.LassoNetRegressorCV", false]], "path() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.path", false]], "path() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.path", false]], "path() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.path", false]], "path() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.path", false]], "path() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.path", false]], "path() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.path", false]], "plot_path() (in module lassonet)": [[0, "lassonet.plot_path", false]], "score() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.score", false]], "score() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.score", false]], "score() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.score", false]], "score() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.score", false]], "score() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.score", false]], "score() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.score", false]], "set_fit_request() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.set_fit_request", false]], "set_fit_request() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.set_fit_request", false]], "set_fit_request() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.set_fit_request", false]], "set_fit_request() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.set_fit_request", false]], "set_fit_request() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.set_fit_request", false]], "set_fit_request() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.set_fit_request", false]], "set_params() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.set_params", false]], "set_params() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.set_params", false]], "set_params() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.set_params", false]], "set_params() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.set_params", false]], "set_params() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.set_params", false]], "set_params() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.set_params", false]], "set_score_request() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.set_score_request", false]], "set_score_request() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.set_score_request", false]], "set_score_request() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.set_score_request", false]], "set_score_request() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.set_score_request", false]], "set_score_request() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.set_score_request", false]], "set_score_request() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.set_score_request", false]], "stability_selection() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.stability_selection", false]], "stability_selection() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.stability_selection", false]], "stability_selection() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.stability_selection", false]], "stability_selection() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.stability_selection", false]], "stability_selection() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.stability_selection", false]], "stability_selection() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.stability_selection", false]]}, "objects": {"lassonet": [[0, 0, 1, "", "LassoNetClassifier"], [0, 0, 1, "", "LassoNetClassifierCV"], [0, 0, 1, "", "LassoNetCoxRegressor"], [0, 0, 1, "", "LassoNetCoxRegressorCV"], [0, 0, 1, "", "LassoNetRegressor"], [0, 0, 1, "", "LassoNetRegressorCV"], [0, 2, 1, "", "lassonet_path"], [0, 2, 1, "", "plot_path"]], "lassonet.LassoNetClassifier": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetClassifierCV": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetCoxRegressor": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetCoxRegressorCV": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetRegressor": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetRegressorCV": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]]}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function"}, "terms": {"0": 0, "02": 0, "1": 0, "10": 0, "100": 0, "1000": 0, "1e": 0, "2": 0, "20": 0, "23": 0, "3": 0, "5": 0, "9": 0, "99": 0, "A": 0, "For": 0, "If": 0, "In": 0, "The": 0, "There": 0, "To": 0, "__": 0, "accord": 0, "accuraci": 0, "ad": 0, "adam": 0, "alia": 0, "all": 0, "allow": 0, "alwai": 0, "an": 0, "ani": 0, "anymor": 0, "approxim": 0, "ar": 0, "arbitrarili": 0, "arrai": 0, "auto": 0, "automat": 0, "avail": 0, "backtrack": 0, "baselassonet": 0, "batch": 0, "batch_siz": 0, "becaus": 0, "beforehand": 0, "being": 0, "best": 0, "bool": 0, "bound": 0, "breslow": 0, "call": 0, "callback": 0, "can": 0, "chang": 0, "check": 0, "check_cv": 0, "class": 0, "class_weight": 0, "classif": 0, "classifi": 0, "coeffici": 0, "compon": 0, "comput": 0, "concord": 0, "connect": 0, "consist": 0, "constant": 0, "contain": 0, "correctli": 0, "cox": 0, "cpu": 0, "cross": 0, "cv": 0, "data": 0, "decreas": 0, "deep": 0, "default": 0, "defin": 0, "dens": 0, "dense_onli": 0, "determin": 0, "devic": 0, "dict": 0, "differ": 0, "dimens": 0, "disabl": 0, "disable_lambda_warn": 0, "disregard": 0, "doe": 0, "dropout": 0, "durat": 0, "dure": 0, "e": 0, "each": 0, "earli": 0, "effect": 0, "efron": 0, "els": 0, "enable_metadata_rout": 0, "encapsul": 0, "ensur": 0, "epoch": 0, "epsilon": 0, "error": 0, "estim": 0, "event": 0, "evolut": 0, "except": 0, "exist": 0, "expect": 0, "factor": 0, "fals": 0, "featur": 0, "first": 0, "fit": 0, "float": 0, "fold": 0, "form": 0, "frac": 0, "from": 0, "function": 0, "g": 0, "gamma": 0, "gamma_skip": 0, "gener": 0, "get": 0, "get_metadata_rout": 0, "get_param": 0, "given": 0, "go": 0, "gpu": 0, "group": 0, "guid": 0, "ha": 0, "harsh": 0, "have": 0, "hidden": 0, "hidden_dim": 0, "hierarchi": 0, "histori": 0, "historyitem": 0, "how": 0, "html": 0, "http": 0, "i": 0, "ideal": 0, "ignor": 0, "improv": 0, "increas": 0, "increment": 0, "index": 0, "indic": 0, "inf": 0, "influenc": 0, "inform": 0, "initi": 0, "input": 0, "insid": 0, "instanc": 0, "instead": 0, "int": 0, "iter": 0, "keep": 0, "kernel": 0, "kwarg": 0, "l2": 0, "label": 0, "lambda": 0, "lambda_": 0, "lambda_max": 0, "lambda_seq": 0, "lambda_start": 0, "lassonet_path": 0, "lassonetclassifi": 0, "lassonetclassifiercv": 0, "lassonetcoxregressor": 0, "lassonetcoxregressorcv": 0, "lassonetregressor": 0, "lassonetregressorcv": 0, "latter": 0, "layer": 0, "learn": 0, "leav": 0, "like": 0, "list": 0, "longtensor": 0, "lr": 0, "m": 0, "mai": 0, "main": 0, "map": 0, "matrix": 0, "maximum": 0, "mean": 0, "mechan": 0, "meta": 0, "metadata": 0, "metadata_rout": 0, "metadatarequest": 0, "method": 0, "metric": 0, "minimum": 0, "model": 0, "model_select": 0, "modul": 0, "momentum": 0, "most": 0, "multi": 0, "multioutput": 0, "multioutputregressor": 0, "multipl": 0, "must": 0, "n_featur": 0, "n_iter": 0, "n_model": 0, "n_output": 0, "n_sampl": 0, "n_samples_fit": 0, "n_step": 0, "name": 0, "neg": 0, "nest": 0, "network": 0, "new": 0, "none": 0, "note": 0, "number": 0, "object": 0, "old": 0, "one": 0, "onli": 0, "optim": 0, "option": 0, "oracl": 0, "order": 0, "org": 0, "origin": 0, "other": 0, "otherwis": 0, "output": [], "over": 0, "pair": 0, "param": 0, "paramet": 0, "pass": 0, "path": 0, "path_multipli": 0, "patienc": 0, "penal": 0, "penalti": 0, "per": 0, "pip": 0, "pipelin": 0, "pleas": 0, "plot": 0, "plot_path": 0, "possibl": 0, "precomput": 0, "predict": 0, "prob": 0, "probabl": 0, "proport": 0, "provid": 0, "pytorch": 0, "r": 0, "r2_score": 0, "rais": 0, "random": 0, "random_st": 0, "regress": 0, "regressor": 0, "regular": 0, "relev": 0, "request": 0, "requir": 0, "residu": 0, "retain": 0, "return": 0, "return_state_dict": 0, "rout": 0, "sampl": 0, "sample_weight": 0, "scikit": 0, "score": 0, "score_funct": 0, "see": 0, "select": 0, "selection_prob": 0, "self": 0, "sequenc": 0, "set": 0, "set_config": 0, "set_fit_request": 0, "set_param": 0, "set_score_request": 0, "sgd": 0, "shape": 0, "should": 0, "shuffl": 0, "simpl": 0, "sinc": 0, "skip": 0, "sklearn": 0, "so": 0, "some": 0, "specifi": 0, "split": 0, "squar": 0, "stabil": 0, "stability_select": 0, "stabl": 0, "start": 0, "state": 0, "step": 0, "stop": 0, "str": 0, "strategi": 0, "sub": 0, "subobject": 0, "subset": 0, "sum": 0, "t": 0, "take": 0, "target": 0, "task": 0, "tensor": 0, "test": 0, "th": 0, "thi": 0, "tie": 0, "tie_approxim": 0, "tol": 0, "torch": 0, "torch_se": 0, "total": 0, "train": 0, "true": 0, "tupl": 0, "two": 0, "type": 0, "u": 0, "unchang": 0, "uniform_averag": 0, "until": 0, "updat": 0, "upper": 0, "us": 0, "user": 0, "util": 0, "v": 0, "val_siz": 0, "valid": 0, "valu": 0, "variabl": 0, "verbos": 0, "version": 0, "w": 0, "wait": 0, "websit": 0, "weight": 0, "well": 0, "when": 0, "where": 0, "which": 0, "without": 0, "work": 0, "wors": 0, "would": 0, "wrong": 0, "x": 0, "x_test": 0, "x_val": 0, "y": 0, "y_pred": 0, "y_test": 0, "y_true": 0, "y_val": 0, "you": 0, "zero": 0}, "titles": ["Welcome to LassoNet\u2019s documentation!"], "titleterms": {"": 0, "api": 0, "document": 0, "instal": 0, "lassonet": 0, "welcom": 0}}) -------------------------------------------------------------------------------- /docs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/fig1.png -------------------------------------------------------------------------------- /docs/images/video_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/images/video_screenshot.png -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | LassoNet: Neural Networks with Feature Sparsity 7 | 8 | 9 | 11 | 13 | 44 | 45 | 46 | 48 | 49 | 50 | 51 |
52 |

LassoNet: Neural Networks with Feature Sparsity

53 | 54 | 55 | 56 | 59 | 62 | 65 | 68 |
57 | Ismael Lemhadri 58 | 60 | Feng Ruan 61 | 63 | Louis Abraham 64 | 66 | Rob Tibshirani 67 |
69 | 70 | 71 | 72 | 75 | 78 | 81 | 84 | 87 | 88 | 89 | 90 | 91 | 92 |

93 | LassoNet is a method for feature selection in neural networks, to enhance interpretability of the final 94 | network. 95 |

111 |

112 | 113 |
114 | 115 |

LassoNet in 2 minutes

116 | 118 | 119 |
120 | 121 |

Installation

122 | 123 |
pip install lassonet
124 | 125 |
126 |
127 |

Tips

128 | 129 | LassoNet sometimes require fine tuning. For optimal performance, consider: 130 | 140 | 141 |
142 | 143 | 144 |
145 |

Intro video

146 | 149 | 150 |

Talk

151 | 154 |
155 |
156 | 157 |
158 | 159 |
160 |

Citation

161 | The algorithms and method used in this package came primarily out of research in Rob Tibshirani's lab at Stanford 162 | University. If you use LassoNet in your research we would appreciate a citation to the paper: 163 |
164 |       @article{lemhadri2019neural,
165 |         title={LassoNet: Neural Networks with Feature Sparsity},
166 |         author={Lemhadri, Ismael and Ruan, Feng and
167 |                 Abraham, Louis and Tibshirani, Robert},
168 |         journal={arXiv preprint arXiv:1907.12207},
169 |         year={2019}
170 |       }
171 |     
172 |
173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /docs/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif; 3 | font-weight: 300; 4 | font-size: 18px; 5 | margin-left: auto; 6 | margin-right: auto; 7 | width: 1100px; 8 | } 9 | 10 | h1 { 11 | font-weight: 300; 12 | } 13 | 14 | .disclaimerbox { 15 | background-color: #eee; 16 | border: 1px solid #eeeeee; 17 | border-radius: 10px; 18 | -moz-border-radius: 10px; 19 | -webkit-border-radius: 10px; 20 | padding: 20px; 21 | } 22 | 23 | video.header-vid { 24 | height: 140px; 25 | border: 1px solid black; 26 | border-radius: 10px; 27 | -moz-border-radius: 10px; 28 | -webkit-border-radius: 10px; 29 | } 30 | 31 | img.header-img { 32 | height: 140px; 33 | border: 1px solid black; 34 | border-radius: 10px; 35 | -moz-border-radius: 10px; 36 | -webkit-border-radius: 10px; 37 | } 38 | 39 | img.rounded { 40 | border: 1px solid #eeeeee; 41 | border-radius: 10px; 42 | -moz-border-radius: 10px; 43 | -webkit-border-radius: 10px; 44 | } 45 | 46 | a:link, 47 | a:visited { 48 | color: #1367a7; 49 | text-decoration: none; 50 | } 51 | 52 | a:hover { 53 | color: #208799; 54 | } 55 | 56 | td.dl-link { 57 | height: 160px; 58 | text-align: center; 59 | font-size: 22px; 60 | } 61 | 62 | .layered-paper-big { 63 | /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ 64 | box-shadow: 65 | 0px 0px 1px 1px rgba(0, 0, 0, 0.35), 66 | /* The top layer shadow */ 67 | 5px 5px 0 0px #fff, 68 | /* The second layer */ 69 | 5px 5px 1px 1px rgba(0, 0, 0, 0.35), 70 | /* The second layer shadow */ 71 | 10px 10px 0 0px #fff, 72 | /* The third layer */ 73 | 10px 10px 1px 1px rgba(0, 0, 0, 0.35), 74 | /* The third layer shadow */ 75 | 15px 15px 0 0px #fff, 76 | /* The fourth layer */ 77 | 15px 15px 1px 1px rgba(0, 0, 0, 0.35), 78 | /* The fourth layer shadow */ 79 | 20px 20px 0 0px #fff, 80 | /* The fifth layer */ 81 | 20px 20px 1px 1px rgba(0, 0, 0, 0.35), 82 | /* The fifth layer shadow */ 83 | 25px 25px 0 0px #fff, 84 | /* The fifth layer */ 85 | 25px 25px 1px 1px rgba(0, 0, 0, 0.35); 86 | /* The fifth layer shadow */ 87 | margin-left: 10px; 88 | margin-right: 45px; 89 | } 90 | 91 | .paper-big { 92 | /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ 93 | box-shadow: 94 | 0px 0px 1px 1px rgba(0, 0, 0, 0.35); 95 | /* The top layer shadow */ 96 | 97 | margin-left: 10px; 98 | margin-right: 45px; 99 | } 100 | 101 | 102 | .layered-paper { 103 | /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ 104 | box-shadow: 105 | 0px 0px 1px 1px rgba(0, 0, 0, 0.35), 106 | /* The top layer shadow */ 107 | 5px 5px 0 0px #fff, 108 | /* The second layer */ 109 | 5px 5px 1px 1px rgba(0, 0, 0, 0.35), 110 | /* The second layer shadow */ 111 | 10px 10px 0 0px #fff, 112 | /* The third layer */ 113 | 10px 10px 1px 1px rgba(0, 0, 0, 0.35); 114 | /* The third layer shadow */ 115 | margin-top: 5px; 116 | margin-left: 10px; 117 | margin-right: 30px; 118 | margin-bottom: 5px; 119 | } 120 | 121 | .vert-cent { 122 | position: relative; 123 | top: 50%; 124 | transform: translateY(-50%); 125 | } 126 | 127 | hr { 128 | border: 0; 129 | height: 1px; 130 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0)); 131 | } 132 | 133 | .center { 134 | margin: auto; 135 | text-align: center; 136 | } 137 | 138 | img.center { 139 | display: block; 140 | 141 | } 142 | 143 | h1 { 144 | text-align: center; 145 | font-size: 42px 146 | } 147 | 148 | #authors { 149 | font-size: 24px; 150 | width: 800px; 151 | } 152 | 153 | #authors td { 154 | width: 100px; 155 | } 156 | 157 | #links td { 158 | font-size: 24px; 159 | width: 130px; 160 | } 161 | 162 | 163 | iframe { 164 | width: 792px; 165 | height: 328px; 166 | display: block; 167 | border-style: none; 168 | margin: 0 auto; 169 | 170 | } -------------------------------------------------------------------------------- /examples/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/examples/accuracy.png -------------------------------------------------------------------------------- /examples/boston_housing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from sklearn.datasets import load_boston 6 | from sklearn.metrics import mean_squared_error 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.preprocessing import StandardScaler, scale 9 | 10 | from lassonet import LassoNetRegressor 11 | 12 | dataset = load_boston() 13 | X = dataset.data 14 | y = dataset.target 15 | _, true_features = X.shape 16 | # add dummy feature 17 | X = np.concatenate([X, np.random.randn(*X.shape)], axis=1) 18 | feature_names = list(dataset.feature_names) + ["fake"] * true_features 19 | 20 | # standardize 21 | X = StandardScaler().fit_transform(X) 22 | y = scale(y) 23 | 24 | 25 | X_train, X_test, y_train, y_test = train_test_split(X, y) 26 | 27 | model = LassoNetRegressor( 28 | hidden_dims=(10,), 29 | verbose=True, 30 | patience=(100, 5), 31 | ) 32 | path = model.path(X_train, y_train) 33 | 34 | n_selected = [] 35 | mse = [] 36 | lambda_ = [] 37 | 38 | for save in path: 39 | model.load(save.state_dict) 40 | y_pred = model.predict(X_test) 41 | n_selected.append(save.selected.sum().cpu().numpy()) 42 | mse.append(mean_squared_error(y_test, y_pred)) 43 | lambda_.append(save.lambda_) 44 | 45 | 46 | fig = plt.figure(figsize=(12, 12)) 47 | 48 | plt.subplot(311) 49 | plt.grid(True) 50 | plt.plot(n_selected, mse, ".-") 51 | plt.xlabel("number of selected features") 52 | plt.ylabel("MSE") 53 | 54 | plt.subplot(312) 55 | plt.grid(True) 56 | plt.plot(lambda_, mse, ".-") 57 | plt.xlabel("lambda") 58 | plt.xscale("log") 59 | plt.ylabel("MSE") 60 | 61 | plt.subplot(313) 62 | plt.grid(True) 63 | plt.plot(lambda_, n_selected, ".-") 64 | plt.xlabel("lambda") 65 | plt.xscale("log") 66 | plt.ylabel("number of selected features") 67 | 68 | plt.savefig("boston.png") 69 | 70 | plt.clf() 71 | 72 | n_features = X.shape[1] 73 | importances = model.feature_importances_.numpy() 74 | order = np.argsort(importances)[::-1] 75 | importances = importances[order] 76 | ordered_feature_names = [feature_names[i] for i in order] 77 | color = np.array(["g"] * true_features + ["r"] * (n_features - true_features))[order] 78 | 79 | 80 | plt.subplot(211) 81 | plt.bar( 82 | np.arange(n_features), 83 | importances, 84 | color=color, 85 | ) 86 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90) 87 | colors = {"real features": "g", "fake features": "r"} 88 | labels = list(colors.keys()) 89 | handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label]) for label in labels] 90 | plt.legend(handles, labels) 91 | plt.ylabel("Feature importance") 92 | 93 | _, order = np.unique(importances, return_inverse=True) 94 | 95 | plt.subplot(212) 96 | plt.bar( 97 | np.arange(n_features), 98 | order + 1, 99 | color=color, 100 | ) 101 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90) 102 | plt.legend(handles, labels) 103 | plt.ylabel("Feature order") 104 | 105 | plt.savefig("boston-bar.png") 106 | -------------------------------------------------------------------------------- /examples/boston_housing_group.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from sklearn.datasets import load_boston 6 | from sklearn.metrics import mean_squared_error 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.preprocessing import StandardScaler, scale 9 | 10 | from lassonet import LassoNetRegressor 11 | 12 | dataset = load_boston() 13 | X = dataset.data 14 | y = dataset.target 15 | _, true_features = X.shape 16 | # add dummy feature 17 | X = np.concatenate([X, np.random.randn(*X.shape)], axis=1) 18 | feature_names = list(dataset.feature_names) + ["fake"] * true_features 19 | 20 | # standardize 21 | X = StandardScaler().fit_transform(X) 22 | y = scale(y) 23 | 24 | 25 | X_train, X_test, y_train, y_test = train_test_split(X, y) 26 | 27 | model = LassoNetRegressor( 28 | hidden_dims=(10,), 29 | verbose=True, 30 | patience=(100, 5), 31 | groups=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12], list(range(13, 26))], 32 | ) 33 | path = model.path(X_train, y_train) 34 | 35 | n_selected = [] 36 | mse = [] 37 | lambda_ = [] 38 | 39 | for save in path: 40 | model.load(save.state_dict) 41 | y_pred = model.predict(X_test) 42 | n_selected.append(save.selected.sum()) 43 | mse.append(mean_squared_error(y_test, y_pred)) 44 | lambda_.append(save.lambda_) 45 | 46 | 47 | fig = plt.figure(figsize=(12, 12)) 48 | 49 | plt.subplot(311) 50 | plt.grid(True) 51 | plt.plot(n_selected, mse, ".-") 52 | plt.xlabel("number of selected features") 53 | plt.ylabel("MSE") 54 | 55 | plt.subplot(312) 56 | plt.grid(True) 57 | plt.plot(lambda_, mse, ".-") 58 | plt.xlabel("lambda") 59 | plt.xscale("log") 60 | plt.ylabel("MSE") 61 | 62 | plt.subplot(313) 63 | plt.grid(True) 64 | plt.plot(lambda_, n_selected, ".-") 65 | plt.xlabel("lambda") 66 | plt.xscale("log") 67 | plt.ylabel("number of selected features") 68 | 69 | plt.savefig("boston-group.png") 70 | 71 | plt.clf() 72 | 73 | n_features = X.shape[1] 74 | importances = model.feature_importances_.numpy() 75 | order = np.argsort(importances)[::-1] 76 | importances = importances[order] 77 | ordered_feature_names = [feature_names[i] for i in order] 78 | color = np.array(["g"] * true_features + ["r"] * (n_features - true_features))[order] 79 | 80 | 81 | plt.subplot(211) 82 | plt.bar( 83 | np.arange(n_features), 84 | importances, 85 | color=color, 86 | ) 87 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90) 88 | colors = {"real features": "g", "fake features": "r"} 89 | labels = list(colors.keys()) 90 | handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label]) for label in labels] 91 | plt.legend(handles, labels) 92 | plt.ylabel("Feature importance") 93 | 94 | _, order = np.unique(importances, return_inverse=True) 95 | 96 | plt.subplot(212) 97 | plt.bar( 98 | np.arange(n_features), 99 | order + 1, 100 | color=color, 101 | ) 102 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90) 103 | plt.legend(handles, labels) 104 | plt.ylabel("Feature order") 105 | 106 | plt.savefig("boston-bar-group.png") 107 | -------------------------------------------------------------------------------- /examples/cox_experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Install required packages with: 4 | 5 | pip install scipy joblib tqdm_joblib 6 | """ 7 | 8 | 9 | from pathlib import Path 10 | from time import time 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import sksurv.datasets 15 | from joblib import Parallel, delayed 16 | from matplotlib import pyplot as plt 17 | from sklearn import preprocessing 18 | from sklearn.model_selection import StratifiedKFold, train_test_split 19 | from tqdm import tqdm 20 | from tqdm_joblib import tqdm_joblib 21 | 22 | from lassonet import LassoNetCoxRegressorCV, plot_cv 23 | from lassonet.utils import confidence_interval 24 | 25 | DATA_PATH = Path(__file__).parent / "data" 26 | DATA_PATH.mkdir(exist_ok=True) 27 | 28 | FIGURES_PATH = Path() / "cox_figures" 29 | FIGURES_PATH.mkdir(exist_ok=True) 30 | 31 | 32 | def transform_one_hot(input_matrix, col_name): 33 | one_hot_col = pd.get_dummies(input_matrix[col_name], prefix=col_name) 34 | input_matrix = input_matrix.drop([col_name], axis=1) 35 | input_matrix = input_matrix.join(one_hot_col) 36 | return input_matrix 37 | 38 | 39 | def dump(array, name): 40 | pd.DataFrame(array).to_csv(DATA_PATH / name, index=False) 41 | 42 | 43 | def gen_data(dataset): 44 | if dataset == "breast": 45 | X, y = sksurv.datasets.load_breast_cancer() 46 | di_er = {"negative": 0, "positive": 1} 47 | di_grade = { 48 | "poorly differentiated": -1, 49 | "intermediate": 0, 50 | "well differentiated": 1, 51 | "unkown": 0, 52 | } 53 | X = X.replace({"er": di_er, "grade": di_grade}) 54 | y_temp = pd.DataFrame(y, columns=["t.tdm", "e.tdm"]) 55 | di_event = {True: 1, False: 0} 56 | y_temp = y_temp.replace({"e.tdm": di_event}) 57 | y = y_temp 58 | 59 | elif dataset == "fl_chain": 60 | X, y = sksurv.datasets.load_flchain() 61 | di_mgus = {"no": 0, "yes": 1} 62 | X = X.replace({"mgus": di_mgus, "creatinine": {np.nan: 0}}) 63 | col_names = ["chapter", "sex", "sample.yr", "flc.grp"] 64 | for col_name in col_names: 65 | X = transform_one_hot(X, col_name) 66 | y_temp = pd.DataFrame(y, columns=["futime", "death"]) 67 | di_event = {True: 0, False: 1} 68 | y_temp = y_temp.replace({"death": di_event}) 69 | y = y_temp 70 | 71 | elif dataset == "whas500": 72 | X, y = sksurv.datasets.load_whas500() 73 | y_temp = pd.DataFrame(y, columns=["lenfol", "fstat"]) 74 | di_event = {True: 1, False: 0} 75 | y_temp = y_temp.replace({"fstat": di_event}) 76 | y = y_temp 77 | 78 | elif dataset == "veterans": 79 | X, y = sksurv.datasets.load_veterans_lung_cancer() 80 | col_names = ["Celltype", "Prior_therapy", "Treatment"] 81 | for col_name in col_names: 82 | X = transform_one_hot(X, col_name) 83 | y_temp = pd.DataFrame(y, columns=["Survival_in_days", "Status"]) 84 | di_event = {False: 0, True: 1} 85 | y_temp = y_temp.replace({"Status": di_event}) 86 | y = y_temp 87 | elif dataset == "hnscc": 88 | raise ValueError("Dataset exists") 89 | else: 90 | raise ValueError("Dataset unknown") 91 | 92 | dump(X, f"{dataset}_x.csv") 93 | dump(y, f"{dataset}_y.csv") 94 | 95 | 96 | def load_data(dataset): 97 | path_x = DATA_PATH / f"{dataset}_x.csv" 98 | path_y = DATA_PATH / f"{dataset}_y.csv" 99 | if not (path_x.exists() and path_y.exists()): 100 | gen_data(dataset) 101 | X = np.genfromtxt(path_x, delimiter=",", skip_header=1) 102 | y = np.genfromtxt(path_y, delimiter=",", skip_header=1) 103 | X = preprocessing.StandardScaler().fit(X).transform(X) 104 | return X, y 105 | 106 | 107 | def run( 108 | dataset, 109 | hidden_dims, 110 | path_multiplier, 111 | *, 112 | random_state, 113 | tie_approximation="breslow", 114 | dump_splits=False, 115 | verbose=False, 116 | ): 117 | X, y = load_data(dataset) 118 | X_train, X_test, y_train, y_test = train_test_split( 119 | X, y, random_state=random_state, stratify=y[:, 1], test_size=0.20 120 | ) 121 | 122 | if dump_splits: 123 | for array, name in [ 124 | (X_train, "x_train"), 125 | (y_train, "y_train"), 126 | (X_test, "x_test"), 127 | (y_test, "y_test"), 128 | ]: 129 | dump(array, f"{dataset}_{name}_{random_state}.csv") 130 | 131 | cv = list( 132 | StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state).split( 133 | X_train, y_train[:, 1] 134 | ) 135 | ) 136 | 137 | model = LassoNetCoxRegressorCV( 138 | tie_approximation=tie_approximation, 139 | hidden_dims=hidden_dims, 140 | path_multiplier=path_multiplier, 141 | cv=cv, 142 | torch_seed=random_state, 143 | verbose=verbose, 144 | ) 145 | model.path(X_train, y_train) 146 | plot_cv(model, X_test, y_test) 147 | plt.savefig( 148 | FIGURES_PATH 149 | / ( 150 | f"cox-cv-{dataset}-{random_state}" 151 | f"-{model.hidden_dims}-{model.path_multiplier}.jpg" 152 | ), 153 | dpi=300, 154 | ) 155 | 156 | test_score = model.score(X_test, y_test) 157 | 158 | return test_score 159 | 160 | 161 | def experiment(dataset, hidden_dims, path_multiplier, n_runs, n_jobs): 162 | start = time() 163 | with tqdm_joblib(desc=f"Running on {dataset}", total=n_runs): 164 | scores = np.array( 165 | Parallel(n_jobs=n_jobs)( 166 | delayed(run)( 167 | dataset, 168 | hidden_dims=hidden_dims, 169 | path_multiplier=path_multiplier, 170 | tie_approximation=tie_approximation, 171 | random_state=random_state, 172 | ) 173 | for random_state in range(n_runs) 174 | ) 175 | ) 176 | scores_str = np.array2string( 177 | scores, separator=", ", formatter={"float_kind": "{:.2f}".format} 178 | ) 179 | log = ( 180 | f"Dataset: {dataset}\n" 181 | f"Arch: {hidden_dims}\n" 182 | f"Path multiplier: {path_multiplier}\n" 183 | f"Score: {scores.mean():.04f} " 184 | f"± {confidence_interval(scores):.04f} " 185 | f"with {n_runs} runs\n" 186 | f"Raw scores: {scores_str}\n" 187 | f"Running time: {time() - start:.00f}s with {n_jobs} processors\n" 188 | f"Running time per run per cpu: {(time() - start) * n_jobs / n_runs:.00f}s\n" 189 | f"{'-' * 50}\n" 190 | ) 191 | 192 | tqdm.write(log) 193 | with open("cox_experiments.log", "a") as f: 194 | print(log, file=f) 195 | 196 | 197 | if __name__ == "__main__": 198 | """ 199 | run with python3 script.py dataset [method] 200 | 201 | dataset=all runs all experiments 202 | 203 | method can be "breslow" or "efron" (default "efron") 204 | """ 205 | 206 | import sys 207 | 208 | dataset = sys.argv[1] 209 | tie_approximation = sys.argv[2] if len(sys.argv) > 2 else "efron" 210 | if dataset == "all": 211 | datasets = ["breast", "whas500", "veterans", "hnscc"] 212 | verbose = False 213 | else: 214 | datasets = [dataset] 215 | verbose = 1 216 | 217 | N_RUNS = 15 218 | N_JOBS = 5 # set to a divisor of `n_runs` for maximal efficiency 219 | 220 | for hidden_dims in [(16, 16), (32,), (32, 16), (64,)]: 221 | for path_multiplier in [1.01, 1.02]: 222 | for dataset in datasets: 223 | experiment( 224 | dataset=dataset, 225 | hidden_dims=hidden_dims, 226 | path_multiplier=path_multiplier, 227 | n_runs=N_RUNS, 228 | n_jobs=N_JOBS, 229 | ) 230 | 231 | 232 | # import optuna 233 | 234 | # def objective(trial: optuna.Trial): 235 | # model = LassoNetCoxRegressorCV( 236 | # tie_approximation="breslow", 237 | # hidden_dims=(trial.suggest_int("hidden_dims", 8, 128),), 238 | # path_multiplier=1.01, 239 | # M=trial.suggest_float("M", 1e-3, 1e3), 240 | # cv=cv, 241 | # torch_seed=random_state, 242 | # verbose=False, 243 | # ) 244 | # model.fit(X_train, y_train) 245 | # test_score = model.score(X_test, y_test) 246 | # trial.set_user_attr("score std", model.best_cv_score_std_) 247 | # trial.set_user_attr("test score", test_score) 248 | # print("test score", test_score) 249 | # return model.best_cv_score_ 250 | 251 | 252 | # if __name__ == "__main__": 253 | # study = optuna.create_study( 254 | # storage="sqlite:///optuna.db", 255 | # study_name="fastcph-lassonet", 256 | # direction="maximize", 257 | # load_if_exists=True, 258 | # ) 259 | # study.optimize(objective, n_trials=100) 260 | -------------------------------------------------------------------------------- /examples/cox_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from sklearn.model_selection import train_test_split 7 | 8 | from lassonet import LassoNetCoxRegressor, plot_path 9 | 10 | data = Path(__file__).parent / "data" 11 | X = np.genfromtxt(data / "hnscc_x.csv", delimiter=",", skip_header=1) 12 | y = np.genfromtxt(data / "hnscc_y.csv", delimiter=",", skip_header=1) 13 | 14 | 15 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 16 | 17 | model = LassoNetCoxRegressor( 18 | hidden_dims=(100,), 19 | lambda_start=1e-2, 20 | path_multiplier=1.02, 21 | gamma=1, 22 | verbose=True, 23 | tie_approximation="breslow", 24 | ) 25 | 26 | model.path(X_train, y_train, return_state_dicts=True) 27 | 28 | plot_path(model, X_test, y_test) 29 | plt.savefig("cox_regression.png") 30 | 31 | 32 | model = LassoNetCoxRegressor( 33 | hidden_dims=(100,), 34 | lambda_start=1e-2, 35 | path_multiplier=1.02, 36 | gamma=1, 37 | verbose=True, 38 | tie_approximation="efron", 39 | ) 40 | 41 | path = model.path(X_train, y_train) 42 | 43 | plot_path(model, X_test, y_test) 44 | plt.savefig("cox_regression_efron.png") 45 | -------------------------------------------------------------------------------- /examples/data/hnscc_y.csv: -------------------------------------------------------------------------------- 1 | Overall survival_duration,OS4Y 2 | 1763.0,0.0 3 | 407.0,1.0 4 | 2329.0,0.0 5 | 3982.0,0.0 6 | 2921.0,0.0 7 | 2560.0,0.0 8 | 1462.0,0.0 9 | 3101.0,0.0 10 | 2815.0,0.0 11 | 915.0,0.0 12 | 4347.0,0.0 13 | 1988.0,0.0 14 | 1277.0,0.0 15 | 3435.0,0.0 16 | 3982.0,0.0 17 | 879.0,1.0 18 | 3162.0,0.0 19 | 3739.0,0.0 20 | 1946.0,0.0 21 | 1933.0,0.0 22 | 3587.0,0.0 23 | 183.0,1.0 24 | 2329.0,0.0 25 | 2441.0,0.0 26 | 2678.0,0.0 27 | 1933.0,0.0 28 | 775.0,0.0 29 | 2113.0,0.0 30 | 1788.0,0.0 31 | 3587.0,0.0 32 | 568.0,1.0 33 | 2207.0,0.0 34 | 304.0,1.0 35 | 958.0,0.0 36 | 303.0,0.0 37 | 2648.0,0.0 38 | 3374.0,0.0 39 | 227.0,1.0 40 | 298.0,1.0 41 | 2207.0,0.0 42 | 3162.0,0.0 43 | 1909.0,0.0 44 | 426.0,1.0 45 | 2994.0,0.0 46 | 851.0,1.0 47 | 2608.0,0.0 48 | 2414.0,0.0 49 | 3314.0,0.0 50 | 2289.0,0.0 51 | 1857.0,0.0 52 | 2900.0,0.0 53 | 3344.0,0.0 54 | 2776.0,0.0 55 | 3709.0,0.0 56 | 3952.0,0.0 57 | 3283.0,0.0 58 | 2833.0,0.0 59 | 447.0,1.0 60 | 1760.0,0.0 61 | 3314.0,0.0 62 | 1028.0,0.0 63 | 3709.0,0.0 64 | 2244.0,0.0 65 | 1860.0,0.0 66 | 3101.0,0.0 67 | 1088.0,0.0 68 | 1480.0,0.0 69 | 2298.0,0.0 70 | 2481.0,0.0 71 | 277.0,1.0 72 | 3618.0,0.0 73 | 3830.0,0.0 74 | 979.0,1.0 75 | 1827.0,0.0 76 | 3283.0,0.0 77 | 2882.0,0.0 78 | 2411.0,0.0 79 | 863.0,1.0 80 | 502.0,1.0 81 | 669.0,1.0 82 | 1006.0,0.0 83 | 2779.0,0.0 84 | 876.0,1.0 85 | 1043.0,0.0 86 | 1842.0,0.0 87 | 33.0,1.0 88 | 2399.0,0.0 89 | 1593.0,0.0 90 | 2572.0,0.0 91 | 2468.0,0.0 92 | 2584.0,0.0 93 | 1809.0,0.0 94 | 2092.0,0.0 95 | 1067.0,1.0 96 | 3101.0,0.0 97 | 1985.0,0.0 98 | 2450.0,0.0 99 | 1067.0,0.0 100 | 3000.0,0.0 101 | 3192.0,0.0 102 | 3222.0,0.0 103 | 1964.0,0.0 104 | 444.0,0.0 105 | 2462.0,0.0 106 | 2739.0,0.0 107 | 1867.0,0.0 108 | 4043.0,0.0 109 | 2879.0,0.0 110 | 1991.0,0.0 111 | 2538.0,0.0 112 | 1943.0,0.0 113 | 2043.0,0.0 114 | 3222.0,0.0 115 | 486.0,1.0 116 | 322.0,1.0 117 | 2301.0,0.0 118 | 3435.0,0.0 119 | 2274.0,0.0 120 | 2578.0,0.0 121 | 2560.0,0.0 122 | 2107.0,0.0 123 | 851.0,0.0 124 | 3070.0,0.0 125 | 2763.0,0.0 126 | 2505.0,0.0 127 | 1724.0,0.0 128 | 3739.0,0.0 129 | 3192.0,0.0 130 | 1240.0,0.0 131 | 1973.0,0.0 132 | 1234.0,0.0 133 | 2247.0,0.0 134 | 2645.0,0.0 135 | 882.0,1.0 136 | 2107.0,0.0 137 | 1766.0,0.0 138 | 2472.0,0.0 139 | 3374.0,0.0 140 | 3496.0,0.0 141 | 1414.0,1.0 142 | 3283.0,0.0 143 | 2338.0,0.0 144 | 1155.0,0.0 145 | 450.0,1.0 146 | 2335.0,0.0 147 | 605.0,1.0 148 | 3709.0,0.0 149 | 2584.0,0.0 150 | 2915.0,0.0 151 | 2204.0,0.0 152 | 4195.0,0.0 153 | 502.0,1.0 154 | 1791.0,0.0 155 | 2253.0,0.0 156 | 4165.0,0.0 157 | 4226.0,0.0 158 | 3405.0,0.0 159 | 2329.0,0.0 160 | 4286.0,0.0 161 | 2353.0,0.0 162 | 1788.0,0.0 163 | 1924.0,0.0 164 | 4165.0,0.0 165 | 3162.0,0.0 166 | 2809.0,0.0 167 | 2240.0,0.0 168 | 3466.0,0.0 169 | 1645.0,0.0 170 | 1176.0,0.0 171 | 2441.0,0.0 172 | 882.0,0.0 173 | 3678.0,0.0 174 | 3070.0,0.0 175 | 3557.0,0.0 176 | 824.0,1.0 177 | 2897.0,0.0 178 | 857.0,0.0 179 | 1888.0,0.0 180 | 118.0,0.0 181 | 3040.0,0.0 182 | 3070.0,0.0 183 | 2630.0,0.0 184 | 471.0,1.0 185 | 3496.0,0.0 186 | 3007.0,0.0 187 | 3435.0,0.0 188 | 2289.0,0.0 189 | 769.0,1.0 190 | 2943.0,0.0 191 | 2724.0,0.0 192 | 1943.0,0.0 193 | 2766.0,0.0 194 | 4134.0,0.0 195 | 2554.0,0.0 196 | 2490.0,0.0 197 | 2687.0,0.0 198 | 2265.0,0.0 199 | 2785.0,0.0 200 | 2143.0,0.0 201 | 1958.0,0.0 202 | 815.0,1.0 203 | 3344.0,0.0 204 | 2985.0,0.0 205 | 2675.0,0.0 206 | 2937.0,0.0 207 | 2660.0,0.0 208 | 529.0,1.0 209 | 1912.0,0.0 210 | 2852.0,0.0 211 | 3678.0,0.0 212 | 1788.0,0.0 213 | 2526.0,0.0 214 | 3222.0,0.0 215 | 1705.0,0.0 216 | 2912.0,0.0 217 | 3770.0,0.0 218 | 468.0,1.0 219 | 3253.0,0.0 220 | 2110.0,0.0 221 | 1538.0,0.0 222 | 304.0,1.0 223 | 2350.0,0.0 224 | 1857.0,0.0 225 | 3800.0,0.0 226 | 2809.0,0.0 227 | 1906.0,0.0 228 | 2073.0,0.0 229 | 3283.0,0.0 230 | 3405.0,0.0 231 | 3800.0,0.0 232 | 3101.0,0.0 233 | 3739.0,0.0 234 | 3922.0,0.0 235 | 2544.0,0.0 236 | 3435.0,0.0 237 | 2347.0,0.0 238 | 2529.0,0.0 239 | 2079.0,0.0 240 | 1368.0,1.0 241 | 2429.0,0.0 242 | 2201.0,0.0 243 | 1338.0,1.0 244 | 2386.0,0.0 245 | 1216.0,1.0 246 | 854.0,1.0 247 | 3070.0,0.0 248 | 2888.0,0.0 249 | 2800.0,0.0 250 | 4165.0,0.0 251 | 3253.0,0.0 252 | 1882.0,0.0 253 | 1915.0,0.0 254 | 3374.0,0.0 255 | 3861.0,0.0 256 | 830.0,1.0 257 | 1912.0,0.0 258 | 3466.0,0.0 259 | 3678.0,0.0 260 | 2748.0,0.0 261 | 2438.0,0.0 262 | 1277.0,0.0 263 | 2304.0,0.0 264 | 2909.0,0.0 265 | 2852.0,0.0 266 | 3101.0,0.0 267 | 2788.0,0.0 268 | 4195.0,0.0 269 | 2721.0,0.0 270 | 3587.0,0.0 271 | 4195.0,0.0 272 | 4165.0,0.0 273 | 3557.0,0.0 274 | 638.0,1.0 275 | 538.0,1.0 276 | 3070.0,0.0 277 | 3587.0,0.0 278 | 2666.0,0.0 279 | 1122.0,1.0 280 | 3253.0,0.0 281 | 3709.0,0.0 282 | 3891.0,0.0 283 | 1702.0,0.0 284 | 2216.0,0.0 285 | 4347.0,0.0 286 | 328.0,1.0 287 | 4165.0,0.0 288 | 3131.0,0.0 289 | 2724.0,0.0 290 | 3891.0,0.0 291 | 1985.0,0.0 292 | 4165.0,0.0 293 | 2587.0,0.0 294 | 2572.0,0.0 295 | 2377.0,0.0 296 | 2271.0,0.0 297 | 1508.0,0.0 298 | 2371.0,0.0 299 | 2052.0,0.0 300 | 3070.0,0.0 301 | 2648.0,0.0 302 | 2313.0,0.0 303 | 2541.0,0.0 304 | 2748.0,0.0 305 | 1815.0,0.0 306 | 3982.0,0.0 307 | 2690.0,0.0 308 | 1857.0,0.0 309 | 1936.0,0.0 310 | 3982.0,0.0 311 | 1611.0,0.0 312 | 1794.0,0.0 313 | 2438.0,0.0 314 | 2703.0,0.0 315 | 1693.0,0.0 316 | 2991.0,0.0 317 | 4043.0,0.0 318 | 2946.0,0.0 319 | 2903.0,0.0 320 | 3070.0,0.0 321 | 3040.0,0.0 322 | 1389.0,0.0 323 | 2554.0,0.0 324 | 1982.0,0.0 325 | 2788.0,0.0 326 | 2098.0,0.0 327 | 1529.0,0.0 328 | 2304.0,0.0 329 | 3435.0,0.0 330 | 2481.0,0.0 331 | 2405.0,0.0 332 | 2852.0,0.0 333 | 2000.0,0.0 334 | 3800.0,0.0 335 | 1210.0,1.0 336 | 499.0,1.0 337 | 2298.0,0.0 338 | 3374.0,0.0 339 | 3253.0,0.0 340 | 3709.0,0.0 341 | 2146.0,0.0 342 | 2700.0,0.0 343 | 3739.0,0.0 344 | 2551.0,0.0 345 | 1614.0,0.0 346 | 3922.0,0.0 347 | 3283.0,0.0 348 | 1693.0,0.0 349 | 2979.0,0.0 350 | 1836.0,0.0 351 | 2417.0,0.0 352 | 1052.0,1.0 353 | 3344.0,0.0 354 | 3618.0,0.0 355 | 3982.0,0.0 356 | 223.0,1.0 357 | 1599.0,0.0 358 | 614.0,1.0 359 | 386.0,1.0 360 | 1818.0,0.0 361 | 4499.0,0.0 362 | 2973.0,0.0 363 | 1395.0,1.0 364 | 3618.0,0.0 365 | 2301.0,0.0 366 | 1456.0,1.0 367 | 2988.0,0.0 368 | 1681.0,0.0 369 | 3770.0,0.0 370 | 1155.0,0.0 371 | 2672.0,0.0 372 | 784.0,0.0 373 | 3557.0,0.0 374 | 73.0,1.0 375 | 1541.0,0.0 376 | 3861.0,0.0 377 | 3861.0,0.0 378 | 3314.0,0.0 379 | 3070.0,0.0 380 | 3466.0,0.0 381 | 3192.0,0.0 382 | 4134.0,0.0 383 | 2450.0,0.0 384 | 1207.0,1.0 385 | 2064.0,0.0 386 | 1377.0,0.0 387 | 3374.0,0.0 388 | 1502.0,0.0 389 | 3025.0,0.0 390 | 1900.0,0.0 391 | 869.0,1.0 392 | 3618.0,0.0 393 | 2611.0,0.0 394 | 2918.0,0.0 395 | 1800.0,0.0 396 | 3253.0,0.0 397 | 2748.0,0.0 398 | 395.0,1.0 399 | 3314.0,0.0 400 | 3405.0,0.0 401 | 1338.0,0.0 402 | 2204.0,0.0 403 | 3770.0,0.0 404 | 1958.0,0.0 405 | 2280.0,0.0 406 | 2964.0,0.0 407 | 3496.0,0.0 408 | 3283.0,0.0 409 | 1587.0,0.0 410 | 1745.0,0.0 411 | 1651.0,0.0 412 | 784.0,0.0 413 | 2031.0,0.0 414 | 1897.0,0.0 415 | 1970.0,0.0 416 | 3034.0,0.0 417 | 2122.0,0.0 418 | 2766.0,0.0 419 | 1848.0,0.0 420 | 1818.0,0.0 421 | 2335.0,0.0 422 | 383.0,1.0 423 | 2377.0,0.0 424 | 340.0,0.0 425 | 3557.0,0.0 426 | 377.0,1.0 427 | 3891.0,0.0 428 | 1833.0,0.0 429 | 2289.0,0.0 430 | 3952.0,0.0 431 | 3131.0,0.0 432 | 2727.0,0.0 433 | 3222.0,0.0 434 | 2943.0,0.0 435 | 2174.0,0.0 436 | 3344.0,0.0 437 | 2052.0,0.0 438 | 389.0,1.0 439 | 3000.0,0.0 440 | 2885.0,0.0 441 | 3040.0,0.0 442 | 2493.0,0.0 443 | 1879.0,0.0 444 | 4226.0,0.0 445 | 3101.0,0.0 446 | 2213.0,0.0 447 | 2079.0,0.0 448 | 3800.0,0.0 449 | 3374.0,0.0 450 | 2794.0,0.0 451 | 2997.0,0.0 452 | 3222.0,0.0 453 | -------------------------------------------------------------------------------- /examples/diabetes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from sklearn.datasets import load_diabetes 6 | from sklearn.model_selection import train_test_split 7 | from sklearn.preprocessing import StandardScaler, scale 8 | 9 | from lassonet import LassoNetRegressor, plot_path 10 | 11 | dataset = load_diabetes() 12 | X = dataset.data 13 | y = dataset.target 14 | _, true_features = X.shape 15 | # add dummy feature 16 | X = np.concatenate([X, np.random.randn(*X.shape)], axis=1) 17 | feature_names = list(dataset.feature_names) + ["fake"] * true_features 18 | 19 | # standardize 20 | X = StandardScaler().fit_transform(X) 21 | y = scale(y) 22 | 23 | 24 | X_train, X_test, y_train, y_test = train_test_split(X, y) 25 | 26 | model = LassoNetRegressor( 27 | hidden_dims=(10,), 28 | verbose=True, 29 | ) 30 | path = model.path(X_train, y_train, return_state_dicts=True) 31 | 32 | plot_path(model, X_test, y_test) 33 | 34 | plt.savefig("diabetes.png") 35 | 36 | plt.clf() 37 | 38 | n_features = X.shape[1] 39 | importances = model.feature_importances_.numpy() 40 | order = np.argsort(importances)[::-1] 41 | importances = importances[order] 42 | ordered_feature_names = [feature_names[i] for i in order] 43 | color = np.array(["g"] * true_features + ["r"] * (n_features - true_features))[order] 44 | 45 | 46 | plt.subplot(211) 47 | plt.bar( 48 | np.arange(n_features), 49 | importances, 50 | color=color, 51 | ) 52 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90) 53 | colors = {"real features": "g", "fake features": "r"} 54 | labels = list(colors.keys()) 55 | handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label]) for label in labels] 56 | plt.legend(handles, labels) 57 | plt.ylabel("Feature importance") 58 | 59 | _, order = np.unique(importances, return_inverse=True) 60 | 61 | plt.subplot(212) 62 | plt.bar( 63 | np.arange(n_features), 64 | order + 1, 65 | color=color, 66 | ) 67 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90) 68 | plt.legend(handles, labels) 69 | plt.ylabel("Feature order") 70 | 71 | plt.savefig("diabetes-bar.png") 72 | -------------------------------------------------------------------------------- /examples/friedman.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from sklearn.metrics import r2_score 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.preprocessing import StandardScaler 6 | 7 | from lassonet import LassoNetRegressor, plot_path 8 | 9 | 10 | def friedman(linear_terms=True): 11 | n = 1000 12 | p = 200 13 | X = np.random.rand(n, p) 14 | y = 10 * np.sin(np.pi * X[:, 0] * X[:, 1]) + 20 * (X[:, 2] - 0.5) ** 2 15 | if linear_terms: 16 | y += 10 * X[:, 3] + 5 * X[:, 4] 17 | return X, y 18 | 19 | 20 | np.random.seed(0) 21 | X, y = friedman(linear_terms=True) 22 | X = StandardScaler().fit_transform(X) 23 | X_train, X_test, y_train, y_test = train_test_split(X, y) 24 | 25 | y_std = 0.5 * y_train.std() 26 | y_train += np.random.randn(*y_train.shape) * y_std 27 | 28 | for y in [y_train, y_test]: 29 | y -= y.mean() 30 | y /= y.std() 31 | 32 | 33 | def rrmse(y, y_pred): 34 | return np.sqrt(1 - r2_score(y, y_pred)) 35 | 36 | 37 | for path_multiplier in [1.01, 1.001]: 38 | print("path_multiplier:", path_multiplier) 39 | for M in [10, 100, 1_000, 10_000, 100_000]: 40 | print("M:", M) 41 | model = LassoNetRegressor( 42 | hidden_dims=(10, 10), 43 | random_state=0, 44 | torch_seed=0, 45 | path_multiplier=path_multiplier, 46 | M=M, 47 | ) 48 | path = model.path(X_train, y_train, return_state_dicts=True) 49 | print( 50 | "rrmse:", 51 | min(rrmse(y_test, model.load(save).predict(X_test)) for save in path), 52 | ) 53 | plot_path(model, X_test, y_test, score_function=rrmse) 54 | plt.savefig(f"friedman_path({path_multiplier})_M({M}).jpg") 55 | 56 | path_multiplier = 1.001 57 | print("path_multiplier:", path_multiplier) 58 | for M in [100, 1_000, 10_000, 100_000]: 59 | print("M:", M) 60 | model = LassoNetRegressor( 61 | hidden_dims=(10, 10), 62 | random_state=0, 63 | torch_seed=0, 64 | path_multiplier=path_multiplier, 65 | M=M, 66 | backtrack=True, 67 | ) 68 | path = model.path(X_train, y_train) 69 | print( 70 | "rrmse:", 71 | min(rrmse(y_test, model.load(save).predict(X_test)) for save in path), 72 | ) 73 | plot_path(model, X_test, y_test, score_function=rrmse) 74 | plt.savefig(f"friedman_path({path_multiplier})_M({M})_backtrack.jpg") 75 | 76 | 77 | for path_multiplier in [1.01, 1.001]: 78 | M = 100_000 79 | print("path_multiplier:", path_multiplier) 80 | print("M:", M) 81 | model = LassoNetRegressor( 82 | hidden_dims=(10, 10), 83 | random_state=0, 84 | torch_seed=0, 85 | path_multiplier=path_multiplier, 86 | M=M, 87 | patience=100, 88 | n_iters=1000, 89 | ) 90 | path = model.path(X_train, y_train) 91 | print( 92 | "rrmse:", 93 | min(rrmse(y_test, model.load(save).predict(X_test)) for save in path), 94 | ) 95 | plot_path(model, X_test, y_test, score_function=rrmse) 96 | plt.savefig(f"friedman_path({path_multiplier})_M({M})_long.jpg") 97 | 98 | # if __name__ == "__main__": 99 | 100 | # model = LassoNetRegressor(verbose=True, path_multiplier=1.01, hidden_dims=(10, 10)) 101 | # path = model.path(X_train, y_train) 102 | 103 | # plot_path(model, X_test, y_test, score_function=rrmse) 104 | # plt.show() 105 | -------------------------------------------------------------------------------- /examples/friedman/download.sh: -------------------------------------------------------------------------------- 1 | URL="https://raw.githubusercontent.com/warbelo/Lockout/main/Synthetic_Data2/dataset_a" 2 | 3 | for name in X.csv Y.csv xtest.csv xtrain.csv xvalid.csv ytest.csv ytrain.csv yvalid.csv; do 4 | curl $URL/$name > $name 5 | done -------------------------------------------------------------------------------- /examples/friedman/main.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from sklearn.metrics import r2_score 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | from lassonet import LassoNetRegressor, plot_path 7 | 8 | 9 | def load(s): 10 | return np.loadtxt(s, delimiter=",") 11 | 12 | 13 | X_train = load("xtrain.csv") 14 | y_train = load("ytrain.csv") 15 | X_val = load("xvalid.csv") 16 | y_val = load("yvalid.csv") 17 | X_test = load("xtest.csv") 18 | y_test = load("ytest.csv") 19 | 20 | X_train, X_val, X_test = np.split( 21 | StandardScaler().fit_transform(np.concatenate((X_train, X_val, X_test))), 3 22 | ) 23 | 24 | y = np.concatenate((y_train, y_val)) 25 | y_mean = y.mean() 26 | y_std = y.std() 27 | 28 | for y in [y_train, y_val, y_test]: 29 | y -= y_mean 30 | y /= y_std 31 | 32 | 33 | def rrmse(y, y_pred): 34 | return np.sqrt(1 - r2_score(y, y_pred)) 35 | 36 | 37 | if __name__ == "__main__": 38 | 39 | model = LassoNetRegressor( 40 | path_multiplier=1.001, 41 | M=100_000, 42 | hidden_dims=(10, 10), 43 | torch_seed=0, 44 | ) 45 | path = model.path( 46 | X_train, y_train, X_val=X_val, y_val=y_val, return_state_dicts=True 47 | ) 48 | print( 49 | "rrmse:", 50 | min(rrmse(y_test, model.load(save).predict(X_test)) for save in path), 51 | ) 52 | plot_path(model, X_test, y_test, score_function=rrmse) 53 | plt.show() 54 | -------------------------------------------------------------------------------- /examples/generated.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | from sklearn.metrics import r2_score 5 | from sklearn.model_selection import train_test_split 6 | from sklearn.preprocessing import StandardScaler 7 | 8 | from lassonet import LassoNetRegressor, plot_path 9 | 10 | 11 | def linear(): 12 | p = 10 13 | n = 400 14 | coef = np.concatenate([np.random.choice([-1, 1], size=p), [0] * p]) 15 | X = np.random.randn(n, 2 * p) 16 | 17 | linear = X.dot(coef) 18 | noise = np.random.randn(n) 19 | 20 | y = linear + noise 21 | return X, y 22 | 23 | 24 | def strong_linear(): 25 | p = 10 26 | n = 400 27 | coef = np.concatenate([np.random.choice([-1, 1], size=p), [0] * p]) 28 | X = np.random.randn(n, 2 * p) 29 | 30 | linear = X.dot(coef) 31 | noise = np.random.randn(n) 32 | x1, x2, x3, *_ = X.T 33 | nonlinear = 2 * (x1**3 - 3 * x1) + 4 * (x2**2 * x3 - x3) 34 | y = 6 * linear + 8 * noise + nonlinear 35 | return X, y 36 | 37 | 38 | def friedman_lockout(): 39 | p = 200 40 | n = 1000 41 | X = np.random.rand(n, p) 42 | y = ( 43 | 10 * np.sin(np.pi * X[:, 0] * X[:, 1]) 44 | + 20 * (X[:, 2] - 0.5) ** 2 45 | + 10 * X[:, 3] 46 | + 5 * X[:, 4] 47 | ) 48 | return X, y 49 | 50 | 51 | for generator in [linear, strong_linear, friedman_lockout]: 52 | X, y = generator() 53 | X = StandardScaler().fit_transform(X) 54 | y -= y.mean() 55 | y /= y.std() 56 | X_train, X_test, y_train, y_test = train_test_split(X, y) 57 | 58 | model = LassoNetRegressor(verbose=True, path_multiplier=1.01, hidden_dims=(10, 10)) 59 | 60 | path = model.path(X_train, y_train, return_state_dicts=True) 61 | import matplotlib.pyplot as plt 62 | 63 | def score(self, X, y, sample_weight=None): 64 | y_pred = self.predict(X) 65 | return np.sqrt(1 - r2_score(y, y_pred, sample_weight=sample_weight)) 66 | 67 | model.score = partial(score, model) 68 | 69 | plot_path(model, X_test, y_test) 70 | plt.show() 71 | -------------------------------------------------------------------------------- /examples/miceprotein.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | """ 5 | We run Lassonet over [the Mice Dataset](https://archive.ics.uci.edu/ml/datasets/Mice%20Protein%20Expression). 6 | This dataset consists of protein expression levels measured in the cortex of normal and 7 | trisomic mice who had been exposed to different experimental conditions. 8 | Each feature is the expression level of one protein. 9 | """ 10 | 11 | import matplotlib.pyplot as plt 12 | from sklearn.datasets import fetch_openml 13 | from sklearn.impute import SimpleImputer 14 | from sklearn.model_selection import train_test_split 15 | from sklearn.preprocessing import LabelEncoder, StandardScaler 16 | 17 | from lassonet import LassoNetClassifier, plot_cv, plot_path 18 | from lassonet.interfaces import LassoNetClassifierCV 19 | 20 | X, y = fetch_openml(name="miceprotein", return_X_y=True) 21 | # Fill missing values with the mean 22 | X = SimpleImputer().fit_transform(X) 23 | # Convert labels to scalar 24 | y = LabelEncoder().fit_transform(y) 25 | 26 | # standardize 27 | X = StandardScaler().fit_transform(X) 28 | 29 | X_train, X_test, y_train, y_test = train_test_split(X, y) 30 | 31 | 32 | model = LassoNetClassifierCV() 33 | model.path(X_train, y_train, return_state_dicts=True) 34 | print("Best model scored", model.score(X_test, y_test)) 35 | print("Lambda =", model.best_lambda_) 36 | plot_cv(model, X_test, y_test) 37 | plt.savefig("miceprotein-cv.png") 38 | 39 | model = LassoNetClassifier() 40 | path = model.path(X_train, y_train, return_state_dicts=True) 41 | plot_path(model, X_test, y_test) 42 | plt.savefig("miceprotein.png") 43 | 44 | model = LassoNetClassifier(dropout=0.5) 45 | path = model.path(X_train, y_train, return_state_dicts=True) 46 | plot_path(model, X_test, y_test) 47 | plt.savefig("miceprotein_dropout.png") 48 | 49 | model = LassoNetClassifier(hidden_dims=(100, 100)) 50 | path = model.path(X_train, y_train, return_state_dicts=True) 51 | plot_path(model, X_test, y_test) 52 | plt.savefig("miceprotein_deep.png") 53 | 54 | model = LassoNetClassifier(hidden_dims=(100, 100), gamma=0.01) 55 | path = model.path(X_train, y_train, return_state_dicts=True) 56 | plot_path(model, X_test, y_test) 57 | plt.savefig("miceprotein_deep_l2_weak.png") 58 | 59 | model = LassoNetClassifier(hidden_dims=(100, 100), gamma=0.1) 60 | path = model.path(X_train, y_train, return_state_dicts=True) 61 | plot_path(model, X_test, y_test) 62 | plt.savefig("miceprotein_deep_l2_strong.png") 63 | 64 | model = LassoNetClassifier(hidden_dims=(100, 100), gamma=1) 65 | path = model.path(X_train, y_train, return_state_dicts=True) 66 | plot_path(model, X_test, y_test) 67 | plt.savefig("miceprotein_deep_l2_super_strong.png") 68 | 69 | model = LassoNetClassifier(hidden_dims=(100, 100), dropout=0.5) 70 | path = model.path(X_train, y_train, return_state_dicts=True) 71 | plot_path(model, X_test, y_test) 72 | plt.savefig("miceprotein_deep_dropout.png") 73 | 74 | model = LassoNetClassifier(hidden_dims=(100, 100), backtrack=True, dropout=0.5) 75 | path = model.path(X_train, y_train, return_state_dicts=True) 76 | plot_path(model, X_test, y_test) 77 | plt.savefig("miceprotein_deep_dropout_backtrack.png") 78 | 79 | model = LassoNetClassifier(batch_size=64) 80 | path = model.path(X_train, y_train, return_state_dicts=True) 81 | plot_path(model, X_test, y_test) 82 | plt.savefig("miceprotein_64.png") 83 | 84 | model = LassoNetClassifier(backtrack=True) 85 | path = model.path(X_train, y_train, return_state_dicts=True) 86 | plot_path(model, X_test, y_test) 87 | plt.savefig("miceprotein_backtrack.png") 88 | 89 | model = LassoNetClassifier(batch_size=64, backtrack=True) 90 | path = model.path(X_train, y_train, return_state_dicts=True) 91 | plot_path(model, X_test, y_test) 92 | plt.savefig("miceprotein_backtrack_64.png") 93 | 94 | model = LassoNetClassifier(class_weight=[0.1, 0.2, 0.3, 0.1, 0.3, 0, 0, 0]) 95 | path = model.path(X_train, y_train, return_state_dicts=True) 96 | plot_path(model, X_test, y_test) 97 | plt.savefig("miceprotein_weighted.png") 98 | -------------------------------------------------------------------------------- /examples/mnist_ae.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from sklearn.datasets import fetch_openml 3 | from sklearn.metrics import mean_squared_error 4 | 5 | from lassonet import LassoNetAutoEncoder 6 | 7 | X, y = fetch_openml(name="mnist_784", return_X_y=True) 8 | filter = y == "3" 9 | X = X[filter].values / 255 10 | 11 | model = LassoNetAutoEncoder( 12 | M=30, n_iters=(3000, 500), path_multiplier=1.05, verbose=True 13 | ) 14 | path = model.path(X) 15 | 16 | img = model.feature_importances_.reshape(28, 28) 17 | 18 | plt.title("Feature importance to reconstruct 3") 19 | plt.imshow(img) 20 | plt.colorbar() 21 | plt.savefig("mnist-ae-importance.png") 22 | 23 | 24 | n_selected = [] 25 | score = [] 26 | lambda_ = [] 27 | 28 | for save in path: 29 | model.load(save.state_dict) 30 | X_pred = model.predict(X) 31 | n_selected.append(save.selected.sum()) 32 | score.append(mean_squared_error(X_pred, X)) 33 | lambda_.append(save.lambda_) 34 | 35 | to_plot = [160, 220, 300] 36 | 37 | for i, save in zip(n_selected, path): 38 | if not to_plot: 39 | break 40 | if i > to_plot[-1]: 41 | continue 42 | to_plot.pop() 43 | plt.clf() 44 | plt.title(f"Linear model with {i} features") 45 | weight = save.state_dict["skip.weight"] 46 | img = (weight[1] - weight[0]).reshape(28, 28) 47 | plt.imshow(img) 48 | plt.colorbar() 49 | plt.savefig(f"mnist-ae-{i}.png") 50 | 51 | plt.clf() 52 | 53 | fig = plt.figure(figsize=(12, 12)) 54 | 55 | plt.subplot(311) 56 | plt.grid(True) 57 | plt.plot(n_selected, score, ".-") 58 | plt.xlabel("number of selected features") 59 | plt.ylabel("MSE") 60 | 61 | plt.subplot(312) 62 | plt.grid(True) 63 | plt.plot(lambda_, score, ".-") 64 | plt.xlabel("lambda") 65 | plt.xscale("log") 66 | plt.ylabel("MSE") 67 | 68 | plt.subplot(313) 69 | plt.grid(True) 70 | plt.plot(lambda_, n_selected, ".-") 71 | plt.xlabel("lambda") 72 | plt.xscale("log") 73 | plt.ylabel("number of selected features") 74 | 75 | plt.savefig("mnist-ae-training.png") 76 | 77 | 78 | plt.subplot(221) 79 | plt.imshow(X[150].reshape(28, 28)) 80 | plt.subplot(222) 81 | plt.imshow(model.predict(X[150]).reshape(28, 28)) 82 | plt.subplot(223) 83 | plt.imshow(X[250].reshape(28, 28)) 84 | plt.subplot(224) 85 | plt.imshow(model.predict(X[250]).reshape(28, 28)) 86 | -------------------------------------------------------------------------------- /examples/mnist_classif.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from sklearn.datasets import fetch_openml 3 | from sklearn.metrics import accuracy_score 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.preprocessing import LabelEncoder 6 | 7 | from lassonet import LassoNetClassifier 8 | 9 | X, y = fetch_openml(name="mnist_784", return_X_y=True) 10 | filter = y.isin(["5", "6"]) 11 | X = X[filter].values / 255 12 | y = LabelEncoder().fit_transform(y[filter]) 13 | 14 | X_train, X_test, y_train, y_test = train_test_split(X, y) 15 | 16 | model = LassoNetClassifier(M=30, verbose=True) 17 | path = model.path(X_train, y_train) 18 | 19 | img = model.feature_importances_.reshape(28, 28) 20 | 21 | plt.title("Feature importance to discriminate 5 and 6") 22 | plt.imshow(img) 23 | plt.colorbar() 24 | plt.savefig("mnist-classification-importance.png") 25 | 26 | n_selected = [] 27 | accuracy = [] 28 | lambda_ = [] 29 | 30 | for save in path: 31 | model.load(save.state_dict) 32 | y_pred = model.predict(X_test) 33 | n_selected.append(save.selected.sum()) 34 | accuracy.append(accuracy_score(y_test, y_pred)) 35 | lambda_.append(save.lambda_) 36 | 37 | to_plot = [160, 220, 300] 38 | 39 | for i, save in zip(n_selected, path): 40 | if not to_plot: 41 | break 42 | if i > to_plot[-1]: 43 | continue 44 | to_plot.pop() 45 | plt.clf() 46 | plt.title(f"Linear model with {i} features") 47 | weight = save.state_dict["skip.weight"] 48 | img = (weight[1] - weight[0]).reshape(28, 28) 49 | plt.imshow(img) 50 | plt.colorbar() 51 | plt.savefig(f"mnist-classification-{i}.png") 52 | 53 | fig = plt.figure(figsize=(12, 12)) 54 | 55 | plt.subplot(311) 56 | plt.grid(True) 57 | plt.plot(n_selected, accuracy, ".-") 58 | plt.xlabel("number of selected features") 59 | plt.ylabel("classification accuracy") 60 | 61 | plt.subplot(312) 62 | plt.grid(True) 63 | plt.plot(lambda_, accuracy, ".-") 64 | plt.xlabel("lambda") 65 | plt.xscale("log") 66 | plt.ylabel("classification accuracy") 67 | 68 | plt.subplot(313) 69 | plt.grid(True) 70 | plt.plot(lambda_, n_selected, ".-") 71 | plt.xlabel("lambda") 72 | plt.xscale("log") 73 | plt.ylabel("number of selected features") 74 | 75 | plt.savefig("mnist-classification-training.png") 76 | -------------------------------------------------------------------------------- /examples/mnist_reconstruction.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from sklearn.datasets import fetch_openml 3 | from sklearn.metrics import mean_squared_error 4 | 5 | from lassonet import LassoNetRegressor 6 | 7 | X, y = fetch_openml(name="mnist_784", return_X_y=True) 8 | filter = y == "3" 9 | X = X[filter].values / 255 10 | 11 | model = LassoNetRegressor(M=30, n_iters=(3000, 500), path_multiplier=1.05, verbose=True) 12 | path = model.path(X, X) 13 | 14 | img = model.feature_importances_.reshape(28, 28) 15 | 16 | plt.title("Feature importance to reconstruct 3") 17 | plt.imshow(img) 18 | plt.colorbar() 19 | plt.savefig("mnist-reconstruction-importance.png") 20 | 21 | 22 | n_selected = [] 23 | score = [] 24 | lambda_ = [] 25 | 26 | for save in path: 27 | model.load(save.state_dict) 28 | X_pred = model.predict(X) 29 | n_selected.append(save.selected.sum()) 30 | score.append(mean_squared_error(X_pred, X)) 31 | lambda_.append(save.lambda_) 32 | 33 | to_plot = [160, 220, 300] 34 | 35 | for i, save in zip(n_selected, path): 36 | if not to_plot: 37 | break 38 | if i > to_plot[-1]: 39 | continue 40 | to_plot.pop() 41 | plt.clf() 42 | plt.title(f"Linear model with {i} features") 43 | weight = save.state_dict["skip.weight"] 44 | img = (weight[1] - weight[0]).reshape(28, 28) 45 | plt.imshow(img) 46 | plt.colorbar() 47 | plt.savefig(f"mnist-reconstruction-{i}.png") 48 | 49 | plt.clf() 50 | 51 | fig = plt.figure(figsize=(12, 12)) 52 | 53 | plt.subplot(311) 54 | plt.grid(True) 55 | plt.plot(n_selected, score, ".-") 56 | plt.xlabel("number of selected features") 57 | plt.ylabel("MSE") 58 | 59 | plt.subplot(312) 60 | plt.grid(True) 61 | plt.plot(lambda_, score, ".-") 62 | plt.xlabel("lambda") 63 | plt.xscale("log") 64 | plt.ylabel("MSE") 65 | 66 | plt.subplot(313) 67 | plt.grid(True) 68 | plt.plot(lambda_, n_selected, ".-") 69 | plt.xlabel("lambda") 70 | plt.xscale("log") 71 | plt.ylabel("number of selected features") 72 | 73 | plt.savefig("mnist-reconstruction-training.png") 74 | 75 | 76 | plt.subplot(221) 77 | plt.imshow(X[150].reshape(28, 28)) 78 | plt.subplot(222) 79 | plt.imshow(model.predict(X[150]).reshape(28, 28)) 80 | plt.subplot(223) 81 | plt.imshow(X[250].reshape(28, 28)) 82 | plt.subplot(224) 83 | plt.imshow(model.predict(X[250]).reshape(28, 28)) 84 | -------------------------------------------------------------------------------- /experiments/README.MD: -------------------------------------------------------------------------------- 1 | - The data to reproduce Table 1 are available in [this Google Drive repository](https://drive.google.com/open?id=1quiURu7w0nU3Pxcc448xRgfeI80okLsS). Alternatively, you can download all the data sets (except MNIST and MNIST-Fashion) directly from the [UCI Repository](https://archive.ics.uci.edu/ml/datasets.php). 2 | - `data-utils.py` contains starter code to load the 6 datasets in Table 1 of [the paper](https://arxiv.org/abs/1907.12207). 3 | - You will need to download the files, unzip them and modify the `/home/lemisma/datasets` path in `data-utils.py` to point to your local copy. 4 | - `run.py` contains the necessary code to reproduce the results in Table 1. This script allows you to specify the dataset and other parameters. To run the script, navigate to the directory containing `run.py` and use the following command: `python run.py`. 5 | 6 | -------------------------------------------------------------------------------- /experiments/data_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import defaultdict 3 | from os.path import join 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.preprocessing import MinMaxScaler 9 | 10 | # The code to load some of these datasets is reproduced from 11 | # https://github.com/mfbalin/Concrete-Autoencoders/blob/master/experiments/generate_comparison_figures.py 12 | 13 | 14 | def load_mice(one_hot=False): 15 | filling_value = -100000 16 | X = np.genfromtxt( 17 | "/home/lemisma/datasets/MICE/Data_Cortex_Nuclear.csv", 18 | delimiter=",", 19 | skip_header=1, 20 | usecols=range(1, 78), 21 | filling_values=filling_value, 22 | encoding="UTF-8", 23 | ) 24 | classes = np.genfromtxt( 25 | "/home/lemisma/datasets/MICE/Data_Cortex_Nuclear.csv", 26 | delimiter=",", 27 | skip_header=1, 28 | usecols=range(78, 81), 29 | dtype=None, 30 | encoding="UTF-8", 31 | ) 32 | 33 | for i, row in enumerate(X): 34 | for j, val in enumerate(row): 35 | if val == filling_value: 36 | X[i, j] = np.mean( 37 | [ 38 | X[k, j] 39 | for k in range(classes.shape[0]) 40 | if np.all(classes[i] == classes[k]) 41 | ] 42 | ) 43 | 44 | DY = np.zeros((classes.shape[0]), dtype=np.uint8) 45 | for i, row in enumerate(classes): 46 | for j, (val, label) in enumerate(zip(row, ["Control", "Memantine", "C/S"])): 47 | DY[i] += (2**j) * (val == label) 48 | 49 | Y = np.zeros((DY.shape[0], np.unique(DY).shape[0])) 50 | for idx, val in enumerate(DY): 51 | Y[idx, val] = 1 52 | 53 | X = MinMaxScaler(feature_range=(0, 1)).fit_transform(X) 54 | 55 | indices = np.arange(X.shape[0]) 56 | np.random.shuffle(indices) 57 | X = X[indices] 58 | Y = Y[indices] 59 | DY = DY[indices] 60 | classes = classes[indices] 61 | 62 | if not one_hot: 63 | Y = DY 64 | 65 | X = X.astype(np.float32) 66 | Y = Y.astype(np.float32) 67 | 68 | print("X shape: {}, Y shape: {}".format(X.shape, Y.shape)) 69 | 70 | return (X[: X.shape[0] * 4 // 5], Y[: X.shape[0] * 4 // 5]), ( 71 | X[X.shape[0] * 4 // 5 :], 72 | Y[X.shape[0] * 4 // 5 :], 73 | ) 74 | 75 | 76 | def load_isolet(): 77 | x_train = np.genfromtxt( 78 | "/home/lemisma/datasets/isolet/isolet1+2+3+4.data", 79 | delimiter=",", 80 | usecols=range(0, 617), 81 | encoding="UTF-8", 82 | ) 83 | y_train = np.genfromtxt( 84 | "/home/lemisma/datasets/isolet/isolet1+2+3+4.data", 85 | delimiter=",", 86 | usecols=[617], 87 | encoding="UTF-8", 88 | ) 89 | x_test = np.genfromtxt( 90 | "/home/lemisma/datasets/isolet/isolet5.data", 91 | delimiter=",", 92 | usecols=range(0, 617), 93 | encoding="UTF-8", 94 | ) 95 | y_test = np.genfromtxt( 96 | "/home/lemisma/datasets/isolet/isolet5.data", 97 | delimiter=",", 98 | usecols=[617], 99 | encoding="UTF-8", 100 | ) 101 | 102 | X = MinMaxScaler(feature_range=(0, 1)).fit_transform( 103 | np.concatenate((x_train, x_test)) 104 | ) 105 | x_train = X[: len(y_train)] 106 | x_test = X[len(y_train) :] 107 | 108 | print(x_train.shape, y_train.shape) 109 | print(x_test.shape, y_test.shape) 110 | 111 | return (x_train, y_train - 1), (x_test, y_test - 1) 112 | 113 | 114 | import numpy as np 115 | 116 | 117 | def load_epileptic(): 118 | filling_value = -100000 119 | 120 | X = np.genfromtxt( 121 | "/home/lemisma/datasets/data.csv", 122 | delimiter=",", 123 | skip_header=1, 124 | usecols=range(1, 179), 125 | filling_values=filling_value, 126 | encoding="UTF-8", 127 | ) 128 | Y = np.genfromtxt( 129 | "/homelemisma/datasets/data.csv", 130 | delimiter=",", 131 | skip_header=1, 132 | usecols=range(179, 180), 133 | encoding="UTF-8", 134 | ) 135 | 136 | X = MinMaxScaler(feature_range=(0, 1)).fit_transform(X) 137 | 138 | indices = np.arange(X.shape[0]) 139 | np.random.shuffle(indices) 140 | X = X[indices] 141 | Y = Y[indices] 142 | 143 | print(X.shape, Y.shape) 144 | 145 | return (X[:8000], Y[:8000]), (X[8000:], Y[8000:]) 146 | 147 | 148 | import os 149 | 150 | from PIL import Image 151 | 152 | 153 | def load_coil(): 154 | samples = [] 155 | for i in range(1, 21): 156 | for image_index in range(72): 157 | obj_img = Image.open( 158 | os.path.join( 159 | "/home/lemisma/datasets/coil-20-proc", 160 | "obj%d__%d.png" % (i, image_index), 161 | ) 162 | ) 163 | rescaled = obj_img.resize((20, 20)) 164 | pixels_values = [float(x) for x in list(rescaled.getdata())] 165 | sample = np.array(pixels_values + [i]) 166 | samples.append(sample) 167 | samples = np.array(samples) 168 | np.random.shuffle(samples) 169 | data = samples[:, :-1] 170 | targets = (samples[:, -1] + 0.5).astype(np.int64) 171 | data = (data - data.min()) / (data.max() - data.min()) 172 | 173 | l = data.shape[0] * 4 // 5 174 | train = (data[:l], targets[:l] - 1) 175 | test = (data[l:], targets[l:] - 1) 176 | print(train[0].shape, train[1].shape) 177 | print(test[0].shape, test[1].shape) 178 | return train, test 179 | 180 | 181 | import tensorflow as tf 182 | from sklearn.model_selection import train_test_split 183 | 184 | 185 | def load_data(fashion=False, digit=None, normalize=False): 186 | if fashion: 187 | (x_train, y_train), (x_test, y_test) = ( 188 | tf.keras.datasets.fashion_mnist.load_data() 189 | ) 190 | else: 191 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() 192 | 193 | if digit is not None and 0 <= digit and digit <= 9: 194 | train = test = {y: [] for y in range(10)} 195 | for x, y in zip(x_train, y_train): 196 | train[y].append(x) 197 | for x, y in zip(x_test, y_test): 198 | test[y].append(x) 199 | 200 | for y in range(10): 201 | 202 | train[y] = np.asarray(train[y]) 203 | test[y] = np.asarray(test[y]) 204 | 205 | x_train = train[digit] 206 | x_test = test[digit] 207 | 208 | x_train = x_train.reshape((-1, x_train.shape[1] * x_train.shape[2])).astype( 209 | np.float32 210 | ) 211 | x_test = x_test.reshape((-1, x_test.shape[1] * x_test.shape[2])).astype(np.float32) 212 | 213 | if normalize: 214 | X = np.concatenate((x_train, x_test)) 215 | X = (X - X.min()) / (X.max() - X.min()) 216 | x_train = X[: len(y_train)] 217 | x_test = X[len(y_train) :] 218 | 219 | # print(x_train.shape, y_train.shape) 220 | # print(x_test.shape, y_test.shape) 221 | return (x_train, y_train), (x_test, y_test) 222 | 223 | 224 | def load_mnist(): 225 | train, test = load_data(fashion=False, normalize=True) 226 | x_train, x_test, y_train, y_test = train_test_split(test[0], test[1], test_size=0.2) 227 | return (x_train, y_train), (x_test, y_test) 228 | 229 | 230 | def load_fashion(): 231 | train, test = load_data(fashion=True, normalize=True) 232 | x_train, x_test, y_train, y_test = train_test_split(test[0], test[1], test_size=0.2) 233 | return (x_train, y_train), (x_test, y_test) 234 | 235 | 236 | def load_mnist_two_digits(digit1, digit2): 237 | train_digit_1, _ = load_data(digit=digit1) 238 | train_digit_2, _ = load_data(digit=digit2) 239 | 240 | X_train_1, X_test_1 = train_test_split(train_digit_1[0], test_size=0.6) 241 | X_train_2, X_test_2 = train_test_split(train_digit_2[0], test_size=0.6) 242 | 243 | X_train = np.concatenate((X_train_1, X_train_2)) 244 | y_train = np.array([0] * X_train_1.shape[0] + [1] * X_train_2.shape[0]) 245 | shuffled_idx = np.random.permutation(X_train.shape[0]) 246 | np.take(X_train, shuffled_idx, axis=0, out=X_train) 247 | np.take(y_train, shuffled_idx, axis=0, out=y_train) 248 | 249 | X_test = np.concatenate((X_test_1, X_test_2)) 250 | y_test = np.array([0] * X_test_1.shape[0] + [1] * X_test_2.shape[0]) 251 | shuffled_idx = np.random.permutation(X_test.shape[0]) 252 | np.take(X_test, shuffled_idx, axis=0, out=X_test) 253 | np.take(y_test, shuffled_idx, axis=0, out=y_test) 254 | 255 | print(X_train.shape, y_train.shape) 256 | print(X_test.shape, y_test.shape) 257 | 258 | return (X_train, y_train), (X_test, y_test) 259 | 260 | 261 | import os 262 | 263 | from sklearn.preprocessing import MinMaxScaler 264 | 265 | 266 | def load_activity(): 267 | x_train = np.loadtxt( 268 | os.path.join("/home/lemisma/datasets/dataset_uci", "final_X_train.txt"), 269 | delimiter=",", 270 | encoding="UTF-8", 271 | ) 272 | x_test = np.loadtxt( 273 | os.path.join("/home/lemisma/datasets/dataset_uci", "final_X_test.txt"), 274 | delimiter=",", 275 | encoding="UTF-8", 276 | ) 277 | y_train = ( 278 | np.loadtxt( 279 | os.path.join("/home/lemisma/datasets/dataset_uci", "final_y_train.txt"), 280 | delimiter=",", 281 | encoding="UTF-8", 282 | ) 283 | - 1 284 | ) 285 | y_test = ( 286 | np.loadtxt( 287 | os.path.join("/home/lemisma/datasets/dataset_uci", "final_y_test.txt"), 288 | delimiter=",", 289 | encoding="UTF-8", 290 | ) 291 | - 1 292 | ) 293 | 294 | X = MinMaxScaler(feature_range=(0, 1)).fit_transform( 295 | np.concatenate((x_train, x_test)) 296 | ) 297 | x_train = X[: len(y_train)] 298 | x_test = X[len(y_train) :] 299 | 300 | print(x_train.shape, y_train.shape) 301 | print(x_test.shape, y_test.shape) 302 | return (x_train, y_train), (x_test, y_test) 303 | 304 | 305 | def load_dataset(dataset): 306 | if dataset == "MNIST": 307 | return load_mnist() 308 | elif dataset == "MNIST-Fashion": 309 | return load_fashion() 310 | if dataset == "MICE": 311 | return load_mice() 312 | elif dataset == "COIL": 313 | return load_coil() 314 | elif dataset == "ISOLET": 315 | return load_isolet() 316 | elif dataset == "Activity": 317 | return load_activity() 318 | else: 319 | print("Please specify a valid dataset") 320 | return None 321 | -------------------------------------------------------------------------------- /experiments/run.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | from data_utils import load_dataset 5 | from sklearn.model_selection import train_test_split 6 | 7 | from lassonet import LassoNetClassifier 8 | from lassonet.utils import eval_on_path 9 | 10 | seed = None 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | batch_size = 256 13 | K = 50 # Number of features to select 14 | n_epochs = 1000 15 | dataset = "ISOLET" 16 | 17 | # Load dataset and split the data 18 | (X_train_valid, y_train_valid), (X_test, y_test) = load_dataset(dataset) 19 | X_train, X_val, y_train, y_val = train_test_split( 20 | X_train_valid, y_train_valid, test_size=0.125, random_state=seed 21 | ) 22 | 23 | # Set the dimensions of the hidden layers 24 | data_dim = X_test.shape[1] 25 | hidden_dim = (data_dim // 3,) 26 | 27 | # Initialize the LassoNetClassifier model and compute the path 28 | lasso_model = LassoNetClassifier( 29 | M=10, 30 | hidden_dims=hidden_dim, 31 | verbose=1, 32 | torch_seed=seed, 33 | random_state=seed, 34 | device=device, 35 | n_iters=n_epochs, 36 | batch_size=batch_size, 37 | ) 38 | path = lasso_model.path(X_train, y_train, X_val=X_val, y_val=y_val) 39 | 40 | # Select the features 41 | desired_save = next(save for save in path if save.selected.sum().item() <= K) 42 | SELECTED_FEATURES = desired_save.selected 43 | print("Number of selected features:", SELECTED_FEATURES.sum().item()) 44 | 45 | # Select the features from the training, validation, and test data 46 | X_train_selected = X_train[:, SELECTED_FEATURES] 47 | X_val_selected = X_val[:, SELECTED_FEATURES] 48 | X_test_selected = X_test[:, SELECTED_FEATURES] 49 | 50 | # Initialize another LassoNetClassifier for retraining with the selected features 51 | lasso_sparse = LassoNetClassifier( 52 | M=10, 53 | hidden_dims=hidden_dim, 54 | verbose=1, 55 | torch_seed=seed, 56 | random_state=seed, 57 | device=device, 58 | n_iters=n_epochs, 59 | ) 60 | path_sparse = lasso_sparse.path( 61 | X_train_selected, 62 | y_train, 63 | X_val=X_val_selected, 64 | y_val=y_val, 65 | lambda_seq=[0], 66 | return_state_dicts=True, 67 | )[:1] 68 | 69 | # Evaluate the model on the test data 70 | score = eval_on_path(lasso_sparse, path_sparse, X_test_selected, y_test) 71 | print("Test accuracy:", score) 72 | 73 | # Save the path 74 | with open(f"{dataset}_path.pkl", "wb") as f: 75 | pickle.dump(path_sparse, f) 76 | -------------------------------------------------------------------------------- /lassonet/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .interfaces import ( 3 | LassoNetClassifier, 4 | LassoNetClassifierCV, 5 | LassoNetCoxRegressor, 6 | LassoNetCoxRegressorCV, 7 | LassoNetRegressor, 8 | LassoNetRegressorCV, 9 | lassonet_path, 10 | ) 11 | from .model import LassoNet 12 | from .plot import plot_cv, plot_path 13 | from .prox import prox 14 | -------------------------------------------------------------------------------- /lassonet/cox.py: -------------------------------------------------------------------------------- 1 | """ 2 | implement CoxPHLoss 3 | """ 4 | 5 | __all__ = ["CoxPHLoss", "concordance_index"] 6 | 7 | import torch 8 | from sortedcontainers import SortedList 9 | 10 | from .utils import log_substract, scatter_logsumexp 11 | 12 | 13 | class CoxPHLoss(torch.nn.Module): 14 | """Loss for CoxPH model.""" 15 | 16 | allowed = ("breslow", "efron") 17 | 18 | def __init__(self, method): 19 | super().__init__() 20 | assert method in self.allowed, f"Method must be one of {self.allowed}" 21 | self.method = method 22 | 23 | def forward(self, log_h, y): 24 | log_h = log_h.flatten() 25 | 26 | durations, events = y.T 27 | 28 | # sort input 29 | durations, idx = durations.sort(descending=True) 30 | log_h = log_h[idx] 31 | events = events[idx] 32 | 33 | event_ind = events.nonzero().flatten() 34 | if event_ind.nelement() == 0: 35 | # return 0 while connecting the gradient 36 | return log_h[:0].sum() 37 | 38 | # numerator 39 | log_num = log_h[event_ind].mean() 40 | 41 | # logcumsumexp of events 42 | event_lcse = torch.logcumsumexp(log_h, dim=0)[event_ind] 43 | 44 | # number of events for each unique risk set 45 | _, tie_inverses, tie_count = torch.unique_consecutive( 46 | durations[event_ind], return_counts=True, return_inverse=True 47 | ) 48 | 49 | # position of last event (lowest duration) of each unique risk set 50 | tie_pos = tie_count.cumsum(axis=0) - 1 51 | 52 | # logcumsumexp by tie for each event 53 | event_tie_lcse = event_lcse[tie_pos][tie_inverses] 54 | 55 | if self.method == "breslow": 56 | log_den = event_tie_lcse.mean() 57 | 58 | elif self.method == "efron": 59 | # based on https://bydmitry.github.io/efron-tensorflow.html 60 | 61 | # logsumexp of ties, duplicated within tie set 62 | tie_lse = scatter_logsumexp(log_h[event_ind], tie_inverses, dim=0)[ 63 | tie_inverses 64 | ] 65 | # multiply (add in log space) with corrective factor 66 | aux = torch.ones_like(tie_inverses) 67 | aux[tie_pos[:-1] + 1] -= tie_count[:-1] 68 | event_id_in_tie = torch.cumsum(aux, dim=0) - 1 69 | discounted_tie_lse = ( 70 | tie_lse 71 | + torch.log(event_id_in_tie) 72 | - torch.log(tie_count[tie_inverses]) 73 | ) 74 | 75 | # denominator 76 | log_den = log_substract(event_tie_lcse, discounted_tie_lse).mean() 77 | 78 | # loss is negative log likelihood 79 | return log_den - log_num 80 | 81 | 82 | def concordance_index(risk, time, event): 83 | """ 84 | O(n log n) implementation of https://square.github.io/pysurvival/metrics/c_index.html 85 | """ 86 | assert len(risk) == len(time) == len(event) 87 | n = len(risk) 88 | order = sorted(range(n), key=time.__getitem__) 89 | past = SortedList() 90 | num = 0 91 | den = 0 92 | for i in order: 93 | num += len(past) - past.bisect_right(risk[i]) 94 | den += len(past) 95 | if event[i]: 96 | past.add(risk[i]) 97 | return num / den 98 | -------------------------------------------------------------------------------- /lassonet/model.py: -------------------------------------------------------------------------------- 1 | from itertools import islice 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .prox import inplace_group_prox, inplace_prox, prox 8 | 9 | 10 | class LassoNet(nn.Module): 11 | def __init__(self, *dims, groups=None, dropout=None): 12 | """ 13 | first dimension is input 14 | last dimension is output 15 | `groups` is a list of list such that `groups[i]` 16 | contains the indices of the features in the i-th group 17 | 18 | """ 19 | assert len(dims) > 2 20 | if groups is not None: 21 | n_inputs = dims[0] 22 | all_indices = [] 23 | for g in groups: 24 | for i in g: 25 | all_indices.append(i) 26 | assert len(all_indices) == n_inputs and set(all_indices) == set( 27 | range(n_inputs) 28 | ), f"Groups must be a partition of range(n_inputs={n_inputs})" 29 | 30 | self.groups = groups 31 | 32 | super().__init__() 33 | 34 | self.dropout = nn.Dropout(p=dropout) if dropout is not None else None 35 | self.layers = nn.ModuleList( 36 | [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)] 37 | ) 38 | self.skip = nn.Linear(dims[0], dims[-1], bias=False) 39 | 40 | def forward(self, inp): 41 | current_layer = inp 42 | result = self.skip(inp) 43 | for theta in self.layers: 44 | current_layer = theta(current_layer) 45 | if theta is not self.layers[-1]: 46 | if self.dropout is not None: 47 | current_layer = self.dropout(current_layer) 48 | current_layer = F.relu(current_layer) 49 | return result + current_layer 50 | 51 | def prox(self, *, lambda_, lambda_bar=0, M=1): 52 | if self.groups is None: 53 | with torch.no_grad(): 54 | inplace_prox( 55 | beta=self.skip, 56 | theta=self.layers[0], 57 | lambda_=lambda_, 58 | lambda_bar=lambda_bar, 59 | M=M, 60 | ) 61 | else: 62 | with torch.no_grad(): 63 | inplace_group_prox( 64 | groups=self.groups, 65 | beta=self.skip, 66 | theta=self.layers[0], 67 | lambda_=lambda_, 68 | lambda_bar=lambda_bar, 69 | M=M, 70 | ) 71 | 72 | def lambda_start( 73 | self, 74 | M=1, 75 | lambda_bar=0, 76 | factor=2, 77 | ): 78 | """Estimate when the model will start to sparsify.""" 79 | 80 | def is_sparse(lambda_): 81 | with torch.no_grad(): 82 | beta = self.skip.weight.data 83 | theta = self.layers[0].weight.data 84 | 85 | for _ in range(10000): 86 | new_beta, theta = prox( 87 | beta, 88 | theta, 89 | lambda_=lambda_, 90 | lambda_bar=lambda_bar, 91 | M=M, 92 | ) 93 | if torch.abs(beta - new_beta).max() < 1e-5: 94 | break 95 | beta = new_beta 96 | return (torch.norm(beta, p=2, dim=0) == 0).sum() 97 | 98 | start = 1e-6 99 | while not is_sparse(factor * start): 100 | start *= factor 101 | return start 102 | 103 | def l2_regularization(self): 104 | """ 105 | L2 regulatization of the MLP without the first layer 106 | which is bounded by the skip connection 107 | """ 108 | ans = 0 109 | for layer in islice(self.layers, 1, None): 110 | ans += ( 111 | torch.norm( 112 | layer.weight.data, 113 | p=2, 114 | ) 115 | ** 2 116 | ) 117 | return ans 118 | 119 | def l1_regularization_skip(self): 120 | return torch.norm(self.skip.weight.data, p=2, dim=0).sum() 121 | 122 | def l2_regularization_skip(self): 123 | return torch.norm(self.skip.weight.data, p=2) 124 | 125 | def input_mask(self): 126 | with torch.no_grad(): 127 | return torch.norm(self.skip.weight.data, p=2, dim=0) != 0 128 | 129 | def selected_count(self): 130 | return self.input_mask().sum().item() 131 | 132 | def cpu_state_dict(self): 133 | return {k: v.detach().clone().cpu() for k, v in self.state_dict().items()} 134 | -------------------------------------------------------------------------------- /lassonet/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from .interfaces import BaseLassoNetCV 4 | from .utils import confidence_interval, eval_on_path 5 | 6 | 7 | def plot_path(model, X_test, y_test, *, score_function=None): 8 | """ 9 | Plot the evolution of the model on the path, namely: 10 | - lambda 11 | - number of selected variables 12 | - score 13 | 14 | Requires to have called model.path(return_state_dicts=True) beforehand. 15 | 16 | 17 | Parameters 18 | ========== 19 | model : LassoNetClassifier or LassoNetRegressor 20 | X_test : array-like 21 | y_test : array-like 22 | score_function : function or None 23 | if None, use score_function=model.score 24 | score_function must take as input X_test, y_test 25 | """ 26 | # TODO: plot with manually computed score 27 | score = eval_on_path( 28 | model, model.path_, X_test, y_test, score_function=score_function 29 | ) 30 | n_selected = [save.selected.sum() for save in model.path_] 31 | lambda_ = [save.lambda_ for save in model.path_] 32 | 33 | plt.figure(figsize=(16, 16)) 34 | 35 | plt.subplot(311) 36 | plt.grid(True) 37 | plt.plot(n_selected, score, ".-") 38 | plt.xlabel("number of selected features") 39 | plt.ylabel("score") 40 | 41 | plt.subplot(312) 42 | plt.grid(True) 43 | plt.plot(lambda_, score, ".-") 44 | plt.xlabel("lambda") 45 | plt.xscale("log") 46 | plt.ylabel("score") 47 | 48 | plt.subplot(313) 49 | plt.grid(True) 50 | plt.plot(lambda_, n_selected, ".-") 51 | plt.xlabel("lambda") 52 | plt.xscale("log") 53 | plt.ylabel("number of selected features") 54 | 55 | plt.tight_layout() 56 | 57 | 58 | def plot_cv(model: BaseLassoNetCV, X_test, y_test, *, score_function=None): 59 | # TODO: plot with manually computed score 60 | lambda_ = [save.lambda_ for save in model.path_] 61 | lambdas = [[h.lambda_ for h in p] for p in model.raw_paths_] 62 | 63 | score = eval_on_path( 64 | model, model.path_, X_test, y_test, score_function=score_function 65 | ) 66 | 67 | plt.figure(figsize=(16, 16)) 68 | 69 | plt.subplot(211) 70 | plt.grid(True) 71 | first = True 72 | for sl, ss in zip(lambdas, model.raw_scores_): 73 | plt.plot( 74 | sl, 75 | ss, 76 | "r.-", 77 | markersize=5, 78 | alpha=0.2, 79 | label="cross-validation" if first else None, 80 | ) 81 | first = False 82 | avg = model.interp_scores_.mean(axis=1) 83 | ci = confidence_interval(model.interp_scores_) 84 | plt.plot( 85 | model.lambdas_, 86 | avg, 87 | "g.-", 88 | markersize=5, 89 | alpha=0.2, 90 | label="average cv with 95% CI", 91 | ) 92 | plt.fill_between(model.lambdas_, avg - ci, avg + ci, color="g", alpha=0.1) 93 | plt.plot(lambda_, score, "b.-", markersize=5, alpha=0.2, label="test") 94 | plt.legend() 95 | plt.xlabel("lambda") 96 | plt.xscale("log") 97 | plt.ylabel("score") 98 | 99 | plt.subplot(212) 100 | plt.grid(True) 101 | first = True 102 | for sl, path in zip(lambdas, model.raw_paths_): 103 | plt.plot( 104 | sl, 105 | [save.selected.sum() for save in path], 106 | "r.-", 107 | markersize=5, 108 | alpha=0.2, 109 | label="cross-validation" if first else None, 110 | ) 111 | first = False 112 | plt.plot( 113 | lambda_, 114 | [save.selected.sum() for save in model.path_], 115 | "b.-", 116 | markersize=5, 117 | alpha=0.2, 118 | label="test", 119 | ) 120 | plt.legend() 121 | plt.xlabel("lambda") 122 | plt.xscale("log") 123 | plt.ylabel("number of selected features") 124 | 125 | plt.tight_layout() 126 | -------------------------------------------------------------------------------- /lassonet/prox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def soft_threshold(l, x): 6 | return torch.sign(x) * torch.relu(torch.abs(x) - l) 7 | 8 | 9 | def sign_binary(x): 10 | ones = torch.ones_like(x) 11 | return torch.where(x >= 0, ones, -ones) 12 | 13 | 14 | def prox(v, u, *, lambda_, lambda_bar, M): 15 | """ 16 | v has shape (m,) or (m, batches) 17 | u has shape (k,) or (k, batches) 18 | 19 | supports GPU tensors 20 | """ 21 | onedim = len(v.shape) == 1 22 | if onedim: 23 | v = v.unsqueeze(-1) 24 | u = u.unsqueeze(-1) 25 | 26 | u_abs_sorted = torch.sort(u.abs(), dim=0, descending=True).values 27 | 28 | k, batch = u.shape 29 | 30 | s = torch.arange(k + 1.0).view(-1, 1).to(v) 31 | zeros = torch.zeros(1, batch).to(u) 32 | 33 | a_s = lambda_ - M * torch.cat( 34 | [zeros, torch.cumsum(u_abs_sorted - lambda_bar, dim=0)] 35 | ) 36 | 37 | norm_v = torch.norm(v, p=2, dim=0) 38 | 39 | x = F.relu(1 - a_s / norm_v) / (1 + s * M**2) 40 | 41 | w = M * x * norm_v 42 | intervals = soft_threshold(lambda_bar, u_abs_sorted) 43 | lower = torch.cat([intervals, zeros]) 44 | 45 | idx = torch.sum(lower > w, dim=0).unsqueeze(0) 46 | 47 | x_star = torch.gather(x, 0, idx).view(1, batch) 48 | w_star = torch.gather(w, 0, idx).view(1, batch) 49 | 50 | beta_star = x_star * v 51 | theta_star = sign_binary(u) * torch.min(soft_threshold(lambda_bar, u.abs()), w_star) 52 | 53 | if onedim: 54 | beta_star.squeeze_(-1) 55 | theta_star.squeeze_(-1) 56 | 57 | return beta_star, theta_star 58 | 59 | 60 | def inplace_prox(beta, theta, lambda_, lambda_bar, M): 61 | beta.weight.data, theta.weight.data = prox( 62 | beta.weight.data, theta.weight.data, lambda_=lambda_, lambda_bar=lambda_bar, M=M 63 | ) 64 | 65 | 66 | def inplace_group_prox(groups, beta, theta, lambda_, lambda_bar, M): 67 | """ 68 | groups is an iterable such that group[i] contains the indices of features in group i 69 | """ 70 | beta_ = beta.weight.data 71 | theta_ = theta.weight.data 72 | beta_ans = torch.empty_like(beta_) 73 | theta_ans = torch.empty_like(theta_) 74 | for g in groups: 75 | group_beta = beta_[:, g] 76 | group_beta_shape = group_beta.shape 77 | group_theta = theta_[:, g] 78 | group_theta_shape = group_theta.shape 79 | group_beta, group_theta = prox( 80 | group_beta.reshape(-1), 81 | group_theta.reshape(-1), 82 | lambda_=lambda_, 83 | lambda_bar=lambda_bar, 84 | M=M, 85 | ) 86 | beta_ans[:, g] = group_beta.reshape(*group_beta_shape) 87 | theta_ans[:, g] = group_theta.reshape(*group_theta_shape) 88 | beta.weight.data, theta.weight.data = beta_ans, theta_ans 89 | -------------------------------------------------------------------------------- /lassonet/r.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .interfaces import LassoNetClassifier, LassoNetRegressor 7 | from .interfaces import lassonet_path as _lassonet_path 8 | 9 | 10 | def make_writable(x): 11 | if isinstance(x, np.ndarray): 12 | x = x.copy() 13 | return x 14 | 15 | 16 | def lassonet_path(X, y, task, *args, **kwargs): 17 | X = make_writable(X) 18 | y = make_writable(y) 19 | 20 | def convert_item(item): 21 | item = asdict(item) 22 | item["state_dict"] = {k: v.numpy() for k, v in item["state_dict"].items()} 23 | item["selected"] = item["selected"].numpy() 24 | return item 25 | 26 | return list(map(convert_item, _lassonet_path(X, y, task, *args, **kwargs))) 27 | 28 | 29 | def lassonet_eval(X, task, state_dict, **kwargs): 30 | X = make_writable(X) 31 | 32 | if task == "classification": 33 | model = LassoNetClassifier(**kwargs) 34 | elif task == "regression": 35 | model = LassoNetRegressor(**kwargs) 36 | else: 37 | raise ValueError('task must be "classification" or "regression"') 38 | state_dict = {k: torch.tensor(v) for k, v in state_dict.items()} 39 | model.load(state_dict) 40 | if hasattr(model, "predict_proba"): 41 | return model.predict_proba(X) 42 | else: 43 | return model.predict(X) 44 | -------------------------------------------------------------------------------- /lassonet/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | from typing import TYPE_CHECKING, Iterable, List 3 | 4 | import scipy.stats 5 | import torch 6 | 7 | if TYPE_CHECKING: 8 | from lassonet.interfaces import HistoryItem 9 | 10 | 11 | def eval_on_path(model, path, X_test, y_test, *, score_function=None): 12 | if score_function is None: 13 | score_fun = model.score 14 | else: 15 | assert callable(score_function) 16 | 17 | def score_fun(X_test, y_test): 18 | return score_function(y_test, model.predict(X_test)) 19 | 20 | score = [] 21 | for save in path: 22 | model.load(save.state_dict) 23 | score.append(score_fun(X_test, y_test)) 24 | return score 25 | 26 | 27 | if hasattr(torch.Tensor, "scatter_reduce_"): 28 | # version >= 1.12 29 | def scatter_reduce(input, dim, index, reduce, *, output_size=None): 30 | src = input 31 | if output_size is None: 32 | output_size = index.max() + 1 33 | return torch.empty(output_size, device=input.device).scatter_reduce( 34 | dim=dim, index=index, src=src, reduce=reduce, include_self=False 35 | ) 36 | 37 | else: 38 | scatter_reduce = torch.scatter_reduce 39 | 40 | 41 | def scatter_logsumexp(input, index, *, dim=-1, output_size=None): 42 | """Inspired by torch_scatter.logsumexp 43 | Uses torch.scatter_reduce for performance 44 | """ 45 | max_value_per_index = scatter_reduce( 46 | input, dim=dim, index=index, output_size=output_size, reduce="amax" 47 | ) 48 | max_per_src_element = max_value_per_index.gather(dim, index) 49 | recentered_scores = input - max_per_src_element 50 | sum_per_index = scatter_reduce( 51 | recentered_scores.exp(), 52 | dim=dim, 53 | index=index, 54 | output_size=output_size, 55 | reduce="sum", 56 | ) 57 | return max_value_per_index + sum_per_index.log() 58 | 59 | 60 | def log_substract(x, y): 61 | """log(exp(x) - exp(y))""" 62 | return x + torch.log1p(-(y - x).exp()) 63 | 64 | 65 | def confidence_interval(data, confidence=0.95): 66 | if isinstance(data[0], Iterable): 67 | return [confidence_interval(d, confidence) for d in data] 68 | return scipy.stats.t.interval( 69 | confidence, 70 | len(data) - 1, 71 | scale=scipy.stats.sem(data), 72 | )[1] 73 | 74 | 75 | def selection_probability(paths: List[List["HistoryItem"]]): 76 | """Compute the selection probability of each feature at each step. 77 | The individual curves are smoothed to that they are monotonically decreasing. 78 | 79 | Input 80 | ----- 81 | paths: List of List of HistoryItem 82 | The lambda paths must be the same for all models. 83 | 84 | Output 85 | ------ 86 | prob: torch.Tensor 87 | Tensor of shape (n_steps, n_features) containing the selection probability 88 | of each feature at lambda value. 89 | expected_wrong: tuple of (Tensor, LongTensor) 90 | Expected number of wrong features. 91 | (values, indices) where values are the expected number of wrong features 92 | and indices are the order of the selected features. 93 | """ 94 | n_models = len(paths) 95 | 96 | prob = [] 97 | selected = torch.ones_like(paths[0][0].selected) 98 | iterable = zip_longest( 99 | *[[it.selected for it in path] for path in paths], 100 | fillvalue=torch.zeros_like(paths[0][0].selected), 101 | ) 102 | for its in iterable: 103 | sel = sum(its) / n_models 104 | selected = torch.minimum(selected, sel) 105 | prob.append(selected) 106 | prob = torch.stack(prob) 107 | 108 | expected_wrong = ( 109 | prob.shape[1] * (prob.mean(dim=1, keepdim=True)) ** 2 / (2 * prob - 1) 110 | ) 111 | expected_wrong[prob <= 0.5] = float("inf") 112 | return prob, expected_wrong.min(axis=0).values.sort() 113 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import setup 4 | 5 | 6 | def read(fname): 7 | return (Path(__file__).parent / fname).open().read() 8 | 9 | 10 | setup( 11 | name="lassonet", 12 | version="0.0.20", 13 | author="Louis Abraham, Ismael Lemhadri", 14 | author_email="louis.abraham@yahoo.fr, lemhadri@stanford.edu", 15 | license="MIT", 16 | description="Reference implementation of LassoNet", 17 | long_description=read("README.md"), 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/lasso-net/lassonet", 20 | classifiers=[ 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: MIT License", 23 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 24 | "Operating System :: OS Independent", 25 | ], 26 | packages=["lassonet"], 27 | install_requires=[ 28 | "torch >= 1.11", 29 | "scikit-learn", 30 | "matplotlib", 31 | "sortedcontainers", 32 | "tqdm", 33 | ], 34 | extras_require={ 35 | "dev": [ 36 | "pre-commit", 37 | "sphinx", 38 | "black", 39 | ] 40 | }, 41 | tests_require=["pytest"], 42 | python_requires=">=3.8", 43 | ) 44 | -------------------------------------------------------------------------------- /sphinx_docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /sphinx_docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("..")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "LassoNet" 22 | copyright = "2021, Louis Abraham, Ismael Lemhadri" 23 | author = "Louis Abraham, Ismael Lemhadri" 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx.ext.autosummary"] 32 | 33 | # Add any paths that contain templates here, relative to this directory. 34 | templates_path = ["_templates"] 35 | 36 | # List of patterns, relative to source directory, that match files and 37 | # directories to ignore when looking for source files. 38 | # This pattern also affects html_static_path and html_extra_path. 39 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 40 | 41 | 42 | # -- Options for HTML output ------------------------------------------------- 43 | 44 | # The theme to use for HTML and HTML Help pages. See the documentation for 45 | # a list of builtin themes. 46 | # 47 | html_theme = "alabaster" 48 | 49 | # Add any paths that contain custom static files (such as style sheets) here, 50 | # relative to this directory. They are copied after the builtin static files, 51 | # so a file named "default.css" will overwrite the builtin "default.css". 52 | # html_static_path = ["_static"] 53 | 54 | 55 | # -- extensions options ------------------------------------------------------ 56 | autoclass_content = "both" 57 | 58 | autodoc_default_flags = ["members"] 59 | autosummary_generate = True 60 | -------------------------------------------------------------------------------- /sphinx_docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to LassoNet's documentation! 2 | ==================================== 3 | 4 | `Go to main website <../>`_ 5 | 6 | Installation 7 | ------------ 8 | 9 | :: 10 | 11 | pip install lassonet 12 | 13 | 14 | API 15 | --- 16 | 17 | .. autosummary:: 18 | :toctree: 19 | 20 | .. autoclass:: lassonet.LassoNetRegressor 21 | :members: 22 | :inherited-members: 23 | 24 | .. autoclass:: lassonet.LassoNetClassifier 25 | :members: 26 | :inherited-members: 27 | 28 | .. autoclass:: lassonet.LassoNetCoxRegressor 29 | :members: 30 | :inherited-members: 31 | 32 | .. autoclass:: lassonet.LassoNetRegressorCV 33 | :members: 34 | :inherited-members: 35 | 36 | .. autoclass:: lassonet.LassoNetClassifierCV 37 | :members: 38 | :inherited-members: 39 | 40 | .. autoclass:: lassonet.LassoNetCoxRegressorCV 41 | :members: 42 | :inherited-members: 43 | 44 | .. autofunction:: lassonet.plot_path 45 | 46 | .. autofunction:: lassonet.lassonet_path 47 | -------------------------------------------------------------------------------- /tests/test_interface.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_diabetes, load_digits 2 | 3 | from lassonet import LassoNetClassifier, LassoNetRegressor 4 | 5 | 6 | def test_regressor(): 7 | X, y = load_diabetes(return_X_y=True) 8 | model = LassoNetRegressor() 9 | model.fit(X, y) 10 | model.score(X, y) 11 | 12 | 13 | def test_classifier(): 14 | X, y = load_digits(return_X_y=True) 15 | model = LassoNetClassifier() 16 | model.fit(X, y) 17 | model.score(X, y) 18 | --------------------------------------------------------------------------------