├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── LICENSE
├── Makefile
├── README.md
├── citation.bib
├── docs
├── .nojekyll
├── api
│ ├── _static
│ │ ├── alabaster.css
│ │ ├── basic.css
│ │ ├── custom.css
│ │ ├── doctools.js
│ │ ├── documentation_options.js
│ │ ├── file.png
│ │ ├── github-banner.svg
│ │ ├── language_data.js
│ │ ├── minus.png
│ │ ├── plus.png
│ │ ├── pygments.css
│ │ ├── searchtools.js
│ │ └── sphinx_highlight.js
│ ├── genindex.html
│ ├── index.html
│ ├── objects.inv
│ ├── search.html
│ └── searchindex.js
├── fig1.png
├── images
│ └── video_screenshot.png
├── index.html
└── style.css
├── examples
├── Data_Cortex_Nuclear.csv
├── accuracy.png
├── boston_housing.py
├── boston_housing_group.py
├── cox_experiments.py
├── cox_regression.py
├── data
│ ├── hnscc_x.csv
│ └── hnscc_y.csv
├── diabetes.py
├── friedman.py
├── friedman
│ ├── download.sh
│ └── main.py
├── generated.py
├── miceprotein.py
├── mnist_ae.py
├── mnist_classif.py
└── mnist_reconstruction.py
├── experiments
├── README.MD
├── data_utils.py
└── run.py
├── lassonet
├── __init__.py
├── cox.py
├── interfaces.py
├── model.py
├── plot.py
├── prox.py
├── r.py
└── utils.py
├── setup.py
├── sphinx_docs
├── Makefile
├── conf.py
└── index.rst
└── tests
└── test_interface.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | Icon
3 | __pycache__
4 | .pytest_cache
5 | *.egg-info/
6 | .DS_Store
7 | sphinx_docs/_*/
8 | docs/api/_sources
9 | dist/
10 | build/
11 | *.png
12 | *.jpg
13 | *.csv
14 | .vscode
15 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "examples/spinet"]
2 | path = examples/spinet
3 | url = https://github.com/meixide/spinet
4 | ignore = dirty
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: black
5 | name: Black Formatter
6 | entry: black .
7 | language: system
8 | types: [python]
9 |
10 | - id: isort
11 | name: isort
12 | entry: isort --profile=black .
13 | language: system
14 | types: [python]
15 |
16 | - id: make-docs
17 | name: Generate Docs
18 | entry: make docs
19 | language: system
20 | pass_filenames: false
21 |
22 | - id: add-docs
23 | name: Stage Docs
24 | entry: git add docs/api/
25 | language: system
26 | pass_filenames: false
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2020 Louis Abraham, Ismael Lemhadri
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | pypi: dist
2 | twine upload dist/*
3 |
4 | dist:
5 | - rm -rf dist
6 | python3 setup.py sdist bdist_wheel
7 |
8 | docs:
9 | cd sphinx_docs && $(MAKE) html
10 | - rm -rf docs/api
11 | mkdir docs/api
12 | cp -r sphinx_docs/_build/html/* docs/api
13 |
14 | .PHONY: docs dist
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://badge.fury.io/py/lassonet)
2 | [](https://pepy.tech/project/lassonet)
3 |
4 | # LassoNet
5 |
6 | LassoNet is a new family of models to incorporate feature selection and neural networks.
7 |
8 | LassoNet works by adding a linear skip connection from the input features to the output. A L1 penalty (LASSO-inspired) is added to that skip connection along with a constraint on the network so that whenever a feature is ignored by the skip connection, it is ignored by the whole network.
9 |
10 |
11 |
12 | ## Installation
13 |
14 | ```
15 | pip install lassonet
16 | ```
17 |
18 | ## Usage
19 |
20 | We have designed the code to follow scikit-learn's standards to the extent possible (e.g. [linear_model.Lasso](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html)).
21 |
22 | ```python
23 | from lassonet import LassoNetClassifierCV
24 | model = LassoNetClassifierCV() # LassoNetRegressorCV
25 | path = model.fit(X_train, y_train)
26 | print("Best model scored", model.score(X_test, y_test))
27 | print("Lambda =", model.best_lambda_)
28 | ```
29 |
30 | You should always try to give normalized data to LassoNet as it uses neural networks under the hood.
31 |
32 | You can read the full [documentation](https://lasso-net.github.io//lassonet/api/) or read the [examples](https://github.com/lasso-net/lassonet/tree/master/examples) that cover most features. We also provide a Quickstart section below.
33 |
34 |
35 |
36 | ## Quickstart
37 |
38 | Here we guide you through the features of LassoNet and how you typically use them.
39 |
40 | ### Task
41 |
42 | LassoNet is based on neural networks and can be used for any kind of data. Currently, we have implemented losses for the following tasks:
43 |
44 | - regression: `LassoNetRegressor`
45 | - classification: `LassoNetClassifier`
46 | - Cox regression: `LassoNetCoxRegressor`
47 | - interval-censored Cox regression: `LassoNetIntervalRegressor`
48 |
49 | If features naturally belong to groups, you can use the `groups` parameter to specify them. This will allow the model to put a penalty on groups of features instead of each feature individually.
50 |
51 | ### Data preparation
52 |
53 | You should always normalize your data before passing it to the model to avoid too large (or too small) values in the data.
54 |
55 | ### What do you want to do?
56 |
57 | The LassoNet family of models do a lot of things.
58 |
59 | Here are some examples of what you can do with LassoNet. Note that you can switch `LassoNetRegressor` with any of the other models to perform the same operations.
60 |
61 | #### Using the base interface
62 |
63 | The original paper describes how to train LassoNet along a regularization path. This requires the user to manually select a model from the path and made the `.fit()` method useless since the resulting model is always empty. This feature is still available with the `.path(return_state_dicts=True)` method for any base model and returns a list of checkpoints that can be loaded with `.load()`.
64 |
65 |
66 | ```python
67 | from lassonet import LassoNetRegressor, plot_path
68 |
69 | model = LassoNetRegressor()
70 | path = model.path(X_train, y_train, return_state_dicts=True)
71 | plot_path(model, X_test, y_test)
72 |
73 | # choose `best_id` based on the plot
74 | model.load(path[best_id].state_dict)
75 | print(model.score(X_test, y_test))
76 | ```
77 |
78 | You can also retrieve the mask of the selected features and train a dense model on the selected features.
79 |
80 | ```python
81 | selected = path[best_id].selected
82 | model.fit(X_train[:, selected], y_train, dense_only=True)
83 | print(model.score(X_test[:, selected], y_test))
84 | ```
85 |
86 | You get a `model.feature_importances_` attribute that is the value of the L1 regularization parameter at which each feature is removed. This can give you an idea of the most important features but is very unstable across different runs. You should use stability selection to select the most stable features.
87 |
88 | #### Using the cross-validation interface
89 |
90 |
91 | We integrated support for cross-validation (5-fold by default) in the estimators whose name ends with `CV`. For each fold, a path is trained. The best regularization value is then chosen to maximize the average score over all folds. The model is then retrained on the whole training dataset to reach that regularization.
92 |
93 | ```python
94 | model = LassoNetRegressorCV()
95 | model.fit(X_train, y_train)
96 | model.score(X_test, y_test)
97 | ```
98 |
99 | You can also use the `plot_cv` method to get more information.
100 |
101 | Some attributes give you more information about the best model, like `best_lambda_`, `best_selected_` or `best_cv_score_`.
102 |
103 | This information is useful to pass to a base model to train it from scratch with the best regularization parameter or the best subset of features.
104 |
105 | #### Using the stability selection interface
106 |
107 |
108 | [Stability selection](https://arxiv.org/abs/0809.2932) is a method to select the most stable features when running the model multiple times on different random subsamples of the data. It is probably the best way to select the most important features.
109 |
110 | ```python
111 | model = LassoNetRegressor()
112 | oracle, order, wrong, paths, prob = model.stability_selection(X_train, y_train)
113 | ```
114 |
115 | - `oracle` is a heuristic that can detect the most stable features when introducing noise.
116 | - `order` sorts the features by their decreasing importance.
117 | - `wrong[k]` is a measure of error when selecting the k+1 first features (read [this paper](https://arxiv.org/pdf/2206.06885) for more details). You can `plt.plot(wrong)` to see the error as a function of the number of selected features.
118 | - `paths` stores all the computed paths.
119 | - `prob` is the probability that a feature is selected at each value of the regularization parameter.
120 |
121 | In practice, you might want to train multiple dense models on different subsets of features to get a better understanding of the importance of each feature.
122 |
123 | For example:
124 |
125 | ```python
126 | for i in range(10):
127 | selected = order[:i]
128 | model.fit(X_train[:, selected], y_train, dense_only=True)
129 | print(model.score(X_test[:, selected], y_test))
130 | ```
131 |
132 | ### Important parameters
133 |
134 | Here are the most important parameters you should be aware of:
135 |
136 | - `hidden_dims`: the number of neurons in each hidden layer. The default value is `(100,)` but you might want to try smaller and deeper networks like `(10, 10)`.
137 | - `path_multiplier`: the number of lambda values to compute on the path. The lower it is, the more precise the model is but the more time it takes. The default value is a trade-off to get a fast training but you might want to try smaller values like `1.01` or `1.005` to get a better model.
138 | - `lambda_start`: the starting value of the regularization parameter. The default value is `"auto"` and the model will try to select a good starting value according to an unpublised heuristic (read the code to know more). You can identify a bad `lambda_start` by plotting the path. If `lambda_start` is too small, the model will stay dense for a long time, which does not affect performance but takes longer. If `lambda_start` is too large, the number of features with decrease very fast and the path will not be accurate. In that case you might also want to decrease `lambda_start`.
139 | - `gamma`: puts some L2 penalty on the network. The default is `0.0` which means no penalty but some small value can improve the performance, especially on small datasets.
140 | - more standard MLP training parameters are accessible: `dropout`, `batch_size`, `optim`, `n_iters`, `patience`, `tol`, `backtrack`, `val_size`. In particular, `batch_size` can be useful to do stochastic gradient descent instead of full batch gradient descent and to avoid memory issues on large datasets.
141 | - `M`: this parameter has almost no effect on the model.
142 |
143 | ## Features
144 |
145 | - regression, classification, [Cox regression](https://en.wikipedia.org/wiki/Proportional_hazards_model) and [interval-censored Cox regression](https://arxiv.org/abs/2206.06885) with `LassoNetRegressor`, `LassoNetClassifier`, `LassoNetCoxRegressor` and `LassoNetIntervalRegressor`.
146 | - cross-validation with `LassoNetRegressorCV`, `LassoNetClassifierCV`, `LassoNetCoxRegressorCV` and `LassoNetIntervalRegressorCV`
147 | - [stability selection](https://arxiv.org/abs/0809.2932) with `model.stability_selection()`
148 | - group feature selection with the `groups` argument
149 | - `lambda_start="auto"` heuristic (default)
150 |
151 | Note that cross-validation, group feature selection and automatic `lambda_start` selection have not been published in papers, you can read the code or [post as issue](https://github.com/lasso-net/lassonet/issues/new) to request more details.
152 |
153 | We are currently working (among others) on adding support for convolution layers, auto-encoders and online logging of experiments.
154 |
155 | ## Website
156 |
157 | LassoNet's website is [https:lasso-net.github.io/](https://lasso-net.github.io/). It contains many useful references including the paper, live talks and additional documentation.
158 |
159 | ## References
160 |
161 | - Lemhadri, Ismael, Feng Ruan, Louis Abraham, and Robert Tibshirani. "LassoNet: A Neural Network with Feature Sparsity." Journal of Machine Learning Research 22, no. 127 (2021). [pdf](https://arxiv.org/pdf/1907.12207.pdf) [bibtex](https://github.com/lasso-net/lassonet/blob/master/citation.bib)
162 | - Yang, Xuelin, Louis Abraham, Sejin Kim, Petr Smirnov, Feng Ruan, Benjamin Haibe-Kains, and Robert Tibshirani. "FastCPH: Efficient Survival Analysis for Neural Networks." In NeurIPS 2022 Workshop on Learning from Time Series for Health. [pdf](https://arxiv.org/pdf/2208.09793.pdf)
163 | - Meixide, Carlos García, Marcos Matabuena, Louis Abraham, and Michael R. Kosorok. "Neural interval‐censored survival regression with feature selection." Statistical Analysis and Data Mining: The ASA Data Science Journal 17, no. 4 (2024): e11704. [pdf](https://arxiv.org/pdf/2206.06885)
--------------------------------------------------------------------------------
/citation.bib:
--------------------------------------------------------------------------------
1 | @article{JMLR:v22:20-848,
2 | author = {Ismael Lemhadri and Feng Ruan and Louis Abraham and Robert Tibshirani},
3 | title = {LassoNet: A Neural Network with Feature Sparsity},
4 | journal = {Journal of Machine Learning Research},
5 | year = {2021},
6 | volume = {22},
7 | number = {127},
8 | pages = {1-29},
9 | url = {http://jmlr.org/papers/v22/20-848.html}
10 | }
11 |
--------------------------------------------------------------------------------
/docs/.nojekyll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/.nojekyll
--------------------------------------------------------------------------------
/docs/api/_static/alabaster.css:
--------------------------------------------------------------------------------
1 | /* -- page layout ----------------------------------------------------------- */
2 |
3 | body {
4 | font-family: Georgia, serif;
5 | font-size: 17px;
6 | background-color: #fff;
7 | color: #000;
8 | margin: 0;
9 | padding: 0;
10 | }
11 |
12 |
13 | div.document {
14 | width: 940px;
15 | margin: 30px auto 0 auto;
16 | }
17 |
18 | div.documentwrapper {
19 | float: left;
20 | width: 100%;
21 | }
22 |
23 | div.bodywrapper {
24 | margin: 0 0 0 220px;
25 | }
26 |
27 | div.sphinxsidebar {
28 | width: 220px;
29 | font-size: 14px;
30 | line-height: 1.5;
31 | }
32 |
33 | hr {
34 | border: 1px solid #B1B4B6;
35 | }
36 |
37 | div.body {
38 | background-color: #fff;
39 | color: #3E4349;
40 | padding: 0 30px 0 30px;
41 | }
42 |
43 | div.body > .section {
44 | text-align: left;
45 | }
46 |
47 | div.footer {
48 | width: 940px;
49 | margin: 20px auto 30px auto;
50 | font-size: 14px;
51 | color: #888;
52 | text-align: right;
53 | }
54 |
55 | div.footer a {
56 | color: #888;
57 | }
58 |
59 | p.caption {
60 | font-family: inherit;
61 | font-size: inherit;
62 | }
63 |
64 |
65 | div.relations {
66 | display: none;
67 | }
68 |
69 |
70 | div.sphinxsidebar {
71 | max-height: 100%;
72 | overflow-y: auto;
73 | }
74 |
75 | div.sphinxsidebar a {
76 | color: #444;
77 | text-decoration: none;
78 | border-bottom: 1px dotted #999;
79 | }
80 |
81 | div.sphinxsidebar a:hover {
82 | border-bottom: 1px solid #999;
83 | }
84 |
85 | div.sphinxsidebarwrapper {
86 | padding: 18px 10px;
87 | }
88 |
89 | div.sphinxsidebarwrapper p.logo {
90 | padding: 0;
91 | margin: -10px 0 0 0px;
92 | text-align: center;
93 | }
94 |
95 | div.sphinxsidebarwrapper h1.logo {
96 | margin-top: -10px;
97 | text-align: center;
98 | margin-bottom: 5px;
99 | text-align: left;
100 | }
101 |
102 | div.sphinxsidebarwrapper h1.logo-name {
103 | margin-top: 0px;
104 | }
105 |
106 | div.sphinxsidebarwrapper p.blurb {
107 | margin-top: 0;
108 | font-style: normal;
109 | }
110 |
111 | div.sphinxsidebar h3,
112 | div.sphinxsidebar h4 {
113 | font-family: Georgia, serif;
114 | color: #444;
115 | font-size: 24px;
116 | font-weight: normal;
117 | margin: 0 0 5px 0;
118 | padding: 0;
119 | }
120 |
121 | div.sphinxsidebar h4 {
122 | font-size: 20px;
123 | }
124 |
125 | div.sphinxsidebar h3 a {
126 | color: #444;
127 | }
128 |
129 | div.sphinxsidebar p.logo a,
130 | div.sphinxsidebar h3 a,
131 | div.sphinxsidebar p.logo a:hover,
132 | div.sphinxsidebar h3 a:hover {
133 | border: none;
134 | }
135 |
136 | div.sphinxsidebar p {
137 | color: #555;
138 | margin: 10px 0;
139 | }
140 |
141 | div.sphinxsidebar ul {
142 | margin: 10px 0;
143 | padding: 0;
144 | color: #000;
145 | }
146 |
147 | div.sphinxsidebar ul li.toctree-l1 > a {
148 | font-size: 120%;
149 | }
150 |
151 | div.sphinxsidebar ul li.toctree-l2 > a {
152 | font-size: 110%;
153 | }
154 |
155 | div.sphinxsidebar input {
156 | border: 1px solid #CCC;
157 | font-family: Georgia, serif;
158 | font-size: 1em;
159 | }
160 |
161 | div.sphinxsidebar #searchbox {
162 | margin: 1em 0;
163 | }
164 |
165 | div.sphinxsidebar .search > div {
166 | display: table-cell;
167 | }
168 |
169 | div.sphinxsidebar hr {
170 | border: none;
171 | height: 1px;
172 | color: #AAA;
173 | background: #AAA;
174 |
175 | text-align: left;
176 | margin-left: 0;
177 | width: 50%;
178 | }
179 |
180 | div.sphinxsidebar .badge {
181 | border-bottom: none;
182 | }
183 |
184 | div.sphinxsidebar .badge:hover {
185 | border-bottom: none;
186 | }
187 |
188 | /* To address an issue with donation coming after search */
189 | div.sphinxsidebar h3.donation {
190 | margin-top: 10px;
191 | }
192 |
193 | /* -- body styles ----------------------------------------------------------- */
194 |
195 | a {
196 | color: #004B6B;
197 | text-decoration: underline;
198 | }
199 |
200 | a:hover {
201 | color: #6D4100;
202 | text-decoration: underline;
203 | }
204 |
205 | div.body h1,
206 | div.body h2,
207 | div.body h3,
208 | div.body h4,
209 | div.body h5,
210 | div.body h6 {
211 | font-family: Georgia, serif;
212 | font-weight: normal;
213 | margin: 30px 0px 10px 0px;
214 | padding: 0;
215 | }
216 |
217 | div.body h1 { margin-top: 0; padding-top: 0; font-size: 240%; }
218 | div.body h2 { font-size: 180%; }
219 | div.body h3 { font-size: 150%; }
220 | div.body h4 { font-size: 130%; }
221 | div.body h5 { font-size: 100%; }
222 | div.body h6 { font-size: 100%; }
223 |
224 | a.headerlink {
225 | color: #DDD;
226 | padding: 0 4px;
227 | text-decoration: none;
228 | }
229 |
230 | a.headerlink:hover {
231 | color: #444;
232 | background: #EAEAEA;
233 | }
234 |
235 | div.body p, div.body dd, div.body li {
236 | line-height: 1.4em;
237 | }
238 |
239 | div.admonition {
240 | margin: 20px 0px;
241 | padding: 10px 30px;
242 | background-color: #EEE;
243 | border: 1px solid #CCC;
244 | }
245 |
246 | div.admonition tt.xref, div.admonition code.xref, div.admonition a tt {
247 | background-color: #FBFBFB;
248 | border-bottom: 1px solid #fafafa;
249 | }
250 |
251 | div.admonition p.admonition-title {
252 | font-family: Georgia, serif;
253 | font-weight: normal;
254 | font-size: 24px;
255 | margin: 0 0 10px 0;
256 | padding: 0;
257 | line-height: 1;
258 | }
259 |
260 | div.admonition p.last {
261 | margin-bottom: 0;
262 | }
263 |
264 | dt:target, .highlight {
265 | background: #FAF3E8;
266 | }
267 |
268 | div.warning {
269 | background-color: #FCC;
270 | border: 1px solid #FAA;
271 | }
272 |
273 | div.danger {
274 | background-color: #FCC;
275 | border: 1px solid #FAA;
276 | -moz-box-shadow: 2px 2px 4px #D52C2C;
277 | -webkit-box-shadow: 2px 2px 4px #D52C2C;
278 | box-shadow: 2px 2px 4px #D52C2C;
279 | }
280 |
281 | div.error {
282 | background-color: #FCC;
283 | border: 1px solid #FAA;
284 | -moz-box-shadow: 2px 2px 4px #D52C2C;
285 | -webkit-box-shadow: 2px 2px 4px #D52C2C;
286 | box-shadow: 2px 2px 4px #D52C2C;
287 | }
288 |
289 | div.caution {
290 | background-color: #FCC;
291 | border: 1px solid #FAA;
292 | }
293 |
294 | div.attention {
295 | background-color: #FCC;
296 | border: 1px solid #FAA;
297 | }
298 |
299 | div.important {
300 | background-color: #EEE;
301 | border: 1px solid #CCC;
302 | }
303 |
304 | div.note {
305 | background-color: #EEE;
306 | border: 1px solid #CCC;
307 | }
308 |
309 | div.tip {
310 | background-color: #EEE;
311 | border: 1px solid #CCC;
312 | }
313 |
314 | div.hint {
315 | background-color: #EEE;
316 | border: 1px solid #CCC;
317 | }
318 |
319 | div.seealso {
320 | background-color: #EEE;
321 | border: 1px solid #CCC;
322 | }
323 |
324 | div.topic {
325 | background-color: #EEE;
326 | }
327 |
328 | p.admonition-title {
329 | display: inline;
330 | }
331 |
332 | p.admonition-title:after {
333 | content: ":";
334 | }
335 |
336 | pre, tt, code {
337 | font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace;
338 | font-size: 0.9em;
339 | }
340 |
341 | .hll {
342 | background-color: #FFC;
343 | margin: 0 -12px;
344 | padding: 0 12px;
345 | display: block;
346 | }
347 |
348 | img.screenshot {
349 | }
350 |
351 | tt.descname, tt.descclassname, code.descname, code.descclassname {
352 | font-size: 0.95em;
353 | }
354 |
355 | tt.descname, code.descname {
356 | padding-right: 0.08em;
357 | }
358 |
359 | img.screenshot {
360 | -moz-box-shadow: 2px 2px 4px #EEE;
361 | -webkit-box-shadow: 2px 2px 4px #EEE;
362 | box-shadow: 2px 2px 4px #EEE;
363 | }
364 |
365 | table.docutils {
366 | border: 1px solid #888;
367 | -moz-box-shadow: 2px 2px 4px #EEE;
368 | -webkit-box-shadow: 2px 2px 4px #EEE;
369 | box-shadow: 2px 2px 4px #EEE;
370 | }
371 |
372 | table.docutils td, table.docutils th {
373 | border: 1px solid #888;
374 | padding: 0.25em 0.7em;
375 | }
376 |
377 | table.field-list, table.footnote {
378 | border: none;
379 | -moz-box-shadow: none;
380 | -webkit-box-shadow: none;
381 | box-shadow: none;
382 | }
383 |
384 | table.footnote {
385 | margin: 15px 0;
386 | width: 100%;
387 | border: 1px solid #EEE;
388 | background: #FDFDFD;
389 | font-size: 0.9em;
390 | }
391 |
392 | table.footnote + table.footnote {
393 | margin-top: -15px;
394 | border-top: none;
395 | }
396 |
397 | table.field-list th {
398 | padding: 0 0.8em 0 0;
399 | }
400 |
401 | table.field-list td {
402 | padding: 0;
403 | }
404 |
405 | table.field-list p {
406 | margin-bottom: 0.8em;
407 | }
408 |
409 | /* Cloned from
410 | * https://github.com/sphinx-doc/sphinx/commit/ef60dbfce09286b20b7385333d63a60321784e68
411 | */
412 | .field-name {
413 | -moz-hyphens: manual;
414 | -ms-hyphens: manual;
415 | -webkit-hyphens: manual;
416 | hyphens: manual;
417 | }
418 |
419 | table.footnote td.label {
420 | width: .1px;
421 | padding: 0.3em 0 0.3em 0.5em;
422 | }
423 |
424 | table.footnote td {
425 | padding: 0.3em 0.5em;
426 | }
427 |
428 | dl {
429 | margin-left: 0;
430 | margin-right: 0;
431 | margin-top: 0;
432 | padding: 0;
433 | }
434 |
435 | dl dd {
436 | margin-left: 30px;
437 | }
438 |
439 | blockquote {
440 | margin: 0 0 0 30px;
441 | padding: 0;
442 | }
443 |
444 | ul, ol {
445 | /* Matches the 30px from the narrow-screen "li > ul" selector below */
446 | margin: 10px 0 10px 30px;
447 | padding: 0;
448 | }
449 |
450 | pre {
451 | background: unset;
452 | padding: 7px 30px;
453 | margin: 15px 0px;
454 | line-height: 1.3em;
455 | }
456 |
457 | div.viewcode-block:target {
458 | background: #ffd;
459 | }
460 |
461 | dl pre, blockquote pre, li pre {
462 | margin-left: 0;
463 | padding-left: 30px;
464 | }
465 |
466 | tt, code {
467 | background-color: #ecf0f3;
468 | color: #222;
469 | /* padding: 1px 2px; */
470 | }
471 |
472 | tt.xref, code.xref, a tt {
473 | background-color: #FBFBFB;
474 | border-bottom: 1px solid #fff;
475 | }
476 |
477 | a.reference {
478 | text-decoration: none;
479 | border-bottom: 1px dotted #004B6B;
480 | }
481 |
482 | a.reference:hover {
483 | border-bottom: 1px solid #6D4100;
484 | }
485 |
486 | /* Don't put an underline on images */
487 | a.image-reference, a.image-reference:hover {
488 | border-bottom: none;
489 | }
490 |
491 | a.footnote-reference {
492 | text-decoration: none;
493 | font-size: 0.7em;
494 | vertical-align: top;
495 | border-bottom: 1px dotted #004B6B;
496 | }
497 |
498 | a.footnote-reference:hover {
499 | border-bottom: 1px solid #6D4100;
500 | }
501 |
502 | a:hover tt, a:hover code {
503 | background: #EEE;
504 | }
505 |
506 | @media screen and (max-width: 940px) {
507 |
508 | body {
509 | margin: 0;
510 | padding: 20px 30px;
511 | }
512 |
513 | div.documentwrapper {
514 | float: none;
515 | background: #fff;
516 | margin-left: 0;
517 | margin-top: 0;
518 | margin-right: 0;
519 | margin-bottom: 0;
520 | }
521 |
522 | div.sphinxsidebar {
523 | display: block;
524 | float: none;
525 | width: unset;
526 | margin: 50px -30px -20px -30px;
527 | padding: 10px 20px;
528 | background: #333;
529 | color: #FFF;
530 | }
531 |
532 | div.sphinxsidebar h3, div.sphinxsidebar h4, div.sphinxsidebar p,
533 | div.sphinxsidebar h3 a {
534 | color: #fff;
535 | }
536 |
537 | div.sphinxsidebar a {
538 | color: #AAA;
539 | }
540 |
541 | div.sphinxsidebar p.logo {
542 | display: none;
543 | }
544 |
545 | div.document {
546 | width: 100%;
547 | margin: 0;
548 | }
549 |
550 | div.footer {
551 | display: none;
552 | }
553 |
554 | div.bodywrapper {
555 | margin: 0;
556 | }
557 |
558 | div.body {
559 | min-height: 0;
560 | min-width: auto; /* fixes width on small screens, breaks .hll */
561 | padding: 0;
562 | }
563 |
564 | .hll {
565 | /* "fixes" the breakage */
566 | width: max-content;
567 | }
568 |
569 | .rtd_doc_footer {
570 | display: none;
571 | }
572 |
573 | .document {
574 | width: auto;
575 | }
576 |
577 | .footer {
578 | width: auto;
579 | }
580 |
581 | .github {
582 | display: none;
583 | }
584 |
585 | ul {
586 | margin-left: 0;
587 | }
588 |
589 | li > ul {
590 | /* Matches the 30px from the "ul, ol" selector above */
591 | margin-left: 30px;
592 | }
593 | }
594 |
595 |
596 | /* misc. */
597 |
598 | .revsys-inline {
599 | display: none!important;
600 | }
601 |
602 | /* Hide ugly table cell borders in ..bibliography:: directive output */
603 | table.docutils.citation, table.docutils.citation td, table.docutils.citation th {
604 | border: none;
605 | /* Below needed in some edge cases; if not applied, bottom shadows appear */
606 | -moz-box-shadow: none;
607 | -webkit-box-shadow: none;
608 | box-shadow: none;
609 | }
610 |
611 |
612 | /* relbar */
613 |
614 | .related {
615 | line-height: 30px;
616 | width: 100%;
617 | font-size: 0.9rem;
618 | }
619 |
620 | .related.top {
621 | border-bottom: 1px solid #EEE;
622 | margin-bottom: 20px;
623 | }
624 |
625 | .related.bottom {
626 | border-top: 1px solid #EEE;
627 | }
628 |
629 | .related ul {
630 | padding: 0;
631 | margin: 0;
632 | list-style: none;
633 | }
634 |
635 | .related li {
636 | display: inline;
637 | }
638 |
639 | nav#rellinks {
640 | float: right;
641 | }
642 |
643 | nav#rellinks li+li:before {
644 | content: "|";
645 | }
646 |
647 | nav#breadcrumbs li+li:before {
648 | content: "\00BB";
649 | }
650 |
651 | /* Hide certain items when printing */
652 | @media print {
653 | div.related {
654 | display: none;
655 | }
656 | }
657 |
658 | img.github {
659 | position: absolute;
660 | top: 0;
661 | border: 0;
662 | right: 0;
663 | }
--------------------------------------------------------------------------------
/docs/api/_static/basic.css:
--------------------------------------------------------------------------------
1 | /*
2 | * basic.css
3 | * ~~~~~~~~~
4 | *
5 | * Sphinx stylesheet -- basic theme.
6 | *
7 | * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
8 | * :license: BSD, see LICENSE for details.
9 | *
10 | */
11 |
12 | /* -- main layout ----------------------------------------------------------- */
13 |
14 | div.clearer {
15 | clear: both;
16 | }
17 |
18 | div.section::after {
19 | display: block;
20 | content: '';
21 | clear: left;
22 | }
23 |
24 | /* -- relbar ---------------------------------------------------------------- */
25 |
26 | div.related {
27 | width: 100%;
28 | font-size: 90%;
29 | }
30 |
31 | div.related h3 {
32 | display: none;
33 | }
34 |
35 | div.related ul {
36 | margin: 0;
37 | padding: 0 0 0 10px;
38 | list-style: none;
39 | }
40 |
41 | div.related li {
42 | display: inline;
43 | }
44 |
45 | div.related li.right {
46 | float: right;
47 | margin-right: 5px;
48 | }
49 |
50 | /* -- sidebar --------------------------------------------------------------- */
51 |
52 | div.sphinxsidebarwrapper {
53 | padding: 10px 5px 0 10px;
54 | }
55 |
56 | div.sphinxsidebar {
57 | float: left;
58 | width: 230px;
59 | margin-left: -100%;
60 | font-size: 90%;
61 | word-wrap: break-word;
62 | overflow-wrap : break-word;
63 | }
64 |
65 | div.sphinxsidebar ul {
66 | list-style: none;
67 | }
68 |
69 | div.sphinxsidebar ul ul,
70 | div.sphinxsidebar ul.want-points {
71 | margin-left: 20px;
72 | list-style: square;
73 | }
74 |
75 | div.sphinxsidebar ul ul {
76 | margin-top: 0;
77 | margin-bottom: 0;
78 | }
79 |
80 | div.sphinxsidebar form {
81 | margin-top: 10px;
82 | }
83 |
84 | div.sphinxsidebar input {
85 | border: 1px solid #98dbcc;
86 | font-family: sans-serif;
87 | font-size: 1em;
88 | }
89 |
90 | div.sphinxsidebar #searchbox form.search {
91 | overflow: hidden;
92 | }
93 |
94 | div.sphinxsidebar #searchbox input[type="text"] {
95 | float: left;
96 | width: 80%;
97 | padding: 0.25em;
98 | box-sizing: border-box;
99 | }
100 |
101 | div.sphinxsidebar #searchbox input[type="submit"] {
102 | float: left;
103 | width: 20%;
104 | border-left: none;
105 | padding: 0.25em;
106 | box-sizing: border-box;
107 | }
108 |
109 |
110 | img {
111 | border: 0;
112 | max-width: 100%;
113 | }
114 |
115 | /* -- search page ----------------------------------------------------------- */
116 |
117 | ul.search {
118 | margin: 10px 0 0 20px;
119 | padding: 0;
120 | }
121 |
122 | ul.search li {
123 | padding: 5px 0 5px 20px;
124 | background-image: url(file.png);
125 | background-repeat: no-repeat;
126 | background-position: 0 7px;
127 | }
128 |
129 | ul.search li a {
130 | font-weight: bold;
131 | }
132 |
133 | ul.search li p.context {
134 | color: #888;
135 | margin: 2px 0 0 30px;
136 | text-align: left;
137 | }
138 |
139 | ul.keywordmatches li.goodmatch a {
140 | font-weight: bold;
141 | }
142 |
143 | /* -- index page ------------------------------------------------------------ */
144 |
145 | table.contentstable {
146 | width: 90%;
147 | margin-left: auto;
148 | margin-right: auto;
149 | }
150 |
151 | table.contentstable p.biglink {
152 | line-height: 150%;
153 | }
154 |
155 | a.biglink {
156 | font-size: 1.3em;
157 | }
158 |
159 | span.linkdescr {
160 | font-style: italic;
161 | padding-top: 5px;
162 | font-size: 90%;
163 | }
164 |
165 | /* -- general index --------------------------------------------------------- */
166 |
167 | table.indextable {
168 | width: 100%;
169 | }
170 |
171 | table.indextable td {
172 | text-align: left;
173 | vertical-align: top;
174 | }
175 |
176 | table.indextable ul {
177 | margin-top: 0;
178 | margin-bottom: 0;
179 | list-style-type: none;
180 | }
181 |
182 | table.indextable > tbody > tr > td > ul {
183 | padding-left: 0em;
184 | }
185 |
186 | table.indextable tr.pcap {
187 | height: 10px;
188 | }
189 |
190 | table.indextable tr.cap {
191 | margin-top: 10px;
192 | background-color: #f2f2f2;
193 | }
194 |
195 | img.toggler {
196 | margin-right: 3px;
197 | margin-top: 3px;
198 | cursor: pointer;
199 | }
200 |
201 | div.modindex-jumpbox {
202 | border-top: 1px solid #ddd;
203 | border-bottom: 1px solid #ddd;
204 | margin: 1em 0 1em 0;
205 | padding: 0.4em;
206 | }
207 |
208 | div.genindex-jumpbox {
209 | border-top: 1px solid #ddd;
210 | border-bottom: 1px solid #ddd;
211 | margin: 1em 0 1em 0;
212 | padding: 0.4em;
213 | }
214 |
215 | /* -- domain module index --------------------------------------------------- */
216 |
217 | table.modindextable td {
218 | padding: 2px;
219 | border-collapse: collapse;
220 | }
221 |
222 | /* -- general body styles --------------------------------------------------- */
223 |
224 | div.body {
225 | min-width: inherit;
226 | max-width: 800px;
227 | }
228 |
229 | div.body p, div.body dd, div.body li, div.body blockquote {
230 | -moz-hyphens: auto;
231 | -ms-hyphens: auto;
232 | -webkit-hyphens: auto;
233 | hyphens: auto;
234 | }
235 |
236 | a.headerlink {
237 | visibility: hidden;
238 | }
239 |
240 | a:visited {
241 | color: #551A8B;
242 | }
243 |
244 | h1:hover > a.headerlink,
245 | h2:hover > a.headerlink,
246 | h3:hover > a.headerlink,
247 | h4:hover > a.headerlink,
248 | h5:hover > a.headerlink,
249 | h6:hover > a.headerlink,
250 | dt:hover > a.headerlink,
251 | caption:hover > a.headerlink,
252 | p.caption:hover > a.headerlink,
253 | div.code-block-caption:hover > a.headerlink {
254 | visibility: visible;
255 | }
256 |
257 | div.body p.caption {
258 | text-align: inherit;
259 | }
260 |
261 | div.body td {
262 | text-align: left;
263 | }
264 |
265 | .first {
266 | margin-top: 0 !important;
267 | }
268 |
269 | p.rubric {
270 | margin-top: 30px;
271 | font-weight: bold;
272 | }
273 |
274 | img.align-left, figure.align-left, .figure.align-left, object.align-left {
275 | clear: left;
276 | float: left;
277 | margin-right: 1em;
278 | }
279 |
280 | img.align-right, figure.align-right, .figure.align-right, object.align-right {
281 | clear: right;
282 | float: right;
283 | margin-left: 1em;
284 | }
285 |
286 | img.align-center, figure.align-center, .figure.align-center, object.align-center {
287 | display: block;
288 | margin-left: auto;
289 | margin-right: auto;
290 | }
291 |
292 | img.align-default, figure.align-default, .figure.align-default {
293 | display: block;
294 | margin-left: auto;
295 | margin-right: auto;
296 | }
297 |
298 | .align-left {
299 | text-align: left;
300 | }
301 |
302 | .align-center {
303 | text-align: center;
304 | }
305 |
306 | .align-default {
307 | text-align: center;
308 | }
309 |
310 | .align-right {
311 | text-align: right;
312 | }
313 |
314 | /* -- sidebars -------------------------------------------------------------- */
315 |
316 | div.sidebar,
317 | aside.sidebar {
318 | margin: 0 0 0.5em 1em;
319 | border: 1px solid #ddb;
320 | padding: 7px;
321 | background-color: #ffe;
322 | width: 40%;
323 | float: right;
324 | clear: right;
325 | overflow-x: auto;
326 | }
327 |
328 | p.sidebar-title {
329 | font-weight: bold;
330 | }
331 |
332 | nav.contents,
333 | aside.topic,
334 | div.admonition, div.topic, blockquote {
335 | clear: left;
336 | }
337 |
338 | /* -- topics ---------------------------------------------------------------- */
339 |
340 | nav.contents,
341 | aside.topic,
342 | div.topic {
343 | border: 1px solid #ccc;
344 | padding: 7px;
345 | margin: 10px 0 10px 0;
346 | }
347 |
348 | p.topic-title {
349 | font-size: 1.1em;
350 | font-weight: bold;
351 | margin-top: 10px;
352 | }
353 |
354 | /* -- admonitions ----------------------------------------------------------- */
355 |
356 | div.admonition {
357 | margin-top: 10px;
358 | margin-bottom: 10px;
359 | padding: 7px;
360 | }
361 |
362 | div.admonition dt {
363 | font-weight: bold;
364 | }
365 |
366 | p.admonition-title {
367 | margin: 0px 10px 5px 0px;
368 | font-weight: bold;
369 | }
370 |
371 | div.body p.centered {
372 | text-align: center;
373 | margin-top: 25px;
374 | }
375 |
376 | /* -- content of sidebars/topics/admonitions -------------------------------- */
377 |
378 | div.sidebar > :last-child,
379 | aside.sidebar > :last-child,
380 | nav.contents > :last-child,
381 | aside.topic > :last-child,
382 | div.topic > :last-child,
383 | div.admonition > :last-child {
384 | margin-bottom: 0;
385 | }
386 |
387 | div.sidebar::after,
388 | aside.sidebar::after,
389 | nav.contents::after,
390 | aside.topic::after,
391 | div.topic::after,
392 | div.admonition::after,
393 | blockquote::after {
394 | display: block;
395 | content: '';
396 | clear: both;
397 | }
398 |
399 | /* -- tables ---------------------------------------------------------------- */
400 |
401 | table.docutils {
402 | margin-top: 10px;
403 | margin-bottom: 10px;
404 | border: 0;
405 | border-collapse: collapse;
406 | }
407 |
408 | table.align-center {
409 | margin-left: auto;
410 | margin-right: auto;
411 | }
412 |
413 | table.align-default {
414 | margin-left: auto;
415 | margin-right: auto;
416 | }
417 |
418 | table caption span.caption-number {
419 | font-style: italic;
420 | }
421 |
422 | table caption span.caption-text {
423 | }
424 |
425 | table.docutils td, table.docutils th {
426 | padding: 1px 8px 1px 5px;
427 | border-top: 0;
428 | border-left: 0;
429 | border-right: 0;
430 | border-bottom: 1px solid #aaa;
431 | }
432 |
433 | th {
434 | text-align: left;
435 | padding-right: 5px;
436 | }
437 |
438 | table.citation {
439 | border-left: solid 1px gray;
440 | margin-left: 1px;
441 | }
442 |
443 | table.citation td {
444 | border-bottom: none;
445 | }
446 |
447 | th > :first-child,
448 | td > :first-child {
449 | margin-top: 0px;
450 | }
451 |
452 | th > :last-child,
453 | td > :last-child {
454 | margin-bottom: 0px;
455 | }
456 |
457 | /* -- figures --------------------------------------------------------------- */
458 |
459 | div.figure, figure {
460 | margin: 0.5em;
461 | padding: 0.5em;
462 | }
463 |
464 | div.figure p.caption, figcaption {
465 | padding: 0.3em;
466 | }
467 |
468 | div.figure p.caption span.caption-number,
469 | figcaption span.caption-number {
470 | font-style: italic;
471 | }
472 |
473 | div.figure p.caption span.caption-text,
474 | figcaption span.caption-text {
475 | }
476 |
477 | /* -- field list styles ----------------------------------------------------- */
478 |
479 | table.field-list td, table.field-list th {
480 | border: 0 !important;
481 | }
482 |
483 | .field-list ul {
484 | margin: 0;
485 | padding-left: 1em;
486 | }
487 |
488 | .field-list p {
489 | margin: 0;
490 | }
491 |
492 | .field-name {
493 | -moz-hyphens: manual;
494 | -ms-hyphens: manual;
495 | -webkit-hyphens: manual;
496 | hyphens: manual;
497 | }
498 |
499 | /* -- hlist styles ---------------------------------------------------------- */
500 |
501 | table.hlist {
502 | margin: 1em 0;
503 | }
504 |
505 | table.hlist td {
506 | vertical-align: top;
507 | }
508 |
509 | /* -- object description styles --------------------------------------------- */
510 |
511 | .sig {
512 | font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace;
513 | }
514 |
515 | .sig-name, code.descname {
516 | background-color: transparent;
517 | font-weight: bold;
518 | }
519 |
520 | .sig-name {
521 | font-size: 1.1em;
522 | }
523 |
524 | code.descname {
525 | font-size: 1.2em;
526 | }
527 |
528 | .sig-prename, code.descclassname {
529 | background-color: transparent;
530 | }
531 |
532 | .optional {
533 | font-size: 1.3em;
534 | }
535 |
536 | .sig-paren {
537 | font-size: larger;
538 | }
539 |
540 | .sig-param.n {
541 | font-style: italic;
542 | }
543 |
544 | /* C++ specific styling */
545 |
546 | .sig-inline.c-texpr,
547 | .sig-inline.cpp-texpr {
548 | font-family: unset;
549 | }
550 |
551 | .sig.c .k, .sig.c .kt,
552 | .sig.cpp .k, .sig.cpp .kt {
553 | color: #0033B3;
554 | }
555 |
556 | .sig.c .m,
557 | .sig.cpp .m {
558 | color: #1750EB;
559 | }
560 |
561 | .sig.c .s, .sig.c .sc,
562 | .sig.cpp .s, .sig.cpp .sc {
563 | color: #067D17;
564 | }
565 |
566 |
567 | /* -- other body styles ----------------------------------------------------- */
568 |
569 | ol.arabic {
570 | list-style: decimal;
571 | }
572 |
573 | ol.loweralpha {
574 | list-style: lower-alpha;
575 | }
576 |
577 | ol.upperalpha {
578 | list-style: upper-alpha;
579 | }
580 |
581 | ol.lowerroman {
582 | list-style: lower-roman;
583 | }
584 |
585 | ol.upperroman {
586 | list-style: upper-roman;
587 | }
588 |
589 | :not(li) > ol > li:first-child > :first-child,
590 | :not(li) > ul > li:first-child > :first-child {
591 | margin-top: 0px;
592 | }
593 |
594 | :not(li) > ol > li:last-child > :last-child,
595 | :not(li) > ul > li:last-child > :last-child {
596 | margin-bottom: 0px;
597 | }
598 |
599 | ol.simple ol p,
600 | ol.simple ul p,
601 | ul.simple ol p,
602 | ul.simple ul p {
603 | margin-top: 0;
604 | }
605 |
606 | ol.simple > li:not(:first-child) > p,
607 | ul.simple > li:not(:first-child) > p {
608 | margin-top: 0;
609 | }
610 |
611 | ol.simple p,
612 | ul.simple p {
613 | margin-bottom: 0;
614 | }
615 |
616 | aside.footnote > span,
617 | div.citation > span {
618 | float: left;
619 | }
620 | aside.footnote > span:last-of-type,
621 | div.citation > span:last-of-type {
622 | padding-right: 0.5em;
623 | }
624 | aside.footnote > p {
625 | margin-left: 2em;
626 | }
627 | div.citation > p {
628 | margin-left: 4em;
629 | }
630 | aside.footnote > p:last-of-type,
631 | div.citation > p:last-of-type {
632 | margin-bottom: 0em;
633 | }
634 | aside.footnote > p:last-of-type:after,
635 | div.citation > p:last-of-type:after {
636 | content: "";
637 | clear: both;
638 | }
639 |
640 | dl.field-list {
641 | display: grid;
642 | grid-template-columns: fit-content(30%) auto;
643 | }
644 |
645 | dl.field-list > dt {
646 | font-weight: bold;
647 | word-break: break-word;
648 | padding-left: 0.5em;
649 | padding-right: 5px;
650 | }
651 |
652 | dl.field-list > dd {
653 | padding-left: 0.5em;
654 | margin-top: 0em;
655 | margin-left: 0em;
656 | margin-bottom: 0em;
657 | }
658 |
659 | dl {
660 | margin-bottom: 15px;
661 | }
662 |
663 | dd > :first-child {
664 | margin-top: 0px;
665 | }
666 |
667 | dd ul, dd table {
668 | margin-bottom: 10px;
669 | }
670 |
671 | dd {
672 | margin-top: 3px;
673 | margin-bottom: 10px;
674 | margin-left: 30px;
675 | }
676 |
677 | .sig dd {
678 | margin-top: 0px;
679 | margin-bottom: 0px;
680 | }
681 |
682 | .sig dl {
683 | margin-top: 0px;
684 | margin-bottom: 0px;
685 | }
686 |
687 | dl > dd:last-child,
688 | dl > dd:last-child > :last-child {
689 | margin-bottom: 0;
690 | }
691 |
692 | dt:target, span.highlighted {
693 | background-color: #fbe54e;
694 | }
695 |
696 | rect.highlighted {
697 | fill: #fbe54e;
698 | }
699 |
700 | dl.glossary dt {
701 | font-weight: bold;
702 | font-size: 1.1em;
703 | }
704 |
705 | .versionmodified {
706 | font-style: italic;
707 | }
708 |
709 | .system-message {
710 | background-color: #fda;
711 | padding: 5px;
712 | border: 3px solid red;
713 | }
714 |
715 | .footnote:target {
716 | background-color: #ffa;
717 | }
718 |
719 | .line-block {
720 | display: block;
721 | margin-top: 1em;
722 | margin-bottom: 1em;
723 | }
724 |
725 | .line-block .line-block {
726 | margin-top: 0;
727 | margin-bottom: 0;
728 | margin-left: 1.5em;
729 | }
730 |
731 | .guilabel, .menuselection {
732 | font-family: sans-serif;
733 | }
734 |
735 | .accelerator {
736 | text-decoration: underline;
737 | }
738 |
739 | .classifier {
740 | font-style: oblique;
741 | }
742 |
743 | .classifier:before {
744 | font-style: normal;
745 | margin: 0 0.5em;
746 | content: ":";
747 | display: inline-block;
748 | }
749 |
750 | abbr, acronym {
751 | border-bottom: dotted 1px;
752 | cursor: help;
753 | }
754 |
755 | .translated {
756 | background-color: rgba(207, 255, 207, 0.2)
757 | }
758 |
759 | .untranslated {
760 | background-color: rgba(255, 207, 207, 0.2)
761 | }
762 |
763 | /* -- code displays --------------------------------------------------------- */
764 |
765 | pre {
766 | overflow: auto;
767 | overflow-y: hidden; /* fixes display issues on Chrome browsers */
768 | }
769 |
770 | pre, div[class*="highlight-"] {
771 | clear: both;
772 | }
773 |
774 | span.pre {
775 | -moz-hyphens: none;
776 | -ms-hyphens: none;
777 | -webkit-hyphens: none;
778 | hyphens: none;
779 | white-space: nowrap;
780 | }
781 |
782 | div[class*="highlight-"] {
783 | margin: 1em 0;
784 | }
785 |
786 | td.linenos pre {
787 | border: 0;
788 | background-color: transparent;
789 | color: #aaa;
790 | }
791 |
792 | table.highlighttable {
793 | display: block;
794 | }
795 |
796 | table.highlighttable tbody {
797 | display: block;
798 | }
799 |
800 | table.highlighttable tr {
801 | display: flex;
802 | }
803 |
804 | table.highlighttable td {
805 | margin: 0;
806 | padding: 0;
807 | }
808 |
809 | table.highlighttable td.linenos {
810 | padding-right: 0.5em;
811 | }
812 |
813 | table.highlighttable td.code {
814 | flex: 1;
815 | overflow: hidden;
816 | }
817 |
818 | .highlight .hll {
819 | display: block;
820 | }
821 |
822 | div.highlight pre,
823 | table.highlighttable pre {
824 | margin: 0;
825 | }
826 |
827 | div.code-block-caption + div {
828 | margin-top: 0;
829 | }
830 |
831 | div.code-block-caption {
832 | margin-top: 1em;
833 | padding: 2px 5px;
834 | font-size: small;
835 | }
836 |
837 | div.code-block-caption code {
838 | background-color: transparent;
839 | }
840 |
841 | table.highlighttable td.linenos,
842 | span.linenos,
843 | div.highlight span.gp { /* gp: Generic.Prompt */
844 | user-select: none;
845 | -webkit-user-select: text; /* Safari fallback only */
846 | -webkit-user-select: none; /* Chrome/Safari */
847 | -moz-user-select: none; /* Firefox */
848 | -ms-user-select: none; /* IE10+ */
849 | }
850 |
851 | div.code-block-caption span.caption-number {
852 | padding: 0.1em 0.3em;
853 | font-style: italic;
854 | }
855 |
856 | div.code-block-caption span.caption-text {
857 | }
858 |
859 | div.literal-block-wrapper {
860 | margin: 1em 0;
861 | }
862 |
863 | code.xref, a code {
864 | background-color: transparent;
865 | font-weight: bold;
866 | }
867 |
868 | h1 code, h2 code, h3 code, h4 code, h5 code, h6 code {
869 | background-color: transparent;
870 | }
871 |
872 | .viewcode-link {
873 | float: right;
874 | }
875 |
876 | .viewcode-back {
877 | float: right;
878 | font-family: sans-serif;
879 | }
880 |
881 | div.viewcode-block:target {
882 | margin: -1px -10px;
883 | padding: 0 10px;
884 | }
885 |
886 | /* -- math display ---------------------------------------------------------- */
887 |
888 | img.math {
889 | vertical-align: middle;
890 | }
891 |
892 | div.body div.math p {
893 | text-align: center;
894 | }
895 |
896 | span.eqno {
897 | float: right;
898 | }
899 |
900 | span.eqno a.headerlink {
901 | position: absolute;
902 | z-index: 1;
903 | }
904 |
905 | div.math:hover a.headerlink {
906 | visibility: visible;
907 | }
908 |
909 | /* -- printout stylesheet --------------------------------------------------- */
910 |
911 | @media print {
912 | div.document,
913 | div.documentwrapper,
914 | div.bodywrapper {
915 | margin: 0 !important;
916 | width: 100%;
917 | }
918 |
919 | div.sphinxsidebar,
920 | div.related,
921 | div.footer,
922 | #top-link {
923 | display: none;
924 | }
925 | }
--------------------------------------------------------------------------------
/docs/api/_static/custom.css:
--------------------------------------------------------------------------------
1 | /* This file intentionally left blank. */
2 |
--------------------------------------------------------------------------------
/docs/api/_static/doctools.js:
--------------------------------------------------------------------------------
1 | /*
2 | * doctools.js
3 | * ~~~~~~~~~~~
4 | *
5 | * Base JavaScript utilities for all Sphinx HTML documentation.
6 | *
7 | * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
8 | * :license: BSD, see LICENSE for details.
9 | *
10 | */
11 | "use strict";
12 |
13 | const BLACKLISTED_KEY_CONTROL_ELEMENTS = new Set([
14 | "TEXTAREA",
15 | "INPUT",
16 | "SELECT",
17 | "BUTTON",
18 | ]);
19 |
20 | const _ready = (callback) => {
21 | if (document.readyState !== "loading") {
22 | callback();
23 | } else {
24 | document.addEventListener("DOMContentLoaded", callback);
25 | }
26 | };
27 |
28 | /**
29 | * Small JavaScript module for the documentation.
30 | */
31 | const Documentation = {
32 | init: () => {
33 | Documentation.initDomainIndexTable();
34 | Documentation.initOnKeyListeners();
35 | },
36 |
37 | /**
38 | * i18n support
39 | */
40 | TRANSLATIONS: {},
41 | PLURAL_EXPR: (n) => (n === 1 ? 0 : 1),
42 | LOCALE: "unknown",
43 |
44 | // gettext and ngettext don't access this so that the functions
45 | // can safely bound to a different name (_ = Documentation.gettext)
46 | gettext: (string) => {
47 | const translated = Documentation.TRANSLATIONS[string];
48 | switch (typeof translated) {
49 | case "undefined":
50 | return string; // no translation
51 | case "string":
52 | return translated; // translation exists
53 | default:
54 | return translated[0]; // (singular, plural) translation tuple exists
55 | }
56 | },
57 |
58 | ngettext: (singular, plural, n) => {
59 | const translated = Documentation.TRANSLATIONS[singular];
60 | if (typeof translated !== "undefined")
61 | return translated[Documentation.PLURAL_EXPR(n)];
62 | return n === 1 ? singular : plural;
63 | },
64 |
65 | addTranslations: (catalog) => {
66 | Object.assign(Documentation.TRANSLATIONS, catalog.messages);
67 | Documentation.PLURAL_EXPR = new Function(
68 | "n",
69 | `return (${catalog.plural_expr})`
70 | );
71 | Documentation.LOCALE = catalog.locale;
72 | },
73 |
74 | /**
75 | * helper function to focus on search bar
76 | */
77 | focusSearchBar: () => {
78 | document.querySelectorAll("input[name=q]")[0]?.focus();
79 | },
80 |
81 | /**
82 | * Initialise the domain index toggle buttons
83 | */
84 | initDomainIndexTable: () => {
85 | const toggler = (el) => {
86 | const idNumber = el.id.substr(7);
87 | const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`);
88 | if (el.src.substr(-9) === "minus.png") {
89 | el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`;
90 | toggledRows.forEach((el) => (el.style.display = "none"));
91 | } else {
92 | el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`;
93 | toggledRows.forEach((el) => (el.style.display = ""));
94 | }
95 | };
96 |
97 | const togglerElements = document.querySelectorAll("img.toggler");
98 | togglerElements.forEach((el) =>
99 | el.addEventListener("click", (event) => toggler(event.currentTarget))
100 | );
101 | togglerElements.forEach((el) => (el.style.display = ""));
102 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler);
103 | },
104 |
105 | initOnKeyListeners: () => {
106 | // only install a listener if it is really needed
107 | if (
108 | !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS &&
109 | !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS
110 | )
111 | return;
112 |
113 | document.addEventListener("keydown", (event) => {
114 | // bail for input elements
115 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return;
116 | // bail with special keys
117 | if (event.altKey || event.ctrlKey || event.metaKey) return;
118 |
119 | if (!event.shiftKey) {
120 | switch (event.key) {
121 | case "ArrowLeft":
122 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break;
123 |
124 | const prevLink = document.querySelector('link[rel="prev"]');
125 | if (prevLink && prevLink.href) {
126 | window.location.href = prevLink.href;
127 | event.preventDefault();
128 | }
129 | break;
130 | case "ArrowRight":
131 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break;
132 |
133 | const nextLink = document.querySelector('link[rel="next"]');
134 | if (nextLink && nextLink.href) {
135 | window.location.href = nextLink.href;
136 | event.preventDefault();
137 | }
138 | break;
139 | }
140 | }
141 |
142 | // some keyboard layouts may need Shift to get /
143 | switch (event.key) {
144 | case "/":
145 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break;
146 | Documentation.focusSearchBar();
147 | event.preventDefault();
148 | }
149 | });
150 | },
151 | };
152 |
153 | // quick alias for translations
154 | const _ = Documentation.gettext;
155 |
156 | _ready(Documentation.init);
157 |
--------------------------------------------------------------------------------
/docs/api/_static/documentation_options.js:
--------------------------------------------------------------------------------
1 | const DOCUMENTATION_OPTIONS = {
2 | VERSION: '',
3 | LANGUAGE: 'en',
4 | COLLAPSE_INDEX: false,
5 | BUILDER: 'html',
6 | FILE_SUFFIX: '.html',
7 | LINK_SUFFIX: '.html',
8 | HAS_SOURCE: true,
9 | SOURCELINK_SUFFIX: '.txt',
10 | NAVIGATION_WITH_KEYS: false,
11 | SHOW_SEARCH_SUMMARY: true,
12 | ENABLE_SEARCH_SHORTCUTS: true,
13 | };
--------------------------------------------------------------------------------
/docs/api/_static/file.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/api/_static/file.png
--------------------------------------------------------------------------------
/docs/api/_static/github-banner.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/docs/api/_static/language_data.js:
--------------------------------------------------------------------------------
1 | /*
2 | * language_data.js
3 | * ~~~~~~~~~~~~~~~~
4 | *
5 | * This script contains the language-specific data used by searchtools.js,
6 | * namely the list of stopwords, stemmer, scorer and splitter.
7 | *
8 | * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
9 | * :license: BSD, see LICENSE for details.
10 | *
11 | */
12 |
13 | var stopwords = ["a", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it", "near", "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"];
14 |
15 |
16 | /* Non-minified version is copied as a separate JS file, if available */
17 |
18 | /**
19 | * Porter Stemmer
20 | */
21 | var Stemmer = function() {
22 |
23 | var step2list = {
24 | ational: 'ate',
25 | tional: 'tion',
26 | enci: 'ence',
27 | anci: 'ance',
28 | izer: 'ize',
29 | bli: 'ble',
30 | alli: 'al',
31 | entli: 'ent',
32 | eli: 'e',
33 | ousli: 'ous',
34 | ization: 'ize',
35 | ation: 'ate',
36 | ator: 'ate',
37 | alism: 'al',
38 | iveness: 'ive',
39 | fulness: 'ful',
40 | ousness: 'ous',
41 | aliti: 'al',
42 | iviti: 'ive',
43 | biliti: 'ble',
44 | logi: 'log'
45 | };
46 |
47 | var step3list = {
48 | icate: 'ic',
49 | ative: '',
50 | alize: 'al',
51 | iciti: 'ic',
52 | ical: 'ic',
53 | ful: '',
54 | ness: ''
55 | };
56 |
57 | var c = "[^aeiou]"; // consonant
58 | var v = "[aeiouy]"; // vowel
59 | var C = c + "[^aeiouy]*"; // consonant sequence
60 | var V = v + "[aeiou]*"; // vowel sequence
61 |
62 | var mgr0 = "^(" + C + ")?" + V + C; // [C]VC... is m>0
63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1
64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1
65 | var s_v = "^(" + C + ")?" + v; // vowel in stem
66 |
67 | this.stemWord = function (w) {
68 | var stem;
69 | var suffix;
70 | var firstch;
71 | var origword = w;
72 |
73 | if (w.length < 3)
74 | return w;
75 |
76 | var re;
77 | var re2;
78 | var re3;
79 | var re4;
80 |
81 | firstch = w.substr(0,1);
82 | if (firstch == "y")
83 | w = firstch.toUpperCase() + w.substr(1);
84 |
85 | // Step 1a
86 | re = /^(.+?)(ss|i)es$/;
87 | re2 = /^(.+?)([^s])s$/;
88 |
89 | if (re.test(w))
90 | w = w.replace(re,"$1$2");
91 | else if (re2.test(w))
92 | w = w.replace(re2,"$1$2");
93 |
94 | // Step 1b
95 | re = /^(.+?)eed$/;
96 | re2 = /^(.+?)(ed|ing)$/;
97 | if (re.test(w)) {
98 | var fp = re.exec(w);
99 | re = new RegExp(mgr0);
100 | if (re.test(fp[1])) {
101 | re = /.$/;
102 | w = w.replace(re,"");
103 | }
104 | }
105 | else if (re2.test(w)) {
106 | var fp = re2.exec(w);
107 | stem = fp[1];
108 | re2 = new RegExp(s_v);
109 | if (re2.test(stem)) {
110 | w = stem;
111 | re2 = /(at|bl|iz)$/;
112 | re3 = new RegExp("([^aeiouylsz])\\1$");
113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$");
114 | if (re2.test(w))
115 | w = w + "e";
116 | else if (re3.test(w)) {
117 | re = /.$/;
118 | w = w.replace(re,"");
119 | }
120 | else if (re4.test(w))
121 | w = w + "e";
122 | }
123 | }
124 |
125 | // Step 1c
126 | re = /^(.+?)y$/;
127 | if (re.test(w)) {
128 | var fp = re.exec(w);
129 | stem = fp[1];
130 | re = new RegExp(s_v);
131 | if (re.test(stem))
132 | w = stem + "i";
133 | }
134 |
135 | // Step 2
136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/;
137 | if (re.test(w)) {
138 | var fp = re.exec(w);
139 | stem = fp[1];
140 | suffix = fp[2];
141 | re = new RegExp(mgr0);
142 | if (re.test(stem))
143 | w = stem + step2list[suffix];
144 | }
145 |
146 | // Step 3
147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/;
148 | if (re.test(w)) {
149 | var fp = re.exec(w);
150 | stem = fp[1];
151 | suffix = fp[2];
152 | re = new RegExp(mgr0);
153 | if (re.test(stem))
154 | w = stem + step3list[suffix];
155 | }
156 |
157 | // Step 4
158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/;
159 | re2 = /^(.+?)(s|t)(ion)$/;
160 | if (re.test(w)) {
161 | var fp = re.exec(w);
162 | stem = fp[1];
163 | re = new RegExp(mgr1);
164 | if (re.test(stem))
165 | w = stem;
166 | }
167 | else if (re2.test(w)) {
168 | var fp = re2.exec(w);
169 | stem = fp[1] + fp[2];
170 | re2 = new RegExp(mgr1);
171 | if (re2.test(stem))
172 | w = stem;
173 | }
174 |
175 | // Step 5
176 | re = /^(.+?)e$/;
177 | if (re.test(w)) {
178 | var fp = re.exec(w);
179 | stem = fp[1];
180 | re = new RegExp(mgr1);
181 | re2 = new RegExp(meq1);
182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$");
183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem))))
184 | w = stem;
185 | }
186 | re = /ll$/;
187 | re2 = new RegExp(mgr1);
188 | if (re.test(w) && re2.test(w)) {
189 | re = /.$/;
190 | w = w.replace(re,"");
191 | }
192 |
193 | // and turn initial Y back to y
194 | if (firstch == "y")
195 | w = firstch.toLowerCase() + w.substr(1);
196 | return w;
197 | }
198 | }
199 |
200 |
--------------------------------------------------------------------------------
/docs/api/_static/minus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/api/_static/minus.png
--------------------------------------------------------------------------------
/docs/api/_static/plus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/api/_static/plus.png
--------------------------------------------------------------------------------
/docs/api/_static/pygments.css:
--------------------------------------------------------------------------------
1 | pre { line-height: 125%; }
2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
6 | .highlight .hll { background-color: #ffffcc }
7 | .highlight { background: #f8f8f8; }
8 | .highlight .c { color: #8f5902; font-style: italic } /* Comment */
9 | .highlight .err { color: #a40000; border: 1px solid #ef2929 } /* Error */
10 | .highlight .g { color: #000000 } /* Generic */
11 | .highlight .k { color: #004461; font-weight: bold } /* Keyword */
12 | .highlight .l { color: #000000 } /* Literal */
13 | .highlight .n { color: #000000 } /* Name */
14 | .highlight .o { color: #582800 } /* Operator */
15 | .highlight .x { color: #000000 } /* Other */
16 | .highlight .p { color: #000000; font-weight: bold } /* Punctuation */
17 | .highlight .ch { color: #8f5902; font-style: italic } /* Comment.Hashbang */
18 | .highlight .cm { color: #8f5902; font-style: italic } /* Comment.Multiline */
19 | .highlight .cp { color: #8f5902 } /* Comment.Preproc */
20 | .highlight .cpf { color: #8f5902; font-style: italic } /* Comment.PreprocFile */
21 | .highlight .c1 { color: #8f5902; font-style: italic } /* Comment.Single */
22 | .highlight .cs { color: #8f5902; font-style: italic } /* Comment.Special */
23 | .highlight .gd { color: #a40000 } /* Generic.Deleted */
24 | .highlight .ge { color: #000000; font-style: italic } /* Generic.Emph */
25 | .highlight .ges { color: #000000 } /* Generic.EmphStrong */
26 | .highlight .gr { color: #ef2929 } /* Generic.Error */
27 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */
28 | .highlight .gi { color: #00A000 } /* Generic.Inserted */
29 | .highlight .go { color: #888888 } /* Generic.Output */
30 | .highlight .gp { color: #745334 } /* Generic.Prompt */
31 | .highlight .gs { color: #000000; font-weight: bold } /* Generic.Strong */
32 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */
33 | .highlight .gt { color: #a40000; font-weight: bold } /* Generic.Traceback */
34 | .highlight .kc { color: #004461; font-weight: bold } /* Keyword.Constant */
35 | .highlight .kd { color: #004461; font-weight: bold } /* Keyword.Declaration */
36 | .highlight .kn { color: #004461; font-weight: bold } /* Keyword.Namespace */
37 | .highlight .kp { color: #004461; font-weight: bold } /* Keyword.Pseudo */
38 | .highlight .kr { color: #004461; font-weight: bold } /* Keyword.Reserved */
39 | .highlight .kt { color: #004461; font-weight: bold } /* Keyword.Type */
40 | .highlight .ld { color: #000000 } /* Literal.Date */
41 | .highlight .m { color: #990000 } /* Literal.Number */
42 | .highlight .s { color: #4e9a06 } /* Literal.String */
43 | .highlight .na { color: #c4a000 } /* Name.Attribute */
44 | .highlight .nb { color: #004461 } /* Name.Builtin */
45 | .highlight .nc { color: #000000 } /* Name.Class */
46 | .highlight .no { color: #000000 } /* Name.Constant */
47 | .highlight .nd { color: #888888 } /* Name.Decorator */
48 | .highlight .ni { color: #ce5c00 } /* Name.Entity */
49 | .highlight .ne { color: #cc0000; font-weight: bold } /* Name.Exception */
50 | .highlight .nf { color: #000000 } /* Name.Function */
51 | .highlight .nl { color: #f57900 } /* Name.Label */
52 | .highlight .nn { color: #000000 } /* Name.Namespace */
53 | .highlight .nx { color: #000000 } /* Name.Other */
54 | .highlight .py { color: #000000 } /* Name.Property */
55 | .highlight .nt { color: #004461; font-weight: bold } /* Name.Tag */
56 | .highlight .nv { color: #000000 } /* Name.Variable */
57 | .highlight .ow { color: #004461; font-weight: bold } /* Operator.Word */
58 | .highlight .pm { color: #000000; font-weight: bold } /* Punctuation.Marker */
59 | .highlight .w { color: #f8f8f8 } /* Text.Whitespace */
60 | .highlight .mb { color: #990000 } /* Literal.Number.Bin */
61 | .highlight .mf { color: #990000 } /* Literal.Number.Float */
62 | .highlight .mh { color: #990000 } /* Literal.Number.Hex */
63 | .highlight .mi { color: #990000 } /* Literal.Number.Integer */
64 | .highlight .mo { color: #990000 } /* Literal.Number.Oct */
65 | .highlight .sa { color: #4e9a06 } /* Literal.String.Affix */
66 | .highlight .sb { color: #4e9a06 } /* Literal.String.Backtick */
67 | .highlight .sc { color: #4e9a06 } /* Literal.String.Char */
68 | .highlight .dl { color: #4e9a06 } /* Literal.String.Delimiter */
69 | .highlight .sd { color: #8f5902; font-style: italic } /* Literal.String.Doc */
70 | .highlight .s2 { color: #4e9a06 } /* Literal.String.Double */
71 | .highlight .se { color: #4e9a06 } /* Literal.String.Escape */
72 | .highlight .sh { color: #4e9a06 } /* Literal.String.Heredoc */
73 | .highlight .si { color: #4e9a06 } /* Literal.String.Interpol */
74 | .highlight .sx { color: #4e9a06 } /* Literal.String.Other */
75 | .highlight .sr { color: #4e9a06 } /* Literal.String.Regex */
76 | .highlight .s1 { color: #4e9a06 } /* Literal.String.Single */
77 | .highlight .ss { color: #4e9a06 } /* Literal.String.Symbol */
78 | .highlight .bp { color: #3465a4 } /* Name.Builtin.Pseudo */
79 | .highlight .fm { color: #000000 } /* Name.Function.Magic */
80 | .highlight .vc { color: #000000 } /* Name.Variable.Class */
81 | .highlight .vg { color: #000000 } /* Name.Variable.Global */
82 | .highlight .vi { color: #000000 } /* Name.Variable.Instance */
83 | .highlight .vm { color: #000000 } /* Name.Variable.Magic */
84 | .highlight .il { color: #990000 } /* Literal.Number.Integer.Long */
--------------------------------------------------------------------------------
/docs/api/_static/searchtools.js:
--------------------------------------------------------------------------------
1 | /*
2 | * searchtools.js
3 | * ~~~~~~~~~~~~~~~~
4 | *
5 | * Sphinx JavaScript utilities for the full-text search.
6 | *
7 | * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
8 | * :license: BSD, see LICENSE for details.
9 | *
10 | */
11 | "use strict";
12 |
13 | /**
14 | * Simple result scoring code.
15 | */
16 | if (typeof Scorer === "undefined") {
17 | var Scorer = {
18 | // Implement the following function to further tweak the score for each result
19 | // The function takes a result array [docname, title, anchor, descr, score, filename]
20 | // and returns the new score.
21 | /*
22 | score: result => {
23 | const [docname, title, anchor, descr, score, filename] = result
24 | return score
25 | },
26 | */
27 |
28 | // query matches the full name of an object
29 | objNameMatch: 11,
30 | // or matches in the last dotted part of the object name
31 | objPartialMatch: 6,
32 | // Additive scores depending on the priority of the object
33 | objPrio: {
34 | 0: 15, // used to be importantResults
35 | 1: 5, // used to be objectResults
36 | 2: -5, // used to be unimportantResults
37 | },
38 | // Used when the priority is not in the mapping.
39 | objPrioDefault: 0,
40 |
41 | // query found in title
42 | title: 15,
43 | partialTitle: 7,
44 | // query found in terms
45 | term: 5,
46 | partialTerm: 2,
47 | };
48 | }
49 |
50 | const _removeChildren = (element) => {
51 | while (element && element.lastChild) element.removeChild(element.lastChild);
52 | };
53 |
54 | /**
55 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Regular_Expressions#escaping
56 | */
57 | const _escapeRegExp = (string) =>
58 | string.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string
59 |
60 | const _displayItem = (item, searchTerms, highlightTerms) => {
61 | const docBuilder = DOCUMENTATION_OPTIONS.BUILDER;
62 | const docFileSuffix = DOCUMENTATION_OPTIONS.FILE_SUFFIX;
63 | const docLinkSuffix = DOCUMENTATION_OPTIONS.LINK_SUFFIX;
64 | const showSearchSummary = DOCUMENTATION_OPTIONS.SHOW_SEARCH_SUMMARY;
65 | const contentRoot = document.documentElement.dataset.content_root;
66 |
67 | const [docName, title, anchor, descr, score, _filename] = item;
68 |
69 | let listItem = document.createElement("li");
70 | let requestUrl;
71 | let linkUrl;
72 | if (docBuilder === "dirhtml") {
73 | // dirhtml builder
74 | let dirname = docName + "/";
75 | if (dirname.match(/\/index\/$/))
76 | dirname = dirname.substring(0, dirname.length - 6);
77 | else if (dirname === "index/") dirname = "";
78 | requestUrl = contentRoot + dirname;
79 | linkUrl = requestUrl;
80 | } else {
81 | // normal html builders
82 | requestUrl = contentRoot + docName + docFileSuffix;
83 | linkUrl = docName + docLinkSuffix;
84 | }
85 | let linkEl = listItem.appendChild(document.createElement("a"));
86 | linkEl.href = linkUrl + anchor;
87 | linkEl.dataset.score = score;
88 | linkEl.innerHTML = title;
89 | if (descr) {
90 | listItem.appendChild(document.createElement("span")).innerHTML =
91 | " (" + descr + ")";
92 | // highlight search terms in the description
93 | if (SPHINX_HIGHLIGHT_ENABLED) // set in sphinx_highlight.js
94 | highlightTerms.forEach((term) => _highlightText(listItem, term, "highlighted"));
95 | }
96 | else if (showSearchSummary)
97 | fetch(requestUrl)
98 | .then((responseData) => responseData.text())
99 | .then((data) => {
100 | if (data)
101 | listItem.appendChild(
102 | Search.makeSearchSummary(data, searchTerms, anchor)
103 | );
104 | // highlight search terms in the summary
105 | if (SPHINX_HIGHLIGHT_ENABLED) // set in sphinx_highlight.js
106 | highlightTerms.forEach((term) => _highlightText(listItem, term, "highlighted"));
107 | });
108 | Search.output.appendChild(listItem);
109 | };
110 | const _finishSearch = (resultCount) => {
111 | Search.stopPulse();
112 | Search.title.innerText = _("Search Results");
113 | if (!resultCount)
114 | Search.status.innerText = Documentation.gettext(
115 | "Your search did not match any documents. Please make sure that all words are spelled correctly and that you've selected enough categories."
116 | );
117 | else
118 | Search.status.innerText = _(
119 | "Search finished, found ${resultCount} page(s) matching the search query."
120 | ).replace('${resultCount}', resultCount);
121 | };
122 | const _displayNextItem = (
123 | results,
124 | resultCount,
125 | searchTerms,
126 | highlightTerms,
127 | ) => {
128 | // results left, load the summary and display it
129 | // this is intended to be dynamic (don't sub resultsCount)
130 | if (results.length) {
131 | _displayItem(results.pop(), searchTerms, highlightTerms);
132 | setTimeout(
133 | () => _displayNextItem(results, resultCount, searchTerms, highlightTerms),
134 | 5
135 | );
136 | }
137 | // search finished, update title and status message
138 | else _finishSearch(resultCount);
139 | };
140 | // Helper function used by query() to order search results.
141 | // Each input is an array of [docname, title, anchor, descr, score, filename].
142 | // Order the results by score (in opposite order of appearance, since the
143 | // `_displayNextItem` function uses pop() to retrieve items) and then alphabetically.
144 | const _orderResultsByScoreThenName = (a, b) => {
145 | const leftScore = a[4];
146 | const rightScore = b[4];
147 | if (leftScore === rightScore) {
148 | // same score: sort alphabetically
149 | const leftTitle = a[1].toLowerCase();
150 | const rightTitle = b[1].toLowerCase();
151 | if (leftTitle === rightTitle) return 0;
152 | return leftTitle > rightTitle ? -1 : 1; // inverted is intentional
153 | }
154 | return leftScore > rightScore ? 1 : -1;
155 | };
156 |
157 | /**
158 | * Default splitQuery function. Can be overridden in ``sphinx.search`` with a
159 | * custom function per language.
160 | *
161 | * The regular expression works by splitting the string on consecutive characters
162 | * that are not Unicode letters, numbers, underscores, or emoji characters.
163 | * This is the same as ``\W+`` in Python, preserving the surrogate pair area.
164 | */
165 | if (typeof splitQuery === "undefined") {
166 | var splitQuery = (query) => query
167 | .split(/[^\p{Letter}\p{Number}_\p{Emoji_Presentation}]+/gu)
168 | .filter(term => term) // remove remaining empty strings
169 | }
170 |
171 | /**
172 | * Search Module
173 | */
174 | const Search = {
175 | _index: null,
176 | _queued_query: null,
177 | _pulse_status: -1,
178 |
179 | htmlToText: (htmlString, anchor) => {
180 | const htmlElement = new DOMParser().parseFromString(htmlString, 'text/html');
181 | for (const removalQuery of [".headerlink", "script", "style"]) {
182 | htmlElement.querySelectorAll(removalQuery).forEach((el) => { el.remove() });
183 | }
184 | if (anchor) {
185 | const anchorContent = htmlElement.querySelector(`[role="main"] ${anchor}`);
186 | if (anchorContent) return anchorContent.textContent;
187 |
188 | console.warn(
189 | `Anchored content block not found. Sphinx search tries to obtain it via DOM query '[role=main] ${anchor}'. Check your theme or template.`
190 | );
191 | }
192 |
193 | // if anchor not specified or not found, fall back to main content
194 | const docContent = htmlElement.querySelector('[role="main"]');
195 | if (docContent) return docContent.textContent;
196 |
197 | console.warn(
198 | "Content block not found. Sphinx search tries to obtain it via DOM query '[role=main]'. Check your theme or template."
199 | );
200 | return "";
201 | },
202 |
203 | init: () => {
204 | const query = new URLSearchParams(window.location.search).get("q");
205 | document
206 | .querySelectorAll('input[name="q"]')
207 | .forEach((el) => (el.value = query));
208 | if (query) Search.performSearch(query);
209 | },
210 |
211 | loadIndex: (url) =>
212 | (document.body.appendChild(document.createElement("script")).src = url),
213 |
214 | setIndex: (index) => {
215 | Search._index = index;
216 | if (Search._queued_query !== null) {
217 | const query = Search._queued_query;
218 | Search._queued_query = null;
219 | Search.query(query);
220 | }
221 | },
222 |
223 | hasIndex: () => Search._index !== null,
224 |
225 | deferQuery: (query) => (Search._queued_query = query),
226 |
227 | stopPulse: () => (Search._pulse_status = -1),
228 |
229 | startPulse: () => {
230 | if (Search._pulse_status >= 0) return;
231 |
232 | const pulse = () => {
233 | Search._pulse_status = (Search._pulse_status + 1) % 4;
234 | Search.dots.innerText = ".".repeat(Search._pulse_status);
235 | if (Search._pulse_status >= 0) window.setTimeout(pulse, 500);
236 | };
237 | pulse();
238 | },
239 |
240 | /**
241 | * perform a search for something (or wait until index is loaded)
242 | */
243 | performSearch: (query) => {
244 | // create the required interface elements
245 | const searchText = document.createElement("h2");
246 | searchText.textContent = _("Searching");
247 | const searchSummary = document.createElement("p");
248 | searchSummary.classList.add("search-summary");
249 | searchSummary.innerText = "";
250 | const searchList = document.createElement("ul");
251 | searchList.classList.add("search");
252 |
253 | const out = document.getElementById("search-results");
254 | Search.title = out.appendChild(searchText);
255 | Search.dots = Search.title.appendChild(document.createElement("span"));
256 | Search.status = out.appendChild(searchSummary);
257 | Search.output = out.appendChild(searchList);
258 |
259 | const searchProgress = document.getElementById("search-progress");
260 | // Some themes don't use the search progress node
261 | if (searchProgress) {
262 | searchProgress.innerText = _("Preparing search...");
263 | }
264 | Search.startPulse();
265 |
266 | // index already loaded, the browser was quick!
267 | if (Search.hasIndex()) Search.query(query);
268 | else Search.deferQuery(query);
269 | },
270 |
271 | _parseQuery: (query) => {
272 | // stem the search terms and add them to the correct list
273 | const stemmer = new Stemmer();
274 | const searchTerms = new Set();
275 | const excludedTerms = new Set();
276 | const highlightTerms = new Set();
277 | const objectTerms = new Set(splitQuery(query.toLowerCase().trim()));
278 | splitQuery(query.trim()).forEach((queryTerm) => {
279 | const queryTermLower = queryTerm.toLowerCase();
280 |
281 | // maybe skip this "word"
282 | // stopwords array is from language_data.js
283 | if (
284 | stopwords.indexOf(queryTermLower) !== -1 ||
285 | queryTerm.match(/^\d+$/)
286 | )
287 | return;
288 |
289 | // stem the word
290 | let word = stemmer.stemWord(queryTermLower);
291 | // select the correct list
292 | if (word[0] === "-") excludedTerms.add(word.substr(1));
293 | else {
294 | searchTerms.add(word);
295 | highlightTerms.add(queryTermLower);
296 | }
297 | });
298 |
299 | if (SPHINX_HIGHLIGHT_ENABLED) { // set in sphinx_highlight.js
300 | localStorage.setItem("sphinx_highlight_terms", [...highlightTerms].join(" "))
301 | }
302 |
303 | // console.debug("SEARCH: searching for:");
304 | // console.info("required: ", [...searchTerms]);
305 | // console.info("excluded: ", [...excludedTerms]);
306 |
307 | return [query, searchTerms, excludedTerms, highlightTerms, objectTerms];
308 | },
309 |
310 | /**
311 | * execute search (requires search index to be loaded)
312 | */
313 | _performSearch: (query, searchTerms, excludedTerms, highlightTerms, objectTerms) => {
314 | const filenames = Search._index.filenames;
315 | const docNames = Search._index.docnames;
316 | const titles = Search._index.titles;
317 | const allTitles = Search._index.alltitles;
318 | const indexEntries = Search._index.indexentries;
319 |
320 | // Collect multiple result groups to be sorted separately and then ordered.
321 | // Each is an array of [docname, title, anchor, descr, score, filename].
322 | const normalResults = [];
323 | const nonMainIndexResults = [];
324 |
325 | _removeChildren(document.getElementById("search-progress"));
326 |
327 | const queryLower = query.toLowerCase().trim();
328 | for (const [title, foundTitles] of Object.entries(allTitles)) {
329 | if (title.toLowerCase().trim().includes(queryLower) && (queryLower.length >= title.length/2)) {
330 | for (const [file, id] of foundTitles) {
331 | const score = Math.round(Scorer.title * queryLower.length / title.length);
332 | const boost = titles[file] === title ? 1 : 0; // add a boost for document titles
333 | normalResults.push([
334 | docNames[file],
335 | titles[file] !== title ? `${titles[file]} > ${title}` : title,
336 | id !== null ? "#" + id : "",
337 | null,
338 | score + boost,
339 | filenames[file],
340 | ]);
341 | }
342 | }
343 | }
344 |
345 | // search for explicit entries in index directives
346 | for (const [entry, foundEntries] of Object.entries(indexEntries)) {
347 | if (entry.includes(queryLower) && (queryLower.length >= entry.length/2)) {
348 | for (const [file, id, isMain] of foundEntries) {
349 | const score = Math.round(100 * queryLower.length / entry.length);
350 | const result = [
351 | docNames[file],
352 | titles[file],
353 | id ? "#" + id : "",
354 | null,
355 | score,
356 | filenames[file],
357 | ];
358 | if (isMain) {
359 | normalResults.push(result);
360 | } else {
361 | nonMainIndexResults.push(result);
362 | }
363 | }
364 | }
365 | }
366 |
367 | // lookup as object
368 | objectTerms.forEach((term) =>
369 | normalResults.push(...Search.performObjectSearch(term, objectTerms))
370 | );
371 |
372 | // lookup as search terms in fulltext
373 | normalResults.push(...Search.performTermsSearch(searchTerms, excludedTerms));
374 |
375 | // let the scorer override scores with a custom scoring function
376 | if (Scorer.score) {
377 | normalResults.forEach((item) => (item[4] = Scorer.score(item)));
378 | nonMainIndexResults.forEach((item) => (item[4] = Scorer.score(item)));
379 | }
380 |
381 | // Sort each group of results by score and then alphabetically by name.
382 | normalResults.sort(_orderResultsByScoreThenName);
383 | nonMainIndexResults.sort(_orderResultsByScoreThenName);
384 |
385 | // Combine the result groups in (reverse) order.
386 | // Non-main index entries are typically arbitrary cross-references,
387 | // so display them after other results.
388 | let results = [...nonMainIndexResults, ...normalResults];
389 |
390 | // remove duplicate search results
391 | // note the reversing of results, so that in the case of duplicates, the highest-scoring entry is kept
392 | let seen = new Set();
393 | results = results.reverse().reduce((acc, result) => {
394 | let resultStr = result.slice(0, 4).concat([result[5]]).map(v => String(v)).join(',');
395 | if (!seen.has(resultStr)) {
396 | acc.push(result);
397 | seen.add(resultStr);
398 | }
399 | return acc;
400 | }, []);
401 |
402 | return results.reverse();
403 | },
404 |
405 | query: (query) => {
406 | const [searchQuery, searchTerms, excludedTerms, highlightTerms, objectTerms] = Search._parseQuery(query);
407 | const results = Search._performSearch(searchQuery, searchTerms, excludedTerms, highlightTerms, objectTerms);
408 |
409 | // for debugging
410 | //Search.lastresults = results.slice(); // a copy
411 | // console.info("search results:", Search.lastresults);
412 |
413 | // print the results
414 | _displayNextItem(results, results.length, searchTerms, highlightTerms);
415 | },
416 |
417 | /**
418 | * search for object names
419 | */
420 | performObjectSearch: (object, objectTerms) => {
421 | const filenames = Search._index.filenames;
422 | const docNames = Search._index.docnames;
423 | const objects = Search._index.objects;
424 | const objNames = Search._index.objnames;
425 | const titles = Search._index.titles;
426 |
427 | const results = [];
428 |
429 | const objectSearchCallback = (prefix, match) => {
430 | const name = match[4]
431 | const fullname = (prefix ? prefix + "." : "") + name;
432 | const fullnameLower = fullname.toLowerCase();
433 | if (fullnameLower.indexOf(object) < 0) return;
434 |
435 | let score = 0;
436 | const parts = fullnameLower.split(".");
437 |
438 | // check for different match types: exact matches of full name or
439 | // "last name" (i.e. last dotted part)
440 | if (fullnameLower === object || parts.slice(-1)[0] === object)
441 | score += Scorer.objNameMatch;
442 | else if (parts.slice(-1)[0].indexOf(object) > -1)
443 | score += Scorer.objPartialMatch; // matches in last name
444 |
445 | const objName = objNames[match[1]][2];
446 | const title = titles[match[0]];
447 |
448 | // If more than one term searched for, we require other words to be
449 | // found in the name/title/description
450 | const otherTerms = new Set(objectTerms);
451 | otherTerms.delete(object);
452 | if (otherTerms.size > 0) {
453 | const haystack = `${prefix} ${name} ${objName} ${title}`.toLowerCase();
454 | if (
455 | [...otherTerms].some((otherTerm) => haystack.indexOf(otherTerm) < 0)
456 | )
457 | return;
458 | }
459 |
460 | let anchor = match[3];
461 | if (anchor === "") anchor = fullname;
462 | else if (anchor === "-") anchor = objNames[match[1]][1] + "-" + fullname;
463 |
464 | const descr = objName + _(", in ") + title;
465 |
466 | // add custom score for some objects according to scorer
467 | if (Scorer.objPrio.hasOwnProperty(match[2]))
468 | score += Scorer.objPrio[match[2]];
469 | else score += Scorer.objPrioDefault;
470 |
471 | results.push([
472 | docNames[match[0]],
473 | fullname,
474 | "#" + anchor,
475 | descr,
476 | score,
477 | filenames[match[0]],
478 | ]);
479 | };
480 | Object.keys(objects).forEach((prefix) =>
481 | objects[prefix].forEach((array) =>
482 | objectSearchCallback(prefix, array)
483 | )
484 | );
485 | return results;
486 | },
487 |
488 | /**
489 | * search for full-text terms in the index
490 | */
491 | performTermsSearch: (searchTerms, excludedTerms) => {
492 | // prepare search
493 | const terms = Search._index.terms;
494 | const titleTerms = Search._index.titleterms;
495 | const filenames = Search._index.filenames;
496 | const docNames = Search._index.docnames;
497 | const titles = Search._index.titles;
498 |
499 | const scoreMap = new Map();
500 | const fileMap = new Map();
501 |
502 | // perform the search on the required terms
503 | searchTerms.forEach((word) => {
504 | const files = [];
505 | const arr = [
506 | { files: terms[word], score: Scorer.term },
507 | { files: titleTerms[word], score: Scorer.title },
508 | ];
509 | // add support for partial matches
510 | if (word.length > 2) {
511 | const escapedWord = _escapeRegExp(word);
512 | if (!terms.hasOwnProperty(word)) {
513 | Object.keys(terms).forEach((term) => {
514 | if (term.match(escapedWord))
515 | arr.push({ files: terms[term], score: Scorer.partialTerm });
516 | });
517 | }
518 | if (!titleTerms.hasOwnProperty(word)) {
519 | Object.keys(titleTerms).forEach((term) => {
520 | if (term.match(escapedWord))
521 | arr.push({ files: titleTerms[term], score: Scorer.partialTitle });
522 | });
523 | }
524 | }
525 |
526 | // no match but word was a required one
527 | if (arr.every((record) => record.files === undefined)) return;
528 |
529 | // found search word in contents
530 | arr.forEach((record) => {
531 | if (record.files === undefined) return;
532 |
533 | let recordFiles = record.files;
534 | if (recordFiles.length === undefined) recordFiles = [recordFiles];
535 | files.push(...recordFiles);
536 |
537 | // set score for the word in each file
538 | recordFiles.forEach((file) => {
539 | if (!scoreMap.has(file)) scoreMap.set(file, {});
540 | scoreMap.get(file)[word] = record.score;
541 | });
542 | });
543 |
544 | // create the mapping
545 | files.forEach((file) => {
546 | if (!fileMap.has(file)) fileMap.set(file, [word]);
547 | else if (fileMap.get(file).indexOf(word) === -1) fileMap.get(file).push(word);
548 | });
549 | });
550 |
551 | // now check if the files don't contain excluded terms
552 | const results = [];
553 | for (const [file, wordList] of fileMap) {
554 | // check if all requirements are matched
555 |
556 | // as search terms with length < 3 are discarded
557 | const filteredTermCount = [...searchTerms].filter(
558 | (term) => term.length > 2
559 | ).length;
560 | if (
561 | wordList.length !== searchTerms.size &&
562 | wordList.length !== filteredTermCount
563 | )
564 | continue;
565 |
566 | // ensure that none of the excluded terms is in the search result
567 | if (
568 | [...excludedTerms].some(
569 | (term) =>
570 | terms[term] === file ||
571 | titleTerms[term] === file ||
572 | (terms[term] || []).includes(file) ||
573 | (titleTerms[term] || []).includes(file)
574 | )
575 | )
576 | break;
577 |
578 | // select one (max) score for the file.
579 | const score = Math.max(...wordList.map((w) => scoreMap.get(file)[w]));
580 | // add result to the result list
581 | results.push([
582 | docNames[file],
583 | titles[file],
584 | "",
585 | null,
586 | score,
587 | filenames[file],
588 | ]);
589 | }
590 | return results;
591 | },
592 |
593 | /**
594 | * helper function to return a node containing the
595 | * search summary for a given text. keywords is a list
596 | * of stemmed words.
597 | */
598 | makeSearchSummary: (htmlText, keywords, anchor) => {
599 | const text = Search.htmlToText(htmlText, anchor);
600 | if (text === "") return null;
601 |
602 | const textLower = text.toLowerCase();
603 | const actualStartPosition = [...keywords]
604 | .map((k) => textLower.indexOf(k.toLowerCase()))
605 | .filter((i) => i > -1)
606 | .slice(-1)[0];
607 | const startWithContext = Math.max(actualStartPosition - 120, 0);
608 |
609 | const top = startWithContext === 0 ? "" : "...";
610 | const tail = startWithContext + 240 < text.length ? "..." : "";
611 |
612 | let summary = document.createElement("p");
613 | summary.classList.add("context");
614 | summary.textContent = top + text.substr(startWithContext, 240).trim() + tail;
615 |
616 | return summary;
617 | },
618 | };
619 |
620 | _ready(Search.init);
621 |
--------------------------------------------------------------------------------
/docs/api/_static/sphinx_highlight.js:
--------------------------------------------------------------------------------
1 | /* Highlighting utilities for Sphinx HTML documentation. */
2 | "use strict";
3 |
4 | const SPHINX_HIGHLIGHT_ENABLED = true
5 |
6 | /**
7 | * highlight a given string on a node by wrapping it in
8 | * span elements with the given class name.
9 | */
10 | const _highlight = (node, addItems, text, className) => {
11 | if (node.nodeType === Node.TEXT_NODE) {
12 | const val = node.nodeValue;
13 | const parent = node.parentNode;
14 | const pos = val.toLowerCase().indexOf(text);
15 | if (
16 | pos >= 0 &&
17 | !parent.classList.contains(className) &&
18 | !parent.classList.contains("nohighlight")
19 | ) {
20 | let span;
21 |
22 | const closestNode = parent.closest("body, svg, foreignObject");
23 | const isInSVG = closestNode && closestNode.matches("svg");
24 | if (isInSVG) {
25 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan");
26 | } else {
27 | span = document.createElement("span");
28 | span.classList.add(className);
29 | }
30 |
31 | span.appendChild(document.createTextNode(val.substr(pos, text.length)));
32 | const rest = document.createTextNode(val.substr(pos + text.length));
33 | parent.insertBefore(
34 | span,
35 | parent.insertBefore(
36 | rest,
37 | node.nextSibling
38 | )
39 | );
40 | node.nodeValue = val.substr(0, pos);
41 | /* There may be more occurrences of search term in this node. So call this
42 | * function recursively on the remaining fragment.
43 | */
44 | _highlight(rest, addItems, text, className);
45 |
46 | if (isInSVG) {
47 | const rect = document.createElementNS(
48 | "http://www.w3.org/2000/svg",
49 | "rect"
50 | );
51 | const bbox = parent.getBBox();
52 | rect.x.baseVal.value = bbox.x;
53 | rect.y.baseVal.value = bbox.y;
54 | rect.width.baseVal.value = bbox.width;
55 | rect.height.baseVal.value = bbox.height;
56 | rect.setAttribute("class", className);
57 | addItems.push({ parent: parent, target: rect });
58 | }
59 | }
60 | } else if (node.matches && !node.matches("button, select, textarea")) {
61 | node.childNodes.forEach((el) => _highlight(el, addItems, text, className));
62 | }
63 | };
64 | const _highlightText = (thisNode, text, className) => {
65 | let addItems = [];
66 | _highlight(thisNode, addItems, text, className);
67 | addItems.forEach((obj) =>
68 | obj.parent.insertAdjacentElement("beforebegin", obj.target)
69 | );
70 | };
71 |
72 | /**
73 | * Small JavaScript module for the documentation.
74 | */
75 | const SphinxHighlight = {
76 |
77 | /**
78 | * highlight the search words provided in localstorage in the text
79 | */
80 | highlightSearchWords: () => {
81 | if (!SPHINX_HIGHLIGHT_ENABLED) return; // bail if no highlight
82 |
83 | // get and clear terms from localstorage
84 | const url = new URL(window.location);
85 | const highlight =
86 | localStorage.getItem("sphinx_highlight_terms")
87 | || url.searchParams.get("highlight")
88 | || "";
89 | localStorage.removeItem("sphinx_highlight_terms")
90 | url.searchParams.delete("highlight");
91 | window.history.replaceState({}, "", url);
92 |
93 | // get individual terms from highlight string
94 | const terms = highlight.toLowerCase().split(/\s+/).filter(x => x);
95 | if (terms.length === 0) return; // nothing to do
96 |
97 | // There should never be more than one element matching "div.body"
98 | const divBody = document.querySelectorAll("div.body");
99 | const body = divBody.length ? divBody[0] : document.querySelector("body");
100 | window.setTimeout(() => {
101 | terms.forEach((term) => _highlightText(body, term, "highlighted"));
102 | }, 10);
103 |
104 | const searchBox = document.getElementById("searchbox");
105 | if (searchBox === null) return;
106 | searchBox.appendChild(
107 | document
108 | .createRange()
109 | .createContextualFragment(
110 | '
' +
111 | '' +
112 | _("Hide Search Matches") +
113 | "
"
114 | )
115 | );
116 | },
117 |
118 | /**
119 | * helper function to hide the search marks again
120 | */
121 | hideSearchWords: () => {
122 | document
123 | .querySelectorAll("#searchbox .highlight-link")
124 | .forEach((el) => el.remove());
125 | document
126 | .querySelectorAll("span.highlighted")
127 | .forEach((el) => el.classList.remove("highlighted"));
128 | localStorage.removeItem("sphinx_highlight_terms")
129 | },
130 |
131 | initEscapeListener: () => {
132 | // only install a listener if it is really needed
133 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return;
134 |
135 | document.addEventListener("keydown", (event) => {
136 | // bail for input elements
137 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return;
138 | // bail with special keys
139 | if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return;
140 | if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) {
141 | SphinxHighlight.hideSearchWords();
142 | event.preventDefault();
143 | }
144 | });
145 | },
146 | };
147 |
148 | _ready(() => {
149 | /* Do not call highlightSearchWords() when we are on the search page.
150 | * It will highlight words from the *previous* search query.
151 | */
152 | if (typeof Search === "undefined") SphinxHighlight.highlightSearchWords();
153 | SphinxHighlight.initEscapeListener();
154 | });
155 |
--------------------------------------------------------------------------------
/docs/api/genindex.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Index — LassoNet documentation
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
Index
35 |
36 |
37 |
F
38 | |
G
39 | |
L
40 | |
P
41 | |
S
42 |
43 |
44 |
F
45 |
63 |
64 |
G
65 |
99 |
100 |
L
101 |
121 |
122 |
P
123 |
145 |
146 |
S
147 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
269 |
270 |
271 |
279 |
280 |
281 |
282 |
283 |
284 |
--------------------------------------------------------------------------------
/docs/api/objects.inv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/api/objects.inv
--------------------------------------------------------------------------------
/docs/api/search.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Search — LassoNet documentation
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
Search
41 |
42 |
43 |
44 |
45 | Please activate JavaScript to enable the search
46 | functionality.
47 |
48 |
49 |
50 |
51 |
52 |
53 | Searching for multiple words only shows matches that contain
54 | all words.
55 |
56 |
57 |
58 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
102 |
103 |
104 |
112 |
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/docs/api/searchindex.js:
--------------------------------------------------------------------------------
1 | Search.setIndex({"alltitles": {"API": [[0, "api"]], "Installation": [[0, "installation"]], "Welcome to LassoNet\u2019s documentation!": [[0, null]]}, "docnames": ["index"], "envversion": {"sphinx": 63, "sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2}, "filenames": ["index.rst"], "indexentries": {"fit() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.fit", false]], "fit() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.fit", false]], "fit() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.fit", false]], "fit() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.fit", false]], "fit() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.fit", false]], "fit() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.fit", false]], "get_metadata_routing() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.get_metadata_routing", false]], "get_metadata_routing() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.get_metadata_routing", false]], "get_params() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.get_params", false]], "get_params() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.get_params", false]], "get_params() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.get_params", false]], "get_params() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.get_params", false]], "get_params() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.get_params", false]], "get_params() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.get_params", false]], "lassonet_path() (in module lassonet)": [[0, "lassonet.lassonet_path", false]], "lassonetclassifier (class in lassonet)": [[0, "lassonet.LassoNetClassifier", false]], "lassonetclassifiercv (class in lassonet)": [[0, "lassonet.LassoNetClassifierCV", false]], "lassonetcoxregressor (class in lassonet)": [[0, "lassonet.LassoNetCoxRegressor", false]], "lassonetcoxregressorcv (class in lassonet)": [[0, "lassonet.LassoNetCoxRegressorCV", false]], "lassonetregressor (class in lassonet)": [[0, "lassonet.LassoNetRegressor", false]], "lassonetregressorcv (class in lassonet)": [[0, "lassonet.LassoNetRegressorCV", false]], "path() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.path", false]], "path() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.path", false]], "path() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.path", false]], "path() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.path", false]], "path() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.path", false]], "path() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.path", false]], "plot_path() (in module lassonet)": [[0, "lassonet.plot_path", false]], "score() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.score", false]], "score() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.score", false]], "score() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.score", false]], "score() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.score", false]], "score() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.score", false]], "score() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.score", false]], "set_fit_request() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.set_fit_request", false]], "set_fit_request() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.set_fit_request", false]], "set_fit_request() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.set_fit_request", false]], "set_fit_request() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.set_fit_request", false]], "set_fit_request() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.set_fit_request", false]], "set_fit_request() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.set_fit_request", false]], "set_params() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.set_params", false]], "set_params() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.set_params", false]], "set_params() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.set_params", false]], "set_params() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.set_params", false]], "set_params() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.set_params", false]], "set_params() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.set_params", false]], "set_score_request() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.set_score_request", false]], "set_score_request() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.set_score_request", false]], "set_score_request() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.set_score_request", false]], "set_score_request() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.set_score_request", false]], "set_score_request() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.set_score_request", false]], "set_score_request() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.set_score_request", false]], "stability_selection() (lassonet.lassonetclassifier method)": [[0, "lassonet.LassoNetClassifier.stability_selection", false]], "stability_selection() (lassonet.lassonetclassifiercv method)": [[0, "lassonet.LassoNetClassifierCV.stability_selection", false]], "stability_selection() (lassonet.lassonetcoxregressor method)": [[0, "lassonet.LassoNetCoxRegressor.stability_selection", false]], "stability_selection() (lassonet.lassonetcoxregressorcv method)": [[0, "lassonet.LassoNetCoxRegressorCV.stability_selection", false]], "stability_selection() (lassonet.lassonetregressor method)": [[0, "lassonet.LassoNetRegressor.stability_selection", false]], "stability_selection() (lassonet.lassonetregressorcv method)": [[0, "lassonet.LassoNetRegressorCV.stability_selection", false]]}, "objects": {"lassonet": [[0, 0, 1, "", "LassoNetClassifier"], [0, 0, 1, "", "LassoNetClassifierCV"], [0, 0, 1, "", "LassoNetCoxRegressor"], [0, 0, 1, "", "LassoNetCoxRegressorCV"], [0, 0, 1, "", "LassoNetRegressor"], [0, 0, 1, "", "LassoNetRegressorCV"], [0, 2, 1, "", "lassonet_path"], [0, 2, 1, "", "plot_path"]], "lassonet.LassoNetClassifier": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetClassifierCV": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetCoxRegressor": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetCoxRegressorCV": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetRegressor": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]], "lassonet.LassoNetRegressorCV": [[0, 1, 1, "", "fit"], [0, 1, 1, "", "get_metadata_routing"], [0, 1, 1, "", "get_params"], [0, 1, 1, "", "path"], [0, 1, 1, "", "score"], [0, 1, 1, "", "set_fit_request"], [0, 1, 1, "", "set_params"], [0, 1, 1, "", "set_score_request"], [0, 1, 1, "", "stability_selection"]]}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "method", "Python method"], "2": ["py", "function", "Python function"]}, "objtypes": {"0": "py:class", "1": "py:method", "2": "py:function"}, "terms": {"0": 0, "02": 0, "1": 0, "10": 0, "100": 0, "1000": 0, "1e": 0, "2": 0, "20": 0, "23": 0, "3": 0, "5": 0, "9": 0, "99": 0, "A": 0, "For": 0, "If": 0, "In": 0, "The": 0, "There": 0, "To": 0, "__": 0, "accord": 0, "accuraci": 0, "ad": 0, "adam": 0, "alia": 0, "all": 0, "allow": 0, "alwai": 0, "an": 0, "ani": 0, "anymor": 0, "approxim": 0, "ar": 0, "arbitrarili": 0, "arrai": 0, "auto": 0, "automat": 0, "avail": 0, "backtrack": 0, "baselassonet": 0, "batch": 0, "batch_siz": 0, "becaus": 0, "beforehand": 0, "being": 0, "best": 0, "bool": 0, "bound": 0, "breslow": 0, "call": 0, "callback": 0, "can": 0, "chang": 0, "check": 0, "check_cv": 0, "class": 0, "class_weight": 0, "classif": 0, "classifi": 0, "coeffici": 0, "compon": 0, "comput": 0, "concord": 0, "connect": 0, "consist": 0, "constant": 0, "contain": 0, "correctli": 0, "cox": 0, "cpu": 0, "cross": 0, "cv": 0, "data": 0, "decreas": 0, "deep": 0, "default": 0, "defin": 0, "dens": 0, "dense_onli": 0, "determin": 0, "devic": 0, "dict": 0, "differ": 0, "dimens": 0, "disabl": 0, "disable_lambda_warn": 0, "disregard": 0, "doe": 0, "dropout": 0, "durat": 0, "dure": 0, "e": 0, "each": 0, "earli": 0, "effect": 0, "efron": 0, "els": 0, "enable_metadata_rout": 0, "encapsul": 0, "ensur": 0, "epoch": 0, "epsilon": 0, "error": 0, "estim": 0, "event": 0, "evolut": 0, "except": 0, "exist": 0, "expect": 0, "factor": 0, "fals": 0, "featur": 0, "first": 0, "fit": 0, "float": 0, "fold": 0, "form": 0, "frac": 0, "from": 0, "function": 0, "g": 0, "gamma": 0, "gamma_skip": 0, "gener": 0, "get": 0, "get_metadata_rout": 0, "get_param": 0, "given": 0, "go": 0, "gpu": 0, "group": 0, "guid": 0, "ha": 0, "harsh": 0, "have": 0, "hidden": 0, "hidden_dim": 0, "hierarchi": 0, "histori": 0, "historyitem": 0, "how": 0, "html": 0, "http": 0, "i": 0, "ideal": 0, "ignor": 0, "improv": 0, "increas": 0, "increment": 0, "index": 0, "indic": 0, "inf": 0, "influenc": 0, "inform": 0, "initi": 0, "input": 0, "insid": 0, "instanc": 0, "instead": 0, "int": 0, "iter": 0, "keep": 0, "kernel": 0, "kwarg": 0, "l2": 0, "label": 0, "lambda": 0, "lambda_": 0, "lambda_max": 0, "lambda_seq": 0, "lambda_start": 0, "lassonet_path": 0, "lassonetclassifi": 0, "lassonetclassifiercv": 0, "lassonetcoxregressor": 0, "lassonetcoxregressorcv": 0, "lassonetregressor": 0, "lassonetregressorcv": 0, "latter": 0, "layer": 0, "learn": 0, "leav": 0, "like": 0, "list": 0, "longtensor": 0, "lr": 0, "m": 0, "mai": 0, "main": 0, "map": 0, "matrix": 0, "maximum": 0, "mean": 0, "mechan": 0, "meta": 0, "metadata": 0, "metadata_rout": 0, "metadatarequest": 0, "method": 0, "metric": 0, "minimum": 0, "model": 0, "model_select": 0, "modul": 0, "momentum": 0, "most": 0, "multi": 0, "multioutput": 0, "multioutputregressor": 0, "multipl": 0, "must": 0, "n_featur": 0, "n_iter": 0, "n_model": 0, "n_output": 0, "n_sampl": 0, "n_samples_fit": 0, "n_step": 0, "name": 0, "neg": 0, "nest": 0, "network": 0, "new": 0, "none": 0, "note": 0, "number": 0, "object": 0, "old": 0, "one": 0, "onli": 0, "optim": 0, "option": 0, "oracl": 0, "order": 0, "org": 0, "origin": 0, "other": 0, "otherwis": 0, "output": [], "over": 0, "pair": 0, "param": 0, "paramet": 0, "pass": 0, "path": 0, "path_multipli": 0, "patienc": 0, "penal": 0, "penalti": 0, "per": 0, "pip": 0, "pipelin": 0, "pleas": 0, "plot": 0, "plot_path": 0, "possibl": 0, "precomput": 0, "predict": 0, "prob": 0, "probabl": 0, "proport": 0, "provid": 0, "pytorch": 0, "r": 0, "r2_score": 0, "rais": 0, "random": 0, "random_st": 0, "regress": 0, "regressor": 0, "regular": 0, "relev": 0, "request": 0, "requir": 0, "residu": 0, "retain": 0, "return": 0, "return_state_dict": 0, "rout": 0, "sampl": 0, "sample_weight": 0, "scikit": 0, "score": 0, "score_funct": 0, "see": 0, "select": 0, "selection_prob": 0, "self": 0, "sequenc": 0, "set": 0, "set_config": 0, "set_fit_request": 0, "set_param": 0, "set_score_request": 0, "sgd": 0, "shape": 0, "should": 0, "shuffl": 0, "simpl": 0, "sinc": 0, "skip": 0, "sklearn": 0, "so": 0, "some": 0, "specifi": 0, "split": 0, "squar": 0, "stabil": 0, "stability_select": 0, "stabl": 0, "start": 0, "state": 0, "step": 0, "stop": 0, "str": 0, "strategi": 0, "sub": 0, "subobject": 0, "subset": 0, "sum": 0, "t": 0, "take": 0, "target": 0, "task": 0, "tensor": 0, "test": 0, "th": 0, "thi": 0, "tie": 0, "tie_approxim": 0, "tol": 0, "torch": 0, "torch_se": 0, "total": 0, "train": 0, "true": 0, "tupl": 0, "two": 0, "type": 0, "u": 0, "unchang": 0, "uniform_averag": 0, "until": 0, "updat": 0, "upper": 0, "us": 0, "user": 0, "util": 0, "v": 0, "val_siz": 0, "valid": 0, "valu": 0, "variabl": 0, "verbos": 0, "version": 0, "w": 0, "wait": 0, "websit": 0, "weight": 0, "well": 0, "when": 0, "where": 0, "which": 0, "without": 0, "work": 0, "wors": 0, "would": 0, "wrong": 0, "x": 0, "x_test": 0, "x_val": 0, "y": 0, "y_pred": 0, "y_test": 0, "y_true": 0, "y_val": 0, "you": 0, "zero": 0}, "titles": ["Welcome to LassoNet\u2019s documentation!"], "titleterms": {"": 0, "api": 0, "document": 0, "instal": 0, "lassonet": 0, "welcom": 0}})
--------------------------------------------------------------------------------
/docs/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/fig1.png
--------------------------------------------------------------------------------
/docs/images/video_screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/docs/images/video_screenshot.png
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | LassoNet: Neural Networks with Feature Sparsity
7 |
8 |
9 |
11 |
13 |
44 |
45 |
46 |
48 |
49 |
50 |
51 |
52 | LassoNet: Neural Networks with Feature Sparsity
53 |
54 |
69 |
70 |
89 |
90 |
91 |
92 |
93 | LassoNet is a method for feature selection in neural networks, to enhance interpretability of the final
94 | network.
95 |
96 | It uses a novel objective function and learning algorithm, that encourage the network to use only a subset of
97 | the available input features, that is the resulting network is "feature sparse"
98 | This is achieved not by post-hoc analysis of a standard neural network but is built into objective
99 | function
100 | itself :
101 |
102 | Input to output (skip layer) connections are added to the network with an L1 penalty on its weights
103 | The weight for each feature in this layer acts as an upper bound for all hidden layer weights involving
104 | that feature
105 |
106 |
107 | The result is an entire path of network solutions , with varying amounts of feature sparsity. This is
108 | analogous to the lasso solution path for linear regression
109 |
110 |
111 |
112 |
113 |
114 |
115 | LassoNet in 2 minutes
116 |
118 |
119 |
120 |
121 | Installation
122 |
123 | pip install lassonet
124 |
125 |
126 |
127 |
Tips
128 |
129 | LassoNet sometimes require fine tuning. For optimal performance, consider:
130 |
131 | standardizing the inputs
132 | making sure that the initial dense model (with $\lambda = 0$) has trained well, before starting the LassoNet
133 | regularization
134 | path. This may involve hyper-parameter tuning, choosing the right optimizer, and so on. If the dense model is
135 | underperforming, it is likely that the sparser models will as well.
136 | making sure the stepsize over the $\lambda$ path is not too large. By default, the stepsize runs in geometric
137 | increments until there is no feature left.
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
Intro video
146 | VIDEO
149 |
150 | Talk
151 | VIDEO
154 |
155 |
156 |
157 |
158 |
159 |
160 |
Citation
161 | The algorithms and method used in this package came primarily out of research in Rob Tibshirani's lab at Stanford
162 | University. If you use LassoNet in your research we would appreciate a citation to the paper:
163 |
164 | @article{lemhadri2019neural,
165 | title={LassoNet: Neural Networks with Feature Sparsity},
166 | author={Lemhadri, Ismael and Ruan, Feng and
167 | Abraham, Louis and Tibshirani, Robert},
168 | journal={arXiv preprint arXiv:1907.12207},
169 | year={2019}
170 | }
171 |
172 |
173 |
174 |
175 |
176 |
--------------------------------------------------------------------------------
/docs/style.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif;
3 | font-weight: 300;
4 | font-size: 18px;
5 | margin-left: auto;
6 | margin-right: auto;
7 | width: 1100px;
8 | }
9 |
10 | h1 {
11 | font-weight: 300;
12 | }
13 |
14 | .disclaimerbox {
15 | background-color: #eee;
16 | border: 1px solid #eeeeee;
17 | border-radius: 10px;
18 | -moz-border-radius: 10px;
19 | -webkit-border-radius: 10px;
20 | padding: 20px;
21 | }
22 |
23 | video.header-vid {
24 | height: 140px;
25 | border: 1px solid black;
26 | border-radius: 10px;
27 | -moz-border-radius: 10px;
28 | -webkit-border-radius: 10px;
29 | }
30 |
31 | img.header-img {
32 | height: 140px;
33 | border: 1px solid black;
34 | border-radius: 10px;
35 | -moz-border-radius: 10px;
36 | -webkit-border-radius: 10px;
37 | }
38 |
39 | img.rounded {
40 | border: 1px solid #eeeeee;
41 | border-radius: 10px;
42 | -moz-border-radius: 10px;
43 | -webkit-border-radius: 10px;
44 | }
45 |
46 | a:link,
47 | a:visited {
48 | color: #1367a7;
49 | text-decoration: none;
50 | }
51 |
52 | a:hover {
53 | color: #208799;
54 | }
55 |
56 | td.dl-link {
57 | height: 160px;
58 | text-align: center;
59 | font-size: 22px;
60 | }
61 |
62 | .layered-paper-big {
63 | /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
64 | box-shadow:
65 | 0px 0px 1px 1px rgba(0, 0, 0, 0.35),
66 | /* The top layer shadow */
67 | 5px 5px 0 0px #fff,
68 | /* The second layer */
69 | 5px 5px 1px 1px rgba(0, 0, 0, 0.35),
70 | /* The second layer shadow */
71 | 10px 10px 0 0px #fff,
72 | /* The third layer */
73 | 10px 10px 1px 1px rgba(0, 0, 0, 0.35),
74 | /* The third layer shadow */
75 | 15px 15px 0 0px #fff,
76 | /* The fourth layer */
77 | 15px 15px 1px 1px rgba(0, 0, 0, 0.35),
78 | /* The fourth layer shadow */
79 | 20px 20px 0 0px #fff,
80 | /* The fifth layer */
81 | 20px 20px 1px 1px rgba(0, 0, 0, 0.35),
82 | /* The fifth layer shadow */
83 | 25px 25px 0 0px #fff,
84 | /* The fifth layer */
85 | 25px 25px 1px 1px rgba(0, 0, 0, 0.35);
86 | /* The fifth layer shadow */
87 | margin-left: 10px;
88 | margin-right: 45px;
89 | }
90 |
91 | .paper-big {
92 | /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
93 | box-shadow:
94 | 0px 0px 1px 1px rgba(0, 0, 0, 0.35);
95 | /* The top layer shadow */
96 |
97 | margin-left: 10px;
98 | margin-right: 45px;
99 | }
100 |
101 |
102 | .layered-paper {
103 | /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
104 | box-shadow:
105 | 0px 0px 1px 1px rgba(0, 0, 0, 0.35),
106 | /* The top layer shadow */
107 | 5px 5px 0 0px #fff,
108 | /* The second layer */
109 | 5px 5px 1px 1px rgba(0, 0, 0, 0.35),
110 | /* The second layer shadow */
111 | 10px 10px 0 0px #fff,
112 | /* The third layer */
113 | 10px 10px 1px 1px rgba(0, 0, 0, 0.35);
114 | /* The third layer shadow */
115 | margin-top: 5px;
116 | margin-left: 10px;
117 | margin-right: 30px;
118 | margin-bottom: 5px;
119 | }
120 |
121 | .vert-cent {
122 | position: relative;
123 | top: 50%;
124 | transform: translateY(-50%);
125 | }
126 |
127 | hr {
128 | border: 0;
129 | height: 1px;
130 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
131 | }
132 |
133 | .center {
134 | margin: auto;
135 | text-align: center;
136 | }
137 |
138 | img.center {
139 | display: block;
140 |
141 | }
142 |
143 | h1 {
144 | text-align: center;
145 | font-size: 42px
146 | }
147 |
148 | #authors {
149 | font-size: 24px;
150 | width: 800px;
151 | }
152 |
153 | #authors td {
154 | width: 100px;
155 | }
156 |
157 | #links td {
158 | font-size: 24px;
159 | width: 130px;
160 | }
161 |
162 |
163 | iframe {
164 | width: 792px;
165 | height: 328px;
166 | display: block;
167 | border-style: none;
168 | margin: 0 auto;
169 |
170 | }
--------------------------------------------------------------------------------
/examples/accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lasso-net/lassonet/f2bf7d21274b21519dddc52e759a137b1aa868c4/examples/accuracy.png
--------------------------------------------------------------------------------
/examples/boston_housing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from sklearn.datasets import load_boston
6 | from sklearn.metrics import mean_squared_error
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.preprocessing import StandardScaler, scale
9 |
10 | from lassonet import LassoNetRegressor
11 |
12 | dataset = load_boston()
13 | X = dataset.data
14 | y = dataset.target
15 | _, true_features = X.shape
16 | # add dummy feature
17 | X = np.concatenate([X, np.random.randn(*X.shape)], axis=1)
18 | feature_names = list(dataset.feature_names) + ["fake"] * true_features
19 |
20 | # standardize
21 | X = StandardScaler().fit_transform(X)
22 | y = scale(y)
23 |
24 |
25 | X_train, X_test, y_train, y_test = train_test_split(X, y)
26 |
27 | model = LassoNetRegressor(
28 | hidden_dims=(10,),
29 | verbose=True,
30 | patience=(100, 5),
31 | )
32 | path = model.path(X_train, y_train)
33 |
34 | n_selected = []
35 | mse = []
36 | lambda_ = []
37 |
38 | for save in path:
39 | model.load(save.state_dict)
40 | y_pred = model.predict(X_test)
41 | n_selected.append(save.selected.sum().cpu().numpy())
42 | mse.append(mean_squared_error(y_test, y_pred))
43 | lambda_.append(save.lambda_)
44 |
45 |
46 | fig = plt.figure(figsize=(12, 12))
47 |
48 | plt.subplot(311)
49 | plt.grid(True)
50 | plt.plot(n_selected, mse, ".-")
51 | plt.xlabel("number of selected features")
52 | plt.ylabel("MSE")
53 |
54 | plt.subplot(312)
55 | plt.grid(True)
56 | plt.plot(lambda_, mse, ".-")
57 | plt.xlabel("lambda")
58 | plt.xscale("log")
59 | plt.ylabel("MSE")
60 |
61 | plt.subplot(313)
62 | plt.grid(True)
63 | plt.plot(lambda_, n_selected, ".-")
64 | plt.xlabel("lambda")
65 | plt.xscale("log")
66 | plt.ylabel("number of selected features")
67 |
68 | plt.savefig("boston.png")
69 |
70 | plt.clf()
71 |
72 | n_features = X.shape[1]
73 | importances = model.feature_importances_.numpy()
74 | order = np.argsort(importances)[::-1]
75 | importances = importances[order]
76 | ordered_feature_names = [feature_names[i] for i in order]
77 | color = np.array(["g"] * true_features + ["r"] * (n_features - true_features))[order]
78 |
79 |
80 | plt.subplot(211)
81 | plt.bar(
82 | np.arange(n_features),
83 | importances,
84 | color=color,
85 | )
86 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90)
87 | colors = {"real features": "g", "fake features": "r"}
88 | labels = list(colors.keys())
89 | handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label]) for label in labels]
90 | plt.legend(handles, labels)
91 | plt.ylabel("Feature importance")
92 |
93 | _, order = np.unique(importances, return_inverse=True)
94 |
95 | plt.subplot(212)
96 | plt.bar(
97 | np.arange(n_features),
98 | order + 1,
99 | color=color,
100 | )
101 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90)
102 | plt.legend(handles, labels)
103 | plt.ylabel("Feature order")
104 |
105 | plt.savefig("boston-bar.png")
106 |
--------------------------------------------------------------------------------
/examples/boston_housing_group.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from sklearn.datasets import load_boston
6 | from sklearn.metrics import mean_squared_error
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.preprocessing import StandardScaler, scale
9 |
10 | from lassonet import LassoNetRegressor
11 |
12 | dataset = load_boston()
13 | X = dataset.data
14 | y = dataset.target
15 | _, true_features = X.shape
16 | # add dummy feature
17 | X = np.concatenate([X, np.random.randn(*X.shape)], axis=1)
18 | feature_names = list(dataset.feature_names) + ["fake"] * true_features
19 |
20 | # standardize
21 | X = StandardScaler().fit_transform(X)
22 | y = scale(y)
23 |
24 |
25 | X_train, X_test, y_train, y_test = train_test_split(X, y)
26 |
27 | model = LassoNetRegressor(
28 | hidden_dims=(10,),
29 | verbose=True,
30 | patience=(100, 5),
31 | groups=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12], list(range(13, 26))],
32 | )
33 | path = model.path(X_train, y_train)
34 |
35 | n_selected = []
36 | mse = []
37 | lambda_ = []
38 |
39 | for save in path:
40 | model.load(save.state_dict)
41 | y_pred = model.predict(X_test)
42 | n_selected.append(save.selected.sum())
43 | mse.append(mean_squared_error(y_test, y_pred))
44 | lambda_.append(save.lambda_)
45 |
46 |
47 | fig = plt.figure(figsize=(12, 12))
48 |
49 | plt.subplot(311)
50 | plt.grid(True)
51 | plt.plot(n_selected, mse, ".-")
52 | plt.xlabel("number of selected features")
53 | plt.ylabel("MSE")
54 |
55 | plt.subplot(312)
56 | plt.grid(True)
57 | plt.plot(lambda_, mse, ".-")
58 | plt.xlabel("lambda")
59 | plt.xscale("log")
60 | plt.ylabel("MSE")
61 |
62 | plt.subplot(313)
63 | plt.grid(True)
64 | plt.plot(lambda_, n_selected, ".-")
65 | plt.xlabel("lambda")
66 | plt.xscale("log")
67 | plt.ylabel("number of selected features")
68 |
69 | plt.savefig("boston-group.png")
70 |
71 | plt.clf()
72 |
73 | n_features = X.shape[1]
74 | importances = model.feature_importances_.numpy()
75 | order = np.argsort(importances)[::-1]
76 | importances = importances[order]
77 | ordered_feature_names = [feature_names[i] for i in order]
78 | color = np.array(["g"] * true_features + ["r"] * (n_features - true_features))[order]
79 |
80 |
81 | plt.subplot(211)
82 | plt.bar(
83 | np.arange(n_features),
84 | importances,
85 | color=color,
86 | )
87 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90)
88 | colors = {"real features": "g", "fake features": "r"}
89 | labels = list(colors.keys())
90 | handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label]) for label in labels]
91 | plt.legend(handles, labels)
92 | plt.ylabel("Feature importance")
93 |
94 | _, order = np.unique(importances, return_inverse=True)
95 |
96 | plt.subplot(212)
97 | plt.bar(
98 | np.arange(n_features),
99 | order + 1,
100 | color=color,
101 | )
102 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90)
103 | plt.legend(handles, labels)
104 | plt.ylabel("Feature order")
105 |
106 | plt.savefig("boston-bar-group.png")
107 |
--------------------------------------------------------------------------------
/examples/cox_experiments.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Install required packages with:
4 |
5 | pip install scipy joblib tqdm_joblib
6 | """
7 |
8 |
9 | from pathlib import Path
10 | from time import time
11 |
12 | import numpy as np
13 | import pandas as pd
14 | import sksurv.datasets
15 | from joblib import Parallel, delayed
16 | from matplotlib import pyplot as plt
17 | from sklearn import preprocessing
18 | from sklearn.model_selection import StratifiedKFold, train_test_split
19 | from tqdm import tqdm
20 | from tqdm_joblib import tqdm_joblib
21 |
22 | from lassonet import LassoNetCoxRegressorCV, plot_cv
23 | from lassonet.utils import confidence_interval
24 |
25 | DATA_PATH = Path(__file__).parent / "data"
26 | DATA_PATH.mkdir(exist_ok=True)
27 |
28 | FIGURES_PATH = Path() / "cox_figures"
29 | FIGURES_PATH.mkdir(exist_ok=True)
30 |
31 |
32 | def transform_one_hot(input_matrix, col_name):
33 | one_hot_col = pd.get_dummies(input_matrix[col_name], prefix=col_name)
34 | input_matrix = input_matrix.drop([col_name], axis=1)
35 | input_matrix = input_matrix.join(one_hot_col)
36 | return input_matrix
37 |
38 |
39 | def dump(array, name):
40 | pd.DataFrame(array).to_csv(DATA_PATH / name, index=False)
41 |
42 |
43 | def gen_data(dataset):
44 | if dataset == "breast":
45 | X, y = sksurv.datasets.load_breast_cancer()
46 | di_er = {"negative": 0, "positive": 1}
47 | di_grade = {
48 | "poorly differentiated": -1,
49 | "intermediate": 0,
50 | "well differentiated": 1,
51 | "unkown": 0,
52 | }
53 | X = X.replace({"er": di_er, "grade": di_grade})
54 | y_temp = pd.DataFrame(y, columns=["t.tdm", "e.tdm"])
55 | di_event = {True: 1, False: 0}
56 | y_temp = y_temp.replace({"e.tdm": di_event})
57 | y = y_temp
58 |
59 | elif dataset == "fl_chain":
60 | X, y = sksurv.datasets.load_flchain()
61 | di_mgus = {"no": 0, "yes": 1}
62 | X = X.replace({"mgus": di_mgus, "creatinine": {np.nan: 0}})
63 | col_names = ["chapter", "sex", "sample.yr", "flc.grp"]
64 | for col_name in col_names:
65 | X = transform_one_hot(X, col_name)
66 | y_temp = pd.DataFrame(y, columns=["futime", "death"])
67 | di_event = {True: 0, False: 1}
68 | y_temp = y_temp.replace({"death": di_event})
69 | y = y_temp
70 |
71 | elif dataset == "whas500":
72 | X, y = sksurv.datasets.load_whas500()
73 | y_temp = pd.DataFrame(y, columns=["lenfol", "fstat"])
74 | di_event = {True: 1, False: 0}
75 | y_temp = y_temp.replace({"fstat": di_event})
76 | y = y_temp
77 |
78 | elif dataset == "veterans":
79 | X, y = sksurv.datasets.load_veterans_lung_cancer()
80 | col_names = ["Celltype", "Prior_therapy", "Treatment"]
81 | for col_name in col_names:
82 | X = transform_one_hot(X, col_name)
83 | y_temp = pd.DataFrame(y, columns=["Survival_in_days", "Status"])
84 | di_event = {False: 0, True: 1}
85 | y_temp = y_temp.replace({"Status": di_event})
86 | y = y_temp
87 | elif dataset == "hnscc":
88 | raise ValueError("Dataset exists")
89 | else:
90 | raise ValueError("Dataset unknown")
91 |
92 | dump(X, f"{dataset}_x.csv")
93 | dump(y, f"{dataset}_y.csv")
94 |
95 |
96 | def load_data(dataset):
97 | path_x = DATA_PATH / f"{dataset}_x.csv"
98 | path_y = DATA_PATH / f"{dataset}_y.csv"
99 | if not (path_x.exists() and path_y.exists()):
100 | gen_data(dataset)
101 | X = np.genfromtxt(path_x, delimiter=",", skip_header=1)
102 | y = np.genfromtxt(path_y, delimiter=",", skip_header=1)
103 | X = preprocessing.StandardScaler().fit(X).transform(X)
104 | return X, y
105 |
106 |
107 | def run(
108 | dataset,
109 | hidden_dims,
110 | path_multiplier,
111 | *,
112 | random_state,
113 | tie_approximation="breslow",
114 | dump_splits=False,
115 | verbose=False,
116 | ):
117 | X, y = load_data(dataset)
118 | X_train, X_test, y_train, y_test = train_test_split(
119 | X, y, random_state=random_state, stratify=y[:, 1], test_size=0.20
120 | )
121 |
122 | if dump_splits:
123 | for array, name in [
124 | (X_train, "x_train"),
125 | (y_train, "y_train"),
126 | (X_test, "x_test"),
127 | (y_test, "y_test"),
128 | ]:
129 | dump(array, f"{dataset}_{name}_{random_state}.csv")
130 |
131 | cv = list(
132 | StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state).split(
133 | X_train, y_train[:, 1]
134 | )
135 | )
136 |
137 | model = LassoNetCoxRegressorCV(
138 | tie_approximation=tie_approximation,
139 | hidden_dims=hidden_dims,
140 | path_multiplier=path_multiplier,
141 | cv=cv,
142 | torch_seed=random_state,
143 | verbose=verbose,
144 | )
145 | model.path(X_train, y_train)
146 | plot_cv(model, X_test, y_test)
147 | plt.savefig(
148 | FIGURES_PATH
149 | / (
150 | f"cox-cv-{dataset}-{random_state}"
151 | f"-{model.hidden_dims}-{model.path_multiplier}.jpg"
152 | ),
153 | dpi=300,
154 | )
155 |
156 | test_score = model.score(X_test, y_test)
157 |
158 | return test_score
159 |
160 |
161 | def experiment(dataset, hidden_dims, path_multiplier, n_runs, n_jobs):
162 | start = time()
163 | with tqdm_joblib(desc=f"Running on {dataset}", total=n_runs):
164 | scores = np.array(
165 | Parallel(n_jobs=n_jobs)(
166 | delayed(run)(
167 | dataset,
168 | hidden_dims=hidden_dims,
169 | path_multiplier=path_multiplier,
170 | tie_approximation=tie_approximation,
171 | random_state=random_state,
172 | )
173 | for random_state in range(n_runs)
174 | )
175 | )
176 | scores_str = np.array2string(
177 | scores, separator=", ", formatter={"float_kind": "{:.2f}".format}
178 | )
179 | log = (
180 | f"Dataset: {dataset}\n"
181 | f"Arch: {hidden_dims}\n"
182 | f"Path multiplier: {path_multiplier}\n"
183 | f"Score: {scores.mean():.04f} "
184 | f"± {confidence_interval(scores):.04f} "
185 | f"with {n_runs} runs\n"
186 | f"Raw scores: {scores_str}\n"
187 | f"Running time: {time() - start:.00f}s with {n_jobs} processors\n"
188 | f"Running time per run per cpu: {(time() - start) * n_jobs / n_runs:.00f}s\n"
189 | f"{'-' * 50}\n"
190 | )
191 |
192 | tqdm.write(log)
193 | with open("cox_experiments.log", "a") as f:
194 | print(log, file=f)
195 |
196 |
197 | if __name__ == "__main__":
198 | """
199 | run with python3 script.py dataset [method]
200 |
201 | dataset=all runs all experiments
202 |
203 | method can be "breslow" or "efron" (default "efron")
204 | """
205 |
206 | import sys
207 |
208 | dataset = sys.argv[1]
209 | tie_approximation = sys.argv[2] if len(sys.argv) > 2 else "efron"
210 | if dataset == "all":
211 | datasets = ["breast", "whas500", "veterans", "hnscc"]
212 | verbose = False
213 | else:
214 | datasets = [dataset]
215 | verbose = 1
216 |
217 | N_RUNS = 15
218 | N_JOBS = 5 # set to a divisor of `n_runs` for maximal efficiency
219 |
220 | for hidden_dims in [(16, 16), (32,), (32, 16), (64,)]:
221 | for path_multiplier in [1.01, 1.02]:
222 | for dataset in datasets:
223 | experiment(
224 | dataset=dataset,
225 | hidden_dims=hidden_dims,
226 | path_multiplier=path_multiplier,
227 | n_runs=N_RUNS,
228 | n_jobs=N_JOBS,
229 | )
230 |
231 |
232 | # import optuna
233 |
234 | # def objective(trial: optuna.Trial):
235 | # model = LassoNetCoxRegressorCV(
236 | # tie_approximation="breslow",
237 | # hidden_dims=(trial.suggest_int("hidden_dims", 8, 128),),
238 | # path_multiplier=1.01,
239 | # M=trial.suggest_float("M", 1e-3, 1e3),
240 | # cv=cv,
241 | # torch_seed=random_state,
242 | # verbose=False,
243 | # )
244 | # model.fit(X_train, y_train)
245 | # test_score = model.score(X_test, y_test)
246 | # trial.set_user_attr("score std", model.best_cv_score_std_)
247 | # trial.set_user_attr("test score", test_score)
248 | # print("test score", test_score)
249 | # return model.best_cv_score_
250 |
251 |
252 | # if __name__ == "__main__":
253 | # study = optuna.create_study(
254 | # storage="sqlite:///optuna.db",
255 | # study_name="fastcph-lassonet",
256 | # direction="maximize",
257 | # load_if_exists=True,
258 | # )
259 | # study.optimize(objective, n_trials=100)
260 |
--------------------------------------------------------------------------------
/examples/cox_regression.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from pathlib import Path
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from sklearn.model_selection import train_test_split
7 |
8 | from lassonet import LassoNetCoxRegressor, plot_path
9 |
10 | data = Path(__file__).parent / "data"
11 | X = np.genfromtxt(data / "hnscc_x.csv", delimiter=",", skip_header=1)
12 | y = np.genfromtxt(data / "hnscc_y.csv", delimiter=",", skip_header=1)
13 |
14 |
15 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
16 |
17 | model = LassoNetCoxRegressor(
18 | hidden_dims=(100,),
19 | lambda_start=1e-2,
20 | path_multiplier=1.02,
21 | gamma=1,
22 | verbose=True,
23 | tie_approximation="breslow",
24 | )
25 |
26 | model.path(X_train, y_train, return_state_dicts=True)
27 |
28 | plot_path(model, X_test, y_test)
29 | plt.savefig("cox_regression.png")
30 |
31 |
32 | model = LassoNetCoxRegressor(
33 | hidden_dims=(100,),
34 | lambda_start=1e-2,
35 | path_multiplier=1.02,
36 | gamma=1,
37 | verbose=True,
38 | tie_approximation="efron",
39 | )
40 |
41 | path = model.path(X_train, y_train)
42 |
43 | plot_path(model, X_test, y_test)
44 | plt.savefig("cox_regression_efron.png")
45 |
--------------------------------------------------------------------------------
/examples/data/hnscc_y.csv:
--------------------------------------------------------------------------------
1 | Overall survival_duration,OS4Y
2 | 1763.0,0.0
3 | 407.0,1.0
4 | 2329.0,0.0
5 | 3982.0,0.0
6 | 2921.0,0.0
7 | 2560.0,0.0
8 | 1462.0,0.0
9 | 3101.0,0.0
10 | 2815.0,0.0
11 | 915.0,0.0
12 | 4347.0,0.0
13 | 1988.0,0.0
14 | 1277.0,0.0
15 | 3435.0,0.0
16 | 3982.0,0.0
17 | 879.0,1.0
18 | 3162.0,0.0
19 | 3739.0,0.0
20 | 1946.0,0.0
21 | 1933.0,0.0
22 | 3587.0,0.0
23 | 183.0,1.0
24 | 2329.0,0.0
25 | 2441.0,0.0
26 | 2678.0,0.0
27 | 1933.0,0.0
28 | 775.0,0.0
29 | 2113.0,0.0
30 | 1788.0,0.0
31 | 3587.0,0.0
32 | 568.0,1.0
33 | 2207.0,0.0
34 | 304.0,1.0
35 | 958.0,0.0
36 | 303.0,0.0
37 | 2648.0,0.0
38 | 3374.0,0.0
39 | 227.0,1.0
40 | 298.0,1.0
41 | 2207.0,0.0
42 | 3162.0,0.0
43 | 1909.0,0.0
44 | 426.0,1.0
45 | 2994.0,0.0
46 | 851.0,1.0
47 | 2608.0,0.0
48 | 2414.0,0.0
49 | 3314.0,0.0
50 | 2289.0,0.0
51 | 1857.0,0.0
52 | 2900.0,0.0
53 | 3344.0,0.0
54 | 2776.0,0.0
55 | 3709.0,0.0
56 | 3952.0,0.0
57 | 3283.0,0.0
58 | 2833.0,0.0
59 | 447.0,1.0
60 | 1760.0,0.0
61 | 3314.0,0.0
62 | 1028.0,0.0
63 | 3709.0,0.0
64 | 2244.0,0.0
65 | 1860.0,0.0
66 | 3101.0,0.0
67 | 1088.0,0.0
68 | 1480.0,0.0
69 | 2298.0,0.0
70 | 2481.0,0.0
71 | 277.0,1.0
72 | 3618.0,0.0
73 | 3830.0,0.0
74 | 979.0,1.0
75 | 1827.0,0.0
76 | 3283.0,0.0
77 | 2882.0,0.0
78 | 2411.0,0.0
79 | 863.0,1.0
80 | 502.0,1.0
81 | 669.0,1.0
82 | 1006.0,0.0
83 | 2779.0,0.0
84 | 876.0,1.0
85 | 1043.0,0.0
86 | 1842.0,0.0
87 | 33.0,1.0
88 | 2399.0,0.0
89 | 1593.0,0.0
90 | 2572.0,0.0
91 | 2468.0,0.0
92 | 2584.0,0.0
93 | 1809.0,0.0
94 | 2092.0,0.0
95 | 1067.0,1.0
96 | 3101.0,0.0
97 | 1985.0,0.0
98 | 2450.0,0.0
99 | 1067.0,0.0
100 | 3000.0,0.0
101 | 3192.0,0.0
102 | 3222.0,0.0
103 | 1964.0,0.0
104 | 444.0,0.0
105 | 2462.0,0.0
106 | 2739.0,0.0
107 | 1867.0,0.0
108 | 4043.0,0.0
109 | 2879.0,0.0
110 | 1991.0,0.0
111 | 2538.0,0.0
112 | 1943.0,0.0
113 | 2043.0,0.0
114 | 3222.0,0.0
115 | 486.0,1.0
116 | 322.0,1.0
117 | 2301.0,0.0
118 | 3435.0,0.0
119 | 2274.0,0.0
120 | 2578.0,0.0
121 | 2560.0,0.0
122 | 2107.0,0.0
123 | 851.0,0.0
124 | 3070.0,0.0
125 | 2763.0,0.0
126 | 2505.0,0.0
127 | 1724.0,0.0
128 | 3739.0,0.0
129 | 3192.0,0.0
130 | 1240.0,0.0
131 | 1973.0,0.0
132 | 1234.0,0.0
133 | 2247.0,0.0
134 | 2645.0,0.0
135 | 882.0,1.0
136 | 2107.0,0.0
137 | 1766.0,0.0
138 | 2472.0,0.0
139 | 3374.0,0.0
140 | 3496.0,0.0
141 | 1414.0,1.0
142 | 3283.0,0.0
143 | 2338.0,0.0
144 | 1155.0,0.0
145 | 450.0,1.0
146 | 2335.0,0.0
147 | 605.0,1.0
148 | 3709.0,0.0
149 | 2584.0,0.0
150 | 2915.0,0.0
151 | 2204.0,0.0
152 | 4195.0,0.0
153 | 502.0,1.0
154 | 1791.0,0.0
155 | 2253.0,0.0
156 | 4165.0,0.0
157 | 4226.0,0.0
158 | 3405.0,0.0
159 | 2329.0,0.0
160 | 4286.0,0.0
161 | 2353.0,0.0
162 | 1788.0,0.0
163 | 1924.0,0.0
164 | 4165.0,0.0
165 | 3162.0,0.0
166 | 2809.0,0.0
167 | 2240.0,0.0
168 | 3466.0,0.0
169 | 1645.0,0.0
170 | 1176.0,0.0
171 | 2441.0,0.0
172 | 882.0,0.0
173 | 3678.0,0.0
174 | 3070.0,0.0
175 | 3557.0,0.0
176 | 824.0,1.0
177 | 2897.0,0.0
178 | 857.0,0.0
179 | 1888.0,0.0
180 | 118.0,0.0
181 | 3040.0,0.0
182 | 3070.0,0.0
183 | 2630.0,0.0
184 | 471.0,1.0
185 | 3496.0,0.0
186 | 3007.0,0.0
187 | 3435.0,0.0
188 | 2289.0,0.0
189 | 769.0,1.0
190 | 2943.0,0.0
191 | 2724.0,0.0
192 | 1943.0,0.0
193 | 2766.0,0.0
194 | 4134.0,0.0
195 | 2554.0,0.0
196 | 2490.0,0.0
197 | 2687.0,0.0
198 | 2265.0,0.0
199 | 2785.0,0.0
200 | 2143.0,0.0
201 | 1958.0,0.0
202 | 815.0,1.0
203 | 3344.0,0.0
204 | 2985.0,0.0
205 | 2675.0,0.0
206 | 2937.0,0.0
207 | 2660.0,0.0
208 | 529.0,1.0
209 | 1912.0,0.0
210 | 2852.0,0.0
211 | 3678.0,0.0
212 | 1788.0,0.0
213 | 2526.0,0.0
214 | 3222.0,0.0
215 | 1705.0,0.0
216 | 2912.0,0.0
217 | 3770.0,0.0
218 | 468.0,1.0
219 | 3253.0,0.0
220 | 2110.0,0.0
221 | 1538.0,0.0
222 | 304.0,1.0
223 | 2350.0,0.0
224 | 1857.0,0.0
225 | 3800.0,0.0
226 | 2809.0,0.0
227 | 1906.0,0.0
228 | 2073.0,0.0
229 | 3283.0,0.0
230 | 3405.0,0.0
231 | 3800.0,0.0
232 | 3101.0,0.0
233 | 3739.0,0.0
234 | 3922.0,0.0
235 | 2544.0,0.0
236 | 3435.0,0.0
237 | 2347.0,0.0
238 | 2529.0,0.0
239 | 2079.0,0.0
240 | 1368.0,1.0
241 | 2429.0,0.0
242 | 2201.0,0.0
243 | 1338.0,1.0
244 | 2386.0,0.0
245 | 1216.0,1.0
246 | 854.0,1.0
247 | 3070.0,0.0
248 | 2888.0,0.0
249 | 2800.0,0.0
250 | 4165.0,0.0
251 | 3253.0,0.0
252 | 1882.0,0.0
253 | 1915.0,0.0
254 | 3374.0,0.0
255 | 3861.0,0.0
256 | 830.0,1.0
257 | 1912.0,0.0
258 | 3466.0,0.0
259 | 3678.0,0.0
260 | 2748.0,0.0
261 | 2438.0,0.0
262 | 1277.0,0.0
263 | 2304.0,0.0
264 | 2909.0,0.0
265 | 2852.0,0.0
266 | 3101.0,0.0
267 | 2788.0,0.0
268 | 4195.0,0.0
269 | 2721.0,0.0
270 | 3587.0,0.0
271 | 4195.0,0.0
272 | 4165.0,0.0
273 | 3557.0,0.0
274 | 638.0,1.0
275 | 538.0,1.0
276 | 3070.0,0.0
277 | 3587.0,0.0
278 | 2666.0,0.0
279 | 1122.0,1.0
280 | 3253.0,0.0
281 | 3709.0,0.0
282 | 3891.0,0.0
283 | 1702.0,0.0
284 | 2216.0,0.0
285 | 4347.0,0.0
286 | 328.0,1.0
287 | 4165.0,0.0
288 | 3131.0,0.0
289 | 2724.0,0.0
290 | 3891.0,0.0
291 | 1985.0,0.0
292 | 4165.0,0.0
293 | 2587.0,0.0
294 | 2572.0,0.0
295 | 2377.0,0.0
296 | 2271.0,0.0
297 | 1508.0,0.0
298 | 2371.0,0.0
299 | 2052.0,0.0
300 | 3070.0,0.0
301 | 2648.0,0.0
302 | 2313.0,0.0
303 | 2541.0,0.0
304 | 2748.0,0.0
305 | 1815.0,0.0
306 | 3982.0,0.0
307 | 2690.0,0.0
308 | 1857.0,0.0
309 | 1936.0,0.0
310 | 3982.0,0.0
311 | 1611.0,0.0
312 | 1794.0,0.0
313 | 2438.0,0.0
314 | 2703.0,0.0
315 | 1693.0,0.0
316 | 2991.0,0.0
317 | 4043.0,0.0
318 | 2946.0,0.0
319 | 2903.0,0.0
320 | 3070.0,0.0
321 | 3040.0,0.0
322 | 1389.0,0.0
323 | 2554.0,0.0
324 | 1982.0,0.0
325 | 2788.0,0.0
326 | 2098.0,0.0
327 | 1529.0,0.0
328 | 2304.0,0.0
329 | 3435.0,0.0
330 | 2481.0,0.0
331 | 2405.0,0.0
332 | 2852.0,0.0
333 | 2000.0,0.0
334 | 3800.0,0.0
335 | 1210.0,1.0
336 | 499.0,1.0
337 | 2298.0,0.0
338 | 3374.0,0.0
339 | 3253.0,0.0
340 | 3709.0,0.0
341 | 2146.0,0.0
342 | 2700.0,0.0
343 | 3739.0,0.0
344 | 2551.0,0.0
345 | 1614.0,0.0
346 | 3922.0,0.0
347 | 3283.0,0.0
348 | 1693.0,0.0
349 | 2979.0,0.0
350 | 1836.0,0.0
351 | 2417.0,0.0
352 | 1052.0,1.0
353 | 3344.0,0.0
354 | 3618.0,0.0
355 | 3982.0,0.0
356 | 223.0,1.0
357 | 1599.0,0.0
358 | 614.0,1.0
359 | 386.0,1.0
360 | 1818.0,0.0
361 | 4499.0,0.0
362 | 2973.0,0.0
363 | 1395.0,1.0
364 | 3618.0,0.0
365 | 2301.0,0.0
366 | 1456.0,1.0
367 | 2988.0,0.0
368 | 1681.0,0.0
369 | 3770.0,0.0
370 | 1155.0,0.0
371 | 2672.0,0.0
372 | 784.0,0.0
373 | 3557.0,0.0
374 | 73.0,1.0
375 | 1541.0,0.0
376 | 3861.0,0.0
377 | 3861.0,0.0
378 | 3314.0,0.0
379 | 3070.0,0.0
380 | 3466.0,0.0
381 | 3192.0,0.0
382 | 4134.0,0.0
383 | 2450.0,0.0
384 | 1207.0,1.0
385 | 2064.0,0.0
386 | 1377.0,0.0
387 | 3374.0,0.0
388 | 1502.0,0.0
389 | 3025.0,0.0
390 | 1900.0,0.0
391 | 869.0,1.0
392 | 3618.0,0.0
393 | 2611.0,0.0
394 | 2918.0,0.0
395 | 1800.0,0.0
396 | 3253.0,0.0
397 | 2748.0,0.0
398 | 395.0,1.0
399 | 3314.0,0.0
400 | 3405.0,0.0
401 | 1338.0,0.0
402 | 2204.0,0.0
403 | 3770.0,0.0
404 | 1958.0,0.0
405 | 2280.0,0.0
406 | 2964.0,0.0
407 | 3496.0,0.0
408 | 3283.0,0.0
409 | 1587.0,0.0
410 | 1745.0,0.0
411 | 1651.0,0.0
412 | 784.0,0.0
413 | 2031.0,0.0
414 | 1897.0,0.0
415 | 1970.0,0.0
416 | 3034.0,0.0
417 | 2122.0,0.0
418 | 2766.0,0.0
419 | 1848.0,0.0
420 | 1818.0,0.0
421 | 2335.0,0.0
422 | 383.0,1.0
423 | 2377.0,0.0
424 | 340.0,0.0
425 | 3557.0,0.0
426 | 377.0,1.0
427 | 3891.0,0.0
428 | 1833.0,0.0
429 | 2289.0,0.0
430 | 3952.0,0.0
431 | 3131.0,0.0
432 | 2727.0,0.0
433 | 3222.0,0.0
434 | 2943.0,0.0
435 | 2174.0,0.0
436 | 3344.0,0.0
437 | 2052.0,0.0
438 | 389.0,1.0
439 | 3000.0,0.0
440 | 2885.0,0.0
441 | 3040.0,0.0
442 | 2493.0,0.0
443 | 1879.0,0.0
444 | 4226.0,0.0
445 | 3101.0,0.0
446 | 2213.0,0.0
447 | 2079.0,0.0
448 | 3800.0,0.0
449 | 3374.0,0.0
450 | 2794.0,0.0
451 | 2997.0,0.0
452 | 3222.0,0.0
453 |
--------------------------------------------------------------------------------
/examples/diabetes.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from sklearn.datasets import load_diabetes
6 | from sklearn.model_selection import train_test_split
7 | from sklearn.preprocessing import StandardScaler, scale
8 |
9 | from lassonet import LassoNetRegressor, plot_path
10 |
11 | dataset = load_diabetes()
12 | X = dataset.data
13 | y = dataset.target
14 | _, true_features = X.shape
15 | # add dummy feature
16 | X = np.concatenate([X, np.random.randn(*X.shape)], axis=1)
17 | feature_names = list(dataset.feature_names) + ["fake"] * true_features
18 |
19 | # standardize
20 | X = StandardScaler().fit_transform(X)
21 | y = scale(y)
22 |
23 |
24 | X_train, X_test, y_train, y_test = train_test_split(X, y)
25 |
26 | model = LassoNetRegressor(
27 | hidden_dims=(10,),
28 | verbose=True,
29 | )
30 | path = model.path(X_train, y_train, return_state_dicts=True)
31 |
32 | plot_path(model, X_test, y_test)
33 |
34 | plt.savefig("diabetes.png")
35 |
36 | plt.clf()
37 |
38 | n_features = X.shape[1]
39 | importances = model.feature_importances_.numpy()
40 | order = np.argsort(importances)[::-1]
41 | importances = importances[order]
42 | ordered_feature_names = [feature_names[i] for i in order]
43 | color = np.array(["g"] * true_features + ["r"] * (n_features - true_features))[order]
44 |
45 |
46 | plt.subplot(211)
47 | plt.bar(
48 | np.arange(n_features),
49 | importances,
50 | color=color,
51 | )
52 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90)
53 | colors = {"real features": "g", "fake features": "r"}
54 | labels = list(colors.keys())
55 | handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label]) for label in labels]
56 | plt.legend(handles, labels)
57 | plt.ylabel("Feature importance")
58 |
59 | _, order = np.unique(importances, return_inverse=True)
60 |
61 | plt.subplot(212)
62 | plt.bar(
63 | np.arange(n_features),
64 | order + 1,
65 | color=color,
66 | )
67 | plt.xticks(np.arange(n_features), ordered_feature_names, rotation=90)
68 | plt.legend(handles, labels)
69 | plt.ylabel("Feature order")
70 |
71 | plt.savefig("diabetes-bar.png")
72 |
--------------------------------------------------------------------------------
/examples/friedman.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from sklearn.metrics import r2_score
4 | from sklearn.model_selection import train_test_split
5 | from sklearn.preprocessing import StandardScaler
6 |
7 | from lassonet import LassoNetRegressor, plot_path
8 |
9 |
10 | def friedman(linear_terms=True):
11 | n = 1000
12 | p = 200
13 | X = np.random.rand(n, p)
14 | y = 10 * np.sin(np.pi * X[:, 0] * X[:, 1]) + 20 * (X[:, 2] - 0.5) ** 2
15 | if linear_terms:
16 | y += 10 * X[:, 3] + 5 * X[:, 4]
17 | return X, y
18 |
19 |
20 | np.random.seed(0)
21 | X, y = friedman(linear_terms=True)
22 | X = StandardScaler().fit_transform(X)
23 | X_train, X_test, y_train, y_test = train_test_split(X, y)
24 |
25 | y_std = 0.5 * y_train.std()
26 | y_train += np.random.randn(*y_train.shape) * y_std
27 |
28 | for y in [y_train, y_test]:
29 | y -= y.mean()
30 | y /= y.std()
31 |
32 |
33 | def rrmse(y, y_pred):
34 | return np.sqrt(1 - r2_score(y, y_pred))
35 |
36 |
37 | for path_multiplier in [1.01, 1.001]:
38 | print("path_multiplier:", path_multiplier)
39 | for M in [10, 100, 1_000, 10_000, 100_000]:
40 | print("M:", M)
41 | model = LassoNetRegressor(
42 | hidden_dims=(10, 10),
43 | random_state=0,
44 | torch_seed=0,
45 | path_multiplier=path_multiplier,
46 | M=M,
47 | )
48 | path = model.path(X_train, y_train, return_state_dicts=True)
49 | print(
50 | "rrmse:",
51 | min(rrmse(y_test, model.load(save).predict(X_test)) for save in path),
52 | )
53 | plot_path(model, X_test, y_test, score_function=rrmse)
54 | plt.savefig(f"friedman_path({path_multiplier})_M({M}).jpg")
55 |
56 | path_multiplier = 1.001
57 | print("path_multiplier:", path_multiplier)
58 | for M in [100, 1_000, 10_000, 100_000]:
59 | print("M:", M)
60 | model = LassoNetRegressor(
61 | hidden_dims=(10, 10),
62 | random_state=0,
63 | torch_seed=0,
64 | path_multiplier=path_multiplier,
65 | M=M,
66 | backtrack=True,
67 | )
68 | path = model.path(X_train, y_train)
69 | print(
70 | "rrmse:",
71 | min(rrmse(y_test, model.load(save).predict(X_test)) for save in path),
72 | )
73 | plot_path(model, X_test, y_test, score_function=rrmse)
74 | plt.savefig(f"friedman_path({path_multiplier})_M({M})_backtrack.jpg")
75 |
76 |
77 | for path_multiplier in [1.01, 1.001]:
78 | M = 100_000
79 | print("path_multiplier:", path_multiplier)
80 | print("M:", M)
81 | model = LassoNetRegressor(
82 | hidden_dims=(10, 10),
83 | random_state=0,
84 | torch_seed=0,
85 | path_multiplier=path_multiplier,
86 | M=M,
87 | patience=100,
88 | n_iters=1000,
89 | )
90 | path = model.path(X_train, y_train)
91 | print(
92 | "rrmse:",
93 | min(rrmse(y_test, model.load(save).predict(X_test)) for save in path),
94 | )
95 | plot_path(model, X_test, y_test, score_function=rrmse)
96 | plt.savefig(f"friedman_path({path_multiplier})_M({M})_long.jpg")
97 |
98 | # if __name__ == "__main__":
99 |
100 | # model = LassoNetRegressor(verbose=True, path_multiplier=1.01, hidden_dims=(10, 10))
101 | # path = model.path(X_train, y_train)
102 |
103 | # plot_path(model, X_test, y_test, score_function=rrmse)
104 | # plt.show()
105 |
--------------------------------------------------------------------------------
/examples/friedman/download.sh:
--------------------------------------------------------------------------------
1 | URL="https://raw.githubusercontent.com/warbelo/Lockout/main/Synthetic_Data2/dataset_a"
2 |
3 | for name in X.csv Y.csv xtest.csv xtrain.csv xvalid.csv ytest.csv ytrain.csv yvalid.csv; do
4 | curl $URL/$name > $name
5 | done
--------------------------------------------------------------------------------
/examples/friedman/main.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from sklearn.metrics import r2_score
4 | from sklearn.preprocessing import StandardScaler
5 |
6 | from lassonet import LassoNetRegressor, plot_path
7 |
8 |
9 | def load(s):
10 | return np.loadtxt(s, delimiter=",")
11 |
12 |
13 | X_train = load("xtrain.csv")
14 | y_train = load("ytrain.csv")
15 | X_val = load("xvalid.csv")
16 | y_val = load("yvalid.csv")
17 | X_test = load("xtest.csv")
18 | y_test = load("ytest.csv")
19 |
20 | X_train, X_val, X_test = np.split(
21 | StandardScaler().fit_transform(np.concatenate((X_train, X_val, X_test))), 3
22 | )
23 |
24 | y = np.concatenate((y_train, y_val))
25 | y_mean = y.mean()
26 | y_std = y.std()
27 |
28 | for y in [y_train, y_val, y_test]:
29 | y -= y_mean
30 | y /= y_std
31 |
32 |
33 | def rrmse(y, y_pred):
34 | return np.sqrt(1 - r2_score(y, y_pred))
35 |
36 |
37 | if __name__ == "__main__":
38 |
39 | model = LassoNetRegressor(
40 | path_multiplier=1.001,
41 | M=100_000,
42 | hidden_dims=(10, 10),
43 | torch_seed=0,
44 | )
45 | path = model.path(
46 | X_train, y_train, X_val=X_val, y_val=y_val, return_state_dicts=True
47 | )
48 | print(
49 | "rrmse:",
50 | min(rrmse(y_test, model.load(save).predict(X_test)) for save in path),
51 | )
52 | plot_path(model, X_test, y_test, score_function=rrmse)
53 | plt.show()
54 |
--------------------------------------------------------------------------------
/examples/generated.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import numpy as np
4 | from sklearn.metrics import r2_score
5 | from sklearn.model_selection import train_test_split
6 | from sklearn.preprocessing import StandardScaler
7 |
8 | from lassonet import LassoNetRegressor, plot_path
9 |
10 |
11 | def linear():
12 | p = 10
13 | n = 400
14 | coef = np.concatenate([np.random.choice([-1, 1], size=p), [0] * p])
15 | X = np.random.randn(n, 2 * p)
16 |
17 | linear = X.dot(coef)
18 | noise = np.random.randn(n)
19 |
20 | y = linear + noise
21 | return X, y
22 |
23 |
24 | def strong_linear():
25 | p = 10
26 | n = 400
27 | coef = np.concatenate([np.random.choice([-1, 1], size=p), [0] * p])
28 | X = np.random.randn(n, 2 * p)
29 |
30 | linear = X.dot(coef)
31 | noise = np.random.randn(n)
32 | x1, x2, x3, *_ = X.T
33 | nonlinear = 2 * (x1**3 - 3 * x1) + 4 * (x2**2 * x3 - x3)
34 | y = 6 * linear + 8 * noise + nonlinear
35 | return X, y
36 |
37 |
38 | def friedman_lockout():
39 | p = 200
40 | n = 1000
41 | X = np.random.rand(n, p)
42 | y = (
43 | 10 * np.sin(np.pi * X[:, 0] * X[:, 1])
44 | + 20 * (X[:, 2] - 0.5) ** 2
45 | + 10 * X[:, 3]
46 | + 5 * X[:, 4]
47 | )
48 | return X, y
49 |
50 |
51 | for generator in [linear, strong_linear, friedman_lockout]:
52 | X, y = generator()
53 | X = StandardScaler().fit_transform(X)
54 | y -= y.mean()
55 | y /= y.std()
56 | X_train, X_test, y_train, y_test = train_test_split(X, y)
57 |
58 | model = LassoNetRegressor(verbose=True, path_multiplier=1.01, hidden_dims=(10, 10))
59 |
60 | path = model.path(X_train, y_train, return_state_dicts=True)
61 | import matplotlib.pyplot as plt
62 |
63 | def score(self, X, y, sample_weight=None):
64 | y_pred = self.predict(X)
65 | return np.sqrt(1 - r2_score(y, y_pred, sample_weight=sample_weight))
66 |
67 | model.score = partial(score, model)
68 |
69 | plot_path(model, X_test, y_test)
70 | plt.show()
71 |
--------------------------------------------------------------------------------
/examples/miceprotein.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | """
5 | We run Lassonet over [the Mice Dataset](https://archive.ics.uci.edu/ml/datasets/Mice%20Protein%20Expression).
6 | This dataset consists of protein expression levels measured in the cortex of normal and
7 | trisomic mice who had been exposed to different experimental conditions.
8 | Each feature is the expression level of one protein.
9 | """
10 |
11 | import matplotlib.pyplot as plt
12 | from sklearn.datasets import fetch_openml
13 | from sklearn.impute import SimpleImputer
14 | from sklearn.model_selection import train_test_split
15 | from sklearn.preprocessing import LabelEncoder, StandardScaler
16 |
17 | from lassonet import LassoNetClassifier, plot_cv, plot_path
18 | from lassonet.interfaces import LassoNetClassifierCV
19 |
20 | X, y = fetch_openml(name="miceprotein", return_X_y=True)
21 | # Fill missing values with the mean
22 | X = SimpleImputer().fit_transform(X)
23 | # Convert labels to scalar
24 | y = LabelEncoder().fit_transform(y)
25 |
26 | # standardize
27 | X = StandardScaler().fit_transform(X)
28 |
29 | X_train, X_test, y_train, y_test = train_test_split(X, y)
30 |
31 |
32 | model = LassoNetClassifierCV()
33 | model.path(X_train, y_train, return_state_dicts=True)
34 | print("Best model scored", model.score(X_test, y_test))
35 | print("Lambda =", model.best_lambda_)
36 | plot_cv(model, X_test, y_test)
37 | plt.savefig("miceprotein-cv.png")
38 |
39 | model = LassoNetClassifier()
40 | path = model.path(X_train, y_train, return_state_dicts=True)
41 | plot_path(model, X_test, y_test)
42 | plt.savefig("miceprotein.png")
43 |
44 | model = LassoNetClassifier(dropout=0.5)
45 | path = model.path(X_train, y_train, return_state_dicts=True)
46 | plot_path(model, X_test, y_test)
47 | plt.savefig("miceprotein_dropout.png")
48 |
49 | model = LassoNetClassifier(hidden_dims=(100, 100))
50 | path = model.path(X_train, y_train, return_state_dicts=True)
51 | plot_path(model, X_test, y_test)
52 | plt.savefig("miceprotein_deep.png")
53 |
54 | model = LassoNetClassifier(hidden_dims=(100, 100), gamma=0.01)
55 | path = model.path(X_train, y_train, return_state_dicts=True)
56 | plot_path(model, X_test, y_test)
57 | plt.savefig("miceprotein_deep_l2_weak.png")
58 |
59 | model = LassoNetClassifier(hidden_dims=(100, 100), gamma=0.1)
60 | path = model.path(X_train, y_train, return_state_dicts=True)
61 | plot_path(model, X_test, y_test)
62 | plt.savefig("miceprotein_deep_l2_strong.png")
63 |
64 | model = LassoNetClassifier(hidden_dims=(100, 100), gamma=1)
65 | path = model.path(X_train, y_train, return_state_dicts=True)
66 | plot_path(model, X_test, y_test)
67 | plt.savefig("miceprotein_deep_l2_super_strong.png")
68 |
69 | model = LassoNetClassifier(hidden_dims=(100, 100), dropout=0.5)
70 | path = model.path(X_train, y_train, return_state_dicts=True)
71 | plot_path(model, X_test, y_test)
72 | plt.savefig("miceprotein_deep_dropout.png")
73 |
74 | model = LassoNetClassifier(hidden_dims=(100, 100), backtrack=True, dropout=0.5)
75 | path = model.path(X_train, y_train, return_state_dicts=True)
76 | plot_path(model, X_test, y_test)
77 | plt.savefig("miceprotein_deep_dropout_backtrack.png")
78 |
79 | model = LassoNetClassifier(batch_size=64)
80 | path = model.path(X_train, y_train, return_state_dicts=True)
81 | plot_path(model, X_test, y_test)
82 | plt.savefig("miceprotein_64.png")
83 |
84 | model = LassoNetClassifier(backtrack=True)
85 | path = model.path(X_train, y_train, return_state_dicts=True)
86 | plot_path(model, X_test, y_test)
87 | plt.savefig("miceprotein_backtrack.png")
88 |
89 | model = LassoNetClassifier(batch_size=64, backtrack=True)
90 | path = model.path(X_train, y_train, return_state_dicts=True)
91 | plot_path(model, X_test, y_test)
92 | plt.savefig("miceprotein_backtrack_64.png")
93 |
94 | model = LassoNetClassifier(class_weight=[0.1, 0.2, 0.3, 0.1, 0.3, 0, 0, 0])
95 | path = model.path(X_train, y_train, return_state_dicts=True)
96 | plot_path(model, X_test, y_test)
97 | plt.savefig("miceprotein_weighted.png")
98 |
--------------------------------------------------------------------------------
/examples/mnist_ae.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from sklearn.datasets import fetch_openml
3 | from sklearn.metrics import mean_squared_error
4 |
5 | from lassonet import LassoNetAutoEncoder
6 |
7 | X, y = fetch_openml(name="mnist_784", return_X_y=True)
8 | filter = y == "3"
9 | X = X[filter].values / 255
10 |
11 | model = LassoNetAutoEncoder(
12 | M=30, n_iters=(3000, 500), path_multiplier=1.05, verbose=True
13 | )
14 | path = model.path(X)
15 |
16 | img = model.feature_importances_.reshape(28, 28)
17 |
18 | plt.title("Feature importance to reconstruct 3")
19 | plt.imshow(img)
20 | plt.colorbar()
21 | plt.savefig("mnist-ae-importance.png")
22 |
23 |
24 | n_selected = []
25 | score = []
26 | lambda_ = []
27 |
28 | for save in path:
29 | model.load(save.state_dict)
30 | X_pred = model.predict(X)
31 | n_selected.append(save.selected.sum())
32 | score.append(mean_squared_error(X_pred, X))
33 | lambda_.append(save.lambda_)
34 |
35 | to_plot = [160, 220, 300]
36 |
37 | for i, save in zip(n_selected, path):
38 | if not to_plot:
39 | break
40 | if i > to_plot[-1]:
41 | continue
42 | to_plot.pop()
43 | plt.clf()
44 | plt.title(f"Linear model with {i} features")
45 | weight = save.state_dict["skip.weight"]
46 | img = (weight[1] - weight[0]).reshape(28, 28)
47 | plt.imshow(img)
48 | plt.colorbar()
49 | plt.savefig(f"mnist-ae-{i}.png")
50 |
51 | plt.clf()
52 |
53 | fig = plt.figure(figsize=(12, 12))
54 |
55 | plt.subplot(311)
56 | plt.grid(True)
57 | plt.plot(n_selected, score, ".-")
58 | plt.xlabel("number of selected features")
59 | plt.ylabel("MSE")
60 |
61 | plt.subplot(312)
62 | plt.grid(True)
63 | plt.plot(lambda_, score, ".-")
64 | plt.xlabel("lambda")
65 | plt.xscale("log")
66 | plt.ylabel("MSE")
67 |
68 | plt.subplot(313)
69 | plt.grid(True)
70 | plt.plot(lambda_, n_selected, ".-")
71 | plt.xlabel("lambda")
72 | plt.xscale("log")
73 | plt.ylabel("number of selected features")
74 |
75 | plt.savefig("mnist-ae-training.png")
76 |
77 |
78 | plt.subplot(221)
79 | plt.imshow(X[150].reshape(28, 28))
80 | plt.subplot(222)
81 | plt.imshow(model.predict(X[150]).reshape(28, 28))
82 | plt.subplot(223)
83 | plt.imshow(X[250].reshape(28, 28))
84 | plt.subplot(224)
85 | plt.imshow(model.predict(X[250]).reshape(28, 28))
86 |
--------------------------------------------------------------------------------
/examples/mnist_classif.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from sklearn.datasets import fetch_openml
3 | from sklearn.metrics import accuracy_score
4 | from sklearn.model_selection import train_test_split
5 | from sklearn.preprocessing import LabelEncoder
6 |
7 | from lassonet import LassoNetClassifier
8 |
9 | X, y = fetch_openml(name="mnist_784", return_X_y=True)
10 | filter = y.isin(["5", "6"])
11 | X = X[filter].values / 255
12 | y = LabelEncoder().fit_transform(y[filter])
13 |
14 | X_train, X_test, y_train, y_test = train_test_split(X, y)
15 |
16 | model = LassoNetClassifier(M=30, verbose=True)
17 | path = model.path(X_train, y_train)
18 |
19 | img = model.feature_importances_.reshape(28, 28)
20 |
21 | plt.title("Feature importance to discriminate 5 and 6")
22 | plt.imshow(img)
23 | plt.colorbar()
24 | plt.savefig("mnist-classification-importance.png")
25 |
26 | n_selected = []
27 | accuracy = []
28 | lambda_ = []
29 |
30 | for save in path:
31 | model.load(save.state_dict)
32 | y_pred = model.predict(X_test)
33 | n_selected.append(save.selected.sum())
34 | accuracy.append(accuracy_score(y_test, y_pred))
35 | lambda_.append(save.lambda_)
36 |
37 | to_plot = [160, 220, 300]
38 |
39 | for i, save in zip(n_selected, path):
40 | if not to_plot:
41 | break
42 | if i > to_plot[-1]:
43 | continue
44 | to_plot.pop()
45 | plt.clf()
46 | plt.title(f"Linear model with {i} features")
47 | weight = save.state_dict["skip.weight"]
48 | img = (weight[1] - weight[0]).reshape(28, 28)
49 | plt.imshow(img)
50 | plt.colorbar()
51 | plt.savefig(f"mnist-classification-{i}.png")
52 |
53 | fig = plt.figure(figsize=(12, 12))
54 |
55 | plt.subplot(311)
56 | plt.grid(True)
57 | plt.plot(n_selected, accuracy, ".-")
58 | plt.xlabel("number of selected features")
59 | plt.ylabel("classification accuracy")
60 |
61 | plt.subplot(312)
62 | plt.grid(True)
63 | plt.plot(lambda_, accuracy, ".-")
64 | plt.xlabel("lambda")
65 | plt.xscale("log")
66 | plt.ylabel("classification accuracy")
67 |
68 | plt.subplot(313)
69 | plt.grid(True)
70 | plt.plot(lambda_, n_selected, ".-")
71 | plt.xlabel("lambda")
72 | plt.xscale("log")
73 | plt.ylabel("number of selected features")
74 |
75 | plt.savefig("mnist-classification-training.png")
76 |
--------------------------------------------------------------------------------
/examples/mnist_reconstruction.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from sklearn.datasets import fetch_openml
3 | from sklearn.metrics import mean_squared_error
4 |
5 | from lassonet import LassoNetRegressor
6 |
7 | X, y = fetch_openml(name="mnist_784", return_X_y=True)
8 | filter = y == "3"
9 | X = X[filter].values / 255
10 |
11 | model = LassoNetRegressor(M=30, n_iters=(3000, 500), path_multiplier=1.05, verbose=True)
12 | path = model.path(X, X)
13 |
14 | img = model.feature_importances_.reshape(28, 28)
15 |
16 | plt.title("Feature importance to reconstruct 3")
17 | plt.imshow(img)
18 | plt.colorbar()
19 | plt.savefig("mnist-reconstruction-importance.png")
20 |
21 |
22 | n_selected = []
23 | score = []
24 | lambda_ = []
25 |
26 | for save in path:
27 | model.load(save.state_dict)
28 | X_pred = model.predict(X)
29 | n_selected.append(save.selected.sum())
30 | score.append(mean_squared_error(X_pred, X))
31 | lambda_.append(save.lambda_)
32 |
33 | to_plot = [160, 220, 300]
34 |
35 | for i, save in zip(n_selected, path):
36 | if not to_plot:
37 | break
38 | if i > to_plot[-1]:
39 | continue
40 | to_plot.pop()
41 | plt.clf()
42 | plt.title(f"Linear model with {i} features")
43 | weight = save.state_dict["skip.weight"]
44 | img = (weight[1] - weight[0]).reshape(28, 28)
45 | plt.imshow(img)
46 | plt.colorbar()
47 | plt.savefig(f"mnist-reconstruction-{i}.png")
48 |
49 | plt.clf()
50 |
51 | fig = plt.figure(figsize=(12, 12))
52 |
53 | plt.subplot(311)
54 | plt.grid(True)
55 | plt.plot(n_selected, score, ".-")
56 | plt.xlabel("number of selected features")
57 | plt.ylabel("MSE")
58 |
59 | plt.subplot(312)
60 | plt.grid(True)
61 | plt.plot(lambda_, score, ".-")
62 | plt.xlabel("lambda")
63 | plt.xscale("log")
64 | plt.ylabel("MSE")
65 |
66 | plt.subplot(313)
67 | plt.grid(True)
68 | plt.plot(lambda_, n_selected, ".-")
69 | plt.xlabel("lambda")
70 | plt.xscale("log")
71 | plt.ylabel("number of selected features")
72 |
73 | plt.savefig("mnist-reconstruction-training.png")
74 |
75 |
76 | plt.subplot(221)
77 | plt.imshow(X[150].reshape(28, 28))
78 | plt.subplot(222)
79 | plt.imshow(model.predict(X[150]).reshape(28, 28))
80 | plt.subplot(223)
81 | plt.imshow(X[250].reshape(28, 28))
82 | plt.subplot(224)
83 | plt.imshow(model.predict(X[250]).reshape(28, 28))
84 |
--------------------------------------------------------------------------------
/experiments/README.MD:
--------------------------------------------------------------------------------
1 | - The data to reproduce Table 1 are available in [this Google Drive repository](https://drive.google.com/open?id=1quiURu7w0nU3Pxcc448xRgfeI80okLsS). Alternatively, you can download all the data sets (except MNIST and MNIST-Fashion) directly from the [UCI Repository](https://archive.ics.uci.edu/ml/datasets.php).
2 | - `data-utils.py` contains starter code to load the 6 datasets in Table 1 of [the paper](https://arxiv.org/abs/1907.12207).
3 | - You will need to download the files, unzip them and modify the `/home/lemisma/datasets` path in `data-utils.py` to point to your local copy.
4 | - `run.py` contains the necessary code to reproduce the results in Table 1. This script allows you to specify the dataset and other parameters. To run the script, navigate to the directory containing `run.py` and use the following command: `python run.py`.
5 |
6 |
--------------------------------------------------------------------------------
/experiments/data_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from collections import defaultdict
3 | from os.path import join
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.preprocessing import MinMaxScaler
9 |
10 | # The code to load some of these datasets is reproduced from
11 | # https://github.com/mfbalin/Concrete-Autoencoders/blob/master/experiments/generate_comparison_figures.py
12 |
13 |
14 | def load_mice(one_hot=False):
15 | filling_value = -100000
16 | X = np.genfromtxt(
17 | "/home/lemisma/datasets/MICE/Data_Cortex_Nuclear.csv",
18 | delimiter=",",
19 | skip_header=1,
20 | usecols=range(1, 78),
21 | filling_values=filling_value,
22 | encoding="UTF-8",
23 | )
24 | classes = np.genfromtxt(
25 | "/home/lemisma/datasets/MICE/Data_Cortex_Nuclear.csv",
26 | delimiter=",",
27 | skip_header=1,
28 | usecols=range(78, 81),
29 | dtype=None,
30 | encoding="UTF-8",
31 | )
32 |
33 | for i, row in enumerate(X):
34 | for j, val in enumerate(row):
35 | if val == filling_value:
36 | X[i, j] = np.mean(
37 | [
38 | X[k, j]
39 | for k in range(classes.shape[0])
40 | if np.all(classes[i] == classes[k])
41 | ]
42 | )
43 |
44 | DY = np.zeros((classes.shape[0]), dtype=np.uint8)
45 | for i, row in enumerate(classes):
46 | for j, (val, label) in enumerate(zip(row, ["Control", "Memantine", "C/S"])):
47 | DY[i] += (2**j) * (val == label)
48 |
49 | Y = np.zeros((DY.shape[0], np.unique(DY).shape[0]))
50 | for idx, val in enumerate(DY):
51 | Y[idx, val] = 1
52 |
53 | X = MinMaxScaler(feature_range=(0, 1)).fit_transform(X)
54 |
55 | indices = np.arange(X.shape[0])
56 | np.random.shuffle(indices)
57 | X = X[indices]
58 | Y = Y[indices]
59 | DY = DY[indices]
60 | classes = classes[indices]
61 |
62 | if not one_hot:
63 | Y = DY
64 |
65 | X = X.astype(np.float32)
66 | Y = Y.astype(np.float32)
67 |
68 | print("X shape: {}, Y shape: {}".format(X.shape, Y.shape))
69 |
70 | return (X[: X.shape[0] * 4 // 5], Y[: X.shape[0] * 4 // 5]), (
71 | X[X.shape[0] * 4 // 5 :],
72 | Y[X.shape[0] * 4 // 5 :],
73 | )
74 |
75 |
76 | def load_isolet():
77 | x_train = np.genfromtxt(
78 | "/home/lemisma/datasets/isolet/isolet1+2+3+4.data",
79 | delimiter=",",
80 | usecols=range(0, 617),
81 | encoding="UTF-8",
82 | )
83 | y_train = np.genfromtxt(
84 | "/home/lemisma/datasets/isolet/isolet1+2+3+4.data",
85 | delimiter=",",
86 | usecols=[617],
87 | encoding="UTF-8",
88 | )
89 | x_test = np.genfromtxt(
90 | "/home/lemisma/datasets/isolet/isolet5.data",
91 | delimiter=",",
92 | usecols=range(0, 617),
93 | encoding="UTF-8",
94 | )
95 | y_test = np.genfromtxt(
96 | "/home/lemisma/datasets/isolet/isolet5.data",
97 | delimiter=",",
98 | usecols=[617],
99 | encoding="UTF-8",
100 | )
101 |
102 | X = MinMaxScaler(feature_range=(0, 1)).fit_transform(
103 | np.concatenate((x_train, x_test))
104 | )
105 | x_train = X[: len(y_train)]
106 | x_test = X[len(y_train) :]
107 |
108 | print(x_train.shape, y_train.shape)
109 | print(x_test.shape, y_test.shape)
110 |
111 | return (x_train, y_train - 1), (x_test, y_test - 1)
112 |
113 |
114 | import numpy as np
115 |
116 |
117 | def load_epileptic():
118 | filling_value = -100000
119 |
120 | X = np.genfromtxt(
121 | "/home/lemisma/datasets/data.csv",
122 | delimiter=",",
123 | skip_header=1,
124 | usecols=range(1, 179),
125 | filling_values=filling_value,
126 | encoding="UTF-8",
127 | )
128 | Y = np.genfromtxt(
129 | "/homelemisma/datasets/data.csv",
130 | delimiter=",",
131 | skip_header=1,
132 | usecols=range(179, 180),
133 | encoding="UTF-8",
134 | )
135 |
136 | X = MinMaxScaler(feature_range=(0, 1)).fit_transform(X)
137 |
138 | indices = np.arange(X.shape[0])
139 | np.random.shuffle(indices)
140 | X = X[indices]
141 | Y = Y[indices]
142 |
143 | print(X.shape, Y.shape)
144 |
145 | return (X[:8000], Y[:8000]), (X[8000:], Y[8000:])
146 |
147 |
148 | import os
149 |
150 | from PIL import Image
151 |
152 |
153 | def load_coil():
154 | samples = []
155 | for i in range(1, 21):
156 | for image_index in range(72):
157 | obj_img = Image.open(
158 | os.path.join(
159 | "/home/lemisma/datasets/coil-20-proc",
160 | "obj%d__%d.png" % (i, image_index),
161 | )
162 | )
163 | rescaled = obj_img.resize((20, 20))
164 | pixels_values = [float(x) for x in list(rescaled.getdata())]
165 | sample = np.array(pixels_values + [i])
166 | samples.append(sample)
167 | samples = np.array(samples)
168 | np.random.shuffle(samples)
169 | data = samples[:, :-1]
170 | targets = (samples[:, -1] + 0.5).astype(np.int64)
171 | data = (data - data.min()) / (data.max() - data.min())
172 |
173 | l = data.shape[0] * 4 // 5
174 | train = (data[:l], targets[:l] - 1)
175 | test = (data[l:], targets[l:] - 1)
176 | print(train[0].shape, train[1].shape)
177 | print(test[0].shape, test[1].shape)
178 | return train, test
179 |
180 |
181 | import tensorflow as tf
182 | from sklearn.model_selection import train_test_split
183 |
184 |
185 | def load_data(fashion=False, digit=None, normalize=False):
186 | if fashion:
187 | (x_train, y_train), (x_test, y_test) = (
188 | tf.keras.datasets.fashion_mnist.load_data()
189 | )
190 | else:
191 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
192 |
193 | if digit is not None and 0 <= digit and digit <= 9:
194 | train = test = {y: [] for y in range(10)}
195 | for x, y in zip(x_train, y_train):
196 | train[y].append(x)
197 | for x, y in zip(x_test, y_test):
198 | test[y].append(x)
199 |
200 | for y in range(10):
201 |
202 | train[y] = np.asarray(train[y])
203 | test[y] = np.asarray(test[y])
204 |
205 | x_train = train[digit]
206 | x_test = test[digit]
207 |
208 | x_train = x_train.reshape((-1, x_train.shape[1] * x_train.shape[2])).astype(
209 | np.float32
210 | )
211 | x_test = x_test.reshape((-1, x_test.shape[1] * x_test.shape[2])).astype(np.float32)
212 |
213 | if normalize:
214 | X = np.concatenate((x_train, x_test))
215 | X = (X - X.min()) / (X.max() - X.min())
216 | x_train = X[: len(y_train)]
217 | x_test = X[len(y_train) :]
218 |
219 | # print(x_train.shape, y_train.shape)
220 | # print(x_test.shape, y_test.shape)
221 | return (x_train, y_train), (x_test, y_test)
222 |
223 |
224 | def load_mnist():
225 | train, test = load_data(fashion=False, normalize=True)
226 | x_train, x_test, y_train, y_test = train_test_split(test[0], test[1], test_size=0.2)
227 | return (x_train, y_train), (x_test, y_test)
228 |
229 |
230 | def load_fashion():
231 | train, test = load_data(fashion=True, normalize=True)
232 | x_train, x_test, y_train, y_test = train_test_split(test[0], test[1], test_size=0.2)
233 | return (x_train, y_train), (x_test, y_test)
234 |
235 |
236 | def load_mnist_two_digits(digit1, digit2):
237 | train_digit_1, _ = load_data(digit=digit1)
238 | train_digit_2, _ = load_data(digit=digit2)
239 |
240 | X_train_1, X_test_1 = train_test_split(train_digit_1[0], test_size=0.6)
241 | X_train_2, X_test_2 = train_test_split(train_digit_2[0], test_size=0.6)
242 |
243 | X_train = np.concatenate((X_train_1, X_train_2))
244 | y_train = np.array([0] * X_train_1.shape[0] + [1] * X_train_2.shape[0])
245 | shuffled_idx = np.random.permutation(X_train.shape[0])
246 | np.take(X_train, shuffled_idx, axis=0, out=X_train)
247 | np.take(y_train, shuffled_idx, axis=0, out=y_train)
248 |
249 | X_test = np.concatenate((X_test_1, X_test_2))
250 | y_test = np.array([0] * X_test_1.shape[0] + [1] * X_test_2.shape[0])
251 | shuffled_idx = np.random.permutation(X_test.shape[0])
252 | np.take(X_test, shuffled_idx, axis=0, out=X_test)
253 | np.take(y_test, shuffled_idx, axis=0, out=y_test)
254 |
255 | print(X_train.shape, y_train.shape)
256 | print(X_test.shape, y_test.shape)
257 |
258 | return (X_train, y_train), (X_test, y_test)
259 |
260 |
261 | import os
262 |
263 | from sklearn.preprocessing import MinMaxScaler
264 |
265 |
266 | def load_activity():
267 | x_train = np.loadtxt(
268 | os.path.join("/home/lemisma/datasets/dataset_uci", "final_X_train.txt"),
269 | delimiter=",",
270 | encoding="UTF-8",
271 | )
272 | x_test = np.loadtxt(
273 | os.path.join("/home/lemisma/datasets/dataset_uci", "final_X_test.txt"),
274 | delimiter=",",
275 | encoding="UTF-8",
276 | )
277 | y_train = (
278 | np.loadtxt(
279 | os.path.join("/home/lemisma/datasets/dataset_uci", "final_y_train.txt"),
280 | delimiter=",",
281 | encoding="UTF-8",
282 | )
283 | - 1
284 | )
285 | y_test = (
286 | np.loadtxt(
287 | os.path.join("/home/lemisma/datasets/dataset_uci", "final_y_test.txt"),
288 | delimiter=",",
289 | encoding="UTF-8",
290 | )
291 | - 1
292 | )
293 |
294 | X = MinMaxScaler(feature_range=(0, 1)).fit_transform(
295 | np.concatenate((x_train, x_test))
296 | )
297 | x_train = X[: len(y_train)]
298 | x_test = X[len(y_train) :]
299 |
300 | print(x_train.shape, y_train.shape)
301 | print(x_test.shape, y_test.shape)
302 | return (x_train, y_train), (x_test, y_test)
303 |
304 |
305 | def load_dataset(dataset):
306 | if dataset == "MNIST":
307 | return load_mnist()
308 | elif dataset == "MNIST-Fashion":
309 | return load_fashion()
310 | if dataset == "MICE":
311 | return load_mice()
312 | elif dataset == "COIL":
313 | return load_coil()
314 | elif dataset == "ISOLET":
315 | return load_isolet()
316 | elif dataset == "Activity":
317 | return load_activity()
318 | else:
319 | print("Please specify a valid dataset")
320 | return None
321 |
--------------------------------------------------------------------------------
/experiments/run.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import torch
4 | from data_utils import load_dataset
5 | from sklearn.model_selection import train_test_split
6 |
7 | from lassonet import LassoNetClassifier
8 | from lassonet.utils import eval_on_path
9 |
10 | seed = None
11 | device = "cuda" if torch.cuda.is_available() else "cpu"
12 | batch_size = 256
13 | K = 50 # Number of features to select
14 | n_epochs = 1000
15 | dataset = "ISOLET"
16 |
17 | # Load dataset and split the data
18 | (X_train_valid, y_train_valid), (X_test, y_test) = load_dataset(dataset)
19 | X_train, X_val, y_train, y_val = train_test_split(
20 | X_train_valid, y_train_valid, test_size=0.125, random_state=seed
21 | )
22 |
23 | # Set the dimensions of the hidden layers
24 | data_dim = X_test.shape[1]
25 | hidden_dim = (data_dim // 3,)
26 |
27 | # Initialize the LassoNetClassifier model and compute the path
28 | lasso_model = LassoNetClassifier(
29 | M=10,
30 | hidden_dims=hidden_dim,
31 | verbose=1,
32 | torch_seed=seed,
33 | random_state=seed,
34 | device=device,
35 | n_iters=n_epochs,
36 | batch_size=batch_size,
37 | )
38 | path = lasso_model.path(X_train, y_train, X_val=X_val, y_val=y_val)
39 |
40 | # Select the features
41 | desired_save = next(save for save in path if save.selected.sum().item() <= K)
42 | SELECTED_FEATURES = desired_save.selected
43 | print("Number of selected features:", SELECTED_FEATURES.sum().item())
44 |
45 | # Select the features from the training, validation, and test data
46 | X_train_selected = X_train[:, SELECTED_FEATURES]
47 | X_val_selected = X_val[:, SELECTED_FEATURES]
48 | X_test_selected = X_test[:, SELECTED_FEATURES]
49 |
50 | # Initialize another LassoNetClassifier for retraining with the selected features
51 | lasso_sparse = LassoNetClassifier(
52 | M=10,
53 | hidden_dims=hidden_dim,
54 | verbose=1,
55 | torch_seed=seed,
56 | random_state=seed,
57 | device=device,
58 | n_iters=n_epochs,
59 | )
60 | path_sparse = lasso_sparse.path(
61 | X_train_selected,
62 | y_train,
63 | X_val=X_val_selected,
64 | y_val=y_val,
65 | lambda_seq=[0],
66 | return_state_dicts=True,
67 | )[:1]
68 |
69 | # Evaluate the model on the test data
70 | score = eval_on_path(lasso_sparse, path_sparse, X_test_selected, y_test)
71 | print("Test accuracy:", score)
72 |
73 | # Save the path
74 | with open(f"{dataset}_path.pkl", "wb") as f:
75 | pickle.dump(path_sparse, f)
76 |
--------------------------------------------------------------------------------
/lassonet/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .interfaces import (
3 | LassoNetClassifier,
4 | LassoNetClassifierCV,
5 | LassoNetCoxRegressor,
6 | LassoNetCoxRegressorCV,
7 | LassoNetRegressor,
8 | LassoNetRegressorCV,
9 | lassonet_path,
10 | )
11 | from .model import LassoNet
12 | from .plot import plot_cv, plot_path
13 | from .prox import prox
14 |
--------------------------------------------------------------------------------
/lassonet/cox.py:
--------------------------------------------------------------------------------
1 | """
2 | implement CoxPHLoss
3 | """
4 |
5 | __all__ = ["CoxPHLoss", "concordance_index"]
6 |
7 | import torch
8 | from sortedcontainers import SortedList
9 |
10 | from .utils import log_substract, scatter_logsumexp
11 |
12 |
13 | class CoxPHLoss(torch.nn.Module):
14 | """Loss for CoxPH model."""
15 |
16 | allowed = ("breslow", "efron")
17 |
18 | def __init__(self, method):
19 | super().__init__()
20 | assert method in self.allowed, f"Method must be one of {self.allowed}"
21 | self.method = method
22 |
23 | def forward(self, log_h, y):
24 | log_h = log_h.flatten()
25 |
26 | durations, events = y.T
27 |
28 | # sort input
29 | durations, idx = durations.sort(descending=True)
30 | log_h = log_h[idx]
31 | events = events[idx]
32 |
33 | event_ind = events.nonzero().flatten()
34 | if event_ind.nelement() == 0:
35 | # return 0 while connecting the gradient
36 | return log_h[:0].sum()
37 |
38 | # numerator
39 | log_num = log_h[event_ind].mean()
40 |
41 | # logcumsumexp of events
42 | event_lcse = torch.logcumsumexp(log_h, dim=0)[event_ind]
43 |
44 | # number of events for each unique risk set
45 | _, tie_inverses, tie_count = torch.unique_consecutive(
46 | durations[event_ind], return_counts=True, return_inverse=True
47 | )
48 |
49 | # position of last event (lowest duration) of each unique risk set
50 | tie_pos = tie_count.cumsum(axis=0) - 1
51 |
52 | # logcumsumexp by tie for each event
53 | event_tie_lcse = event_lcse[tie_pos][tie_inverses]
54 |
55 | if self.method == "breslow":
56 | log_den = event_tie_lcse.mean()
57 |
58 | elif self.method == "efron":
59 | # based on https://bydmitry.github.io/efron-tensorflow.html
60 |
61 | # logsumexp of ties, duplicated within tie set
62 | tie_lse = scatter_logsumexp(log_h[event_ind], tie_inverses, dim=0)[
63 | tie_inverses
64 | ]
65 | # multiply (add in log space) with corrective factor
66 | aux = torch.ones_like(tie_inverses)
67 | aux[tie_pos[:-1] + 1] -= tie_count[:-1]
68 | event_id_in_tie = torch.cumsum(aux, dim=0) - 1
69 | discounted_tie_lse = (
70 | tie_lse
71 | + torch.log(event_id_in_tie)
72 | - torch.log(tie_count[tie_inverses])
73 | )
74 |
75 | # denominator
76 | log_den = log_substract(event_tie_lcse, discounted_tie_lse).mean()
77 |
78 | # loss is negative log likelihood
79 | return log_den - log_num
80 |
81 |
82 | def concordance_index(risk, time, event):
83 | """
84 | O(n log n) implementation of https://square.github.io/pysurvival/metrics/c_index.html
85 | """
86 | assert len(risk) == len(time) == len(event)
87 | n = len(risk)
88 | order = sorted(range(n), key=time.__getitem__)
89 | past = SortedList()
90 | num = 0
91 | den = 0
92 | for i in order:
93 | num += len(past) - past.bisect_right(risk[i])
94 | den += len(past)
95 | if event[i]:
96 | past.add(risk[i])
97 | return num / den
98 |
--------------------------------------------------------------------------------
/lassonet/model.py:
--------------------------------------------------------------------------------
1 | from itertools import islice
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from .prox import inplace_group_prox, inplace_prox, prox
8 |
9 |
10 | class LassoNet(nn.Module):
11 | def __init__(self, *dims, groups=None, dropout=None):
12 | """
13 | first dimension is input
14 | last dimension is output
15 | `groups` is a list of list such that `groups[i]`
16 | contains the indices of the features in the i-th group
17 |
18 | """
19 | assert len(dims) > 2
20 | if groups is not None:
21 | n_inputs = dims[0]
22 | all_indices = []
23 | for g in groups:
24 | for i in g:
25 | all_indices.append(i)
26 | assert len(all_indices) == n_inputs and set(all_indices) == set(
27 | range(n_inputs)
28 | ), f"Groups must be a partition of range(n_inputs={n_inputs})"
29 |
30 | self.groups = groups
31 |
32 | super().__init__()
33 |
34 | self.dropout = nn.Dropout(p=dropout) if dropout is not None else None
35 | self.layers = nn.ModuleList(
36 | [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)]
37 | )
38 | self.skip = nn.Linear(dims[0], dims[-1], bias=False)
39 |
40 | def forward(self, inp):
41 | current_layer = inp
42 | result = self.skip(inp)
43 | for theta in self.layers:
44 | current_layer = theta(current_layer)
45 | if theta is not self.layers[-1]:
46 | if self.dropout is not None:
47 | current_layer = self.dropout(current_layer)
48 | current_layer = F.relu(current_layer)
49 | return result + current_layer
50 |
51 | def prox(self, *, lambda_, lambda_bar=0, M=1):
52 | if self.groups is None:
53 | with torch.no_grad():
54 | inplace_prox(
55 | beta=self.skip,
56 | theta=self.layers[0],
57 | lambda_=lambda_,
58 | lambda_bar=lambda_bar,
59 | M=M,
60 | )
61 | else:
62 | with torch.no_grad():
63 | inplace_group_prox(
64 | groups=self.groups,
65 | beta=self.skip,
66 | theta=self.layers[0],
67 | lambda_=lambda_,
68 | lambda_bar=lambda_bar,
69 | M=M,
70 | )
71 |
72 | def lambda_start(
73 | self,
74 | M=1,
75 | lambda_bar=0,
76 | factor=2,
77 | ):
78 | """Estimate when the model will start to sparsify."""
79 |
80 | def is_sparse(lambda_):
81 | with torch.no_grad():
82 | beta = self.skip.weight.data
83 | theta = self.layers[0].weight.data
84 |
85 | for _ in range(10000):
86 | new_beta, theta = prox(
87 | beta,
88 | theta,
89 | lambda_=lambda_,
90 | lambda_bar=lambda_bar,
91 | M=M,
92 | )
93 | if torch.abs(beta - new_beta).max() < 1e-5:
94 | break
95 | beta = new_beta
96 | return (torch.norm(beta, p=2, dim=0) == 0).sum()
97 |
98 | start = 1e-6
99 | while not is_sparse(factor * start):
100 | start *= factor
101 | return start
102 |
103 | def l2_regularization(self):
104 | """
105 | L2 regulatization of the MLP without the first layer
106 | which is bounded by the skip connection
107 | """
108 | ans = 0
109 | for layer in islice(self.layers, 1, None):
110 | ans += (
111 | torch.norm(
112 | layer.weight.data,
113 | p=2,
114 | )
115 | ** 2
116 | )
117 | return ans
118 |
119 | def l1_regularization_skip(self):
120 | return torch.norm(self.skip.weight.data, p=2, dim=0).sum()
121 |
122 | def l2_regularization_skip(self):
123 | return torch.norm(self.skip.weight.data, p=2)
124 |
125 | def input_mask(self):
126 | with torch.no_grad():
127 | return torch.norm(self.skip.weight.data, p=2, dim=0) != 0
128 |
129 | def selected_count(self):
130 | return self.input_mask().sum().item()
131 |
132 | def cpu_state_dict(self):
133 | return {k: v.detach().clone().cpu() for k, v in self.state_dict().items()}
134 |
--------------------------------------------------------------------------------
/lassonet/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | from .interfaces import BaseLassoNetCV
4 | from .utils import confidence_interval, eval_on_path
5 |
6 |
7 | def plot_path(model, X_test, y_test, *, score_function=None):
8 | """
9 | Plot the evolution of the model on the path, namely:
10 | - lambda
11 | - number of selected variables
12 | - score
13 |
14 | Requires to have called model.path(return_state_dicts=True) beforehand.
15 |
16 |
17 | Parameters
18 | ==========
19 | model : LassoNetClassifier or LassoNetRegressor
20 | X_test : array-like
21 | y_test : array-like
22 | score_function : function or None
23 | if None, use score_function=model.score
24 | score_function must take as input X_test, y_test
25 | """
26 | # TODO: plot with manually computed score
27 | score = eval_on_path(
28 | model, model.path_, X_test, y_test, score_function=score_function
29 | )
30 | n_selected = [save.selected.sum() for save in model.path_]
31 | lambda_ = [save.lambda_ for save in model.path_]
32 |
33 | plt.figure(figsize=(16, 16))
34 |
35 | plt.subplot(311)
36 | plt.grid(True)
37 | plt.plot(n_selected, score, ".-")
38 | plt.xlabel("number of selected features")
39 | plt.ylabel("score")
40 |
41 | plt.subplot(312)
42 | plt.grid(True)
43 | plt.plot(lambda_, score, ".-")
44 | plt.xlabel("lambda")
45 | plt.xscale("log")
46 | plt.ylabel("score")
47 |
48 | plt.subplot(313)
49 | plt.grid(True)
50 | plt.plot(lambda_, n_selected, ".-")
51 | plt.xlabel("lambda")
52 | plt.xscale("log")
53 | plt.ylabel("number of selected features")
54 |
55 | plt.tight_layout()
56 |
57 |
58 | def plot_cv(model: BaseLassoNetCV, X_test, y_test, *, score_function=None):
59 | # TODO: plot with manually computed score
60 | lambda_ = [save.lambda_ for save in model.path_]
61 | lambdas = [[h.lambda_ for h in p] for p in model.raw_paths_]
62 |
63 | score = eval_on_path(
64 | model, model.path_, X_test, y_test, score_function=score_function
65 | )
66 |
67 | plt.figure(figsize=(16, 16))
68 |
69 | plt.subplot(211)
70 | plt.grid(True)
71 | first = True
72 | for sl, ss in zip(lambdas, model.raw_scores_):
73 | plt.plot(
74 | sl,
75 | ss,
76 | "r.-",
77 | markersize=5,
78 | alpha=0.2,
79 | label="cross-validation" if first else None,
80 | )
81 | first = False
82 | avg = model.interp_scores_.mean(axis=1)
83 | ci = confidence_interval(model.interp_scores_)
84 | plt.plot(
85 | model.lambdas_,
86 | avg,
87 | "g.-",
88 | markersize=5,
89 | alpha=0.2,
90 | label="average cv with 95% CI",
91 | )
92 | plt.fill_between(model.lambdas_, avg - ci, avg + ci, color="g", alpha=0.1)
93 | plt.plot(lambda_, score, "b.-", markersize=5, alpha=0.2, label="test")
94 | plt.legend()
95 | plt.xlabel("lambda")
96 | plt.xscale("log")
97 | plt.ylabel("score")
98 |
99 | plt.subplot(212)
100 | plt.grid(True)
101 | first = True
102 | for sl, path in zip(lambdas, model.raw_paths_):
103 | plt.plot(
104 | sl,
105 | [save.selected.sum() for save in path],
106 | "r.-",
107 | markersize=5,
108 | alpha=0.2,
109 | label="cross-validation" if first else None,
110 | )
111 | first = False
112 | plt.plot(
113 | lambda_,
114 | [save.selected.sum() for save in model.path_],
115 | "b.-",
116 | markersize=5,
117 | alpha=0.2,
118 | label="test",
119 | )
120 | plt.legend()
121 | plt.xlabel("lambda")
122 | plt.xscale("log")
123 | plt.ylabel("number of selected features")
124 |
125 | plt.tight_layout()
126 |
--------------------------------------------------------------------------------
/lassonet/prox.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 |
5 | def soft_threshold(l, x):
6 | return torch.sign(x) * torch.relu(torch.abs(x) - l)
7 |
8 |
9 | def sign_binary(x):
10 | ones = torch.ones_like(x)
11 | return torch.where(x >= 0, ones, -ones)
12 |
13 |
14 | def prox(v, u, *, lambda_, lambda_bar, M):
15 | """
16 | v has shape (m,) or (m, batches)
17 | u has shape (k,) or (k, batches)
18 |
19 | supports GPU tensors
20 | """
21 | onedim = len(v.shape) == 1
22 | if onedim:
23 | v = v.unsqueeze(-1)
24 | u = u.unsqueeze(-1)
25 |
26 | u_abs_sorted = torch.sort(u.abs(), dim=0, descending=True).values
27 |
28 | k, batch = u.shape
29 |
30 | s = torch.arange(k + 1.0).view(-1, 1).to(v)
31 | zeros = torch.zeros(1, batch).to(u)
32 |
33 | a_s = lambda_ - M * torch.cat(
34 | [zeros, torch.cumsum(u_abs_sorted - lambda_bar, dim=0)]
35 | )
36 |
37 | norm_v = torch.norm(v, p=2, dim=0)
38 |
39 | x = F.relu(1 - a_s / norm_v) / (1 + s * M**2)
40 |
41 | w = M * x * norm_v
42 | intervals = soft_threshold(lambda_bar, u_abs_sorted)
43 | lower = torch.cat([intervals, zeros])
44 |
45 | idx = torch.sum(lower > w, dim=0).unsqueeze(0)
46 |
47 | x_star = torch.gather(x, 0, idx).view(1, batch)
48 | w_star = torch.gather(w, 0, idx).view(1, batch)
49 |
50 | beta_star = x_star * v
51 | theta_star = sign_binary(u) * torch.min(soft_threshold(lambda_bar, u.abs()), w_star)
52 |
53 | if onedim:
54 | beta_star.squeeze_(-1)
55 | theta_star.squeeze_(-1)
56 |
57 | return beta_star, theta_star
58 |
59 |
60 | def inplace_prox(beta, theta, lambda_, lambda_bar, M):
61 | beta.weight.data, theta.weight.data = prox(
62 | beta.weight.data, theta.weight.data, lambda_=lambda_, lambda_bar=lambda_bar, M=M
63 | )
64 |
65 |
66 | def inplace_group_prox(groups, beta, theta, lambda_, lambda_bar, M):
67 | """
68 | groups is an iterable such that group[i] contains the indices of features in group i
69 | """
70 | beta_ = beta.weight.data
71 | theta_ = theta.weight.data
72 | beta_ans = torch.empty_like(beta_)
73 | theta_ans = torch.empty_like(theta_)
74 | for g in groups:
75 | group_beta = beta_[:, g]
76 | group_beta_shape = group_beta.shape
77 | group_theta = theta_[:, g]
78 | group_theta_shape = group_theta.shape
79 | group_beta, group_theta = prox(
80 | group_beta.reshape(-1),
81 | group_theta.reshape(-1),
82 | lambda_=lambda_,
83 | lambda_bar=lambda_bar,
84 | M=M,
85 | )
86 | beta_ans[:, g] = group_beta.reshape(*group_beta_shape)
87 | theta_ans[:, g] = group_theta.reshape(*group_theta_shape)
88 | beta.weight.data, theta.weight.data = beta_ans, theta_ans
89 |
--------------------------------------------------------------------------------
/lassonet/r.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from .interfaces import LassoNetClassifier, LassoNetRegressor
7 | from .interfaces import lassonet_path as _lassonet_path
8 |
9 |
10 | def make_writable(x):
11 | if isinstance(x, np.ndarray):
12 | x = x.copy()
13 | return x
14 |
15 |
16 | def lassonet_path(X, y, task, *args, **kwargs):
17 | X = make_writable(X)
18 | y = make_writable(y)
19 |
20 | def convert_item(item):
21 | item = asdict(item)
22 | item["state_dict"] = {k: v.numpy() for k, v in item["state_dict"].items()}
23 | item["selected"] = item["selected"].numpy()
24 | return item
25 |
26 | return list(map(convert_item, _lassonet_path(X, y, task, *args, **kwargs)))
27 |
28 |
29 | def lassonet_eval(X, task, state_dict, **kwargs):
30 | X = make_writable(X)
31 |
32 | if task == "classification":
33 | model = LassoNetClassifier(**kwargs)
34 | elif task == "regression":
35 | model = LassoNetRegressor(**kwargs)
36 | else:
37 | raise ValueError('task must be "classification" or "regression"')
38 | state_dict = {k: torch.tensor(v) for k, v in state_dict.items()}
39 | model.load(state_dict)
40 | if hasattr(model, "predict_proba"):
41 | return model.predict_proba(X)
42 | else:
43 | return model.predict(X)
44 |
--------------------------------------------------------------------------------
/lassonet/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import zip_longest
2 | from typing import TYPE_CHECKING, Iterable, List
3 |
4 | import scipy.stats
5 | import torch
6 |
7 | if TYPE_CHECKING:
8 | from lassonet.interfaces import HistoryItem
9 |
10 |
11 | def eval_on_path(model, path, X_test, y_test, *, score_function=None):
12 | if score_function is None:
13 | score_fun = model.score
14 | else:
15 | assert callable(score_function)
16 |
17 | def score_fun(X_test, y_test):
18 | return score_function(y_test, model.predict(X_test))
19 |
20 | score = []
21 | for save in path:
22 | model.load(save.state_dict)
23 | score.append(score_fun(X_test, y_test))
24 | return score
25 |
26 |
27 | if hasattr(torch.Tensor, "scatter_reduce_"):
28 | # version >= 1.12
29 | def scatter_reduce(input, dim, index, reduce, *, output_size=None):
30 | src = input
31 | if output_size is None:
32 | output_size = index.max() + 1
33 | return torch.empty(output_size, device=input.device).scatter_reduce(
34 | dim=dim, index=index, src=src, reduce=reduce, include_self=False
35 | )
36 |
37 | else:
38 | scatter_reduce = torch.scatter_reduce
39 |
40 |
41 | def scatter_logsumexp(input, index, *, dim=-1, output_size=None):
42 | """Inspired by torch_scatter.logsumexp
43 | Uses torch.scatter_reduce for performance
44 | """
45 | max_value_per_index = scatter_reduce(
46 | input, dim=dim, index=index, output_size=output_size, reduce="amax"
47 | )
48 | max_per_src_element = max_value_per_index.gather(dim, index)
49 | recentered_scores = input - max_per_src_element
50 | sum_per_index = scatter_reduce(
51 | recentered_scores.exp(),
52 | dim=dim,
53 | index=index,
54 | output_size=output_size,
55 | reduce="sum",
56 | )
57 | return max_value_per_index + sum_per_index.log()
58 |
59 |
60 | def log_substract(x, y):
61 | """log(exp(x) - exp(y))"""
62 | return x + torch.log1p(-(y - x).exp())
63 |
64 |
65 | def confidence_interval(data, confidence=0.95):
66 | if isinstance(data[0], Iterable):
67 | return [confidence_interval(d, confidence) for d in data]
68 | return scipy.stats.t.interval(
69 | confidence,
70 | len(data) - 1,
71 | scale=scipy.stats.sem(data),
72 | )[1]
73 |
74 |
75 | def selection_probability(paths: List[List["HistoryItem"]]):
76 | """Compute the selection probability of each feature at each step.
77 | The individual curves are smoothed to that they are monotonically decreasing.
78 |
79 | Input
80 | -----
81 | paths: List of List of HistoryItem
82 | The lambda paths must be the same for all models.
83 |
84 | Output
85 | ------
86 | prob: torch.Tensor
87 | Tensor of shape (n_steps, n_features) containing the selection probability
88 | of each feature at lambda value.
89 | expected_wrong: tuple of (Tensor, LongTensor)
90 | Expected number of wrong features.
91 | (values, indices) where values are the expected number of wrong features
92 | and indices are the order of the selected features.
93 | """
94 | n_models = len(paths)
95 |
96 | prob = []
97 | selected = torch.ones_like(paths[0][0].selected)
98 | iterable = zip_longest(
99 | *[[it.selected for it in path] for path in paths],
100 | fillvalue=torch.zeros_like(paths[0][0].selected),
101 | )
102 | for its in iterable:
103 | sel = sum(its) / n_models
104 | selected = torch.minimum(selected, sel)
105 | prob.append(selected)
106 | prob = torch.stack(prob)
107 |
108 | expected_wrong = (
109 | prob.shape[1] * (prob.mean(dim=1, keepdim=True)) ** 2 / (2 * prob - 1)
110 | )
111 | expected_wrong[prob <= 0.5] = float("inf")
112 | return prob, expected_wrong.min(axis=0).values.sort()
113 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from setuptools import setup
4 |
5 |
6 | def read(fname):
7 | return (Path(__file__).parent / fname).open().read()
8 |
9 |
10 | setup(
11 | name="lassonet",
12 | version="0.0.20",
13 | author="Louis Abraham, Ismael Lemhadri",
14 | author_email="louis.abraham@yahoo.fr, lemhadri@stanford.edu",
15 | license="MIT",
16 | description="Reference implementation of LassoNet",
17 | long_description=read("README.md"),
18 | long_description_content_type="text/markdown",
19 | url="https://github.com/lasso-net/lassonet",
20 | classifiers=[
21 | "Programming Language :: Python :: 3",
22 | "License :: OSI Approved :: MIT License",
23 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
24 | "Operating System :: OS Independent",
25 | ],
26 | packages=["lassonet"],
27 | install_requires=[
28 | "torch >= 1.11",
29 | "scikit-learn",
30 | "matplotlib",
31 | "sortedcontainers",
32 | "tqdm",
33 | ],
34 | extras_require={
35 | "dev": [
36 | "pre-commit",
37 | "sphinx",
38 | "black",
39 | ]
40 | },
41 | tests_require=["pytest"],
42 | python_requires=">=3.8",
43 | )
44 |
--------------------------------------------------------------------------------
/sphinx_docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/sphinx_docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 |
16 | sys.path.insert(0, os.path.abspath(".."))
17 |
18 |
19 | # -- Project information -----------------------------------------------------
20 |
21 | project = "LassoNet"
22 | copyright = "2021, Louis Abraham, Ismael Lemhadri"
23 | author = "Louis Abraham, Ismael Lemhadri"
24 |
25 |
26 | # -- General configuration ---------------------------------------------------
27 |
28 | # Add any Sphinx extension module names here, as strings. They can be
29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
30 | # ones.
31 | extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx.ext.autosummary"]
32 |
33 | # Add any paths that contain templates here, relative to this directory.
34 | templates_path = ["_templates"]
35 |
36 | # List of patterns, relative to source directory, that match files and
37 | # directories to ignore when looking for source files.
38 | # This pattern also affects html_static_path and html_extra_path.
39 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
40 |
41 |
42 | # -- Options for HTML output -------------------------------------------------
43 |
44 | # The theme to use for HTML and HTML Help pages. See the documentation for
45 | # a list of builtin themes.
46 | #
47 | html_theme = "alabaster"
48 |
49 | # Add any paths that contain custom static files (such as style sheets) here,
50 | # relative to this directory. They are copied after the builtin static files,
51 | # so a file named "default.css" will overwrite the builtin "default.css".
52 | # html_static_path = ["_static"]
53 |
54 |
55 | # -- extensions options ------------------------------------------------------
56 | autoclass_content = "both"
57 |
58 | autodoc_default_flags = ["members"]
59 | autosummary_generate = True
60 |
--------------------------------------------------------------------------------
/sphinx_docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to LassoNet's documentation!
2 | ====================================
3 |
4 | `Go to main website <../>`_
5 |
6 | Installation
7 | ------------
8 |
9 | ::
10 |
11 | pip install lassonet
12 |
13 |
14 | API
15 | ---
16 |
17 | .. autosummary::
18 | :toctree:
19 |
20 | .. autoclass:: lassonet.LassoNetRegressor
21 | :members:
22 | :inherited-members:
23 |
24 | .. autoclass:: lassonet.LassoNetClassifier
25 | :members:
26 | :inherited-members:
27 |
28 | .. autoclass:: lassonet.LassoNetCoxRegressor
29 | :members:
30 | :inherited-members:
31 |
32 | .. autoclass:: lassonet.LassoNetRegressorCV
33 | :members:
34 | :inherited-members:
35 |
36 | .. autoclass:: lassonet.LassoNetClassifierCV
37 | :members:
38 | :inherited-members:
39 |
40 | .. autoclass:: lassonet.LassoNetCoxRegressorCV
41 | :members:
42 | :inherited-members:
43 |
44 | .. autofunction:: lassonet.plot_path
45 |
46 | .. autofunction:: lassonet.lassonet_path
47 |
--------------------------------------------------------------------------------
/tests/test_interface.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import load_diabetes, load_digits
2 |
3 | from lassonet import LassoNetClassifier, LassoNetRegressor
4 |
5 |
6 | def test_regressor():
7 | X, y = load_diabetes(return_X_y=True)
8 | model = LassoNetRegressor()
9 | model.fit(X, y)
10 | model.score(X, y)
11 |
12 |
13 | def test_classifier():
14 | X, y = load_digits(return_X_y=True)
15 | model = LassoNetClassifier()
16 | model.fit(X, y)
17 | model.score(X, y)
18 |
--------------------------------------------------------------------------------