├── .circleci └── config.yml ├── .coveragerc ├── .github └── workflows │ ├── build.yml │ ├── flake8.yml │ └── main.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── celer ├── PN_logreg.pyx ├── __init__.py ├── cython_utils.pxd ├── cython_utils.pyx ├── datasets │ ├── __init__.py │ ├── climate.py │ ├── libsvm.py │ ├── ml_uci.py │ ├── simulated.py │ └── tests │ │ └── test_datasets.py ├── dropin_sklearn.py ├── group_fast.pyx ├── homotopy.py ├── lasso_fast.pyx ├── multitask_fast.pyx ├── tests │ ├── test_docstring_parameters.py │ ├── test_enet.py │ ├── test_lasso.py │ ├── test_logreg.py │ └── test_mtl.py └── utils │ ├── __init__.py │ └── testing.py ├── codecov.yml ├── doc ├── Makefile ├── _static │ └── style.css ├── api.rst ├── conf.py ├── contribute.rst ├── get_started.rst ├── github_link.py └── index.rst ├── examples ├── README.txt ├── plot_finance_path.py ├── plot_group_lasso.py ├── plot_lasso_cv.py ├── plot_leukemia_path.py ├── plot_logreg_timings.py └── plot_multitask_lasso_cv.py ├── pyproject.toml ├── setup.cfg ├── setup.py └── whats_new.rst /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build_docs: 4 | docker: 5 | - image: circleci/python:3.8.1-buster 6 | steps: 7 | - checkout 8 | - run: 9 | name: Set BASH_ENV 10 | command: | 11 | echo "set -e" >> $BASH_ENV 12 | echo "export DISPLAY=:99" >> $BASH_ENV 13 | echo "export OPENBLAS_NUM_THREADS=4" >> $BASH_ENV 14 | echo "export LIBSVMDATA_HOME=$HOME/celer_data" >> $BASH_ENV 15 | echo "BASH_ENV:" 16 | cat $BASH_ENV 17 | 18 | - run: 19 | name: Merge with upstream 20 | command: | 21 | echo $(git log -1 --pretty=%B) | tee gitlog.txt 22 | echo ${CI_PULL_REQUEST//*pull\//} | tee merge.txt 23 | if [[ $(cat merge.txt) != "" ]]; then 24 | echo "Merging $(cat merge.txt)"; 25 | git remote add upstream https://github.com/mathurinm/celer.git; 26 | git pull --ff-only upstream "refs/pull/$(cat merge.txt)/merge"; 27 | git fetch upstream main; 28 | fi 29 | 30 | # If both keys are in the same command only one is restored 31 | - restore_cache: 32 | keys: 33 | - pip-cache 34 | 35 | # do not merge with above, only the first cache is restored then 36 | - restore_cache: 37 | keys: 38 | - celer-data-1 39 | 40 | - run: 41 | name: Spin up Xvfb 42 | command: | 43 | /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset; 44 | 45 | - run: 46 | name: Get Python running 47 | command: | 48 | python -m pip install --user --upgrade --progress-bar off pip 49 | python -m pip install --user -e . 50 | python -m pip install --user --upgrade --progress-bar off -e .[doc] 51 | 52 | - save_cache: 53 | key: pip-cache 54 | paths: 55 | - ~/.cache/pip 56 | 57 | # Look at what we have and fail early if there is some library conflict 58 | - run: 59 | name: Check installation 60 | command: | 61 | which python 62 | python -c "import celer" 63 | 64 | # Build docs 65 | - run: 66 | name: make html 67 | no_output_timeout: 100m 68 | command: | 69 | cd doc; 70 | make clean; 71 | make SPHINXOPTS=-v html; 72 | cd ..; 73 | 74 | - run: 75 | name: List celer after make 76 | command: | 77 | ls $HOME/celer_data 78 | find $HOME/celer_data -maxdepth 3 79 | 80 | - save_cache: 81 | key: celer-data-1 82 | paths: 83 | - /home/circleci/celer_data 84 | 85 | # Deploy docs 86 | - run: 87 | name: deploy 88 | command: | 89 | if [[ ${CIRCLE_BRANCH} == "main" ]]; then 90 | set -e 91 | mkdir -p ~/.ssh 92 | echo -e "Host *\nStrictHostKeyChecking no" > ~/.ssh/config 93 | chmod og= ~/.ssh/config 94 | cd doc; 95 | pip install ghp-import; 96 | make install 97 | fi 98 | 99 | # Save the outputs 100 | - store_artifacts: 101 | path: doc/_build/html/ 102 | destination: dev 103 | - persist_to_workspace: 104 | root: doc/_build 105 | paths: 106 | - html 107 | 108 | 109 | workflows: 110 | version: 2 111 | 112 | default: 113 | jobs: 114 | - build_docs 115 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # Configuration for coverage.py 2 | 3 | [run] 4 | branch = True 5 | source = celer 6 | include = */celer/* 7 | omit = */setup.py 8 | 9 | [report] 10 | exclude_lines = 11 | pragma: no cover 12 | def __repr__ 13 | if self.debug: 14 | if settings.DEBUG 15 | raise AssertionError 16 | raise NotImplementedError 17 | if 0: 18 | if __name__ == .__main__.: 19 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | branches: 9 | - 'main' 10 | 11 | jobs: 12 | build-linux: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 3.8 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: 3.8 21 | 22 | - name: install 23 | run: | 24 | pip install -U pip 25 | pip install -e . 26 | 27 | - name: test 28 | run: | 29 | pip install pytest pytest-cov coverage numpydoc 30 | pytest -lv --cov-report term-missing celer --cov=celer --cov-config .coveragerc 31 | - name: codecov 32 | uses: codecov/codecov-action@v4.0.1 33 | with: 34 | slug: mathurinm/celer 35 | token: ${{ secrets.CODECOV_TOKEN }} 36 | files: .coveragerc 37 | flags: unittests 38 | fail_ci_if_error: true 39 | verbose: true 40 | -------------------------------------------------------------------------------- /.github/workflows/flake8.yml: -------------------------------------------------------------------------------- 1 | name: linter 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | branches: 9 | - 'main' 10 | 11 | 12 | jobs: 13 | lint: 14 | name: Lint code base 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v2 20 | 21 | - name: Setup Python 3.8 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: 3.8 25 | 26 | - name: Lint with flake 27 | run: | 28 | pip install --upgrade pip 29 | pip install flake8 30 | flake8 --max-line-length=88 celer/ 31 | flake8 --max-line-length=88 examples/ -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: circleci-redirector 2 | on: [status] 3 | jobs: 4 | circleci_artifacts_redirector_job: 5 | runs-on: ubuntu-latest 6 | name: Run CircleCI artifacts redirector 7 | steps: 8 | - name: GitHub Action step 9 | uses: larsoner/circleci-artifacts-redirector-action@master 10 | with: 11 | api-token: ${{ secrets.CIRCLE_TOKEN }} 12 | repo-token: ${{ secrets.GITHUB_TOKEN }} 13 | artifact-path: 0/dev/index.html 14 | circleci-jobs: build_docs 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Cython generation: 2 | *.cpp 3 | *.so 4 | *.o 5 | *.c 6 | *.html 7 | celer.egg-info 8 | .eggs 9 | 10 | # Python precompilation 11 | *pyc 12 | *pyd 13 | 14 | # build 15 | build 16 | 17 | # generated doc 18 | doc/_build/ 19 | doc/auto_examples/ 20 | doc/generated 21 | 22 | # cache 23 | .pytest_cache 24 | __pycache__ 25 | 26 | 27 | coverage/* 28 | .coverage 29 | 30 | .DS_Store 31 | MANIFEST 32 | *.examples 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018-2022, celer 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.rst 2 | include *.md 3 | include *.in 4 | include LICENSE 5 | include celer/__init__.py 6 | 7 | recursive-include examples *.py 8 | recursive-include examples *.txt 9 | recursive-include celer *.py 10 | recursive-include celer *.pyx 11 | recursive-include celer *.pxd 12 | 13 | ### Exclude 14 | 15 | exclude .gitignore 16 | exclude Makefile 17 | exclude .coveragerc 18 | exclude *.yml 19 | exclude .circleci/config.yml 20 | exclude .mailmap 21 | recursive-exclude celer *.pyc 22 | recursive-exclude celer *.cpp 23 | recursive-exclude examples * 24 | recursive-exclude .github/ * 25 | 26 | recursive-exclude doc * 27 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # simple makefile to simplify repetetive build env management tasks under posix 2 | 3 | PYTHON ?= python 4 | CYTHON ?= cython 5 | PYTESTS ?= pytest 6 | 7 | CTAGS ?= ctags 8 | 9 | all: clean inplace test 10 | 11 | clean-pyc: 12 | find . -name "*.pyc" | xargs rm -f 13 | find . -name "__pycache__" | xargs rm -rf 14 | 15 | clean-so: 16 | find . -name "*.so" | xargs rm -f 17 | find . -name "*.pyd" | xargs rm -f 18 | find . -name "*.cpp" | xargs rm -f 19 | find . -name "*.c" | xargs rm -f 20 | 21 | clean-build: 22 | rm -rf build 23 | 24 | clean-ctags: 25 | rm -f tags 26 | 27 | clean: clean-build clean-pyc clean-so clean-ctags 28 | 29 | in: inplace # just a shortcut 30 | inplace: 31 | $(PYTHON) setup.py build_ext -i 32 | 33 | test-code: 34 | $(PYTESTS) celer 35 | 36 | test-doc: 37 | $(PYTESTS) $(shell find doc -name '*.rst' | sort) 38 | 39 | test-coverage: 40 | rm -rf coverage .coverage 41 | $(PYTESTS) celer --cov=celer --cov-report html:coverage 42 | 43 | test: test-code test-doc test-manifest 44 | 45 | trailing-spaces: 46 | find . -name "*.py" | xargs perl -pi -e 's/[ \t]*$$//' 47 | 48 | cython: 49 | find -name "*.pyx" | xargs $(CYTHON) 50 | 51 | ctags: 52 | # make tags for symbol based navigation in emacs and vim 53 | # Install with: sudo apt-get install exuberant-ctags 54 | $(CTAGS) -R * 55 | 56 | .PHONY : doc-plot 57 | doc-plot: 58 | make -C doc html 59 | 60 | .PHONY : doc 61 | doc: 62 | make -C doc html-noplot 63 | 64 | test-manifest: 65 | check-manifest --ignore doc,celer/*/tests; 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # celer 2 | 3 | ![build](https://github.com/mathurinm/celer/workflows/build/badge.svg) 4 | ![coverage](https://codecov.io/gh/mathurinm/celer/branch/main/graphs/badge.svg?branch=main) 5 | ![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg) 6 | [![Downloads](https://static.pepy.tech/badge/celer)](https://pepy.tech/project/celer) 7 | [![Downloads](https://pepy.tech/badge/celer/month)](https://pepy.tech/project/celer) 8 | [![PyPI version](https://badge.fury.io/py/celer.svg)](https://pypi.org/project/celer) 9 | 10 | 11 | ``celer`` is a Python package that solves Lasso-like problems and provides estimators that follow the ``scikit-learn`` API. Thanks to a tailored implementation, ``celer`` provides a fast solver that tackles large-scale datasets with millions of features **up to 100 times faster than ``scikit-learn``**. 12 | 13 | Currently, the package handles the following problems: 14 | 15 | 16 | | Problem | Support Weights | Native cross-validation 17 | | ----------- | ----------- |---------------- 18 | | Lasso | ✓ | ✓ 19 | | ElasticNet | ✓ | ✓ 20 | | Group Lasso | ✓ | ✓ 21 | | Multitask Lasso | ✕ | ✓ 22 | | Sparse Logistic regression | ✕ | ✕ 23 | 24 | If you are interested in other models, such as non convex penalties (SCAD, MCP), sparse group lasso, group logistic regression, Poisson regression, Tweedie regression, have a look at our companion package [``skglm``](https://github.com/scikit-learn-contrib/skglm) 25 | 26 | ## Cite 27 | 28 | ``celer`` is licensed under the [BSD 3-Clause](https://github.com/mathurinm/celer/blob/main/LICENSE). Hence, you are free to use it. 29 | If you do so, please cite: 30 | 31 | 32 | ```bibtex 33 | @InProceedings{pmlr-v80-massias18a, 34 | title = {Celer: a Fast Solver for the Lasso with Dual Extrapolation}, 35 | author = {Massias, Mathurin and Gramfort, Alexandre and Salmon, Joseph}, 36 | booktitle = {Proceedings of the 35th International Conference on Machine Learning}, 37 | pages = {3321--3330}, 38 | year = {2018}, 39 | volume = {80}, 40 | } 41 | 42 | @article{massias2020dual, 43 | author = {Mathurin Massias and Samuel Vaiter and Alexandre Gramfort and Joseph Salmon}, 44 | title = {Dual Extrapolation for Sparse GLMs}, 45 | journal = {Journal of Machine Learning Research}, 46 | year = {2020}, 47 | volume = {21}, 48 | number = {234}, 49 | pages = {1-33}, 50 | url = {http://jmlr.org/papers/v21/19-587.html} 51 | } 52 | ``` 53 | 54 | ## Why ``celer``? 55 | 56 | ``celer`` is specially designed to handle Lasso-like problems which makes it a fast solver of such problems. 57 | In particular, it comes with tools such as: 58 | 59 | - automated parallel cross-validation 60 | - support of sparse and dense data 61 | - optional feature centering and normalization 62 | - unpenalized intercept fitting 63 | 64 | ``celer`` also provides easy-to-use estimators as it is designed under the ``scikit-learn`` API. 65 | 66 | 67 | 68 | ## Get started 69 | 70 | To get started, install ``celer`` via pip 71 | 72 | ```shell 73 | pip install -U celer 74 | ``` 75 | 76 | On your python console, 77 | run the following commands to fit a Lasso estimator on a toy dataset. 78 | 79 | ```python 80 | >>> from celer import Lasso 81 | >>> from celer.datasets import make_correlated_data 82 | >>> X, y, _ = make_correlated_data(n_samples=100, n_features=1000) 83 | >>> estimator = Lasso() 84 | >>> estimator.fit(X, y) 85 | ``` 86 | 87 | This is just a starter example. 88 | Make sure to browse [``celer`` documentation ](https://mathurinm.github.io/celer/) to learn more about its features. 89 | To get familiar with [``celer`` API](https://mathurinm.github.io/celer/api.html), you can also explore the gallery of examples 90 | which includes examples on real-life datasets as well as timing comparisons with other solvers. 91 | 92 | 93 | 94 | ## Contribute to celer 95 | 96 | ``celer`` is an open-source project and hence relies on community efforts to evolve. 97 | Your contribution is highly valuable and can come in three forms 98 | 99 | - **bug report:** you may encounter a bug while using ``celer``. Don't hesitate to report it on the [issue section](https://github.com/mathurinm/celer/issues). 100 | - **feature request:** you may want to extend/add new features to ``celer``. You can use the [issue section](https://github.com/mathurinm/celer/issues) to make suggestions. 101 | - **pull request:** you may have fixed a bug, enhanced the documentation, ... you can submit a [pull request](https://github.com/mathurinm/celer/pulls) and we will respond asap. 102 | 103 | For the last mean of contribution, here are the steps to help you setup ``celer`` on your local machine: 104 | 105 | 1. Fork the repository and afterwards run the following command to clone it on your local machine 106 | 107 | ```shell 108 | git clone https://github.com/{YOUR_GITHUB_USERNAME}/celer.git 109 | ``` 110 | 111 | 2. ``cd`` to ``celer`` directory and install it in edit mode by running 112 | 113 | ```shell 114 | cd celer 115 | pip install -e . 116 | ``` 117 | 118 | 3. To run the gallery examples and build the documentation, run the following 119 | 120 | ```shell 121 | cd doc 122 | pip install -e .[doc] 123 | make html 124 | ``` 125 | 126 | 127 | 128 | ## Further links 129 | 130 | - https://mathurinm.github.io/celer/ 131 | - https://arxiv.org/abs/1802.07481 132 | - https://arxiv.org/abs/1907.05830 133 | 134 | -------------------------------------------------------------------------------- /celer/PN_logreg.pyx: -------------------------------------------------------------------------------- 1 | #cython: language_level=3 2 | # Author: Mathurin Massias 3 | # License: BSD 3 clause 4 | 5 | cimport cython 6 | import numpy as np 7 | cimport numpy as np 8 | import warnings 9 | 10 | from numpy.linalg import norm 11 | from cython cimport floating 12 | from libc.math cimport fabs, sqrt, exp 13 | from sklearn.exceptions import ConvergenceWarning 14 | 15 | from .cython_utils cimport fdot, faxpy, fcopy, fposv, fscal, fnrm2 16 | from .cython_utils cimport (primal, dual, create_dual_pt, 17 | sigmoid, ST, LOGREG, dnorm_enet, 18 | compute_Xw, compute_norms_X_col, set_prios) 19 | 20 | cdef: 21 | int inc = 1 22 | 23 | @cython.boundscheck(False) 24 | @cython.wraparound(False) 25 | @cython.cdivision(True) 26 | def newton_celer( 27 | bint is_sparse, floating[::1, :] X, floating[:] X_data, 28 | int[:] X_indices, int[:] X_indptr, floating[:] y, floating alpha, 29 | floating[:] w, int max_iter, floating tol=1e-4, int p0=100, 30 | int verbose=0, bint use_accel=1, bint prune=1, bint blitz_sc=False, 31 | int max_pn_iter=50): 32 | 33 | if floating is double: 34 | dtype = np.float64 35 | else: 36 | dtype = np.float32 37 | 38 | # Enet not supported for Logreg 39 | cdef floating l1_ratio = 1.0 40 | cdef floating norm_w2 = 0. 41 | 42 | cdef int verbose_in = max(0, verbose - 1) 43 | cdef int n_samples = y.shape[0] 44 | cdef int n_features = w.shape[0] 45 | # scale tol for when problem has large or small p_obj 46 | tol *= n_samples * np.log(2) 47 | 48 | cdef int t = 0 49 | cdef int i, j, k 50 | cdef floating p_obj, d_obj, dnorm_XTtheta, theta_scaling 51 | cdef floating gap = -1 # initialized for the warning if max_iter=0 52 | cdef int info_dposv 53 | cdef int ws_size 54 | cdef floating eps_inner = 0.1 55 | cdef floating growth = 2. 56 | 57 | 58 | cdef floating[:] weights_pen = np.ones(n_features, dtype=dtype) 59 | cdef int[:] all_features = np.arange(n_features, dtype=np.int32) 60 | cdef floating[:] prios = np.empty(n_features, dtype=dtype) 61 | cdef int[:] WS 62 | cdef floating[:] gaps = np.zeros(max(1, max_iter), dtype=dtype) 63 | gaps[0] = gap # support max_iter = 0 64 | cdef floating[:] X_mean = np.zeros(n_features, dtype=dtype) 65 | cdef bint center = False 66 | # TODO support centering 67 | cdef int[:] screened = np.zeros(n_features, dtype=np.int32) 68 | cdef int n_screened = 0 69 | cdef floating radius = 10000 70 | 71 | cdef floating d_obj_acc = 0. 72 | cdef floating tol_inner 73 | 74 | cdef int K = 6 75 | cdef floating[:, :] last_K_Xw = np.zeros((K, n_samples), dtype=dtype) 76 | cdef floating[:, :] U = np.zeros((K - 1, n_samples), dtype=dtype) 77 | cdef floating[:, :] UUt = np.zeros((K - 1, K - 1), dtype=dtype) 78 | cdef floating[:] onesK = np.ones((K - 1), dtype=dtype) 79 | 80 | cdef floating[:] norms_X_col = np.zeros(n_features, dtype=dtype) 81 | compute_norms_X_col(is_sparse, norms_X_col, n_samples, X, 82 | X_data, X_indices, X_indptr, X_mean) 83 | cdef floating[:] Xw = np.zeros(n_samples, dtype=dtype) 84 | compute_Xw(is_sparse, LOGREG, Xw, w, y, center, X, X_data, X_indices, 85 | X_indptr, X_mean) 86 | 87 | cdef floating[:] theta = np.empty(n_samples, dtype=dtype) 88 | cdef floating[:] theta_acc = np.empty(n_samples, dtype=dtype) 89 | 90 | cdef floating[:] exp_Xw = np.empty(n_samples, dtype=dtype) 91 | for i in range(n_samples): 92 | exp_Xw[i] = exp(Xw[i]) 93 | cdef floating[:] low_exp_Xw = np.empty(n_samples, dtype=dtype) 94 | cdef floating[:] aux = np.empty(n_samples, dtype=dtype) 95 | cdef int[:] is_positive_label = np.empty(n_samples, dtype=np.int32) 96 | for i in range(n_samples): 97 | if y[i] > 0.: 98 | is_positive_label[i] = 1 99 | else: 100 | is_positive_label[i] = 0 101 | 102 | cdef char * char_U = 'U' 103 | cdef int one = 1 104 | cdef int Kminus1 = K - 1 105 | cdef floating sum_z 106 | cdef bint positive = 0 107 | 108 | for t in range(max_iter): 109 | p_obj = primal(LOGREG, alpha, l1_ratio, Xw, y, w, weights_pen) 110 | 111 | # theta = y * sigmoid(-y * Xw) 112 | create_dual_pt(LOGREG, n_samples, &theta[0], &Xw[0], &y[0]) 113 | dnorm_XTtheta = dnorm_enet( 114 | is_sparse, theta, w, X, X_data, X_indices, X_indptr, 115 | screened, X_mean, weights_pen, center, positive, alpha, l1_ratio) 116 | 117 | if dnorm_XTtheta > alpha: 118 | theta_scaling = alpha / dnorm_XTtheta 119 | fscal(&n_samples, &theta_scaling, &theta[0], &inc) 120 | 121 | d_obj = dual(LOGREG, n_samples, alpha, l1_ratio, 0., norm_w2, &theta[0], &y[0]) 122 | gap = p_obj - d_obj 123 | 124 | if t != 0 and use_accel: 125 | # do some epochs of CD to create an extrapolated dual point 126 | for k in range(K): 127 | cd_one_pass(w, is_sparse, X, X_data, 128 | X_indices, X_indptr, y, alpha, Xw) 129 | fcopy(&n_samples, &Xw[0], &inc, 130 | &last_K_Xw[k, 0], &inc) 131 | 132 | # TODO use function in utils 133 | for k in range(K - 1): 134 | for i in range(n_samples): 135 | U[k, i] = last_K_Xw[k + 1, i] - last_K_Xw[k, i] 136 | 137 | for k in range(K - 1): 138 | for j in range(k, K - 1): 139 | UUt[k, j] = fdot(&n_samples, &U[k, 0], &inc, 140 | &U[j, 0], &inc) 141 | UUt[j, k] = UUt[k, j] 142 | 143 | for k in range(K - 1): 144 | onesK[k] = 1. 145 | 146 | fposv(char_U, &Kminus1, &one, &UUt[0, 0], &Kminus1, 147 | &onesK[0], &Kminus1, &info_dposv) 148 | 149 | if info_dposv != 0: 150 | for k in range(K - 2): 151 | onesK[k] = 0 152 | onesK[K - 2] = 1 153 | 154 | sum_z = 0. 155 | for k in range(K - 1): 156 | sum_z += onesK[k] 157 | for k in range(K - 1): 158 | onesK[k] /= sum_z 159 | 160 | for i in range(n_samples): 161 | theta_acc[i] = 0. 162 | for k in range(K - 1): 163 | for i in range(n_samples): 164 | theta_acc[i] += onesK[k] * last_K_Xw[k, i] 165 | for i in range(n_samples): 166 | theta_acc[i] = y[i] * sigmoid(- y[i] * theta_acc[i]) 167 | 168 | # do not forget to update exp_Xw 169 | for i in range(n_samples): 170 | exp_Xw[i] = exp(Xw[i]) 171 | 172 | dnorm_XTtheta = dnorm_enet( 173 | is_sparse, theta_acc, w, X, X_data, X_indices, X_indptr, 174 | screened, X_mean, weights_pen, center, positive, alpha, l1_ratio) 175 | 176 | if dnorm_XTtheta > alpha: 177 | theta_scaling = alpha / dnorm_XTtheta 178 | fscal(&n_samples, &theta_scaling, &theta_acc[0], &inc) 179 | 180 | d_obj_acc = dual(LOGREG, n_samples, alpha, l1_ratio, 0., norm_w2, &theta_acc[0], &y[0]) 181 | if d_obj_acc > d_obj: 182 | fcopy(&n_samples, &theta_acc[0], &inc, &theta[0], &inc) 183 | gap = p_obj - d_obj_acc 184 | 185 | gaps[t] = gap 186 | if verbose: 187 | print("Iter %d: primal %.10f, gap %.2e" % (t, p_obj, gap)) 188 | 189 | if gap <= tol: 190 | if verbose: 191 | print("Early exit, gap: %.2e < %.2e" % (gap, tol)) 192 | break 193 | 194 | 195 | set_prios(is_sparse, theta, w, alpha, l1_ratio, X, X_data, X_indices, X_indptr, 196 | norms_X_col, weights_pen, prios, screened, radius, 197 | &n_screened, 0) 198 | 199 | if prune: 200 | if t == 0: 201 | ws_size = p0 202 | else: 203 | ws_size = 0 204 | for j in range(n_features): 205 | if w[j] != 0: 206 | prios[j] = -1 207 | ws_size += 1 208 | ws_size = 2 * ws_size 209 | else: 210 | if t == 0: 211 | ws_size = p0 212 | else: 213 | ws_size *= 2 214 | 215 | if ws_size >= n_features: 216 | ws_size = n_features 217 | WS = all_features # argpartition breaks otherwise 218 | else: 219 | WS = np.asarray(np.argpartition(prios, ws_size)[:ws_size]).astype(np.int32) 220 | np.asarray(WS).sort() 221 | tol_inner = eps_inner * gap 222 | if verbose: 223 | print("Solving subproblem with %d constraints" % len(WS)) 224 | 225 | PN_logreg(is_sparse, w, WS, X, X_data, X_indices, X_indptr, y, 226 | alpha, tol_inner, Xw, exp_Xw, low_exp_Xw, 227 | aux, is_positive_label, X_mean, weights_pen, center, 228 | blitz_sc, verbose_in, max_pn_iter) 229 | else: 230 | warnings.warn( 231 | 'Objective did not converge: duality ' + 232 | f'gap: {gap}, tolerance: {tol}. Increasing `tol` may make the' + 233 | ' solver faster without affecting the results much. \n' + 234 | 'Fitting data with very small alpha causes precision issues.', 235 | ConvergenceWarning) 236 | return np.asarray(w), np.asarray(theta), np.asarray(gaps[:t + 1]) 237 | 238 | 239 | @cython.boundscheck(False) 240 | @cython.wraparound(False) 241 | @cython.cdivision(True) 242 | cpdef int PN_logreg( 243 | bint is_sparse, floating[:] w, int[:] WS, 244 | floating[::1, :] X, floating[:] X_data, int[:] X_indices, 245 | int[:] X_indptr, floating[:] y, floating alpha, 246 | floating tol_inner, floating[:] Xw, 247 | floating[:] exp_Xw, floating[:] low_exp_Xw, floating[:] aux, 248 | int[:] is_positive_label, floating[:] X_mean, 249 | floating[:] weights_pen, bint center, bint blitz_sc, int verbose_in, 250 | int max_pn_iter): 251 | 252 | cdef int n_samples = Xw.shape[0] 253 | cdef int ws_size = WS.shape[0] 254 | cdef int n_features = w.shape[0] 255 | 256 | # Enet not supported for Logreg 257 | cdef floating l1_ratio = 1.0 258 | cdef floating norm_w2 = 0. 259 | 260 | if floating is double: 261 | dtype = np.float64 262 | else: 263 | dtype = np.float32 264 | 265 | cdef: 266 | int MAX_BACKTRACK_ITR = 10 267 | int MAX_PN_CD_ITR = 10 268 | int MIN_PN_CD_ITR = 2 269 | cdef floating PN_EPSILON_RATIO = 10. 270 | 271 | cdef floating[:] weights = np.zeros(n_samples, dtype=dtype) 272 | cdef floating[:] grad = np.zeros(n_samples, dtype=dtype) 273 | 274 | # solve a Lasso with other X and y (see paper) 275 | cdef floating[:] bias = np.zeros(ws_size, dtype=dtype) 276 | cdef floating[:] lc = np.zeros(ws_size, dtype=dtype) 277 | cdef floating[:] delta_w = np.zeros(ws_size, dtype=dtype) 278 | cdef floating[:] X_delta_w = np.zeros(n_samples, dtype=dtype) 279 | 280 | cdef floating[:] theta = np.zeros(n_samples, dtype=dtype) 281 | 282 | # for CD stopping criterion 283 | cdef bint first_pn_iteration = True 284 | cdef floating pn_grad_diff = 0. 285 | cdef floating approx_grad, actual_grad, sum_sq_hess_diff, pn_epsilon 286 | cdef floating[:] pn_grad_cache = np.zeros(ws_size, dtype=dtype) 287 | 288 | cdef int i, j, ind, max_cd_itr, cd_itr, pn_iter 289 | cdef floating prob 290 | 291 | cdef int start_ptr, end_ptr 292 | cdef floating gap, p_obj, d_obj, dnorm_XTtheta 293 | cdef floating tmp, theta_scaling, new_value, old_value, diff 294 | 295 | cdef int[:] notin_WS = np.ones(n_features, dtype=np.int32) 296 | for ind in range(ws_size): 297 | notin_WS[WS[ind]] = 0 298 | 299 | for pn_iter in range(max_pn_iter): 300 | # run prox newton iterations: 301 | for i in range(n_samples): 302 | prob = 1. / (1. + exp(y[i] * Xw[i])) 303 | weights[i] = prob * (1. - prob) 304 | grad[i] = - y[i] * prob 305 | 306 | for ind in range(ws_size): 307 | lc[ind] = wdot(Xw, weights, WS[ind], is_sparse, X, X_data, 308 | X_indices, X_indptr, 1) 309 | bias[ind] = xj_dot(grad, WS[ind], is_sparse, X, 310 | X_data, X_indices, X_indptr, n_samples) 311 | 312 | if first_pn_iteration: 313 | # very weird: first cd iter, do only 314 | max_cd_itr = MIN_PN_CD_ITR 315 | pn_epsilon = 0 316 | first_pn_iteration = False 317 | else: 318 | max_cd_itr = MAX_PN_CD_ITR 319 | pn_epsilon = PN_EPSILON_RATIO * pn_grad_diff 320 | 321 | for ind in range(ws_size): 322 | delta_w[ind] = 0. 323 | for i in range(n_samples): 324 | X_delta_w[i] = 0 325 | for cd_itr in range(max_cd_itr): 326 | sum_sq_hess_diff = 0. 327 | 328 | for ind in range(ws_size): 329 | j = WS[ind] 330 | old_value = w[j] + delta_w[ind] 331 | tmp = wdot(X_delta_w, weights, j, is_sparse, X, 332 | X_data, X_indices, X_indptr, jj=False) 333 | new_value = ST(old_value - (bias[ind] + tmp) / lc[ind], 334 | alpha / lc[ind]) 335 | 336 | diff = new_value - old_value 337 | if diff != 0: 338 | sum_sq_hess_diff += lc[ind] ** 2 * diff ** 2 339 | delta_w[ind] = new_value - w[j] 340 | if is_sparse: 341 | start_ptr, end_ptr = X_indptr[j], X_indptr[j + 1] 342 | for i in range(start_ptr, end_ptr): 343 | X_delta_w[X_indices[i]] += diff * X_data[i] 344 | else: 345 | for i in range(n_samples): 346 | X_delta_w[i] += diff * X[i, j] 347 | if (sum_sq_hess_diff < pn_epsilon and 348 | cd_itr + 1 >= MIN_PN_CD_ITR): 349 | break 350 | 351 | do_line_search(w, WS, delta_w, X_delta_w, Xw, alpha, is_sparse, X, X_data, 352 | X_indices, X_indptr, MAX_BACKTRACK_ITR, y, 353 | exp_Xw, low_exp_Xw, aux, is_positive_label) 354 | # aux is an up-to-date gradient (= - alpha * unscaled dual point) 355 | create_dual_pt(LOGREG, n_samples, &aux[0], &Xw[0], &y[0]) 356 | 357 | if blitz_sc: # blitz stopping criterion for CD iter 358 | pn_grad_diff = 0. 359 | for ind in range(ws_size): 360 | j = WS[ind] 361 | actual_grad = xj_dot( 362 | aux, j, is_sparse, X, X_data, X_indices, X_indptr, n_features) 363 | # TODO step_size taken into account? 364 | approx_grad = pn_grad_cache[ind] + wdot( 365 | X_delta_w, weights, j, is_sparse, X, X_data, X_indices, 366 | X_indptr, False) 367 | pn_grad_cache[ind] = actual_grad 368 | diff = approx_grad - actual_grad 369 | 370 | pn_grad_diff += diff ** 2 371 | 372 | dnorm_XTtheta = 0. 373 | for ind in range(ws_size): 374 | theta_scaling = fabs(pn_grad_cache[ind]) 375 | if theta_scaling > dnorm_XTtheta: 376 | dnorm_XTtheta = theta_scaling 377 | 378 | else: 379 | # rescale aux to create dual point 380 | dnorm_XTtheta = dnorm_enet( 381 | is_sparse, aux, w, X, X_data, X_indices, X_indptr, 382 | notin_WS, X_mean, weights_pen, center, 0, alpha, l1_ratio) 383 | 384 | for i in range(n_samples): 385 | aux[i] /= max(1, dnorm_XTtheta / alpha) 386 | 387 | d_obj = dual(LOGREG, n_samples, alpha, l1_ratio, 0., norm_w2, &aux[0], &y[0]) 388 | p_obj = primal(LOGREG, alpha, l1_ratio, Xw, y, w, weights_pen) 389 | 390 | gap = p_obj - d_obj 391 | if verbose_in: 392 | print("iter %d, p_obj %.10f, d_obj % .10f" % (pn_iter, p_obj, d_obj)) 393 | if gap <= tol_inner: 394 | if verbose_in: 395 | print("%.2e < %.2e, exit." % (gap, tol_inner)) 396 | break 397 | 398 | 399 | @cython.boundscheck(False) 400 | @cython.wraparound(False) 401 | @cython.cdivision(True) 402 | cpdef void do_line_search( 403 | floating[:] w, int[:] WS, floating[:] delta_w, 404 | floating[:] X_delta_w, floating[:] Xw, floating alpha, bint is_sparse, 405 | floating[::1, :] X, floating[:] X_data, 406 | int[:] X_indices, int[:] X_indptr, int MAX_BACKTRACK_ITR, 407 | floating[:] y, floating[:] exp_Xw, floating[:] low_exp_Xw, 408 | floating[:] aux, int[:] is_positive_label) nogil: 409 | 410 | cdef int i, ind, backtrack_itr 411 | cdef floating deriv 412 | cdef floating step_size = 1. 413 | 414 | cdef int n_samples = y.shape[0] 415 | fcopy(&n_samples, &exp_Xw[0], &inc, &low_exp_Xw[0], &inc) 416 | for i in range(n_samples): 417 | exp_Xw[i] = exp(Xw[i] + X_delta_w[i]) 418 | 419 | for backtrack_itr in range(MAX_BACKTRACK_ITR): 420 | compute_aux(aux, is_positive_label, exp_Xw) 421 | 422 | deriv = compute_derivative( 423 | w, WS, delta_w, X_delta_w, alpha, aux, step_size, y) 424 | 425 | if deriv < 1e-7: 426 | break 427 | else: 428 | step_size = step_size / 2. 429 | for i in range(n_samples): 430 | exp_Xw[i] = sqrt(exp_Xw[i] * low_exp_Xw[i]) 431 | else: 432 | pass 433 | # TODO what do we do in this case? 434 | 435 | # a suitable step size is found, perform step: 436 | for ind in range(WS.shape[0]): 437 | w[WS[ind]] += step_size * delta_w[ind] 438 | for i in range(Xw.shape[0]): 439 | Xw[i] += step_size * X_delta_w[i] 440 | 441 | 442 | @cython.boundscheck(False) 443 | @cython.wraparound(False) 444 | @cython.cdivision(True) 445 | cpdef floating compute_derivative( 446 | floating[:] w, int[:] WS, floating[:] delta_w, 447 | floating[:] X_delta_w, floating alpha, floating[:] aux, 448 | floating step_size, floating[:] y) nogil: 449 | 450 | cdef int j 451 | cdef floating deriv_l1 = 0. 452 | cdef floating deriv_loss, wj 453 | cdef int n_samples = X_delta_w.shape[0] 454 | 455 | for j in range(WS.shape[0]): 456 | 457 | wj = w[WS[j]] + step_size * delta_w[j] 458 | if wj == 0.: 459 | deriv_l1 -= fabs(delta_w[j]) 460 | else: 461 | deriv_l1 += wj / fabs(wj) * delta_w[j] 462 | 463 | deriv_loss = fdot(&n_samples, &X_delta_w[0], &inc, &aux[0], &inc) 464 | return deriv_loss + alpha * deriv_l1 465 | 466 | 467 | @cython.boundscheck(False) 468 | @cython.wraparound(False) 469 | @cython.cdivision(True) 470 | cpdef void compute_aux(floating[:] aux, int[:] is_positive_label, 471 | floating[:] exp_Xw) nogil: 472 | """-y / (1. + exp(y * Xw))""" 473 | cdef int i 474 | for i in range(is_positive_label.shape[0]): 475 | if is_positive_label[i]: 476 | aux[i] = -1 / (1. + exp_Xw[i]) 477 | else: 478 | aux[i] = 1. - 1 / (1. + exp_Xw[i]) 479 | 480 | 481 | 482 | @cython.boundscheck(False) 483 | @cython.wraparound(False) 484 | @cython.cdivision(True) 485 | cpdef floating wdot(floating[:] v, floating[:] weights, int j, 486 | bint is_sparse, floating[::1, :] X, floating[:] X_data, 487 | int[:] X_indices, int[:] X_indptr, bint jj) nogil: 488 | """Weighted dot product between j-th column of X and v. 489 | 490 | Parameters: 491 | ---------- 492 | jj: bool 493 | If true, v is ignored and dot product is between X[:, j] and X[:, j] 494 | """ 495 | cdef floating tmp = 0 496 | cdef int start, end 497 | cdef int i 498 | 499 | if jj: 500 | if is_sparse: 501 | start, end = X_indptr[j], X_indptr[j + 1] 502 | for i in range(start, end): 503 | tmp += X_data[i] * X_data[i] * weights[X_indices[i]] 504 | else: 505 | for i in range(X.shape[0]): 506 | tmp += X[i, j] ** 2 * weights[i] 507 | else: 508 | if is_sparse: 509 | start, end = X_indptr[j], X_indptr[j + 1] 510 | for i in range(start, end): 511 | tmp += X_data[i] * v[X_indices[i]] * weights[X_indices[i]] 512 | else: 513 | for i in range(X.shape[0]): 514 | tmp += X[i, j] * v[i] * weights[i] 515 | return tmp 516 | 517 | 518 | @cython.boundscheck(False) 519 | @cython.wraparound(False) 520 | @cython.cdivision(True) 521 | cpdef double xj_dot(floating[:] v, int j, bint is_sparse, 522 | floating[::1, :] X, floating[:] X_data, int[:] X_indices, 523 | int[:] X_indptr, int n_samples) nogil: 524 | """Dot product between j-th column of X and v.""" 525 | cdef floating tmp = 0 526 | cdef int start, end 527 | cdef int i 528 | 529 | 530 | if is_sparse: 531 | start, end = X_indptr[j], X_indptr[j + 1] 532 | for i in range(start, end): 533 | tmp += X_data[i] * v[X_indices[i]] 534 | else: 535 | for i in range(n_samples): 536 | tmp += X[i, j] * v[i] 537 | return tmp 538 | 539 | 540 | @cython.boundscheck(False) 541 | @cython.wraparound(False) 542 | @cython.cdivision(True) 543 | cpdef void cd_one_pass( 544 | floating[:] w, bint is_sparse, 545 | floating[::1, :] X, floating[:] X_data, 546 | int[:] X_indices, int[:] X_indptr, floating[:] y, 547 | floating alpha, floating[:] Xw): 548 | """ 549 | Do one pass of CD on non zero elements of w. Modifies w and Xw inplace 550 | """ 551 | cdef int n_features = w.shape[0] 552 | cdef int n_samples = Xw.shape[0] 553 | 554 | cdef floating old_w_j, grad_j, lc_j, exp_yXw_i, tmp 555 | cdef int startptr, endptr 556 | cdef int i, j, ind 557 | 558 | for j in range(n_features): 559 | if not w[j]: 560 | continue 561 | old_w_j = w[j] 562 | grad_j = 0. 563 | 564 | if is_sparse: 565 | startptr = X_indptr[j] 566 | endptr = X_indptr[j + 1] 567 | 568 | for i in range(startptr, endptr): 569 | ind = X_indices[i] 570 | grad_j -= X_data[i] * y[ind] / (1. + exp(y[ind] * Xw[ind])) 571 | else: 572 | for i in range(n_samples): 573 | grad_j -= X[i, j] * y[i] / (1. + exp(y[i] * Xw[i])) 574 | 575 | lc_j = 0. 576 | 577 | if is_sparse: 578 | startptr = X_indptr[j] 579 | endptr = X_indptr[j + 1] 580 | 581 | for i in range(startptr, endptr): 582 | ind = X_indices[i] 583 | exp_yXw_i = exp(-y[ind] * Xw[ind]) 584 | lc_j += X_data[i] ** 2 * exp_yXw_i / (1. + exp_yXw_i) ** 2 585 | else: 586 | for i in range(n_samples): 587 | exp_yXw_i = exp(- y[i] * Xw[i]) 588 | lc_j += (X[i, j] ** 2 * exp_yXw_i / (1. + exp_yXw_i) ** 2) 589 | w[j] = ST(w[j] - grad_j / lc_j, alpha / lc_j) 590 | 591 | if old_w_j != w[j]: 592 | if is_sparse: 593 | startptr = X_indptr[j] 594 | endptr = X_indptr[j + 1] 595 | tmp = w[j] - old_w_j 596 | 597 | for i in range(startptr, endptr): 598 | Xw[X_indices[i]] += tmp * X_data[i] 599 | 600 | else: 601 | for i in range(n_samples): 602 | Xw[i] += (w[j] - old_w_j) * X[i, j] -------------------------------------------------------------------------------- /celer/__init__.py: -------------------------------------------------------------------------------- 1 | """Celer algorithm to solve L1-type regularized problems.""" 2 | 3 | from .homotopy import celer_path 4 | from .dropin_sklearn import (ElasticNet, ElasticNetCV, 5 | GroupLasso, GroupLassoCV, 6 | Lasso, LassoCV, LogisticRegression, 7 | MultiTaskLasso, MultiTaskLassoCV) 8 | 9 | 10 | __version__ = '0.7.5dev0' 11 | -------------------------------------------------------------------------------- /celer/cython_utils.pxd: -------------------------------------------------------------------------------- 1 | # Author: Mathurin Massias 2 | # License: BSD 3 clause 3 | from cython cimport floating 4 | cimport numpy as np 5 | 6 | cdef int LASSO 7 | cdef int LOGREG 8 | 9 | cdef floating ST(floating, floating) nogil 10 | 11 | cdef floating fweighted_norm_w2(floating[:], floating[:]) nogil 12 | 13 | cdef floating dual(int, int, floating, floating, floating, floating, floating *, floating *) nogil 14 | cdef floating primal(int, floating, floating, floating[:], floating [:], 15 | floating [:], floating[:]) nogil 16 | cdef void create_dual_pt(int, int, floating *, floating *, floating *) nogil 17 | 18 | cdef floating Nh(floating) nogil 19 | cdef floating sigmoid(floating) nogil 20 | 21 | cdef floating fdot(int *, floating *, int *, floating *, int *) nogil 22 | cdef floating fasum(int *, floating *, int *) nogil 23 | cdef void faxpy(int *, floating *, floating *, int *, floating *, int *) nogil 24 | cdef floating fnrm2(int * , floating *, int *) nogil 25 | cdef void fcopy(int *, floating *, int *, floating *, int *) nogil 26 | cdef void fscal(int *, floating *, floating *, int *) nogil 27 | 28 | cdef void fposv(char *, int *, int *, floating *, 29 | int *, floating *, int *, int *) nogil 30 | 31 | cdef int create_accel_pt( 32 | int, int, int, int, floating *, floating *, 33 | floating *, floating[:, :], floating[:, :], floating[:], floating[:]) 34 | 35 | 36 | cpdef void compute_Xw( 37 | bint, int, floating[:], floating[:], 38 | floating[:], bint, floating[::1, :], 39 | floating[:], int[:], int[:], floating[:]) 40 | 41 | 42 | cpdef void compute_norms_X_col( 43 | bint, floating[:], int, floating[::1, :], 44 | floating[:], int[:], int[:], floating[:]) 45 | 46 | 47 | cpdef floating dnorm_enet( 48 | bint, floating[:], floating[:], floating[::1, :], floating[:], 49 | int[:], int[:], int[:], floating[:], floating[:], bint, bint, floating, floating) nogil 50 | 51 | 52 | cdef void set_prios( 53 | bint, floating[:], floating[:], floating, floating, floating[::1, :], floating[:], int[:], 54 | int[:], floating[:], floating[:], floating[:], int[:], floating, int *, bint) nogil 55 | -------------------------------------------------------------------------------- /celer/cython_utils.pyx: -------------------------------------------------------------------------------- 1 | # Author: Mathurin Massias 2 | # Alexandre Gramfort 3 | # Joseph Salmon 4 | # License: BSD 3 clause 5 | 6 | 7 | cimport cython 8 | cimport numpy as np 9 | 10 | from scipy.linalg.cython_blas cimport ddot, dasum, daxpy, dnrm2, dcopy, dscal 11 | from scipy.linalg.cython_blas cimport sdot, sasum, saxpy, snrm2, scopy, sscal 12 | from scipy.linalg.cython_lapack cimport sposv, dposv 13 | from libc.math cimport fabs, log, exp, sqrt, INFINITY 14 | from cython cimport floating 15 | 16 | 17 | cdef: 18 | int LASSO = 0 19 | int LOGREG = 1 20 | int GRPLASSO = 2 21 | int inc = 1 22 | 23 | 24 | cdef floating fdot(int * n, floating * x, int * inc1, floating * y, 25 | int * inc2) nogil: 26 | if floating is double: 27 | return ddot(n, x, inc1, y, inc2) 28 | else: 29 | return sdot(n, x, inc1, y, inc2) 30 | 31 | 32 | cdef floating fasum(int * n, floating * x, int * inc) nogil: 33 | if floating is double: 34 | return dasum(n, x, inc) 35 | else: 36 | return sasum(n, x, inc) 37 | 38 | 39 | cdef void faxpy(int * n, floating * alpha, floating * x, int * incx, 40 | floating * y, int * incy) nogil: 41 | if floating is double: 42 | daxpy(n, alpha, x, incx, y, incy) 43 | else: 44 | saxpy(n, alpha, x, incx, y, incy) 45 | 46 | 47 | cdef floating fnrm2(int * n, floating * x, int * inc) nogil: 48 | if floating is double: 49 | return dnrm2(n, x, inc) 50 | else: 51 | return snrm2(n, x, inc) 52 | 53 | 54 | cdef void fcopy(int * n, floating * x, int * incx, floating * y, 55 | int * incy) nogil: 56 | if floating is double: 57 | dcopy(n, x, incx, y, incy) 58 | else: 59 | scopy(n, x, incx, y, incy) 60 | 61 | 62 | cdef void fscal(int * n, floating * alpha, floating * x, 63 | int * incx) nogil: 64 | if floating is double: 65 | dscal(n, alpha, x, incx) 66 | else: 67 | sscal(n, alpha, x, incx) 68 | 69 | 70 | cdef void fposv(char * uplo, int * n, int * nrhs, floating * a, 71 | int * lda, floating * b, int * ldb, int * info) nogil: 72 | if floating is double: 73 | dposv(uplo, n, nrhs, a, lda, b, ldb, info) 74 | else: 75 | sposv(uplo, n, nrhs, a, lda, b, ldb, info) 76 | 77 | 78 | cdef inline floating ST(floating x, floating u) nogil: 79 | if x > u: 80 | return x - u 81 | elif x < - u: 82 | return x + u 83 | else: 84 | return 0 85 | 86 | 87 | cdef floating log_1pexp(floating x) nogil: 88 | """Compute log(1. + exp(x)) while avoiding over/underflow.""" 89 | if x < - 18: 90 | return exp(x) 91 | elif x > 18: 92 | return x 93 | else: 94 | return log(1. + exp(x)) 95 | 96 | 97 | cdef inline floating xlogx(floating x) nogil: 98 | if x < 1e-10: 99 | return 0. 100 | else: 101 | return x * log(x) 102 | 103 | cdef inline floating Nh(floating x) nogil: 104 | """Negative entropy of scalar x.""" 105 | if 0. <= x <= 1.: 106 | return xlogx(x) + xlogx(1. - x) 107 | else: 108 | return INFINITY # not - INFINITY 109 | 110 | 111 | @cython.boundscheck(False) 112 | @cython.wraparound(False) 113 | cdef floating fweighted_norm_w2(floating[:] w, floating[:] weights) nogil: 114 | cdef floating weighted_norm = 0. 115 | cdef int n_features = w.shape[0] 116 | cdef int j 117 | 118 | for j in range(n_features): 119 | if weights[j] == INFINITY: 120 | continue 121 | weighted_norm += weights[j] * w[j] ** 2 122 | return weighted_norm 123 | 124 | 125 | @cython.boundscheck(False) 126 | @cython.wraparound(False) 127 | @cython.cdivision(True) 128 | cdef inline floating sigmoid(floating x) nogil: 129 | return 1. / (1. + exp(- x)) 130 | 131 | 132 | @cython.boundscheck(False) 133 | @cython.wraparound(False) 134 | @cython.cdivision(True) 135 | cdef floating primal_logreg( 136 | floating alpha, floating[:] Xw, floating[:] y, floating[:] w, 137 | floating[:] weights) nogil: 138 | cdef int inc = 1 139 | cdef int n_samples = Xw.shape[0] 140 | cdef int n_features = w.shape[0] 141 | cdef floating p_obj = 0. 142 | cdef int i, j 143 | for i in range(n_samples): 144 | p_obj += log_1pexp(- y[i] * Xw[i]) 145 | for j in range(n_features): 146 | # avoid nan when weights[j] is INFINITY 147 | if w[j]: 148 | p_obj += alpha * weights[j] * fabs(w[j]) 149 | return p_obj 150 | 151 | 152 | # todo check normalization by 1 / n_samples everywhere 153 | @cython.boundscheck(False) 154 | @cython.wraparound(False) 155 | @cython.cdivision(True) 156 | cdef floating primal_lasso( 157 | floating alpha, floating l1_ratio, floating[:] R, floating[:] w, 158 | floating[:] weights) nogil: 159 | cdef int n_samples = R.shape[0] 160 | cdef int n_features = w.shape[0] 161 | cdef int inc = 1 162 | cdef int j 163 | cdef floating p_obj = 0. 164 | p_obj = fdot(&n_samples, &R[0], &inc, &R[0], &inc) / (2. * n_samples) 165 | for j in range(n_features): 166 | # avoid nan when weights[j] is INFINITY 167 | if w[j]: 168 | p_obj += alpha * weights[j] * ( 169 | l1_ratio * fabs(w[j]) + 170 | 0.5 * (1. - l1_ratio) * w[j] ** 2) 171 | return p_obj 172 | 173 | 174 | cdef floating primal( 175 | int pb, floating alpha, floating l1_ratio, floating[:] R, floating[:] y, 176 | floating[:] w, floating[:] weights) nogil: 177 | if pb == LASSO: 178 | return primal_lasso(alpha, l1_ratio, R, w, weights) 179 | else: 180 | return primal_logreg(alpha, R, y, w, weights) 181 | 182 | 183 | @cython.boundscheck(False) 184 | @cython.wraparound(False) 185 | @cython.cdivision(True) 186 | cdef floating dual_enet(int n_samples, floating alpha, floating l1_ratio, 187 | floating norm_y2, floating norm_w2, floating * theta, 188 | floating * y) nogil: 189 | """Theta must be feasible""" 190 | cdef int i 191 | cdef floating d_obj = 0. 192 | 193 | for i in range(n_samples): 194 | d_obj -= (y[i] - n_samples * theta[i]) ** 2 195 | d_obj *= 0.5 / n_samples 196 | d_obj += norm_y2 / (2. * n_samples) 197 | if l1_ratio != 1.0: 198 | d_obj -= 0.5 * alpha * (1 - l1_ratio) * norm_w2 199 | return d_obj 200 | 201 | 202 | @cython.boundscheck(False) 203 | @cython.wraparound(False) 204 | @cython.cdivision(True) 205 | cdef floating dual_logreg(int n_samples, floating * theta, 206 | floating * y) nogil: 207 | """Compute dual objective value at theta, which must be feasible.""" 208 | cdef int i 209 | cdef floating d_obj = 0. 210 | 211 | for i in range(n_samples): 212 | d_obj -= Nh(y[i] * theta[i]) 213 | return d_obj 214 | 215 | 216 | cdef floating dual(int pb, int n_samples, floating alpha, floating l1_ratio, 217 | floating norm_y2, floating norm_w2, floating * theta, floating * y) nogil: 218 | if pb == LASSO: 219 | return dual_enet(n_samples, alpha, l1_ratio, norm_y2, norm_w2, &theta[0], &y[0]) 220 | else: 221 | return dual_logreg(n_samples, &theta[0], &y[0]) 222 | 223 | 224 | @cython.boundscheck(False) 225 | @cython.wraparound(False) 226 | @cython.cdivision(True) 227 | cdef void create_dual_pt( 228 | int pb, int n_samples, floating * out, 229 | floating * R, floating * y) nogil: 230 | cdef floating tmp = 1. 231 | if pb == LASSO: # out = R / n_samples 232 | tmp = 1. / n_samples 233 | fcopy(&n_samples, &R[0], &inc, &out[0], &inc) 234 | else: # out = y * sigmoid(-y * Xw) 235 | for i in range(n_samples): 236 | out[i] = y[i] * sigmoid(-y[i] * R[i]) 237 | 238 | fscal(&n_samples, &tmp, &out[0], &inc) 239 | 240 | 241 | @cython.boundscheck(False) 242 | @cython.wraparound(False) 243 | @cython.cdivision(True) 244 | cdef int create_accel_pt( 245 | int pb, int n_samples, int epoch, int gap_freq, 246 | floating * R, floating * out, floating * last_K_R, floating[:, :] U, 247 | floating[:, :] UtU, floating[:] onesK, floating[:] y): 248 | 249 | # solving linear system in cython 250 | # doc at https://software.intel.com/en-us/node/468894 251 | 252 | # cdef int n_samples = y.shape[0] cannot use this for MTL 253 | cdef int K = U.shape[0] + 1 254 | cdef char * char_U = 'U' 255 | cdef int one = 1 256 | cdef int Kminus1 = K - 1 257 | cdef int inc = 1 258 | cdef floating sum_z 259 | cdef int info_dposv 260 | 261 | cdef int i, j, k 262 | # warning: this is wrong (n_samples) for MTL, it is handled outside 263 | cdef floating tmp = 1. if pb == LOGREG else 1. / n_samples 264 | 265 | if epoch // gap_freq < K: 266 | # last_K_R[it // f_gap] = R: 267 | fcopy(&n_samples, R, &inc, 268 | &last_K_R[(epoch // gap_freq) * n_samples], &inc) 269 | else: 270 | for k in range(K - 1): 271 | fcopy(&n_samples, &last_K_R[(k + 1) * n_samples], &inc, 272 | &last_K_R[k * n_samples], &inc) 273 | fcopy(&n_samples, R, &inc, &last_K_R[(K - 1) * n_samples], &inc) 274 | for k in range(K - 1): 275 | for i in range(n_samples): 276 | U[k, i] = last_K_R[(k + 1) * n_samples + i] - \ 277 | last_K_R[k * n_samples + i] 278 | 279 | for k in range(K - 1): 280 | for j in range(k, K - 1): 281 | UtU[k, j] = fdot(&n_samples, &U[k, 0], &inc, &U[j, 0], &inc) 282 | UtU[j, k] = UtU[k, j] 283 | 284 | # refill onesK with ones because it has been overwritten 285 | # by dposv 286 | for k in range(K - 1): 287 | onesK[k] = 1. 288 | 289 | fposv(char_U, &Kminus1, &one, &UtU[0, 0], &Kminus1, 290 | &onesK[0], &Kminus1, &info_dposv) 291 | 292 | # onesK now holds the solution in x to UtU dot x = onesK 293 | if info_dposv != 0: 294 | # don't use accel for this iteration 295 | for k in range(K - 2): 296 | onesK[k] = 0 297 | onesK[K - 2] = 1 298 | 299 | sum_z = 0. 300 | for k in range(K - 1): 301 | sum_z += onesK[k] 302 | for k in range(K - 1): 303 | onesK[k] /= sum_z 304 | 305 | for i in range(n_samples): 306 | out[i] = 0. 307 | for k in range(K - 1): 308 | for i in range(n_samples): 309 | out[i] += onesK[k] * last_K_R[k * n_samples + i] 310 | 311 | if pb == LOGREG: 312 | for i in range(n_samples): 313 | out[i] = y[i] * sigmoid(- y[i] * out[i]) 314 | 315 | fscal(&n_samples, &tmp, &out[0], &inc) 316 | # out now holds the extrapolated dual point: 317 | # LASSO: (y - Xw) / n_samples 318 | # LOGREG: y * sigmoid(-y * Xw) 319 | 320 | return info_dposv 321 | 322 | 323 | @cython.boundscheck(False) 324 | @cython.wraparound(False) 325 | @cython.cdivision(True) 326 | cpdef void compute_norms_X_col( 327 | bint is_sparse, floating[:] norms_X_col, int n_samples, 328 | floating[::1, :] X, floating[:] X_data, int[:] X_indices, 329 | int[:] X_indptr, floating[:] X_mean): 330 | cdef int j, startptr, endptr 331 | cdef floating tmp, X_mean_j 332 | cdef int n_features = norms_X_col.shape[0] 333 | 334 | for j in range(n_features): 335 | if is_sparse: 336 | startptr = X_indptr[j] 337 | endptr = X_indptr[j + 1] 338 | X_mean_j = X_mean[j] 339 | tmp = 0. 340 | for i in range(startptr, endptr): 341 | tmp += (X_data[i] - X_mean_j) ** 2 342 | tmp += (n_samples - endptr + startptr) * X_mean_j ** 2 343 | norms_X_col[j] = sqrt(tmp) 344 | else: 345 | norms_X_col[j] = fnrm2(&n_samples, &X[0, j], &inc) 346 | 347 | 348 | @cython.boundscheck(False) 349 | @cython.wraparound(False) 350 | @cython.cdivision(True) 351 | cpdef void compute_Xw( 352 | bint is_sparse, int pb, floating[:] R, floating[:] w, 353 | floating[:] y, bint center, floating[::1, :] X, floating[:] X_data, 354 | int[:] X_indices, int[:] X_indptr, floating[:] X_mean): 355 | # R holds residuals if LASSO, Xw for LOGREG 356 | cdef int i, j, startptr, endptr 357 | cdef floating tmp, X_mean_j 358 | cdef int inc = 1 359 | cdef int n_samples = y.shape[0] 360 | cdef int n_features = w.shape[0] 361 | 362 | for j in range(n_features): 363 | if w[j] != 0: 364 | if is_sparse: 365 | startptr, endptr = X_indptr[j], X_indptr[j + 1] 366 | for i in range(startptr, endptr): 367 | R[X_indices[i]] += w[j] * X_data[i] 368 | if center: 369 | X_mean_j = X_mean[j] 370 | for i in range(n_samples): 371 | R[i] -= X_mean_j * w[j] 372 | else: 373 | tmp = w[j] 374 | faxpy(&n_samples, &tmp, &X[0, j], &inc, &R[0], &inc) 375 | # currently R = X @ w, update for LASSO/GRPLASSO: 376 | if pb in (LASSO, GRPLASSO): 377 | for i in range(n_samples): 378 | R[i] = y[i] - R[i] 379 | 380 | 381 | @cython.boundscheck(False) 382 | @cython.wraparound(False) 383 | @cython.cdivision(True) 384 | cpdef floating dnorm_enet( 385 | bint is_sparse, floating[:] theta, floating[:] w, floating[::1, :] X, 386 | floating[:] X_data, int[:] X_indices, int[:] X_indptr, int[:] skip, 387 | floating[:] X_mean, floating[:] weights, bint center, 388 | bint positive, floating alpha, floating l1_ratio) nogil: 389 | """compute norm(X[:, ~skip].T.dot(theta), ord=inf)""" 390 | cdef int n_samples = theta.shape[0] 391 | cdef int n_features = skip.shape[0] 392 | cdef floating Xj_theta 393 | cdef floating dnorm_XTtheta = 0. 394 | cdef floating theta_sum = 0. 395 | cdef int i, j, Cj, startptr, endptr 396 | 397 | if is_sparse: 398 | # TODO by design theta_sum should always be 0 when center 399 | if center: 400 | for i in range(n_samples): 401 | theta_sum += theta[i] 402 | 403 | # max over feature for which skip[j] == False 404 | for j in range(n_features): 405 | if skip[j] or weights[j] == INFINITY: 406 | continue 407 | if is_sparse: 408 | startptr = X_indptr[j] 409 | endptr = X_indptr[j + 1] 410 | Xj_theta = 0. 411 | for i in range(startptr, endptr): 412 | Xj_theta += X_data[i] * theta[X_indices[i]] 413 | if center: 414 | Xj_theta -= theta_sum * X_mean[j] 415 | else: 416 | Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], &inc) 417 | 418 | # minus sign to consider the choice theta = y - Xw and not theta = Xw -y 419 | if l1_ratio != 1: 420 | Xj_theta -= alpha * (1 - l1_ratio) * weights[j] * w[j] 421 | 422 | if not positive: 423 | Xj_theta = fabs(Xj_theta) 424 | dnorm_XTtheta = max(dnorm_XTtheta, Xj_theta / weights[j]) 425 | return dnorm_XTtheta 426 | 427 | 428 | @cython.boundscheck(False) 429 | @cython.wraparound(False) 430 | @cython.cdivision(True) 431 | cdef void set_prios( 432 | bint is_sparse, floating[:] theta, floating[:] w, floating alpha, floating l1_ratio, 433 | floating[::1, :] X, floating[:] X_data, int[:] X_indices, int[:] X_indptr, 434 | floating[:] norms_X_col, floating[:] weights, floating[:] prios, 435 | int[:] screened, floating radius, int * n_screened, bint positive) nogil: 436 | cdef int i, j, startptr, endptr 437 | cdef floating Xj_theta 438 | cdef int n_samples = theta.shape[0] 439 | cdef int n_features = prios.shape[0] 440 | cdef floating norms_X_col_j = 0. 441 | 442 | # TODO we do not substract theta_sum, which seems to indicate that theta 443 | # is always centered... 444 | for j in range(n_features): 445 | if screened[j] or norms_X_col[j] == 0. or weights[j] == 0.: 446 | prios[j] = INFINITY 447 | continue 448 | if is_sparse: 449 | Xj_theta = 0 450 | startptr = X_indptr[j] 451 | endptr = X_indptr[j + 1] 452 | for i in range(startptr, endptr): 453 | Xj_theta += theta[X_indices[i]] * X_data[i] 454 | else: 455 | Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], &inc) 456 | 457 | norms_X_col_j = norms_X_col[j] 458 | if l1_ratio != 1: 459 | Xj_theta -= alpha * (1 - l1_ratio) * weights[j] * w[j] 460 | 461 | norms_X_col_j = norms_X_col_j ** 2 462 | norms_X_col_j += sqrt(norms_X_col_j + alpha * (1 - l1_ratio) * weights[j]) 463 | 464 | if positive: 465 | prios[j] = fabs(Xj_theta - alpha * l1_ratio * weights[j]) / norms_X_col_j 466 | else: 467 | prios[j] = (alpha * l1_ratio * weights[j] - fabs(Xj_theta)) / norms_X_col_j 468 | 469 | if prios[j] > radius: 470 | screened[j] = True 471 | n_screened[0] += 1 472 | -------------------------------------------------------------------------------- /celer/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | CELER_PATH = str(Path.home()) + '/celer_data/' # noqa 3 | 4 | from .climate import fetch_climate 5 | from .libsvm import fetch_libsvm 6 | from .ml_uci import fetch_ml_uci 7 | from .simulated import make_correlated_data 8 | -------------------------------------------------------------------------------- /celer/datasets/climate.py: -------------------------------------------------------------------------------- 1 | # Author : Eugene Ndiaye 2 | # Mathurin Massias 3 | # BSD License 4 | 5 | import os 6 | from os.path import join as pjoin 7 | 8 | import numpy as np 9 | import xarray 10 | import download 11 | from scipy.signal import detrend 12 | 13 | from celer.datasets import CELER_PATH 14 | 15 | FILES = ["air.mon.mean.nc", 'pres.mon.mean.nc', 'pr_wtr.mon.mean.nc', 16 | "rhum.mon.mean.nc", 'slp.mon.mean.nc', "uwnd.mon.mean.nc", 17 | "vwnd.mon.mean.nc", 18 | ] 19 | 20 | 21 | def _get_data(filename): 22 | data = xarray.open_dataset( 23 | pjoin(CELER_PATH, 'climate/surface', filename), decode_times=False) 24 | 25 | n_times = data[list(data.data_vars.keys())[0]].shape[0] 26 | 27 | X = np.array(data[list(data.data_vars.keys())[0]]).reshape(n_times, -1) 28 | 29 | # remove seasonality 30 | period = 12 31 | for m in range(period): 32 | # TODO using sklearn for preprocessing would be an improvement 33 | X[m::period] -= np.mean(X[m::period], axis=0)[None, :] 34 | X[m::period] /= np.std(X[m::period], axis=0)[None, :] 35 | if np.sum(np.isnan(X[m::period])) > 0: 36 | X[m::period] = np.where(np.isnan(X[m::period]), 0, X[m::period]) 37 | 38 | # remove trend 39 | X = detrend(X, axis=0, type='linear') 40 | 41 | return X 42 | 43 | 44 | def _download_climate(replace=False): 45 | prefix = "ftp://ftp.cdc.noaa.gov/Datasets/ncep.reanalysis.derived/" 46 | 47 | for fname in FILES: 48 | target = pjoin(CELER_PATH, 'climate/surface', fname) 49 | download.download(prefix + "surface/" + fname, target, 50 | replace=replace) 51 | 52 | 53 | def _target_region(lx, Lx): 54 | 55 | arrays = [_get_data(filename) for filename in FILES] 56 | 57 | n, p = arrays[0].shape 58 | X = np.zeros((n, 7 * (p - 1)), order='F') 59 | 60 | pos_lx = int((90 - lx) / 2.5) 61 | pos_Lx = (np.ceil(Lx / 2.5)).astype(int) 62 | target = pos_lx * 144 + pos_Lx 63 | 64 | begin = 0 65 | for j in range(p): 66 | if j == target: 67 | continue 68 | X[:, begin:begin + 7] = np.vstack( 69 | [arr[:, j] for arr in arrays]).T 70 | begin += 7 71 | 72 | y = arrays[0][:, target].astype(np.float64) 73 | 74 | # np.save(pjoin(path, 'climate_data.npy'), X) 75 | # np.save(pjoin(path, 'climate_target.npy'), y) 76 | 77 | return X, y 78 | 79 | 80 | def fetch_climate(replace=False): 81 | """Get design matrix and observation for the climate dataset. 82 | 83 | Parameters 84 | ---------- 85 | replace: bool (default=False) 86 | Whether to redownload the files if already present on disk. 87 | 88 | Returns 89 | ------- 90 | X: np.array, shape (n_samples, n_features) 91 | Design matrix. 92 | y: np.array, shape (n_samples,) 93 | Observations. 94 | """ 95 | path = pjoin(CELER_PATH, 'climate') 96 | if not os.path.exists(path): 97 | os.mkdir(path) 98 | 99 | _download_climate(replace=replace) 100 | lx, Lx = 14, 17 # Dakar 101 | print("Preprocessing and loading target region...") 102 | X, y = _target_region(lx, Lx) 103 | 104 | return X, y 105 | 106 | 107 | if __name__ == "__main__": 108 | X, y = fetch_climate(replace=True) 109 | -------------------------------------------------------------------------------- /celer/datasets/libsvm.py: -------------------------------------------------------------------------------- 1 | # Author: Mathurin Massias 2 | # License: BSD 3 clause 3 | 4 | import warnings 5 | 6 | import libsvmdata 7 | 8 | 9 | def fetch_libsvm(dataset, replace=False, normalize=True, min_nnz=3): 10 | """ 11 | This function is deprecated, we now rely on the libsvmdata package. 12 | 13 | Parameters 14 | ---------- 15 | dataset: string 16 | Name of the dataset. 17 | replace: bool 18 | Whether to redownload the data. 19 | normalize: bool 20 | Whether to divide the columns by their norm. 21 | min_nnz: int 22 | Columns with strictly less than `nnz` non-zero entries are discarded. 23 | """ 24 | warnings.simplefilter("always", FutureWarning) 25 | warnings.warn("celer.datasets.fetch_libsvm is deprecated and will be " 26 | "removed in version 0.6. Use the lightweight " 27 | "libsvmadata package instead.", FutureWarning) 28 | return libsvmdata.fetch_libsvm(dataset, replace=replace, 29 | normalize=normalize, min_nnz=min_nnz) 30 | 31 | 32 | if __name__ == "__main__": 33 | for dataset in libsvmdata.datasets.NAMES: 34 | fetch_libsvm(dataset, replace=False) 35 | -------------------------------------------------------------------------------- /celer/datasets/ml_uci.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from scipy import sparse 5 | 6 | from os.path import join as pjoin 7 | 8 | from celer.datasets import CELER_PATH 9 | 10 | BASE = 'https://archive.ics.uci.edu/ml/machine-learning-databases/' 11 | 12 | NAMES = {'gisette_train': 'gisette/GISETTE/gisette_train'} 13 | 14 | 15 | def fetch_ml_uci(dataset): 16 | """Get a datasest from ML UCI database. 17 | 18 | Parameters 19 | ---------- 20 | dataset: string 21 | Dataset name. Must be in NAMES.keys() 22 | 23 | Returns 24 | ------- 25 | X: np.array, shape (n_samples, n_features) 26 | Design matrix. 27 | y: np.array, shape (n_samples) 28 | Target vector. 29 | """ 30 | if not os.path.exists(pjoin(CELER_PATH, "ml_uci")): 31 | os.makedirs(pjoin(CELER_PATH, "ml_uci")) 32 | 33 | if dataset not in NAMES: 34 | raise ValueError("Unsupported dataset %s" % dataset) 35 | 36 | X_path = pjoin(CELER_PATH, "ml_uci", dataset + '_data.npz') 37 | y_path = pjoin(CELER_PATH, "ml_uci", dataset + '_target.npy') 38 | try: 39 | X = sparse.load_npz(X_path) 40 | y = np.load(y_path) 41 | except FileNotFoundError: 42 | df = pd.read_csv(BASE + NAMES[dataset] + '.data', sep=' ', header=None) 43 | # trailing wspace > extra column 44 | X = sparse.csc_matrix(df.values[:, :-1]) 45 | y = np.array(pd.read_csv(BASE + NAMES[dataset] + '.labels', 46 | header=None)).ravel() 47 | sparse.save_npz(X_path, X) 48 | np.save(y_path, y) 49 | 50 | return X, y 51 | -------------------------------------------------------------------------------- /celer/datasets/simulated.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Mathurin Massias 3 | # Thomas Moreau 4 | 5 | import numpy as np 6 | from numpy.linalg import norm 7 | from sklearn.utils import check_random_state 8 | 9 | 10 | def make_correlated_data(n_samples=100, n_features=50, corr=0.6, snr=3, 11 | density=0.2, w_true=None, random_state=None): 12 | r"""Generate correlated design matrix with decaying correlation rho**|i-j|. 13 | according to 14 | 15 | .. math:: 16 | 17 | y = X w^* + \epsilon 18 | 19 | such that :math:`||X w^*|| / ||\epsilon|| = snr`. 20 | 21 | The generated features have mean 0, variance 1 and the expected correlation 22 | structure: 23 | 24 | .. math:: 25 | 26 | \mathbb E[x_i] = 0~, \quad \mathbb E[x_i^2] = 1 \quad 27 | \text{and} \quad \mathbb E[x_ix_j] = \rho^{|i-j|} 28 | 29 | 30 | Parameters 31 | ---------- 32 | n_samples: int 33 | Number of samples in the design matrix. 34 | n_features: int 35 | Number of features in the design matrix. 36 | corr: float 37 | Correlation :math:`\rho` between successive features. The element 38 | :math:`C_{i, j}` in the correlation matrix will be 39 | :math:`\rho^{|i-j|}`. This parameter should be selected in 40 | :math:`[0, 1[`. 41 | snr: float or np.inf 42 | Signal-to-noise ratio. In np.inf, no noise is added. 43 | density: float 44 | Proportion of non zero elements in w_true if it must be simulated. 45 | w_true: np.array, shape (n_features,) | None 46 | True regression coefficients. If None, an array with `nnz` non zero 47 | standard Gaussian entries is simulated. 48 | random_state: int | RandomState instance | None (default) 49 | Determines random number generation for data generation. Use an int to 50 | make the randomness deterministic. 51 | 52 | Returns 53 | ------- 54 | X: ndarray, shape (n_samples, n_features) 55 | A design matrix with Toeplitz covariance. 56 | y: ndarray, shape (n_samples,) 57 | Observation vector. 58 | w_true: ndarray, shape (n_features,) 59 | True regression vector of the model. 60 | """ 61 | if not 0 <= corr < 1: 62 | raise ValueError("The correlation `corr` should be chosen in [0, 1[.") 63 | if not 0 < density <= 1: 64 | raise ValueError("The density should be chosen in ]0, 1].") 65 | rng = check_random_state(random_state) 66 | nnz = int(density * n_features) 67 | 68 | if corr == 0: 69 | X = np.asfortranarray(rng.randn(n_samples, n_features)) 70 | else: 71 | # X is generated cleverly using an AR model with reason corr and 72 | # innovation sigma^2 = 1 - corr ** 2: X[:, j+1] = corr X[:, j] + eps_j 73 | # where eps_j = sigma * rng.randn(n_samples) 74 | sigma = np.sqrt(1 - corr ** 2) 75 | U = rng.randn(n_samples) 76 | 77 | X = np.empty([n_samples, n_features], order='F') 78 | X[:, 0] = U 79 | for j in range(1, n_features): 80 | U *= corr 81 | U += sigma * rng.randn(n_samples) 82 | X[:, j] = U 83 | 84 | if w_true is None: 85 | w_true = np.zeros(n_features) 86 | support = rng.choice(n_features, nnz, replace=False) 87 | w_true[support] = rng.randn(nnz) 88 | 89 | y = X @ w_true 90 | if snr != np.inf: 91 | noise = rng.randn(n_samples) 92 | y += noise / norm(noise) * norm(y) / snr 93 | return X, y, w_true 94 | -------------------------------------------------------------------------------- /celer/datasets/tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import norm 3 | 4 | from celer.datasets import make_correlated_data 5 | 6 | 7 | def test_correlated(): 8 | X, y, w_true = make_correlated_data(snr=np.inf) 9 | np.testing.assert_allclose(y, X @ w_true) 10 | 11 | snr = 5 12 | w_true = np.ones(50) 13 | X, y, _ = make_correlated_data(n_features=w_true.shape[0], snr=5) 14 | np.testing.assert_allclose(snr, norm(X @ w_true), norm(y - X @ w_true)) 15 | 16 | np.testing.assert_raises(ValueError, make_correlated_data, corr=1.01) 17 | np.testing.assert_raises( 18 | ValueError, make_correlated_data, density=1.01) 19 | -------------------------------------------------------------------------------- /celer/group_fast.pyx: -------------------------------------------------------------------------------- 1 | #cython: language_level=3 2 | # Author: Mathurin Massias 3 | # License: BSD 3 clause 4 | 5 | cimport cython 6 | import numpy as np 7 | cimport numpy as np 8 | import warnings 9 | 10 | from cython cimport floating 11 | from libc.math cimport fabs, sqrt, INFINITY 12 | from sklearn.exceptions import ConvergenceWarning 13 | 14 | from .cython_utils cimport (fdot, fasum, faxpy, fnrm2, fcopy, fscal, dual, 15 | LASSO, LOGREG, create_accel_pt) 16 | 17 | 18 | cdef: 19 | int inc = 1 20 | 21 | 22 | @cython.boundscheck(False) 23 | @cython.wraparound(False) 24 | @cython.cdivision(True) 25 | cpdef floating primal_grplasso( 26 | floating alpha, floating[:] R, int[::1] grp_ptr, 27 | int[::1] grp_indices, floating[:] w, floating[:] weights): 28 | cdef floating nrm = 0. 29 | cdef int j, k, g 30 | cdef int n_samples = R.shape[0] 31 | cdef int n_groups = grp_ptr.shape[0] - 1 32 | cdef floating p_obj = fnrm2(&n_samples, &R[0], &inc) ** 2 / (2 * n_samples) 33 | 34 | for g in range(n_groups): 35 | if weights[g] != INFINITY: 36 | nrm = 0. 37 | for k in range(grp_ptr[g], grp_ptr[g + 1]): 38 | j = grp_indices[k] 39 | nrm += w[j] ** 2 40 | p_obj += alpha * sqrt(nrm) * weights[g] 41 | 42 | return p_obj 43 | 44 | 45 | @cython.boundscheck(False) 46 | @cython.wraparound(False) 47 | @cython.cdivision(True) 48 | cpdef floating dnorm_grp( 49 | bint is_sparse, floating[::1] theta, int[::1] grp_ptr, 50 | int[::1] grp_indices, floating[::1, :] X, floating[::1] X_data, 51 | int[::1] X_indices, int[::1] X_indptr, floating[::1] X_mean, 52 | floating[:] weights, int ws_size, int[:] C, bint center): 53 | """Dual norm in the group case, i.e. L2/infty ofter groups.""" 54 | cdef floating Xj_theta, tmp 55 | cdef floating dnorm_XTtheta = 0. 56 | cdef floating theta_sum = 0. 57 | cdef int i, j, g, g_idx, k, startptr, endptr 58 | cdef int n_groups = grp_ptr.shape[0] - 1 59 | cdef int n_samples = theta.shape[0] 60 | 61 | if is_sparse: 62 | if center: 63 | for i in range(n_samples): 64 | theta_sum += theta[i] 65 | 66 | if ws_size == n_groups: # max over all groups 67 | for g in range(n_groups): 68 | if weights[g] == INFINITY: 69 | continue 70 | 71 | tmp = 0 72 | for k in range(grp_ptr[g], grp_ptr[g + 1]): 73 | j = grp_indices[k] 74 | if is_sparse: 75 | startptr = X_indptr[j] 76 | endptr = X_indptr[j + 1] 77 | Xj_theta = 0. 78 | for i in range(startptr, endptr): 79 | Xj_theta += X_data[i] * theta[X_indices[i]] 80 | if center: 81 | Xj_theta -= theta_sum * X_mean[j] 82 | else: 83 | Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], 84 | &inc) 85 | tmp += Xj_theta ** 2 86 | 87 | dnorm_XTtheta = max(dnorm_XTtheta, sqrt(tmp) / weights[g]) 88 | 89 | else: # scaling only with features in C 90 | for g_idx in range(ws_size): 91 | g = C[g_idx] 92 | 93 | if weights[g] == INFINITY: 94 | continue 95 | 96 | tmp = 0 97 | for k in range(grp_ptr[g], grp_ptr[g + 1]): 98 | j = grp_indices[k] 99 | if is_sparse: 100 | startptr = X_indptr[j] 101 | endptr = X_indptr[j + 1] 102 | Xj_theta = 0. 103 | for i in range(startptr, endptr): 104 | Xj_theta += X_data[i] * theta[X_indices[i]] 105 | if center: 106 | Xj_theta -= theta_sum * X_mean[j] 107 | else: 108 | Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], 109 | &inc) 110 | tmp += Xj_theta ** 2 111 | 112 | dnorm_XTtheta = max(dnorm_XTtheta, sqrt(tmp) / weights[g]) 113 | return dnorm_XTtheta 114 | 115 | 116 | @cython.boundscheck(False) 117 | @cython.wraparound(False) 118 | @cython.cdivision(True) 119 | cdef void set_prios_grp( 120 | bint is_sparse, int pb, floating[::1] theta, floating alpha, floating[::1, :] X, 121 | floating[::1] X_data, int[::1] X_indices, int[::1] X_indptr, 122 | floating[:] weights, floating[::1] norms_X_grp, int[::1] grp_ptr, 123 | int[::1] grp_indices, floating[::1] prios, int[::1] screened, 124 | floating radius, int * n_screened): 125 | cdef int i, j, k, g, startptr, endptr 126 | cdef floating nrm_Xgtheta, Xj_theta 127 | cdef int n_groups = grp_ptr.shape[0] - 1 128 | cdef int n_samples = theta.shape[0] 129 | 130 | for g in range(n_groups): 131 | if screened[g] or norms_X_grp[g] == 0.: 132 | prios[g] = INFINITY 133 | continue 134 | nrm_Xgtheta = 0 135 | for k in range(grp_ptr[g], grp_ptr[g + 1]): 136 | j = grp_indices[k] 137 | if is_sparse: 138 | startptr = X_indptr[j] 139 | endptr = X_indptr[j + 1] 140 | Xj_theta = 0. 141 | for i in range(startptr, endptr): 142 | Xj_theta += X_data[i] * theta[X_indices[i]] 143 | else: 144 | Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], &inc) 145 | nrm_Xgtheta += Xj_theta ** 2 146 | nrm_Xgtheta = sqrt(nrm_Xgtheta) / weights[g] 147 | 148 | prios[g] = (alpha - nrm_Xgtheta) / norms_X_grp[g] 149 | 150 | if prios[g] > radius: 151 | pass 152 | # TODO check 153 | # screened[g] = True 154 | # n_screened[0] += 1 155 | 156 | @cython.boundscheck(False) 157 | @cython.wraparound(False) 158 | @cython.cdivision(True) 159 | cpdef celer_grp( 160 | bint is_sparse, int pb, floating[::1, :] X, int[::1] grp_indices, 161 | int[::1] grp_ptr, floating[::1] X_data, int[::1] X_indices, 162 | int[::1] X_indptr, floating[::1] X_mean, floating[:] y, floating alpha, 163 | floating[:] w, floating[:] R, floating[::1] theta, 164 | floating[::1] norms_X_grp, floating tol, floating[:] weights, int max_iter, 165 | int max_epochs, int gap_freq=10, floating tol_ratio_inner=0.3, int p0=100, 166 | bint prune=1, bint use_accel=1, 167 | bint verbose=0): 168 | 169 | pb = LASSO 170 | cdef int verbose_in = max(0, verbose - 1) 171 | cdef floating l1_ratio = 1.0 172 | cdef floating norm_w2 = 0. 173 | 174 | if floating is double: 175 | dtype = np.float64 176 | else: 177 | dtype = np.float32 178 | 179 | cdef int n_samples = y.shape[0] 180 | cdef int n_features = w.shape[0] 181 | cdef int n_groups = norms_X_grp.shape[0] 182 | 183 | cdef floating norm_y2 = fnrm2(&n_samples, &y[0], &inc) ** 2 184 | # scale tolerance to account for small or large y: 185 | tol *= norm_y2 / n_samples 186 | 187 | cdef floating[::1] lc_groups = np.square(norms_X_grp) 188 | cdef int[:] all_groups = np.arange(n_groups, dtype=np.int32) 189 | cdef int[:] dummy_C = np.zeros(1, dtype=np.int32) 190 | cdef int[:] C 191 | 192 | cdef int n_screened = 0 193 | cdef int i, j, g, g_idx, k, startptr, endptr, epoch, t 194 | cdef int nnz, ws_size 195 | cdef floating[::1] prios = np.empty(n_groups, dtype=dtype) 196 | cdef int[::1] screened = np.zeros(n_groups, dtype=np.int32) 197 | cdef int max_group_size = 0 198 | 199 | cdef bint center = False 200 | if is_sparse: 201 | # center = X_mean.any(): 202 | for j in range(n_features): 203 | if X_mean[j]: 204 | center = True 205 | break 206 | 207 | for g in range(n_groups): 208 | max_group_size = max(max_group_size, grp_ptr[g + 1] - grp_ptr[g]) 209 | 210 | cdef floating[:] old_w_g = np.zeros(max_group_size, dtype=dtype) 211 | 212 | cdef floating[::1] gaps = np.zeros(max_iter, dtype=dtype) 213 | cdef floating[::1] theta_inner = np.zeros(n_samples, dtype=dtype) 214 | cdef floating[::1] thetacc = np.empty(n_samples, dtype=dtype) 215 | 216 | cdef floating gap, p_obj, d_obj, dnorm_XTtheta, X_mean_j 217 | cdef floating gap_in, p_obj_in, d_obj_in, tol_in, d_obj_accel 218 | cdef floating d_obj_from_inner 219 | cdef floating highest_d_obj = 0. 220 | cdef floating highest_d_obj_in = 0. 221 | cdef floating tmp, theta_scaling, R_sum, norm_wg, bst_scal 222 | cdef floating radius = INFINITY 223 | 224 | # acceleration variables: 225 | cdef int K = 6 226 | cdef floating[:, :] last_K_R = np.empty([K, n_samples], dtype=dtype) 227 | cdef floating[:, :] U = np.empty([K - 1, n_samples], dtype=dtype) 228 | cdef floating[:, :] UtU = np.empty([K - 1, K - 1], dtype=dtype) 229 | cdef floating[:] onesK = np.ones(K - 1, dtype=dtype) 230 | 231 | cdef int info_dposv 232 | 233 | for t in range(max_iter): 234 | # if t != 0: TODO potential speedup at iteration 0 235 | fcopy(&n_samples, &R[0], &inc, &theta[0], &inc) 236 | 237 | tmp = 1. / n_samples 238 | fscal(&n_samples, &tmp, &theta[0], &inc) 239 | 240 | dnorm_XTtheta = dnorm_grp( 241 | is_sparse, theta, grp_ptr, grp_indices, X, X_data, X_indices, 242 | X_indptr, X_mean, weights, n_groups, dummy_C, center) 243 | 244 | if dnorm_XTtheta > alpha: 245 | theta_scaling = alpha / dnorm_XTtheta 246 | fscal(&n_samples, &theta_scaling, &theta[0], &inc) 247 | 248 | d_obj = dual(pb, n_samples, alpha, l1_ratio, norm_y2, norm_w2, &theta[0], &y[0]) 249 | 250 | if t > 0: 251 | # also test dual point returned by inner solver after 1st iter: 252 | dnorm_XTtheta = dnorm_grp( 253 | is_sparse, theta_inner, grp_ptr, grp_indices, X, X_data, 254 | X_indices, X_indptr, X_mean, weights, n_groups, dummy_C, center) 255 | 256 | if dnorm_XTtheta > alpha: 257 | theta_scaling = alpha / dnorm_XTtheta 258 | fscal(&n_samples, &theta_scaling, &theta_inner[0], &inc) 259 | 260 | d_obj_from_inner = dual( 261 | pb, n_samples, alpha, l1_ratio, norm_y2, norm_w2, &theta_inner[0], &y[0]) 262 | 263 | if d_obj_from_inner > d_obj: 264 | d_obj = d_obj_from_inner 265 | fcopy(&n_samples, &theta_inner[0], &inc, &theta[0], &inc) 266 | 267 | if t == 0 or d_obj > highest_d_obj: 268 | highest_d_obj = d_obj 269 | # TODO implement a best_theta 270 | 271 | p_obj = primal_grplasso(alpha, R, grp_ptr, grp_indices, w, weights) 272 | gap = p_obj - highest_d_obj 273 | gaps[t] = gap 274 | 275 | if verbose: 276 | print("Iter %d: primal %.10f, gap %.2e" % (t, p_obj, gap), end="") 277 | 278 | if gap <= tol: 279 | if verbose: 280 | print("\nEarly exit, gap: %.2e < %.2e" % (gap, tol)) 281 | break 282 | 283 | # if pb == LASSO: 284 | radius = sqrt(2 * gap / n_samples) 285 | # elif pb == LOGREG: 286 | # radius = sqrt(gap / 2.) 287 | 288 | set_prios_grp( 289 | is_sparse, pb, theta, alpha, X, X_data, X_indices, X_indptr, 290 | weights, lc_groups, grp_ptr, grp_indices, prios, screened, 291 | radius, &n_screened) 292 | 293 | if prune: 294 | nnz = 0 295 | for g in range(n_groups): 296 | # TODO this is a hack, will fail for sparse group lasso 297 | if w[grp_indices[grp_ptr[g]]] != 0: 298 | prios[g] = -1. 299 | nnz += 1 300 | 301 | if t == 0: 302 | ws_size = p0 if nnz == 0 else nnz 303 | else: 304 | ws_size = 2 * nnz 305 | 306 | else: 307 | for g in range(n_groups): 308 | if w[grp_indices[grp_ptr[g]]] != 0: 309 | prios[g] = - 1 # include active features 310 | if t == 0: 311 | ws_size = p0 312 | else: 313 | for g in range(ws_size): 314 | if not screened[C[g]]: 315 | prios[C[g]] = -1 316 | ws_size = 2 * ws_size 317 | 318 | if ws_size > n_groups - n_screened: 319 | ws_size = n_groups - n_screened 320 | 321 | # if ws_size == n_groups then argpartition will break: 322 | if ws_size == n_groups: 323 | C = all_groups 324 | else: 325 | C = np.argpartition(np.asarray(prios), 326 | ws_size)[:ws_size].astype(np.int32) 327 | if prune: 328 | tol_in = 0.3 * gap 329 | else: 330 | tol_in = tol 331 | 332 | if verbose: 333 | print(", %d groups in subpb (%d left)" % 334 | (len(C), n_groups - n_screened)) 335 | 336 | highest_d_obj_in = 0. 337 | for epoch in range(max_epochs): 338 | if epoch != 0 and epoch % gap_freq == 0: 339 | fcopy(&n_samples, &R[0], &inc, &theta_inner[0], &inc) 340 | 341 | tmp = 1. / n_samples 342 | fscal(&n_samples, &tmp, &theta_inner[0], &inc) 343 | 344 | dnorm_XTtheta = dnorm_grp( 345 | is_sparse, theta_inner, grp_ptr, grp_indices, X, X_data, 346 | X_indices, X_indptr, X_mean, weights, ws_size, C, center) 347 | 348 | if dnorm_XTtheta > alpha: 349 | theta_scaling = alpha / dnorm_XTtheta 350 | fscal(&n_samples, &theta_scaling, &theta_inner[0], &inc) 351 | 352 | # dual value is the same as for the Lasso 353 | d_obj_in = dual( 354 | pb, n_samples, alpha, l1_ratio, norm_y2, norm_w2, &theta_inner[0], &y[0]) 355 | 356 | if use_accel: # also compute accelerated dual_point 357 | info_dposv = create_accel_pt( 358 | LASSO, n_samples, epoch, gap_freq, &R[0], 359 | &thetacc[0], &last_K_R[0, 0], U, UtU, onesK, y) 360 | 361 | # if info_dposv != 0 and verbose: 362 | # print("linear system solving failed") 363 | 364 | if epoch // gap_freq >= K: 365 | dnorm_XTtheta = dnorm_grp( 366 | is_sparse, thetacc, grp_ptr, grp_indices, X, 367 | X_data, X_indices, X_indptr, X_mean, weights, 368 | ws_size, C, center) 369 | 370 | if dnorm_XTtheta > alpha: 371 | theta_scaling = alpha / dnorm_XTtheta 372 | fscal(&n_samples, &theta_scaling, &thetacc[0], &inc) 373 | 374 | d_obj_accel = dual(pb, n_samples, alpha, l1_ratio, norm_y2, 375 | norm_w2, &thetacc[0], &y[0]) 376 | if d_obj_accel > d_obj_in: 377 | d_obj_in = d_obj_accel 378 | fcopy(&n_samples, &thetacc[0], &inc, 379 | &theta_inner[0], &inc) 380 | 381 | 382 | if d_obj_in > highest_d_obj_in: 383 | highest_d_obj_in = d_obj_in 384 | 385 | p_obj_in = primal_grplasso(alpha, R, grp_ptr, grp_indices, w, weights) 386 | gap_in = p_obj_in - highest_d_obj_in 387 | 388 | if verbose_in: 389 | print("Epoch %d, primal %.10f, gap: %.2e" % 390 | (epoch, p_obj_in, gap_in)) 391 | if gap_in < tol_in: 392 | if verbose_in: 393 | print("Exit epoch %d, gap: %.2e < %.2e" % 394 | (epoch, gap_in, tol_in)) 395 | break 396 | 397 | for g_idx in range(ws_size): 398 | g = C[g_idx] 399 | if lc_groups[g] == 0.: 400 | continue 401 | norm_wg = 0. 402 | for k in range(grp_ptr[g + 1] - grp_ptr[g]): 403 | j = grp_indices[k + grp_ptr[g]] 404 | old_w_g[k] = w[j] 405 | 406 | if is_sparse: 407 | X_mean_j = X_mean[j] 408 | startptr, endptr = X_indptr[j], X_indptr[j + 1] 409 | for i in range(startptr, endptr): 410 | w[j] += R[X_indices[i]] * X_data[i] / lc_groups[g] 411 | if center: 412 | R_sum = 0. 413 | for i in range(n_samples): 414 | R_sum += R[i] 415 | w[j] -= R_sum * X_mean_j / lc_groups[g] 416 | else: 417 | w[j] += fdot(&n_samples, &X[0, j], &inc, &R[0], 418 | &inc) / lc_groups[g] 419 | norm_wg += w[j] ** 2 420 | norm_wg = sqrt(norm_wg) 421 | if norm_wg != 0.: 422 | bst_scal = max(0., 423 | 1. - alpha * weights[g] / lc_groups[g] * n_samples / norm_wg) 424 | else: 425 | bst_scal = 0. 426 | 427 | for k in range(grp_ptr[g + 1] - grp_ptr[g]): 428 | j = grp_indices[grp_ptr[g] + k] 429 | # perform BST: 430 | w[j] *= bst_scal 431 | # R -= (w_j - old_w_j) * (X[:, j] - X_mean[j]) 432 | tmp = old_w_g[k] - w[j] 433 | if tmp != 0.: 434 | if is_sparse: 435 | startptr, endptr = X_indptr[j], X_indptr[j + 1] 436 | for i in range(startptr, endptr): 437 | R[X_indices[i]] += tmp * X_data[i] 438 | if center: 439 | X_mean_j = X_mean[j] 440 | for i in range(n_samples): 441 | R[i] -= X_mean_j * tmp 442 | else: 443 | faxpy(&n_samples, &tmp, &X[0, j], &inc, &R[0], 444 | &inc) 445 | 446 | else: 447 | warnings.warn( 448 | 'Objective did not converge: duality ' + 449 | f'gap: {gap}, tolerance: {tol}. Increasing `tol` may make the' + 450 | ' solver faster without affecting the results much. \n' + 451 | 'Fitting data with very small alpha causes precision issues.', 452 | ConvergenceWarning) 453 | return np.asarray(w), np.asarray(theta), np.asarray(gaps[:t + 1]) 454 | -------------------------------------------------------------------------------- /celer/homotopy.py: -------------------------------------------------------------------------------- 1 | # Author: Mathurin Massias 2 | # Alexandre Gramfort 3 | # Joseph Salmon 4 | # License: BSD 3 clause 5 | 6 | import numpy as np 7 | 8 | from scipy import sparse 9 | from numpy.linalg import norm 10 | from sklearn.utils import check_array 11 | from sklearn.linear_model._base import _preprocess_data 12 | 13 | from .lasso_fast import celer 14 | from .group_fast import celer_grp, dnorm_grp 15 | from .cython_utils import compute_norms_X_col, compute_Xw 16 | from .cython_utils import dnorm_enet as dnorm_enet_cython 17 | from .multitask_fast import celer_mtl 18 | from .PN_logreg import newton_celer 19 | 20 | LASSO = 0 21 | LOGREG = 1 22 | GRPLASSO = 2 23 | 24 | 25 | def celer_path(X, y, pb, eps=1e-3, n_alphas=100, alphas=None, l1_ratio=1.0, 26 | coef_init=None, max_iter=20, max_epochs=50000, 27 | p0=10, verbose=0, tol=1e-6, prune=0, weights=None, 28 | groups=None, return_thetas=False, use_PN=False, X_offset=None, 29 | X_scale=None, return_n_iter=False, positive=False): 30 | r"""Compute optimization path with Celer as inner solver. 31 | 32 | With ``n = len(y)`` and ``p = len(w)`` the number of samples and features, 33 | the losses are: 34 | 35 | * Lasso: 36 | 37 | .. math:: 38 | 39 | \frac{\| y - X w \||_2^2}{2 n} + \alpha \sum_{j=1}^p weights_j |w_j| 40 | 41 | * ElasticNet: 42 | 43 | .. math:: 44 | 45 | \frac{\| y - X w \|_2^2}{2 n} + 46 | \alpha \sum_{j=1}^p weights_j (l1\_ratio |w_j| + (1-l1\_ratio) w_j^2) 47 | 48 | * Logreg: 49 | 50 | .. math:: 51 | 52 | \sum_{i=1}^n \text{log} \,(1 + e^{-y_i x_i^\top w}) + \alpha 53 | \sum_{j=1}^p weights_j |w_j| 54 | 55 | * GroupLasso, with `G` the number of groups and :math:`w_{[g]}` the subvector 56 | corresponding the group `g`: 57 | 58 | .. math:: 59 | 60 | \frac{\| y - X w \|_2^2}{2 n} + \alpha \sum_{g=1}^G weights_g \| w_{[g]} \|_2 61 | 62 | 63 | Parameters 64 | ---------- 65 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 66 | Training data. Pass directly as Fortran-contiguous data or column 67 | sparse format (CSC) to avoid unnecessary memory duplication. 68 | 69 | y : ndarray, shape (n_samples,) 70 | Target values. 71 | 72 | pb : "lasso" | "logreg" | "grouplasso" 73 | Optimization problem to solve. 74 | 75 | eps : float, optional 76 | Length of the path. ``eps=1e-3`` means that 77 | ``alpha_min = 1e-3 * alpha_max``. 78 | 79 | n_alphas : int, optional 80 | Number of alphas along the regularization path. 81 | 82 | alphas : ndarray, optional 83 | List of alphas where to compute the models. 84 | If ``None`` alphas are set automatically. 85 | 86 | l1_ratio : float, optional 87 | The ElasticNet mixing parameter, with ``0 < l1_ratio <= 1``. 88 | Defaults to 1.0 which corresponds to L1 penalty (Lasso). 89 | ``l1_ratio = 0`` (Ridge regression) is not supported. 90 | 91 | coef_init : ndarray, shape (n_features,) | None, optional, (default=None) 92 | Initial value of coefficients. If ``None``, ``np.zeros(n_features)`` is used. 93 | 94 | max_iter : int, optional 95 | The maximum number of iterations (definition of working set and 96 | resolution of problem restricted to features in working set). 97 | 98 | max_epochs : int, optional 99 | Maximum number of (block) CD epochs on each subproblem. 100 | 101 | p0 : int, optional 102 | First working set size. 103 | 104 | verbose : bool or integer, optional 105 | Amount of verbosity. ``0`` or ``False`` is silent. 106 | 107 | tol : float, optional 108 | The tolerance for the optimization: the solver runs until the duality 109 | gap is smaller than ``tol`` or the maximum number of iteration is 110 | reached. 111 | 112 | prune : 0 | 1, optional 113 | Whether or not to use pruning when growing working sets. 114 | 115 | weights : ndarray, shape (n_features,) or (n_groups,), optional 116 | Feature/group weights used in the penalty. Default to array of ones. 117 | Features with weights equal to ``np.inf`` are ignored. 118 | 119 | groups : int or list of ints or list of list of ints, optional 120 | Used for the group Lasso only. See the documentation of the 121 | :ref:`celer.GroupLasso` class. 122 | 123 | return_thetas : bool, optional 124 | If ``True``, dual variables along the path are returned. 125 | 126 | use_PN : bool, optional 127 | If ``pb == "logreg"``, use ProxNewton solver instead of coordinate 128 | descent. 129 | 130 | X_offset : np.array, shape (n_features,), optional 131 | Used to center sparse X without breaking sparsity. Mean of each column. 132 | See `sklearn.linear_model.base._preprocess_data() 133 | `_. 135 | 136 | X_scale : np.array, shape (n_features,), optional 137 | Used to scale centered sparse X without breaking sparsity. Norm of each 138 | centered column. 139 | See `sklearn.linear_model.base._preprocess_data() 140 | `_. 142 | 143 | return_n_iter : bool, optional 144 | If ``True``, number of iterations along the path are returned. 145 | 146 | positive : bool, optional (default=False) 147 | If ``True`` and ``pb == "lasso"``, forces the coefficients to be positive. 148 | 149 | Returns 150 | ------- 151 | alphas : array, shape (n_alphas,) 152 | The alphas along the path where models are computed. 153 | 154 | coefs : array, shape (n_features, n_alphas) 155 | Coefficients along the path. 156 | 157 | dual_gaps : array, shape (n_alphas,) 158 | Duality gaps returned by the solver along the path. 159 | 160 | thetas : array, shape (n_alphas, n_samples) 161 | The dual variables along the path. 162 | (``thetas`` are returned if ``return_thetas`` is set to ``True``). 163 | """ 164 | 165 | if pb.lower() not in ("lasso", "logreg", "grouplasso"): 166 | raise ValueError("Unsupported problem %s" % pb) 167 | 168 | if pb.lower() != "lasso" and l1_ratio != 1.0: 169 | raise NotImplementedError( 170 | "Mix of l1 and l2 penalty not supported for %s" % pb 171 | ) 172 | 173 | n_groups = None # set n_groups to None for lasso and logreg 174 | if pb.lower() == "lasso": 175 | pb = LASSO 176 | elif pb.lower() == "logreg": 177 | pb = LOGREG 178 | if set(y) - set([-1.0, 1.0]): 179 | raise ValueError( 180 | "y must contain only -1. or 1 values. Got %s " % (set(y))) 181 | elif pb.lower() == "grouplasso": 182 | pb = GRPLASSO 183 | if groups is None: 184 | raise ValueError( 185 | "Groups must be specified for the group lasso problem.") 186 | grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) 187 | n_groups = len(grp_ptr) - 1 188 | else: 189 | raise ValueError("Unsupported problem: %s" % pb) 190 | 191 | is_sparse = sparse.issparse(X) 192 | 193 | X = check_array(X, 'csc', dtype=[np.float64, np.float32], 194 | order='F', copy=False, accept_large_sparse=False) 195 | y = check_array(y, 'csc', dtype=X.dtype.type, order='F', copy=False, 196 | ensure_2d=False) 197 | 198 | n_samples, n_features = X.shape 199 | 200 | if X_offset is not None: 201 | # As sparse matrices are not actually centered we need this 202 | # to be passed to the CD solver. 203 | X_sparse_scaling = X_offset / X_scale 204 | X_sparse_scaling = np.asarray(X_sparse_scaling, dtype=X.dtype) 205 | else: 206 | X_sparse_scaling = np.zeros(n_features, dtype=X.dtype) 207 | 208 | X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X) 209 | 210 | weights = _check_weights(weights, pb, X, n_groups) 211 | # to prevent ref before assignment in dnorm_enet 212 | w = np.zeros(n_features, dtype=X.dtype) 213 | 214 | if alphas is None: 215 | if pb == LASSO: 216 | alpha_max = dnorm_enet(X, y, w, weights, X_sparse_scaling, 217 | positive) / n_samples 218 | elif pb == LOGREG: 219 | alpha_max = dnorm_enet(X, y, w, weights, X_sparse_scaling, 220 | positive) / 2 221 | elif pb == GRPLASSO: 222 | # TODO compute it with dscal to handle centering sparse 223 | alpha_max = 0 224 | for g in range(n_groups): 225 | X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] 226 | alpha_max = max(alpha_max, norm(X_g.T @ y / weights[g], ord=2)) 227 | alpha_max /= n_samples 228 | 229 | alphas = alpha_max / l1_ratio * np.geomspace(1, eps, n_alphas, dtype=X.dtype) 230 | else: 231 | alphas = np.sort(alphas)[::-1] 232 | 233 | n_alphas = len(alphas) 234 | 235 | coefs = np.zeros((n_features, n_alphas), order='F', dtype=X.dtype) 236 | thetas = np.zeros((n_alphas, n_samples), dtype=X.dtype) 237 | dual_gaps = np.zeros(n_alphas) 238 | 239 | if return_n_iter: 240 | n_iters = np.zeros(n_alphas, dtype=int) 241 | 242 | if pb == GRPLASSO: 243 | # TODO this must be included in compute_norm_Xcols when centering 244 | norms_X_grp = np.zeros(n_groups, dtype=X_dense.dtype) 245 | for g in range(n_groups): 246 | X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] 247 | if is_sparse: 248 | gram = (X_g.T @ X_g).todense() 249 | # handle centering: 250 | for j1 in range(grp_ptr[g], grp_ptr[g + 1]): 251 | for j2 in range(grp_ptr[g], grp_ptr[g + 1]): 252 | gram[j1 - grp_ptr[g], j2 - grp_ptr[g]] += \ 253 | X_sparse_scaling[j1] * \ 254 | X_sparse_scaling[j2] * n_samples - \ 255 | X_sparse_scaling[j1] * \ 256 | X_data[X_indptr[j2]:X_indptr[j2+1]].sum() - \ 257 | X_sparse_scaling[j2] * \ 258 | X_data[X_indptr[j1]:X_indptr[j1+1]].sum() 259 | 260 | norms_X_grp[g] = np.sqrt(norm(gram, ord=2)) 261 | else: 262 | norms_X_grp[g] = norm(X_g, ord=2) 263 | else: 264 | # TODO harmonize names 265 | norms_X_col = np.zeros(n_features, dtype=X_dense.dtype) 266 | compute_norms_X_col( 267 | is_sparse, norms_X_col, n_samples, X_dense, X_data, 268 | X_indices, X_indptr, X_sparse_scaling) 269 | 270 | # do not skip alphas[0], it is not always alpha_max 271 | for t in range(n_alphas): 272 | alpha = alphas[t] 273 | 274 | if verbose: 275 | to_print = "##### Computing alpha %d/%d" % (t + 1, n_alphas) 276 | print("#" * len(to_print)) 277 | print(to_print) 278 | print("#" * len(to_print)) 279 | if t > 0: 280 | w = coefs[:, t - 1].copy() 281 | # theta was feasible for alphas[t-1], make it feasible for alphas[t] 282 | theta = thetas[t - 1] * (alphas[t] / alphas[t-1]) 283 | p0 = max(len(np.where(w != 0)[0]), 1) 284 | else: 285 | if coef_init is not None: 286 | w = coef_init.copy() 287 | p0 = max((w != 0.).sum(), p0) 288 | # y - Xw for Lasso, Xw for Logreg: 289 | Xw = np.zeros(n_samples, dtype=X.dtype) 290 | compute_Xw( 291 | is_sparse, pb, Xw, w, y, X_sparse_scaling.any(), X_dense, 292 | X_data, X_indices, X_indptr, X_sparse_scaling) 293 | else: 294 | Xw = np.zeros(n_samples, X.dtype) if pb == LOGREG else y.copy() 295 | 296 | # different link equations and normalization scal for dual point: 297 | if pb in (LASSO, LOGREG): 298 | if pb == LASSO: 299 | theta = Xw.copy() 300 | elif pb == LOGREG: 301 | theta = y / (1 + np .exp(y * Xw)) / alpha 302 | dnorm = dnorm_enet(X, theta, w, weights, X_sparse_scaling, 303 | positive, alpha, l1_ratio) 304 | elif pb == GRPLASSO: 305 | theta = Xw.copy() 306 | dnorm = dnorm_grp( 307 | is_sparse, theta, grp_ptr, grp_indices, X_dense, 308 | X_data, X_indices, X_indptr, X_sparse_scaling, 309 | weights, len(grp_ptr) - 1, np.zeros(1, dtype=np.int32), 310 | X_sparse_scaling.any()) 311 | 312 | theta /= max(dnorm / (alpha * l1_ratio), n_samples) 313 | 314 | # celer modifies w, Xw, and theta in place: 315 | if pb == GRPLASSO: 316 | # TODO this if else scheme is complicated 317 | sol = celer_grp( 318 | is_sparse, LASSO, X_dense, grp_indices, grp_ptr, X_data, 319 | X_indices, X_indptr, X_sparse_scaling, y, alpha, w, Xw, theta, 320 | norms_X_grp, tol, weights, max_iter, max_epochs, p0=p0, 321 | prune=prune, verbose=verbose) 322 | # TODO handle case of enet 323 | elif pb == LASSO or (pb == LOGREG and not use_PN): 324 | sol = celer( 325 | is_sparse, pb, 326 | X_dense, X_data, X_indices, X_indptr, X_sparse_scaling, y, 327 | alpha, l1_ratio, w, Xw, theta, norms_X_col, weights, 328 | max_iter=max_iter, max_epochs=max_epochs, 329 | p0=p0, verbose=verbose, use_accel=1, tol=tol, prune=prune, 330 | positive=positive) 331 | else: # pb == LOGREG and use_PN 332 | sol = newton_celer( 333 | is_sparse, X_dense, X_data, X_indices, X_indptr, y, alpha, w, 334 | max_iter, tol=tol, p0=p0, verbose=verbose, prune=prune) 335 | 336 | coefs[:, t], thetas[t], dual_gaps[t] = sol[0], sol[1], sol[2][-1] 337 | if return_n_iter: 338 | n_iters[t] = len(sol[2]) 339 | 340 | results = alphas, coefs, dual_gaps 341 | if return_thetas: 342 | results += (thetas,) 343 | if return_n_iter: 344 | results += (n_iters,) 345 | 346 | return results 347 | 348 | 349 | def _check_weights(weights, pb, X, n_groups): 350 | """Handle weights cases.""" 351 | if weights is None: 352 | n_weights = n_groups if pb == GRPLASSO else X.shape[1] 353 | weights = np.ones(n_weights, dtype=X.dtype) 354 | elif (weights <= 0).any(): 355 | raise ValueError("0 or negative weights are not supported.") 356 | else: 357 | expected_n_weights = n_groups if pb == GRPLASSO else X.shape[1] 358 | feat_or_grp = "groups" if pb == GRPLASSO else "features" 359 | 360 | if weights.shape[0] != expected_n_weights: 361 | raise ValueError( 362 | f"As many weights as {feat_or_grp} must be passed. " 363 | f"Expected {expected_n_weights}, got {weights.shape[0]}." 364 | ) 365 | 366 | return weights 367 | 368 | 369 | def _sparse_and_dense(X): 370 | if sparse.issparse(X): 371 | X_dense = np.empty([1, 1], order='F', dtype=X.data.dtype) 372 | X_data = X.data 373 | X_indptr = X.indptr 374 | X_indices = X.indices 375 | else: 376 | X_dense = X 377 | X_data = np.empty([1], dtype=X.dtype) 378 | X_indices = np.empty([1], dtype=np.int32) 379 | X_indptr = np.empty([1], dtype=np.int32) 380 | return X_dense, X_data, X_indices, X_indptr 381 | 382 | 383 | def dnorm_enet(X, theta, w, weights, X_sparse_scaling, 384 | positive, alpha=1.0, l1_ratio=1.0): 385 | """Theta should be centered.""" 386 | X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X) 387 | skip = np.zeros(X.shape[1], dtype=np.int32) 388 | dnorm = dnorm_enet_cython( 389 | sparse.issparse(X), theta, w, X_dense, X_data, X_indices, X_indptr, 390 | skip, X_sparse_scaling, weights, X_sparse_scaling.any(), positive, 391 | alpha, l1_ratio) 392 | return dnorm 393 | 394 | 395 | def _alpha_max_grp(X, y, groups, center=False, normalize=False): 396 | """This costly function (copies X) should only be used for debug.""" 397 | grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) 398 | X, y, X_offset, _, X_scale = _preprocess_data( 399 | X, y, center, normalize, copy=True) 400 | 401 | X_mean = X_offset / X_scale 402 | X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X) 403 | alpha_max = dnorm_grp( 404 | sparse.issparse(X), y, grp_ptr, grp_indices, X_dense, X_data, 405 | X_indices, X_indptr, X_mean, len(grp_ptr) - 1, 406 | np.zeros(1, dtype=np.int32), X_mean.any()) / len(y) 407 | 408 | return alpha_max 409 | 410 | 411 | def _grp_converter(groups, n_features): 412 | if isinstance(groups, int): 413 | grp_size = groups 414 | if n_features % grp_size != 0: 415 | raise ValueError("n_features (%d) is not a multiple of the desired" 416 | " group size (%d)" % (n_features, grp_size)) 417 | n_groups = n_features // grp_size 418 | grp_ptr = grp_size * np.arange(n_groups + 1) 419 | grp_indices = np.arange(n_features) 420 | elif isinstance(groups, list) and isinstance(groups[0], int): 421 | grp_indices = np.arange(n_features).astype(np.int32) 422 | grp_ptr = np.cumsum(np.hstack([[0], groups])) 423 | elif isinstance(groups, list) and isinstance(groups[0], list): 424 | grp_sizes = np.array([len(ls) for ls in groups]) 425 | grp_ptr = np.cumsum(np.hstack([[0], grp_sizes])) 426 | grp_indices = np.array([idx for grp in groups for idx in grp]) 427 | else: 428 | raise ValueError("Unsupported group format.") 429 | return grp_ptr.astype(np.int32), grp_indices.astype(np.int32) 430 | 431 | 432 | def mtl_path( 433 | X, Y, eps=1e-2, n_alphas=100, alphas=None, max_iter=100, 434 | max_epochs=50_000, p0=10, verbose=0, tol=1e-6, 435 | prune=True, return_thetas=False, coef_init=None): 436 | X = check_array(X, "csc", dtype=[ 437 | np.float64, np.float32], order="F", copy=False) 438 | Y = check_array(Y, "csc", dtype=[ 439 | np.float64, np.float32], order="F", copy=False) 440 | n_samples, n_features = X.shape 441 | n_tasks = Y.shape[1] 442 | if alphas is None: 443 | alpha_max = np.max(norm(X.T @ Y, ord=2, axis=1)) / n_samples 444 | alphas = alpha_max * \ 445 | np.geomspace(1, eps, n_alphas, dtype=X.dtype) 446 | else: 447 | alphas = np.sort(alphas)[::-1] 448 | 449 | n_alphas = len(alphas) 450 | 451 | coefs = np.zeros((n_features, n_tasks, n_alphas), order="F", 452 | dtype=X.dtype) 453 | 454 | thetas = np.zeros((n_alphas, n_samples, n_tasks), dtype=X.dtype) 455 | gaps = np.zeros(n_alphas) 456 | 457 | norms_X_col = np.linalg.norm(X, axis=0) 458 | Y = np.asfortranarray(Y) 459 | R = Y.copy(order='F') 460 | theta = np.zeros_like(Y, order='F') 461 | 462 | # do not skip alphas[0], it is not always alpha_max 463 | for t in range(n_alphas): 464 | alpha = alphas[t] 465 | 466 | if verbose: 467 | msg = "##### Computing alpha %d/%d" % (t + 1, n_alphas) 468 | print("#" * len(msg)) 469 | print(msg) 470 | print("#" * len(msg)) 471 | if t > 0: 472 | W = coefs[:, :, t - 1].copy() 473 | p_t = max(len(np.where(W[:, 0] != 0)[0]), p0) 474 | else: 475 | if coef_init is not None: 476 | W = coef_init.T 477 | R = np.asfortranarray(Y - X @ W) 478 | p_t = max(len(np.where(W[:, 0] != 0)[0]), p0) 479 | else: 480 | W = np.zeros((n_features, n_tasks), dtype=X.dtype) 481 | p_t = 10 482 | 483 | sol = celer_mtl( 484 | X, Y, alpha, W, R, theta, norms_X_col, p0=p_t, tol=tol, 485 | prune=prune, max_iter=max_iter, max_epochs=max_epochs, 486 | verbose=verbose, use_accel=True) 487 | 488 | coefs[:, :, t], thetas[t], gaps[t] = sol[0], sol[1], sol[2] 489 | 490 | coefs = np.swapaxes(coefs, 0, 1).copy('F') 491 | 492 | if return_thetas: 493 | return alphas, coefs, gaps, thetas 494 | 495 | return alphas, coefs, gaps 496 | -------------------------------------------------------------------------------- /celer/lasso_fast.pyx: -------------------------------------------------------------------------------- 1 | #cython: language_level=3 2 | # Author: Mathurin Massias 3 | # License: BSD 3 clause 4 | 5 | import numpy as np 6 | cimport numpy as np 7 | cimport cython 8 | import warnings 9 | 10 | from cython cimport floating 11 | from libc.math cimport fabs, sqrt, exp, INFINITY 12 | from sklearn.exceptions import ConvergenceWarning 13 | 14 | from .cython_utils cimport fdot, fasum, faxpy, fnrm2, fcopy, fscal, fposv 15 | from .cython_utils cimport (primal, dual, create_dual_pt, create_accel_pt, 16 | sigmoid, ST, LASSO, LOGREG, dnorm_enet, 17 | set_prios, fweighted_norm_w2) 18 | 19 | 20 | @cython.boundscheck(False) 21 | @cython.wraparound(False) 22 | @cython.cdivision(True) 23 | def celer( 24 | bint is_sparse, int pb, floating[::1, :] X, floating[:] X_data, 25 | int[:] X_indices, int[:] X_indptr, floating[:] X_mean, 26 | floating[:] y, floating alpha, floating l1_ratio, floating[:] w, floating[:] Xw, 27 | floating[:] theta, floating[:] norms_X_col, floating[:] weights, 28 | int max_iter, int max_epochs, int gap_freq=10, 29 | float tol=1e-6, int p0=100, int verbose=0, 30 | int use_accel=1, int prune=0, bint positive=0, 31 | int better_lc=1): 32 | """R/Xw and w are modified in place and assumed to match. 33 | Weights must be > 0, features with weights equal to np.inf are ignored. 34 | WARNING for Logreg the datafit is a sum, while for Lasso it is a mean. 35 | """ 36 | assert pb in (LASSO, LOGREG) 37 | 38 | if floating is double: 39 | dtype = np.float64 40 | else: 41 | dtype = np.float32 42 | 43 | cdef int inc = 1 44 | cdef int verbose_in = max(0, verbose - 1) 45 | cdef int n_features = w.shape[0] 46 | cdef int n_samples = y.shape[0] 47 | 48 | # scale stopping criterion: multiply tol by primal value at w = 0 49 | if pb == LASSO: 50 | # actually for Lasso, omit division by 2 to match sklearn 51 | tol *= fnrm2(&n_samples, &y[0], &inc) ** 2 / n_samples 52 | elif pb == LOGREG: 53 | tol *= n_samples * np.log(2) 54 | 55 | if p0 > n_features: 56 | p0 = n_features 57 | 58 | cdef int t = 0 59 | cdef int i, j, k, idx, startptr, endptr, epoch 60 | cdef int ws_size = 0 61 | cdef int nnz = 0 62 | cdef floating gap = -1 # initialized for the warning if max_iter=0 63 | cdef floating p_obj, d_obj, highest_d_obj, radius, tol_in 64 | cdef floating gap_in, p_obj_in, d_obj_in, d_obj_accel, highest_d_obj_in 65 | cdef floating theta_scaling, R_sum, tmp, tmp_exp, dnorm_XTtheta 66 | cdef int n_screened = 0 67 | cdef bint center = False 68 | cdef floating old_w_j, X_mean_j 69 | cdef floating[:] prios = np.empty(n_features, dtype=dtype) 70 | cdef int[:] screened = np.zeros(n_features, dtype=np.int32) 71 | cdef int[:] notin_ws = np.zeros(n_features, dtype=np.int32) 72 | 73 | 74 | # acceleration variables: 75 | cdef int K = 6 76 | cdef floating[:, :] last_K_Xw = np.empty([K, n_samples], dtype=dtype) 77 | cdef floating[:, :] U = np.empty([K - 1, n_samples], dtype=dtype) 78 | cdef floating[:, :] UtU = np.empty([K - 1, K - 1], dtype=dtype) 79 | cdef floating[:] onesK = np.ones(K - 1, dtype=dtype) 80 | cdef int info_dposv 81 | 82 | if is_sparse: 83 | # center = X_mean.any(): 84 | for j in range(n_features): 85 | if X_mean[j]: 86 | center = True 87 | break 88 | 89 | # TODO this is used only for logreg, L97 is misleading and deserves a comment/refactoring 90 | cdef floating[:] inv_lc = np.zeros(n_features, dtype=dtype) 91 | 92 | for j in range(n_features): 93 | # can have 0 features when performing CV on sparse X 94 | if norms_X_col[j]: 95 | if pb == LOGREG: 96 | inv_lc[j] = 4. / norms_X_col[j] ** 2 97 | else: 98 | inv_lc[j] = 1. / norms_X_col[j] ** 2 99 | 100 | cdef floating norm_y2 = fnrm2(&n_samples, &y[0], &inc) ** 2 101 | cdef floating weighted_norm_w2 = fweighted_norm_w2(w, weights) 102 | theta_scaling = 1.0 103 | 104 | # max_iter + 1 is to deal with max_iter=0 105 | cdef floating[:] gaps = np.zeros(max_iter + 1, dtype=dtype) 106 | gaps[0] = -1 107 | 108 | cdef floating[:] theta_in = np.zeros(n_samples, dtype=dtype) 109 | cdef floating[:] thetacc = np.zeros(n_samples, dtype=dtype) 110 | cdef floating d_obj_from_inner = 0. 111 | 112 | cdef int[:] ws 113 | cdef int[:] all_features = np.arange(n_features, dtype=np.int32) 114 | 115 | for t in range(max_iter): 116 | if t != 0: 117 | create_dual_pt(pb, n_samples, &theta[0], &Xw[0], &y[0]) 118 | 119 | dnorm_XTtheta = dnorm_enet( 120 | is_sparse, theta, w, X, X_data, X_indices, X_indptr, screened, 121 | X_mean, weights, center, positive, alpha, l1_ratio) 122 | 123 | if dnorm_XTtheta > alpha * l1_ratio: 124 | theta_scaling = alpha * l1_ratio / dnorm_XTtheta 125 | fscal(&n_samples, &theta_scaling, &theta[0], &inc) 126 | else: 127 | theta_scaling = 1. 128 | 129 | # compute ||w||^2 only for Enet 130 | if l1_ratio != 1: 131 | weighted_norm_w2 = fweighted_norm_w2(w, weights) 132 | 133 | d_obj = dual(pb, n_samples, alpha, l1_ratio, norm_y2, 134 | theta_scaling**2*weighted_norm_w2, &theta[0], &y[0]) 135 | 136 | # also test dual point returned by inner solver after 1st iter: 137 | dnorm_XTtheta = dnorm_enet( 138 | is_sparse, theta_in, w, X, X_data, X_indices, X_indptr, 139 | screened, X_mean, weights, center, positive, alpha, l1_ratio) 140 | 141 | if dnorm_XTtheta > alpha * l1_ratio: 142 | theta_scaling = alpha * l1_ratio / dnorm_XTtheta 143 | fscal(&n_samples, &theta_scaling, &theta_in[0], &inc) 144 | else: 145 | theta_scaling = 1. 146 | 147 | d_obj_from_inner = dual(pb, n_samples, alpha, l1_ratio, norm_y2, 148 | theta_scaling**2*weighted_norm_w2, &theta_in[0], &y[0]) 149 | else: 150 | d_obj = dual(pb, n_samples, alpha, l1_ratio, norm_y2, 151 | theta_scaling**2*weighted_norm_w2, &theta[0], &y[0]) 152 | 153 | if d_obj_from_inner > d_obj: 154 | d_obj = d_obj_from_inner 155 | fcopy(&n_samples, &theta_in[0], &inc, &theta[0], &inc) 156 | 157 | highest_d_obj = d_obj # TODO monotonicity could be enforced but it 158 | # would add yet another variable, best_theta. I'm not sure it brings 159 | # anything. 160 | 161 | p_obj = primal(pb, alpha, l1_ratio, Xw, y, w, weights) 162 | gap = p_obj - highest_d_obj 163 | gaps[t] = gap 164 | if verbose: 165 | print("Iter %d: primal %.10f, gap %.2e" % (t, p_obj, gap), end="") 166 | 167 | if gap <= tol: 168 | if verbose: 169 | print("\nEarly exit, gap: %.2e < %.2e" % (gap, tol)) 170 | break 171 | 172 | if pb == LASSO: 173 | radius = sqrt(2 * gap / n_samples) 174 | else: 175 | radius = sqrt(gap / 2.) 176 | set_prios( 177 | is_sparse, theta, w, alpha, l1_ratio, X, X_data, X_indices, X_indptr, norms_X_col, 178 | weights, prios, screened, radius, &n_screened, positive) 179 | 180 | if prune: 181 | nnz = 0 182 | for j in range(n_features): 183 | if w[j] != 0: 184 | prios[j] = -1. 185 | nnz += 1 186 | 187 | if t == 0: 188 | ws_size = p0 if nnz == 0 else nnz 189 | else: 190 | ws_size = 2 * nnz 191 | 192 | else: 193 | for j in range(n_features): 194 | if w[j] != 0: 195 | prios[j] = - 1 # include active features 196 | if t == 0: 197 | ws_size = p0 198 | else: 199 | for j in range(ws_size): 200 | if not screened[ws[j]]: 201 | # include previous features, if not screened 202 | prios[ws[j]] = -1 203 | ws_size = 2 * ws_size 204 | if ws_size > n_features - n_screened: 205 | ws_size = n_features - n_screened 206 | 207 | 208 | # if ws_size === n_features then argpartition will break: 209 | if ws_size == n_features: 210 | ws = all_features 211 | else: 212 | ws = np.argpartition(np.asarray(prios), ws_size)[:ws_size].astype(np.int32) 213 | 214 | for j in range(n_features): 215 | notin_ws[j] = 1 216 | for idx in range(ws_size): 217 | notin_ws[ws[idx]] = 0 218 | 219 | if prune: 220 | tol_in = 0.3 * gap 221 | else: 222 | tol_in = tol 223 | 224 | if verbose: 225 | print(", %d feats in subpb (%d left)" % 226 | (len(ws), n_features - n_screened)) 227 | 228 | # calling inner solver which will modify w and R inplace 229 | highest_d_obj_in = 0 230 | for epoch in range(max_epochs): 231 | if epoch != 0 and epoch % gap_freq == 0: 232 | create_dual_pt( 233 | pb, n_samples, &theta_in[0], &Xw[0], &y[0]) 234 | 235 | dnorm_XTtheta = dnorm_enet( 236 | is_sparse, theta_in, w, X, X_data, X_indices, X_indptr, 237 | notin_ws, X_mean, weights, center, positive, alpha, l1_ratio) 238 | 239 | if dnorm_XTtheta > alpha * l1_ratio: 240 | theta_scaling = alpha * l1_ratio / dnorm_XTtheta 241 | fscal(&n_samples, &theta_scaling, &theta_in[0], &inc) 242 | else: 243 | theta_scaling = 1. 244 | 245 | # update norm_w2 in inner loop for Enet only 246 | if l1_ratio != 1: 247 | weighted_norm_w2 = fweighted_norm_w2(w, weights) 248 | d_obj_in = dual(pb, n_samples, alpha, l1_ratio, norm_y2, 249 | theta_scaling**2*weighted_norm_w2, &theta_in[0], &y[0]) 250 | 251 | if use_accel: # also compute accelerated dual_point 252 | info_dposv = create_accel_pt( 253 | pb, n_samples, epoch, gap_freq, &Xw[0], 254 | &thetacc[0], &last_K_Xw[0, 0], U, UtU, onesK, y) 255 | 256 | if info_dposv != 0 and verbose_in: 257 | pass 258 | # print("linear system solving failed") 259 | 260 | if epoch // gap_freq >= K: 261 | dnorm_XTtheta = dnorm_enet( 262 | is_sparse, thetacc, w, X, X_data, X_indices, 263 | X_indptr, notin_ws, X_mean, weights, center, 264 | positive, alpha, l1_ratio) 265 | 266 | if dnorm_XTtheta > alpha * l1_ratio: 267 | theta_scaling = alpha * l1_ratio / dnorm_XTtheta 268 | fscal(&n_samples, &theta_scaling, &thetacc[0], &inc) 269 | else: 270 | theta_scaling = 1. 271 | 272 | d_obj_accel = dual(pb, n_samples, alpha, l1_ratio, norm_y2, 273 | theta_scaling**2*weighted_norm_w2, &thetacc[0], &y[0]) 274 | if d_obj_accel > d_obj_in: 275 | d_obj_in = d_obj_accel 276 | fcopy(&n_samples, &thetacc[0], &inc, 277 | &theta_in[0], &inc) 278 | 279 | if d_obj_in > highest_d_obj_in: 280 | highest_d_obj_in = d_obj_in 281 | 282 | # CAUTION: code does not yet include a best_theta. 283 | # Can be an issue in screening: dgap and theta might disagree. 284 | 285 | p_obj_in = primal(pb, alpha, l1_ratio, Xw, y, w, weights) 286 | gap_in = p_obj_in - highest_d_obj_in 287 | 288 | if verbose_in: 289 | print("Epoch %d, primal %.10f, gap: %.2e" % 290 | (epoch, p_obj_in, gap_in)) 291 | if gap_in < tol_in: 292 | if verbose_in: 293 | print("Exit epoch %d, gap: %.2e < %.2e" % \ 294 | (epoch, gap_in, tol_in)) 295 | break 296 | 297 | for k in range(ws_size): 298 | j = ws[k] 299 | if norms_X_col[j] == 0. or weights[j] == INFINITY: 300 | continue 301 | old_w_j = w[j] 302 | 303 | if pb == LASSO: 304 | if is_sparse: 305 | X_mean_j = X_mean[j] 306 | startptr, endptr = X_indptr[j], X_indptr[j + 1] 307 | for i in range(startptr, endptr): 308 | w[j] += Xw[X_indices[i]] * X_data[i] / \ 309 | norms_X_col[j] ** 2 310 | if center: 311 | R_sum = 0. 312 | for i in range(n_samples): 313 | R_sum += Xw[i] 314 | w[j] -= R_sum * X_mean_j / norms_X_col[j] ** 2 315 | else: 316 | w[j] += fdot(&n_samples, &X[0, j], &inc, &Xw[0], 317 | &inc) / norms_X_col[j] ** 2 318 | 319 | if positive and w[j] <= 0.: 320 | w[j] = 0. 321 | else: 322 | if l1_ratio != 1.: 323 | w[j] = ST( 324 | w[j], 325 | alpha * l1_ratio / norms_X_col[j] ** 2 * n_samples * weights[j]) / \ 326 | (1 + alpha * (1 - l1_ratio) * weights[j] / norms_X_col[j] ** 2 * n_samples) 327 | else: 328 | w[j] = ST( 329 | w[j], 330 | alpha / norms_X_col[j] ** 2 * n_samples * weights[j]) 331 | 332 | # R -= (w_j - old_w_j) * (X[:, j] - X_mean[j]) 333 | tmp = old_w_j - w[j] 334 | if tmp != 0.: 335 | if is_sparse: 336 | for i in range(startptr, endptr): 337 | Xw[X_indices[i]] += tmp * X_data[i] 338 | if center: 339 | for i in range(n_samples): 340 | Xw[i] -= X_mean_j * tmp 341 | else: 342 | faxpy(&n_samples, &tmp, &X[0, j], &inc, 343 | &Xw[0], &inc) 344 | else: 345 | if is_sparse: 346 | startptr = X_indptr[j] 347 | endptr = X_indptr[j + 1] 348 | if better_lc: 349 | tmp = 0. 350 | for i in range(startptr, endptr): 351 | tmp_exp = exp(Xw[X_indices[i]]) 352 | tmp += X_data[i] ** 2 * tmp_exp / \ 353 | (1. + tmp_exp) ** 2 354 | inv_lc[j] = 1. / tmp 355 | else: 356 | if better_lc: 357 | tmp = 0. 358 | for i in range(n_samples): 359 | tmp_exp = exp(Xw[i]) 360 | tmp += (X[i, j] ** 2) * tmp_exp / \ 361 | (1. + tmp_exp) ** 2 362 | inv_lc[j] = 1. / tmp 363 | 364 | tmp = 0. # tmp = dot(Xj, y * sigmoid(-y * w)) / lc[j] 365 | if is_sparse: 366 | for i in range(startptr, endptr): 367 | idx = X_indices[i] 368 | tmp += X_data[i] * y[idx] * \ 369 | sigmoid(- y[idx] * Xw[idx]) 370 | else: 371 | for i in range(n_samples): 372 | tmp += X[i, j] * y[i] * sigmoid(- y[i] * Xw[i]) 373 | 374 | w[j] = ST(w[j] + tmp * inv_lc[j], 375 | alpha * inv_lc[j] * weights[j]) 376 | 377 | tmp = w[j] - old_w_j 378 | if tmp != 0.: 379 | if is_sparse: 380 | for i in range(startptr, endptr): 381 | Xw[X_indices[i]] += tmp * X_data[i] 382 | else: 383 | faxpy(&n_samples, &tmp, &X[0, j], &inc, 384 | &Xw[0], &inc) 385 | else: 386 | warnings.warn( 387 | 'Inner solver did not converge at ' + 388 | f'epoch: {epoch}, gap: {gap_in:.2e} > {tol_in:.2e}', 389 | ConvergenceWarning) 390 | else: 391 | warnings.warn( 392 | 'Objective did not converge: duality ' + 393 | f'gap: {gap}, tolerance: {tol}. Increasing `tol` may make the' + 394 | ' solver faster without affecting the results much. \n' + 395 | 'Fitting data with very small alpha causes precision issues.', 396 | ConvergenceWarning) 397 | 398 | return np.asarray(w), np.asarray(theta), np.asarray(gaps[:t + 1]) 399 | 400 | -------------------------------------------------------------------------------- /celer/multitask_fast.pyx: -------------------------------------------------------------------------------- 1 | #cython: language_level=3 2 | cimport cython 3 | cimport numpy as np 4 | 5 | import numpy as np 6 | import warnings 7 | from cython cimport floating 8 | from libc.math cimport fabs, sqrt, INFINITY 9 | from sklearn.exceptions import ConvergenceWarning 10 | 11 | from .cython_utils cimport fscal, fcopy, fnrm2, fdot, faxpy 12 | from .cython_utils cimport LASSO, create_accel_pt 13 | 14 | @cython.boundscheck(False) 15 | @cython.wraparound(False) 16 | @cython.cdivision(True) 17 | cdef void BST(int n_tasks, floating * x, floating u) nogil: 18 | cdef int inc = 1 19 | cdef int k 20 | cdef floating tmp 21 | cdef floating norm_x = fnrm2(&n_tasks, x, &inc) 22 | if norm_x < u: 23 | for k in range(n_tasks): 24 | x[k] = 0. 25 | else: 26 | tmp = 1. - u / norm_x 27 | fscal(&n_tasks, &tmp, x, &inc) 28 | 29 | 30 | @cython.boundscheck(False) 31 | @cython.wraparound(False) 32 | @cython.cdivision(True) 33 | cdef floating dual_scaling_mtl( 34 | int n_features, int n_samples, int n_tasks, floating[::1, :] theta, 35 | floating[::1, :] X, int ws_size, int * C, int * screened, 36 | floating * Xj_theta) nogil: 37 | cdef int ind, j, k 38 | cdef int inc = 1 39 | cdef floating tmp 40 | cdef floating dnorm_XTtheta = 0. 41 | 42 | if ws_size == n_features: 43 | for j in range(n_features): 44 | if screened[j]: 45 | continue 46 | for k in range(n_tasks): 47 | Xj_theta[k] = fdot(&n_samples, &theta[0, k], &inc, &X[0, j], &inc) 48 | tmp = fnrm2(&n_tasks, &Xj_theta[0], &inc) 49 | if tmp > dnorm_XTtheta: 50 | dnorm_XTtheta = tmp 51 | else: 52 | for ind in range(ws_size): 53 | j = C[ind] 54 | for k in range(n_tasks): 55 | Xj_theta[k] = fdot(&n_samples, &theta[0, k], &inc, &X[0, j], &inc) 56 | tmp = fnrm2(&n_tasks, &Xj_theta[0], &inc) 57 | if tmp > dnorm_XTtheta: 58 | dnorm_XTtheta = tmp 59 | return dnorm_XTtheta 60 | 61 | 62 | @cython.boundscheck(False) 63 | @cython.wraparound(False) 64 | @cython.cdivision(True) 65 | cdef void set_prios_mtl( 66 | floating[:, ::1] W, int[:] screened, 67 | floating[::1, :] X, floating[::1, :] theta, floating alpha, floating[:] norms_X_col, 68 | floating[:] Xj_theta, floating[:] prios, floating radius, 69 | int * n_screened) nogil: 70 | cdef int j, k 71 | cdef int inc = 1 72 | cdef floating nrm = 0. 73 | cdef int n_samples = X.shape[0] 74 | cdef int n_features = X.shape[1] 75 | cdef int n_tasks = W.shape[1] 76 | 77 | for j in range(n_features): 78 | if screened[j]: 79 | prios[j] = INFINITY 80 | continue 81 | for k in range(n_tasks): 82 | Xj_theta[k] = fdot(&n_samples, &theta[0, k], &inc, &X[0, j], &inc) 83 | 84 | nrm = fnrm2(&n_tasks, &Xj_theta[0], &inc) 85 | prios[j] = (alpha - nrm) / norms_X_col[j] 86 | if prios[j] > radius: 87 | # screen only if W[j, :] is zero: 88 | for k in range(n_tasks): 89 | if W[j, k] != 0: 90 | break 91 | else: 92 | screened[j] = True 93 | n_screened[0] += 1 94 | 95 | 96 | @cython.boundscheck(False) 97 | @cython.wraparound(False) 98 | @cython.cdivision(True) 99 | cdef floating dual_mtl( 100 | int n_samples, int n_tasks, floating[::1, :] theta, floating[::1, :] Y, 101 | floating norm_Y2) nogil: 102 | cdef int inc = 1 103 | cdef int i, k 104 | cdef floating d_obj = 0. 105 | 106 | for k in range(n_tasks): 107 | for i in range(n_samples): 108 | d_obj -= (Y[i, k] / n_samples - theta[i, k]) ** 2 109 | d_obj *= 0.5 * n_samples 110 | d_obj += norm_Y2 / (2. * n_samples) 111 | return d_obj 112 | 113 | 114 | @cython.boundscheck(False) 115 | @cython.wraparound(False) 116 | @cython.cdivision(True) 117 | cdef floating primal_mtl( 118 | int n_samples, int n_features, int n_tasks, 119 | floating[:, ::1] W, floating alpha, floating[::1, :] R) nogil: 120 | cdef int inc = 1 121 | cdef int j, k 122 | cdef int n_obs = n_samples * n_tasks 123 | cdef floating p_obj = fnrm2(&n_obs, &R[0, 0], &inc) ** 2 / (2. * n_samples) 124 | 125 | for j in range(n_features): 126 | for k in range(n_tasks): 127 | if W[j, k]: 128 | p_obj += alpha * fnrm2(&n_tasks, &W[j, 0], &inc) 129 | break 130 | 131 | return p_obj 132 | 133 | 134 | @cython.boundscheck(False) 135 | @cython.wraparound(False) 136 | @cython.cdivision(True) 137 | def celer_mtl( 138 | floating[::1, :] X, floating[::1, :] Y, floating alpha, 139 | floating[:, ::1] W, floating[::1, :] R, floating[::1, :] theta, 140 | floating[:] norms_X_col, int max_iter, int max_epochs, 141 | int gap_freq=10, floating tol_ratio=0.3, float tol=1e-6, int p0=100, 142 | int verbose=0, bint use_accel=1, bint prune=1, 143 | int K=6): 144 | 145 | cdef int verbose_inner = max(0, verbose - 1) 146 | if floating is double: 147 | dtype = np.float64 148 | else: 149 | dtype = np.float32 150 | 151 | cdef int n_samples = Y.shape[0] 152 | cdef int n_tasks = Y.shape[1] 153 | cdef int n_features = W.shape[0] 154 | 155 | 156 | if p0 > n_features: 157 | p0 = n_features 158 | 159 | cdef int i, j, k, t 160 | cdef int inc = 1 161 | cdef floating tmp, theta_scaling 162 | cdef int n_obs = n_samples * n_tasks 163 | cdef int ws_size 164 | cdef int nnz = 0 165 | cdef floating p_obj, d_obj, highes_d_obj, gap, radius, dnorm_XTtheta 166 | cdef int n_screened = 0 167 | cdef floating[:] prios = np.empty(n_features, dtype=dtype) 168 | cdef int[:] screened = np.zeros(n_features, dtype=np.int32) 169 | cdef floating[:] Xj_theta = np.empty(n_tasks, dtype=dtype) 170 | 171 | cdef floating norm_Y2 = fnrm2(&n_obs, &Y[0, 0], &inc) ** 2 172 | # scale tolerance to account for small or large Y: 173 | tol *= norm_Y2 / n_samples 174 | 175 | cdef floating[::1, :] theta_inner = np.zeros((n_samples, n_tasks), 176 | dtype=dtype, order='F') 177 | cdef floating[::1, :] theta_to_use 178 | 179 | cdef floating d_obj_from_inner = 0 180 | cdef int[:] dummy_C = np.zeros(1, dtype=np.int32) 181 | cdef int[:] all_features = np.arange(n_features, dtype=np.int32) 182 | cdef int[:] C 183 | cdef floating tol_inner 184 | 185 | for t in range(max_iter): 186 | # if t != 0: TODO 187 | p_obj = primal_mtl(n_samples, n_features, n_tasks, W, alpha, R) 188 | # theta = R : 189 | fcopy(&n_obs, &R[0, 0], &inc, &theta[0, 0], &inc) 190 | 191 | dnorm_XTtheta = dual_scaling_mtl( 192 | n_features, n_samples, n_tasks, theta, X, n_features, 193 | &dummy_C[0], &screened[0], &Xj_theta[0]) 194 | 195 | if dnorm_XTtheta > alpha: 196 | theta_scaling = alpha / dnorm_XTtheta 197 | fscal(&n_obs, &theta_scaling, &theta[0, 0], &inc) 198 | d_obj = dual_mtl(n_samples, n_tasks, theta, Y, norm_Y2) 199 | 200 | if t > 0: 201 | dnorm_XTtheta = dual_scaling_mtl( 202 | n_features, n_samples, n_tasks, theta_inner, X, 203 | n_features, &dummy_C[0], &screened[0], &Xj_theta[0]) 204 | 205 | if dnorm_XTtheta > alpha: 206 | theta_scaling = alpha / dnorm_XTtheta 207 | fscal(&n_obs, &theta_scaling, &theta_inner[0, 0], &inc) 208 | d_obj_from_inner = dual_mtl( 209 | n_samples, n_tasks, theta_inner, Y, norm_Y2) 210 | if d_obj_from_inner > d_obj: 211 | d_obj = d_obj_from_inner 212 | 213 | gap = p_obj - d_obj 214 | if verbose: 215 | print("Iter %d: primal %.10f, gap %.2e" % (t, p_obj, gap), end="") 216 | 217 | if gap <= tol + 1e-16: 218 | if verbose: 219 | print("\nEarly exit, gap %.2e < %.2e" % (gap, tol)) 220 | break 221 | 222 | radius = sqrt(2 * gap / n_samples) 223 | # TODO prios could be computed along with scaling 224 | set_prios_mtl( 225 | W, screened, X, theta, alpha, norms_X_col, Xj_theta, prios, radius, 226 | &n_screened) 227 | 228 | if t == 0: 229 | ws_size = p0 230 | # prios[j] = -1 if W[j, :].any() 231 | for j in range(n_features): 232 | for k in range(n_tasks): 233 | if W[j, k]: 234 | prios[j] = -1 235 | break 236 | else: 237 | nnz = 0 238 | if prune: 239 | for j in range(n_features): 240 | if W[j, 0]: 241 | prios[j] = -1 242 | nnz += 1 243 | ws_size = 2 * nnz 244 | else: 245 | for k in range(ws_size): 246 | if not screened[C[k]]: 247 | prios[C[k]] = -1 248 | ws_size = 2 * ws_size 249 | ws_size = min(n_features - n_screened, ws_size) 250 | 251 | if ws_size == n_features: 252 | C = all_features 253 | else: 254 | C = np.sort(np.argpartition(prios, ws_size)[:ws_size].astype(np.int32)) 255 | 256 | if prune: 257 | tol_inner = tol_ratio * gap 258 | else: 259 | tol_inner = tol 260 | if verbose: 261 | print(", %d feats in subpb (%d left)" % (len(C), n_features - n_screened)) 262 | 263 | inner_solver( 264 | n_samples, n_features, n_tasks, ws_size, X, Y, alpha, W, R, C, 265 | theta_inner, norms_X_col, norm_Y2, tol_inner, max_epochs, 266 | gap_freq, verbose_inner, use_accel, K) 267 | 268 | else: 269 | warnings.warn( 270 | 'Objective did not converge: duality ' + 271 | f'gap: {gap}, tolerance: {tol}. Increasing `tol` may make the' + 272 | ' solver faster without affecting the results much. \n' + 273 | 'Fitting data with very small alpha causes precision issues.', 274 | ConvergenceWarning) 275 | return (np.asarray(W), np.asarray(theta), gap) 276 | 277 | 278 | @cython.cdivision(True) 279 | @cython.boundscheck(False) 280 | @cython.wraparound(False) 281 | cpdef void inner_solver( 282 | int n_samples, int n_features, int n_tasks, int ws_size, 283 | floating[::1, :] X, floating[::1, :] Y, floating alpha, 284 | floating[:, ::1] W, floating[::1, :] R, int[:] C, 285 | floating[::1, :] theta, floating[:] norms_X_col, 286 | floating norm_Y2, floating eps, int max_epochs, 287 | int gap_freq, bint verbose, bint use_accel=1, 288 | int K=6): 289 | 290 | if floating is double: 291 | dtype = np.float64 292 | else: 293 | dtype = np.float32 294 | 295 | cdef floating p_obj, d_obj, gap 296 | cdef floating highest_d_obj = 0. 297 | cdef int i, j, k, epoch, ind 298 | cdef floating[:] old_Wj = np.empty(n_tasks, dtype=dtype) 299 | cdef int inc = 1 300 | cdef int n_obs = n_samples * n_tasks 301 | cdef floating tmp, dnorm_XTtheta, theta_scaling 302 | cdef int[:] dummy_screened = np.zeros(1, dtype=np.int32) 303 | cdef floating[:] Xj_theta = np.empty(n_tasks, dtype=dtype) 304 | 305 | 306 | # acceleration: 307 | cdef floating[::1, :] theta_acc = np.empty([n_samples, n_tasks], 308 | dtype=dtype, order='F') 309 | cdef floating d_obj_acc = 0 310 | cdef floating[:, :] last_K_R = np.empty([K, n_obs], dtype=dtype) 311 | cdef floating[:, :] U = np.empty([K - 1, n_obs], dtype=dtype) 312 | cdef floating[:, :] UtU = np.empty([K - 1, K - 1], dtype=dtype) 313 | cdef floating[:] onesK = np.ones(K - 1, dtype=dtype) 314 | # doc at https://software.intel.com/en-us/node/468894 315 | cdef char * char_U = 'U' 316 | cdef int Kminus1 = K - 1 317 | cdef int one = 1 318 | cdef floating sum_z 319 | cdef int info_dposv 320 | #################### 321 | 322 | for epoch in range(max_epochs): 323 | if epoch > 0 and epoch % gap_freq == 0: 324 | p_obj = primal_mtl(n_samples, n_features, n_tasks, W, alpha, R) 325 | fcopy(&n_obs, &R[0, 0], &inc, &theta[0, 0], &inc) 326 | 327 | tmp = 1. / n_samples 328 | fscal(&n_obs, &tmp, &theta[0, 0], &inc) 329 | 330 | dnorm_XTtheta = dual_scaling_mtl( 331 | n_features, n_samples, n_tasks, theta, X, ws_size, 332 | &C[0], &dummy_screened[0], &Xj_theta[0]) 333 | 334 | if dnorm_XTtheta > alpha: 335 | theta_scaling = alpha / dnorm_XTtheta 336 | fscal(&n_obs, &theta_scaling, &theta[0, 0], &inc) 337 | d_obj = dual_mtl(n_samples, n_tasks, theta, Y, norm_Y2) 338 | 339 | if use_accel: 340 | create_accel_pt( 341 | LASSO, n_obs, epoch, gap_freq, 342 | &R[0, 0], &theta_acc[0, 0], &last_K_R[0, 0], U, UtU, 343 | onesK, onesK) # passing onesK as y which is ignored 344 | # account for wrong n_samples passed to create_accel_pt 345 | tmp = n_tasks 346 | fscal(&n_obs, &tmp, &theta_acc[0, 0], &inc) 347 | if epoch // gap_freq >= K: 348 | dnorm_XTtheta = dual_scaling_mtl( 349 | n_features, n_samples, n_tasks, theta_acc, X, ws_size, 350 | &C[0], &dummy_screened[0], &Xj_theta[0]) 351 | 352 | if dnorm_XTtheta > alpha: 353 | theta_scaling = alpha / dnorm_XTtheta 354 | fscal(&n_obs, &theta_scaling, &theta_acc[0, 0], &inc) 355 | d_obj_acc = dual_mtl( 356 | n_samples, n_tasks, theta_acc, Y, norm_Y2) 357 | if d_obj_acc > d_obj: 358 | d_obj = d_obj_acc 359 | fcopy(&n_obs, &theta_acc[0, 0], &inc, &theta[0, 0], 360 | &inc) 361 | highest_d_obj = max(highest_d_obj, d_obj) 362 | gap = p_obj - highest_d_obj 363 | if verbose: 364 | print("Inner epoch %d, primal %.10f, gap: %.2e" % (epoch, p_obj, gap)) 365 | if gap < eps: 366 | if verbose: 367 | print("Inner: early exit at epoch %d, gap: %.2e < %.2e" % \ 368 | (epoch, gap, eps)) 369 | break 370 | 371 | for ind in range(ws_size): 372 | j = C[ind] 373 | fcopy(&n_tasks, &W[j, 0], &inc, &old_Wj[0], &inc) 374 | 375 | for k in range(n_tasks): 376 | tmp = fdot(&n_samples, &X[0, j], &inc, &R[0, k], &inc) 377 | W[j, k] += tmp / norms_X_col[j] ** 2 378 | BST(n_tasks, &W[j, 0], alpha / norms_X_col[j] ** 2 * n_samples) 379 | 380 | for k in range(n_tasks): 381 | tmp = old_Wj[k] - W[j, k] 382 | if tmp != 0.: 383 | # for i in range(n_samples): 384 | # R[i, k] += tmp * X[i, j] 385 | faxpy(&n_samples, &tmp, &X[0, j], &inc, &R[0, k], &inc) 386 | 387 | -------------------------------------------------------------------------------- /celer/tests/test_docstring_parameters.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os.path as op 3 | import re 4 | import sys 5 | from unittest import SkipTest 6 | import warnings 7 | 8 | from pkgutil import walk_packages 9 | from inspect import getsource 10 | 11 | from numpydoc import docscrape 12 | import celer 13 | 14 | # copied from sklearn.fixes 15 | if hasattr(inspect, 'signature'): # py35 16 | def _get_args(function, varargs=False): 17 | params = inspect.signature(function).parameters 18 | args = [key for key, param in params.items() 19 | if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)] 20 | if varargs: 21 | varargs = [param.name for param in params.values() 22 | if param.kind == param.VAR_POSITIONAL] 23 | if len(varargs) == 0: 24 | varargs = None 25 | return args, varargs 26 | else: 27 | return args 28 | else: 29 | def _get_args(function, varargs=False): 30 | out = inspect.getargspec(function) # args, varargs, keywords, defaults 31 | if varargs: 32 | return out[:2] 33 | else: 34 | return out[0] 35 | 36 | 37 | public_modules = [ 38 | # the list of modules users need to access for all functionality 39 | 'celer', 40 | 'celer.datasets' 41 | ] 42 | 43 | 44 | def get_name(func): 45 | """Get the name.""" 46 | parts = [] 47 | module = inspect.getmodule(func) 48 | if module: 49 | parts.append(module.__name__) 50 | if hasattr(func, 'im_class'): 51 | parts.append(func.im_class.__name__) 52 | parts.append(func.__name__) 53 | return '.'.join(parts) 54 | 55 | 56 | # functions to ignore args / docstring of 57 | _docstring_ignores = [ 58 | "celer.dropin_sklearn.Lasso.path", 59 | "celer.dropin_sklearn.path", 60 | "celer.dropin_sklearn.LassoCV.path", 61 | "celer.dropin_sklearn.MultitaskLasso.fit", 62 | "celer.dropin_sklearn.fit", 63 | ] 64 | _tab_ignores = [] 65 | 66 | 67 | def check_parameters_match(func, doc=None): 68 | """Check docstring, return list of incorrect results.""" 69 | incorrect = [] 70 | name_ = get_name(func) 71 | if not name_.startswith('celer.'): 72 | return incorrect 73 | if inspect.isdatadescriptor(func): 74 | return incorrect 75 | args = _get_args(func) 76 | # drop self 77 | if len(args) > 0 and args[0] == 'self': 78 | args = args[1:] 79 | 80 | if doc is None: 81 | with warnings.catch_warnings(record=True) as w: 82 | try: 83 | doc = docscrape.FunctionDoc(func) 84 | except Exception as exp: 85 | incorrect += [name_ + ' parsing error: ' + str(exp)] 86 | return incorrect 87 | if len(w): 88 | raise RuntimeError('Error for %s:\n%s' % (name_, w[0])) 89 | # check set 90 | param_names = [name for name, _, _ in doc['Parameters']] 91 | # clean up some docscrape output: 92 | param_names = [name.split(':')[0].strip('` ') for name in param_names] 93 | param_names = [name for name in param_names if '*' not in name] 94 | if len(param_names) != len(args): 95 | bad = str(sorted(list(set(param_names) - set(args)) + 96 | list(set(args) - set(param_names)))) 97 | if not any(re.match(d, name_) for d in _docstring_ignores) and \ 98 | 'deprecation_wrapped' not in func.__code__.co_name: 99 | incorrect += [name_ + ' arg mismatch: ' + bad] 100 | else: 101 | for n1, n2 in zip(param_names, args): 102 | if n1 != n2: 103 | incorrect += [name_ + ' ' + n1 + ' != ' + n2] 104 | return incorrect 105 | 106 | 107 | # TODO: readd numpydoc 108 | # @requires_numpydoc 109 | def test_docstring_parameters(): 110 | """Test module docstring formatting.""" 111 | public_modules_ = public_modules[:] 112 | 113 | incorrect = [] 114 | for name in public_modules_: 115 | with warnings.catch_warnings(record=True): # traits warnings 116 | module = __import__(name, globals()) 117 | for submod in name.split('.')[1:]: 118 | module = getattr(module, submod) 119 | classes = inspect.getmembers(module, inspect.isclass) 120 | for cname, cls in classes: 121 | if cname.startswith('_'): 122 | continue 123 | with warnings.catch_warnings(record=True) as w: 124 | cdoc = docscrape.ClassDoc(cls) 125 | if len(w): 126 | raise RuntimeError('Error for __init__ of %s in %s:\n%s' 127 | % (cls, name, w[0])) 128 | if hasattr(cls, '__init__'): 129 | incorrect += check_parameters_match(cls.__init__, cdoc) 130 | for method_name in cdoc.methods: 131 | method = getattr(cls, method_name) 132 | incorrect += check_parameters_match(method) 133 | if hasattr(cls, '__call__'): 134 | incorrect += check_parameters_match(cls.__call__) 135 | functions = inspect.getmembers(module, inspect.isfunction) 136 | for fname, func in functions: 137 | if fname.startswith('_'): 138 | continue 139 | incorrect += check_parameters_match(func) 140 | msg = '\n' + '\n'.join(sorted(list(set(incorrect)))) 141 | if len(incorrect) > 0: 142 | raise AssertionError(msg) 143 | 144 | 145 | def test_tabs(): 146 | """Test that there are no tabs in our source files.""" 147 | ignore = _tab_ignores[:] 148 | 149 | for importer, modname, ispkg in walk_packages(celer.__path__, 150 | prefix='celer.'): 151 | if not ispkg and modname not in ignore: 152 | # mod = importlib.import_module(modname) # not py26 compatible! 153 | try: 154 | with warnings.catch_warnings(record=True): # traits 155 | __import__(modname) 156 | except Exception: # can't import properly 157 | continue 158 | mod = sys.modules[modname] 159 | try: 160 | source = getsource(mod) 161 | except IOError: # user probably should have run "make clean" 162 | continue 163 | assert '\t' not in source, ('"%s" has tabs, please remove them ' 164 | 'or add it to the ignore list' 165 | % modname) 166 | 167 | 168 | documented_ignored_mods = tuple() 169 | documented_ignored_names = """ 170 | """.split('\n') 171 | 172 | 173 | def test_documented(): 174 | """Test that public functions and classes are documented.""" 175 | public_modules_ = public_modules[:] 176 | 177 | doc_file = op.abspath(op.join(op.dirname(__file__), '..', '..', 'doc', 178 | 'api.rst')) 179 | if not op.isfile(doc_file): 180 | raise SkipTest('Documentation file not found: %s' % doc_file) 181 | known_names = list() 182 | with open(doc_file, 'rb') as fid: 183 | for line in fid: 184 | line = line.decode('utf-8') 185 | if not line.startswith(' '): # at least two spaces 186 | continue 187 | line = line.split() 188 | if len(line) == 1 and line[0] != ':': 189 | known_names.append(line[0].split('.')[-1]) 190 | known_names = set(known_names) 191 | 192 | missing = [] 193 | for name in public_modules_: 194 | with warnings.catch_warnings(record=True): # traits warnings 195 | module = __import__(name, globals()) 196 | for submod in name.split('.')[1:]: 197 | module = getattr(module, submod) 198 | classes = inspect.getmembers(module, inspect.isclass) 199 | functions = inspect.getmembers(module, inspect.isfunction) 200 | checks = list(classes) + list(functions) 201 | for name, cf in checks: 202 | if not name.startswith('_') and name not in known_names: 203 | from_mod = inspect.getmodule(cf).__name__ 204 | if (from_mod.startswith('celer') and 205 | from_mod not in documented_ignored_mods and 206 | name not in documented_ignored_names): 207 | missing.append('%s (%s.%s)' % (name, from_mod, name)) 208 | if len(missing) > 0: 209 | raise AssertionError('\n\nFound new public members missing from ' 210 | 'doc/python_reference.rst:\n\n* ' + 211 | '\n* '.join(sorted(set(missing)))) 212 | -------------------------------------------------------------------------------- /celer/tests/test_enet.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import pytest 3 | 4 | import numpy as np 5 | from numpy.linalg import norm 6 | from numpy.testing import (assert_allclose, assert_array_less, assert_equal) 7 | 8 | from sklearn.linear_model import ( 9 | enet_path, ElasticNet as sk_ElasticNet, ElasticNetCV as sk_ElasticNetCV) 10 | 11 | from celer import Lasso, ElasticNet, celer_path, ElasticNetCV 12 | from celer.utils.testing import build_dataset 13 | 14 | 15 | def test_raise_errors_l1_ratio(): 16 | with np.testing.assert_raises(ValueError): 17 | ElasticNet(l1_ratio=5.) 18 | 19 | with np.testing.assert_raises(NotImplementedError): 20 | ElasticNet(l1_ratio=0.) 21 | 22 | with np.testing.assert_raises(NotImplementedError): 23 | X, y = build_dataset(n_samples=30, n_features=50) 24 | y = np.sign(y) 25 | celer_path(X, y, 'logreg', l1_ratio=0.5) 26 | 27 | 28 | @pytest.mark.parametrize("sparse_X", (True, False)) 29 | def test_ElasticNet_Lasso_equivalence(sparse_X): 30 | n_samples, n_features = 50, 100 31 | X, y = build_dataset(n_samples, n_features, sparse_X=sparse_X) 32 | alpha_max = norm(X.T@y, ord=np.inf) / n_samples 33 | 34 | alpha = alpha_max / 100. 35 | coef_lasso = Lasso(alpha=alpha).fit(X, y).coef_ 36 | coef_enet = ElasticNet(alpha=alpha, l1_ratio=1.0).fit(X, y).coef_ 37 | 38 | assert_allclose(coef_lasso, coef_enet) 39 | 40 | np.random.seed(0) 41 | weights = abs(np.random.randn(n_features)) 42 | alpha_max = norm(X.T@y / weights, ord=np.inf) / n_samples 43 | 44 | alpha = alpha_max / 100. 45 | coef_lasso = Lasso(alpha=alpha, weights=weights).fit(X, y).coef_ 46 | coef_enet = ElasticNet(alpha=alpha, l1_ratio=1.0, weights=weights).fit(X, y).coef_ 47 | 48 | assert_allclose(coef_lasso, coef_enet) 49 | 50 | 51 | @pytest.mark.parametrize("prune", (0, 1)) 52 | def test_sk_enet_path_equivalence(prune): 53 | """Test that celer_path matches sklearn enet_path.""" 54 | 55 | n_samples, n_features = 40, 80 56 | X, y = build_dataset(n_samples, n_features, sparse_X=False) 57 | 58 | tol = 1e-14 59 | l1_ratio = 0.7 60 | alpha_max = norm(X.T@y, ord=np.inf) / n_samples 61 | params = dict(eps=1e-3, tol=tol, l1_ratio=l1_ratio) 62 | 63 | # one alpha 64 | alpha = alpha_max / 100. 65 | alphas1, coefs1, gaps1 = celer_path( 66 | X, y, "lasso", alphas=[alpha], 67 | prune=prune, max_iter=30, **params) 68 | 69 | alphas2, coefs2, _ = enet_path(X, y, max_iter=10000, 70 | alphas=[alpha], **params) 71 | 72 | assert_equal(alphas1, alphas2) 73 | assert_array_less(gaps1, tol * norm(y) ** 2 / n_samples) 74 | assert_allclose(coefs1, coefs2, rtol=1e-3, atol=1e-4) 75 | 76 | # many alphas 77 | n_alphas = 20 78 | alphas1, coefs1, gaps1 = celer_path( 79 | X, y, "lasso", n_alphas=n_alphas, 80 | prune=prune, max_iter=30, **params) 81 | 82 | alphas2, coefs2, _ = enet_path(X, y, max_iter=10000, 83 | n_alphas=n_alphas, **params) 84 | 85 | assert_allclose(alphas1, alphas2) 86 | assert_array_less(gaps1, tol * norm(y) ** 2 / n_samples) 87 | assert_allclose(coefs1, coefs2, rtol=1e-3, atol=1e-4) 88 | 89 | 90 | @pytest.mark.parametrize("sparse_X, fit_intercept, positive", 91 | product([False, True], [False, True], [False, True])) 92 | def test_sk_ElasticNet_equivalence(sparse_X, fit_intercept, positive): 93 | n_samples, n_features = 30, 50 94 | X, y = build_dataset(n_samples, n_features, sparse_X=sparse_X) 95 | 96 | params = {'l1_ratio': 0.5, 'tol': 1e-14, 97 | 'fit_intercept': fit_intercept, 'positive': positive} 98 | 99 | reg_celer = ElasticNet(**params).fit(X, y) 100 | reg_sk = sk_ElasticNet(**params).fit(X, y) 101 | 102 | assert_allclose(reg_celer.coef_, reg_sk.coef_, rtol=1e-3, atol=1e-3) 103 | if fit_intercept: 104 | assert_allclose(reg_celer.intercept_, reg_sk.intercept_) 105 | 106 | 107 | @pytest.mark.parametrize("sparse_X", (True, False)) 108 | def test_weighted_ElasticNet(sparse_X): 109 | n_samples, n_features = 30, 50 110 | X, y = build_dataset(n_samples, n_features, sparse_X) 111 | 112 | np.random.seed(0) 113 | weights = abs(np.random.randn(n_features)) 114 | l1_ratio = .7 115 | 116 | params = {'max_iter': 10000, 'tol': 1e-14, 'fit_intercept': False} 117 | 118 | alpha_max = norm(X.T@y / weights, ord=np.inf) / n_samples 119 | alpha = alpha_max / 100. 120 | 121 | reg_enet = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, **params).fit(X, y) 122 | 123 | lmbda = alpha * l1_ratio * n_samples / (n_samples + n_features) 124 | mu = alpha * (1 - l1_ratio) 125 | X_tilde = np.vstack( 126 | (X, np.sqrt(n_samples*mu) * np.eye(n_features))) 127 | y_tilde = np.hstack((y, np.zeros(n_features))) 128 | 129 | reg_lasso = Lasso(alpha=lmbda, **params) 130 | reg_lasso.fit(X_tilde, y_tilde) 131 | 132 | assert_allclose(reg_enet.coef_, reg_lasso.coef_, rtol=1e-4, atol=1e-3) 133 | 134 | 135 | @pytest.mark.parametrize("fit_intercept", (False, True)) 136 | def test_infinite_weights(fit_intercept): 137 | n_samples, n_features = 30, 100 138 | X, y = build_dataset(n_samples, n_features, sparse_X=False) 139 | 140 | np.random.seed(42) 141 | weights = abs(np.random.rand(n_features)) 142 | n_inf = n_features // 5 143 | inf_indices = np.random.choice(n_features, size=n_inf, replace=False) 144 | weights[inf_indices] = np.inf 145 | 146 | reg = ElasticNet(l1_ratio=0.5, tol=1e-8, 147 | fit_intercept=fit_intercept, weights=weights) 148 | reg.fit(X, y) 149 | 150 | assert_equal(reg.coef_[inf_indices], 0) 151 | 152 | 153 | @pytest.mark.parametrize("fit_intercept", (False, True)) 154 | def test_ElasticNetCV(fit_intercept): 155 | n_samples, n_features = 30, 100 156 | X, y = build_dataset(n_samples, n_features, sparse_X=False) 157 | 158 | params = dict(l1_ratio=[0.7, 0.8, 0.5], eps=0.05, n_alphas=10, tol=1e-10, cv=2, 159 | fit_intercept=fit_intercept, n_jobs=-1) 160 | 161 | clf = ElasticNetCV(**params) 162 | clf.fit(X, y) 163 | 164 | clf2 = sk_ElasticNetCV(**params, max_iter=10000) 165 | clf2.fit(X, y) 166 | 167 | assert_allclose( 168 | clf.mse_path_, clf2.mse_path_, rtol=1e-3, atol=1e-4) 169 | assert_allclose(clf.alpha_, clf2.alpha_) 170 | assert_allclose(clf.coef_, clf2.coef_, atol=1e-5) 171 | assert_allclose(clf.l1_ratio_, clf2.l1_ratio_, atol=1e-5) 172 | 173 | 174 | if __name__ == '__main__': 175 | pass 176 | -------------------------------------------------------------------------------- /celer/tests/test_lasso.py: -------------------------------------------------------------------------------- 1 | # Author: Mathurin Massias 2 | # Alexandre Gramfort 3 | # Joseph Salmon 4 | # License: BSD 3 clause 5 | 6 | import warnings 7 | from itertools import product 8 | 9 | import numpy as np 10 | from numpy.linalg import norm 11 | from numpy.testing import assert_allclose, assert_array_less, assert_array_equal 12 | import pytest 13 | 14 | from sklearn.exceptions import ConvergenceWarning 15 | from sklearn.linear_model import (LassoCV as sklearn_LassoCV, 16 | Lasso as sklearn_Lasso, lasso_path) 17 | 18 | from celer import celer_path 19 | from celer.dropin_sklearn import Lasso, LassoCV 20 | from celer.utils.testing import build_dataset 21 | 22 | 23 | @pytest.mark.parametrize("sparse_X, alphas, pb, dtype", 24 | product([False, True], [None, 1], 25 | ["lasso", "logreg"], 26 | [np.float32, np.float64])) 27 | def test_celer_path(sparse_X, alphas, pb, dtype): 28 | """Test Lasso path convergence.""" 29 | X, y = build_dataset(n_samples=30, n_features=50, sparse_X=sparse_X) 30 | X = X.astype(dtype) 31 | y = y.astype(dtype) 32 | 33 | tol = 1e-6 34 | if pb == "logreg": 35 | y = np.sign(y) 36 | tol_scaled = tol * len(y) * np.log(2) 37 | else: 38 | tol_scaled = tol * norm(y) ** 2 / len(y) 39 | n_samples = X.shape[0] 40 | if alphas is not None: 41 | alpha_max = np.max(np.abs(X.T.dot(y))) / n_samples 42 | n_alphas = 10 43 | alphas = alpha_max * np.logspace(0, -2, n_alphas) 44 | 45 | alphas, _, gaps, _, n_iters = celer_path( 46 | X, y, pb, alphas=alphas, tol=tol, return_thetas=True, 47 | verbose=1, return_n_iter=True) 48 | assert_array_less(gaps, tol_scaled) 49 | # hack because array_less wants strict inequality 50 | assert_array_less(0.99, n_iters) 51 | 52 | 53 | def test_convergence_warning(): 54 | X, y = build_dataset(n_samples=10, n_features=10) 55 | tol = 1e-16 # very small, not enough iterations below 56 | alpha_max = np.max(np.abs(X.T.dot(y))) / X.shape[0] 57 | clf = Lasso(alpha_max / 100, max_iter=1, max_epochs=1, tol=tol) 58 | 59 | with warnings.catch_warnings(record=True) as w: 60 | # Cause all warnings to always be triggered. 61 | warnings.simplefilter("always") 62 | clf.fit(X, y) 63 | assert len(w) >= 1 64 | assert issubclass(w[-1].category, ConvergenceWarning) 65 | 66 | 67 | @pytest.mark.parametrize("sparse_X, prune", [(False, 0), (False, 1)]) 68 | def test_celer_path_vs_lasso_path(sparse_X, prune): 69 | """Test that celer_path matches sklearn lasso_path.""" 70 | X, y = build_dataset(n_samples=30, n_features=50, sparse_X=sparse_X) 71 | 72 | tol = 1e-14 73 | params = dict(eps=1e-3, n_alphas=10, tol=tol) 74 | alphas1, coefs1, gaps1 = celer_path( 75 | X, y, "lasso", return_thetas=False, verbose=1, prune=prune, 76 | max_iter=30, **params) 77 | 78 | alphas2, coefs2, _ = lasso_path(X, y, verbose=False, **params, 79 | max_iter=10000) 80 | 81 | assert_allclose(alphas1, alphas2) 82 | assert_array_less(gaps1, tol * norm(y) ** 2 / len(y)) 83 | assert_allclose(coefs1, coefs2, rtol=1e-03, atol=1e-4) 84 | 85 | 86 | @pytest.mark.parametrize("sparse_X, fit_intercept, positive", 87 | product([False, True], [False, True], [False, True])) 88 | def test_LassoCV(sparse_X, fit_intercept, positive): 89 | """Test that our LassoCV behaves like sklearn's LassoCV.""" 90 | 91 | X, y = build_dataset(n_samples=20, n_features=30, sparse_X=sparse_X) 92 | params = dict(eps=0.05, n_alphas=10, tol=1e-10, cv=2, 93 | fit_intercept=fit_intercept, positive=positive, n_jobs=-1) 94 | 95 | clf = LassoCV(**params) 96 | clf.fit(X, y) 97 | 98 | clf2 = sklearn_LassoCV(**params, max_iter=10000) 99 | clf2.fit(X, y) 100 | 101 | assert_allclose( 102 | clf.mse_path_, clf2.mse_path_, rtol=1e-3, atol=1e-4) 103 | assert_allclose(clf.alpha_, clf2.alpha_) 104 | assert_allclose(clf.coef_, clf2.coef_, atol=1e-5) 105 | 106 | # TODO this one is slow (3s * 8 tests). Pass an instance and increase tol 107 | # check_estimator(LassoCV) 108 | 109 | 110 | @pytest.mark.parametrize("sparse_X, fit_intercept, positive", 111 | product([False, True], [False, True], [False, True])) 112 | def test_Lasso(sparse_X, fit_intercept, positive): 113 | """Test that our Lasso class behaves as sklearn's Lasso.""" 114 | X, y = build_dataset(n_samples=20, n_features=30, sparse_X=sparse_X) 115 | if not positive: 116 | alpha_max = norm(X.T.dot(y), ord=np.inf) / X.shape[0] 117 | else: 118 | alpha_max = X.T.dot(y).max() / X.shape[0] 119 | 120 | alpha = alpha_max / 2. 121 | params = dict(alpha=alpha, fit_intercept=fit_intercept, tol=1e-10, 122 | positive=positive) 123 | clf = Lasso(**params) 124 | clf.fit(X, y) 125 | 126 | clf2 = sklearn_Lasso(**params) 127 | clf2.fit(X, y) 128 | assert_allclose(clf.coef_, clf2.coef_, rtol=1e-5) 129 | if fit_intercept: 130 | assert_allclose(clf.intercept_, clf2.intercept_) 131 | 132 | # TODO fix for sklearn 0.24, pass an instance instead (buffer type error) 133 | # check_estimator(Lasso) 134 | 135 | 136 | @pytest.mark.parametrize("sparse_X, pb", 137 | product([True, False], ["lasso", "logreg"])) 138 | def test_celer_single_alpha(sparse_X, pb): 139 | X, y = build_dataset(n_samples=20, n_features=100, sparse_X=sparse_X) 140 | tol = 1e-6 141 | 142 | if pb == "logreg": 143 | y = np.sign(y) 144 | tol_scaled = tol * np.log(2) * len(y) 145 | else: 146 | tol_scaled = tol * norm(y) ** 2 / len(y) 147 | 148 | alpha_max = norm(X.T.dot(y), ord=np.inf) / X.shape[0] 149 | _, _, gaps = celer_path(X, y, pb, alphas=[alpha_max / 10.], tol=tol) 150 | assert_array_less(gaps, tol_scaled) 151 | 152 | 153 | @pytest.mark.parametrize("sparse_X", [True, False]) 154 | def test_zero_column(sparse_X): 155 | X, y = build_dataset(n_samples=60, n_features=50, sparse_X=sparse_X) 156 | n_zero_columns = 20 157 | if sparse_X: 158 | X.data[:X.indptr[n_zero_columns]].fill(0.) 159 | else: 160 | X[:, :n_zero_columns].fill(0.) 161 | alpha_max = norm(X.T.dot(y), ord=np.inf) / X.shape[0] 162 | tol = 1e-6 163 | _, coefs, gaps = celer_path( 164 | X, y, "lasso", alphas=[alpha_max / 10.], tol=tol, p0=50, prune=0) 165 | w = coefs.T[0] 166 | assert_array_less(gaps, tol * norm(y) ** 2 / len(y)) 167 | np.testing.assert_equal(w.shape[0], X.shape[1]) 168 | 169 | 170 | def test_warm_start(): 171 | """Test Lasso path convergence.""" 172 | X, y = build_dataset( 173 | n_samples=100, n_features=100, sparse_X=True) 174 | n_samples, n_features = X.shape 175 | alpha_max = np.max(np.abs(X.T.dot(y))) / n_samples 176 | n_alphas = 10 177 | alphas = alpha_max * np.logspace(0, -2, n_alphas) 178 | 179 | reg1 = Lasso(tol=1e-6, warm_start=True, p0=10) 180 | reg1.coef_ = np.zeros(n_features) 181 | 182 | for alpha in alphas: 183 | reg1.set_params(alpha=alpha) 184 | reg1.fit(X, y) 185 | # refitting with warm start should take less than 2 iters: 186 | reg1.fit(X, y) 187 | # hack because assert_array_less does strict comparison... 188 | assert_array_less(reg1.n_iter_, 2.01) 189 | 190 | 191 | def test_weights_lasso(): 192 | X, y = build_dataset(n_samples=30, n_features=50, sparse_X=True) 193 | 194 | np.random.seed(0) 195 | weights = np.abs(np.random.randn(X.shape[1])) 196 | 197 | tol = 1e-14 198 | params = {'n_alphas': 10, 'tol': tol} 199 | alphas1, coefs1, gaps1 = celer_path( 200 | X, y, "lasso", weights=weights, verbose=1, **params) 201 | 202 | alphas2, coefs2, gaps2 = celer_path( 203 | X.multiply(1 / weights[None, :]), y, "lasso", **params) 204 | 205 | assert_allclose(alphas1, alphas2) 206 | assert_allclose(coefs1, coefs2 / weights[:, None], atol=1e-4, rtol=1e-3) 207 | assert_array_less(gaps1, tol * norm(y) ** 2 / len(y)) 208 | assert_array_less(gaps2, tol * norm(y) ** 2 / len(y)) 209 | 210 | alpha = 0.001 211 | clf1 = Lasso(alpha=alpha, weights=weights, fit_intercept=False).fit(X, y) 212 | clf2 = Lasso(alpha=alpha, fit_intercept=False).fit( 213 | X.multiply(1. / weights), y) 214 | 215 | assert_allclose(clf1.coef_, clf2.coef_ / weights) 216 | 217 | # weights must be > 0 218 | clf1.weights[0] = 0. 219 | np.testing.assert_raises(ValueError, clf1.fit, X=X, y=y) 220 | # weights must be equal to X.shape[1] 221 | clf1.weights = np.ones(X.shape[1] + 1) 222 | np.testing.assert_raises(ValueError, clf1.fit, X=X, y=y) 223 | 224 | 225 | @pytest.mark.parametrize("pb", ["lasso", "logreg"]) 226 | def test_infinite_weights(pb): 227 | n_samples, n_features = 50, 100 228 | X, y = build_dataset(n_samples, n_features) 229 | if pb == "logreg": 230 | y = np.sign(y) 231 | 232 | np.random.seed(1) 233 | weights = np.abs(np.random.randn(n_features)) 234 | n_inf = n_features // 10 235 | inf_indices = np.random.choice(n_features, size=n_inf, replace=False) 236 | weights[inf_indices] = np.inf 237 | 238 | alpha = norm(X.T @ y / weights, ord=np.inf) / n_samples / 100 239 | 240 | tol = 1e-8 241 | _, coefs, dual_gaps = celer_path( 242 | X, y, pb=pb, alphas=[alpha], weights=weights, tol=tol) 243 | 244 | if pb == "logreg": 245 | assert_array_less(dual_gaps[0], tol * n_samples * np.log(2)) 246 | else: 247 | assert_array_less(dual_gaps[0], tol * norm(y) ** 2 / 2.) 248 | 249 | assert_array_equal(coefs[inf_indices], 0) 250 | 251 | 252 | def test_one_iteration_alpha_max(): 253 | n_samples, n_features = 100, 50 254 | X, y = build_dataset(n_samples, n_features) 255 | 256 | alpha_max = norm(X.T @ y, ord=np.inf) / n_samples 257 | m = 5 258 | model = Lasso(alpha=m*alpha_max, fit_intercept=False) 259 | model.fit(X, y) 260 | 261 | assert_array_equal(model.coef_, np.zeros(n_features)) 262 | # solver exits right after computing first duality gap: 263 | np.testing.assert_equal(model.n_iter_, 1) 264 | 265 | 266 | if __name__ == "__main__": 267 | pass 268 | -------------------------------------------------------------------------------- /celer/tests/test_logreg.py: -------------------------------------------------------------------------------- 1 | # Author: Mathurin Massias 2 | # License: BSD 3 clause 3 | import pytest 4 | import numpy as np 5 | from numpy.linalg import norm 6 | 7 | from numpy.testing import assert_allclose, assert_array_less 8 | from sklearn.linear_model._logistic import _logistic_regression_path 9 | from sklearn.utils.estimator_checks import check_estimator 10 | from sklearn.linear_model import LogisticRegression as sklearn_Logreg 11 | 12 | from celer import celer_path 13 | from celer.dropin_sklearn import LogisticRegression 14 | from celer.utils.testing import build_dataset 15 | 16 | 17 | @pytest.mark.parametrize("solver", ["celer", "celer-pn"]) 18 | def test_celer_path_logreg(solver): 19 | X, y = build_dataset( 20 | n_samples=60, n_features=100, sparse_X=True) 21 | y = np.sign(y) 22 | alpha_max = norm(X.T.dot(y), ord=np.inf) / 2 23 | alphas = alpha_max * np.geomspace(1, 1e-2, 10) 24 | 25 | tol = 1e-11 26 | coefs, Cs, n_iters = _logistic_regression_path( 27 | X, y, Cs=1. / alphas, fit_intercept=False, penalty='l1', 28 | solver='liblinear', tol=tol, max_iter=1000, random_state=0) 29 | 30 | _, coefs_c, gaps = celer_path( 31 | X, y, "logreg", alphas=alphas, tol=tol, verbose=0, 32 | use_PN=(solver == "celer-pn")) 33 | 34 | assert_array_less(gaps, tol * len(y) * np.log(2)) 35 | assert_allclose(coefs != 0, coefs_c.T != 0) 36 | assert_allclose(coefs, coefs_c.T, atol=1e-5, rtol=1e-3) 37 | 38 | 39 | @pytest.mark.parametrize("sparse_X", [True, False]) 40 | def test_binary(sparse_X): 41 | np.random.seed(1409) 42 | X, y = build_dataset( 43 | n_samples=30, n_features=60, sparse_X=sparse_X) 44 | y = np.sign(y) 45 | alpha_max = norm(X.T.dot(y), ord=np.inf) / 2 46 | C = 20. / alpha_max 47 | 48 | clf = LogisticRegression(C=-1) 49 | np.testing.assert_raises(ValueError, clf.fit, X, y) 50 | tol = 1e-8 51 | clf = LogisticRegression(C=C, tol=tol, verbose=0) 52 | clf.fit(X, y) 53 | 54 | clf_sk = sklearn_Logreg( 55 | C=C, penalty='l1', solver='liblinear', fit_intercept=False, tol=tol) 56 | clf_sk.fit(X, y) 57 | assert_allclose(clf.coef_, clf_sk.coef_, rtol=1e-3, atol=1e-5) 58 | 59 | 60 | @pytest.mark.parametrize("sparse_X", [True, False]) 61 | def test_multinomial(sparse_X): 62 | np.random.seed(1409) 63 | X, y = build_dataset( 64 | n_samples=30, n_features=60, sparse_X=sparse_X) 65 | y = np.random.choice(4, len(y)) 66 | tol = 1e-8 67 | clf = LogisticRegression(C=1, tol=tol, verbose=0) 68 | clf.fit(X, y) 69 | 70 | clf_sk = sklearn_Logreg( 71 | C=1, penalty='l1', solver='liblinear', fit_intercept=False, tol=tol) 72 | clf_sk.fit(X, y) 73 | assert_allclose(clf.coef_, clf_sk.coef_, rtol=1e-3, atol=1e-3) 74 | 75 | 76 | @pytest.mark.parametrize("solver", ["celer-pn"]) 77 | def test_check_estimator(solver): 78 | # sklearn fits on unnormalized data for which there are convergence issues 79 | # fix with increased tolerance: 80 | clf = LogisticRegression(C=1, solver=solver, tol=0.1) 81 | check_estimator(clf) 82 | -------------------------------------------------------------------------------- /celer/tests/test_mtl.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import itertools 3 | 4 | import numpy as np 5 | from numpy.linalg import norm 6 | from numpy.testing import assert_allclose, assert_array_less, assert_array_equal 7 | 8 | from sklearn.utils.estimator_checks import check_estimator 9 | from sklearn.linear_model import MultiTaskLassoCV as sklearn_MultiTaskLassoCV 10 | from sklearn.linear_model import MultiTaskLasso as sklearn_MultiTaskLasso 11 | from sklearn.linear_model import lasso_path 12 | 13 | from celer import (Lasso, GroupLasso, GroupLassoCV, MultiTaskLasso, 14 | MultiTaskLassoCV) 15 | from celer.homotopy import celer_path, mtl_path, _grp_converter 16 | from celer.group_fast import dnorm_grp 17 | from celer.utils.testing import build_dataset 18 | 19 | 20 | @pytest.mark.parametrize("sparse_X, fit_intercept", 21 | itertools.product([0, 1], [False, True])) 22 | def test_GroupLasso_Lasso_equivalence(sparse_X, fit_intercept): 23 | """Check that GroupLasso with groups of size 1 gives Lasso.""" 24 | n_features = 1000 25 | X, y = build_dataset( 26 | n_samples=100, n_features=n_features, sparse_X=sparse_X) 27 | alpha_max = norm(X.T @ y, ord=np.inf) / len(y) 28 | alpha = alpha_max / 10 29 | clf = Lasso(alpha, tol=1e-12, fit_intercept=fit_intercept, 30 | verbose=0) 31 | clf.fit(X, y) 32 | # take groups of size 1: 33 | clf1 = GroupLasso(alpha=alpha, groups=1, tol=1e-12, 34 | fit_intercept=fit_intercept, verbose=0) 35 | clf1.fit(X, y) 36 | 37 | np.testing.assert_allclose(clf1.coef_, clf.coef_, atol=1e-6) 38 | np.testing.assert_allclose(clf1.intercept_, clf.intercept_, rtol=1e-4) 39 | 40 | 41 | def test_GroupLasso_MultitaskLasso_equivalence(): 42 | "GroupLasso and MultitaskLasso equivalence.""" 43 | n_samples, n_features = 30, 50 44 | X_, Y_ = build_dataset(n_samples, n_features, n_targets=3) 45 | y = Y_.reshape(-1, order='F') 46 | X = np.zeros([3 * n_samples, 3 * n_features], order='F') 47 | 48 | # block filling new design 49 | for i in range(3): 50 | X[i * n_samples:(i + 1) * n_samples, i * 51 | n_features:(i + 1) * n_features] = X_ 52 | 53 | grp_indices = np.arange( 54 | 3 * n_features).reshape(3, -1).reshape(-1, order='F').astype(np.int32) 55 | grp_ptr = 3 * np.arange(n_features + 1).astype(np.int32) 56 | 57 | alpha_max = np.max(norm(X_.T @ Y_, axis=1)) / len(Y_) 58 | 59 | X_data = np.empty([1], dtype=X.dtype) 60 | X_indices = np.empty([1], dtype=np.int32) 61 | X_indptr = np.empty([1], dtype=np.int32) 62 | weights = np.ones(len(grp_ptr) - 1) 63 | 64 | other = dnorm_grp( 65 | False, y, grp_ptr, grp_indices, X, X_data, 66 | X_indices, X_indptr, X_data, weights, len(grp_ptr) - 1, 67 | np.zeros(1, dtype=np.int32), False) 68 | np.testing.assert_allclose(alpha_max, other / len(Y_)) 69 | 70 | alpha = alpha_max / 10 71 | clf = MultiTaskLasso(alpha, fit_intercept=False, tol=1e-8, verbose=0) 72 | clf.fit(X_, Y_) 73 | 74 | groups = [grp.tolist() for grp in grp_indices.reshape(50, 3)] 75 | clf1 = GroupLasso(alpha=alpha / 3, groups=groups, 76 | fit_intercept=False, tol=1e-8, verbose=0) 77 | clf1.fit(X, y) 78 | 79 | np.testing.assert_allclose(clf1.coef_, clf.coef_.reshape(-1), atol=1e-4) 80 | 81 | 82 | def test_convert_groups(): 83 | n_features = 6 84 | grp_ptr, grp_indices = _grp_converter(3, n_features) 85 | np.testing.assert_equal(grp_ptr, [0, 3, 6]) 86 | np.testing.assert_equal(grp_indices, [0, 1, 2, 3, 4, 5]) 87 | 88 | grp_ptr, grp_indices = _grp_converter([1, 3, 2], 6) 89 | np.testing.assert_equal(grp_ptr, [0, 1, 4, 6]) 90 | np.testing.assert_equal(grp_indices, [0, 1, 2, 3, 4, 5]) 91 | 92 | groups = [[0, 2, 5], [1, 3], [4]] 93 | grp_ptr, grp_indices = _grp_converter(groups, 6) 94 | np.testing.assert_equal(grp_ptr, [0, 3, 5, 6]) 95 | np.testing.assert_equal(grp_indices, [0, 2, 5, 1, 3, 4]) 96 | 97 | 98 | def test_mtl_path(): 99 | X, Y = build_dataset(n_targets=3) 100 | tol = 1e-12 101 | params = dict(eps=0.01, tol=tol, n_alphas=10) 102 | alphas, coefs, gaps = mtl_path(X, Y, **params) 103 | np.testing.assert_array_less(gaps, tol * norm(Y) ** 2 / len(Y)) 104 | 105 | sk_alphas, sk_coefs, sk_gaps = lasso_path(X, Y, **params, max_iter=10000) 106 | np.testing.assert_array_less(sk_gaps, tol * np.linalg.norm(Y, 'fro')**2) 107 | np.testing.assert_array_almost_equal(coefs, sk_coefs, decimal=5) 108 | np.testing.assert_allclose(alphas, sk_alphas) 109 | 110 | 111 | def test_MultiTaskLassoCV(): 112 | """Test that our MultitaskLassoCV behaves like sklearn's.""" 113 | X, y = build_dataset(n_samples=30, n_features=50, n_targets=3) 114 | 115 | params = dict(eps=1e-2, n_alphas=10, tol=1e-12, cv=2, n_jobs=1, 116 | fit_intercept=False, verbose=0) 117 | 118 | clf = MultiTaskLassoCV(**params) 119 | clf.fit(X, y) 120 | 121 | clf2 = sklearn_MultiTaskLassoCV(**params) 122 | clf2.max_iter = 10000 # increase max_iter bc of low tol 123 | clf2.fit(X, y) 124 | 125 | np.testing.assert_allclose(clf.mse_path_, clf2.mse_path_, 126 | atol=1e-4, rtol=1e-04) 127 | np.testing.assert_allclose(clf.alpha_, clf2.alpha_, 128 | atol=1e-4, rtol=1e-04) 129 | np.testing.assert_allclose(clf.coef_, clf2.coef_, 130 | atol=1e-4, rtol=1e-04) 131 | 132 | # check_estimator tests float32 so using tol < 1e-7 causes precision 133 | # issues 134 | # we don't support sample_weights for MTL 135 | # clf.tol = 1e-5 136 | # check_estimator(clf) 137 | 138 | 139 | @pytest.mark.parametrize("fit_intercept", [True, False]) 140 | def test_MultiTaskLasso(fit_intercept): 141 | """Test that our MultiTaskLasso behaves as sklearn's.""" 142 | X, Y = build_dataset(n_samples=20, n_features=30, n_targets=10) 143 | alpha_max = np.max(norm(X.T.dot(Y), axis=1)) / X.shape[0] 144 | 145 | alpha = alpha_max / 2. 146 | params = dict(alpha=alpha, fit_intercept=fit_intercept, tol=1e-10) 147 | clf = MultiTaskLasso(**params) 148 | clf.verbose = 2 149 | clf.fit(X, Y) 150 | 151 | clf2 = sklearn_MultiTaskLasso(**params) 152 | clf2.fit(X, Y) 153 | np.testing.assert_allclose(clf.coef_, clf2.coef_, rtol=1e-5) 154 | if fit_intercept: 155 | np.testing.assert_allclose(clf.intercept_, clf2.intercept_) 156 | 157 | # we don't support sample_weights for MTL 158 | # clf.tol = 1e-7 159 | # check_estimator(clf) 160 | 161 | 162 | @pytest.mark.parametrize("sparse_X", [True, False]) 163 | def test_group_lasso_path(sparse_X): 164 | n_features = 50 165 | X, y = build_dataset( 166 | n_samples=11, n_features=n_features, sparse_X=sparse_X) 167 | 168 | alphas, coefs, gaps = celer_path( 169 | X, y, "grouplasso", groups=5, eps=1e-2, n_alphas=10, tol=1e-8) 170 | tol = 1e-8 171 | np.testing.assert_array_less(gaps, tol * norm(y) ** 2 / len(y)) 172 | 173 | 174 | @pytest.mark.parametrize("sparse_X", [True, False]) 175 | def test_GroupLasso(sparse_X): 176 | n_features = 50 177 | X, y = build_dataset( 178 | n_samples=11, n_features=n_features, sparse_X=sparse_X) 179 | 180 | tol = 1e-8 181 | clf = GroupLasso(alpha=0.8, groups=10, tol=tol) 182 | clf.fit(X, y) 183 | np.testing.assert_array_less(clf.dual_gap_, tol * norm(y) ** 2 / len(y)) 184 | 185 | clf.tol = 1e-6 186 | clf.groups = 1 # unsatisfying but sklearn will fit out of 5 features 187 | check_estimator(clf) 188 | 189 | 190 | @pytest.mark.parametrize("sparse_X", [True, False]) 191 | def test_GroupLassoCV(sparse_X): 192 | n_features = 50 193 | X, y = build_dataset( 194 | n_samples=11, n_features=n_features, sparse_X=sparse_X) 195 | 196 | tol = 1e-8 197 | clf = GroupLassoCV(groups=10, tol=tol) 198 | clf.fit(X, y) 199 | np.testing.assert_array_less(clf.dual_gap_, tol * norm(y) ** 2 / len(y)) 200 | 201 | clf.tol = 1e-6 202 | clf.groups = 1 # unsatisfying but sklearn will fit with 5 features 203 | check_estimator(clf) 204 | 205 | 206 | def test_weights_group_lasso(): 207 | n_samples, n_features = 30, 50 208 | X, y = build_dataset(n_samples, n_features, sparse_X=True) 209 | 210 | groups = 5 211 | n_groups = n_features // groups 212 | np.random.seed(0) 213 | weights = np.abs(np.random.randn(n_groups)) 214 | 215 | tol = 1e-14 216 | params = {'n_alphas': 10, 'tol': tol, 'verbose': 1} 217 | augmented_weights = np.repeat(weights, groups) 218 | 219 | alphas1, coefs1, gaps1 = celer_path( 220 | X, y, "grouplasso", groups=groups, weights=weights, 221 | eps=1e-2, **params) 222 | alphas2, coefs2, gaps2 = celer_path( 223 | X.multiply(1 / augmented_weights[None, :]), y, "grouplasso", 224 | groups=groups, eps=1e-2, **params) 225 | 226 | assert_allclose(alphas1, alphas2) 227 | assert_allclose( 228 | coefs1, coefs2 / augmented_weights[:, None], rtol=1e-3) 229 | assert_array_less(gaps1, tol * norm(y) ** 2 / len(y)) 230 | assert_array_less(gaps2, tol * norm(y) ** 2 / len(y)) 231 | 232 | 233 | def test_check_weights(): 234 | X, y = build_dataset(30, 42) 235 | weights = np.ones(X.shape[1] // 7) 236 | weights[0] = 0 237 | clf = GroupLasso(weights=weights, groups=7) # groups of size 7 238 | # weights must be > 0 239 | np.testing.assert_raises(ValueError, clf.fit, X=X, y=y) 240 | # len(weights) must be equal to number of groups (6 here) 241 | clf.weights = np.ones(8) 242 | np.testing.assert_raises(ValueError, clf.fit, X=X, y=y) 243 | 244 | 245 | def test_infinite_weights_group(): 246 | n_samples, n_features = 50, 100 247 | X, y = build_dataset(n_samples, n_features) 248 | 249 | np.random.seed(1) 250 | group_size = 5 251 | weights = np.abs(np.random.randn(n_features // group_size)) 252 | n_inf = 3 253 | inf_indices = np.random.choice( 254 | n_features // group_size, size=n_inf, replace=False) 255 | weights[inf_indices] = np.inf 256 | alpha_max = np.max( 257 | norm((X.T @ y).reshape(-1, group_size), 2, axis=1) 258 | ) / n_samples 259 | 260 | clf = GroupLasso( 261 | alpha=alpha_max / 100., weights=weights, groups=group_size, tol=1e-8 262 | ).fit(X, y) 263 | 264 | assert_array_less(clf.dual_gap_, clf.tol * norm(y) ** 2 / 2) 265 | assert_array_equal( 266 | norm(clf.coef_.reshape(-1, group_size), axis=1)[inf_indices], 0) 267 | 268 | 269 | if __name__ == "__main__": 270 | pass 271 | -------------------------------------------------------------------------------- /celer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathurinm/celer/eef3bb5f239cf1a0b3a6d0fc8997a20b5e73e21a/celer/utils/__init__.py -------------------------------------------------------------------------------- /celer/utils/testing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from scipy import sparse 4 | 5 | 6 | def build_dataset(n_samples=50, n_features=200, n_targets=1, sparse_X=False): 7 | """Build samples and observation for linear regression problem.""" 8 | random_state = np.random.RandomState(0) 9 | if n_targets > 1: 10 | w = random_state.randn(n_features, n_targets) 11 | else: 12 | w = random_state.randn(n_features) 13 | 14 | if sparse_X: 15 | X = sparse.random(n_samples, n_features, density=0.5, format='csc', 16 | random_state=random_state) 17 | 18 | else: 19 | X = np.asfortranarray(random_state.randn(n_samples, n_features)) 20 | 21 | y = X.dot(w) 22 | return X, y 23 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 2 3 | round: down 4 | range: "70...100" 5 | status: 6 | project: 7 | default: 8 | target: auto 9 | threshold: 0.01 10 | patch: false 11 | changes: false 12 | ignore: 13 | - "celer/celer/datasets/.*" 14 | - "celer/datasets/libsvm.py" 15 | - "datasets/libsvm.py" 16 | comment: 17 | layout: "header, diff, sunburst, uncovered" 18 | behavior: default 19 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | GITHUB_PAGES_BRANCH = gh-pages 11 | OUTPUTDIR = _build/html 12 | 13 | # User-friendly check for sphinx-build 14 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 15 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 16 | endif 17 | 18 | # Internal variables. 19 | PAPEROPT_a4 = -D latex_paper_size=a4 20 | PAPEROPT_letter = -D latex_paper_size=letter 21 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 22 | # the i18n builder cannot share the environment and doctrees with the others 23 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 24 | 25 | .PHONY: help 26 | help: 27 | @echo "Please use \`make ' where is one of" 28 | @echo " html-noplot to make standalone HTML files, without plotting anything" 29 | @echo " html to make standalone HTML files" 30 | @echo " dirhtml to make HTML files named index.html in directories" 31 | @echo " singlehtml to make a single large HTML file" 32 | @echo " pickle to make pickle files" 33 | @echo " htmlhelp to make HTML files and a HTML help project" 34 | @echo " qthelp to make HTML files and a qthelp project" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " changes to make an overview of all changed/added/deprecated items" 38 | @echo " linkcheck to check all external links for integrity" 39 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 40 | @echo " coverage to run coverage check of the documentation (if enabled)" 41 | @echo " install to make the html and push it online" 42 | 43 | .PHONY: clean 44 | 45 | clean: 46 | rm -rf $(BUILDDIR)/* 47 | rm -rf auto_examples/ 48 | rm -rf generated/* 49 | rm -rf modules/* 50 | 51 | html-noplot: 52 | $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 53 | @echo 54 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 55 | 56 | .PHONY: html 57 | html: 58 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 61 | 62 | .PHONY: dirhtml 63 | dirhtml: 64 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 65 | @echo 66 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 67 | 68 | .PHONY: singlehtml 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | .PHONY: pickle 75 | pickle: 76 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 77 | @echo 78 | @echo "Build finished; now you can process the pickle files." 79 | 80 | .PHONY: htmlhelp 81 | htmlhelp: 82 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 83 | @echo 84 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 85 | ".hhp project file in $(BUILDDIR)/htmlhelp." 86 | 87 | .PHONY: qthelp 88 | qthelp: 89 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 90 | @echo 91 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 92 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 93 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/celer.qhcp" 94 | @echo "To view the help file:" 95 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/celer.qhc" 96 | 97 | .PHONY: latex 98 | latex: 99 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 100 | @echo 101 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 102 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 103 | "(use \`make latexpdf' here to do that automatically)." 104 | 105 | .PHONY: latexpdf 106 | latexpdf: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo "Running LaTeX files through pdflatex..." 109 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 110 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 111 | 112 | .PHONY: changes 113 | changes: 114 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 115 | @echo 116 | @echo "The overview file is in $(BUILDDIR)/changes." 117 | 118 | .PHONY: linkcheck 119 | linkcheck: 120 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 121 | @echo 122 | @echo "Link check complete; look for any errors in the above output " \ 123 | "or in $(BUILDDIR)/linkcheck/output.txt." 124 | 125 | .PHONY: doctest 126 | doctest: 127 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 128 | @echo "Testing of doctests in the sources finished, look at the " \ 129 | "results in $(BUILDDIR)/doctest/output.txt." 130 | 131 | .PHONY: coverage 132 | coverage: 133 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 134 | @echo "Testing of coverage in the sources finished, look at the " \ 135 | "results in $(BUILDDIR)/coverage/python.txt." 136 | 137 | install: 138 | touch $(OUTPUTDIR)/.nojekyll 139 | ghp-import -m "Generate Pelican site [ci skip]" -b $(GITHUB_PAGES_BRANCH) $(OUTPUTDIR) 140 | git push -f origin $(GITHUB_PAGES_BRANCH) 141 | -------------------------------------------------------------------------------- /doc/_static/style.css: -------------------------------------------------------------------------------- 1 | 2 | blockquote p { 3 | font-size: 14px !important; 4 | } 5 | 6 | blockquote { 7 | margin: 0 0 4px !important; 8 | } 9 | 10 | code { 11 | color: #49759c !important; 12 | background-color: #f3f5f9 !important; 13 | } 14 | 15 | .alert-info { 16 | background-color: #adb8cb !important; 17 | border-color: #adb8cb !important; 18 | color: #2c3e50 !important; 19 | } 20 | 21 | /* This breaks the sphinx.ext.linkcode in bootstrap theme (see PR 251) */ 22 | /* .function dt { 23 | padding-top: 150px; 24 | margin-top: -150px; 25 | -webkit-background-clip: content-box; 26 | background-clip: content-box; 27 | } */ -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | API Documentation 3 | ================= 4 | 5 | Estimators 6 | ========== 7 | 8 | .. currentmodule:: celer 9 | 10 | .. autosummary:: 11 | :toctree: generated/ 12 | 13 | ElasticNet 14 | ElasticNetCV 15 | GroupLasso 16 | GroupLassoCV 17 | Lasso 18 | LassoCV 19 | LogisticRegression 20 | MultiTaskLasso 21 | MultiTaskLassoCV 22 | 23 | 24 | Functions 25 | ========= 26 | 27 | .. autosummary:: 28 | :toctree: generated/ 29 | 30 | celer_path 31 | 32 | 33 | Datasets fetchers 34 | ================= 35 | 36 | :py:mod:`celer.datasets`: 37 | 38 | .. currentmodule:: celer.datasets 39 | 40 | .. automodule:: celer.datasets 41 | :no-members: 42 | :no-inherited-members: 43 | 44 | .. autosummary:: 45 | :toctree: generated/ 46 | 47 | make_correlated_data 48 | fetch_ml_uci 49 | fetch_libsvm 50 | fetch_climate 51 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # celer documentation build configuration file, created by 4 | # sphinx-quickstart on Thu Jun 1 00:35:01 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import os 16 | import sys 17 | import warnings 18 | import sphinx_gallery 19 | # import sphinx_bootstrap_theme 20 | from distutils.version import LooseVersion 21 | import matplotlib 22 | 23 | # If extensions (or modules to document with autodoc) are in another directory, 24 | # add these directories to sys.path here. If the directory is relative to the 25 | # documentation root, use os.path.abspath to make it absolute, like shown here. 26 | # sys.path.insert(0, os.path.abspath('.')) 27 | sys.path.insert(0, os.path.abspath(".")) 28 | from github_link import make_linkcode_resolve 29 | 30 | # Mathurin: disable agg warnings in doc 31 | warnings.filterwarnings("ignore", category=UserWarning, 32 | message='Matplotlib is currently using agg, which is a' 33 | ' non-GUI backend, so cannot show the figure.') 34 | 35 | 36 | # -- General configuration ------------------------------------------------ 37 | 38 | # If your documentation needs a minimal Sphinx version, state it here. 39 | # needs_sphinx = '1.0' 40 | 41 | # Add any Sphinx extension module names here, as strings. They can be 42 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 43 | # ones. 44 | extensions = [ 45 | 'sphinx.ext.autodoc', 46 | 'sphinx.ext.autosummary', 47 | 'sphinx.ext.autosectionlabel', 48 | 'sphinx_copybutton', 49 | 'sphinx.ext.doctest', 50 | 'sphinx.ext.intersphinx', 51 | 'sphinx.ext.mathjax', 52 | 'sphinx_gallery.gen_gallery', 53 | 'numpydoc', 54 | "sphinx.ext.linkcode", 55 | ] 56 | 57 | if LooseVersion(sphinx_gallery.__version__) < LooseVersion('0.2'): 58 | raise ImportError('Must have at least version 0.2 of sphinx-gallery, got ' 59 | '%s' % (sphinx_gallery.__version__,)) 60 | 61 | matplotlib.use('agg') 62 | 63 | 64 | # Add any paths that contain templates here, relative to this directory. 65 | templates_path = ['_templates'] 66 | 67 | # The suffix(es) of source filenames. 68 | # You can specify multiple suffix as a list of string: 69 | # 70 | # source_suffix = ['.rst', '.md'] 71 | source_suffix = '.rst' 72 | 73 | # The master toctree document. 74 | master_doc = 'index' 75 | 76 | # General information about the project. 77 | project = u'celer' 78 | copyright = u'2018-2022, Mathurin Massias' 79 | author = u'Mathurin Massias' 80 | 81 | # The version info for the project you're documenting, acts as replacement for 82 | # |version| and |release|, also used in various other places throughout the 83 | # built documents. 84 | # 85 | # The short X.Y version. 86 | from celer import __version__ as version # noqa 87 | # The full version, including alpha/beta/rc tags. 88 | release = version 89 | 90 | # The language for content autogenerated by Sphinx. Refer to documentation 91 | # for a list of supported languages. 92 | # 93 | # This is also used if you do content translation via gettext catalogs. 94 | # Usually you set "language" from the command line for these cases. 95 | language = 'en' 96 | 97 | # List of patterns, relative to source directory, that match files and 98 | # directories to ignore when looking for source files. 99 | # This patterns also effect to html_static_path and html_extra_path 100 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 101 | 102 | # The name of the Pygments (syntax highlighting) style to use. 103 | pygments_style = 'sphinx' 104 | 105 | # If true, `todo` and `todoList` produce output, else they produce nothing. 106 | todo_include_todos = False 107 | 108 | # generate autosummary even if no references 109 | autosummary_generate = True 110 | 111 | # remove warnings: "toctree contains reference to nonexisting document" 112 | numpydoc_show_class_members = False 113 | 114 | # -- Options for HTML output ---------------------------------------------- 115 | 116 | # The theme to use for HTML and HTML Help pages. See the documentation for 117 | # a list of builtin themes. 118 | html_theme = 'furo' 119 | 120 | # Theme options are theme-specific and customize the look and feel of a theme 121 | # further. For a list of options available for each theme, see the 122 | # documentation. 123 | 124 | html_theme_options = { 125 | "footer_icons": [ 126 | { 127 | "name": "GitHub", 128 | "url": "https://github.com/mathurinm/celer", 129 | "html": """ 130 | 131 | 132 | 133 | """, 134 | "class": "", 135 | }, 136 | ], 137 | } 138 | 139 | # Add any paths that contain custom themes here, relative to this directory. 140 | # html_theme_path = sphinx_bootstrap_theme.get_html_theme_path() 141 | 142 | # Add any paths that contain custom static files (such as style sheets) here, 143 | # relative to this directory. They are copied after the builtin static files, 144 | # so a file named "default.css" will overwrite the builtin "default.css". 145 | html_static_path = ['_static'] 146 | 147 | 148 | # -- Options for HTMLHelp output ------------------------------------------ 149 | 150 | # Output file base name for HTML help builder. 151 | htmlhelp_basename = 'celer_doc' 152 | 153 | 154 | # -- Options for copybutton --------------------------------------------- 155 | # complete explanation of the regex expression can be found here 156 | # https://sphinx-copybutton.readthedocs.io/en/latest/use.html#using-regexp-prompt-identifiers 157 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " 158 | copybutton_prompt_is_regexp = True 159 | 160 | 161 | # -- Options for LaTeX output --------------------------------------------- 162 | 163 | latex_elements = { 164 | # The paper size ('letterpaper' or 'a4paper'). 165 | # 166 | # 'papersize': 'letterpaper', 167 | 168 | # The font size ('10pt', '11pt' or '12pt'). 169 | # 170 | # 'pointsize': '10pt', 171 | 172 | # Additional stuff for the LaTeX preamble. 173 | # 174 | # 'preamble': '', 175 | 176 | # Latex figure (float) alignment 177 | # 178 | # 'figure_align': 'htbp', 179 | } 180 | 181 | # Grouping the document tree into LaTeX files. List of tuples 182 | # (source start file, target name, title, 183 | # author, documentclass [howto, manual, or own class]). 184 | latex_documents = [ 185 | (master_doc, 'celer.tex', u'celer Documentation', 186 | u'Mathurin Massias', 'manual'), 187 | ] 188 | 189 | 190 | # -- Options for manual page output --------------------------------------- 191 | 192 | # One entry per manual page. List of tuples 193 | # (source start file, name, description, authors, manual section). 194 | man_pages = [ 195 | (master_doc, 'celer', u'celer Documentation', 196 | [author], 1) 197 | ] 198 | 199 | 200 | # -- Options for Texinfo output ------------------------------------------- 201 | 202 | # Grouping the document tree into Texinfo files. List of tuples 203 | # (source start file, target name, title, author, 204 | # dir menu entry, description, category) 205 | texinfo_documents = [ 206 | (master_doc, 'celer', u'celer Documentation', 207 | author, 'celer', 'One line description of project.', 208 | 'Miscellaneous'), 209 | ] 210 | 211 | 212 | intersphinx_mapping = { 213 | # 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 214 | # 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), 215 | 'matplotlib': ('https://matplotlib.org/', None), 216 | 'sklearn': ('http://scikit-learn.org/stable', None), 217 | } 218 | 219 | sphinx_gallery_conf = { 220 | 'doc_module': ('celer', 'sklearn'), 221 | 'reference_url': dict(celer=None), 222 | 'examples_dirs': '../examples', 223 | 'gallery_dirs': 'auto_examples', 224 | 'reference_url': { 225 | 'celer': None, 226 | } 227 | } 228 | 229 | # The following is used by sphinx.ext.linkcode to provide links to github 230 | linkcode_resolve = make_linkcode_resolve( 231 | "celer", 232 | "https://github.com/mathurinm/" 233 | "celer/blob/{revision}/" 234 | "{package}/{path}#L{lineno}", 235 | ) 236 | 237 | 238 | def setup(app): 239 | app.add_css_file('style.css') 240 | -------------------------------------------------------------------------------- /doc/contribute.rst: -------------------------------------------------------------------------------- 1 | ==================== 2 | Become a contributor 3 | ==================== 4 | 5 | How to contribute to celer? 6 | --------------------------- 7 | 8 | ``celer`` is an open source project and hence rely on community efforts to evolve. 9 | No matter how small your contribution is, we highly valuate it. Your contribution 10 | can come in three forms 11 | 12 | - **Bug report** 13 | 14 | ``celer`` runs continuously unit test on the code base to prevent bugs. Help us tighten these tests by reporting 15 | any bug that you encountered while using ``celer``. To do so, use the 16 | `issue section `_ 17 | available on the ``celer`` repository. 18 | 19 | - **Feature request** 20 | 21 | We are constantly improving ``celer`` and we would like to align that with our user needs. 22 | Hence, we highly appreciate any suggestion to extend or add new features to ``celer``. 23 | You can use the `issue section `_ to make suggestions. 24 | 25 | 26 | - **Pull request** 27 | 28 | If you fixed a bug, added new features, or even corrected a small typo in the documentation. 29 | You can submit a `pull request `_ to integrate your changes 30 | and we will reach out to you as soon as possible. 31 | 32 | 33 | 34 | Setup ``celer`` on your local machine 35 | --------------------------------------- 36 | 37 | Here are key steps to help you setup ``celer`` on your local machine in case you wanted to 38 | contribute with code or documentation to celer. 39 | 40 | 1. Fork the repository and afterwards run the following command to clone it on your local machine 41 | 42 | .. code-block:: shell 43 | 44 | $ git clone https://github.com/{YOUR_GITHUB_USERNAME}/celer.git 45 | 46 | 47 | 2. ``cd`` to ``celer`` directory and install it in edit mode by running 48 | 49 | .. code-block:: shell 50 | 51 | $ cd celer 52 | $ pip install -e . 53 | 54 | 55 | 3. To run the gallery examples and build the documentation, run the followings 56 | 57 | .. code-block:: shell 58 | 59 | $ cd doc 60 | $ pip install -e .[doc] 61 | $ make html 62 | 63 | 64 | .. note:: 65 | You should have a `gcc compiler `_ 66 | installed in your local machine since ``celer`` uses Cython. -------------------------------------------------------------------------------- /doc/get_started.rst: -------------------------------------------------------------------------------- 1 | =========== 2 | Get started 3 | =========== 4 | 5 | In this starter examples, we will fit a Lasso estimator on a toy dataset. 6 | 7 | Beforehand, make sure to install ``celer``:: 8 | 9 | $ pip install -U celer 10 | 11 | 12 | 13 | Generate a toy dataset 14 | ---------------------- 15 | 16 | ``celer`` comes with a module, :ref:`Datasets fetchers`, 17 | that expose several functions to fetch/generate datasets. 18 | We are going to use ``make_correlated_data`` to generate a toy dataset. 19 | 20 | .. code-block:: python 21 | 22 | # imports 23 | from celer.datasets import make_correlated_data 24 | from sklearn.model_selection import train_test_split 25 | 26 | # generate the toy dataset 27 | X, y, _ = make_correlated_data(n_samples=500, n_features=5000) 28 | # split the dataset into training and test sets 29 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, 30 | random_state=42) 31 | 32 | 33 | 34 | Fit and score a Lasso estimator 35 | ------------------------------- 36 | 37 | ``celer`` exposes easy-to-use to use estimators as it was designed under the ``scikit-learn`` 38 | API. ``celer`` also integrates well with it (e.g. the ``Pipeline`` and ``GridSearchCV``). 39 | 40 | 41 | .. code-block:: python 42 | 43 | # import model 44 | from celer import Lasso 45 | 46 | # init and fit 47 | model = Lasso() 48 | model.fit(X_train, y_train) 49 | 50 | # print R² 51 | print(model.score(X_test, y_test)) 52 | 53 | 54 | 55 | Perform cross-validation 56 | ------------------------ 57 | 58 | ``celer`` Lasso estimator comes with native cross-validation. 59 | The following snippets performs cross-validation on a grid 100 ``alphas`` using 5 folds. 60 | And look how fast ``celer`` is compared to the ``scikit-learn``. 61 | 62 | .. code-block:: python 63 | 64 | # imports 65 | import time 66 | from celer import LassoCV 67 | from sklearn.linear_model import LassoCV as sk_LassoCV 68 | 69 | # fit for celer 70 | start = time.time() 71 | celer_lassoCV = LassoCV(n_alphas=100, cv=5) 72 | celer_lassoCV.fit(X, y) 73 | print(f"time elapsed for celer LassoCV: {time.time() - start}") 74 | 75 | # fit for scikit-learn 76 | start = time.time() 77 | sk_lassoCV = sk_LassoCV(n_alphas=100, cv=5) 78 | sk_lassoCV.fit(X, y) 79 | print(f"time elapsed for scikit-learn LassoCV: {time.time() - start}") 80 | 81 | 82 | .. rst-class:: sphx-glr-script-out 83 | 84 | Out: 85 | 86 | .. code-block:: none 87 | 88 | time elapsed for celer LassoCV: 5.062559127807617 89 | time elapsed for scikit-learn LassoCV: 27.427260398864746 90 | 91 | 92 | 93 | Further links 94 | ------------- 95 | 96 | This was just a starter example. Get familiar with ``celer`` by browsing its :ref:`API documentation` or 97 | explore the :ref:`Examples Gallery`, which includes examples on real-life datasets as well as 98 | timing comparison with other solvers. -------------------------------------------------------------------------------- /doc/github_link.py: -------------------------------------------------------------------------------- 1 | # this code is a copy/paste of 2 | # https://github.com/scikit-learn/scikit-learn/blob/ 3 | # b0b8a39d8bb80611398e4c57895420d5cb1dfe09/doc/sphinxext/github_link.py 4 | 5 | from operator import attrgetter 6 | import inspect 7 | import subprocess 8 | import os 9 | import sys 10 | from functools import partial 11 | 12 | REVISION_CMD = "git rev-parse --short HEAD" 13 | 14 | 15 | def _get_git_revision(): 16 | try: 17 | revision = subprocess.check_output(REVISION_CMD.split()).strip() 18 | except (subprocess.CalledProcessError, OSError): 19 | print("Failed to execute git to get revision") 20 | return None 21 | return revision.decode("utf-8") 22 | 23 | 24 | def _linkcode_resolve(domain, info, package, url_fmt, revision): 25 | """Determine a link to online source for a class/method/function 26 | This is called by sphinx.ext.linkcode 27 | An example with a long-untouched module that everyone has 28 | >>> _linkcode_resolve('py', {'module': 'tty', 29 | ... 'fullname': 'setraw'}, 30 | ... package='tty', 31 | ... url_fmt='http://hg.python.org/cpython/file/' 32 | ... '{revision}/Lib/{package}/{path}#L{lineno}', 33 | ... revision='xxxx') 34 | 'http://hg.python.org/cpython/file/xxxx/Lib/tty/tty.py#L18' 35 | """ 36 | 37 | if revision is None: 38 | return 39 | if domain not in ("py", "pyx"): 40 | return 41 | if not info.get("module") or not info.get("fullname"): 42 | return 43 | 44 | class_name = info["fullname"].split(".")[0] 45 | module = __import__(info["module"], fromlist=[class_name]) 46 | obj = attrgetter(info["fullname"])(module) 47 | 48 | # Unwrap the object to get the correct source 49 | # file in case that is wrapped by a decorator 50 | obj = inspect.unwrap(obj) 51 | 52 | try: 53 | fn = inspect.getsourcefile(obj) 54 | except Exception: 55 | fn = None 56 | if not fn: 57 | try: 58 | fn = inspect.getsourcefile(sys.modules[obj.__module__]) 59 | except Exception: 60 | fn = None 61 | if not fn: 62 | return 63 | 64 | fn = os.path.relpath(fn, start=os.path.dirname(__import__(package).__file__)) 65 | try: 66 | lineno = inspect.getsourcelines(obj)[1] 67 | except Exception: 68 | lineno = "" 69 | return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno) 70 | 71 | 72 | def make_linkcode_resolve(package, url_fmt): 73 | """Returns a linkcode_resolve function for the given URL format 74 | revision is a git commit reference (hash or name) 75 | package is the name of the root module of the package 76 | url_fmt is along the lines of ('https://github.com/USER/PROJECT/' 77 | 'blob/{revision}/{package}/' 78 | '{path}#L{lineno}') 79 | """ 80 | revision = _get_git_revision() 81 | return partial( 82 | _linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt 83 | ) 84 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | celer 2 | ===== 3 | 4 | 5 | A fast solver for Lasso-like problems 6 | ------------------------------------- 7 | 8 | ``celer`` is a Python package that solves Lasso-like problems and provides estimators 9 | that follow the ``scikit-learn`` API. Thanks to a tailored implementation, 10 | ``celer`` provides a fast solver that tackles large-scale datasets with millions of features 11 | **up to 100 times faster than** ``scikit-learn``. 12 | 13 | Currently, the package handles the following problems: 14 | 15 | .. list-table:: The supported lasso-like problems 16 | :header-rows: 1 17 | 18 | * - Problem 19 | - Support of weights 20 | - Native cross-validation 21 | * - Lasso 22 | - ✓ 23 | - ✓ 24 | * - ElasticNet 25 | - ✓ 26 | - ✓ 27 | * - Group Lasso 28 | - ✓ 29 | - ✓ 30 | * - Multitask Lasso 31 | - ✕ 32 | - ✓ 33 | * - Sparse Logistic regression 34 | - ✕ 35 | - ✕ 36 | 37 | 38 | Why ``celer``? 39 | -------------- 40 | 41 | ``celer`` is specially designed to handle Lasso-like problems which enable it to solve them quickly. 42 | ``celer`` comes particularly with 43 | 44 | - automated parallel cross-validation 45 | - support of sparse and dense data 46 | - optional feature centering and normalization 47 | - unpenalized intercept fitting 48 | 49 | ``celer`` also provides easy-to-use estimators as it is designed under the ``scikit-learn`` API. 50 | 51 | 52 | Install ``celer`` 53 | ----------------- 54 | 55 | ``celer`` can be easily installed through the Python package manager ``pip``. 56 | To get the laster version of the package, run:: 57 | 58 | $ pip install -U celer 59 | 60 | Head directly to the :ref:`Get started` page to get a hands-on example of how to use ``celer``. 61 | 62 | 63 | 64 | Cite 65 | ---- 66 | 67 | ``celer`` is an open source package licensed under 68 | the `BSD 3-Clause License `_. 69 | Hence, you are free to use it. And if you do so, do not forget to cite: 70 | 71 | 72 | .. code-block:: bibtex 73 | 74 | @InProceedings{pmlr-v80-massias18a, 75 | title = {Celer: a Fast Solver for the Lasso with Dual Extrapolation}, 76 | author = {Massias, Mathurin and Gramfort, Alexandre and Salmon, Joseph}, 77 | booktitle = {Proceedings of the 35th International Conference on Machine Learning}, 78 | pages = {3321--3330}, 79 | year = {2018}, 80 | volume = {80}, 81 | } 82 | 83 | 84 | .. code-block:: bibtex 85 | 86 | @article{massias2020dual, 87 | author = {Mathurin Massias and Samuel Vaiter and Alexandre Gramfort and Joseph Salmon}, 88 | title = {Dual Extrapolation for Sparse GLMs}, 89 | journal = {Journal of Machine Learning Research}, 90 | year = {2020}, 91 | volume = {21}, 92 | number = {234}, 93 | pages = {1-33}, 94 | url = {http://jmlr.org/papers/v21/19-587.html} 95 | } 96 | 97 | ``celer`` is a outcome of perseverant research. Here are the links to the original papers: 98 | 99 | - `Celer: a Fast Solver for the Lasso with Dual Extrapolation `_ 100 | - `Dual Extrapolation for Sparse GLMs `_ 101 | 102 | 103 | 104 | Explore the documentation 105 | ------------------------- 106 | 107 | .. toctree:: 108 | :maxdepth: 1 109 | 110 | get_started.rst 111 | api.rst 112 | contribute.rst 113 | auto_examples/index.rst 114 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _general_examples: 2 | 3 | Examples Gallery 4 | ================ 5 | 6 | .. contents:: Contents 7 | :local: 8 | :depth: 3 9 | -------------------------------------------------------------------------------- /examples/plot_finance_path.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================================================= 3 | Lasso path computation on Finance/log1p dataset 4 | ======================================================= 5 | 6 | The example runs the Celer algorithm on the Finance dataset 7 | which is a large sparse dataset. 8 | 9 | Running time is not compared with the scikit-learn 10 | implementation as it makes the example too long to run. 11 | """ 12 | 13 | import time 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import matplotlib.pyplot as plt 18 | from libsvmdata import fetch_libsvm 19 | 20 | from celer import celer_path 21 | 22 | print(__doc__) 23 | 24 | print("*** Warning: this example may take more than 5 minutes to run ***") 25 | X, y = fetch_libsvm('finance') 26 | y -= np.mean(y) 27 | n_samples, n_features = X.shape 28 | alpha_max = np.max(np.abs(X.T.dot(y))) / n_samples 29 | print("Dataset size: %d samples, %d features" % X.shape) 30 | 31 | # construct grid of regularization parameters alpha 32 | n_alphas = 11 33 | alphas = alpha_max * np.geomspace(1, 0.1, n_alphas) 34 | 35 | ############################################################################### 36 | # Run Celer on a grid of regularization parameters, for various tolerances: 37 | tols = [1e-2, 1e-4, 1e-6] 38 | results = np.zeros([1, len(tols)]) 39 | gaps = np.zeros((len(tols), len(alphas))) 40 | 41 | print("Starting path computation...") 42 | for tol_ix, tol in enumerate(tols): 43 | t0 = time.time() 44 | res = celer_path(X, y, 'lasso', alphas=alphas, 45 | tol=tol, prune=True, verbose=1) 46 | results[0, tol_ix] = time.time() - t0 47 | _, coefs, gaps[tol_ix] = res 48 | 49 | 50 | labels = [r"\sc{Celer}"] 51 | figsize = (4, 3.5) 52 | 53 | df = pd.DataFrame(results.T, columns=["Celer"]) 54 | df.index = [str(tol) for tol in tols] 55 | df.plot.bar(rot=0, figsize=figsize) 56 | plt.xlabel("stopping tolerance") 57 | plt.ylabel("path computation time (s)") 58 | plt.tight_layout() 59 | plt.show(block=False) 60 | 61 | ############################################################################### 62 | # Measure the influence of regularization on the sparsity of the solutions: 63 | 64 | fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True) 65 | plt.bar(np.arange(n_alphas), (coefs != 0).sum(axis=0)) 66 | plt.title("Sparsity of solution along regularization path") 67 | ax.set_ylabel(r"$||\hat w||_0$") 68 | ax.set_xlabel(r"$\lambda / \lambda_{\mathrm{max}}$") 69 | ax.set_yscale('log') 70 | ax.set_xticks(np.arange(n_alphas)[::2]) 71 | ax.set_xticklabels(map(lambda x: "%.2f" % x, alphas[::2] / alphas[0])) 72 | plt.show(block=False) 73 | 74 | 75 | ############################################################################### 76 | # Check convergence guarantees: gap is inferior to tolerance 77 | 78 | df = pd.DataFrame(gaps.T, columns=map(lambda x: r"tol=%.0e" % x, tols)) 79 | df.index = map(lambda x: "%.2f" % x, alphas / alphas[0]) 80 | ax = df.plot.bar(figsize=(7, 4)) 81 | ax.set_ylabel("duality gap reached") 82 | ax.set_xlabel(r"$\lambda / \lambda_{\mathrm{max}}$") 83 | ax.set_yscale('log') 84 | ax.set_yticks(tols) 85 | plt.tight_layout() 86 | plt.show(block=False) 87 | -------------------------------------------------------------------------------- /examples/plot_group_lasso.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================================================== 3 | Run GroupLasso and GroupLasso CV for structured sparse recovery 4 | =============================================================== 5 | 6 | The example runs the GroupLasso scikit-learn like estimators. 7 | """ 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from celer import GroupLassoCV, LassoCV 13 | from celer.datasets import make_correlated_data 14 | 15 | print(__doc__) 16 | 17 | # Generating X, y, and true regression coefs with 4 groups of 5 non-zero values 18 | 19 | n_samples, n_features = 100, 50 20 | 21 | w_true = np.zeros(n_features) 22 | w_true[:5] = 1 23 | w_true[10:15] = 1 24 | w_true[30:35] = -1 25 | w_true[45:] = 1 26 | X, y, w_true = make_correlated_data( 27 | n_samples, n_features, w_true=w_true, snr=5, random_state=0) 28 | 29 | ############################################################################### 30 | # Get group Lasso's optimal alpha for prediction by cross validation 31 | 32 | groups = 5 # groups are contiguous and of size 5 33 | # irregular groups are also supported, 34 | group_lasso = GroupLassoCV(groups=groups) 35 | group_lasso.fit(X, y) 36 | 37 | print("Estimated regularization parameter alpha: %s" % group_lasso.alpha_) 38 | 39 | fig = plt.figure(figsize=(6, 3), constrained_layout=True) 40 | plt.semilogx(group_lasso.alphas_, group_lasso.mse_path_, ':') 41 | plt.semilogx(group_lasso.alphas_, group_lasso.mse_path_.mean(axis=-1), 'k', 42 | label='Average across the folds', linewidth=2) 43 | plt.axvline(group_lasso.alpha_, linestyle='--', color='k', 44 | label='alpha: CV estimate') 45 | 46 | plt.legend() 47 | 48 | plt.xlabel(r'$\alpha$') 49 | plt.ylabel('Mean square prediction error') 50 | plt.show(block=False) 51 | 52 | 53 | lasso = LassoCV().fit(X, y) 54 | 55 | 56 | ############################################################################### 57 | # Show optimal regression vector for prediction, obtained by cross validation 58 | 59 | fig = plt.figure(figsize=(8, 3), constrained_layout=True) 60 | m, s, _ = plt.stem(np.where(w_true)[0], w_true[w_true != 0], 61 | label=r"true regression coefficients") 62 | labels = ["LassoCV-estimated regression coefficients", 63 | "GroupLassoCV-estimated regression coefficients"] 64 | colors = [u'#ff7f0e', u'#2ca02c'] 65 | 66 | for w, label, color in zip([lasso.coef_, group_lasso.coef_], labels, colors): 67 | m, s, _ = plt.stem(np.where(w)[0], w[w != 0], label=label, markerfmt='x') 68 | plt.setp([m, s], color=color) 69 | plt.xlabel("feature index") 70 | plt.legend(fontsize=12) 71 | plt.show(block=False) 72 | -------------------------------------------------------------------------------- /examples/plot_lasso_cv.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================================================== 3 | Run LassoCV for cross-validation on the Leukemia dataset 4 | ======================================================== 5 | 6 | The example runs the LassoCV scikit-learn like estimator 7 | using the Celer algorithm. 8 | """ 9 | 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | from sklearn.datasets import fetch_openml 14 | from sklearn.model_selection import KFold 15 | 16 | from celer import LassoCV 17 | 18 | print(__doc__) 19 | 20 | print("Loading data...") 21 | dataset = fetch_openml("leukemia") 22 | X = np.asfortranarray(dataset.data.astype(float)) 23 | y = 2 * ((dataset.target == "AML") - 0.5) 24 | y -= np.mean(y) 25 | y /= np.std(y) 26 | 27 | kf = KFold(shuffle=True, n_splits=3, random_state=0) 28 | model = LassoCV(cv=kf, n_jobs=3) 29 | model.fit(X, y) 30 | 31 | print("Estimated regularization parameter alpha: %s" % model.alpha_) 32 | 33 | ############################################################################### 34 | # Display results 35 | 36 | plt.figure(figsize=(5, 3), constrained_layout=True) 37 | plt.semilogx(model.alphas_, model.mse_path_, ':') 38 | plt.semilogx(model.alphas_, model.mse_path_.mean(axis=-1), 'k', 39 | label='Average across the folds', linewidth=2) 40 | plt.axvline(model.alpha_, linestyle='--', color='k', 41 | label='alpha: CV estimate') 42 | 43 | plt.legend() 44 | 45 | plt.xlabel(r'$\alpha$') 46 | plt.ylabel('Mean square prediction error') 47 | plt.show(block=False) 48 | -------------------------------------------------------------------------------- /examples/plot_leukemia_path.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================== 3 | Lasso path computation on Leukemia dataset 4 | ========================================== 5 | 6 | The example runs the Celer algorithm for the Lasso on the Leukemia 7 | dataset which is a dense dataset. 8 | 9 | Running time is compared with the scikit-learn implementation. 10 | """ 11 | 12 | import time 13 | import pandas as pd 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | 17 | from sklearn.linear_model import lasso_path 18 | from sklearn.datasets import fetch_openml 19 | 20 | from celer import celer_path 21 | 22 | print(__doc__) 23 | 24 | print("Loading data...") 25 | dataset = fetch_openml("leukemia") 26 | X = np.asfortranarray(dataset.data.astype(float)) 27 | y = 2 * ((dataset.target != "AML") - 0.5) 28 | n_samples = len(y) 29 | 30 | y -= np.mean(y) 31 | y /= np.std(y) 32 | 33 | print("Starting path computation...") 34 | alpha_max = np.max(np.abs(X.T.dot(y))) / n_samples 35 | 36 | n_alphas = 100 37 | alphas = alpha_max * np.geomspace(1, 0.01, n_alphas) 38 | 39 | tols = [1e-2, 1e-3, 1e-4] 40 | results = np.zeros([2, len(tols)]) 41 | for tol_ix, tol in enumerate(tols): 42 | t0 = time.time() 43 | _, coefs, gaps = celer_path( 44 | X, y, pb='lasso', alphas=alphas, tol=tol, prune=True) 45 | results[0, tol_ix] = time.time() - t0 46 | print('Celer time: %.2f s' % results[0, tol_ix]) 47 | 48 | t0 = time.time() 49 | _, coefs, dual_gaps = lasso_path( 50 | X, y, tol=tol, alphas=alphas, max_iter=10_000) 51 | results[1, tol_ix] = time.time() - t0 52 | 53 | df = pd.DataFrame(results.T, columns=["Celer", "scikit-learn"]) 54 | df.index = [str(tol) for tol in tols] 55 | df.plot.bar(rot=0) 56 | plt.xlabel("stopping tolerance") 57 | plt.ylabel("path computation time (s)") 58 | plt.tight_layout() 59 | plt.show() 60 | -------------------------------------------------------------------------------- /examples/plot_logreg_timings.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================================================== 3 | Compare LogisticRegression solver with sklearn's liblinear backend 4 | ================================================================== 5 | """ 6 | 7 | import time 8 | import warnings 9 | import numpy as np 10 | from numpy.linalg import norm 11 | import matplotlib.pyplot as plt 12 | from sklearn import linear_model 13 | from libsvmdata import fetch_libsvm 14 | 15 | from celer import LogisticRegression 16 | 17 | warnings.filterwarnings("ignore", message="Objective did not converge") 18 | warnings.filterwarnings("ignore", message="Liblinear failed to converge") 19 | 20 | X, y = fetch_libsvm("news20.binary") 21 | 22 | C_min = 2 / norm(X.T @ y, ord=np.inf) 23 | C = 20 * C_min 24 | 25 | 26 | def pobj_logreg(w): 27 | return np.sum(np.log(1 + np.exp(-y * (X @ w)))) + 1. / C * norm(w, ord=1) 28 | 29 | 30 | pobj_celer = [] 31 | t_celer = [] 32 | 33 | for n_iter in range(10): 34 | t0 = time.time() 35 | clf = LogisticRegression( 36 | C=C, solver="celer-pn", max_iter=n_iter, tol=0).fit(X, y) 37 | t_celer.append(time.time() - t0) 38 | w_celer = clf.coef_.ravel() 39 | pobj_celer.append(pobj_logreg(w_celer)) 40 | 41 | pobj_celer = np.array(pobj_celer) 42 | 43 | 44 | pobj_libl = [] 45 | t_libl = [] 46 | 47 | for n_iter in np.arange(0, 50, 10): 48 | t0 = time.time() 49 | clf = linear_model.LogisticRegression( 50 | C=C, solver="liblinear", penalty='l1', fit_intercept=False, 51 | max_iter=n_iter, random_state=0, tol=1e-10).fit(X, y) 52 | t_libl.append(time.time() - t0) 53 | w_libl = clf.coef_.ravel() 54 | pobj_libl.append(pobj_logreg(w_libl)) 55 | 56 | pobj_libl = np.array(pobj_libl) 57 | 58 | p_star = min(pobj_celer.min(), pobj_libl.min()) 59 | 60 | plt.close("all") 61 | fig = plt.figure(figsize=(4, 2), constrained_layout=True) 62 | plt.semilogy(t_celer, pobj_celer - p_star, label="Celer-PN") 63 | plt.semilogy(t_libl, pobj_libl - p_star, label="liblinear") 64 | plt.legend() 65 | plt.xlabel("Time (s)") 66 | plt.ylabel("objective suboptimality") 67 | plt.show(block=False) 68 | -------------------------------------------------------------------------------- /examples/plot_multitask_lasso_cv.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================================== 3 | Run MultitaskLassoCV and compare performance with scikit-learn 4 | ============================================================== 5 | 6 | The example runs the MultitaskLassoCV scikit-learn like estimator. 7 | """ 8 | import time 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from celer import MultiTaskLassoCV 13 | from numpy.linalg import norm 14 | from sklearn.utils import check_random_state 15 | from sklearn import linear_model 16 | 17 | rng = check_random_state(0) 18 | 19 | ############################################################################### 20 | # Generate some 2D coefficients with sine waves with random frequency and phase 21 | n_samples, n_features, n_tasks = 100, 500, 50 22 | n_relevant_features = 50 23 | support = rng.choice(n_features, n_relevant_features, replace=False) 24 | coef = np.zeros((n_tasks, n_features)) 25 | times = np.linspace(0, 2 * np.pi, n_tasks) 26 | for k in support: 27 | coef[:, k] = np.sin((1. + rng.randn(1)) * times + 3 * rng.randn(1)) 28 | 29 | 30 | X = rng.randn(n_samples, n_features) 31 | Y = np.dot(X, coef.T) + rng.randn(n_samples, n_tasks) 32 | Y /= norm(Y, ord='fro') 33 | 34 | 35 | ############################################################################### 36 | # Fit with sklearn and celer, using the same API 37 | params = dict(tol=1e-6, cv=4, n_jobs=1, n_alphas=20) 38 | t0 = time.perf_counter() 39 | clf = MultiTaskLassoCV(**params).fit(X, Y) 40 | t_celer = time.perf_counter() - t0 41 | 42 | t0 = time.perf_counter() 43 | clf_sklearn = linear_model.MultiTaskLassoCV(**params).fit(X, Y) 44 | t_sklearn = time.perf_counter() - t0 45 | 46 | ############################################################################### 47 | # Celer is faster 48 | print("Time for celer : %.2f s" % t_celer) 49 | print("Time for sklearn: %.2f s" % t_sklearn) 50 | 51 | ############################################################################### 52 | # Both packages find the same solution 53 | print("Celer's optimal regularizer : %s" % clf.alpha_) 54 | print("Sklearn's optimal regularizer: %s" % clf_sklearn.alpha_) 55 | 56 | print("Relative norm difference between optimal coefs: %.2f %%" % 57 | (100 * norm(clf.coef_ - clf_sklearn.coef_) / norm(clf.coef_))) 58 | 59 | ############################################################################### 60 | fig, axarr = plt.subplots(2, 1, constrained_layout=True) 61 | axarr[0].spy(clf.coef_, aspect="auto") 62 | axarr[0].xaxis.tick_bottom() 63 | axarr[0].set_title("celer") 64 | axarr[0].set_ylabel("tasks") 65 | axarr[1].spy(clf_sklearn.coef_, aspect="auto") 66 | axarr[1].xaxis.tick_bottom() 67 | axarr[1].set_title("sklearn") 68 | plt.suptitle("Sparsity patterns") 69 | plt.ylabel("tasks") 70 | plt.xlabel("features") 71 | plt.show(block=False) 72 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "numpy>=1.12", "scipy>=0.18.0", "Cython>=0.26"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "celer" 8 | dependencies =[ 9 | 'libsvmdata>=0.3', 'scikit-learn>=1.0', 'xarray', 'download', 'tqdm' 10 | ] 11 | description = "A fast algorithm with dual extrapolation for sparse problems" 12 | authors = [ 13 | {name = "Mathurin Massias", email = "mathurin.massias@gmail.com"}, 14 | ] 15 | readme = {file = "README.md", content-type = "text/markdown"} 16 | license = {text = "BSD (3-clause)"} 17 | dynamic = ["version"] 18 | 19 | 20 | [tool.setuptools.dynamic] 21 | version = {attr = "celer.__version__"} 22 | 23 | 24 | [project.urls] 25 | documentation = "https://mathurinm.github.io/celer/" 26 | repository = "https://github.com/mathurinm/celer.git" 27 | 28 | 29 | [project.optional-dependencies] 30 | doc = [ 31 | "numpydoc", "pandas", "pillow", "matplotlib", 32 | "furo", "sphinx-copybutton", "sphinx-gallery" 33 | ] 34 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = __init__.py, ./doc/* 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools.command.build_ext import build_ext 2 | from setuptools import setup, Extension, find_packages 3 | 4 | import numpy as np 5 | 6 | 7 | setup(packages=find_packages(), 8 | cmdclass={'build_ext': build_ext}, 9 | ext_modules=[ 10 | Extension('celer.lasso_fast', 11 | sources=['celer/lasso_fast.pyx'], 12 | language='c++', 13 | include_dirs=[np.get_include()], 14 | extra_compile_args=["-O3"]), 15 | Extension('celer.cython_utils', 16 | sources=['celer/cython_utils.pyx'], 17 | language='c++', 18 | include_dirs=[np.get_include()], 19 | extra_compile_args=["-O3"]), 20 | Extension('celer.PN_logreg', 21 | sources=['celer/PN_logreg.pyx'], 22 | language='c++', 23 | include_dirs=[np.get_include()], 24 | extra_compile_args=["-O3"]), 25 | Extension('celer.multitask_fast', 26 | sources=['celer/multitask_fast.pyx'], 27 | language='c++', 28 | include_dirs=[np.get_include()], 29 | extra_compile_args=["-O3"]), 30 | Extension('celer.group_fast', 31 | sources=['celer/group_fast.pyx'], 32 | language='c++', 33 | include_dirs=[np.get_include()], 34 | extra_compile_args=["-O3"]), 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /whats_new.rst: -------------------------------------------------------------------------------- 1 | Version 0.6.1 2 | ----------- 3 | 4 | Changelog 5 | ~~~~~~~~~ 6 | - Rely on the libsvmdata package to donwload datasets from LIBSVM. 7 | 8 | 9 | Version 0.6 10 | ----------- 11 | 12 | Changelog 13 | ~~~~~~~~~ 14 | - Added `weights` to the Lasso estimator. 15 | - Added `make_correlated_data` to the `datasets` module, to generate simulations with Toeplitz correlated design. 16 | 17 | 18 | Version 0.5 19 | ----------- 20 | 21 | Changelog 22 | ~~~~~~~~~ 23 | - Structure of `~/celer_data/` folder changed, consider deleting it and redownloading the datasets. 24 | - Added module `datasets`, supporting more datasets (climate for Sparse Group Lasso) 25 | - Removed `celer_logreg` function, use `celer` instead with `pb="logreg"` 26 | - Added sklearn-like `LogisticRegression` class. 27 | 28 | Version 0.4 29 | ----------- 30 | 31 | Changelog 32 | ~~~~~~~~~ 33 | - Faster homotopy by precomputing norms_X_col and passing residuals from one alpha to the next. 34 | 35 | 36 | Version 0.3.1 37 | ------------- 38 | 39 | Changelog 40 | ~~~~~~~~~ 41 | - Fixed bugs in screening. 42 | --------------------------------------------------------------------------------