├── tests ├── __init__.py ├── utils.py ├── test_experimental.py ├── test_linear_model_ppd.py ├── test_decomposition_dct.py ├── test_decomposition_dict_learning.py ├── test_preprocessing_whitening.py ├── test_notebook.py ├── test_linear_model_hmlasso.py ├── test_feature_extraction_image.py ├── test_decomposition_ksvd.py └── test_linear_model_admm.py ├── MANIFEST.in ├── spmimage ├── __init__.py ├── feature_extraction │ ├── __init__.py │ └── image.py ├── preprocessing │ ├── __init__.py │ └── data.py ├── experimental │ ├── __init__.py │ └── enable_ppd.py ├── linear_model │ ├── __init__.py │ ├── _ppd.py │ ├── admm.py │ └── hmlasso.py └── decomposition │ ├── __init__.py │ ├── dct.py │ ├── dict_learning.py │ └── ksvd.py ├── requirements.txt ├── examples_requirements.txt ├── .gitignore ├── .github ├── pull_request_template.md ├── issue_template.md └── workflows │ └── unittest.yml ├── setup.py ├── CHANGELOG.md ├── README.md └── examples └── group-lasso.ipynb /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | -------------------------------------------------------------------------------- /spmimage/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.11' -------------------------------------------------------------------------------- /spmimage/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | scikit-learn>=0.24.0 4 | joblib -------------------------------------------------------------------------------- /spmimage/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import WhiteningScaler 2 | 3 | __all__ = ['WhiteningScaler'] 4 | -------------------------------------------------------------------------------- /examples_requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | jupyter 3 | jupyter_contrib_nbextensions 4 | scikit-image >= 0.18 5 | seaborn -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pyc 2 | **/*.egg-info 3 | /venv 4 | /examples_venv 5 | /build 6 | /dist 7 | /.idea 8 | **/.ipynb_checkpoints 9 | /.python-version 10 | -------------------------------------------------------------------------------- /spmimage/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`spmimage.experimental` module provides importable modules that enable 3 | the use of experimental features or estimators. 4 | The features and estimators that are experimental aren't subject to 5 | deprecation cycles. Use them at your own risks! 6 | """ 7 | -------------------------------------------------------------------------------- /spmimage/linear_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .admm import LassoADMM, FusedLassoADMM, TrendFilteringADMM, QuadraticTrendFilteringADMM 2 | from .hmlasso import HMLasso 3 | 4 | __all__ = [ 5 | 'LassoADMM', 6 | 'FusedLassoADMM', 7 | 'TrendFilteringADMM', 8 | 'QuadraticTrendFilteringADMM', 9 | 'HMLasso', 10 | ] 11 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | * Write a brief summary of the changes in the code 4 | * You can include details of how the code was changed 5 | * Include all information regarding the changes 6 | 7 | # Related Issues/Wiki/Resources 8 | 9 | * Put related issues 10 | * You can put related wiki or external resources 11 | -------------------------------------------------------------------------------- /spmimage/decomposition/__init__.py: -------------------------------------------------------------------------------- 1 | from .dct import generate_dct_dictionary 2 | from .dict_learning import sparse_encode_with_mask 3 | from .dict_learning import sparse_encode_with_l21_norm 4 | from .ksvd import KSVD 5 | 6 | __all__ = [ 7 | 'KSVD', 8 | 'sparse_encode_with_mask', 9 | 'sparse_encode_with_l21_norm', 10 | 'generate_dct_dictionary' 11 | ] 12 | -------------------------------------------------------------------------------- /.github/issue_template.md: -------------------------------------------------------------------------------- 1 | # Issue Description 2 | 3 | Describe the issue here. You can make a list as follows to clarify what should be done in this issue. 4 | 5 | * [ ] Put up a simple task like 6 | * [ ] fix bugs in spmimage.xxx.yyy 7 | * [ ] You can also use this to include deliverable and other requirements 8 | 9 | # References 10 | 11 | * [Put link to external resources related to this issue](https://github.com/hacarus/spm-lib) -------------------------------------------------------------------------------- /spmimage/experimental/enable_ppd.py: -------------------------------------------------------------------------------- 1 | """Enables LassoPPD 2 | The API and results of this estimator might change without any deprecation cycle. 3 | Importing this file dynamically sets :class:`spmimage.linear_model.LassoPPD` 4 | as an attribute of the impute module:: 5 | >>> # explicitly require this experimental feature 6 | >>> from spmimage.experimental import enable_ppd import # noqa 7 | >>> # now you can import normally from impute 8 | >>> from spmimage.linear_model import LassoPPD 9 | """ 10 | 11 | from ..linear_model._ppd import LassoPPD 12 | from .. import linear_model 13 | 14 | linear_model.LassoPPD = LassoPPD 15 | linear_model.__all__ += ['LassoPPD'] 16 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | def generate_dictionary_and_samples(n_samples: int, n_features: int, n_components: int, n_nonzero_coefs: int) \ 7 | -> Tuple[np.ndarray, np.ndarray]: 8 | # random dictionary base 9 | A0 = np.random.randn(n_components, n_features) 10 | A0 = np.dot(A0, np.diag(1. / np.sqrt(np.diag(np.dot(A0.T, A0))))) 11 | 12 | X = np.zeros((n_samples, n_features)) 13 | for i in range(n_samples): 14 | # select n_nonzero_coefs components from dictionary 15 | X[i, :] = np.dot(np.random.randn(n_nonzero_coefs), 16 | A0[np.random.permutation(range(n_components))[:n_nonzero_coefs], :]) 17 | return A0, X 18 | -------------------------------------------------------------------------------- /tests/test_experimental.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import spmimage.linear_model 4 | import importlib 5 | 6 | 7 | class TestExperimental(unittest.TestCase): 8 | 9 | def test_ppd_disabled(self): 10 | # if you see this method fail, please check if you import LassoPPD in global in other test cases 11 | importlib.reload(spmimage.linear_model) 12 | with self.assertRaises(ImportError) as ctx: 13 | from spmimage.linear_model import LassoPPD 14 | self.assertIsNotNone(ctx.exception) 15 | 16 | def test_ppd_enabled(self): 17 | importlib.reload(spmimage.linear_model) 18 | from spmimage.experimental import enable_ppd 19 | from spmimage.linear_model import LassoPPD 20 | self.assertIsNotNone(LassoPPD) 21 | 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from pathlib import Path 3 | import sys 4 | 5 | from spmimage import __version__ as version 6 | 7 | LICENSE = 'Modified BSD' 8 | 9 | if not (3, 5) <= sys.version_info[:2]: 10 | raise Exception('spm-image requires Python 3.5 or later. \n Now running on {0}'.format(sys.version)) 11 | 12 | with Path('requirements.txt').open() as f: 13 | INSTALL_REQUIRES = [line.strip() for line in f.readlines() if line] 14 | 15 | setup( 16 | name='spm-image', 17 | author='Takashi Someda', 18 | author_email='takashi@hacarus.com', 19 | url='https://github.com/hacarus/spm-image', 20 | description='Sparse modeling and Compressive sensing in Python', 21 | version=version, 22 | packages=find_packages(), 23 | install_requires=INSTALL_REQUIRES, 24 | test_suite='tests', 25 | license=LICENSE 26 | ) 27 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | 3 | ## v0.0.11 4 | 5 | * Implement notebook unittest ([#115](https://github.com/hacarus/spm-image/issues/115)) 6 | * Prepare change log ([#104](https://github.com/hacarus/spm-image/issues/104)) 7 | * Implement HMLasso ([#73](https://github.com/hacarus/spm-image/issues/73)) 8 | * Fix ksvd inpainting example ([#76](https://github.com/hacarus/spm-image/issues/76)) 9 | * Implement DCT dictionary generator ([#79](https://github.com/hacarus/spm-image/issues/79)) 10 | * Implement Dictionary initialization for KSVD ([#80](https://github.com/hacarus/spm-image/issues/80)) 11 | * Fix masked ksvd bug ([#83](https://github.com/hacarus/spm-image/issues/83)) 12 | * Implement L2,1 norm constrained regression ([#88](https://github.com/hacarus/spm-image/issues/88)) 13 | * Support `scikit-learn >= 0.24.0` ([#98](https://github.com/hacarus/spm-image/pull/98)) 14 | * Fix fused lasso ([#99](https://github.com/hacarus/spm-image/issues/99)) 15 | 16 | ## v0.0.1 ~ v0.0.10 17 | 18 | * available for `scikit-learn <= 0.23` -------------------------------------------------------------------------------- /.github/workflows/unittest.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: [main, master, development] 4 | pull_request: 5 | branches: [main, master, development] 6 | 7 | name: Test 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.config.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | config: 16 | - {os: ubuntu-latest, pip: ~/.cache/pip} 17 | - {os: macos-latest, pip: ~/Library/Caches/pip} 18 | - {os: windows-latest, pip: ~\AppData\Local\pip\Cache} 19 | python: [ '3.8', '3.9' ] 20 | name: "${{ matrix.config.os }} Python ${{ matrix.python }}" 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Setup python 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python }} 27 | architecture: x64 28 | - name: Cache 29 | uses: actions/cache@v2 30 | with: 31 | path: ${{ matrix.config.pip }} 32 | key: ${{ runner.os }}-pip-${{ hashFiles('**/examples_requirements.txt') }} 33 | restore-keys: | 34 | ${{ runner.os }}-pip- 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install -r examples_requirements.txt 39 | - name: Test 40 | run: python -m unittest discover 41 | 42 | 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spm-image : Sparse modeling and Compressive sensing in Python [![GitHubActions](https://github.com/hacarus/spm-image/actions/workflows/unittest.yml/badge.svg)](https://github.com/hacarus/spm-image/actions) 2 | 3 | spm-image is a Python library for image analysis using sparse modeling and compressive sensing. 4 | 5 | ## Requirements 6 | 7 | * Python 3.5 or later 8 | 9 | ## Install 10 | 11 | pip install spm-image 12 | 13 | ## For developers 14 | 15 | To set up development environment, run the following commands. 16 | 17 | ``` 18 | python -m venv venv 19 | source venv/bin/activate 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ### Examples 24 | 25 | If you want to run examples, create separated venv from one for development above. 26 | 27 | ``` 28 | python -m venv examples_venv 29 | source examples_venv/bin/activate 30 | pip install -r examples_requirements.txt 31 | ``` 32 | 33 | Then add it to jupyter kernels like this. 34 | 35 | ``` 36 | python -m ipykernel install --user --name spm-image-examples --display-name "spm-image Examples" 37 | ``` 38 | 39 | Thereafter, you can run jupyter notebook as follows. 40 | 41 | ``` 42 | jupyter notebook 43 | ``` 44 | 45 | ### Testing 46 | 47 | You can run all test cases just like this 48 | 49 | ``` 50 | python -m unittest discover 51 | ``` 52 | 53 | Or run specific test case as follows 54 | 55 | ``` 56 | python -m unittest tests.test_decomposition_ksvd.TestKSVD 57 | ``` 58 | 59 | -------------------------------------------------------------------------------- /tests/test_linear_model_ppd.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from numpy.testing import assert_array_almost_equal 5 | 6 | 7 | class TestLassoPPD(unittest.TestCase): 8 | def setUp(self) -> None: 9 | # NOTE: we should import LassoPPD here, otherwise `test_experimental.py` will fail. 10 | from spmimage.experimental import enable_ppd 11 | from spmimage.linear_model import LassoPPD 12 | 13 | self.lasso = LassoPPD(alpha=1, max_iter=10, params=[1]) 14 | self.fused = LassoPPD(alpha=1, max_iter=10, params=[1, -1]) 15 | self.trend = LassoPPD(alpha=1, max_iter=10, params=[1, -2, 1]) 16 | 17 | def test_Du(self): 18 | u = np.ones(10) 19 | assert_array_almost_equal(u, self.lasso._Du(u)) 20 | assert_array_almost_equal(np.zeros(9), self.fused._Du(u)) 21 | assert_array_almost_equal(np.zeros(8), self.trend._Du(u)) 22 | 23 | def test_DTv(self): 24 | v = np.ones(10) 25 | v1 = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1]) 26 | v2 = np.array([1, -1, 0, 0, 0, 0, 0, 0, 0, 0, -1, 1]) 27 | assert_array_almost_equal(v, self.lasso._DTv(v)) 28 | assert_array_almost_equal(v1, self.fused._DTv(v)) 29 | assert_array_almost_equal(v2, self.trend._DTv(v)) 30 | 31 | def test_lasso_admm_zero(self): 32 | # Check that lasso by admm can handle zero data without crashing 33 | y = np.random.rand(100) 34 | self.lasso.fit(None, y) 35 | self.fused.fit(None, y) 36 | self.trend.fit(None, y) 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /tests/test_decomposition_dct.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from spmimage.decomposition import generate_dct_dictionary 7 | from spmimage.decomposition.dct import zig_zag_index 8 | 9 | 10 | class TestDCT(unittest.TestCase): 11 | def setUp(self): 12 | np.random.seed(0) 13 | 14 | def test_zig_zag_index(self): 15 | n = 5 16 | true_matrix = np.array([[0., 1., 5., 6., 14.], 17 | [2., 4., 7., 13., 15.], 18 | [3., 8., 12., 16., 21.], 19 | [9., 11., 17., 20., 22.], 20 | [10., 18., 19., 23., 24.]]) 21 | 22 | M = np.empty((n, n)) 23 | for k in range(n * n): 24 | M[zig_zag_index(k, n)] = k 25 | 26 | npt.assert_array_equal(M, true_matrix) 27 | 28 | def test_dct_complete(self): 29 | n_components = 4 30 | patch_size = 2 31 | D = generate_dct_dictionary(n_components, patch_size) 32 | D22 = np.array([ 33 | [1.0, 1.0, 1.0, 1.0], 34 | [0.707, -0.707, 0.707, -0.707], 35 | [0.707, 0.707, -0.707, -0.707], 36 | [0.5, -0.5, -0.5, 0.5] 37 | ]) 38 | npt.assert_array_almost_equal(D, D22, 3) 39 | 40 | def test_dct_less(self): 41 | patch_size = 5 42 | 43 | n_components = 4 44 | D = generate_dct_dictionary(n_components, patch_size) 45 | npt.assert_array_equal((n_components, patch_size ** 2), D.shape) 46 | 47 | n_components = 15 48 | D = generate_dct_dictionary(n_components, patch_size) 49 | npt.assert_array_equal((n_components, patch_size ** 2), D.shape) 50 | 51 | n_components = 24 52 | D = generate_dct_dictionary(n_components, patch_size) 53 | npt.assert_array_equal((n_components, patch_size ** 2), D.shape) 54 | 55 | 56 | if __name__ == '__main__': 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /spmimage/feature_extraction/image.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import numpy as np 4 | from itertools import product 5 | 6 | from typing import Tuple 7 | 8 | __all__ = [ 9 | 'extract_simple_patches_2d', 10 | 'reconstruct_from_simple_patches_2d', 11 | ] 12 | 13 | logger = getLogger(__name__) 14 | 15 | 16 | def extract_simple_patches_2d(image: np.ndarray, patch_size: Tuple[int, int]) -> np.ndarray: 17 | """Reshape a 2D image into a collection of patches without duplication of extracted range. 18 | """ 19 | 20 | i_h, i_w = image.shape[:2] 21 | p_h, p_w = patch_size 22 | 23 | if i_h % p_h != 0 or i_w % p_w != 0: 24 | logger.warning( 25 | 'image %s divided by patch %s is not zero and some parts will be lost', image.shape[:2], patch_size) 26 | 27 | image = image.reshape((i_h, i_w, -1)) 28 | n_colors = image.shape[-1] 29 | 30 | patches = [] 31 | 32 | n_h = int(i_h / p_h) 33 | n_w = int(i_w / p_w) 34 | 35 | for i in range(n_h): 36 | for j in range(n_w): 37 | patch = image[p_h * i:p_h * i + p_h, p_w * j:p_w * j + p_w] 38 | patches.append(patch.flatten()) 39 | 40 | n_patches = len(patches) 41 | 42 | patches_ret = np.asarray(patches).flatten().reshape(-1, p_h, p_w, n_colors) 43 | if patches_ret.shape[-1] == 1: 44 | return patches_ret.reshape((n_patches, p_h, p_w)) 45 | else: 46 | return patches_ret 47 | 48 | 49 | def reconstruct_from_simple_patches_2d(patches: np.ndarray, image_size: Tuple[int, int]) -> np.ndarray: 50 | """Reconstruct the image from all of its patches. 51 | """ 52 | i_h, i_w = image_size[:2] 53 | p_h, p_w = patches.shape[1:3] 54 | image = np.zeros(image_size, dtype=patches.dtype) 55 | 56 | n_h = int(i_h / p_h) 57 | n_w = int(i_w / p_w) 58 | for p, (i, j) in zip(patches, product(range(n_h), range(n_w))): 59 | image[p_h * i:p_h * i + p_h, p_w * j:p_w * j + p_w] += p 60 | return image 61 | -------------------------------------------------------------------------------- /tests/test_decomposition_dict_learning.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sklearn.decomposition import sparse_encode 4 | from spmimage.decomposition import sparse_encode_with_mask, sparse_encode_with_l21_norm 5 | 6 | import numpy as np 7 | 8 | from tests.utils import generate_dictionary_and_samples 9 | 10 | 11 | class TestDictLearning(unittest.TestCase): 12 | def setUp(self): 13 | np.random.seed(0) 14 | 15 | def test_sparse_encode_with_no_mask(self): 16 | k0 = 3 17 | n_samples = 64 18 | n_features = 32 19 | n_components = 10 20 | 21 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, k0) 22 | mask = np.ones(X.shape) 23 | 24 | W1 = sparse_encode(X, A0, algorithm='omp', n_nonzero_coefs=k0) 25 | W2 = sparse_encode_with_mask(X, A0, mask, algorithm='omp', n_nonzero_coefs=k0) 26 | 27 | # check if W1 and W2 is almost same 28 | self.assertTrue(abs(np.linalg.norm(X - W1.dot(A0), 'fro') - np.linalg.norm(X - W2.dot(A0), 'fro')) < 1e-8) 29 | 30 | def test_sparse_encode_with_mask(self): 31 | k0 = 5 32 | n_samples = 128 33 | n_features = 64 34 | n_components = 32 35 | 36 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, k0) 37 | mask = np.random.rand(X.shape[0], X.shape[1]) 38 | mask = np.where(mask < 0.8, 1, 0) 39 | 40 | W = sparse_encode_with_mask(X, A0, mask, algorithm='omp', n_nonzero_coefs=k0) 41 | 42 | # check error of learning 43 | # print(np.linalg.norm(mask*(X-W.dot(A0)), 'fro')) 44 | self.assertTrue(np.linalg.norm(mask * (X - W.dot(A0)), 'fro') < 50) 45 | 46 | def test_sparse_encode_with_l21_norm(self): 47 | k0 = 5 48 | n_samples = 128 49 | n_features = 64 50 | n_components = 32 51 | 52 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, k0) 53 | 54 | W = sparse_encode_with_l21_norm(X, A0) 55 | 56 | # check error of learning 57 | self.assertTrue(np.linalg.norm(X - W.dot(A0), 'fro') < 50) 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /spmimage/preprocessing/data.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import numpy as np 4 | 5 | from scipy import sparse 6 | from sklearn.base import BaseEstimator, TransformerMixin 7 | from sklearn.utils import check_array 8 | 9 | logger = getLogger(__name__) 10 | 11 | 12 | class WhiteningScaler(BaseEstimator, TransformerMixin): 13 | def __init__(self, copy=True, eps=1e-9, thresholding=None, unbiased=True, apply_zca=False): 14 | self.copy = copy 15 | self.eps = eps 16 | self.thresholding = thresholding 17 | self.unbiased = unbiased 18 | self.apply_zca = apply_zca 19 | 20 | def _fit(self, X): 21 | if sparse.issparse(X): 22 | raise ValueError(""" 23 | WhiteningScaler does not support sparse input. See TruncatedSVD for a possible alternative. 24 | """) 25 | 26 | if self.eps <= 0 and self.thresholding in ['normalize', 'drop_minute']: 27 | raise ValueError('Threshold eps must be positive: eps={0}.'.format(self.eps)) 28 | 29 | X = check_array(X, dtype=[np.float64, np.float32], ensure_2d=True, copy=self.copy) 30 | 31 | # Center data 32 | mean = np.mean(X, axis=0) 33 | X -= mean 34 | 35 | # SVD 36 | n_samples = X.shape[0] - 1 if self.unbiased else X.shape[0] 37 | _, s, V = np.linalg.svd(X / np.sqrt(n_samples), full_matrices=False) 38 | 39 | if self.thresholding is None: 40 | if np.any(np.isclose(s, np.zeros(s.shape), atol=1e-10)): 41 | raise ValueError(""" 42 | Eigenvalues of X' are degenerated: X'=X-np.mean(X,axis=0), \ 43 | try thresholding='normalize' or thresholding='drop_minute'. 44 | """) 45 | elif self.thresholding == 'normalize': 46 | s += self.eps 47 | elif self.thresholding == 'drop_minute': 48 | s = s[s > self.eps] 49 | V = V[:s.shape[0]] 50 | else: 51 | raise ValueError('No such parameter: thresholding={0}.'.format(self.thresholding)) 52 | 53 | return mean, s, V 54 | 55 | def fit(self, X): 56 | self.mean_, self.var_, self.unitary_ = self._fit(X) 57 | return self 58 | 59 | def transform(self, X): 60 | # Decorrelation & Whitening 61 | S_inv = np.diag(np.ones(self.var_.shape[0]) / self.var_) 62 | X_transformed = (X - self.mean_).dot(self.unitary_.T.dot(S_inv)) 63 | 64 | # ZCA(Zero-phase Component Analysis) Whitening 65 | if self.apply_zca: 66 | return X_transformed.dot(self.unitary_) 67 | return X_transformed 68 | 69 | def inverse_transform(self, X): 70 | S = np.diag(self.var_) 71 | X_itransformed = np.copy(X) 72 | if self.apply_zca: 73 | X_itransformed = X_itransformed.dot(self.unitary_.T) 74 | return X_itransformed.dot(S.dot(self.unitary_)) + self.mean_ 75 | -------------------------------------------------------------------------------- /spmimage/decomposition/dct.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | 6 | 7 | def zig_zag_index(k: int, n: int) -> Tuple[int, int]: 8 | """ 9 | get k-th index i and j on (n, n)-matrix according to zig-zag scan. 10 | 11 | Parameters: 12 | ----------- 13 | k : int 14 | a ranking of element, which we want to know the index i and j 15 | 16 | n : int 17 | a size of square matrix 18 | 19 | Returns: 20 | ----------- 21 | (i, j) : Tuple[int, int] 22 | the tuple which represents the height and width index of k-th elements 23 | 24 | Reference 25 | ---------- 26 | https://medium.com/100-days-of-algorithms/day-63-zig-zag-51a41127f31 27 | """ 28 | # upper side of interval 29 | if k >= n * (n + 1) // 2: 30 | i, j = zig_zag_index(n * n - 1 - k, n) 31 | return n - 1 - i, n - 1 - j 32 | 33 | # lower side of interval 34 | i = int((np.sqrt(1 + 8 * k) - 1) / 2) 35 | j = k - i * (i + 1) // 2 36 | return (j, i - j) if i & 1 else (i - j, j) 37 | 38 | 39 | def generate_dct_atom(u, v, n) -> np.ndarray: 40 | """ 41 | generate an (u, v)-th atom of DCT dictionary with size n by n. 42 | 43 | Parameters: 44 | ----------- 45 | u : int 46 | an index for height 47 | 48 | v : int 49 | an index for width 50 | 51 | n : int 52 | a size of DCT 53 | 54 | Returns: 55 | ----------- 56 | atom : np.ndarray 57 | (n, n) matrix which represents (u,v)-th atom of DCT dictionary 58 | """ 59 | atom = np.empty((n, n)) 60 | for i, j in itertools.product(range(n), range(n)): 61 | atom[i, j] = np.cos(((i+0.5)*u*np.pi)/n) * np.cos(((j+0.5)*v*np.pi)/n) 62 | return atom 63 | 64 | def generate_dct_dictionary(n_components: int, patch_size: int) -> np.ndarray: 65 | """generate_dct_dictionary 66 | Generate a DCT dictionary. 67 | An atom is a (patch_size, patch_size) image, and total number of atoms is 68 | n_components. 69 | 70 | The result D is a matrix whose shape is (n_components, patch_size ** 2). 71 | Note that, a row of the result D shows an atom (flatten). 72 | 73 | Parameters: 74 | ------------ 75 | n_components: int 76 | a number of atom, where n_components <= patch_size ** 2. 77 | 78 | patch_size : int 79 | size of atom of DCT dictionary 80 | 81 | Returns: 82 | ------------ 83 | D : np.ndarray, shape (n_components, patch_size ** 2) 84 | DCT dictionary 85 | """ 86 | D = np.empty((n_components, patch_size ** 2)) 87 | 88 | if n_components > patch_size ** 2: 89 | raise ValueError("n_components must be smaller than patch_size ** 2") 90 | 91 | elif n_components == patch_size ** 2: 92 | for i, j in itertools.product(range(patch_size), range(patch_size)): 93 | D[i*patch_size + j] = generate_dct_atom(i, j, patch_size).flatten() 94 | else: 95 | for k in range(n_components): 96 | i, j = zig_zag_index(k, patch_size) 97 | D[k, :] = generate_dct_atom(i, j, patch_size).flatten() 98 | return D 99 | -------------------------------------------------------------------------------- /tests/test_preprocessing_whitening.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from spmimage.preprocessing import WhiteningScaler 5 | from numpy.testing import assert_array_almost_equal 6 | 7 | 8 | class TestWhitening(unittest.TestCase): 9 | def setUp(self): 10 | self.identity_matrix = np.identity(3, dtype=int) 11 | 12 | def test_fit_error(self): 13 | X = self.identity_matrix 14 | self.assertEqual(np.linalg.matrix_rank(X - np.mean(X, axis=0)), 2) 15 | 16 | model = WhiteningScaler() 17 | self.assertRaises(ValueError, model.fit, X) 18 | 19 | model = WhiteningScaler(eps=-1, thresholding='normalize') 20 | self.assertRaises(ValueError, model.fit, X) 21 | 22 | model = WhiteningScaler(eps=-1, thresholding='drop_minute') 23 | self.assertRaises(ValueError, model.fit, X) 24 | 25 | def test_fit_normalize(self): 26 | X = self.identity_matrix 27 | 28 | model = WhiteningScaler(eps=1e-6, thresholding='normalize') 29 | actual = model.fit_transform(X) 30 | expected = np.array([[-1.15206e+00, -7.80281e-02, 2.04709e-10], 31 | [6.43604e-01, -9.58699e-01, 5.91895e-11], 32 | [5.08455e-01, 1.03673e+00, -5.42803e-11]]) 33 | # assert_array_almost_equal(actual, expected, decimal=4) 34 | assert_array_almost_equal(np.cov(actual.T), np.diag([1, 1, 0]), decimal=4) 35 | assert_array_almost_equal(X, model.inverse_transform(actual)) 36 | 37 | def test_fit_drop_minute(self): 38 | X = self.identity_matrix 39 | 40 | model = WhiteningScaler(eps=1e-8, thresholding='drop_minute') 41 | actual = model.fit_transform(X) 42 | expected = np.array([[-1.15206, -0.07803], 43 | [0.64361, -0.95870], 44 | [0.50846, 1.03673]]) 45 | # assert_array_almost_equal(actual, expected, decimal=4) 46 | assert_array_almost_equal(np.cov(actual.T), np.diag([1, 1]), decimal=4) 47 | assert_array_almost_equal(X, model.inverse_transform(actual)) 48 | 49 | def test_fit_normalize_biased(self): 50 | X = self.identity_matrix 51 | 52 | model = WhiteningScaler(eps=1e-8, thresholding='normalize', unbiased=False) 53 | actual = model.fit_transform(X) 54 | expected = np.array([[-1.4142e+00, -1.4345e-17, 1.8453e-08], 55 | [7.0711e-01, -1.2247e+00, -3.8987e-09], 56 | [7.0711e-01, 1.2247e+00, 6.7568e-09]]) 57 | # assert_array_almost_equal(actual, expected, decimal=4) 58 | assert_array_almost_equal(np.cov(actual.T), np.diag([1.5, 1.5, 0]), decimal=4) 59 | assert_array_almost_equal(np.cov(actual.T) * (2 / 3), np.diag([1, 1, 0]), decimal=4) 60 | assert_array_almost_equal(X, model.inverse_transform(actual)) 61 | 62 | def test_fit_normalize_drop_minute_zca(self): 63 | X = self.identity_matrix 64 | 65 | model = WhiteningScaler(eps=1e-8, thresholding='drop_minute', apply_zca=True) 66 | actual = model.fit_transform(X) 67 | expected = np.array([[0.94281, -0.4714, -0.4714], 68 | [-0.47140, 0.94281, -0.47140], 69 | [-0.47140, -0.47140, 0.94281]]) 70 | assert_array_almost_equal(actual, expected, decimal=4) 71 | assert_array_almost_equal(np.cov(actual.T), 72 | np.array([[0.6667, -0.3333, -0.3333], 73 | [-0.3333, 0.6667, -0.3333], 74 | [-0.3333, -0.3333, 0.6667]]), 75 | decimal=4) 76 | assert_array_almost_equal(X, model.inverse_transform(actual)) 77 | 78 | 79 | if __name__ == '__main__': 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_notebook.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pathlib 4 | import subprocess 5 | import sys 6 | import time 7 | import unittest 8 | 9 | class TestJupyterNotebook(unittest.TestCase): 10 | def run_notebook(self, notebook) -> None: 11 | """ 12 | execute notebook and write the result into `out_notebook` 13 | Parameter: 14 | notebook: a path of .ipynb 15 | Return: 16 | None 17 | """ 18 | out_notebook = os.path.join(notebook.parent, self.out_notebook_name) 19 | result = subprocess.run( 20 | [ 21 | "jupyter", 22 | "nbconvert", 23 | "--to", 24 | "notebook", 25 | "--execute", 26 | notebook, 27 | "--output", 28 | out_notebook, 29 | "--ExecutePreprocessor.timeout=86400", 30 | "--allow-errors", 31 | "--debug", 32 | ], 33 | stdout=subprocess.DEVNULL, 34 | stderr=subprocess.PIPE, 35 | ) 36 | return result.returncode 37 | 38 | def parse_notebook(self, notebook) -> bool: 39 | """ 40 | parse notebook and report errors 41 | Parameter: 42 | notebook: a path of .ipynb 43 | Return: 44 | True: an error exists 45 | False: no error 46 | """ 47 | is_error_exists = False 48 | out_notebook = os.path.join(notebook.parent, self.out_notebook_name) 49 | with open(out_notebook, "r") as f: 50 | json_dict = json.load(f) 51 | for cell in json_dict["cells"]: 52 | if ( 53 | "outputs" in cell 54 | and len(cell["outputs"]) > 0 55 | and "output_type" in cell["outputs"][0] 56 | and cell["outputs"][0]["output_type"] == "error" 57 | ): 58 | msg_notebook_info = ( 59 | "notebook name: {name}\n" 60 | "execution_count: {cnt}\n".format( 61 | name=notebook, cnt=cell["execution_count"] 62 | ) 63 | ) 64 | msg_code = ( 65 | "-----code-----\n" 66 | "{code}\n" 67 | "--------------\n".format(code=("\n").join(cell["source"])) 68 | ) 69 | msg_traceback = ( 70 | "---traceback--\n" 71 | "{traceback}" 72 | "\n--------------\n".format( 73 | traceback=("\n").join(cell["outputs"][0]["traceback"]) 74 | ) 75 | ) 76 | error_msg = msg_notebook_info + msg_code + msg_traceback 77 | logger.error(error_msg) 78 | 79 | is_error_exists = True 80 | 81 | if os.path.exists(out_notebook): 82 | os.remove(out_notebook) 83 | 84 | return is_error_exists 85 | 86 | def setUp(self): 87 | self.out_notebook_name = "tmp.ipynb" 88 | self.DIR_NOTEBOOK = "../examples" 89 | sys.path = sys.path + [self.DIR_NOTEBOOK] 90 | 91 | # the notebook in this directory is pass in unittest 92 | self.EXCLUDE_DIR = [ 93 | ".ipynb_checkpoints", 94 | ] 95 | 96 | # @unittest.skip("due to high computation") 97 | def test_all_notebooks(self): 98 | # get directory of notebooks 99 | path_notebook = pathlib.Path( 100 | os.path.join(os.path.dirname(__file__), self.DIR_NOTEBOOK) 101 | ) 102 | 103 | is_error = [] 104 | times = {} 105 | for notebook in path_notebook.glob("**/*.ipynb"): 106 | skip_flag = False 107 | for exclude_dir in self.EXCLUDE_DIR: 108 | if exclude_dir in str(notebook.parent): 109 | skip_flag = True 110 | if skip_flag: 111 | continue 112 | print("Run: ", notebook.name) 113 | start = time.time() 114 | status = self.run_notebook(notebook) 115 | if status != 0: 116 | print("error:", notebook) 117 | else: 118 | elapsed_time = time.time() - start 119 | times[notebook.name] = elapsed_time 120 | result = self.parse_notebook(notebook) 121 | is_error.append(result) 122 | 123 | print("---computation time for each notebook---") 124 | for notebook, t in times.items(): 125 | print(f"{notebook:<50}: {t:<8.2f} [sec]") 126 | 127 | self.assertNotIn(True, is_error) 128 | -------------------------------------------------------------------------------- /spmimage/decomposition/dict_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.utils import check_array 4 | from sklearn.decomposition import sparse_encode 5 | 6 | 7 | def sparse_encode_with_mask(X, dictionary, mask, **kwargs): 8 | """sparse_encode_with_mask 9 | Finds a sparse coding that represent data with given dictionary. 10 | 11 | X ~= code * dictionary 12 | 13 | Parameters: 14 | ------------ 15 | X : array-like, shape (n_samples, n_features) 16 | Training vector, where n_samples in the number of samples 17 | and n_features is the number of features. 18 | 19 | dictionary : array of shape (n_components, n_features), 20 | The dictionary factor 21 | 22 | mask : array-like, shape (n_samples, n_features), 23 | value at (i,j) in mask is not 1 indicates value at (i,j) in X is missing 24 | 25 | verbose : bool 26 | Degree of output the procedure will print. 27 | 28 | **kwargs : 29 | algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'} 30 | lars: uses the least angle regression method (linear_model.lars_path) 31 | lasso_lars: uses Lars to compute the Lasso solution 32 | lasso_cd: uses the coordinate descent method to compute the 33 | Lasso solution (linear_model.Lasso). lasso_lars will be faster if 34 | the estimated components are sparse. 35 | omp: uses orthogonal matching pursuit to estimate the sparse solution 36 | threshold: squashes to zero all coefficients less than regularization 37 | from the projection dictionary * data' 38 | n_nonzero_coefs : int, 39 | number of non-zero elements of sparse coding 40 | n_jobs : int, optional 41 | Number of parallel jobs to run. 42 | 43 | Returns: 44 | --------- 45 | code : array of shape (n_components, n_features) 46 | The sparse codes 47 | """ 48 | code = np.zeros((X.shape[0], dictionary.shape[0])) 49 | for idx in range(X.shape[0]): 50 | code[idx, :] = sparse_encode(X[idx, :][mask[idx, :] == 1].reshape(1, -1), 51 | dictionary[:, mask[idx, :] == 1], 52 | **kwargs) 53 | return code 54 | 55 | 56 | def _shrinkage_map_for_l21_norm(X, gamma): 57 | """ 58 | Shrinkage mapping for l2,1-norm minimization. 59 | """ 60 | norm_X = np.linalg.norm(X, axis=0).reshape(1, -1) 61 | norm_X[norm_X == 0] = gamma 62 | return np.maximum(1 - gamma / norm_X, 0) * X 63 | 64 | 65 | def sparse_encode_with_l21_norm(X, dictionary, max_iter=30, alpha=1.0, tau=1.0, check_input=True): 66 | """ 67 | Finds a sparse coding that represent data with given dictionary. 68 | 69 | X ~= code * dictionary 70 | 71 | Minimizes the following objective function: 72 | 73 | 1 / (2 * n_samples) * ||X - WD||^2_F + alpha * ||W||_2,1 74 | 75 | To solve this problem, ADMM uses augmented Lagrangian 76 | 77 | 1 / (2 * n_samples) * ||X - WD||^2_F + alpha * ||Y||_2,1 78 | + U^T (W - Y) + tau / (2 * n_samples) * ||W - Y||^2_F 79 | 80 | where U is Lagrange multiplier and tau is tuning parameter. 81 | 82 | Parameters: 83 | ------------ 84 | X : array-like, shape (n_samples, n_features) 85 | Training matrix 86 | 87 | dictionary : array of shape (n_components, n_features) 88 | The dictionary factor 89 | 90 | max_iter : int, optional (default=1000) 91 | Maximum number of iterations 92 | 93 | alpha : float, optional (default=1.0) 94 | The penalty applied to the L2-1 norm 95 | 96 | check_input : boolean, optional (default=True) 97 | If False, the input arrays X and dictionary will not be checked. 98 | 99 | tau : float, optional (default=1.0) 100 | The penalty applied to the augmented Lagrangian function 101 | 102 | Returns: 103 | --------- 104 | Y : array of shape (n_components, n_features) 105 | The sparse codes 106 | """ 107 | if check_input: 108 | dictionary = check_array(dictionary) 109 | X = check_array(X) 110 | 111 | n_components = dictionary.shape[0] 112 | n_samples = X.shape[0] 113 | tau /= n_samples 114 | inv_matrix = np.linalg.inv(dictionary @ dictionary.T / n_samples + tau * np.identity(n_components)) 115 | XD = X @ dictionary.T / n_samples 116 | tau_inv = 1 / tau 117 | alpha_tau = alpha * tau_inv 118 | 119 | # initialize 120 | W = XD @ inv_matrix 121 | Y = W.copy() 122 | U = np.zeros_like(W) 123 | 124 | for _ in range(max_iter): 125 | W = (XD + tau * Y - U) @ inv_matrix 126 | Y = _shrinkage_map_for_l21_norm(W + tau_inv * U, alpha_tau) 127 | U = U + tau * (W - Y) 128 | return Y 129 | -------------------------------------------------------------------------------- /tests/test_linear_model_hmlasso.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from spmimage.linear_model import HMLasso 5 | from numpy.testing import assert_array_almost_equal 6 | 7 | 8 | def build_dataset(n_samples=50, n_features=100, n_informative_features=10, 9 | n_targets=1): 10 | """ 11 | build an ill-posed linear regression problem with many noisy features and 12 | comparatively few samples 13 | """ 14 | random_state = np.random.RandomState(0) 15 | if n_targets > 1: 16 | w = random_state.randn(n_features, n_targets) 17 | else: 18 | w = random_state.randn(n_features) 19 | w[n_informative_features:] = 0.0 20 | X = random_state.randn(n_samples, n_features) 21 | y = np.dot(X, w) 22 | rand = random_state.rand(n_samples, n_features) 23 | X[rand > 0.99] = np.nan 24 | X_test = random_state.randn(n_samples, n_features) 25 | y_test = np.dot(X_test, w) 26 | return X, y, X_test, y_test 27 | 28 | 29 | class TestHMLasso(unittest.TestCase): 30 | def test_lasso_admm_zero(self): 31 | # Check that lasso by admm can handle zero data without crashing 32 | X = [[0], [0], [0]] 33 | y = [0, 0, 0] 34 | clf = HMLasso(alpha=0.1).fit(X, y) 35 | pred = clf.predict([[1], [2], [3]]) 36 | assert_array_almost_equal(clf.coef_, [0]) 37 | assert_array_almost_equal(pred, [0, 0, 0]) 38 | 39 | def test_lasso_admm_toy(self): 40 | # Test HMLasso for various parameters of alpha and mu_coef, using 41 | # the same test case as Lasso implementation of sklearn. 42 | # (see https://github.com/scikit-learn/scikit-learn/blob/master 43 | # /sklearn/linear_model/tests/test_coordinate_descent.py) 44 | # Actually, the parameters alpha = 0 should not be allowed. However, 45 | # we test it as a border case. 46 | # WARNING: 47 | # HMLasso can't check the case which is not converged 48 | # because HMLasso doesn't check dual gap yet. 49 | # This problem will be fixed in future. 50 | 51 | X = np.array([[-1.], [0.], [1.]]) 52 | Y = [-1, 0, 1] # just a straight line 53 | T = [[2.], [3.], [4.]] # test sample 54 | 55 | clf = HMLasso(alpha=1e-8, tol_coef=1e-8) 56 | clf.fit(X, Y) 57 | pred = clf.predict(T) 58 | assert_array_almost_equal(clf.coef_, [1], decimal=3) 59 | assert_array_almost_equal(pred, [2, 3, 4], decimal=3) 60 | 61 | clf = HMLasso(alpha=0.1, tol_coef=1e-8) 62 | clf.fit(X, Y) 63 | pred = clf.predict(T) 64 | assert_array_almost_equal(clf.coef_, [.85], decimal=3) 65 | assert_array_almost_equal(pred, [1.7, 2.55, 3.4], decimal=3) 66 | 67 | clf = HMLasso(alpha=0.5, tol_coef=1e-8) 68 | clf.fit(X, Y) 69 | pred = clf.predict(T) 70 | assert_array_almost_equal(clf.coef_, [.25], decimal=3) 71 | assert_array_almost_equal(pred, [0.5, 0.75, 1.0], decimal=3) 72 | 73 | clf = HMLasso(alpha=1, tol_coef=1e-8) 74 | clf.fit(X, Y) 75 | pred = clf.predict(T) 76 | assert_array_almost_equal(clf.coef_, [.0], decimal=3) 77 | assert_array_almost_equal(pred, [0, 0, 0], decimal=3) 78 | 79 | # this is the same test case as the case alpha=1e-8 80 | # because the default mu_coef parameter equals 1.0 81 | clf = HMLasso(alpha=1e-8, mu_coef=1.0, tol_coef=1e-8) 82 | clf.fit(X, Y) 83 | pred = clf.predict(T) 84 | assert_array_almost_equal(clf.coef_, [1], decimal=3) 85 | assert_array_almost_equal(pred, [2, 3, 4], decimal=3) 86 | 87 | clf = HMLasso(alpha=0.5, mu_coef=0.5, tol_coef=1e-8) 88 | clf.fit(X, Y) 89 | pred = clf.predict(T) 90 | assert_array_almost_equal(clf.coef_, [0.25], decimal=3) 91 | assert_array_almost_equal(pred, [0.5, 0.75, 1.0], decimal=3) 92 | 93 | def test_lasso_admm_toy_multi(self): 94 | # for issue #39 95 | X = np.eye(4) 96 | y = np.array([[1, 1, 0], 97 | [1, 0, 1], 98 | [0, 1, 0], 99 | [0, 0, 1]]) 100 | 101 | clf = HMLasso(alpha=0.05, tol_coef=1e-8).fit(X, y) 102 | assert_array_almost_equal(clf.coef_[0], [0.29999988, 0.29999988, -0.29999988, -0.29999988], decimal=3) 103 | assert_array_almost_equal(clf.coef_[1], [0.29999988, -0.29999988, 0.29999988, -0.29999988], decimal=3) 104 | assert_array_almost_equal(clf.coef_[2], [-0.29999988, 0.29999988, -0.29999988, 0.29999988], decimal=3) 105 | 106 | def test_lasso_admm(self): 107 | X, y, X_test, y_test = build_dataset() 108 | 109 | clf = HMLasso(alpha=0.05, mu_cov=0.1, tol_coef=1e-8).fit(X, y) 110 | self.assertGreater(clf.score(X_test, y_test), 0.9) 111 | 112 | clf = HMLasso(alpha=0.05, mu_cov=0.1).fit(X, y) 113 | self.assertGreater(clf.score(X_test, y_test), 0.9) 114 | 115 | clf = HMLasso(alpha=0.144, mu_cov=0.1, normalize=True).fit(X, y) 116 | self.assertGreater(clf.score(X_test, y_test), 0.9) 117 | 118 | def test_lasso_admm_multi(self): 119 | X, y, X_test, y_test = build_dataset(n_targets=3) 120 | 121 | clf = HMLasso(alpha=0.05, mu_cov=0.1, tol_coef=1e-8).fit(X, y) 122 | self.assertGreater(clf.score(X_test, y_test), 0.9) 123 | 124 | clf = HMLasso(alpha=0.05, mu_cov=0.1, ).fit(X, y) 125 | self.assertGreater(clf.score(X_test, y_test), 0.9) 126 | 127 | clf = HMLasso(alpha=0.144, mu_cov=0.1, normalize=True).fit(X, y) 128 | self.assertGreater(clf.score(X_test, y_test), 0.9) 129 | 130 | 131 | if __name__ == '__main__': 132 | unittest.main() 133 | -------------------------------------------------------------------------------- /tests/test_feature_extraction_image.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from spmimage.feature_extraction.image import extract_simple_patches_2d, reconstruct_from_simple_patches_2d 4 | 5 | import numpy as np 6 | from numpy.testing import assert_array_almost_equal 7 | 8 | class TestImage(TestCase): 9 | def test_extract_simple_patches_2d_gray(self): 10 | image = np.arange(16).reshape((4, 4)) 11 | 12 | actual = extract_simple_patches_2d(image, (2, 2,)) 13 | self.assertEqual((4, 2, 2), actual.shape) 14 | self.assertEqual(image.dtype, actual.dtype) 15 | 16 | self.assertTrue((np.array([0, 1, 4, 5]) == actual[0].flatten()).all()) 17 | self.assertTrue((np.array([2, 3, 6, 7]) == actual[1].flatten()).all()) 18 | self.assertTrue((np.array([8, 9, 12, 13]) == actual[2].flatten()).all()) 19 | self.assertTrue((np.array([10, 11, 14, 15]) == actual[3].flatten()).all()) 20 | 21 | def test_extract_simple_patches_2d_color(self): 22 | image = np.arange(48).reshape((4, 4, 3)) 23 | 24 | actual = extract_simple_patches_2d(image, (2, 2,)) 25 | self.assertEqual((4, 2, 2, 3), actual.shape) 26 | self.assertEqual(image.dtype, actual.dtype) 27 | 28 | self.assertTrue((np.array([0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 17]) == actual[0].flatten()).all()) 29 | self.assertTrue((np.array([6, 7, 8, 9, 10, 11, 18, 19, 20, 21, 22, 23]) == actual[1].flatten()).all()) 30 | self.assertTrue((np.array([24, 25, 26, 27, 28, 29, 36, 37, 38, 39, 40, 41]) == actual[2].flatten()).all()) 31 | self.assertTrue((np.array([30, 31, 32, 33, 34, 35, 42, 43, 44, 45, 46, 47]) == actual[3].flatten()).all()) 32 | 33 | def test_reconstruct_from_simple_patches_2d_gray(self): 34 | patches = np.stack(( 35 | np.array([[0, 1], [4, 5]]), 36 | np.array([[2, 3], [6, 7]]), 37 | np.array([[8, 9], [12, 13]]), 38 | np.array([[10, 11], [14, 15]]), 39 | )) 40 | actual = reconstruct_from_simple_patches_2d(patches, (4, 4)) 41 | self.assertTrue((np.arange(16).reshape((4, 4)) == actual).all()) 42 | self.assertEqual(patches.dtype, actual.dtype) 43 | 44 | def test_reconstruct_from_simple_patches_2d_color(self): 45 | patches = np.stack(( 46 | np.array([[[0, 1, 2], [3, 4, 5]], [[12, 13, 14], [15, 16, 17]]]), 47 | np.array([[[6, 7, 8], [9, 10, 11]], [[18, 19, 20], [21, 22, 23]]]), 48 | np.array([[[24, 25, 26], [27, 28, 29]], [[36, 37, 38], [39, 40, 41]]]), 49 | np.array([[[30, 31, 32], [33, 34, 35]], [[42, 43, 44], [45, 46, 47]]]), 50 | )) 51 | actual = reconstruct_from_simple_patches_2d(patches, (4, 4, 3)) 52 | self.assertTrue((np.arange(48).reshape((4, 4, 3)) == actual).all()) 53 | self.assertEqual(patches.dtype, actual.dtype) 54 | 55 | def test_extract_simple_patches_2d_gray_float(self): 56 | image = np.arange(0.0, 1.6, 0.1, dtype='float').reshape((4, 4)) 57 | 58 | actual = extract_simple_patches_2d(image, (2, 2,)) 59 | self.assertEqual((4, 2, 2), actual.shape) 60 | self.assertEqual(image.dtype, actual.dtype) 61 | 62 | assert_array_almost_equal(np.array([0.0, 0.1, 0.4, 0.5]), actual[0].flatten()) 63 | assert_array_almost_equal(np.array([0.2, 0.3, 0.6, 0.7]), actual[1].flatten()) 64 | assert_array_almost_equal(np.array([0.8, 0.9, 1.2, 1.3]), actual[2].flatten()) 65 | assert_array_almost_equal(np.array([1.0, 1.1, 1.4, 1.5]), actual[3].flatten()) 66 | 67 | def test_extract_simple_patches_2d_color_float(self): 68 | image = np.arange(0.0, 4.8, 0.1, dtype='float').reshape((4, 4, 3)) 69 | 70 | actual = extract_simple_patches_2d(image, (2, 2,)) 71 | self.assertEqual((4, 2, 2, 3), actual.shape) 72 | self.assertEqual(image.dtype, actual.dtype) 73 | 74 | assert_array_almost_equal(np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7]), actual[0].flatten()) 75 | assert_array_almost_equal(np.array([0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3]), actual[1].flatten()) 76 | assert_array_almost_equal(np.array([2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.6, 3.7, 3.8, 3.9, 4.0, 4.1]), actual[2].flatten()) 77 | assert_array_almost_equal(np.array([3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7]), actual[3].flatten()) 78 | 79 | def test_reconstruct_from_simple_patches_2d_gray_float(self): 80 | patches = np.stack(( 81 | np.array([[0.0, 0.1], [0.4, 0.5]]), 82 | np.array([[0.2, 0.3], [0.6, 0.7]]), 83 | np.array([[0.8, 0.9], [1.2, 1.3]]), 84 | np.array([[1.0, 1.1], [1.4, 1.5]]), 85 | )) 86 | actual = reconstruct_from_simple_patches_2d(patches, (4, 4)) 87 | assert_array_almost_equal(np.arange(0.0, 1.6, 0.1, dtype='float').reshape((4, 4)), actual) 88 | self.assertEqual(patches.dtype, actual.dtype) 89 | 90 | def test_reconstruct_from_simple_patches_2d_color_float(self): 91 | patches = np.stack(( 92 | np.array([[[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]], [[1.2, 1.3, 1.4], [1.5, 1.6, 1.7]]]), 93 | np.array([[[0.6, 0.7, 0.8], [0.9, 1.0, 1.1]], [[1.8, 1.9, 2.0], [2.1, 2.2, 2.3]]]), 94 | np.array([[[2.4, 2.5, 2.6], [2.7, 2.8, 2.9]], [[3.6, 3.7, 3.8], [3.9, 4.0, 4.1]]]), 95 | np.array([[[3.0, 3.1, 3.2], [3.3, 3.4, 3.5]], [[4.2, 4.3, 4.4], [4.5, 4.6, 4.7]]]), 96 | )) 97 | actual = reconstruct_from_simple_patches_2d(patches, (4, 4, 3)) 98 | assert_array_almost_equal(np.arange(0.0, 4.8, 0.1, dtype='float').reshape((4, 4, 3)), actual) 99 | self.assertEqual(patches.dtype, actual.dtype) 100 | -------------------------------------------------------------------------------- /spmimage/linear_model/_ppd.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | from abc import abstractmethod, ABC 4 | 5 | from sklearn.base import RegressorMixin 6 | from sklearn.linear_model._base import LinearModel 7 | 8 | from typing import Tuple 9 | import numpy as np 10 | 11 | logger = getLogger(__name__) 12 | 13 | 14 | class PreconditionedPrimalDual(LinearModel, RegressorMixin, ABC): 15 | """ 16 | Abstract class for Preconditioned Primal Dual algorithm 17 | """ 18 | 19 | def __init__(self, 20 | alpha: float = 1.0, 21 | max_iter: int = 1000): 22 | """ 23 | Lasso Preconditioned Primal Dual algorithm 24 | 25 | Parameters 26 | ---------- 27 | alpha : float 28 | A regularization parametfer. 29 | max_iter : int 30 | The maximum number of iterations. 31 | """ 32 | self.alpha = alpha 33 | self.max_iter = max_iter 34 | 35 | @abstractmethod 36 | def _Du(self, u: np.ndarray): 37 | """ 38 | calc Du, where D is a matrix and u is a vector 39 | 40 | Parameters 41 | -------------- 42 | u : np.ndarray 43 | a vector 44 | """ 45 | raise NotImplementedError() 46 | 47 | @abstractmethod 48 | def _DTv(self, v: np.ndarray): 49 | """ 50 | calc D^Tv, where D is a matrix and v is a vector 51 | 52 | Parameters 53 | -------------- 54 | v : np.ndarray 55 | a vector 56 | """ 57 | raise NotImplementedError() 58 | 59 | @abstractmethod 60 | def _step_size(self): 61 | """ 62 | define tau and sigma 63 | """ 64 | raise NotImplementedError() 65 | 66 | @abstractmethod 67 | def _init_dual(self, y, sigma): 68 | """ 69 | define dual vector 70 | """ 71 | raise NotImplementedError() 72 | 73 | def fit(self, X, y, check_input=False): 74 | """ 75 | Parameters 76 | ---------- 77 | X : ignored 78 | y : np.ndarray 79 | input data 80 | 81 | Attribute 82 | --------- 83 | self._coef 84 | estimated result 85 | 86 | Returns 87 | -------- 88 | self : PreconditionedPrimalDual 89 | for method chain 90 | """ 91 | if self.alpha == 0: 92 | logger.warning( 93 | """With alpha=0, this algorithm does not converge well. You are advised to use the LinearRegression estimator""") 94 | raise ValueError() 95 | 96 | self.shape = y.shape 97 | self.dim = np.prod(self.shape) 98 | tau, sigma = self._step_size() 99 | 100 | # initialize 101 | u_pre = np.copy(y) 102 | dual = self._init_dual(u_pre, sigma) 103 | 104 | # main loop 105 | for _ in range(self.max_iter): 106 | u = u_pre - (tau * (self._DTv(dual[:-self.dim]) + self.alpha * dual[-self.dim:])) 107 | u_bar = 2 * u - u_pre 108 | dual[:-self.dim] += sigma[:-self.dim] * self._Du(u_bar) 109 | dual[-self.dim:] += sigma[-self.dim:] * self.alpha * (u_bar - y) 110 | dual = np.clip(dual, -1, 1) 111 | u_pre = u 112 | 113 | self.coef_ = u_pre 114 | return self 115 | 116 | 117 | class LassoPPD(PreconditionedPrimalDual): 118 | def __init__(self, 119 | alpha: float = 1.0, 120 | max_iter: int = 1000, 121 | params: np.ndarray = None): 122 | """ 123 | Lasso Preconditioned Primal Dual algorithm 124 | 125 | Parameters 126 | ---------- 127 | alpha : float 128 | A regularization parametfer. 129 | max_iter : int 130 | The maximum number of iterations. 131 | coef : np.ndarray, default, None 132 | lasso coef 133 | """ 134 | super().__init__(alpha, max_iter) 135 | if params is None: 136 | self.params = np.array([1]) 137 | else: 138 | self.params = np.array(params) 139 | 140 | def _Du(self, u: np.ndarray) -> np.ndarray: 141 | """ 142 | calc Du 143 | 144 | Parameters 145 | -------------- 146 | u : np.ndarray 147 | a vector 148 | 149 | Return 150 | -------------- 151 | x : np.ndarray 152 | x = Du 153 | """ 154 | x = np.zeros(len(u) - len(self.params) + 1) 155 | for i, k in enumerate(self.params): 156 | x += k * u[i:len(u) - len(self.params) + i + 1] 157 | return x 158 | 159 | def _DTv(self, v: np.ndarray) -> np.ndarray: 160 | """ 161 | calc D^Tv 162 | 163 | Parameters 164 | -------------- 165 | v : np.ndarray 166 | a vector 167 | 168 | Return 169 | -------------- 170 | x : np.ndarray 171 | x = Dv 172 | """ 173 | x = np.zeros(len(v) + len(self.params) - 1) 174 | for i, k in enumerate(self.params): 175 | x[i:len(v) + i] += k * v 176 | return x 177 | 178 | def _step_size(self) -> Tuple[np.ndarray, np.ndarray]: 179 | """ 180 | define tau and sigma 181 | 182 | Parameters 183 | --------------- 184 | dim : int 185 | the dimension of input array 186 | 187 | Return 188 | ------- 189 | tau : np.ndarray 190 | step parameter tau 191 | 192 | sigma : np.ndarray 193 | step parameter sigma 194 | """ 195 | tau = np.ones(self.dim) * np.sum(np.abs(self.params)) 196 | for i in range(len(self.params) - 1): 197 | tau[i] = np.sum(np.abs(self.params[:i + 1])) 198 | tau[-i] = np.sum(np.abs(self.params[-i - 1:])) 199 | tau += self.alpha 200 | tau = 1. / tau 201 | 202 | sigma = np.ones(2 * self.dim - len(self.params) + 1) 203 | sigma[:self.dim - len(self.params) + 1] *= np.sum(np.abs(self.params)) 204 | sigma[self.dim - len(self.params) + 1:] *= self.alpha 205 | sigma = 1. / sigma 206 | return tau, sigma 207 | 208 | def _init_dual(self, y, sigma) -> np.ndarray: 209 | """ 210 | define dual vector 211 | 212 | Return 213 | ------ 214 | dual : np.ndarray 215 | """ 216 | dual = np.zeros(2 * self.dim - len(self.params) + 1) 217 | dual[:-self.dim] = sigma[:-self.dim] * self._Du(y) 218 | dual = np.clip(dual, -1, 1) 219 | return dual 220 | -------------------------------------------------------------------------------- /tests/test_decomposition_ksvd.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from spmimage.decomposition import KSVD 4 | 5 | import numpy as np 6 | import numpy.testing as npt 7 | 8 | from tests.utils import generate_dictionary_and_samples 9 | 10 | 11 | class TestKSVD(unittest.TestCase): 12 | def setUp(self): 13 | np.random.seed(0) 14 | 15 | def test_ksvd_attributes(self): 16 | D = np.random.rand(10, 100) 17 | model = KSVD(n_components=10, transform_n_nonzero_coefs=5, max_iter=1, method='normal', dict_init=D) 18 | self.assertIsInstance(model.get_params(), dict) 19 | 20 | def test_ksvd_normal_input(self): 21 | n_nonzero_coefs = 3 22 | n_samples = 512 23 | n_features = 32 24 | n_components = 24 25 | max_iter = 10 26 | 27 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, n_nonzero_coefs) 28 | model = KSVD(n_components=n_components, transform_n_nonzero_coefs=n_nonzero_coefs, max_iter=max_iter, 29 | method='normal') 30 | model.fit(X) 31 | 32 | # check error of learning 33 | self.assertTrue(model.error_[-1] < model.error_[0]) 34 | self.assertTrue(model.n_iter_ <= max_iter) 35 | 36 | # check estimated dictionary 37 | norm = np.linalg.norm(model.components_ - A0, ord='fro') 38 | self.assertTrue(norm < 15) 39 | 40 | # check reconstructed data 41 | code = model.transform(X) 42 | reconstructed = np.dot(code, model.components_) 43 | reconstruct_error = np.linalg.norm(reconstructed - X, ord='fro') 44 | self.assertTrue(reconstruct_error < 15) 45 | 46 | def test_ksvd_input_with_missing_values(self): 47 | n_nonzero_coefs = 4 48 | n_samples = 128 49 | n_features = 32 50 | n_components = 16 51 | max_iter = 10 52 | missing_value = 0 53 | 54 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, n_nonzero_coefs) 55 | X[X < 0.1] = missing_value 56 | model = KSVD(n_components=n_components, transform_n_nonzero_coefs=n_nonzero_coefs, max_iter=max_iter, 57 | missing_value=missing_value, method='normal') 58 | model.fit(X) 59 | 60 | # check error of learning 61 | self.assertTrue(model.error_[-1] <= model.error_[0]) 62 | self.assertTrue(model.n_iter_ <= max_iter) 63 | 64 | def test_ksvd_warm_start(self): 65 | n_nonzero_coefs = 3 66 | n_samples = 128 67 | n_features = 32 68 | n_components = 16 69 | max_iter = 1 70 | 71 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, n_nonzero_coefs) 72 | model = KSVD(n_components=n_components, transform_n_nonzero_coefs=n_nonzero_coefs, max_iter=max_iter, 73 | method='normal') 74 | 75 | prev_error = np.linalg.norm(X, 'fro') 76 | for i in range(10): 77 | model.fit(X) 78 | # print(model.error_) 79 | self.assertTrue(model.error_[-1] <= prev_error) 80 | prev_error = model.error_[-1] 81 | 82 | def test_ksvd_dict_init(self): 83 | D = np.random.rand(10, 100) 84 | model = KSVD(n_components=10, transform_n_nonzero_coefs=5, max_iter=1, method='normal', dict_init=D) 85 | npt.assert_array_equal(model.dict_init, D) 86 | 87 | # shape of X is invalid against initial dictionary 88 | X = np.random.rand(20, 200) 89 | with self.assertRaises(ValueError): 90 | model.fit(X) 91 | 92 | # n_components is invalid against initial dictionary 93 | X = np.random.rand(20, 100) 94 | model = KSVD(n_components=20, transform_n_nonzero_coefs=5, max_iter=1, method='normal', dict_init=D) 95 | with self.assertRaises(ValueError): 96 | model.fit(X) 97 | 98 | def test_approximate_ksvd(self): 99 | n_nonzero_coefs = 3 100 | n_samples = 128 101 | n_features = 32 102 | n_components = 16 103 | max_iter = 10 104 | 105 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, n_nonzero_coefs) 106 | model = KSVD(n_components=n_components, transform_n_nonzero_coefs=n_nonzero_coefs, max_iter=max_iter, 107 | method='approximate') 108 | model.fit(X) 109 | 110 | # check error of learning 111 | self.assertTrue(model.error_[-1] <= model.error_[0]) 112 | self.assertTrue(model.n_iter_ <= max_iter) 113 | 114 | def test_approximate_ksvd_warm_start(self): 115 | n_nonzero_coefs = 3 116 | n_samples = 128 117 | n_features = 32 118 | n_components = 16 119 | max_iter = 1 120 | 121 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, n_nonzero_coefs) 122 | model = KSVD(n_components=n_components, transform_n_nonzero_coefs=n_nonzero_coefs, max_iter=max_iter, 123 | method='approximate') 124 | 125 | prev_error = np.linalg.norm(X, 'fro') 126 | for i in range(10): 127 | model.fit(X) 128 | # print(model.error_) 129 | self.assertTrue(model.error_[-1] <= prev_error) 130 | prev_error = model.error_[-1] 131 | 132 | def test_transform(self): 133 | n_nonzero_coefs = 3 134 | n_samples = 128 135 | n_features = 32 136 | n_components = 24 137 | max_iter = 10 138 | 139 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, n_nonzero_coefs) 140 | model = KSVD(n_components=n_components, transform_n_nonzero_coefs=n_nonzero_coefs, max_iter=max_iter, 141 | method='normal') 142 | model.fit(X) 143 | 144 | # check error of learning 145 | code = model.transform(X) 146 | err = np.linalg.norm(X - code.dot(model.components_), 'fro') 147 | self.assertTrue(err <= model.error_[0]) 148 | self.assertTrue(model.n_iter_ <= max_iter) 149 | 150 | def test_transform_with_mask(self): 151 | n_nonzero_coefs = 4 152 | n_samples = 128 153 | n_features = 32 154 | n_components = 16 155 | max_iter = 10 156 | missing_value = 0 157 | 158 | A0, X = generate_dictionary_and_samples(n_samples, n_features, n_components, n_nonzero_coefs) 159 | X[X < 0.1] = missing_value 160 | mask = np.where(X == missing_value, 0, 1) 161 | 162 | model = KSVD(n_components=n_components, transform_n_nonzero_coefs=n_nonzero_coefs, max_iter=max_iter, 163 | missing_value=missing_value, method='normal') 164 | model.fit(X) 165 | 166 | # check error of learning 167 | code = model.transform(X) 168 | err = np.linalg.norm(mask * (X - code.dot(model.components_)), 'fro') 169 | self.assertTrue(err <= model.error_[0]) 170 | self.assertTrue(model.n_iter_ <= max_iter) 171 | 172 | 173 | if __name__ == '__main__': 174 | unittest.main() 175 | -------------------------------------------------------------------------------- /spmimage/linear_model/admm.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from abc import abstractmethod 3 | 4 | import numpy as np 5 | 6 | from sklearn.utils import check_array, check_X_y 7 | from sklearn.base import RegressorMixin 8 | from sklearn.linear_model._base import LinearModel 9 | from sklearn.linear_model._coordinate_descent import _alpha_grid 10 | from joblib import Parallel, delayed 11 | 12 | logger = getLogger(__name__) 13 | 14 | 15 | def _soft_threshold(X: np.ndarray, thresh: float) -> np.ndarray: 16 | return np.where(np.abs(X) <= thresh, 0, X - thresh * np.sign(X)) 17 | 18 | 19 | def _cost_function(X, y, w, z, alpha): 20 | n_samples = X.shape[0] 21 | return np.linalg.norm(y - X.dot(w)) / n_samples + alpha * np.sum(np.abs(z)) 22 | 23 | 24 | def _update(X, y_k, D, coef_matrix, inv_Xy_k, inv_D, alpha, rho, max_iter, 25 | tol): 26 | # Initialize ADMM parameters 27 | n_samples = X.shape[0] 28 | 29 | w_k = X.T.dot(y_k) / n_samples 30 | z_k = D.dot(w_k) 31 | h_k = np.zeros_like(z_k) 32 | 33 | cost = _cost_function(X, y_k, w_k, z_k, alpha) 34 | threshold = alpha / rho 35 | for t in range(max_iter): 36 | # Update 37 | w_k = inv_Xy_k + inv_D.dot(z_k - h_k / rho) 38 | Dw_t = D.dot(w_k) 39 | z_k = _soft_threshold(Dw_t + h_k / rho, threshold) 40 | h_k += rho * (Dw_t - z_k) 41 | 42 | # after cost 43 | pre_cost = cost 44 | cost = _cost_function(X, y_k, w_k, z_k, alpha) 45 | gap = np.abs(cost - pre_cost) 46 | if gap < tol: 47 | break 48 | # should return z_k as well since it's sparse by soft threshold ?! 49 | return w_k, t 50 | 51 | 52 | def admm_path(X, y, Xy=None, alphas=None, eps=1e-3, n_alphas=100, rho=1.0, 53 | max_iter=1000, tol=1e-04): 54 | _, n_features = X.shape 55 | multi_output = False 56 | n_iters = [] 57 | 58 | if y.ndim != 1: 59 | multi_output = True 60 | _, n_outputs = y.shape 61 | 62 | if alphas is None: 63 | alphas = _alpha_grid(X, y, Xy=Xy, l1_ratio=1.0, eps=eps, 64 | n_alphas=n_alphas) 65 | else: 66 | alphas = np.sort(alphas)[::-1] 67 | n_alphas = len(alphas) 68 | 69 | if not multi_output: 70 | coefs = np.zeros((n_features, n_alphas), dtype=X.dtype) 71 | else: 72 | coefs = np.zeros((n_features, n_outputs, n_alphas), dtype=X.dtype) 73 | 74 | for i, alpha in enumerate(alphas): 75 | clf = LassoADMM(alpha=alpha, rho=rho, max_iter=max_iter, tol=tol) 76 | clf.fit(X, y) 77 | coefs[..., i] = clf.coef_ 78 | n_iters.append(clf.n_iter_) 79 | 80 | return alphas, coefs, n_iters 81 | 82 | 83 | def _admm( 84 | X: np.ndarray, y: np.ndarray, D: np.ndarray, alpha: float, 85 | rho: float, tol: float, max_iter: int): 86 | """Alternate Direction Multiplier Method(ADMM) for Generalized Lasso. 87 | 88 | Minimizes the objective function:: 89 | 90 | 1 / (2 * n_samples) * ||y - Xw||^2_2 + alpha * ||z||_1 91 | 92 | where:: 93 | 94 | Dw = z 95 | 96 | To solve this problem, ADMM uses augmented Lagrangian 97 | 98 | 1 / (2 * n_samples) * ||y - Xw||^2_2 + alpha * ||z||_1 99 | + h^T (Dw - z) + rho / 2 * ||Dw - z||^2_2 100 | 101 | where h is Lagrange multiplier and rho is tuning parameter. 102 | """ 103 | n_samples, n_features = X.shape 104 | n_targets = y.shape[1] 105 | 106 | w_t = np.empty((n_features, n_targets), dtype=X.dtype) 107 | 108 | # Calculate inverse matrix 109 | coef_matrix = X.T.dot(X) / n_samples + rho * D.T.dot(D) 110 | inv_matrix = np.linalg.inv(coef_matrix) 111 | inv_Xy = inv_matrix.dot(X.T).dot(y) / n_samples 112 | inv_D = inv_matrix.dot(rho * D.T) 113 | 114 | # Update ADMM parameters by columns 115 | n_iter_ = np.empty((n_targets,), dtype=int) 116 | if n_targets == 1: 117 | w_t, n_iter_[0] = _update(X, y, D, coef_matrix, inv_Xy, inv_D, alpha, 118 | rho, max_iter, tol) 119 | else: 120 | results = Parallel(n_jobs=-1, backend='threading')( 121 | delayed(_update)(X, y[:, k], D, coef_matrix, inv_Xy[:, k], inv_D, 122 | alpha, rho, max_iter, tol) 123 | for k in range(n_targets) 124 | ) 125 | for k in range(n_targets): 126 | w_t[:, k], n_iter_[k] = results[k] 127 | 128 | return np.squeeze(w_t.T), n_iter_.tolist() 129 | 130 | 131 | class GeneralizedLasso(LinearModel, RegressorMixin): 132 | """Alternate Direction Multiplier Method(ADMM) for Generalized Lasso. 133 | """ 134 | 135 | def __init__(self, alpha=1.0, rho=1.0, fit_intercept=True, 136 | normalize=False, copy_X=True, max_iter=1000, 137 | tol=1e-4): 138 | self.alpha = alpha 139 | self.rho = rho 140 | self.fit_intercept = fit_intercept 141 | self.normalize = normalize 142 | self.copy_X = copy_X 143 | self.max_iter = max_iter 144 | self.tol = tol 145 | 146 | def fit(self, X, y, check_input=False): 147 | if self.alpha == 0: 148 | logger.warning(""" 149 | With alpha=0, this algorithm does not converge well. 150 | You are advised to use the LinearRegression estimator 151 | """) 152 | 153 | if check_input: 154 | X, y = check_X_y(X, y, accept_sparse='csc', 155 | order='F', dtype=[np.float64, np.float32], 156 | copy=self.copy_X and self.fit_intercept, 157 | multi_output=True, y_numeric=True) 158 | y = check_array(y, order='F', copy=False, dtype=X.dtype.type, 159 | ensure_2d=False) 160 | 161 | (X, y, 162 | X_offset, 163 | y_offset, 164 | X_scale) = self._preprocess_data(X, y, 165 | fit_intercept=self.fit_intercept, 166 | normalize=self.normalize, 167 | copy=self.copy_X and not check_input) 168 | 169 | if y.ndim == 1: 170 | y = y[:, np.newaxis] 171 | 172 | n_features = X.shape[1] 173 | D = self.generate_transform_matrix(n_features) 174 | self.coef_, self.n_iter_ = _admm(X, y, D, self.alpha, self.rho, 175 | self.tol, self.max_iter) 176 | 177 | if y.shape[1] == 1: 178 | self.n_iter_ = self.n_iter_[0] 179 | 180 | self._set_intercept(X_offset, y_offset, X_scale) 181 | 182 | # workaround since _set_intercept will cast self.coef_ into X.dtype 183 | self.coef_ = np.asarray(self.coef_, dtype=X.dtype) 184 | 185 | return self 186 | 187 | @abstractmethod 188 | def generate_transform_matrix(self, n_features: int) -> np.ndarray: 189 | """ 190 | :return: 191 | """ 192 | 193 | 194 | class LassoADMM(GeneralizedLasso): 195 | """Linear Model trained with L1 prior as regularizer (aka the Lasso) 196 | The optimization objective for Lasso is:: 197 | (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 198 | """ 199 | 200 | def generate_transform_matrix(self, n_features: int) -> np.ndarray: 201 | return np.eye(n_features) 202 | 203 | 204 | class FusedLassoADMM(GeneralizedLasso): 205 | """Fused Lasso minimises the following objective function. 206 | 207 | 1/(2 * n_samples) * ||y - Xw||^2_2 + \lambda_1 \sum_{j=1}^p |w_j| 208 | + \lambda_2 \sum_{j=2}^p |w_j - w_{j-1}| 209 | """ 210 | 211 | def __init__(self, alpha=1.0, sparse_coef=1.0, trend_coef=1.0, rho=1.0, 212 | fit_intercept=True, normalize=False, copy_X=True, 213 | max_iter=1000, tol=1e-4): 214 | super().__init__(alpha=alpha, rho=rho, fit_intercept=fit_intercept, 215 | normalize=normalize, copy_X=copy_X, max_iter=max_iter, 216 | tol=tol) 217 | self.sparse_coef = sparse_coef 218 | self.trend_coef = trend_coef 219 | 220 | def generate_transform_matrix(self, n_features: int) -> np.ndarray: 221 | fused = np.eye(n_features - 1, n_features, k=1) \ 222 | - np.eye(n_features - 1, n_features) 223 | return self.merge_matrix(n_features, fused) 224 | 225 | def merge_matrix(self, n_features: int, 226 | trend_matrix: np.ndarray) -> np.ndarray: 227 | if self.sparse_coef == 0: 228 | return self.trend_coef * trend_matrix 229 | elif self.trend_coef == 0: 230 | return self.sparse_coef * np.identity(n_features) 231 | else: 232 | generated = np.vstack([self.sparse_coef * np.identity(n_features), 233 | self.trend_coef * trend_matrix]) 234 | return generated 235 | 236 | 237 | class TrendFilteringADMM(FusedLassoADMM): 238 | 239 | def generate_transform_matrix(self, n_features: int) -> np.ndarray: 240 | trend = - np.eye(n_features - 2, n_features) \ 241 | + 2 * np.eye(n_features - 2, n_features, k=1) \ 242 | - np.eye(n_features - 2, n_features, k=2) 243 | return self.merge_matrix(n_features, trend) 244 | 245 | 246 | class QuadraticTrendFilteringADMM(TrendFilteringADMM): 247 | 248 | def generate_transform_matrix(self, n_features: int) -> np.ndarray: 249 | if n_features < 3: 250 | trend = np.zeros((0, n_features)) 251 | else: 252 | trend = np.eye(n_features - 3, n_features) \ 253 | - 3 * np.eye(n_features - 3, n_features, k=1) \ 254 | + 3 * np.eye(n_features - 3, n_features, k=2) \ 255 | - np.eye(n_features - 3, n_features, k=3) 256 | return self.merge_matrix(n_features, trend) 257 | -------------------------------------------------------------------------------- /spmimage/decomposition/ksvd.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import numpy as np 4 | from sklearn.base import BaseEstimator 5 | from sklearn.decomposition._dict_learning import _BaseSparseCoding, sparse_encode 6 | from sklearn.utils import check_array 7 | from sklearn.utils.validation import check_is_fitted 8 | 9 | from .dict_learning import sparse_encode_with_mask 10 | 11 | logger = getLogger(__name__) 12 | 13 | 14 | def _ksvd(Y: np.ndarray, n_components: int, n_nonzero_coefs: int, max_iter: int, tol: float, 15 | dict_init: np.ndarray = None, mask: np.ndarray = None, n_jobs: int = 1, method: str = None): 16 | """_ksvd 17 | Finds a dictionary that can be used to represent data using a sparse code. 18 | Solves the optimization problem: 19 | argmin \sum_{i=1}^M || y_i - w_iH ||_2^2 such that ||w_i||_0 <= k_0 for all 1 <= i <= M 20 | ({w_i}_{i=1}^M, H) 21 | 22 | **Note** 23 | Y ~= WH = code * dictionary 24 | 25 | Parameters: 26 | ------------ 27 | Y : array-like, shape (n_samples, n_features) 28 | Training vector, where n_samples in the number of samples 29 | and n_features is the number of features. 30 | 31 | n_components : int, 32 | number of dictionary elements to extract 33 | 34 | n_nonzero_coefs : int, 35 | number of non-zero elements of sparse coding 36 | 37 | max_iter : int, 38 | maximum number of iterations to perform 39 | 40 | tol : float, 41 | tolerance for numerical error 42 | 43 | dict_init : array of shape (n_components, n_features), 44 | initial values for the dictionary, for warm restart 45 | 46 | mask : array-like, shape (n_samples, n_features), 47 | value at (i,j) in mask is not 1 indicates value at (i,j) in Y is missing 48 | 49 | n_jobs : int, optional 50 | Number of parallel jobs to run. 51 | Returns: 52 | --------- 53 | code : array of shape (n_samples, n_components) 54 | The sparse code factor in the matrix factorization. 55 | 56 | dictionary : array of shape (n_components, n_features), 57 | The dictionary factor in the matrix factorization. 58 | 59 | errors : array 60 | Vector of errors at each iteration. 61 | 62 | n_iter : int 63 | Number of iterations. 64 | """ 65 | 66 | W = np.zeros((Y.shape[0], n_components)) 67 | if dict_init is None: 68 | H = np.random.rand(n_components, Y.shape[1]) 69 | H = np.dot(H, np.diag(1. / np.sqrt(np.diag(np.dot(H.T, H))))) 70 | else: 71 | H = dict_init 72 | 73 | if mask is None: 74 | errors = [np.linalg.norm(Y - W.dot(H), 'fro')] 75 | else: 76 | errors = [np.linalg.norm(mask * (Y - W.dot(H)), 'fro')] 77 | 78 | k = -1 79 | for k in range(max_iter): 80 | if mask is None: 81 | W = sparse_encode(Y, H, algorithm='omp', n_nonzero_coefs=n_nonzero_coefs, n_jobs=n_jobs) 82 | else: 83 | W = sparse_encode_with_mask(Y, H, mask, algorithm='omp', n_nonzero_coefs=n_nonzero_coefs, n_jobs=n_jobs) 84 | Y[mask == 0] = W.dot(H)[mask == 0] 85 | 86 | for j in range(n_components): 87 | x = W[:, j] != 0 88 | if np.sum(x) == 0: 89 | continue 90 | 91 | if method == 'approximate': 92 | H[j, :] = 0 93 | error = Y[x, :] - np.dot(W[x, :], H) 94 | g = W[x, j].T 95 | d = error.T.dot(g) 96 | d /= np.linalg.norm(d) 97 | g = error.dot(d) 98 | W[x, j] = g.T 99 | H[j, :] = d 100 | else: 101 | # normal ksvd 102 | W[x, j] = 0 103 | error = Y[x, :] - np.dot(W[x, :], H) 104 | U, s, V = np.linalg.svd(error) 105 | W[x, j] = U[:, 0] * s[0] 106 | H[j, :] = V.T[:, 0] 107 | 108 | if mask is None: 109 | errors.append(np.linalg.norm(Y - W.dot(H), 'fro')) 110 | else: 111 | errors.append(np.linalg.norm(mask * (Y - W.dot(H)), 'fro')) 112 | 113 | if np.abs(errors[-1] - errors[-2]) < tol: 114 | break 115 | 116 | return W, H, errors, k + 1 117 | 118 | 119 | class KSVD(BaseEstimator, _BaseSparseCoding): 120 | """ K-SVD 121 | Finds a dictionary that can be used to represent data using a sparse code. 122 | Solves the optimization problem: 123 | argmin \sum_{i=1}^M || y_i - Ax_i ||_2^2 such that ||x_i||_0 <= k_0 for all 1 <= i <= M 124 | (A,{x_i}_{i=1}^M) 125 | 126 | Parameters 127 | ---------- 128 | n_components : int, 129 | number of dictionary elements to extract 130 | 131 | max_iter : int, 132 | maximum number of iterations to perform 133 | 134 | tol : float, 135 | tolerance for numerical error 136 | 137 | missing_value : float, 138 | missing value in the data 139 | 140 | transform_algorithm : {'lasso_lars', 'lasso_cd', 'lars', 'omp', 'threshold'} 141 | Algorithm used to transform the data 142 | lars: uses the least angle regression method (linear_model.lars_path) 143 | lasso_lars: uses Lars to compute the Lasso solution 144 | lasso_cd: uses the coordinate descent method to compute the 145 | Lasso solution (linear_model.Lasso). lasso_lars will be faster if 146 | the estimated components are sparse. 147 | omp: uses orthogonal matching pursuit to estimate the sparse solution 148 | threshold: squashes to zero all coefficients less than alpha from 149 | the projection ``dictionary * X'`` 150 | .. versionadded:: 0.17 151 | *lasso_cd* coordinate descent method to improve speed. 152 | 153 | transform_n_nonzero_coefs : int, ``0.1 * n_features`` by default 154 | Number of nonzero coefficients to target in each column of the 155 | solution. This is only used by `algorithm='lars'` and `algorithm='omp'` 156 | and is overridden by `alpha` in the `omp` case. 157 | 158 | transform_alpha : float, 1. by default 159 | If `algorithm='lasso_lars'` or `algorithm='lasso_cd'`, `alpha` is the 160 | penalty applied to the L1 norm. 161 | If `algorithm='threshold'`, `alpha` is the absolute value of the 162 | threshold below which coefficients will be squashed to zero. 163 | If `algorithm='omp'`, `alpha` is the tolerance parameter: the value of 164 | the reconstruction error targeted. In this case, it overrides 165 | `n_nonzero_coefs`. 166 | 167 | n_jobs : int, 168 | number of parallel jobs to run 169 | 170 | split_sign : bool, False by default 171 | Whether to split the sparse feature vector into the concatenation of 172 | its negative part and its positive part. This can improve the 173 | performance of downstream classifiers. 174 | 175 | random_state : int, RandomState instance or None, optional (default=None) 176 | If int, random_state is the seed used by the random number generator; 177 | If RandomState instance, random_state is the random number generator; 178 | If None, the random number generator is the RandomState instance used 179 | by `np.random`. 180 | 181 | method : {'approximate': Approximate KSVD, 'normal': normal KSVD}, 'approximate' by default 182 | 183 | Attributes 184 | ---------- 185 | components_ : array, [n_components, n_features] 186 | dictionary atoms extracted from the data 187 | 188 | error_ : array 189 | vector of errors at each iteration 190 | 191 | n_iter_ : int 192 | Number of iterations run. 193 | 194 | **References:** 195 | Elad, Michael, and Michal Aharon. 196 | "Image denoising via sparse and redundant representations over learned dictionaries." 197 | IEEE Transactions on Image processing 15.12 (2006): 3736-3745. 198 | ---------- 199 | 200 | """ 201 | 202 | def __init__(self, n_components=None, max_iter=1000, tol=1e-8, 203 | missing_value=None, transform_algorithm='omp', 204 | transform_n_nonzero_coefs=None, 205 | transform_max_iter=None, 206 | positive_code=False, 207 | transform_alpha=None, n_jobs=1, 208 | split_sign=False, random_state=None, method='approximate', dict_init=None): 209 | 210 | self.n_components = n_components 211 | self.transform_algorithm = transform_algorithm 212 | self.transform_n_nonzero_coefs = transform_n_nonzero_coefs 213 | self.transform_max_iter = transform_max_iter 214 | self.transform_alpha = transform_alpha 215 | self.split_sign = split_sign 216 | self.n_jobs = n_jobs 217 | self.max_iter = max_iter 218 | self.tol = tol 219 | self.positive_code = positive_code 220 | self.missing_value = missing_value 221 | self.random_state = random_state 222 | self.method = method 223 | self.dict_init = dict_init 224 | self.components_ = None 225 | 226 | def fit(self, X, y=None): 227 | """Fit the model from data in X. 228 | Parameters 229 | ---------- 230 | X : array-like, shape (n_samples, n_features) 231 | Training vector, where n_samples in the number of samples 232 | and n_features is the number of features. 233 | 234 | y : Ignored 235 | 236 | Returns 237 | ------- 238 | self : object 239 | Returns the object itself 240 | """ 241 | 242 | # Input validation on an array, list, sparse matrix or similar. 243 | # By default, the input is converted to an at least 2D numpy array. If the dtype of the array is object, attempt converting to float, raising on failure. 244 | X = check_array(X) 245 | n_samples, n_features = X.shape 246 | if self.n_components is None: 247 | n_components = X.shape[1] 248 | else: 249 | n_components = self.n_components 250 | 251 | mask = None 252 | if self.missing_value is not None: 253 | mask = np.where(X == self.missing_value, 0, 1) 254 | 255 | # initialize dictionary 256 | dict_init = None 257 | if self.components_ is not None: 258 | # Warm Start 259 | logger.info("KSVD fit - warm start") 260 | dict_init = self.components_ 261 | elif self.dict_init is not None: 262 | logger.info("KSVD fit - init start") 263 | dict_init = self.dict_init 264 | else: 265 | logger.info("KSVD fit - cold start") 266 | 267 | if (dict_init is not None) and (dict_init.shape[1] != n_features): 268 | raise ValueError("Found input variables with inconsistent numbers of n_features") 269 | elif (dict_init is not None) and (dict_init.shape[0] != n_components): 270 | raise ValueError("Found input variables with inconsistent numbers of n_components") 271 | 272 | code, self.components_, self.error_, self.n_iter_ = _ksvd( 273 | X, n_components, self.transform_n_nonzero_coefs, 274 | max_iter=self.max_iter, tol=self.tol, 275 | dict_init=dict_init, mask=mask, n_jobs=self.n_jobs, method=self.method) 276 | 277 | return self 278 | 279 | def transform(self, X): 280 | """Encode the data as a sparse combination of the dictionary atoms. 281 | Coding method is determined by the object parameter 282 | `transform_algorithm`. 283 | 284 | Parameters 285 | ---------- 286 | X : array of shape (n_samples, n_features) 287 | Test data to be transformed, must have the same number of 288 | features as the data used to train the model. 289 | 290 | Returns 291 | ------- 292 | code : array, shape (n_samples, n_components) 293 | Transformed data 294 | """ 295 | if self.missing_value is not None: 296 | check_is_fitted(self, 'components_') 297 | 298 | X = check_array(X) 299 | 300 | mask = np.where(X == self.missing_value, 0, 1) 301 | 302 | code = sparse_encode_with_mask( 303 | X, self.components_, mask, algorithm=self.transform_algorithm, 304 | n_nonzero_coefs=self.transform_n_nonzero_coefs, 305 | alpha=self.transform_alpha, n_jobs=self.n_jobs) 306 | 307 | if self.split_sign: 308 | # feature vector is split into a positive and negative side 309 | n_samples, n_features = code.shape 310 | split_code = np.empty((n_samples, 2 * n_features)) 311 | split_code[:, :n_features] = np.maximum(code, 0) 312 | split_code[:, n_features:] = -np.minimum(code, 0) 313 | code = split_code 314 | return code 315 | 316 | else: 317 | return super().transform(X) 318 | -------------------------------------------------------------------------------- /spmimage/linear_model/hmlasso.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import numpy as np 4 | from sklearn.utils import check_array, check_X_y 5 | from sklearn.utils.validation import FLOAT_DTYPES 6 | from sklearn.base import RegressorMixin 7 | from sklearn.linear_model._base import LinearModel 8 | from sklearn.preprocessing import StandardScaler 9 | from joblib import Parallel, delayed 10 | 11 | logger = getLogger(__name__) 12 | 13 | 14 | def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True, check_input=True): 15 | """ 16 | Apply preprocess to data. 17 | 18 | Parameters 19 | ---------- 20 | X : np.ndarray, shape = (n_samples, n_features) 21 | Data. 22 | 23 | y : np.ndarray, shape = (n_samples, ) or (n_samples, n_targets) 24 | Target. 25 | 26 | fit_intercept : boolean 27 | Whether to calculate the intercept for this model. 28 | If set to False, no intercept will be used in calculations (i.e. data is expected to be centered). 29 | 30 | normalize : boolean, optional (default=False) 31 | This parameter is ignored when fit_intercept is set to False. 32 | If True, the regressors X will be standardized before regression. 33 | 34 | copy : boolean, optional (default=True) 35 | If True, X will be copied; else, it may be overwritten. 36 | 37 | check_input : boolean, optional (default=True) 38 | Allow to bypass several input checking. 39 | 40 | Returns 41 | ------- 42 | X : np.ndarray, shape = (n_samples, n_features) 43 | Processed data. 44 | 45 | y : np.ndarray, shape = (n_samples, ) or (n_samples, n_targets) 46 | Processed target. 47 | 48 | X_offset : np.ndarray, shape = (n_features, ) 49 | Mean of data. 50 | 51 | y_offset : np.ndarray or int 52 | Mean of target. 53 | 54 | X_scale : np.ndarray, shape = (n_features, ) 55 | Standard deviation of data. 56 | """ 57 | if check_input: 58 | X = check_array(X, copy=copy, accept_sparse='csc', 59 | dtype=FLOAT_DTYPES, force_all_finite='allow-nan') 60 | if copy: 61 | X = X.copy() 62 | 63 | y = np.asarray(y, dtype=X.dtype) 64 | 65 | if fit_intercept: 66 | # standardize 67 | if normalize: 68 | scaler = StandardScaler() 69 | X = scaler.fit_transform(X) 70 | X_offset, X_scale = scaler.mean_, scaler.var_ 71 | X_scale = np.sqrt(X_scale) 72 | X_scale[X_scale == 0] = 1 73 | else: 74 | scaler = StandardScaler(with_std=False) 75 | X = scaler.fit_transform(X) 76 | X_offset = scaler.mean_ 77 | X_scale = np.ones(X.shape[1], dtype=X.dtype) 78 | y_offset = np.average(y, axis=0) 79 | y = y - y_offset 80 | else: 81 | X_offset = np.zeros(X.shape[1], dtype=X.dtype) 82 | X_scale = np.ones(X.shape[1], dtype=X.dtype) 83 | if y.ndim == 1: 84 | y_offset = X.dtype.type(0) 85 | else: 86 | y_offset = np.zeros(y.shape[1], dtype=X.dtype) 87 | 88 | return X, y, X_offset, y_offset, X_scale 89 | 90 | 91 | def _proximal_map(x: np.ndarray, alpha: float) -> np.ndarray: 92 | """ 93 | Proximal mapping. 94 | 95 | Parameters 96 | ---------- 97 | x : np.ndarray, shape = (n, ) 98 | 99 | Returns 100 | ------- 101 | : np.ndarray, shape = (n, ) 102 | """ 103 | return np.sign(x) * np.maximum((np.abs(x) - alpha), 0) 104 | 105 | 106 | def _symm(x: np.ndarray) -> np.ndarray: 107 | """ 108 | Return symmetric matrix. 109 | 110 | Parameters 111 | ---------- 112 | x : np.ndarray, shape = (n, n) 113 | 114 | Returns 115 | ------- 116 | : np.ndarray, shape = (n, n) 117 | """ 118 | return 0.5 * (x + x.T) 119 | 120 | 121 | def _projection(x: np.ndarray, eps: float) -> np.ndarray: 122 | """ 123 | Projection onto PSD. 124 | 125 | Parameters 126 | ---------- 127 | x : np.ndarray, shape = (n, n) 128 | symmetric matrix 129 | 130 | Returns 131 | ------- 132 | : np.ndarray, shape = (n, n) 133 | symmetric positive definite matrix 134 | """ 135 | w, v = np.linalg.eigh(x) 136 | return v.dot(np.diag(np.maximum(w, eps)).dot(v.T)) 137 | 138 | 139 | def _cost_function_qpl1(x: np.ndarray, Q: np.ndarray, p: np.ndarray, alpha: float) -> float: 140 | """ 141 | Return 1 / 2 * x^TQx + p^Tx + alpha * ||x||_1. 142 | """ 143 | return 0.5 * x.dot(Q.dot(x)) + p.dot(x) + alpha * np.sum(np.abs(x)) 144 | 145 | 146 | def _admm_qpl1(Q: np.ndarray, p: np.ndarray, alpha: float, mu: float, tol: float, max_iter: int) -> (np.ndarray, int): 147 | """ 148 | Alternate Direction Multiplier Method (ADMM) for quadratic programming with l1 regularization. 149 | 150 | Minimizes the objective function:: 151 | 1 / 2 * x^TQx + p^Tx + alpha * ||z||_1 152 | 153 | To solve this problem, ADMM uses augmented Lagrangian 154 | 1 / 2 * x^TQx + p^Tx + alpha * ||z||_1 155 | + h^T (x - z) + mu / 2 * ||x - z||^2_2 156 | where h is Lagrange multiplier and mu is tuning parameter. 157 | """ 158 | n_features = p.shape[0] 159 | 160 | coef_matrix = Q + mu * np.eye(n_features) 161 | inv_matrix = np.linalg.inv(coef_matrix) 162 | alpha_mu = alpha / mu 163 | inv_mu = 1 / mu 164 | x = -np.copy(p) 165 | z = -np.copy(p) 166 | h = np.zeros_like(p) 167 | cost = _cost_function_qpl1(x, Q, p, alpha) 168 | t = 0 169 | for t in range(max_iter): 170 | x = inv_matrix.dot(mu * z - p - h) 171 | z = _proximal_map(x + inv_mu * h, alpha_mu) 172 | h = h + mu * (x - z) 173 | 174 | pre_cost = cost 175 | cost = _cost_function_qpl1(x, Q, p, alpha) 176 | gap = np.abs(cost - pre_cost) 177 | if gap < tol: 178 | break 179 | return x, t 180 | 181 | 182 | def _cost_function_psd(A: np.ndarray, S: np.ndarray, W: np.ndarray) -> float: 183 | """ 184 | Return 1 / 2 * ||W * (A - S)||^2_F. 185 | """ 186 | return 0.5 * np.sum((W * (A - S)) ** 2) 187 | 188 | 189 | def _admm_psd(S: np.ndarray, W: np.ndarray, mu: float, tol: float, max_iter: int, eps: float) -> np.ndarray: 190 | """ 191 | Alternate Direction Multiplier Method (ADMM) for following minimization. 192 | 193 | Minimizes the objective function:: 194 | 1 / 2 * ||W * (A - S)||^2_F 195 | where 196 | A ≥ O 197 | 198 | To solve this problem, ADMM uses augmented Lagrangian 199 | 1 / 2 * ||W * B||^2_F - + mu / 2 * ||A - B - S||^2_F 200 | where Lambda is Lagrange multiplier and mu is tuning parameter. 201 | """ 202 | # initialize 203 | A = np.zeros_like(S) 204 | B = np.zeros_like(S) 205 | Lambda = np.zeros_like(S) 206 | 207 | cost = _cost_function_psd(B, S, W) 208 | weight = W * W / mu + np.eye(S.shape[0]) 209 | for _ in range(max_iter): 210 | A = _projection(_symm(B + S + Lambda), eps) 211 | B = (A - S - Lambda) / weight 212 | Lambda = Lambda - (A - B - S) 213 | 214 | pre_cost = cost 215 | cost = _cost_function_psd(B, S, W) 216 | gap = np.abs(cost - pre_cost) 217 | if gap < tol: 218 | break 219 | return A 220 | 221 | 222 | def _update( 223 | S: np.ndarray, W: np.ndarray, mu_cov: float, tol_cov: float, max_iter_cov: int, eps: float, 224 | p: np.ndarray, alpha: float, mu_coef: float, tol_coef: float, max_iter_coef: int 225 | ) -> (np.ndarray, int): 226 | """ 227 | Update. 228 | """ 229 | Sigma = _admm_psd(S=S, W=W, mu=mu_cov, tol=tol_cov, max_iter=max_iter_cov, eps=eps) 230 | coef = _admm_qpl1(Q=Sigma, p=p, alpha=alpha, mu=mu_coef, tol=tol_coef, max_iter=max_iter_coef) 231 | return coef 232 | 233 | 234 | class HMLasso(LinearModel, RegressorMixin): 235 | """ 236 | Lasso with High Missing Rate. 237 | 238 | Parameters 239 | ---------- 240 | alpha : float, optional (default=1.0) 241 | Constant that multiplies the L1 term. 242 | 243 | mu_coef : float, optional (default=1.0) 244 | Constant that used in augmented Lagrangian function for Lasso. 245 | 246 | mu_cov : float, optional (default=1.0) 247 | Constant that used in augmented Lagrangian function for covariance estimation. 248 | 249 | normalize : boolean, optional (default=False) 250 | This parameter is ignored when fit_intercept is set to False. 251 | If True, the regressors X will be standardized before regression. 252 | 253 | copy_X : boolean, optional (default=True) 254 | If True, X will be copied; else, it may be overwritten. 255 | 256 | tol_coef : float, optional (default=1e-4) 257 | The tolerance for Lasso. 258 | 259 | tol_cov : float, optional (default=1e-4) 260 | The tolerance for covariance estimation. 261 | 262 | max_iter_coef : int, optional (default=1000) 263 | The maximum number of iterations of Lasso. 264 | 265 | max_iter_cov : int, optional (default=100) 266 | The maximum number of iterations of covariance estimation. 267 | 268 | eps : float, optional (default=1e-8) 269 | small positive value used in projection onto PSD. 270 | 271 | Attributes 272 | ---------- 273 | coef_ : array, shape (n_features,) | (n_targets, n_features) 274 | parameter vector 275 | 276 | intercept_ : float | array, shape (n_targets,) 277 | independent term in decision function. 278 | 279 | n_iter_ : int | array-like, shape (n_targets,) 280 | number of iterations run by admm solver to reach 281 | the specified tolerance. 282 | """ 283 | def __init__( 284 | self, 285 | alpha: float = 1.0, 286 | mu_coef: float = 1.0, 287 | mu_cov: float = 1.0, 288 | normalize: bool = False, 289 | copy_X: bool = True, 290 | tol_coef: float = 1e-4, 291 | tol_cov: float = 1e-4, 292 | max_iter_coef: int = 1000, 293 | max_iter_cov: int = 100, 294 | eps: float = 1e-8 295 | ): 296 | self.alpha = alpha 297 | self.mu_coef = mu_coef 298 | self.mu_cov = mu_cov 299 | self.normalize = normalize 300 | self.copy_X = copy_X 301 | self.tol_coef = tol_coef 302 | self.tol_cov = tol_cov 303 | self.max_iter_coef = max_iter_coef 304 | self.max_iter_cov = max_iter_cov 305 | self.eps = eps 306 | 307 | self.coef_ = None 308 | self.n_iter_ = None 309 | 310 | def _set_intercept(self, X_offset, y_offset, X_scale): 311 | """ 312 | Set the intercept_ 313 | """ 314 | self.coef_ = self.coef_ / X_scale 315 | self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T) 316 | 317 | _preprocess_data = staticmethod(_preprocess_data) 318 | 319 | def fit(self, X: np.ndarray, y: np.ndarray, check_input: bool = False): 320 | """ 321 | fit. 322 | 323 | Parameters 324 | ---------- 325 | X : np.ndarray, shape = (n_samples, n_features) 326 | Missed data. 327 | 328 | y : np.ndarray, shape = (n_samples, ) or (n_samples, n_targets) 329 | Target. 330 | 331 | check_input : boolean, optional (default=False) 332 | Allow to bypass several input checking. 333 | """ 334 | if self.alpha == 0: 335 | logger.warning(""" 336 | With alpha=0, this algorithm does not converge well. 337 | You are advised to use the LinearRegression estimator. 338 | """) 339 | 340 | if check_input: 341 | X, y = check_X_y( 342 | X, y, accept_sparse='csc', accept_large_sparse=False, 343 | order='F', dtype=[np.float64, np.float32], 344 | copy=False, force_all_finite='allow-nan', 345 | multi_output=True, y_numeric=True 346 | ) 347 | y = check_array( 348 | y, order='F', copy=False, dtype=X.dtype.type, 349 | ensure_2d=False 350 | ) 351 | X, y, X_offset, y_offset, X_scale = self._preprocess_data(X, y, fit_intercept=True, 352 | normalize=self.normalize, 353 | copy=self.copy_X and not check_input) 354 | 355 | n_samples, n_features = X.shape[:2] 356 | 357 | missed = np.isnan(X) 358 | not_missed = (~missed).astype(np.float) 359 | R = not_missed.T.dot(not_missed) 360 | 361 | # centralize 362 | X[missed] = 0 363 | if y.ndim == 1: 364 | y = y[:, np.newaxis] 365 | n_targets = y.shape[1] 366 | 367 | S = X.T.dot(X) / R 368 | rho = X.T.dot(y) / np.diag(R).reshape(-1, 1) 369 | 370 | # Update by columns 371 | w_t = np.empty((n_targets, n_features), dtype=X.dtype) 372 | n_iter_ = np.empty((n_targets,), dtype=int) 373 | if n_targets == 1: 374 | w_t, n_iter_[0] = _update( 375 | S=S, W=R/n_samples, mu_cov=self.mu_cov, tol_cov=self.tol_cov, max_iter_cov=self.max_iter_cov, 376 | eps=self.eps, p=-rho[:, 0], alpha=self.alpha, mu_coef=self.mu_coef, tol_coef=self.tol_coef, 377 | max_iter_coef=self.max_iter_coef 378 | ) 379 | else: 380 | results = Parallel(n_jobs=-1, backend='threading')( 381 | delayed(_update)( 382 | S=S, W=R/n_samples, mu_cov=self.mu_cov, tol_cov=self.tol_cov, max_iter_cov=self.max_iter_cov, 383 | eps=self.eps, p=-rho[:, k], alpha=self.alpha, mu_coef=self.mu_coef, tol_coef=self.tol_coef, 384 | max_iter_coef=self.max_iter_coef 385 | ) 386 | for k in range(n_targets) 387 | ) 388 | for k in range(n_targets): 389 | w_t[k], n_iter_[k] = results[k] 390 | 391 | self.coef_, self.n_iter_ = np.squeeze(w_t), n_iter_.tolist() 392 | 393 | if y.shape[1] == 1: 394 | self.n_iter_ = self.n_iter_[0] 395 | 396 | self._set_intercept(X_offset, y_offset, X_scale) 397 | 398 | # workaround since _set_intercept will cast self.coef_ into X.dtype 399 | self.coef_ = np.asarray(self.coef_, dtype=X.dtype) 400 | 401 | return self 402 | -------------------------------------------------------------------------------- /tests/test_linear_model_admm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from spmimage.linear_model import (LassoADMM, 5 | FusedLassoADMM, 6 | TrendFilteringADMM, 7 | QuadraticTrendFilteringADMM) 8 | from spmimage.linear_model.admm import admm_path 9 | from numpy.testing import assert_array_almost_equal 10 | 11 | 12 | def build_dataset(n_samples=50, n_features=200, n_informative_features=10, 13 | n_targets=1): 14 | """ 15 | build an ill-posed linear regression problem with many noisy features and 16 | comparatively few samples 17 | this is the same dataset builder as in sklearn implementation 18 | (see https://github.com/scikit-learn/scikit-learn/blob/master 19 | /sklearn/linear_model/tests/test_coordinate_descent.py) 20 | """ 21 | random_state = np.random.RandomState(0) 22 | if n_targets > 1: 23 | w = random_state.randn(n_features, n_targets) 24 | else: 25 | w = random_state.randn(n_features) 26 | w[n_informative_features:] = 0.0 27 | X = random_state.randn(n_samples, n_features) 28 | y = np.dot(X, w) 29 | X_test = random_state.randn(n_samples, n_features) 30 | y_test = np.dot(X_test, w) 31 | return X, y, X_test, y_test 32 | 33 | 34 | class TestLassoADMM(unittest.TestCase): 35 | def test_lasso_admm_zero(self): 36 | # Check that lasso by admm can handle zero data without crashing 37 | X = [[0], [0], [0]] 38 | y = [0, 0, 0] 39 | clf = LassoADMM(alpha=0.1).fit(X, y) 40 | pred = clf.predict([[1], [2], [3]]) 41 | assert_array_almost_equal(clf.coef_, [0]) 42 | assert_array_almost_equal(pred, [0, 0, 0]) 43 | 44 | def test_lasso_admm_toy(self): 45 | # Test LassoADMM for various parameters of alpha and rho, using 46 | # the same test case as Lasso implementation of sklearn. 47 | # (see https://github.com/scikit-learn/scikit-learn/blob/master 48 | # /sklearn/linear_model/tests/test_coordinate_descent.py) 49 | # Actually, the parameters alpha = 0 should not be allowed. However, 50 | # we test it as a border case. 51 | # WARNING: 52 | # LassoADMM can't check the case which is not converged 53 | # because LassoADMM doesn't check dual gap yet. 54 | # This problem will be fixed in future. 55 | 56 | X = np.array([[-1.], [0.], [1.]]) 57 | Y = [-1, 0, 1] # just a straight line 58 | T = [[2.], [3.], [4.]] # test sample 59 | 60 | clf = LassoADMM(alpha=1e-8) 61 | clf.fit(X, Y) 62 | pred = clf.predict(T) 63 | assert_array_almost_equal(clf.coef_, [1], decimal=3) 64 | assert_array_almost_equal(pred, [2, 3, 4], decimal=3) 65 | 66 | clf = LassoADMM(alpha=0.1) 67 | clf.fit(X, Y) 68 | pred = clf.predict(T) 69 | assert_array_almost_equal(clf.coef_, [.85], decimal=3) 70 | assert_array_almost_equal(pred, [1.7, 2.55, 3.4], decimal=3) 71 | 72 | clf = LassoADMM(alpha=0.5) 73 | clf.fit(X, Y) 74 | pred = clf.predict(T) 75 | assert_array_almost_equal(clf.coef_, [.254], decimal=3) 76 | assert_array_almost_equal(pred, [0.508, 0.762, 1.016], decimal=3) 77 | 78 | clf = LassoADMM(alpha=1) 79 | clf.fit(X, Y) 80 | pred = clf.predict(T) 81 | assert_array_almost_equal(clf.coef_, [.0], decimal=3) 82 | assert_array_almost_equal(pred, [0, 0, 0], decimal=3) 83 | 84 | # this is the same test case as the case alpha=1e-8 85 | # because the default rho parameter equals 1.0 86 | clf = LassoADMM(alpha=1e-8, rho=1.0) 87 | clf.fit(X, Y) 88 | pred = clf.predict(T) 89 | assert_array_almost_equal(clf.coef_, [1], decimal=3) 90 | assert_array_almost_equal(pred, [2, 3, 4], decimal=3) 91 | 92 | clf = LassoADMM(alpha=0.5, rho=0.3, max_iter=50) 93 | clf.fit(X, Y) 94 | pred = clf.predict(T) 95 | assert_array_almost_equal(clf.coef_, [0.249], decimal=3) 96 | assert_array_almost_equal(pred, [0.498, 0.746, 0.995], decimal=3) 97 | 98 | clf = LassoADMM(alpha=0.5, rho=0.5) 99 | clf.fit(X, Y) 100 | pred = clf.predict(T) 101 | assert_array_almost_equal(clf.coef_, [0.249], decimal=3) 102 | assert_array_almost_equal(pred, [0.498, 0.746, 0.995], decimal=3) 103 | 104 | def test_lasso_admm_toy_multi(self): 105 | # for issue #39 106 | X = np.eye(4) 107 | y = np.array([[1, 1, 0], 108 | [1, 0, 1], 109 | [0, 1, 0], 110 | [0, 0, 1]]) 111 | 112 | clf = LassoADMM(alpha=0.05, tol=1e-8).fit(X, y) 113 | assert_array_almost_equal(clf.coef_[0], [0.3, 0.3, -0.3, -0.3], 114 | decimal=3) 115 | assert_array_almost_equal(clf.coef_[1], [0.3, -0.3, 0.3, -0.3], 116 | decimal=3) 117 | assert_array_almost_equal(clf.coef_[2], [-0.3, 0.3, -0.3, 0.3], 118 | decimal=3) 119 | 120 | def test_lasso_admm(self): 121 | X, y, X_test, y_test = build_dataset() 122 | 123 | clf = LassoADMM(alpha=0.05, tol=1e-8).fit(X, y) 124 | self.assertGreater(clf.score(X_test, y_test), 0.99) 125 | self.assertLess(clf.n_iter_, 150) 126 | 127 | clf = LassoADMM(alpha=0.05, fit_intercept=False).fit(X, y) 128 | self.assertGreater(clf.score(X_test, y_test), 0.99) 129 | 130 | # normalize doesn't seem to work well 131 | clf = LassoADMM(alpha=0.144, rho=0.1, normalize=True).fit(X, y) 132 | self.assertGreater(clf.score(X_test, y_test), 0.60) 133 | 134 | def test_lasso_admm_multi(self): 135 | X, y, X_test, y_test = build_dataset(n_targets=3) 136 | 137 | clf = LassoADMM(alpha=0.05, tol=1e-8).fit(X, y) 138 | self.assertGreater(clf.score(X_test, y_test), 0.99) 139 | self.assertLess(clf.n_iter_[0], 150) 140 | 141 | clf = LassoADMM(alpha=0.05, fit_intercept=False).fit(X, y) 142 | self.assertGreater(clf.score(X_test, y_test), 0.99) 143 | 144 | # normalize doesn't seem to work well 145 | clf = LassoADMM(alpha=0.144, rho=0.1, normalize=True).fit(X, y) 146 | self.assertGreater(clf.score(X_test, y_test), 0.60) 147 | 148 | 149 | class TestFusedLassoADMM(unittest.TestCase): 150 | def setUp(self): 151 | np.random.seed(0) 152 | self.X = np.random.normal(0.0, 1.0, (8, 4)) 153 | 154 | def test_fused_lasso_alpha(self): 155 | beta = np.array([4, 4, 0, 0]) 156 | y = self.X.dot(beta) 157 | T = np.array([[5., 6., 7., 8.], 158 | [9., 10., 11., 12.], 159 | [13., 14., 15., 16.]]) # test sample 160 | 161 | # small regularization parameter 162 | clf = FusedLassoADMM(alpha=1e-8).fit(self.X, y) 163 | actual = clf.predict(T) 164 | assert_array_almost_equal(clf.coef_, [3.998, 3.998, -0.022, 0.003], 165 | decimal=3) 166 | assert_array_almost_equal(actual, [43.862, 75.773, 107.683], decimal=3) 167 | self.assertLess(clf.n_iter_, 100) 168 | 169 | # default 170 | clf = FusedLassoADMM(alpha=1).fit(self.X, y) 171 | actual = clf.predict(T) 172 | assert_array_almost_equal(clf.coef_, [2.428, 1.506, 0., 0.], decimal=3) 173 | assert_array_almost_equal(actual, [22.69, 38.427, 54.164], decimal=3) 174 | self.assertLess(clf.n_iter_, 100) 175 | 176 | # all coefs will be zero 177 | clf = FusedLassoADMM(alpha=10).fit(self.X, y) 178 | actual = clf.predict(T) 179 | assert_array_almost_equal(clf.coef_, [0, 0, 0, 0], decimal=3) 180 | assert_array_almost_equal(actual, [3.724, 3.722, 3.72], decimal=3) 181 | 182 | self.assertLess(clf.n_iter_, 20) 183 | 184 | def test_fused_lasso_coef(self): 185 | beta = np.array([4, 4, 0, 0]) 186 | y = self.X.dot(beta) 187 | T = np.array([[5., 6., 7., 8.], 188 | [9., 10., 11., 12.], 189 | [13., 14., 15., 16.]]) # test sample 190 | 191 | # small trend_coef 192 | clf = FusedLassoADMM(alpha=1e-8, trend_coef=1e-4).fit(self.X, y) 193 | actual = clf.predict(T) 194 | assert_array_almost_equal(clf.coef_, [3.999, 3.999, -0.007, 0.001], 195 | decimal=3) 196 | assert_array_almost_equal(actual, [43.95, 75.916, 107.883], decimal=3) 197 | self.assertLess(clf.n_iter_, 100) 198 | 199 | # large trend_coef 200 | clf = FusedLassoADMM(alpha=1e-8, trend_coef=10).fit(self.X, y) 201 | actual = clf.predict(T) 202 | assert_array_almost_equal(clf.coef_, [3.938, 3.755, -0.079, 0.141], 203 | decimal=3) 204 | assert_array_almost_equal(actual, [42.862, 73.885, 104.908], decimal=3) 205 | self.assertLess(clf.n_iter_, clf.max_iter) 206 | 207 | # small sparse_coef 208 | clf = FusedLassoADMM(alpha=1e-8, sparse_coef=1e-4).fit(self.X, y) 209 | actual = clf.predict(T) 210 | assert_array_almost_equal(clf.coef_, [3.999, 3.999, -0.011, 0.002], 211 | decimal=3) 212 | assert_array_almost_equal(actual, [43.931, 75.885, 107.839], decimal=3) 213 | self.assertLess(clf.n_iter_, 100) 214 | 215 | # large sparse_coef 216 | clf = FusedLassoADMM(alpha=1e-8, sparse_coef=10).fit(self.X, y) 217 | actual = clf.predict(T) 218 | assert_array_almost_equal(clf.coef_, [3.913, 3.837, -0.349, 0.113], 219 | decimal=3) 220 | assert_array_almost_equal(actual, [41.265, 71.32, 101.374], decimal=3) 221 | self.assertLess(clf.n_iter_, clf.max_iter) 222 | 223 | def test_simple_lasso(self): 224 | X, y, X_test, y_test = build_dataset() 225 | 226 | # check if FusedLasso generates the same result of LassoAdmm 227 | # when trend_coef is zero 228 | clf = FusedLassoADMM(alpha=0.05, sparse_coef=1, 229 | trend_coef=0, tol=1e-8).fit(X, y) 230 | self.assertGreater(clf.score(X_test, y_test), 0.99) 231 | self.assertLess(clf.n_iter_, 150) 232 | 233 | 234 | class TestTrendFilteringADMM(unittest.TestCase): 235 | def setUp(self): 236 | np.random.seed(0) 237 | self.X = np.random.normal(0.0, 1.0, (8, 5)) 238 | 239 | def test_generate_transform_matrix(self): 240 | D = np.array([[-1, 2, -1, 0, 0], 241 | [0, -1, 2, -1, 0], 242 | [0, 0, -1, 2, -1]]) 243 | clf = TrendFilteringADMM(sparse_coef=1, trend_coef=0) 244 | assert_array_almost_equal(np.eye(5), clf.generate_transform_matrix(5)) 245 | 246 | clf = TrendFilteringADMM(sparse_coef=0, trend_coef=1) 247 | assert_array_almost_equal(D, clf.generate_transform_matrix(5)) 248 | 249 | clf = TrendFilteringADMM(sparse_coef=1, trend_coef=1) 250 | assert_array_almost_equal(np.vstack([np.eye(5), D]), 251 | clf.generate_transform_matrix(5)) 252 | 253 | def test_trend_filtering(self): 254 | beta = np.array([0., 10., 20., 10., 0.]) 255 | y = self.X.dot(beta) 256 | 257 | # small regularization parameter 258 | clf = TrendFilteringADMM(alpha=1e-8).fit(self.X, y) 259 | assert_array_almost_equal(np.round(clf.coef_), [0, 10, 20, 10, 0]) 260 | 261 | # default 262 | clf = TrendFilteringADMM(alpha=0.01).fit(self.X, y) 263 | assert_array_almost_equal(clf.coef_, [0.015, 10.017, 19.932, 264 | 9.989, 0.033], decimal=3) 265 | 266 | # all coefs will be zero 267 | clf = TrendFilteringADMM(alpha=1e5).fit(self.X, y) 268 | assert_array_almost_equal(clf.coef_, [0., 0., 0., 0., 0.], decimal=1) 269 | 270 | 271 | class TestQuadraticTrendFilteringADMM(unittest.TestCase): 272 | def setUp(self): 273 | np.random.seed(0) 274 | self.X = np.random.normal(0.0, 1.0, (8, 5)) 275 | 276 | def test_generate_transform_matrix(self): 277 | D = np.array([[1., -3., 3., -1., 0.], 278 | [0., 1., -3., 3., -1.]]) 279 | 280 | clf = QuadraticTrendFilteringADMM(sparse_coef=1, trend_coef=0) 281 | assert_array_almost_equal(np.eye(5), clf.generate_transform_matrix(5)) 282 | 283 | clf = QuadraticTrendFilteringADMM(sparse_coef=0, trend_coef=1) 284 | assert_array_almost_equal(D, clf.generate_transform_matrix(5)) 285 | 286 | clf = QuadraticTrendFilteringADMM(sparse_coef=1, trend_coef=1) 287 | assert_array_almost_equal(np.vstack([np.eye(5), D]), 288 | clf.generate_transform_matrix(5)) 289 | 290 | # boundary check 291 | assert_array_almost_equal(np.eye(1), clf.generate_transform_matrix(1)) 292 | assert_array_almost_equal(np.eye(2), clf.generate_transform_matrix(2)) 293 | assert_array_almost_equal(np.eye(3), clf.generate_transform_matrix(3)) 294 | 295 | def test_trend_filtering(self): 296 | beta = np.array([0., 1., 2., 1., 0.]) 297 | y = self.X.dot(beta) 298 | 299 | # small regularization parameter 300 | clf = QuadraticTrendFilteringADMM(alpha=1e-8).fit(self.X, y) 301 | assert_array_almost_equal(np.round(clf.coef_), [0, 1, 2, 1, 0]) 302 | 303 | # all coefs will be zero 304 | clf = QuadraticTrendFilteringADMM(alpha=1e5).fit(self.X, y) 305 | assert_array_almost_equal(clf.coef_, [0., 0., 0., 0., 0.], decimal=1) 306 | 307 | 308 | class TestAdmmPath(unittest.TestCase): 309 | def test_admm_path_alphas(self): 310 | # check if input alphas are sorted 311 | X, y, X_test, y_test = build_dataset() 312 | 313 | alphas = [0.1, 0.3, 0.5, -0.1, -0.2] 314 | actual_alphas, _, _ = admm_path(X, y, alphas=alphas) 315 | assert_array_almost_equal(actual_alphas, [0.5, 0.3, 0.1, -0.1, -0.2]) 316 | 317 | def test_admm_path_coefs(self): 318 | # check if we can get correct coefs 319 | 320 | X = np.array([[-1.], [0.], [1.]]) 321 | y = np.array([-1, 0, 1]) # just a straight line 322 | 323 | _, coefs_actual, _ = admm_path(X, y, 324 | alphas=[1e-8, 0.1, 0.5, 1], rho=1.0) 325 | assert_array_almost_equal(coefs_actual[0], [-1.31072000e-04, 326 | 2.53888000e-01, 327 | 8.49673483e-01, 328 | 9.99738771e-01], decimal=3) 329 | 330 | 331 | if __name__ == '__main__': 332 | unittest.main() 333 | -------------------------------------------------------------------------------- /examples/group-lasso.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Joint $\\ell_{2,1}$-norm minimization\n", 8 | "$$\\min_W\\quad\\frac12\\|X - WD\\|_F^2 + \\alpha\\|W\\|_{2,1}$$" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [ 16 | { 17 | "name": "stderr", 18 | "output_type": "stream", 19 | "text": [ 20 | "\n", 21 | "Bad key \"backend in file /Users/masui/.matplotlib/matplotlibrc, line 1 ('\"backend : Tkagg\"')\n", 22 | "You probably need to get an updated matplotlibrc file from\n", 23 | "https://github.com/matplotlib/matplotlib/blob/v3.4.3/matplotlibrc.template\n", 24 | "or from the matplotlib source distribution\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "import numpy as np\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import seaborn as sns" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "## Toy data" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "np.random.seed(0)\n", 48 | "n_samples = 60\n", 49 | "n_features = 30\n", 50 | "n_components = 50\n", 51 | "transform_n_nonzero_coefs = 5\n", 52 | "\n", 53 | "dictionary = np.random.randn(n_components, n_features)\n", 54 | "W = np.random.rand(n_samples, n_components) * 2 - 1\n", 55 | "for i in range(n_components - transform_n_nonzero_coefs):\n", 56 | " W[:, -(i+1)] = 0\n", 57 | "X = W @ dictionary + np.random.randn(n_samples, n_features) * 0.1" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "$$X \\simeq W D$$" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "## Efficient and robust feature selection via joint ℓ2, 1-norms minimization\n", 72 | "[Nie, F., Huang, H., Cai, X., & Ding, C. H. (2010). Efficient and robust feature selection via joint ℓ2, 1-norms minimization. In Advances in neural information processing systems (pp. 1813-1821).](https://papers.nips.cc/paper/3988-efficient-and-robust-feature-selection-via-joint-l21-norms-minimization.pdf)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "def sparse_code_via_l21(X, dictionary, alpha=1.0, max_iter=30):\n", 82 | " m, n = dictionary.shape\n", 83 | " A = np.vstack((dictionary, alpha * np.identity(n)))\n", 84 | " D_inv = np.identity(m+n)\n", 85 | " for _ in range(max_iter):\n", 86 | " DA = D_inv @ A\n", 87 | " ADA = np.linalg.inv(A.T @ DA)\n", 88 | " U = X @ ADA @ DA.T\n", 89 | " D_inv = 2 * np.diag(np.linalg.norm(U, axis=0))\n", 90 | " return U[:, :m]" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "CPU times: user 20.3 ms, sys: 7.35 ms, total: 27.6 ms\n", 103 | "Wall time: 7.8 ms\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "%time coef1 = sparse_code_via_l21(X, dictionary, alpha=.1)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": { 115 | "scrolled": true 116 | }, 117 | "outputs": [ 118 | { 119 | "data": { 120 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdgAAADyCAYAAADumdR9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyd0lEQVR4nO2deZhcVZn/v2919o1AWJNAwr7KMrLI4DgozAzyQ0ARBXEQlYkboP5GEQaQRcfBcWZABxSjoDIugAoSxiCLgoCskU22LISEdFaykpVOd73zR1XOPbe6zu1TdW91p9/+fp6nnn6rznLPvdXfOve+55z3iKqCEEIIIcVS6usGEEIIIRZhB0sIIYS0AHawhBBCSAtgB0sIIYS0AHawhBBCSAtgB0sIIYS0AHawvYSITBYRFZFBfXDseSJyfG8fl5CtGRH5GxGZ2dftqIeIHCsi7X3dDpIPUx2siJwhIk+IyHoRWVa1Pysi0tdty0JE1nmvsohs9N6f1WBdPxaRr7eqrQMFEdlVRB4QkZdE5EUR+XydPCIi3xGROSLyvIj8VV+0dSBSvWncWKOd63oooyKy15b3qvqwqu7bovZRh1sRfaXnXn+aahUi8s8ALgTwOQD3AFgH4FAAXwJwI4C36pRpU9WuXmxmXVR11BZbROYBOFdV76/NJyKDVLWzN9s2gOkE8M+q+rSIjAbwZxG5T1Vf8vK8F8De1ddRAL5X/Ut6h/fV0wkhdegTPZt4ghWRbQBcBeCzqvorVV2rFZ5R1bNU9a1qvh+LyPdEZLqIrAfwbhHZX0QeFJHV1Tubk716HxSRc73354jII957FZFPi8jsavnrtzwti0ibiPyHiCwXkbkA/l8T53WsiLSLyFdEZAmAH9W2wWvHXiIyBcBZAC6s3tHf5WU7tHpXtkZEbhWRYY22ZyChqotV9emqvRbAywAm1GQ7BcDN1f+1xwGMFZFdermpxKOqgz9W/8+Xi8it1c8fqmZ5rqqND9e6YatPxV+u6mS9iNwoIjuJyN0islZE7heRbb38vxSRJdVjPSQiB1Y/r6tDERkvIr8WkTdE5DURucCra3j192mViLwE4IjWX62BQ1/p2coT7NEAhgK4MyLvRwCcCOAkACMBPAPgJgB/D+CdAO4UkcNVNXZs5iRUxDAGwJ8B3AXgdwD+qZp2GID1AH4dezI17AxgOwCTULkh+nAoo6pOFZG/BtCuqpfWJH8IwAkANgH4E4BzANzQZJu2Xl78dXTsTznog58CMMX7aKqqTu2WT2QyKt/jEzVJEwAs8N63Vz9bHNsGUjhfA3AvgHcDGALgcABQ1XeJiAI4RFXnAJUb2DrlTwPwd6j8Nj6Dyvf+SVR+kKcDuADAldW8dwP4BIAOAN8E8DMAh9bToYiUUPltuBPAmQAmArhfRGaq6j0ALgewZ/U1slr3wKYFWgZ6V89WOtjtASz33aci8iiAA1DpeP9BVbfcwd6pqn+q5jkUwCgAV6tqGcAfROR/URHAFZHHvlpVVwNYLSIPoOKW/h0qHdq1qrqgeqx/A3BsE+dWBnC59xTeRBUAgO+o6qJqHXdV22kO7Yr3+FcFWFeEWxCRUajcHH1BVd/M1zpSML8REX/I5MsANqNyMzpeVdsBPFK3ZJj/VtWlACAiDwNYpqrPVN/fAeC4LRlV9aYttohcAWCViGyjqmvq1HsEgB1U9arq+7ki8gMAZ6AypPUhVDxwKwGsFJHvAPhqg203RdFaBnpfz1Y62BUAtvfHKFX1rwGg6gLyXeH+Hcp4AAuqnesW5qO76yCLJZ69AZUO29VdU28zvKGqm5os61PbzvEF1Ln10VXcELWIDEZFjD9T1dvrZFkIYFfv/cTqZ6R3OLV2DLZ68/g1AE+KyCoA/+l3hBEs9eyNdd6Pqh6nDcC/AjgdwA6o3AgDlZv9eh3sJADjRWS191kbgIerdlG/F3YoUMtA3+jZxBgsgMdQmcR0SkRe3+2wCMCuVffNFnZDclHXAxjhpe3cQJsWI/1l7dZAWZ9aN0mqTSJS26YBvT2SljujX1lUx9JvBPCyqv5XINs0AGdLhXcAWKOqdA/3Iaq6RFX/SVXHA/gUgO+KN3O4QD6Cyu/N8QC2ATC5+vkWF1OtDhcAeE1Vx3qv0ap6YjW9qN8LMxSlZaDv9GziCVZVV4vIlaiISVBxuawHcDAq4xkhnkDlae5CEflPAMcAeB+SCQbPAviAiPwQlTvMTyJ9R5vFbQAuqLqc1wO4qKGTCvMcgAOr7u1X0N2VvRTAHgUdq//RgFupB44B8I8A/iIiz1Y/+xdUf/hU9QZUxuROBDAHlf+jjxd1cNIcInI6gMeq7uFVqHR0W54ut2hjTgGHGo3KTf0KVG54v1GTXqvDJwGsFZGvAPgOKuO2+wMYrqpPofJ7cbGIPIHKb9b5BbSxf1OcloE+0rOJDhYAVPXfRWQhKkt1bkalU5sL4CsAHg2U6RCR9wH4LoCLUXlyPVtVX6lmuQaVznYpgOdRmcQQG7DhBwD2QaVDfBPAfwB4T+Nn1q3Ns0TkKgD3o+KyuhiVO/Ut3Ajgl1VX1IOqemreY/YntKMIbzqgqo8geRoJ5VFUloWRvuEuEfF/he8DMBvAtVJZWbAUwOdVdW41/QoAPxGR4ahMiFmW49g3A/gHVH4zVgK4DMBnvPRuOhSRkwD8J4DXUJkbMhPAlsmIV6Iy6fA1VDxrPwLQba3mQKIoLQN9p2fhhuvEEp0P/Xf0P/Sgd52/VQcgIWQgY0HLZp5gCQFQtFuJENJXGNAyO1hiCi145iEhpG+woGV2sMQWBkRJCIEJLbODJabQcv93KxFCbGiZHSwxRZEzDwkhfYcFLWd2sO/59ONuFtfCu89Ipe158OXOHj77v529ybvrePGt1bkbSAgAzJs/P2qWoIW73lbh63nJ3We6zycfcqWzh8+6NlVmXVeHs2d2rI06jv9Fbc1rFPK3M66GIq9HrQg0UHvsMUP5ssQWd6bhnANJy3yCJbYwMG5DCIEJLbODJaawMPOQEGJDy5kd7LGP/ouzbxkxMZX2xuLfOXv7w/7V2W974Wpn00VMeh0DomwVvp5/MTzZ5nLpgjucvdOB6V0OD5h5jbPDLuK0x0+D70rBXOK9z++uTWirqS3sdGzGkZuVr+TlKgfyhF2q6dZIXbuCX7efL9y2Ns9OX4+48mFqy+SM/WBAy3yCJabQzW/1dRMIIQVgQcvsYIkpGtlDkhCy9WJBy5kd7CNvvyLJOOPrqbTNy2c4++U3vI3hh4wupmWENIGFcZtW8diBX07ePPM1Z25e9ZyzX5rxdKpMx+BRzg65FiXDKZx2E4ZcpeHZsPFuRt/9nLQu3mnZjEu0dfOlw7XVXkPxrDg3e9a30Dj+da+tOd81saBlPsESW0TsDUkI6QcY0DI7WGIKC24lQogNLWd2sO9/OnEj/ceG9Mbuw0fu5uzrhiduiqtXLiqqbYQ0jgFRtooTXvovZ3+nY7WzB4/eK/l8SPr6fWPlAmeHrmx3R2AzwQ+SVH/mb1d0mIT6js9yyoUZzpdNcpyS155yZtvqp6WvQeMu1O7zjotzTTcXnKKZ6xmJAS3zCZaYQjv6/8xDQogNLbODJaaw4FYihNjQcq3/hJB+jZa7ol89ISI3icgyEXkhkH6siKwRkWerr68WfkKEDFAsaDnzCfZXRyVBwLe97+xU2qSuNc6+vGuSsy8dm/TZX1wxO3cDCWmIYu96fwzgOgA3Z+R5WFVPKvKgreLOo5M5Fdve81FnT+x4w9lXDd4zVebLnp6/snJW1HHS45QJ2eN6Sc74b7Dn8UepGSMMLQAKLxNKvy8HF/qEy/T8aU9kLb/xI0ap92m4TKgNRcSvKhQDWqaLmJiiSLeSqj4kIpMLq5AQEo0FLdNFTEyhXeXoV0EcLSLPicjdInJgUZUSMtCxoOXMJ9jPPJW4ob8z4cRU2vrdT3H26oc+4eyN2+5TRLsIaQrt2BydV0SmAJjifTRVVac2cLinAUxS1XUiciKA3wDYu4HyvcrZj1/m7Kk7Hevsjfue5eyVD56TKtO2nX86cdGWQstXsmL+pI7p2XmfYbTmGcJ3Gce7OpM60uWzokz11q64SXviYmalCT1hdS8f++0l5L0CFrRMFzExhXbFS7kqwEZEWFv+Tc+eLiLfFZHtVXV5s3USQipY0DI7WGKKAt1FPSIiOwNYqqoqIkeicpu/otcaQIhhLGg5s4O9WsY6+60F/5tKe719ute4Ic5+tqsjb5sIaZoiRSkivwBwLIDtRaQdwOUABgOAqt4A4IMAPiMinQA2AjhDVXttkmWjfG/wDs5es/h+Zy9a8gdnl0rDUmWe7PTddCGnX9wpx34zXYEA9rXEtSY8iziekFs5q23+LN7GA/L7xDpn1XOuS8aGB+ky9e3u1HdFhzd2yB9lyoKW+QRLTKHl4vo3VT2zh/TrUJn6TwgpGAtaZgdLTFHu2GofIAkhDWBBy5kd7OoVTzm7XBqRShs3dj9nr1iZ7CG5Q2lwUW0jpGEamRgx0Fi17FFnd5aGO9vX8kpPy0CtnkMuu+4h6Hsmy51Y371aizY1bzZcW4iYudPZwRzizie0MUL6zKTmnV+37xYOO5abm91bP6BFtpM53zxiC1rmEywxRbn35kUQQlqIBS2zgyWm0P4fH5wQAhtaZgdLTGFBlIQQG1rO7GBHjz3I2Rs2LkulfUHXOXva8O2d/b3VDPBP+g4LbqVWMWrbg529fsNiZ39W1zv7zqHbpsrcENRz7jg9wZT4Tckjjitt6fca+gcJ1+WnhCIfxS+5qb8RQnOtycoXrr25kc2ex7u7b21ff0w5Fgta5hMsMUXX5makTAjZ2rCgZXawxBQW7noJITa0nNnBbrfdYc4+ZOl9qbSf7///nf3yY59z9h7vvdXZc+/+cO4GEtIIFsZtWsW4cUc4+7CO3zr71/te4OyXnvhCqsw+7/kfZ8/+wz86O2+UnqyFPfldmF4Nkf8QsQ7vZn7z087V1m0CEB+Ov+eFR7URtEreEqBQ3dku78afRi1omU+wxBTlcv93KxFCbGiZHSwxRdnAXS8hxIaWMzvYtsl/7+yX5t2aStPZP3b2gYd+3dkz75sCQvoKC3e9raI08Z3Ofvm1nzr7rVk3OvuAt12SKjPzoc86O2/Q+njnaH1nZ3a8qPo1xsaYym5P3D64IZpxp4evVfe5ut2tnlvU0+e1Ofy+LnTWtS1Lu/0bvwYWtMwnWGKKrs7+L0pCiA0ts4MlprBw10sIsaHlzA521gPnOHvokHGptEs7Vzp76huPOLut020Mj868rSOkQSyIslXM+mMyfDN06HbOvsQLGjN15VOpMoM7k7SuJtzCvkMxy02YdomW66bUlk+Xqe9ULWaeblJL2nmdFXzBz5l/jnRfE+Per/08dAVisaBlPsESU3QZECUhxIaW2cESU1i46yWE2NAyO1hiis6u/i9KQogNLWd2sJ/bZg9nP7DXZ1Npl834orP32f49zpZhOyWZNi3J2z5CGqJLixOliNwE4CQAy1T1oDrpAuDbAE4EsAHAOar6dG2+rYXzt5nk7Psnne3srz53mbP3GveudKGm9Nx4tKKs0dW4MiHCy1pikYCdXVe+OH/NBPGPj+TUOI1vR5+/DRa0HNocgpB+Sbks0a8IfgzghIz09wLYu/qaAuB7uU+AEALAhpbpIiamKPKuV1UfEpHJGVlOAXCzqiqAx0VkrIjsoqqLM8oQQiKwoOXMDvaXb8539rlz0h36/mMSd9OvZ9/g7DHe/oub8rSMkCYoFyjKCCYAWOC9b69+tlV2sLe9+bqzPzo/CeK/7+jdnP2bOT9MlRlXGuzszd7nXRmuwWaiPKUpcn/Z/M5S/0z9847d2zVcW74IT83VkFVj/K629cvUti7fuVrQMp9giSkauesVkSmouIO2MFVVpxbeKEJIw1jQMjtYYorNDUztrwowjwgXAtjVez+x+hkhJCcWtJzZwR4xckdnzyl3pNKOHDTU2dOGJFFhtGNN3jYR0jRFjttEMA3AeSJyC4CjAKzZmsdfDxye6PQ5L0LTuwaPcvb0oTumymx8a7mzQ5ubNOe2DJeKdSyGZvcW7a4N7eFaP97Ullz1/w8lcJ619YXI3qu2GZdsTJSq7kdKyJq7HD7XGCxomU+wxBRdBUaiE5FfADgWwPYi0g7gcgCDAUBVbwAwHZVp/XNQmdr/8eKOTsjAxoKW2cESU3Q1da9cH1U9s4d0BfC5wg5ICHFY0HJmB7u+nITrn7VLegnR7d5sww+Mmezsp3c9xdlrX/1R3vYR0hBF3vVao8Nz4S3d5f3OvtJbBfA+b0YxALyw6yecPTs1wzhupnD4JzJvAIlYt3BWoInG3aih49fmiJlJXUQwiNCGB83M5E63p/sZ1T9+UiquRDwWtMwnWGKK0DghIaR/YUHL7GCJKToKdCsRQvoOC1pmB0tM0aUG/EqEEBNazuxgXxiWLAsa2bEylfa3I3dx9vRyMs6x28hJIKSvsOBWahUvDtnZ2cPWvers40dNcPa9NWsPdxu6fd26YgPLN7Mxe7hUbaj9mNqLDnufoBlXIbRMJ71pfDMLnNJjyunN6euPh2YvIQrljIvOlfWtZbUhBgta5hMsMYUFURJCbGiZHSwxhQVREkJsaDmzgz2qIwlk8VDHqlTaks6Nzt6wOYne9In2nzn7S7mbR0hjdBQQ/twq7+xMhnkeWLfe2Qs2J/amzs2pMh9ZdLuzv+p93pzjNbyzarrGNs/2f2Zrv9v6LloJ5KitLZbQUphsx2fP/4dSkyfOQZy+8uWa1Hop3esNHanxpVOp40tbOlF9t3Lj/zEWtMwnWGIKCxMjCCE2tMwOlpjCgluJEGJDy5kd7IPl5JH/uNJbqbSZR1/r7P02vOnsK5/8YkFNI6Rxugy4lVrF/d4v1vFtiZ5nHfEtZ++7aZ1fBN/6czLQk293z3SpWvdoOldsMPn6+TQ14zV8nNCuplkRidL1NdMF+O2MdZsmZbqXKHu5yhn5EsLfY/haR80a16yjtmWk1ceClvkES0xhQZSEEBtaZgdLTGHBrUQIsaHlzA62bUO7s5/c8ZhU2psPJkHAB43Z29nj9/2Ms2fOvC53AwlphI5MN9XAZsjGZFXAo9sd5uwNj3zK2b6WAWD8vuc7+5VXrg3U3Phs2u6f1q8j2+3Zc7j/rP+G2P+UUL7mXOb1Zztn1xHX0tjzaeY4cXVn71bbKBa0zCdYYgoLbiVCiA0ts4MlpigbmNpPCLGhZXawxBQW7noJITa0nNnB7nrY1c5eOv+2VNoGbwThNGxy9u2zphbVNkIaxoIoW8X4t13q7IWenn0tf1TSkZx+NvO73rvQWGvs5txhQjGNsr/NmI3Qa5eHhDZcj9u+IBRjqqdW1P80azP4xglHnIrTRHNjykmpUmapxnVpQcu13zAh/Zou1ehXT4jICSIyU0TmiMhFddLPEZE3ROTZ6uvclpwUIQOQIrUM9I2e6SImpthc0MxDEWkDcD2AvwPQDuApEZmmqi/VZL1VVc8r5KCEEEdRWgb6Ts+ZHexZr/3A2etqbhJu3y+Zvv/yqGQP2FNnXuPsO9a+nrd9hDREgW6lIwHMUdW5ACAitwA4BUCtIPsNp77+c2d3eP7Au7yldU+P3jNV5gOzrnX2r9/09dy4OzAcNQhoCyzHif02w67X2NWUcbGP4vagrbSi508bj+Qk3fadDdUdbmfIFdzctU5KZZ9N4zvCFuwi7hM900VMTNEFjX71wAQAC7z37dXPajlNRJ4XkV+JyK5FnQchA50CtQz0kZ7ZwRJTlFWjXyIyRURmeK8pDR7uLgCTVfVgAPcB+EnxZ0TIwKSXtQy0QM+ZLuI3vBuD6zYsT6XtuePBzp7zxIXOfmXzhrxtIqRpGnErqepUAKFp7wsB+HewE6uf+eVXeG9/CODfow/eB6zzrs3UTWudvcc2+zh79rNXpMrMLPuzinuetVvJFTMfNbyvqUbO6E0fs35ttU8QIZdodotDs5qbmXcrnhXr1PX3ds2K/xTXnrjtD9LXPbQLbvh61Ka2dhZxD1oG+kjPnORETFHgxIinAOwtIrujIsQzAHzEzyAiu6jqlviDJwN4uaiDEzLQKXKSE/pIz+xgiSmK2qRZVTtF5DwA96Cy/PEmVX1RRK4CMENVpwG4QEROBtAJYCWAcwo5OCGk0A3X+0rPohknccCk3V3ijpM+mEpbvmKGs9etT8aO9zsmWZj+yiP/lLd9hAAA5s2fHzUN8eR9Do1W5bRZzzY+tbEf4+t5h91OdZ+vXPW8s9euey1VZr9jbnB27+k5/86z9WrKX1tW7eFgGxJwMTfTtuwycdctlKu4q559zNcGkJb5BEtMYSH6CyHEhpbZwRJTlA1scUUIsaFldrDEFPGBAAghWzMWtJzZwfpBwAcNGplKW7d+vrMHDxrt7FmPng9C+oqCZx6aYqO39KKtNNTZa9clWm5rG5EqM+fxLwVqa2a5SeMbs2eX7zliU94j1h6lHD1S2fOGBc10H9ll8o3Bxtfd80b3tTRzrha0zCdYYgoLe0gSQmxomR0sMUX/v+clhAA2tJzZwe6159nOPnPlo6m0W/b8uLNfXzDN2YcOSqp8auMmENKbWLjrbRV7epo9dcWfnP2bPT7q7AXtv02VOawtuZ6Pd/opsXu21i9TSyiAfHb55Cc45PaM/5EOu1fD4fX7nnSr4842nCvLxdxMDKx8C38saJlPsMQUFiZGEEJsaJkdLDFFp4GJEYQQG1rO7GDXvZbsH/n4iHGptNlzbnT2xLZhzm4bOqqothHSMBbuelvFyrk/dfZzI3dw9pxXb3b2HoOHpcpsLiXv40LBA0VGX8raOKCU2ou0futqg/2Hf7KTurI2LwhFZepOm2eH9qQNX7f02YSvvN+GxrdIqKWZHWFj3cqNY0HLfIIlpij3f00SQmBDy+xgiSks3PUSQmxoObODXeY9/p8yfHgq7ZkNSdE1I5KN4Ue/tayothHSMBZE2SpWSeJEPHFoEjjmL56WV43cPV1ofRL8v9ZBuwWJcrxWcta342fAhuqW1J6pCc24SuP3g/WPX1vCdwuHrlU4QEd8QIskXzn1aTOzvLOICU9RexWS96UmvgkLWuYTLDGFgZn9hBDY0DI7WGIKCztwEEJsaJkdLDGFBbcSIcSGljM72EHlt5z9Ndk5lTaitNjZa9e+6uyhw7Ytqm2ENEz/l2TrGFLucPa/YqyzR5QWOnvV6hdTZcYP3SZQW+y4aygl9pvKqq3nccqsBUTNxRmqv31A9+hToXyx5NsKIO/G7vFXLqtmfxlV41jQMp9giSksiJIQYkPL7GCJKSy4lQghNrSc2cHu6gUBHzJ8Qirt1WVJ8P+9PTfSmSO2d/azm1bmbiAhjVCkJEXkBADfRiUszw9V9eqa9KEAbgbwdgArAHxYVecV2IRC2Xm39zt72PDxzp63Yoaz9x06NlXm9BFJxKfL3lrjpTSzDKTxpSexe66GymjQndlT3XnLNOMOr0/21ejZNd7clY697q2j6KP0hZ5rI4kR0q8pN/DKQkTaAFwP4L0ADgBwpogcUJPtkwBWqepeAK4B8M2CToOQAU9RWgb6Ts/sYIkptIFXDxwJYI6qzlXVDgC3ADilJs8pAH5StX8F4DgRyXpkIoREUqCWgT7Sc6aLuKsrmUW8esUTqbT3jdktSetKZidevLo9T3sI6TVEZAqAKd5HU1V1atWeAGCBl9YO4KiaKlweVe0UkTUAxgFY3poW56OtNNTZa1Y/7+yTPC0v60rv4Xz5Gl/PzTjtGp+rGxuFKBwQ34+iFDenuTmXalapmND74Zm64ZrjysTSTHj/2PL5Nx+IpwctA32kZ05yIsaI/5mpCnBqjxkJIX1A/9cyXcTEGNLAK5OFAHb13k+sflY3j4gMArANKpMjCCG5KUzLQB/pOfMJdsQ2yRjwkuevTKXdMyoJCr5uQ/IEXZLBedpDSD6KGwJ9CsDeIrI7KsI7A8BHavJMA/AxAI8B+CCAP6huvRFUh46c5Oyl82519u9GJr876zekvWFtpfT+sDGkXYMxbs80MQEkKnX7QfQTt3BvzQ3OdmXHOEXj3Nf5U/IT+534pDcfaIJipzP0iZ7pIibGKMYpUx2DOQ/APahM679JVV8UkasAzFDVaQBuBPA/IjIHwEpUREsIKYTiHKx9pWd2sMQU0ty9cl1UdTqA6TWffdWzNwE4vbADEkIcRWoZ6Bs9s4MltuAqGUJsYEDLmR2stCXT+k8eMymVNmPjUmcP867D9vskM6VfeeXanM0jpDGKvuu1RGlIEnHt2NFJJKcXN77h7ME112/HPf7R2bNn31C33rgw99kjd22e3RXMlSa0iXf2shY/JTYale+qTOrT6JY2t61AHPXdqG2BDeiBrM3cs9oZEzGqrSY137i4BS3zCZYYgxPjCbFB/9cyO1hiCinV3kUTQvojFrSc2cEe/GqybvexbY9MpY3Y/W3OHjoyCQg+6plLi2obIQ0jBu56W8VRc3/o7HtHH+TsURMTe/zI9KYeQ174mrNj3XyxjlOftLO38WUgIaSmrnDNWXGHQgtOmtmUIJYYNy4QaltXZiyokGs8r/u6mW8+jAUt8wmWmIKhgAmxgQUts4MltpD+f9dLCIEJLWd2sE8ck2yXd/SfLkqlPdx+l7P3HTbW2b/fuKqgphHSOGJAlK3i/nd/y9l/9fvPO/vppX909qSho1Nl/rTJ3wO2vtsxHH4+npgytccJOWvTof7DrYt3jzYetj7koI2vKdSerE0FkrQ2z649TnOO4OSMSsFrKDXv8m1EYEHLfIIlprAwbkMIsaFldrDEFKUSY2ETYgELWs7sYJf88UvOfma301Jpu8z/qbMfxxBnH3DQxc5+6YWv524gIY1gwa3UKpbe/Wlnz5z8IWfv9PovnP2sDE+VOeCQK5z98nOXRR0nzg1a+z31vGdqrbs37Z6sXz5rFnFoU4LsmbpxZxdynWYfp5kdauvPIs5yP6drbtz97QeqSF/f2mudL8CGBS3zCZaYwoIoCSE2tMwOlphCpP8vTieE2NAyO1hiCgt3vYQQG1rO7GA3bUqCgG/c8Hoq7bWNyUbvo8clm6/P4rgr6UNKpSE9ZxqgdHSsdvaG9a85e/7Glc4euV16U4+Z3rhr7Cha3JKM9HhferQuVEPtGGGpjtVTO0PjoT6xUZnixhibi5wU2qg+a9w2dvMCn1CUqqwa/HHxrDHlfFjQMp9giSlKBu56CSE2tMwOlpjCwrgNIcSGljM72O22TQL6ty+6L5U2afIZzp6/YJqzJwwa4ezXOzfkbiAhjWBBlK1i3HaHOHvhot87e7Kn5Xmv35Eqs523FnF5udNLCbsG0/uNluvmyg5ZX989WusgFu+Tcmbt9Y+UXqCSFRw/IRTFKGuBS3i5SlwMrNjlLrELbsK1xbqs61+D2iVRWd9dDBa0zCdYYoreEqWIbAfgVgCTAcwD8CFV7RYnVES6APyl+vZ1VT25VxpISD/Hgpb7v5ObEA8ptUW/cnIRgN+r6t4Afl99X4+Nqnpo9cXOlZBILGg58wl2343znb3j8G1TaY9vWuLs4cN3cnbHhoUxxyWkJbT13szDUwAcW7V/AuBBAF/prYM3w14b5zl77Ihxzn7O0/KIEeNTZYZsXOS92xx5pPoOwXDMn6yawu5eDbotY/eTjXO9+inN7Hiaf//VsFvZrzt91cMO47y7vsbOvc7aCCAGC1rmEywxhUhbAy+ZIiIzvNeUBg61k6ourtpLAOwUyDesWvfjInJqrpMjZABhQcscgyWmEIn/l1bVqQCmhuuS+wHsXCfpkpp6VERCN/aTVHWhiOwB4A8i8hdVfTW6kYQMUCxoOTvY/99829kHPJ4O9L14yQPO9h+Ddzv4iqT881eAkN6kVODECFU9PpQmIktFZBdVXSwiuwBYFqhjYfXvXBF5EMBhAPqkg11+zLXOPuDJrzr7d0sedHapxrE36bB/c/aiZy72UmJdfqF8cS7Z7Fm3JS9X3DzVcID/MHn3ds3rkm1u3m8ogEQRLQqdeezmBZFHMaBluoiJKXpxYsQ0AB+r2h8DcGe3tohsKyJDq/b2AI4B8FLeAxMyELCgZbqIiSkacSvl5GoAt4nIJwHMB/ChyvHlcACfVtVzAewP4PsiUkblZvZqVWUHS0gEFrTMDpaYolQa2ivHUdUVAI6r8/kMAOdW7UcBvK02DyGkZyxoObODHfZGu7N3L6W9yUcM397Zf964PElIRXshpHcpctzGGkNXJEtudvPcaocOS5bgPbspvb6+a9MK1CO89KSW/COQobpqIwfFHLHnbd3rLTepv5VA7BhufBD9EHGB9+MX0DS+yXqaUJlmrnwYC1rmEywxhZT4L02IBSxouf+fASEevThuQwhpIRa0nHkGrz59obOv2eGoVNplI5NH/hmb3nT2wgW3F9U2QhrGQoDwVjHv6SQ4zfU7HOnsy0YlWv5Cx6ZUmUUL73J2zHIVoMnA7oG6mw1NX690bJnuNLMHrJ+rmaOKZ/nu+KyrXd/12/0ahL6hONdxeH/auOsRiwUt9/9bBEI8Sm3D+roJhJACsKBldrDEFBbcSoQQG1rOPIPxg5O9XdfWpH1hZTIjcZ+/TSJUDVqz2NlrUpFfCOkFDIiyVew0eLiz3ywnYeu/uDrR7L7v+K9UGVmXbATwyvNXODu9F2p+R2ysWzhN/XzNzNmNnwnd+Lk1N4e4/iztUo3rthxI8z+PPaZfPitwv8J33cbtPKvNOIkNaLn/nwEhHhZmHhJCbGi5/58BIR4W3EqEEBtazjyD88cke0N+f99PptImzEhmJJ723Dec/fPxHyiqbYQ0joG73lbh63mqp+ed/pzsL336y2kX8W0TPuTssDOwyGAS6fqy3asxgSaKmONcn1jXbygt1rEee5xycOZxHPFXJhleiA1ZIc1cdwNa7v9nQIiHtPVOeDVCSGuxoGV2sMQWBtxKhBCY0HL/PwNCPNSAW4kQYkPLmWdwyYpkL9lRj34ulXbh2EnO/u6Ytzt7zeJ7i2obIY2Tf29Is6T0/PgFzg5pGQBWe5Gc8pI9llg/Nf/obtzYX+w4Z/wm636NoXOLO7v4ZTb+RgSxNBOZKiEUR6qSlvN7NKDl/n+LQIiPAVESQmBCy+xgiSnKg4b0dRMIIQVgQcuZHezOkz/s7CXzbk2l/cCL8DJmwz3OntA22NkP524eIQ1i4K63Vew46YPOXjT/l84OaRkAdvHGwf4UqDd2uUnx+8b27LAtIti/X0e869VfatS4qzQU96jYHVd7qrGxXMUtgKpiQMt8giWm0FLtSBAhpD9iQcv9/wwI8dBSW/QrDyJyuoi8KCJlETk8I98JIjJTROaIyEWhfISQNBa0nPkEu3DeLc4+a5s9UmmP7vI+Z7+68H+dPXL9gpjjEtISym29ds/4AoAPAPh+KINUNrS8HsDfAWgH8JSITFPVl3qniWna59/m7LO32dPZvpbnLro7VWbYute8d/WDvHefDRszM1Vq3tWP3pTldgwFp/f3TO0eZL4L9cja19R/FxvmXlOB7hvfiKC5zQ/iiJ0JHVc+XFve41jQMl3ExBS95VZS1ZcBQCRzl5AjAcxR1bnVvLcAOAVAn3SwhPQnLGiZHSwxRdeg+H9pEZkCYIr30VRVnRrK3wQTAPgunXYARxVYPyFmsaDlzDM4fvQEZ/9005pU2gVLpjv79nFHOPvsQcldxxWrZvV0fEIKRRtwK1UFGBShiNwPYOc6SZeo6p2Nt65v+YfRE53t6/lzvpbHHpQqc1YpcUl+Y7Wv56y7/Rg3ZtgNG3LX1h6xHHSdJnbW/qnp0uE2+/9RXTndtc05eMPXIOxKzn/UEOkrGq47na/x/WAtaJlPsMQUWmpiY+dQXarH56xiIYBdvfcTq58RQnrAgpbZwRJTlNuKE2UBPAVgbxHZHRUxngHgI33bJEL6Bxa0zGU6xBRakuhXHkTk/SLSDuBoAL8VkXuqn48XkekAoKqdAM4DcA+AlwHcpqov5jowIQMEC1rOfIKdt/fnnb3ds5em0q55a7mzu/yxmW336emYhLSMIt1KmcdRvQPAHXU+XwTgRO/9dADTa/P1Ba/tfq6zt3vh35z9bU/LWJ2eN7FxrK/nYpeLhMm34XpzxC3TKbbuIuJMhYgdG23dcXykmfFqA1qmi5iYojx4q3IrEUKaxIKW2cESW3DQgxAbGNByZgfbuX6+s5dq+hF/3/2/5Oxx82929hWrZhfVNkIap//HB28Zmze0O3uJdjp7v4CWAeCa1a9670LuzVr3XyiGT+NLe7KXofS8rUB2tKWYutIp8Xuu1q8vdmlQbK3hbyF8rdObD+R1+/vfdW0Lcu4Ha0DLfIIltjBw10sIgQkts4MltjAgSkIITGg5s4OdPfM6Z3/cCw4OAPetfdnZczYsc/YXvVmH16xmJCfSu5QGtXJ2a/9mzpwfOvtj3uYd9676s7NneVoG0nq+1tOz796s9eSVA0H402R9T6FITmEnaMhd292pnM8lGjsDN+QkL3oOcUzUq9p6w9dAPCt2dnByduWaM8o61xgsaJlPsMQUYuCulxBiQ8vsYIkpcm4NSQjZSrCg5cwOdvLgkc6etvb1dMG1yYxEP4j4jzGkqLYR0jC9tMNVv2RC23Bn/3LNPGePWZuEVD1u1PhUmZs0+YkIOexqd1gNzysNB60Pu3v9vV1jw8xn5alfKnrvUhnsVbU5+TjzKFLXrj2f5lrU82zluBI95alfY3omdHMDAiEsaJlPsMQUUur/4zaEEBtaZgdLTGHhrpcQYkPL7GCJKRrYo5kQshVjQcuZp/DG5k3O3m/4tqm0tW9PgoXf/WiyKcAl2yTjtl/L3TxCGsPCXW+reLPc4ex9ho11dsfBVzj79zMuTJW5aMwoZ38jlRJe0hGzlEW7LekILeoo1/20UqZnaufJ1I4Xdz9Kd1KLWrxx1/i2xI1TxrcoRNLScgsWBPVMrfjybStgQcsG7hEISbAgSkKIDS2zgyWmaDMwMYIQYkPLmR3sGdvs7uyl5bdSaYf+5ZvOXj/+3c6+fYd3JplWXZy3fYQ0hIW73lbx/jGTnL2oKxn+2XXmt529dvxxqTK/GfeO5M3qy5yZHci9fqSg9MKT8I9nm+da9F268YHuE7pq3KOhQPdZ55MO9p9Q7L6qRdDz5gfNlQ8vvdKMBVYSzBeHBS3zCZaYwoIoCSE2tMwOlphisIHoL4QQG1rO7GAf3uUkZ89+5dpU2t7vusnZy564yNmDlz5YTMsIaYK2XrrrFZHTAVwBYH8AR6rqjEC+eQDWouLt7FTVw3unhd15dIf3OHvBnBudPfmdU5297KlLUmXaljxQt67sQO713a2xO8OmZ/qGHcFxu9OGA9CH2tmd0OzcMPldyXHu61ZR+/2UoxzyadLfT+Mtt6BlPsESU/SiW+kFAB8A8P2IvO9W1eUtbg8hprCgZXawxBS9dderqi8DgEgzG3ERQnrCgpYzO9hxr//c2eW9P51Km/XIec4eL8nj/2vl0FJuQlpPb4myARTAvSKiAL6vqlN7KtAqdlw0LXmz1yedOevR8509seaxYX5XZ4/1xrpXm3HJZrtHe3bESqSDNntvWf/I/sBg+Lcu/wzjuCD+eV3GoS4ldpOEeHdx452XBS3zCZaYYnAD/9EiMgXAFO+jqb5oROR+ADvXKXqJqt4ZeZh3qupCEdkRwH0i8oqqPhTfSkIGJha0zA6WmKKtgRvlqgCDd6Gqenze9qjqwurfZSJyB4AjAbCDJaQHLGg5s4Pt8h7/2+b/PJW23+FXO3vxrOS8dj86WYw+9+4PRzSbkOLYmtxKIjISQElV11btvwdwVV+1Z5MmLk2df4uz9z80iRre/uqPUmX2PDKZVfzqfWcl5TNnlTY+4zRNrCs5ccS2LsQC4J9PKeAWzgpO0XOt2WWyr3VC7MzlvN9Oc2y9s4hjaFbLW9EpEJKftlL8Kw8i8n4RaQdwNIDfisg91c/Hi8j0aradADwiIs8BeBLAb1X1d/mOTMjAwIKW6SImphhU6p1Zvap6B4A76ny+CMCJVXsugEN6pUGEGMOCltnBElNsTW4lQkjzWNByZgf75kHJeOrS2Tem0sY8nURvOnlkMjnr9kfSkWAI6U2GGAiv1io2Hnips1d4kZxGP3+5s08buUuqzG2PJb8B4bHAZhaPhJ9OwktzijhOqEzWCKbWTQkFwK+Q/CNKcNw2vH9q+kqHzzM21H44JXZENvtsK/VmHadxLGiZT7DEFBbuegkhNrTMDpaYoq2Xxm0IIa3Fgpazl+msmeXsIzenwy8eMmqis2d5e0u+V5LbjttyN4+QxrBw19sqyusXOvvwzpXOfsfoXZ39QufGVJnT2pKfiJ94jr5YF6ZPdpmSZyVpWXHh4lydWUuIfOJiL6UdyXFO0Gbco7F7qWrwOwm3rRyoL+tsSoElUa1c5mNBy3yCJaZoZHE6IWTrxYKW2cESU1hwKxFCbGg5s4Mte67fYwaPTKXduLbd2Qs0uRDf3Ga8s+kiJr3NkEH9X5Stoty5ztnHDxnl7OvfXODshZq+fleMmeC9S9LSLt64mbp5ozJl7Qcbvztsvr1V03vaZtVV37mdblnYLR1KqZ2pm3bX+pskhK9nqN1Z1yPGgd7dla0ZaT1jQct8giWmsDBuQwixoWV2sMQUJe7PSogJLGhZVHsv3DMhreb7T/8l+h/6U3/1tv6vYEKMYkHLfIIlprAwMYIQYkPL7GCJKSxMjCCE2NAyO1hiCgt3vYQQG1pmB0tMYWHmISHEhpbZwRJTlAzc9RJCbGiZHSwxhQW3EiHEhpbZwRJTWHArEUJsaNnAKRCSMHhQKfqVBxH5loi8IiLPi8gdIjI2kO8EEZkpInNE5KJcByVkAGFBy+xgiSnaSvGvnNwH4CBVPRjALAAX12YQkTYA1wN4L4ADAJwpIgfkPjIhAwALWmYHS0zRVpLoVx5U9V5V7ay+fRzAxDrZjgQwR1XnqmoHgFsAnJLrwIQMECxomWOwxBR9NDHiEwBurfP5BAALvPftAI7qlRYR0s+xoGV2sMQU79l5r2hVisgUAFO8j6aq6lQv/X4AO9cpeomq3lnNcwmATgA/a67FhJB6WNAyO1gyYKkKcGpG+vFZ5UXkHAAnAThO6++asRDArt77idXPCCEFsrVqmWOwhDSBiJwA4EIAJ6vqhkC2pwDsLSK7i8gQAGcAmNZbbSSE9EwrtcwOlpDmuA7AaAD3icizInIDAIjIeBGZDgDViRPnAbgHwMsAblPVF/uqwYSQurRMy9wPlhBCCGkBfIIlhBBCWgA7WEIIIaQFsIMlhBBCWgA7WEIIIaQFsIMlhBBCWgA7WEIIIaQFsIMlhBBCWgA7WEIIIaQF/B8TTnWmgCtLuwAAAABJRU5ErkJggg==\n", 121 | "text/plain": [ 122 | "
" 123 | ] 124 | }, 125 | "metadata": { 126 | "needs_background": "light" 127 | }, 128 | "output_type": "display_data" 129 | } 130 | ], 131 | "source": [ 132 | "plt.figure(figsize=(8, 4))\n", 133 | "plt.subplot(1, 2, 1)\n", 134 | "sns.heatmap(W, xticklabels=False, yticklabels=False, annot=False, square=True, vmin=-2, vmax=2, center=0)\n", 135 | "plt.title('Ground Truth')\n", 136 | "plt.subplot(1, 2, 2)\n", 137 | "sns.heatmap(coef1, xticklabels=False, yticklabels=False, annot=False, square=True, vmin=-2, vmax=2, center=0)\n", 138 | "plt.title('Estimated')\n", 139 | "plt.show()" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## ADMM\n", 147 | "\\begin{align}\n", 148 | " \\min_{W, Y} &\\quad \\frac1{2N}\\|X - WD\\|_F^2 + \\alpha \\sum_i\\|y_i\\|_2\\\\\n", 149 | " \\text{s.t.} &\\quad W - Y = 0\n", 150 | "\\end{align}\n", 151 | "\n", 152 | "Augmented Lagrangian Function:\n", 153 | "\\begin{align}\n", 154 | " \\mathcal{L}(W, Y, U) = \\frac1{2N}\\left\\|X - WD\\right\\|_F^2 + \\alpha \\sum_i\\|y_i\\|_2 + U^\\top(W - Y) + \\frac\\tau2\\|W - Y\\|_F^2\n", 155 | "\\end{align}\n", 156 | "\n", 157 | "Optimal Condition:\n", 158 | "\\begin{align}\n", 159 | " \\frac{\\partial \\mathcal{L}}{\\partial W} = \\frac1N(WD - X)D^\\top + U + \\tau(W - Y) = 0\n", 160 | "\\end{align}\n", 161 | "\n", 162 | "\\begin{align}\n", 163 | " \\partial_{y_i} \\mathcal{L} = \\alpha\\partial \\|y_i\\|_2 + u_i + \\tau(y_i - w_i) \\ni 0\n", 164 | "\\end{align}\n", 165 | "\n", 166 | "ADMM Algorithm:\n", 167 | "\\begin{align}\n", 168 | " W^{t+1} &= \\left(\\frac1NXD^\\top + \\tau Y^{t} - U^{t}\\right)\\left(\\frac1NDD^\\top + \\tau I\\right)^{-1}\\\\\n", 169 | " y_i^{t+1} &= S\\left(w_i^{t+1} + \\frac1\\tau u_i^t, \\frac{\\alpha}{\\tau}\\right)\\\\\n", 170 | " U^{t+1} &= U^t + \\tau\\left(W^{t+1} - Y^{t+1}\\right)\n", 171 | "\\end{align}\n", 172 | "\n", 173 | "where the shrinkage mapping $S$ is defined:\n", 174 | "\\begin{align}\n", 175 | " S(x, \\alpha) = \\max\\left\\{1 - \\frac{\\alpha}{\\|x\\|_2}, 0\\right\\}x\n", 176 | "\\end{align}" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 6, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "CPU times: user 6.86 ms, sys: 817 µs, total: 7.67 ms\n", 189 | "Wall time: 2.17 ms\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "import sys\n", 195 | "sys.path = ['..'] + sys.path\n", 196 | "from spmimage.decomposition import sparse_encode_with_l21_norm\n", 197 | "\n", 198 | "%time coef2 = sparse_encode_with_l21_norm(X, dictionary, alpha=.1)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 7, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdgAAADyCAYAAADumdR9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAmZUlEQVR4nO3deZwcZZkH8N/Tk8lkcocEyH0SSLgMCmFZXOUSgYUERQRkRVSIqCjigWBAwuEadlEQQSEcAqvIoSBBgxDQADECCUlAQq7JPSH3OcnMJJnpZ//oyltVnZ6e7umq7plnft/Ppz883VXV9XaYX79Vb1VXiaqCiIiIopUodQOIiIgsYgdLREQUA3awREREMWAHS0REFAN2sERERDFgB0tERBQDdrBFIiJDRURFpEMJ1r1SRM4o9nqJWjMR+Q8RWVzqdmQiIqeISHWp20GFMdXBisjFIvKWiOwWkY1e/Q0RkVK3LRsR2RV4JEWkLvD80jzf61ERuT2utrYXIjJIRP4uIh+IyAIRuSbDPCIi94hIlYi8JyIfLUVb2yNvo7EuLTv3NrOMishh+5+r6huqekRM7WMOW5FS5bnoe1NxEZHvAbgOwDcBvARgF4AxAL4P4GEAezIsU6aqjUVsZkaq2nV/LSIrAVyhqq+kzyciHVS1oZhta8caAHxPVeeKSDcA74jIdFX9IDDP2QBGeo8TAfza+y8Vx3mZckKUQUnybGIPVkR6ALgVwDdU9Q+qWqMp81T1UlXd4833qIj8WkSmichuAKeKyGgRmSEi270tm3GB950hIlcEnl8uIjMDz1VErhKRpd7y9+3fWxaRMhG5U0Q2i8hyAP/Zgs91iohUi8gPRWQ9gN+ktyHQjsNEZAKASwFc523RvxCYbYy3VbZDRJ4SkU75tqc9UdV1qjrXq2sALAQwIG228QAe9/7W3gTQU0T6FbmpFODl4DXv73yziDzlvf66N8u7XjYuSh+G9faKf+DlZLeIPCwih4rIiyJSIyKviEivwPzPiMh6b12vi8hR3usZcygi/UXkjyKySURWiMi3A+9V6X0/bRORDwCcEP+/VvtRqjxb2YM9CUAFgOdzmPcLAM4BcC6ALgDmAXgEwJkAPg7geRE5XlVzPTZzLlJh6A7gHQAvAPgrgCu9accB2A3gj7l+mDR9ARwEYAhSG0QXNTWjqk4RkX8HUK2qN6ZN/jyAswDUA/gHgMsB3N/CNrVeC/6Y87U/5ejPfQ3AhMBLU1R1ygHziQxF6v/jW2mTBgBYE3he7b22Ltc2UORuA/AygFMBdARwPACo6idERAF8RFWrgNQGbIblLwDwKaS+G+ch9f/9q0h9IU8D8G0At3jzvgjgKwD2ArgDwO8AjMmUQxFJIPXd8DyASwAMBPCKiCxW1ZcA3AxghPfo4r13+xZDloHi5tlKB9sHwObg8KmIzAJwJFId76dVdf8W7POq+g9vnjEAugKYrKpJAH8TkT8jFYBJOa57sqpuB7BdRP6O1LD0X5Hq0O5W1TXeun4K4JQWfLYkgJsDe+EteAsAwD2q+qH3Hi947TRHG3Mf8fcCmDGE+4lIV6Q2jr6jqjsLax1F7E8iEjxk8gMA+5DaGO2vqtUAZmZcsmm/VNUNACAibwDYqKrzvOfPATh9/4yq+sj+WkQmAdgmIj1UdUeG9z0BwMGqeqv3fLmIPAjgYqQOaX0eqRG4rQC2isg9AH6cZ9tNiTrLQPHzbKWD3QKgT/AYpar+OwB4Q0DBofDgFkp/AGu8znW/VThw6CCb9YG6FqkO27132vu2xCZVrW/hskHp7ewfwXu2Po3RHaIWkXKkwvg7VX02wyxrAQwKPB/ovUbFcX76MVhv4/E2AG+LyDYAPwt2hDnYEKjrMjzv6q2nDMBPAFwI4GCkNoSB1MZ+pg52CID+IrI98FoZgDe8OqrvCzsizDJQmjybOAYL4J9IncQ0Pod5g8MOHwIY5A3f7DcY/j/qbgCdA9P65tGmdQj/zxqcx7JB6cMkoTaJSHqb2vXtkTTZkPMjG+9Y+sMAFqrqz5uYbSqAyyTl3wDsUFUOD5eQqq5X1StVtT+ArwH4lQTOHI7QF5D6vjkDQA8AQ73X9w8xpedwDYAVqtoz8Oimqud406P6vjAjqiwDpcuziT1YVd0uIrcgFSZBashlN4BjkTqe0ZS3kNqbu05EfgbgZADnwT/BYD6Az4rIQ0htYX4V4S3abJ4G8G1vyHk3gOvz+lBNexfAUd7w9iIcOJS9AcDwiNbV9uQxrNSMkwF8EcC/RGS+99qP4H3xqer9SB2TOwdAFVJ/R1+OauXUMiJyIYB/esPD25Dq6PbvXe7PRlUEq+qG1Eb9FqQ2eP87bXp6Dt8GUCMiPwRwD1LHbUcDqFTV2Uh9X9wgIm8h9Z31rQja2LZFl2WgRHk20cECgKr+j4isReqnOo8j1aktB/BDALOaWGaviJwH4FcAbkBqz/UyVV3kzXIXUp3tBgDvIXUSQ64XbHgQwOFIdYg7AdwJ4LT8P9kBbV4iIrcCeAWpIasbkNpS3+9hAM94Q1EzVPX8QtfZlujeKEbTAVWdCX9vpKl5FKmfhVFpvCAiwW/h6QCWArhbUr8s2ADgGlVd7k2fBOAxEalE6oSYjQWs+3EAn0bqO2MrgJsAfD0w/YAcisi5AH4GYAVS54YsBrD/ZMRbkDrpcAVSI2u/AXDAbzXbk6iyDJQuz8IbrpMlDa//Muc/6A6f+FarvgAJUXtmIctm9mCJAEQ9rEREpWIgy+xgyRSN+MxDIioNC1lmB0u2GAglEcFEltnBkimabPvDSkRkI8vsYMmUKM88JKLSsZDlrB3saVe96c7iWvvixaFpI4692dWVS3/p6vrAVseCPdsLbiARAKxctSqnswQtbPXG5cwvv+7yvPpvl7nXhx31Q1d3Xv5gaJltDf6X3PJ9u+NsHrUT7SnL3IMlWwwctyEimMgyO1gyxcKZh0RkI8tZO9hTZv3I1U92HhiatmndX13d57ifuPqY9ye7mkPEVHQGQhmXk97xD+usqfTv9bDhQ//OaP1Gh6/oefTSe1zNIWIqKgNZ5h4smaL79pS6CUQUAQtZZgdLpuRzD0kiar0sZDlrBzvzY5P8GefcHpq2b/McVy/cFLgxfMdu0bSMqAUsHLeJy+zR33O1vOcfymnY/r6rF8z9QWiZ3R06g6gULGSZe7BkSw73hiSiNsBAltnBkikWhpWIyEaWs3awn5l7m6vvrA3f2L2yy2BX31vp/2548tYPo2obUf4MhDIupy6629UP7N3p6rKuw1x9R3kitMxd21fF3i6ijAxkmXuwZIrubftnHhKRjSyzgyVTLAwrEZGNLCean4Wo7dBkY86P5ojIIyKyUUTeb2L6KSKyQ0Tme48fR/6BiNopC1nOugf7hxNvcXWv6ZeFpg1p3OHqmxuHuPrGnn6ffe2WpQU3kCgv0W71PgrgXgCPZ5nnDVU9N8qVxmX6R/3vjJ4zvuLqgfu2unpyWf/QMtf08LM9aduSGFtHlMZAljlETKZEOaykqq+LyNDI3pCIcmYhyxwiJlO0MZnzIyInici7IvKiiBwV1ZsStXcWspx1D/brs/0hpXsGnBOatnvYeFdvf90fbqrrdXgU7SJqEd27L+d5RWQCgAmBl6ao6pQ8VjcXwBBV3SUi5wD4E4CReSxfVF+c5//s7oG+p7t6z9DzXb1j1lWhZcqYZyoRC1nmEDGZoo3a/Ez7500FMJ8Qpi+/M1BPE5FfiUgfVd3c0vckohQLWWYHS6ZEOFzULBHpC2CDqqqIjEXqkMuWojWAyDALWc7awU6Wnq7es+bPoWmrq6cFGtfR1fMb9xbaJqIWizKUIvJ7AKcA6CMi1QBuBlAOAKp6P4DPAfi6iDQAqANwsarmvtldZHdLV1fXffiyq6vXvepqSVSElpnbUB9/w4gysJBl7sGSKZqMrn9T1UuamX4vUqf+E1HELGSZHSyZktzbancgiSgPFrKctYPdvmW2q5OJ8H0he/cc5eotW+e6+uBEeVRtI8pbPidGtDc1W+e5uiFR6eqDevgnS27d9m5omV7MM5WIhSxzD5ZMSRbvvAgiipGFLLODJVO07V8fnIhgI8vsYMkUC6EkIhtZztrBdut5tKtr6zaGpn1Hd7l6amUfV/96Oy/wT6VjYVgpLl16+ld/21233tVXaZ2rn6/oGVrmsR3LYm8XUSYWssw9WDKlcZ+UuglEFAELWWYHS6ZY2OolIhtZztrBHnTQca7+yIbpoWlPjP6uqxf+85uuHn72U65e/uJFBTeQKB8WjtvEpXfvE1x99Hr/SmzPD7/S1QvmXR9a5ohTHnX14hmXx9Y2onQWssw9WDIlmWz7w0pEZCPL7GDJlKSBrV4ispHlrB1s2dAzXf3ByqdC03Tpo64+asztrl48fQKISsXCVm9cyg890dVLVzzh6r0rfuvq0Uf+ILRM1azvgqgULGSZe7BkSmND2w8lEdnIMjtYMsXCVi8R2chy1g52yd8vd3VFx96haTc2bHX1lE0zXV3W4G4Mj4ZCW0eUJwuhjEvwbP9OFQe7+ruBC008vm1+aJnE3m2xt4soEwtZ5h4smdJoIJREZCPL7GDJFAtbvURkI8vsYMmUhsa2H0oispHlrB3sN3sMd/XfD/tGaNpNc6519eF9TnO1dDrUn6l+PYiKqVGjC6WIPALgXAAbVfXoDNMFwC8AnAOgFsDlqjo3sgZELJTnIZe5+qfv3uTqEb1PDi2jFYFzL/Zsjq9xRGksZDlR6BsQtSbJpOT8yMGjAM7KMv1sACO9xwQAvy74AxARABtZ5hAxmRLlVq+qvi4iQ7PMMh7A46qqAN4UkZ4i0k9V10XWCKJ2ykKWs3awz+xc5eorqsId+ujuQ1z9x6X3u7q7lLm6vpCWEbVAMsJQ5mAAgDWB59Xea62ygw3m+dJV/+fqI7oNdvVzy34TWuYgKXf1VhAVj4Uscw+WTMlnq1dEJiA1HLTfFFWdEnmjiChvFrLMDpZM2ZfHqf1eAAsJ4VoAgwLPB3qvEVGBLGQ5awd7QpdDXF2V3BuaNrZDhaundjzI1bp3R6FtImqxKI/b5GAqgKtF5EkAJwLY0ZqPvx5T6Z8RvLBht6tPKO/q6r9U9Akts28Pr+REpWEhy9yDJVMaNbr3EpHfAzgFQB8RqQZwM4ByAFDV+wFMQ+q0/iqkTu3/cnRrJ2rfLGSZHSyZ0ohIzzy8pJnpCuCb2eYhopaxkOWsHezupH+5/iX9wj8herbqIVd/tvtQV88dNN7VNWlnJBLFLcqtXmvq1b+D9Zr+fk5fCfwK4Jxug0LLvNfXv4hMzaqnY2wdUZiFLHMPlkxpbH4WImoDLGSZHSyZsjfCYSUiKh0LWWYHS6Y0qoFxJSIykeWsHez7nfzjMV32hq/j8sku/Vw9Lelf0nhwlyEgKhULw0pxWVjhZ7bTrmWu/niXvq5+Je0fcHBl/9jbRZSJhSxzD5ZMsRBKIrKRZXawZIqFUBKRjSxn7WBP3OtfyOL1veEruqxvqHN17T7/6k1fqf6dq79fcPOI8rMXbf+4TVzG7vPv5/qPwJWcNuyrdfWexvDX2oUbXnT1bTG2jSidhSxzD5ZMsXBiBBHZyDI7WDLFwrASEdnIctYOdkbSv7fr6Yk9oWmLT7rb1aNqd7r6lrevjahpRPlrNDCsFJfXAt9YnyzzD/GsHnO7q0ft2x1cBD+fNzH2dhFlYiHL3IMlUyyEkohsZJkdLJliYViJiGxkOWsHW1Zb7eq3Dzk5NG3njK/4b9J9pKv7H/F1Vy9efG/BDSTKx15NlroJrVZ5/XpXv9PnBFfXvH2Nqzt2Pzy0TP/D/JwvDdzggyhuFrLMPVgyxcKwEhHZyDI7WDIlaeDUfiKykWV2sGSKha1eIrKR5awd7KDjJrt6Q9rNlmsDtxK6APWufnbJlKjaRpQ3C6GMy8CjbnD1uuo/uboe/s06xiH8c7xnlj0ee7uIMrGQ5UTzsxC1HY2qOT+aIyJnichiEakSkeszTL9cRDaJyHzvcUUsH4qoHYoyy0Bp8swhYjJlX0RnHopIGYD7AHwKQDWA2SIyVVU/SJv1KVW9OpKVEpETVZaB0uU5awd76YoHXb0rbSPh2VHfcvXCrv49YM9ffJern6tZXWj7iPIS4bDSWABVqrocAETkSQDjAaQHss24YO0zrt4V+Hd6YcRlrl6Q9jOd86p+5eqpNWtibB1RWMRDxCXJM4eIyZRGaM6PZgwAEOxRqr3X0l0gIu+JyB9EZFBUn4OovYswy0CJ8swOlkxJqub8EJEJIjIn8JiQ5+peADBUVY8FMB3AY9F/IqL2qchZBmLIc9Yh4k2BDYN7azeHpo045FhXV711nasXBe4tSVRs+QwrqeoUAE2d9r4WQHALdqD3WnD5LYGnDwH4n5xXXgJbA8e0Hqrz7+E8ovdxrl4y78ehZRYl98XfMKIMIswyUKI88yQnMiXCEyNmAxgpIsOQCuLFAL4QnEFE+qnqOu/pOAALo1o5UXsX5UlOKFGe2cGSKVHdpFlVG0TkagAvASgD8IiqLhCRWwHMUdWpAL4tIuMANADYCuDySFZORJHecL1UeRbN8iGOHDLMTTxkyOdC0zZvmePqXbv9Y8ejTvbPOlw088pC20cEAFi5apU0Pxcw7vAxOady6pL5Ob2nFaOGDHX/NocO/ox7feu291xds2tFaJkjP3anqz+Yw3s9U+HaU5a5B0umWLj6CxHZyDI7WDIlaeAWV0RkI8vsYMmUpIGtXiKykeWsHWzwgv4dOnQJTdu1e5Wryzt0c/WSWd8CUalEfOahKfVNvF6zy89yWVnn0LRFcyfG2CKiplnIMvdgyRQL95AkIhtZZgdLprT9bV4iAmxkOWsHe1jgIuCXbJ0VmvbkiC+7evWaqa4e08F/y9l1TQ1KEcXDwlZvXEYM/y9Xj9v6lqufC/xk58P1fwstc0y5/+/5Th2v0kbFYyHL3IMlUyycGEFENrLMDpZMaTBwYgQR2chy1g5214onXP1m596haUurHnb1wLJOri6r6BpV24jyZmGrNy7bVjzp6nmBPK9c5d8ndnCHytAy+1ARf8OIMrCQZe7BkinJtp9JIoKNLLODJVMsbPUSkY0sZ+1gNwYuNDG+Mjx0NK/WX3RHZ//G8N32bIyqbUR5sxDKuGyThKvHd/IvDvNB3VZXb+88KLRMx7rQLTOJisZClrkHS6YYOLOfiGAjy+xgyRQLd+AgIhtZZgdLplgYViIiG1nO2sF2SO5x9W3SNzStc2Kdq2tqlrm6olOvqNpGlLe2H8n4lCf3uvon6OnqTlLt6h07l4SWqajoEXu7iDKxkGXuwZIpFkJJRDayzA6WTLEwrERENrKctYMdFLg4eMfKAaFpyzb6F/8fGRhGuqRzH1fPr98KomKKMpIichaAXwAoA/CQqk5Om14B4HEAHwOwBcBFqroywiZEqu+gca6uDPwcZ+WWOa4+oqJnaJlzKw929YI922NrG1G6qLvXUuQ50fwsRG1HMo9HNiJSBuA+AGcDOBLAJSJyZNpsXwWwTVUPA3AXgDsi+hhE7V5UWQZKl2d2sGSK5vFoxlgAVaq6XFX3AngSwPi0ecYDeMyr/wDgdBEREFHBIswyUKI8Zx0ibmz0zyLevuWt0LTzug/2pzX6ZyfesL0aRG2BiEwAMCHw0hRVneLVAwCsCUyrBnBi2lu4eVS1QUR2AOgNYHM8LS5MaiM+ZeeOBa7+dLeBrt4SyDwA3FHDKzlR69dMloES5ZknOZExuW9wegGc0uyMRFQCbT/LHCImYySPR1ZrAQQvzDvQey3jPCLSAUAPpE6OIKKCRZZloER5zroH27mHfwx4/Xu3hKa91HWYq3fV+nvQCSkvpD1EhYnuEOhsACNFZBhSwbsYwBfS5pkK4EsA/gngcwD+ptp6r6Ba2WW4qzev/rmrXwncrKO2Lvx9UpboBKKSiPZ0hpLkmUPEZEw0gzLeMZirAbyE1Gn9j6jqAhG5FcAcVZ0K4GEA/yciVQC2IhVaIopEdAOspcozO1gyRfI4btMcVZ0GYFraaz8O1PUALoxshUTkRJlloDR5ZgdLtvBXMkQ2GMhy1g5WyipcPa77kNC0OXUbXN0p8O/Q53D/TOlFi+4usHlE+Yl6q9eSRJl/PPX0bv5x1/mB4661Eh6W6zn0866uWvabGFtHFGYhy9yDJWN4YjyRDW0/y+xgyRRJlDU/ExG1ehaynLWDPXaZ/7vdf/YaG5rWedgxrq7o4l8QvOu8G6NqG1HexMBWb1xOWPOEq1/tdpSruw4Y5epulYeGlilf+L/xN4woAwtZ5h4smcJLARPZYCHL7GDJFmn7W71EBBNZztrBvnWyf7u8k/5xfWjaG9UvuPqITj1d/WrdtoiaRpQ/MRDKuMz8t5+4+qOzrnP12+v/7uqRlT1Dy7xRvz3uZhFlZCHL3IMlUywctyEiG1lmB0umJBK8FjaRBRaynLWDXf/a9109b/AFoWn9Vv3W1W+io6uPPPoGV3/w/u0FN5AoHxaGleKy9rVrXS0Dznb1IdV/cvWcZPgrYdTo77p60cI742scURoLWeYeLJliIZREZCPL7GDJFJG2/+N0IrKRZXawZIqFrV4ispHlrB1sff0mV9fVrg5NWxG4QHi33v7N15fwuCuVUCLRsfmZ2qk9e7a6ura22tVr9mx3dZdeg0PLLOFxVyoRC1nmHiyZkjCw1UtENrLMDpZMsXDchohsZDlrB3tQL/+C/tUfTg9NGzL0YlevWjPV1QM6dHb16obaghtIlA8LoYxLr56jXb1h4yxXDxg0ztXVa18KLXNQ4LeIm5P7YmwdUZiFLHMPlkwpVihF5CAATwEYCmAlgM+r6gHXCRWRRgD/8p6uVtVx6fMQ0YEsZLntD3ITBUiiLOdHga4H8KqqjgTwqvc8kzpVHeM92LkS5chClrPuwR5Rt8rVh1T2Ck17s369qysD95DcW7s2l/USxaKseGcejgdwilc/BmAGgB8Wa+UtMaLez2bvwEX95+3Z7OrOnfuHlkkEzjYmKiYLWeYeLJkiUpbHQyaIyJzAY0IeqzpUVdd59XoAhzYxXyfvvd8UkfML+nBE7YiFLPMYLJkikvuftKpOATCl6feSVwD0zTBpYtr7qIhoE28zRFXXishwAH8TkX+p6rKcG0nUTlnIcvaL/f/HL1x95Js3haatC9xDMrgbPPjYSf7y700CUTElIjwxQlXPaGqaiGwQkX6quk5E+gHY2MR7rPX+u1xEZgA4DkBJOthtHw/k+W0/zxs2/sPV6UNag4++0dUbeREZKiILWeYQMZlSxBMjpgL4kld/CcDzB7RFpJeIVHh1HwAnA/ig0BUTtQcWsswhYjIln2GlAk0G8LSIfBXAKgCfT61fjgdwlapeAWA0gAdEJInUxuxkVWUHS5QDC1lmB0umJBIVRVmPqm4BcHqG1+cAuMKrZwE4Jn0eImqehSxn7WA7bfJP0R+WCI8mn1DZx9Xv1Pmn+SPZkG8biCIT5XEbayo2+zfsGJbwo39URU9XLwhc+B8Akg274m4WUUYWssw9WDJFEvyTJrLAQpbb/icgCijicRsiipGFLGf9BMvmXufquw4+MTTtpi5JV8+p3+nqtWuejaptRHmzcIHwuCyb9yNX39PrWFd/v/Mhrr5xX31ombVr/xx/w4gysJDltr+JQBSQKOtU6iYQUQQsZJkdLJliYViJiGxkOesn6F/u39u1Jm3ad7Z+6OrDP+lfoarDjnWu3jHvhgKbR5QnA6GMS98Ola7eFbiQ+o3bFrt61NifhZZJ7ljh6kUL74yxdURpDGS57X8CogALZx4SkY0st/1PQBRgYViJiGxkOesn+FZ3/96QDxzx1dC0AXP82+Vd8O5/u/qJ/p+Nqm1E+TOw1RuXCd0Guvq3I77o6t7z/RuKnL/oF6Fl/njomfE3jCgTA1lu+5+AKEDKinN5NSKKl4Uss4MlWwwMKxERTGS57X8CogA1MKxERDaynPUTTNzi30u266xvhqZd13OIq3/V/WOu3rHu5ajaRpS/wu8Nadbt26pc3XnO91x9TY9hrn6o29GhZWo2vhF/w4gyMZDltr+JQBRkIJREBBNZZgdLpiQ7dGx+JiJq9SxkOWsH23foRa5ev/Kp0LQHd613dffal1w9oKzc1RxcoqIzsNUbl0OGfM7V61Y94+rHdvlXZetauzG8TOA42NYY20Z0AANZ5h4smaKJRKmbQEQRsJDltv8JiAI0UZbzoxAicqGILBCRpIgcn2W+s0RksYhUicj1Ba2UqB2xkOWse7BrVz7p6kt7DA9Nm9XvPFcvC9wzssvuNbmslygWybKibTO+D+CzAB5oagZJ3dDyPgCfAlANYLaITFXVD4rTxLDgsPCFgTOH3zr4NFev3jAjtAzzTKViIcscIiZTijWspKoLAUBEss02FkCVqi735n0SwHgAJelgidoSC1lmB0umNHbI/U9aRCYAmBB4aYqqTmlq/hYYACC4C1gN4MQI35/ILAtZzvoJzug2wNW/rd8Rmvbt9dNc/WzvE1x9WQd/q2PStiXNrZ8oUprHsJIXwCZDKCKvAOibYdJEVX0+/9aV1qld/Zt3PF2/09VXbn7N1S/0PCq0zCWBf887tjPPVDwWssw9WDJFE1mHefJ7L9UzCnyLtQAGBZ4P9F4jomZYyDI7WDIlWRZdKCMwG8BIERmGVBgvBvCF0jaJqG2wkGX+TIdM0YTk/CiEiHxGRKoBnATgLyLykvd6fxGZBgCq2gDgagAvAVgI4GlVXVDQionaCQtZzroHu3LkNa4+aP6NoWl37dns6sbgsZlehze3TqLYRDmslHU9qs8BeC7D6x8COCfwfBqAaenzlcLq4Ve6+qD3f+rq+wNZTm5fGlpmV8+R8TeMKAMLWeYQMZmSLG9Vw0pE1EIWsswOlmzhQQ8iGwxkOWsH27B7las3qIamHTH6+67uvepxV0/aFh5iIiqqtn998Njsq6129XpNuvqII652da814Zt63Ld9efwNI8rEQJa5B0u2GNjqJSKYyDI7WLLFQCiJCCaynLWDXbr4Xld/uceI0LTpNQtdXRW4h+S1Pf2ziO/ilV+oyBIdtPmZ2qllVQ+5+kuBm3e8vMP/tUFV7abQMlf2OMzVD+6oirF1RGEWssw9WDJFDGz1EpGNLLODJVMKvDUkEbUSFrKctYMdWt7F1VNrVocXrPHPSPx0t4GufhQdo2obUd6KdIerNqlvWSdXP7vD/4VAZSDLpwVuCAAATyZBVBIWssw9WDJFEm3/uA0R2cgyO1gyxcJWLxHZyDI7WDIlj3s0E1ErZiHLWT/Cpn31rh5V2Ss0reZj/sXCX5zl3xRgYg//uO1tBTePKD8Wtnrjsiu5z9XDKrq7uuGYH7v61bk3hJa5pntXV98VY9uI0lnIsoFtBCKfhVASkY0ss4MlU8oMnBhBRDaynLWDvbjHMFdvSO4JTRvzrztcvbv/qa5+9uCP+zNtCw83EcXNwlZvXMZ1H+LqdQ11ru639Jeu3tn3lNAyf+01xn+yfXJcTSM6gIUscw+WTLEQSiKykWV2sGRKuYGrvxCRjSxn7WDf6Heuq5cuujs0beQnHnH1xreud3X5hhnRtIyoBcqKtNUrIhcCmARgNICxqjqniflWAqgB0AigQVWPL04LD/Rmn0+6unrZo64ePvYXrt4wf1J4oQ2vxdwqoswsZJl7sGRKEYeV3gfwWQAP5DDvqaq6Oeb2EJliIcvsYMmUYm31qupCABCR4qyQqJ2xkOWsHWzv1U+4OjnyqtC0JTOvdnV/8U+nXpFsjKptRHkrVijzoABeFhEF8ICqTilVQ/qs+7PfqOH/5epFc65zdf+0L5nqZEP8DSPKwEKWuQdLppTn8RctIhMATAi8NCUYGhF5BUDfDItOVNXnc1zNx1V1rYgcAmC6iCxS1ddzbyVR+2Qhy+xgyZSyPEZ5vAA2uRWqqmcU2h5VXev9d6OIPAdgLAB2sETNsJDlrB1sI/yh37JVT4SmjTre/9H5uiX+5xp20k2uXv7iRTk0myg6rWlYSUS6AEioao1Xnwng1lK1p7bRH+7VNc+6etQxE129duXvQ8uMOuZaVy+aeWWMrSMKs5DlVvQRiApXlsj9UQgR+YyIVAM4CcBfROQl7/X+IjLNm+1QADNF5F0AbwP4i6r+tbA1E7UPFrLMIWIypUOiOGf1qupzAJ7L8PqHAM7x6uUAPlKUBhEZYyHL7GDJlNY0rERELWchy1k72J1H+8dTNyx9ODSt+1z/6k3juvgnZz07cyKISqWjgcurxSV439dtVX6euy3w7+18TudDQ8tMnXd7/A0jysBClrkHS6ZY2OolIhtZZgdLppQV6bgNEcXLQpaz/0xnxxJXj90XvvziR7oOdPWSxnpXny3+ZsfTBTePKD8Wtnrj0lizzNUfCeT5+ECWFzbWhpY5M3BB2APOAiGKkYUscw+WTMnnx+lE1HpZyDI7WDLFwrASEdnIctYONhkY+j25vEto2sM11a5eo/4/xB09+ruaQ8RUbB07tP1QxqWxYberT+/Y3dVNZRkAftTdHz7mEDEVk4Uscw+WTLFw3IaIbGSZHSyZkuD9WYlMsJBlUdXm5yJqIx6Y+6+c/6C/9tFj2n6CiYyykGXuwZIpFk6MICIbWWYHS6ZYODGCiGxkmR0smWJhq5eIbGSZHSyZYuHMQyKykWV2sGRKwsBWLxHZyDI7WDLFwrASEdnIMjtYMsXCsBIR2ciygY9A5CvvkMj5UQgR+V8RWSQi74nIcyLSs4n5zhKRxSJSJSLXF7RSonbEQpbZwZIpZYncHwWaDuBoVT0WwBIAN6TPICJlAO4DcDaAIwFcIiJHFrxmonbAQpbZwZIpZQnJ+VEIVX1ZVRu8p28CGJhhtrEAqlR1uaruBfAkgPEFrZionbCQZR6DJVNKdGLEVwA8leH1AQDWBJ5XAzixKC0iauMsZJkdLJlyWt/Dck6liEwAMCHw0hRVnRKY/gqAvhkWnaiqz3vzTATQAOB3LWsxEWViIcvsYKnd8gI4Jcv0M7ItLyKXAzgXwOma+a4ZawEMCjwf6L1GRBFqrVnmMViiFhCRswBcB2CcqtY2MdtsACNFZJiIdARwMYCpxWojETUvziyzgyVqmXsBdAMwXUTmi8j9ACAi/UVkGgB4J05cDeAlAAsBPK2qC0rVYCLKKLYs836wREREMeAeLBERUQzYwRIREcWAHSwREVEM2MESERHFgB0sERFRDNjBEhERxYAdLBERUQzYwRIREcXg/wEy8bIgDkEMPwAAAABJRU5ErkJggg==\n", 209 | "text/plain": [ 210 | "
" 211 | ] 212 | }, 213 | "metadata": { 214 | "needs_background": "light" 215 | }, 216 | "output_type": "display_data" 217 | } 218 | ], 219 | "source": [ 220 | "plt.figure(figsize=(8, 4))\n", 221 | "plt.subplot(1, 2, 1)\n", 222 | "sns.heatmap(W, xticklabels=False, yticklabels=False, annot=False, square=True, vmin=-2, vmax=2, center=0)\n", 223 | "plt.title('Ground Truth')\n", 224 | "plt.subplot(1, 2, 2)\n", 225 | "sns.heatmap(coef2, xticklabels=False, yticklabels=False, annot=False, square=True, vmin=-2, vmax=2, center=0)\n", 226 | "plt.title('Estimated')\n", 227 | "plt.show()" 228 | ] 229 | } 230 | ], 231 | "metadata": { 232 | "kernelspec": { 233 | "display_name": "Python 3 (ipykernel)", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.7.3" 248 | } 249 | }, 250 | "nbformat": 4, 251 | "nbformat_minor": 4 252 | } 253 | --------------------------------------------------------------------------------