├── tests ├── dsp │ ├── __init__.py │ ├── test_fwht.py │ ├── test_dct.py │ └── test_signals.py ├── linear_systems │ ├── ls_setup.py │ ├── test_jacobi.py │ ├── test_triangular.py │ └── test_solve_submatrix.py ├── svd │ ├── svd_setup.py │ ├── test_lansvd.py │ ├── test_svd_utils.py │ └── test_subspaces.py ├── core │ ├── test_special.py │ ├── test_matrix.py │ ├── test_distance.py │ ├── test_ndarray.py │ ├── test_similarity.py │ ├── test_util.py │ ├── test_metrics.py │ ├── test_norms.py │ └── test_vector.py ├── decomposition │ └── test_cholesky.py ├── geom2d │ └── test_linear.py ├── io │ ├── test_sphinx.py │ └── test_resource.py ├── orthogonalization │ ├── test_rq.py │ └── test_householder.py ├── discrete │ └── test_number.py ├── affine │ └── test_affine.py └── signal │ └── test_signal.py ├── docs ├── _static │ ├── js │ │ ├── custom.js │ │ └── mathconf.js │ └── css │ │ └── custom.css ├── changelog.md ├── zzzreference.rst ├── autobuild.sh ├── references.bib ├── source │ ├── index.rst │ ├── metrics.rst │ ├── data.rst │ ├── array.rst │ ├── compression.rst │ ├── util.rst │ ├── vector.rst │ ├── matrix.rst │ ├── la.rst │ └── dsp.rst ├── requirements.txt ├── _templates │ └── namedtuple.rst ├── index.rst ├── Makefile ├── make2.bat └── start.rst ├── src └── cr │ └── nimble │ ├── version.py │ ├── _src │ ├── io │ │ ├── sphinx.py │ │ └── resource.py │ ├── svdpack │ │ ├── lansvd_utils.py │ │ ├── bdsqr.py │ │ ├── reorth.py │ │ └── lansvd.py │ ├── dsp │ │ ├── convolution.py │ │ ├── util.py │ │ ├── interpolation.py │ │ ├── quantization.py │ │ ├── energy.py │ │ ├── features.py │ │ ├── spectrum.py │ │ ├── scaling.py │ │ ├── wht.py │ │ ├── dct.py │ │ └── thresholding.py │ ├── affine.py │ ├── dls.py │ ├── similarity.py │ ├── array.py │ ├── latex.py │ ├── triangular.py │ ├── chol.py │ ├── linear.py │ ├── noise.py │ ├── compression │ │ ├── run_length.py │ │ ├── bits.py │ │ ├── fixed_length.py │ │ └── binary_arrs.py │ ├── standard_matrices.py │ ├── rq.py │ ├── discrete │ │ └── number.py │ ├── spd │ │ └── jacobi.py │ ├── toeplitz.py │ ├── ndarray.py │ ├── signalcomparison.py │ ├── util.py │ ├── norm.py │ └── householder.py │ ├── io │ ├── __init__.py │ └── resource.py │ ├── spd.py │ ├── test_setup.py │ ├── rq.py │ ├── affine.py │ ├── data.py │ ├── dsp │ ├── signals.py │ └── __init__.py │ ├── subspaces.py │ ├── compression.py │ └── svd.py ├── setup.cfg ├── requirements ├── requirements-tests.txt └── requirements.txt ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── release.yml │ ├── ci.yml │ └── sphinx.yaml ├── .readthedocs.yaml ├── README.md ├── .gitignore ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── CHANGELOG.md └── setup.py /tests/dsp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/js/custom.js: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/cr/nimble/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.3.2' -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CHANGELOG.md 2 | ``` -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /requirements/requirements-tests.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-cov 3 | -------------------------------------------------------------------------------- /docs/zzzreference.rst: -------------------------------------------------------------------------------- 1 | References 2 | =================== 3 | 4 | .. bibliography:: 5 | :cited: 6 | -------------------------------------------------------------------------------- /tests/linear_systems/ls_setup.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | import cr.nimble.data as data 4 | import cr.nimble.spd as spd 5 | -------------------------------------------------------------------------------- /docs/autobuild.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sphinx-autobuild --host=0.0.0.0 --port=9400 -N . _build/html --watch ../src/cr/nimble --re-ignore "gallery.*" 3 | 4 | -------------------------------------------------------------------------------- /tests/svd/svd_setup.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | import cr.nimble.svd 4 | import cr.nimble.data 5 | import cr.nimble.subspaces as subspaces 6 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=1.0.0 2 | requests>=2.20.0 3 | jax>=0.3.14 4 | jaxlib>=0.3.14 5 | numpy>=1.20.0 6 | scipy>=1.6 7 | matplotlib 8 | sympy>=1.6 9 | bitarray>=2.6 10 | 11 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | 2 | @book{golub2012matrix, 3 | title={Matrix computations}, 4 | author={Golub, Gene H and Van Loan, Charles F}, 5 | volume={3}, 6 | year={2012}, 7 | publisher={JHU Press} 8 | } 9 | 10 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | API Docs 2 | ============================= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :caption: Contents: 8 | 9 | array 10 | vector 11 | matrix 12 | metrics 13 | la 14 | dsp 15 | compression 16 | data 17 | util 18 | -------------------------------------------------------------------------------- /tests/linear_systems/test_jacobi.py: -------------------------------------------------------------------------------- 1 | from ls_setup import * 2 | 3 | def test_jacobi(): 4 | A = jnp.array([[3., 2], [2, 6]]) 5 | b = jnp.array([2., -8]) 6 | sol = spd.jacobi_solve_jit(A, b) 7 | r = A @ sol.x - b 8 | r_norm_sqr = r.T @ r 9 | assert jnp.isclose(0, r_norm_sqr, atol=1e-4) 10 | -------------------------------------------------------------------------------- /tests/core/test_special.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | def test_pascal_1(): 4 | n = 4 5 | A = cnb.pascal_jit(n) 6 | assert A.shape == (n,n) 7 | assert A[n-1,n-1] == 1 8 | 9 | def test_pascal_2(): 10 | n = 4 11 | A = cnb.pascal_jit(n, True) 12 | assert A.shape == (n,n) 13 | assert A[n-1,n-1] == 20 14 | -------------------------------------------------------------------------------- /tests/decomposition/test_cholesky.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | cholesky_build_factor = jit(cnb.cholesky_build_factor) 4 | 5 | def test_cholesky_update(): 6 | key = random.PRNGKey(0) 7 | A = random.normal(key, (4,4)) 8 | L = cholesky_build_factor(A) 9 | G1 = A.T @ A 10 | G2 = L @ L.T 11 | assert jnp.allclose(G1, G2) -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | chex>=0.0.4 2 | jax>=0.1.55 3 | jaxlib>=0.1.37 4 | numpy>=1.18.0 5 | scipy 6 | matplotlib 7 | sphinx==4.0.0 8 | sphinxcontrib-katex>=0.7.1 9 | sphinxcontrib-bibtex>=1.0.0 10 | sphinx-autodoc-typehints>=1.11.1 11 | IPython>=7.16.1 12 | ipykernel>=5.3.4 13 | nbsphinx>=0.8.0 14 | sphinx-gallery>=0.8.0 15 | sphinx-panels 16 | myst-parser 17 | -------------------------------------------------------------------------------- /docs/_templates/namedtuple.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | 8 | {% block attributes %} 9 | {% if attributes %} 10 | .. rubric:: {{ _('Attributes') }} 11 | 12 | .. autosummary:: 13 | {% for item in attributes %} 14 | ~{{ name }}.{{ item }} 15 | {%- endfor %} 16 | {% endif %} 17 | {% endblock %} -------------------------------------------------------------------------------- /tests/geom2d/test_linear.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | def test_point(): 4 | x = cnb.point2d(1, 2) 5 | 6 | def test_vec(): 7 | x = cnb.vec2d(1, 2) 8 | 9 | def test_rotate2d_cw(): 10 | theta = jnp.pi 11 | R = cnb.rotate2d_cw(theta) 12 | 13 | def test_rotate2d_ccw(): 14 | theta = jnp.pi 15 | R = cnb.rotate2d_ccw(theta) 16 | 17 | def test_reflect2d(): 18 | theta = jnp.pi 19 | R = cnb.reflect2d(theta) 20 | -------------------------------------------------------------------------------- /tests/io/test_sphinx.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | 4 | from cr.nimble import io 5 | import jax.numpy as jnp 6 | 7 | 8 | def test_print_dataframe_as_list_table(): 9 | d = { 10 | "one": pd.Series([1.0, 2.0, 3.0], index=["a", "b", "c"]), 11 | "two": pd.Series([1.0, 2.0, 3.0], index=["a", "b", "c"]), 12 | "three": pd.Series([1, 2, 3], index=["a", "b", "c"]), 13 | } 14 | df = pd.DataFrame(d) 15 | io.print_dataframe_as_list_table(df, 'abc') -------------------------------------------------------------------------------- /tests/dsp/test_fwht.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import math 3 | # jax imports 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | # crs imports 8 | import cr.nimble as cnb 9 | from cr.nimble.dsp import * 10 | 11 | atol = 1e-6 12 | rtol = 1e-6 13 | 14 | def test_fwht1(): 15 | for n in [4,8,16]: 16 | fact = 1/math.sqrt(n) 17 | for i in range(n): 18 | print(n, i) 19 | y = cnb.vec_unit(n, i) 20 | a = fact * fwht(y) 21 | x = fact * fwht(a) 22 | assert jnp.allclose(x, y, rtol=rtol, atol=atol) -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | CR-Nimble 2 | ===================================== 3 | 4 | CR-Nimble consists of fast linear algebra 5 | and signal processing routines. 6 | Most of the routines have been implemented using 7 | Google JAX. Thus, they can be easily run on 8 | a variety of hardware (CPU, GPU, TPU). 9 | 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Contents: 14 | 15 | start 16 | source/index 17 | zzzreference 18 | changelog 19 | 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | 24 | **Additional context** 25 | Add any other context about the problem here. 26 | -------------------------------------------------------------------------------- /tests/svd/test_lansvd.py: -------------------------------------------------------------------------------- 1 | from svd_setup import * 2 | 3 | 4 | def test_lanbpro1(): 5 | A = jnp.eye(4) 6 | r = cnb.svd.lanbpro_random_start(cnb.KEYS[0], A) 7 | state = cnb.svd.lanbpro_jit(A, 4, r) 8 | assert_allclose(state.alpha, 1., atol=atol) 9 | assert_allclose(state.beta[1:], 0., atol=atol) 10 | 11 | 12 | def test_lansvd1(): 13 | A = jnp.eye(4) 14 | r = cnb.svd.lanbpro_random_start(cnb.KEYS[0], A) 15 | U, S, V, bnd, n_converged, state = cnb.svd.lansvd_simple_jit(A, 4, r) 16 | assert_allclose(state.alpha, 1., atol=atol) 17 | assert_allclose(state.beta[1:], 0., atol=atol) 18 | -------------------------------------------------------------------------------- /docs/source/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ================ 3 | 4 | .. contents:: 5 | :depth: 2 6 | :local: 7 | 8 | .. currentmodule:: cr.nimble 9 | 10 | 11 | Error 12 | -------------- 13 | 14 | .. autosummary:: 15 | :toctree: _autosummary 16 | 17 | mean_squared 18 | mean_squared_error 19 | root_mean_squared 20 | root_mse 21 | normalized_root_mse 22 | normalized_mse 23 | percent_rms_diff 24 | percent_space_saving 25 | compression_ratio 26 | prd_to_snr 27 | 28 | 29 | SNR 30 | ------------- 31 | 32 | .. autosummary:: 33 | :toctree: _autosummary 34 | 35 | peak_signal_noise_ratio 36 | signal_noise_ratio 37 | -------------------------------------------------------------------------------- /tests/core/test_matrix.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | from cr.nimble import * 3 | 4 | def test_is_symmetric(): 5 | x = jnp.array([1,2]) 6 | assert not is_symmetric(x) 7 | A = jnp.array([[1,2], [2, 1]]) 8 | assert is_symmetric(A) 9 | 10 | def test_is_hermitian(): 11 | x = jnp.array([1,2]) 12 | assert not is_hermitian(x) 13 | A = jnp.array([[1,2], [2, 1]]) 14 | assert is_hermitian(A) 15 | 16 | def test_is_positive_definite(): 17 | x = jnp.array([1,2]) 18 | assert not is_positive_definite(x) 19 | A = jnp.array([[1.,0], [0, 1]]) 20 | if is_cpu(): 21 | assert is_positive_definite(A) 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /tests/orthogonalization/test_rq.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | import cr.nimble.rq as rq 4 | 5 | def test_1(): 6 | A = jnp.eye(3) 7 | R, Q = rq.factor_mgs(A) 8 | 9 | def test_2(): 10 | with pytest.raises(Exception): 11 | R, Q = rq.factor_mgs(jnp.zeros((5, 3))) 12 | 13 | def test_3(): 14 | A = jnp.eye(3) 15 | n, m = A.shape 16 | Q = jnp.empty([n, m]) 17 | R = jnp.zeros([n, n]) 18 | R, Q = rq.update(R, Q, A[0], 0) 19 | R, Q = rq.update(R, Q, A[1], 1) 20 | 21 | 22 | def test_4(): 23 | A = jnp.eye(3) 24 | n, m = A.shape 25 | Q = jnp.empty([n, m]) 26 | R = jnp.zeros([n, n]) 27 | x = jnp.zeros(n) 28 | rq.solve(R, Q, x) 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /docs/source/data.rst: -------------------------------------------------------------------------------- 1 | Test Data Generation 2 | ============================ 3 | 4 | .. contents:: 5 | :depth: 2 6 | :local: 7 | 8 | Random Keys 9 | ------------------------ 10 | 11 | .. currentmodule:: cr.nimble 12 | 13 | .. autosummary:: 14 | :toctree: _autosummary 15 | 16 | KEY0 17 | KEYS 18 | 19 | 20 | Random Subspaces 21 | ------------------------ 22 | 23 | .. currentmodule:: cr.nimble.data 24 | 25 | .. autosummary:: 26 | :toctree: _autosummary 27 | 28 | random_subspaces 29 | random_subspaces_jit 30 | uniform_points_on_subspaces 31 | uniform_points_on_subspaces_jit 32 | two_subspaces_at_angle 33 | two_subspaces_at_angle_jit 34 | three_subspaces_at_angle 35 | three_subspaces_at_angle_jit 36 | -------------------------------------------------------------------------------- /docs/source/array.rst: -------------------------------------------------------------------------------- 1 | Arrays 2 | =============== 3 | 4 | These functions are applicable for general ND-arrays. 5 | 6 | .. contents:: 7 | :depth: 2 8 | :local: 9 | 10 | 11 | .. currentmodule:: cr.nimble 12 | 13 | 14 | Inner products 15 | ------------------------ 16 | 17 | .. autosummary:: 18 | :toctree: _autosummary 19 | 20 | arr_vdot 21 | arr_rdot 22 | 23 | 24 | Norms 25 | --------------------- 26 | 27 | .. autosummary:: 28 | :toctree: _autosummary 29 | 30 | 31 | arr_l1norm 32 | arr_l2norm 33 | arr_l2norm_sqr 34 | 35 | 36 | Utilities 37 | ------------------------ 38 | 39 | .. autosummary:: 40 | :toctree: _autosummary 41 | 42 | hermitian 43 | arr_largest_index 44 | arr2vec 45 | check_shapes_are_equal 46 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.8" 13 | 14 | 15 | # Build documentation in the docs/ directory with Sphinx 16 | sphinx: 17 | configuration: docs/conf.py 18 | 19 | # Optionally build your docs in additional formats such as PDF 20 | formats: 21 | - epub 22 | - htmlzip 23 | 24 | # Optionally set the version of Python and requirements required to build your docs 25 | python: 26 | install: 27 | - requirements: docs/requirements.txt 28 | - requirements: requirements/requirements.txt -------------------------------------------------------------------------------- /tests/io/test_resource.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from cr.nimble.io.resource import * 3 | 4 | 5 | def test_get_uri(): 6 | assert get_uri('abc') is None 7 | assert get_uri("haarcascade_frontalface_default.xml") is not None 8 | 9 | def test_ensure_resource(): 10 | fname = "sst_nino3.dat" 11 | path = CACHE_DIR / fname 12 | if path.is_file(): 13 | path.unlink() 14 | path = ensure_resource(fname, "http://paos.colorado.edu/research/wavelets/wave_idl/sst_nino3.dat") 15 | assert path.is_file() 16 | path = ensure_resource("http://paos.colorado.edu/research/wavelets/wave_idl/sst_nino3.dat") 17 | assert path.is_file() 18 | path = ensure_resource(None) 19 | assert path is None 20 | path = ensure_resource('abc') 21 | assert path is None 22 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/io/sphinx.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def format_value(value, dtype=None): 5 | if dtype == 'int64': 6 | return f'{int(value)}' 7 | if dtype == np.float64: 8 | return f'{value:.2f}' 9 | return f'{value}' 10 | def print_dataframe_as_list_table(df, title): 11 | print(f'.. list-table:: {title}') 12 | print(' :header-rows: 1\n') 13 | cols = df.columns 14 | dtypes = df.dtypes 15 | print(f' * - {df.index.name}') 16 | for i, col in enumerate(cols): 17 | print(f' - {col}') 18 | for index, row in df.iterrows(): 19 | print(f' * - {format_value(index)}') 20 | for i, col in enumerate(cols): 21 | dtype = dtypes[col] 22 | print(f' - {format_value(row[col], dtype)}') 23 | -------------------------------------------------------------------------------- /tests/dsp/test_dct.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | # jax imports 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | # crs imports 8 | import cr.nimble as cnb 9 | from cr.nimble.dsp import * 10 | 11 | atol = 1e-6 12 | rtol = 1e-6 13 | 14 | def test_dct1(): 15 | for n in [4,8,16]: 16 | for i in range(n): 17 | print(n, i) 18 | y = cnb.vec_unit(n, i) 19 | a = dct(y) 20 | x = idct(a) 21 | assert jnp.allclose(x, y, rtol=rtol, atol=atol) 22 | 23 | def test_orthonormal_dct1(): 24 | for n in [4,8,16]: 25 | for i in range(n): 26 | print(n, i) 27 | y = cnb.vec_unit(n, i) 28 | a = orthonormal_dct(y) 29 | x = orthonormal_idct(a) 30 | assert jnp.allclose(x, y, rtol=rtol, atol=atol) -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | clean: 18 | rm -rf $(BUILDDIR)/* 19 | rm -rf source/_autosummary/* 20 | rm -rf gallery/* 21 | 22 | # Catch-all target: route all unknown targets to Sphinx using the new 23 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 24 | %: Makefile 25 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 26 | -------------------------------------------------------------------------------- /tests/discrete/test_number.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | from cr.nimble import * 5 | import jax.numpy as jnp 6 | 7 | def test_numbers(): 8 | assert is_integer(10) 9 | assert is_positive_integer(10) 10 | assert is_negative_integer(-10) 11 | assert is_odd(-3) 12 | assert is_even(40) 13 | assert is_odd_natural(5) 14 | assert is_even_natural(6) 15 | assert is_power_of_2(1024) 16 | assert is_perfect_square(81) 17 | a,b = integer_factors_close_to_sqr_root(20) 18 | assert a == 4 19 | assert b == 5 20 | a,b = integer_factors_close_to_sqr_root(25) 21 | assert a == 5 22 | assert b == 5 23 | a,b = integer_factors_close_to_sqr_root(90) 24 | assert a == 9 25 | assert b == 10 26 | a,b = integer_factors_close_to_sqr_root(77) 27 | assert a == 7 28 | assert b == 11 29 | 30 | -------------------------------------------------------------------------------- /src/cr/nimble/io/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | IO Utilities 17 | """ 18 | 19 | # pylint: disable=W0611 20 | 21 | 22 | 23 | from cr.nimble._src.io.sphinx import ( 24 | print_dataframe_as_list_table 25 | ) 26 | -------------------------------------------------------------------------------- /src/cr/nimble/spd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Jacoby Iterations 17 | """ 18 | 19 | # pylint: disable=W0611 20 | 21 | 22 | from cr.nimble._src.spd.jacobi import ( 23 | jacobi_solve, 24 | jacobi_solve_jit 25 | ) 26 | -------------------------------------------------------------------------------- /src/cr/nimble/test_setup.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | from numpy.testing import (assert_almost_equal, assert_allclose, assert_, 5 | assert_equal, assert_raises, assert_raises_regex, 6 | assert_array_equal, assert_warns) 7 | 8 | 9 | from jax.config import config 10 | config.update("jax_enable_x64", True) 11 | 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | from jax import random, lax, vmap, jit 16 | 17 | import cr.nimble as cnb 18 | 19 | rtol = 1e-8 if jax.config.jax_enable_x64 else 1e-6 20 | atol = 1e-7 if jax.config.jax_enable_x64 else 1e-5 21 | decimal_cmp = 7 if jax.config.jax_enable_x64 else 6 22 | 23 | float_type = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 24 | complex_type = jnp.complex128 if jax.config.jax_enable_x64 else jnp.complex64 25 | -------------------------------------------------------------------------------- /src/cr/nimble/rq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | RQ decomposition for row major matrices 17 | """ 18 | # pylint: disable=W0611 19 | 20 | 21 | from cr.nimble._src.rq import ( 22 | factor_mgs, 23 | update, 24 | solve 25 | ) -------------------------------------------------------------------------------- /tests/linear_systems/test_triangular.py: -------------------------------------------------------------------------------- 1 | from ls_setup import * 2 | 3 | def test_lxb(): 4 | n = 5 5 | L = jnp.eye(n) 6 | b = jnp.ones(n) 7 | x = cnb.solve_Lx_b(L, b) 8 | assert jnp.allclose(b, x) 9 | 10 | def test_ltxb(): 11 | n = 5 12 | L = jnp.eye(n) 13 | b = jnp.ones(n) 14 | x = cnb.solve_LTx_b(L, b) 15 | assert jnp.allclose(b, x) 16 | 17 | def test_uxb(): 18 | n = 5 19 | L = jnp.eye(n) 20 | b = jnp.ones(n) 21 | x = cnb.solve_Ux_b(L, b) 22 | assert jnp.allclose(b, x) 23 | 24 | def test_utxb(): 25 | n = 5 26 | L = jnp.eye(n) 27 | b = jnp.ones(n) 28 | x = cnb.solve_UTx_b(L, b) 29 | assert jnp.allclose(b, x) 30 | 31 | def test_spd_chol(): 32 | n = 5 33 | L = jnp.eye(n) 34 | b = jnp.ones(n) 35 | x = cnb.solve_spd_chol(L, b) 36 | assert jnp.allclose(b, x) 37 | 38 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | .wy-nav-content { 2 | max-width: 1200px; 3 | } 4 | 5 | .theoremenv { 6 | border: 1px solid #e1e4e5; 7 | margin: 1px 0 24px 0; 8 | padding-bottom: 10px; 9 | } 10 | 11 | .theoremenv_caption { 12 | background-color: #e8e4e4; 13 | padding-bottom: 5px; 14 | margin-bottom: 5px; 15 | padding-left: 10px; 16 | padding-top: 10px; 17 | } 18 | 19 | .theoremenv_counter:before{ 20 | content:" "; 21 | display:inline-block; 22 | width:5px; 23 | } 24 | 25 | .theoremenv_title:before{ 26 | content:" "; 27 | display:inline-block; 28 | width:5px; 29 | } 30 | 31 | .theoremenv_body { 32 | padding-left: 10px; 33 | } 34 | 35 | .think{ 36 | background: #e7f2fa; 37 | } 38 | 39 | .think_body{ 40 | background: #e7f2fa; 41 | } 42 | 43 | span.small_caps { 44 | font-variant: small-caps; 45 | } 46 | 47 | -------------------------------------------------------------------------------- /src/cr/nimble/affine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Utility functions for working with Affine spaces 17 | """ 18 | # pylint: disable=W0611 19 | 20 | from cr.nimble._src.affine import ( 21 | homogenize, 22 | homogenize_vec, 23 | homogenize_cols, 24 | ) -------------------------------------------------------------------------------- /src/cr/nimble/io/resource.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Resource Caching 17 | """ 18 | 19 | # pylint: disable=W0611 20 | 21 | 22 | from cr.nimble._src.io.resource import ( 23 | is_valid_url, 24 | CACHE_DIR, 25 | ensure_resource, 26 | ensure_cr_suite_resource, 27 | get_uri 28 | ) 29 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/svdpack/lansvd_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax import lax, jit, vmap, random 16 | import jax.numpy as jnp 17 | from jax.numpy.linalg import norm 18 | 19 | 20 | def refine_bounds(D, bnd, tol): 21 | """Examins the gap structure of Ritz eigen values and refines error bounds 22 | """ 23 | return bnd -------------------------------------------------------------------------------- /docs/source/compression.rst: -------------------------------------------------------------------------------- 1 | .. _api:compression: 2 | 3 | Data Compression 4 | =============================== 5 | 6 | .. contents:: 7 | :depth: 2 8 | :local: 9 | 10 | 11 | The library comes with some basic data compression 12 | routines. They are helpful in simple use cases like: 13 | 14 | - compression of binary arrays 15 | - run length encoding 16 | - fixed length encoding of integers 17 | 18 | These routines are primarily based on 19 | `numpy` arrays and `bitarray` based 20 | compressed bit arrays. This module 21 | doesn't use JAX. 22 | 23 | 24 | .. currentmodule:: cr.nimble.compression 25 | 26 | 27 | .. autosummary:: 28 | :toctree: _autosummary 29 | 30 | count_runs_values 31 | expand_runs_values 32 | encode_int_arr_sgn_mag_fl 33 | decode_int_arr_sgn_mag_fl 34 | count_binary_runs 35 | encode_binary_arr 36 | decode_binary_arr 37 | binary_compression_ratio 38 | binary_space_saving_ratio 39 | 40 | -------------------------------------------------------------------------------- /docs/make2.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /tests/affine/test_affine.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | import cr.nimble.affine 4 | 5 | homogenize = jax.jit(cnb.affine.homogenize) 6 | homogenize_vec = jax.jit(cnb.affine.homogenize_vec) 7 | homogenize_cols = jax.jit(cnb.affine.homogenize_cols) 8 | 9 | def test_homogenize_vec(): 10 | x = jnp.array([1,2,3]) 11 | y = homogenize_vec(x) 12 | assert len(x) + 1 == len(y) 13 | assert y[-1] == 1 14 | 15 | def test_homogenize_vec2(): 16 | x = jnp.array([1,2,3]) 17 | y = homogenize(x) 18 | assert len(x) + 1 == len(y) 19 | assert y[-1] == 1 20 | 21 | def test_homogenize_cols(): 22 | x = jnp.array([[1,2,3],[4,5,6]]) 23 | y = homogenize_cols(x) 24 | assert x.shape[0] + 1 == y.shape[0] 25 | assert jnp.allclose(y[-1, :] , 1) 26 | 27 | 28 | def test_homogenize_cols2(): 29 | x = jnp.array([[1,2,3],[4,5,6]]) 30 | y = homogenize(x) 31 | assert x.shape[0] + 1 == y.shape[0] 32 | assert jnp.allclose(y[-1, :] , 1) 33 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Upload Package to PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [created] 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Install Python 3 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.8 17 | - name: Install dependencies 18 | run: | 19 | python --version 20 | python -m pip install --upgrade pip 21 | python -m pip --version 22 | python -m pip install -r requirements/requirements.txt 23 | python -m pip install -r requirements/requirements-tests.txt 24 | python -m pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine check dist/* 32 | twine upload dist/* 33 | -------------------------------------------------------------------------------- /docs/start.rst: -------------------------------------------------------------------------------- 1 | Quick Start 2 | =================== 3 | 4 | Platform Support 5 | ---------------------- 6 | 7 | ``cr-nimble`` can run on any platform supported by ``JAX``. 8 | We have tested ``cr-nimble`` on Mac and Linux platforms. 9 | 10 | ``JAX`` is not officially supported on Windows platforms at the moment. 11 | Although, it is possible to build it from source using Windows Subsystems for Linux. 12 | 13 | 14 | Installation 15 | ------------------------------- 16 | 17 | Installation from PyPI: 18 | 19 | .. code:: shell 20 | 21 | python -m pip install cr-nimble 22 | 23 | 24 | 25 | Directly from our GITHUB repository: 26 | 27 | .. code:: shell 28 | 29 | python -m pip install git+https://github.com/carnotresearch/cr-nimble.git 30 | 31 | .. note:: 32 | 33 | If you are on Windows, JAX is not yet officially supported. 34 | However, you can install an unofficial JAX build for windows 35 | from https://github.com/cloudhan/jax-windows-builder. 36 | This works quite well for development purposes. 37 | 38 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax import jit 16 | from jax.scipy import signal 17 | 18 | 19 | def vec_convolve(x, h): 20 | """1D full convolution based on a hack suggested by Jake Vanderplas 21 | 22 | See https://github.com/google/jax/discussions/7961 for details 23 | """ 24 | return signal.convolve(x[None], h[None])[0] 25 | 26 | vec_convolve_jit = jit(vec_convolve) 27 | 28 | -------------------------------------------------------------------------------- /docs/source/util.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ============================== 3 | 4 | .. contents:: 5 | :depth: 2 6 | :local: 7 | 8 | 9 | .. currentmodule:: cr.nimble 10 | 11 | 12 | Data Type Management 13 | --------------------------- 14 | 15 | 16 | .. autosummary:: 17 | :toctree: _autosummary 18 | 19 | promote_arg_dtypes 20 | check_shapes_are_equal 21 | canonicalize_dtype 22 | promote_to_complex 23 | promote_to_real 24 | 25 | System Information 26 | ----------------------------- 27 | 28 | .. autosummary:: 29 | :toctree: _autosummary 30 | 31 | platform 32 | is_cpu 33 | is_gpu 34 | is_tpu 35 | nbytes_live_buffers 36 | 37 | 38 | 39 | 2D Geometry 40 | ---------------------------- 41 | 42 | .. currentmodule:: cr.nimble 43 | 44 | .. rubric:: Points and Vectors 45 | 46 | 47 | .. autosummary:: 48 | :toctree: _autosummary 49 | 50 | point2d 51 | vec2d 52 | 53 | .. rubric:: Transformations 54 | 55 | .. autosummary:: 56 | :toctree: _autosummary 57 | 58 | rotate2d_cw 59 | rotate2d_ccw 60 | reflect2d 61 | -------------------------------------------------------------------------------- /src/cr/nimble/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Utility functions for working with Affine spaces 17 | """ 18 | # pylint: disable=W0611 19 | 20 | from cr.nimble._src.data.subspaces import ( 21 | random_subspaces, 22 | random_subspaces_jit, 23 | uniform_points_on_subspaces, 24 | uniform_points_on_subspaces_jit, 25 | two_subspaces_at_angle, 26 | two_subspaces_at_angle_jit, 27 | three_subspaces_at_angle, 28 | three_subspaces_at_angle_jit 29 | ) -------------------------------------------------------------------------------- /src/cr/nimble/_src/affine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax.numpy as jnp 16 | 17 | def homogenize_vec(x, value=1): 18 | assert x.ndim == 1 19 | return jnp.hstack((x, value)) 20 | 21 | def homogenize_cols(X, value=1): 22 | assert X.ndim == 2 23 | n = X.shape[-1] 24 | o = value * jnp.ones(n) 25 | return jnp.vstack((X, o)) 26 | 27 | 28 | def homogenize(X, value=1): 29 | if X.ndim == 1: 30 | return homogenize_vec(X, value) 31 | return homogenize_cols(X, value) -------------------------------------------------------------------------------- /tests/core/test_distance.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | M = 10 4 | p = 3 5 | N = 5 6 | 7 | A = jnp.zeros([M, p]) 8 | B = jnp.ones([M, p]) 9 | 10 | C = A.T 11 | D = B.T 12 | 13 | def test_1(): 14 | cnb.pairwise_sqr_l2_distances_rw(A, B) 15 | 16 | def test_2(): 17 | cnb.pairwise_sqr_l2_distances_cw(A, B) 18 | 19 | def test_a(): 20 | cnb.pairwise_l2_distances_rw(A, B) 21 | 22 | def test_b(): 23 | cnb.pairwise_l2_distances_cw(A, B) 24 | 25 | def test_c(): 26 | cnb.pdist_sqr_l2_rw(A) 27 | 28 | def test_d(): 29 | cnb.pdist_sqr_l2_cw(A) 30 | 31 | def test_e(): 32 | cnb.pdist_l2_rw(A) 33 | 34 | def test_f(): 35 | cnb.pdist_l2_cw(A) 36 | 37 | 38 | def test_3(): 39 | cnb.pairwise_l1_distances_rw(C, D) 40 | 41 | def test_4(): 42 | cnb.pairwise_l1_distances_cw(A, B) 43 | 44 | def test_5(): 45 | cnb.pdist_l1_rw(C) 46 | 47 | def test_6(): 48 | cnb.pdist_l1_cw(A) 49 | 50 | def test_7(): 51 | cnb.pairwise_linf_distances_rw(C, D) 52 | 53 | 54 | def test_8(): 55 | cnb.pairwise_linf_distances_cw(A, B) 56 | 57 | def test_9(): 58 | cnb.pdist_linf_rw(C) 59 | 60 | def test_10(): 61 | cnb.pdist_linf_cw(B) 62 | -------------------------------------------------------------------------------- /tests/core/test_ndarray.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | def test_arr_largest_index(): 4 | x = jnp.array([1, -2, -3, 2]) 5 | assert_equal(cnb.arr_largest_index(x), (jnp.array([2]), )) 6 | x = jnp.reshape(x, (2,2)) 7 | idx = cnb.arr_largest_index(x) 8 | print(idx) 9 | assert_equal(idx, (jnp.array([1]), jnp.array([0]))) 10 | 11 | 12 | def test_arr_l2norm(): 13 | x = jnp.zeros(10) 14 | assert_allclose(cnb.arr_l2norm(x), 0.) 15 | 16 | def test_arr_l2norm_sqr(): 17 | x = jnp.zeros(10) 18 | assert_allclose(cnb.arr_l2norm_sqr(x), 0.) 19 | 20 | def test_arr_vdot(): 21 | x = jnp.zeros(10) 22 | assert_allclose(cnb.arr_vdot(x, x), 0.) 23 | 24 | 25 | 26 | 27 | @pytest.mark.parametrize("x,y", [ ([1, 0], [1, 0]), 28 | ([1], [1, 0]), 29 | ([1 + 0j, 0], [1+0j, 0]), 30 | ([1 + 1j, 0], [1+1j, 0]), 31 | ([1 + 1j, 2-3j], [1, 0]), 32 | ([1, -4], [1+2j, -3j]), 33 | ]) 34 | def test_rdot(x, y): 35 | x = jnp.asarray(x) 36 | y = jnp.asarray(y) 37 | x1 = cnb.arr2vec(x) 38 | y1 = cnb.arr2vec(y) 39 | expected = jnp.sum(jnp.conjugate(x1) * y1) 40 | expected = jnp.real(expected) 41 | assert_almost_equal(cnb.arr_rdot(x, y), expected) 42 | 43 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dls.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for solving dense linear systems 16 | """ 17 | 18 | from jax.scipy.linalg import solve 19 | from jax.numpy.linalg import lstsq 20 | 21 | def mult_with_submatrix(A, columns, x): 22 | """Computes :math:`b = A[:, I] x` 23 | """ 24 | A = A[:, columns] 25 | return A @ x 26 | 27 | 28 | def solve_on_submatrix(A, columns, b): 29 | """Solves the problem :math:`A[:, I] x = b` where I is an 30 | index set of selected columns 31 | """ 32 | A = A[:, columns] 33 | x, r_norms, rank, s = lstsq(A, b) 34 | r = b - A @ x 35 | return x, r 36 | 37 | -------------------------------------------------------------------------------- /src/cr/nimble/dsp/signals.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-Present CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Synthetic Signals 17 | """ 18 | # pylint: disable=W0611 19 | 20 | from cr.nimble._src.dsp.synthetic_signals import ( 21 | chirp, 22 | chirp_centered, 23 | pulse, 24 | gaussian, 25 | gaussian_pulse, 26 | decaying_sine_wave, 27 | transient_sine_wave, 28 | picket_fence, 29 | heavi_sine, 30 | bumps, 31 | blocks, 32 | doppler, 33 | ramp, 34 | cusp, 35 | sing, 36 | hi_sine, 37 | lo_sine, 38 | lin_chirp, 39 | two_chirp, 40 | quad_chirp, 41 | mish_mash, 42 | werner_sorrows, 43 | leopold, 44 | ) 45 | -------------------------------------------------------------------------------- /tests/svd/test_svd_utils.py: -------------------------------------------------------------------------------- 1 | from svd_setup import * 2 | 3 | # ambient dimension 4 | n = 10 5 | # subspace dimension 6 | d = 4 7 | # number of subspaces 8 | k = 2 9 | 10 | bases = cnb.data.random_subspaces_jit(cnb.KEYS[0], n, d, k) 11 | 12 | 13 | def test_orth(): 14 | A = jnp.array([[2, 0, 0], [0, 5, 0]]) # rank 2 array 15 | Q, rank = cnb.orth_jit(A) 16 | Q0 = jnp.array([[0., 1.], [1., 0.]]) 17 | assert_allclose(Q, Q0) 18 | 19 | def test_row_space(): 20 | A = jnp.array([[2, 0, 0], [0, 5, 0]]) # rank 2 array 21 | Q, rank = cnb.row_space_jit(A) 22 | Q0 = jnp.array([[0., 1.], [1., 0.], [0, 0]]) 23 | assert_allclose(Q, Q0) 24 | 25 | def test_null_space(): 26 | A = bases[0] 27 | Z, r = cnb.null_space_jit(A) 28 | Z = Z[:, r:] 29 | assert_allclose(A @ Z, 0) 30 | 31 | def test_left_null_space(): 32 | A = bases[0] 33 | Z, r = cnb.left_null_space_jit(A) 34 | Z = Z[:, r:] 35 | assert_allclose(Z.T @ A, 0, atol=atol) 36 | 37 | def test_effective_rank(): 38 | A = random.normal(cnb.KEYS[0], (3, 5)) 39 | r = cnb.effective_rank_jit(A) 40 | assert_array_equal(r, 3) 41 | 42 | def test_singular_values(): 43 | A = bases[0] 44 | s = cnb.singular_values(A) 45 | assert_allclose(s, 1., atol=atol) -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - 'tests/**' 9 | - 'src/**' 10 | pull_request: 11 | branches: 12 | - main 13 | 14 | jobs: 15 | build: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Install Python 3 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: 3.7 23 | - name: Setup timezone 24 | uses: zcong1993/setup-timezone@master 25 | with: 26 | timezone: UTC 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install -r requirements/requirements.txt 31 | pip install -r requirements/requirements-tests.txt 32 | - name: Install the development package itself 33 | run: | 34 | python -m pip install -e . 35 | - name: Run tests with pytest and generate coverage report 36 | run: pytest --cov=cr.nimble --cov-report=xml 37 | - name: "Upload coverage to Codecov" 38 | uses: codecov/codecov-action@v1 39 | with: 40 | fail_ci_if_error: true 41 | verbose: true 42 | -------------------------------------------------------------------------------- /tests/core/test_similarity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | import numpy as np 5 | from numpy.testing import (assert_almost_equal, assert_allclose, assert_, 6 | assert_equal, assert_raises, assert_raises_regex, 7 | assert_array_equal, assert_warns) 8 | 9 | 10 | import cr.nimble as cnb 11 | import jax.numpy as jnp 12 | 13 | def test_dist_to_gaussian_sim(): 14 | n = 10 15 | I = jnp.eye(n) 16 | distances = jnp.ones((n, n)) - I 17 | sim = cnb.dist_to_gaussian_sim(distances, sigma=1.) 18 | expected = 0.606531 * jnp.ones((n, n)) 19 | expected = cnb.set_diagonal(expected, 1) 20 | assert_allclose(sim, expected, atol=1e-5) 21 | 22 | def test_sqr_dist_to_gaussian_sim(): 23 | n = 10 24 | I = jnp.eye(n) 25 | distances = jnp.ones((n, n)) - I 26 | sqr_dist = distances ** 2 27 | sim = cnb.sqr_dist_to_gaussian_sim(sqr_dist, sigma=1.) 28 | expected = 0.606531 * jnp.ones((n, n)) 29 | expected = cnb.set_diagonal(expected, 1) 30 | assert_allclose(sim, expected, atol=1e-5) 31 | 32 | def test_eps_neighborhood_sim(): 33 | n = 10 34 | I = jnp.eye(n) 35 | distances = jnp.ones((n, n)) - I 36 | sim = cnb.eps_neighborhood_sim(distances, 0.5) 37 | assert_allclose(sim, I) 38 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-Present CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax import jit 18 | 19 | def sliding_windows_rw(x, wlen, overlap): 20 | """Converts a signal into sliding windows (per row) with the specified overlap 21 | """ 22 | step = wlen - overlap 23 | starts = jnp.arange(0, len(x) - wlen + 1, step) 24 | block = jnp.arange(wlen) 25 | idx = starts[:, None] + block[None, :] 26 | return x[idx] 27 | 28 | def sliding_windows_cw(x, wlen, overlap): 29 | """Converts a signal into sliding windows (per column) with the specified overlap 30 | """ 31 | return sliding_windows_rw(x, wlen, overlap).T 32 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/similarity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Similarity measures 16 | """ 17 | 18 | from jax import jit 19 | import jax.numpy as jnp 20 | 21 | def dist_to_gaussian_sim(dist, sigma): 22 | """Computes the Gaussian similarities for given distances 23 | """ 24 | d = dist**2 / (2 * sigma**2) 25 | return jnp.exp(-d) 26 | 27 | def sqr_dist_to_gaussian_sim(sqr_dist, sigma): 28 | """Computes the Gaussian similarities for given squared distances 29 | """ 30 | d = sqr_dist / (2 * sigma**2) 31 | return jnp.exp(-d) 32 | 33 | def eps_neighborhood_sim(dist, threshold): 34 | """Computes the epsilon neighborhood similarity 35 | """ 36 | return dist < threshold -------------------------------------------------------------------------------- /src/cr/nimble/_src/array.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax.numpy as jnp 16 | 17 | def hermitian(A): 18 | r"""Returns the Hermitian transpose of an array 19 | 20 | Args: 21 | A (jax.numpy.ndarray): An array 22 | 23 | Returns: 24 | (jax.numpy.ndarray): An array: :math:`A^H` 25 | """ 26 | return jnp.conjugate(A.T) 27 | 28 | def check_shapes_are_equal(array1, array2): 29 | """Raise an error if the shapes of the two arrays do not match. 30 | 31 | Raises: 32 | ValueError: if the shape of two arrays is not same 33 | """ 34 | if not array1.shape == array2.shape: 35 | raise ValueError('Input arrays must have the same shape.') 36 | return 37 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/latex.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | 18 | def to_tex_matrix(a, env='bmatrix'): 19 | a = np.asarray(a) 20 | assert a.ndim == 2, 'Input must be a matrix' 21 | # use numpy string conversion first 22 | text = str(a) 23 | # remove the brackets 24 | text = text.replace('[', '') 25 | text = text.replace(']', '') 26 | # split the text into lines 27 | lines = text.splitlines() 28 | # fill in the ampersands 29 | lines = [' & '.join(line.split()) for line in lines] 30 | # combine the lines 31 | body = '\n'.join([line + '\\\\' for line in lines]) 32 | # add the env block 33 | body = f'\\begin{{{env}}}\n{body}\n\\end{{{env}}}' 34 | return body 35 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/interpolation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax import jit 18 | import jax.numpy.fft as jfft 19 | 20 | def interpft(x, N): 21 | """Interpolates x to n points in Fourier Transform domain 22 | """ 23 | n = len(x) 24 | assert n < N 25 | a = jfft.fft(x) 26 | nyqst = (n + 1) // 2 27 | z = jnp.zeros(N -n) 28 | a1 = a[:nyqst+1] 29 | a2 = a[nyqst+1:] 30 | b = jnp.concatenate((a1, z, a2)) 31 | if n % 2 == 0: 32 | b = b.at[nyqst].set(b[nyqst] /2 ) 33 | b = b.at[nyqst + N -n].set(b[nyqst]) 34 | y = jfft.ifft(b) 35 | if jnp.isrealobj(x): 36 | y = jnp.real(y) 37 | # scale it up 38 | y = y * (N / n) 39 | return y 40 | 41 | -------------------------------------------------------------------------------- /tests/core/test_util.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | def test_promote_arg_dtypes(): 4 | res = cnb.promote_arg_dtypes(jnp.array(1), jnp.array(2)) 5 | expected = jnp.array([1.0, 2.0]) 6 | assert jnp.array_equal(res, expected) 7 | assert jnp.array_equal(cnb.promote_arg_dtypes(jnp.array(1)), jnp.array(1.)) 8 | cnb.promote_arg_dtypes(jnp.array(1), jnp.array(2.)) 9 | 10 | def test_canonicalize_dtype(): 11 | cnb.canonicalize_dtype(None) 12 | cnb.canonicalize_dtype(jnp.int32) 13 | 14 | 15 | def test_is_cpu(): 16 | assert_equal(cnb.is_cpu(), cnb.platform == 'cpu') 17 | 18 | def test_is_gpu(): 19 | assert_equal(cnb.is_gpu(), cnb.platform == 'gpu') 20 | 21 | def test_is_tpu(): 22 | assert_equal(cnb.is_tpu(), cnb.platform == 'tpu') 23 | 24 | def test_check_shapes_are_equal(): 25 | z = jnp.zeros(4) 26 | o = jnp.ones(4) 27 | cnb.check_shapes_are_equal(z, o) 28 | o = jnp.ones(5) 29 | with assert_raises(ValueError): 30 | cnb.check_shapes_are_equal(z, o) 31 | 32 | def test_promote_to_complex(): 33 | z = jnp.zeros(4) 34 | z = cnb.promote_to_complex(z) 35 | assert z.dtype == np.complex128 36 | 37 | def test_promote_to_real(): 38 | z = jnp.zeros(4, dtype=int) 39 | z = cnb.promote_to_real(z) 40 | assert z.dtype == np.float32 41 | 42 | 43 | def test_nbytes_live_buffers(): 44 | nbytes = cnb.nbytes_live_buffers() 45 | assert nbytes > 0 -------------------------------------------------------------------------------- /tests/linear_systems/test_solve_submatrix.py: -------------------------------------------------------------------------------- 1 | from ls_setup import * 2 | 3 | @pytest.mark.parametrize("K", [1, 2, 4, 8]) 4 | def test_solve1(K): 5 | M = 20 6 | N = 40 7 | Phi = cnb.gaussian_mtx(cnb.KEYS[0], M, N) 8 | cols = random.permutation(cnb.KEYS[1], jnp.arange(N))[:K] 9 | X = random.normal(cnb.KEYS[2], (K, 1)) 10 | Phi_I = Phi[:, cols] 11 | B_ref = Phi_I @ X 12 | B = cnb.mult_with_submatrix(Phi, cols, X) 13 | assert_allclose(B_ref, B) 14 | Z, R = cnb.solve_on_submatrix(Phi, cols, B) 15 | assert_allclose(Z, X, atol=atol, rtol=rtol) 16 | 17 | 18 | submat_multiplier = vmap(cnb.mult_with_submatrix, (None, 1, 1), 1) 19 | submat_solver = vmap(cnb.solve_on_submatrix, (None, 1, 1), (1, 1,)) 20 | 21 | @pytest.mark.parametrize("K", [1, 2, 4]) 22 | def test_solve2(K): 23 | M = 20 24 | N = 40 25 | Phi = cnb.gaussian_mtx(cnb.KEYS[0], M, N) 26 | # Number of signals 27 | S = 4 28 | # index sets for each signal 29 | omega = jnp.arange(N) 30 | keys = random.split(cnb.KEYS[1], S) 31 | set_gen = lambda key : random.permutation(key, omega)[:K] 32 | cols = vmap(set_gen, 0, 1)(keys) 33 | # signals [column wise] 34 | X = random.normal(cnb.KEYS[2], (K, S)) 35 | # measurements 36 | B = submat_multiplier(Phi, cols, X) 37 | # solutions 38 | Z, R = submat_solver(Phi, cols, B) 39 | # verify 40 | assert_allclose(Z, X, atol=atol, rtol=rtol) 41 | -------------------------------------------------------------------------------- /tests/core/test_metrics.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | n = 16 4 | 5 | z = jnp.zeros(n) 6 | o = jnp.ones(n) 7 | 8 | def test_mean_squared(): 9 | assert_almost_equal(cnb.mean_squared(z), 0) 10 | assert_almost_equal(cnb.mean_squared(o), 1) 11 | 12 | def test_mean_squared_error(): 13 | assert_almost_equal(cnb.mean_squared_error(z, z), 0) 14 | assert_almost_equal(cnb.mean_squared_error(z, o), 1) 15 | 16 | def test_root_mean_squared(): 17 | assert_almost_equal(cnb.root_mean_squared(z), 0) 18 | assert_almost_equal(cnb.root_mean_squared(o), 1) 19 | 20 | def test_root_mse(): 21 | assert_almost_equal(cnb.root_mse(z, z), 0) 22 | assert_almost_equal(cnb.root_mse(z, o), 1) 23 | 24 | 25 | def test_normalized_root_mse(): 26 | assert_almost_equal(cnb.normalized_root_mse(z, z, 'euclidean'), 0) 27 | assert_almost_equal(cnb.normalized_root_mse(z, z, 'min-max'), 0) 28 | assert_almost_equal(cnb.normalized_root_mse(z, z, 'mean'), 0) 29 | assert_almost_equal(cnb.normalized_root_mse(z, z, 'median'), 0) 30 | with assert_raises(ValueError): 31 | cnb.normalized_root_mse(z, z, 'abcd') 32 | 33 | 34 | def test_peak_signal_noise_ratio(): 35 | assert_almost_equal(cnb.peak_signal_noise_ratio(z, z), 156.5356, decimal=2) 36 | 37 | 38 | @pytest.mark.parametrize("x,expected", [(o, 168.5768), (z, 0)]) 39 | def test_signal_noise_ratio(x, expected): 40 | assert_almost_equal(cnb.signal_noise_ratio(x, x), expected, decimal=3) 41 | -------------------------------------------------------------------------------- /src/cr/nimble/subspaces.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from cr.nimble._src.subspaces import ( 16 | orth_complement, 17 | project_to_subspace, 18 | is_in_subspace, 19 | principal_angles_cos, 20 | principal_angles_cos_jit, 21 | principal_angles_rad, 22 | principal_angles_rad_jit, 23 | principal_angles_deg, 24 | principal_angles_deg_jit, 25 | smallest_principal_angle_cos, 26 | smallest_principal_angle_cos_jit, 27 | smallest_principal_angle_rad, 28 | smallest_principal_angle_rad_jit, 29 | smallest_principal_angle_deg, 30 | smallest_principal_angle_deg_jit, 31 | smallest_principal_angles_cos, 32 | smallest_principal_angles_cos_jit, 33 | smallest_principal_angles_rad, 34 | smallest_principal_angles_rad_jit, 35 | smallest_principal_angles_deg, 36 | smallest_principal_angles_deg_jit, 37 | subspace_distance, 38 | ) -------------------------------------------------------------------------------- /src/cr/nimble/compression.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Basic data compression routines 17 | """ 18 | # pylint: disable=W0611 19 | 20 | 21 | from cr.nimble._src.compression.binary_arrs import ( 22 | count_binary_runs, 23 | encode_binary_arr, 24 | decode_binary_arr, 25 | binary_compression_ratio, 26 | binary_space_saving_ratio, 27 | ) 28 | 29 | from cr.nimble._src.compression.fixed_length import ( 30 | encode_uint_arr_fl, 31 | decode_uint_arr_fl, 32 | encode_int_arr_sgn_mag_fl, 33 | decode_int_arr_sgn_mag_fl 34 | ) 35 | 36 | from cr.nimble._src.compression.run_length import ( 37 | count_runs_values, 38 | expand_runs_values 39 | ) 40 | 41 | from cr.nimble._src.compression.bits import ( 42 | float_to_int, 43 | int_to_float, 44 | int_to_bitarray, 45 | read_int_from_bitarray, 46 | float_to_bitarray, 47 | bitarray_to_float, 48 | read_float_from_bitarray 49 | ) 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CR-Nimble 2 | 3 | `CR-Nimble` consists of fast linear algebra 4 | and signal processing routines. 5 | Most of the routines have been implemented using 6 | Google JAX. Thus, they can be easily run on 7 | a variety of hardware (CPU, GPU, TPU). 8 | 9 | Functionality includes: 10 | 11 | * Utility functions for working with vectors, matrices and arrays 12 | * Linear algebra functions 13 | * Digital signal processing functions 14 | * Data compression functions 15 | * Test data generation functions 16 | 17 | 18 | Installation 19 | 20 | ```{shell} 21 | python -m pip install cr-nimble 22 | ``` 23 | 24 | For Windows, you can use unofficial JAX builds 25 | from [here](https://github.com/cloudhan/jax-windows-builder). 26 | 27 | Import 28 | 29 | ```{python} 30 | import cr.nimble as crn 31 | ``` 32 | 33 | See [documentation](https://cr-nimble.readthedocs.io) 34 | for library usage. 35 | 36 | `CR-Nimble` is part of 37 | [CR-Suite](https://carnotresearch.github.io/cr-suite/). 38 | 39 | Related libraries: 40 | 41 | * [CR-Wavelets](https://cr-wavelets.readthedocs.io) 42 | * [CR-Sparse](https://cr-sparse.readthedocs.io) 43 | 44 | 45 | [![codecov](https://codecov.io/gh/carnotresearch/cr-nimble/branch/main/graph/badge.svg?token=PX1MGTZ7VL)](https://codecov.io/gh/carnotresearch/cr-nimble) 46 | [![Unit Tests](https://github.com/carnotresearch/cr-nimble/actions/workflows/ci.yml/badge.svg)](https://github.com/carnotresearch/cr-nimble/actions/workflows/ci.yml) 47 | [![Documentation Status](https://readthedocs.org/projects/cr-nimble/badge/?version=latest)](https://cr-nimble.readthedocs.io/en/latest/?badge=latest) 48 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/svdpack/bdsqr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import NamedTuple 16 | import math 17 | 18 | from jax import lax, jit, vmap 19 | import jax.numpy as jnp 20 | from jax.numpy.linalg import norm 21 | from jax.scipy.linalg import svd 22 | 23 | def bdsqr(alpha, beta, k): 24 | """Computes the SVD of the bidiagonal matrix 25 | """ 26 | # prepare the k+1 x k bidiagonal matrix 27 | B = jnp.zeros((k+1, k)) 28 | # diagonal indices for k alpha entries 29 | indices = jnp.diag_indices(k) 30 | B = B.at[indices].set(alpha[:k]) 31 | # subdiagonal indices for k beta entries (from second row) 32 | rows, cols = indices 33 | rows = rows + 1 34 | B = B.at[(rows, cols)].set(beta[1:k+1]) 35 | # print(B) 36 | # perform full svd 37 | U, s, Vh = svd(B, full_matrices=False, compute_uv=True) 38 | # print(s) 39 | # pick the last row of U as the bounds 40 | bnd = U[-1, :k] 41 | # print(bnd) 42 | return U, s, Vh, bnd 43 | 44 | bdsqr_jit = jit(bdsqr, static_argnums=(2,)) -------------------------------------------------------------------------------- /src/cr/nimble/svd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Algorithms and utilities for solving SVD problems 17 | """ 18 | 19 | from cr.nimble._src.svdpack.reorth import ( 20 | reorth_mgs, 21 | reorth_mgs_jit, 22 | reorth_noop 23 | ) 24 | 25 | from cr.nimble._src.svdpack.bdsqr import ( 26 | bdsqr, 27 | bdsqr_jit 28 | ) 29 | 30 | from cr.nimble._src.svdpack.lanbpro_utils import ( 31 | LanBDOptions, 32 | LanBProState, 33 | lanbpro_options_init, 34 | do_elr, 35 | lanbpro_random_start, 36 | update_nu, 37 | update_mu, 38 | compute_ind, 39 | bpro_norm_estimate, 40 | ) 41 | 42 | from cr.nimble._src.svdpack.lanbpro import ( 43 | lanbpro_init, 44 | lanbpro_iteration, 45 | lanbpro_iteration_jit, 46 | lanbpro, 47 | lanbpro_jit, 48 | new_r_vec, 49 | new_p_vec, 50 | ) 51 | from cr.nimble._src.svdpack.lansvd_utils import ( 52 | refine_bounds, 53 | ) 54 | 55 | from cr.nimble._src.svdpack.lansvd import ( 56 | lansvd_simple, 57 | lansvd_simple_jit 58 | ) 59 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/quantization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax import jit 18 | 19 | 20 | def quantize_1(x, n): 21 | """Quantizes a signal to n bits where signal values are bounded by 1 22 | 23 | Args: 24 | x (jax.numpy.ndarray): A signal to be quantized 25 | n (int): number of bits for quantization 26 | 27 | Returns: 28 | (jax.numpy.ndarray): Quantized signal with integer values 29 | """ 30 | # scaling 31 | factor = 2**n-1 32 | x = factor * x 33 | # quantization 34 | x = jnp.round(x) 35 | # type conversion from float to int 36 | x = x.astype(int) 37 | return x 38 | 39 | 40 | def inv_quantize_1(x, n): 41 | """Inverse quantizes a signal from n bits 42 | 43 | Args: 44 | x (jax.numpy.ndarray): A signal to be inverse quantized 45 | n (int): number of bits for quantization 46 | 47 | Returns: 48 | (jax.numpy.ndarray): Quantized signal with integer values 49 | """ 50 | # scaling 51 | factor = 2**n-1 52 | x = x / factor 53 | return x 54 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/triangular.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from jax.scipy.linalg import solve_triangular 17 | 18 | def solve_Lx_b(L, b): 19 | """ 20 | Solves the system L x = b using back substitution 21 | """ 22 | return solve_triangular(L, b, lower=True) 23 | 24 | def solve_LTx_b(L, b): 25 | """ 26 | Solves the system L^T x = b using back substitution 27 | """ 28 | return solve_triangular(L, b, lower=True, trans='T') 29 | 30 | def solve_Ux_b(U, b): 31 | """ 32 | Solves the system U x = b using back substitution 33 | """ 34 | return solve_triangular(U, b) 35 | 36 | def solve_UTx_b(U, b): 37 | """ 38 | Solves the system U^T x = b using back substitution 39 | """ 40 | return solve_triangular(U, b, trans='T') 41 | 42 | 43 | def solve_spd_chol(L, b): 44 | """ 45 | Solves a symmetric positive definite system A x = b 46 | where A = L L' 47 | """ 48 | # We have to solve L L' x = b 49 | # We first solve L u = b 50 | u = solve_Lx_b(L, b) 51 | # We now solve L' x = u 52 | x = solve_LTx_b(L, u) 53 | return x 54 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/chol.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax.numpy as jnp 16 | 17 | from cr.nimble import solve_Lx_b, solve_LTx_b, solve_Ux_b, solve_UTx_b 18 | 19 | 20 | def cholesky_update_on_add_column(A, L, I, i): 21 | """Incrementally updates the Cholesky factorization :math:`G = L L^T` where :math:`G = A_I^T A_I with new column A_i` 22 | 23 | Note: 24 | 25 | Assumes that the first iteration is already done. 26 | """ 27 | a = A[:, i] 28 | v = A[:, I].T @ a 29 | m, n = L.shape 30 | z = jnp.zeros((m, 1)) 31 | w = solve_Lx_b(L, v) 32 | s = jnp.sqrt(a.T @ a - w.T @ w) 33 | L0 = jnp.hstack((L, z)) 34 | L1 = jnp.hstack((w.T, s)) 35 | L = jnp.vstack((L0, L1)) 36 | return L 37 | 38 | def cholesky_build_factor(A): 39 | """Builds the Cholesky factor :math:`L` of Gram matrix :math:`G = A^T A` as :math:`G = L L^T` incrementally column wise 40 | """ 41 | a = A[:, 0] 42 | L = jnp.reshape(jnp.sqrt(a @ a), (1,1)) 43 | for j in range(1, A.shape[1]): 44 | I = jnp.arange(0, j) 45 | L = cholesky_update_on_add_column(A, L, I, j) 46 | return L -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/energy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax import jit 18 | 19 | from cr.nimble import ( 20 | is_matrix, 21 | sqr_norms_l2_rw, 22 | sqr_norms_l2_cw) 23 | 24 | def find_first_signal_with_energy_le_rw(X, energy): 25 | """Returns the index of the first row which has energy less than the specified threshold 26 | """ 27 | assert is_matrix(X) 28 | energies = sqr_norms_l2_rw(X) 29 | index = jnp.argmax(energies <= energy) 30 | return index if energies[index] <= energy else jnp.array(-1) 31 | 32 | def find_first_signal_with_energy_le_cw(X, energy): 33 | """Returns the index of the first column which has energy less than the specified threshold 34 | """ 35 | assert is_matrix(X) 36 | energies = sqr_norms_l2_cw(X) 37 | index = jnp.argmax(energies <= energy) 38 | return index if energies[index] <= energy else jnp.array(-1) 39 | 40 | 41 | def energy(data, axis=-1): 42 | """ 43 | Computes the energy of the signal along the specified axis 44 | """ 45 | power = jnp.abs(data) ** 2 46 | return jnp.sum(power, axis) 47 | 48 | 49 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax import jit 18 | 19 | 20 | def dynamic_range(x): 21 | """Returns the ratio of largest and smallest values (by magnitude) in x (dB) 22 | 23 | Args: 24 | x (jax.numpy.ndarray): A signal 25 | 26 | Returns: 27 | (float): The dynamic range between largest and smallest value 28 | 29 | Note: 30 | This function is not suitable for sparse signals where some values are actually 0 31 | 32 | See Also: 33 | :func:`nonzero_dynamic_range` 34 | """ 35 | x = jnp.sort(jnp.abs(x)) 36 | return 20 * jnp.log10(x[-1] / x[0]) 37 | 38 | 39 | def nonzero_dynamic_range(x): 40 | """Returns the ratio of largest and smallest non-zero values (by magnitude) in x (dB) 41 | 42 | Args: 43 | x (jax.numpy.ndarray): A sparse/compressible signal 44 | 45 | Returns: 46 | (float): The dynamic range between largest and smallest nonzero value 47 | 48 | See Also: 49 | :func:`dynamic_range` 50 | """ 51 | x = jnp.sort(jnp.abs(x)) 52 | idx = jnp.argmax(x != 0) 53 | return 20 * jnp.log10(x[-1] / x[idx]) 54 | 55 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Some basic linear transformations 17 | """ 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from cr.nimble import promote_arg_dtypes 23 | 24 | def point2d(x,y): 25 | """A point in 2D vector space""" 26 | return jnp.array([x+0.,y]) 27 | 28 | def vec2d(x,y): 29 | """A vector in 2D vector space""" 30 | return jnp.array([x+0.,y]) 31 | 32 | def rotate2d_cw(theta): 33 | """Construct an operator that rotates a 2D vector by angle :math:`\theta` clock-wise 34 | """ 35 | Q = jnp.array([[jnp.cos(theta), jnp.sin(theta)], 36 | [-jnp.sin(theta), jnp.cos(theta)]]) 37 | return Q 38 | 39 | def rotate2d_ccw(theta): 40 | """Construct an operator that rotates a 2D vector by angle :math:`\theta` counter-clock-wise 41 | """ 42 | Q = jnp.array([[jnp.cos(theta), -jnp.sin(theta)], 43 | [jnp.sin(theta), jnp.cos(theta)]]) 44 | return Q 45 | 46 | def reflect2d(theta): 47 | """Construct an operator that reflects a 2D vector across a line defined at angle :math:`\theta/2` 48 | """ 49 | R = jnp.array([[jnp.cos(theta), jnp.sin(theta)], 50 | [jnp.sin(theta), -jnp.cos(theta)]]) 51 | return R 52 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/noise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax import random, jit 18 | 19 | from cr.nimble import promote_arg_dtypes 20 | from cr.nimble import arr_l2norm_sqr 21 | 22 | def awgn_at_snr_ms(key, signal, snr): 23 | """Generates noise for the signal at the specified SNR based on signal energy 24 | """ 25 | signal = jnp.asarray(signal) 26 | signal = promote_arg_dtypes(signal) 27 | n = signal.size 28 | energy = arr_l2norm_sqr(signal) 29 | mean_energy = energy / n 30 | mean_energy_db = 10*jnp.log10(mean_energy) 31 | noise_mean_energy_db = mean_energy_db - snr 32 | sigma = 10**(noise_mean_energy_db/20) 33 | noise = sigma * random.normal(key, signal.shape) 34 | return noise 35 | 36 | 37 | def awgn_at_snr_std(key, signal, snr): 38 | """Generates noise for the signal at the specified SNR based on std ratio 39 | 40 | Note: 41 | Signal is expected to be zero mean 42 | """ 43 | signal = jnp.asarray(signal) 44 | signal = promote_arg_dtypes(signal) 45 | sigma_s = jnp.std(signal) 46 | sigma_n = sigma_s*10**(-snr/20) 47 | noise = sigma_n * random.normal(key, signal.shape) 48 | return noise 49 | 50 | 51 | awgn_at_snr = awgn_at_snr_ms 52 | 53 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/compression/run_length.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Run Length Encoding 17 | 18 | 19 | References: 20 | 21 | https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi 22 | """ 23 | 24 | 25 | import numpy as np 26 | 27 | def count_runs_values(input_arr): 28 | """Computes run lengths of an array of integers 29 | """ 30 | # make sure that input is a numpy array 31 | input_arr = np.asarray(input_arr) 32 | n = len(input_arr) 33 | if n == 0: 34 | return (np.empty(0),np.empty(0)) 35 | if n == 1: 36 | return (np.ones(1), input_arr) 37 | # locate the changes 38 | changes = input_arr[1:] != input_arr[:-1] 39 | changes, = np.where(changes) 40 | # the last position should always be recorded as a change 41 | changes = np.append(changes, n-1) 42 | values = input_arr[changes] 43 | changes = np.insert(changes, 0, -1) 44 | # run lengths can be computed now 45 | runs = np.diff(changes) 46 | return runs, values 47 | 48 | def expand_runs_values(runs, values): 49 | """Decodes run lengths to form an array of integers 50 | """ 51 | return np.concatenate( 52 | [v * np.ones(r, dtype=np.int32) 53 | for r, v in zip(runs, values)]) 54 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/spectrum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax import jit 18 | import jax.numpy.fft as jfft 19 | 20 | from cr.nimble import next_pow_of_2 21 | 22 | def norm_freq(frequency, sampling_rate): 23 | """Returns the normalized frequency 24 | 25 | The Nyquist frequency is the half of the sampling rate. 26 | In normalized range, the Nyquist frequency has a value of 1. 27 | If sampling rate is 200 Hz and signal frequency is 28 | 20 Hz, then Nyquist frequency is 100 Hz and the 29 | normalized frequency is 0.2. 30 | 31 | Args: 32 | frequency (float): Frequency in Hz. 33 | sampling_rate (float): Sampling rate of signal in Hz. 34 | 35 | Returns: 36 | float: Normalized sampling frequency 37 | """ 38 | return 2.0 * frequency / sampling_rate 39 | 40 | 41 | def frequency_spectrum(x, dt=1.): 42 | """Frequency spectrum of 1D data using FFT 43 | """ 44 | n = len(x) 45 | nn = next_pow_of_2(n) 46 | X = jfft.fft(x, nn) 47 | f = jfft.fftfreq(nn, d=dt) 48 | X = jfft.fftshift(X) 49 | f = jfft.fftshift(f) 50 | return f, X 51 | 52 | def power_spectrum(x, dt=1.): 53 | """Power spectrum of 1D data using FFT 54 | """ 55 | n = len(x) 56 | T = dt * n 57 | f, X = frequency_spectrum(x, dt) 58 | nn = len(f) 59 | n2 = nn // 2 60 | f = f[n2:] 61 | X = X[n2:] 62 | sxx = (X * jnp.conj(X)) / T 63 | sxx = jnp.abs(sxx) 64 | return f, sxx 65 | -------------------------------------------------------------------------------- /tests/dsp/test_signals.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import math 3 | # jax imports 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | # crs imports 8 | import cr.nimble as cnb 9 | from cr.nimble.dsp import signals 10 | 11 | 12 | def test_chirp(): 13 | fs = 100 14 | f0 = 1 15 | f1 = 5 16 | T = 5 17 | initial_phase = 0 18 | t, sig = signals.chirp(fs, T, f0, f1, initial_phase) 19 | assert len(t) == len(sig) 20 | 21 | def test_chirp_centered(): 22 | fs = 100 23 | f0 = 1 24 | f1 = 5 25 | T = 5 26 | fc = (f0 + f1) / 2 27 | bw = f1 - f0 28 | initial_phase = 0 29 | t, sig = signals.chirp_centered(fs, T, fc, bw, initial_phase) 30 | assert len(t) == len(sig) 31 | 32 | def test_pulse(): 33 | fs = 100 34 | T = 16 35 | begin = 4 36 | end = 6 37 | init = -4 38 | t, sig = signals.pulse(fs, T, begin, end, init) 39 | assert len(t) == len(sig) 40 | 41 | 42 | def test_gaussian_pulse(): 43 | fs = 1000 44 | T = 4 45 | b = T/2 46 | fc = 5 47 | t, sig = signals.gaussian_pulse(fs, T, b, fc) 48 | assert len(t) == len(sig) 49 | t, real, imag = signals.gaussian_pulse(fs, T, b, fc, retquad=True) 50 | assert len(t) == len(real) 51 | t, sig, env = signals.gaussian_pulse(fs, T, b, fc, retenv=True) 52 | assert len(t) == len(sig) 53 | t, real, imag, env = signals.gaussian_pulse(fs, T, b, fc, retenv=True, retquad=True) 54 | assert len(t) == len(real) 55 | 56 | 57 | def test_decaying_sine_wave(): 58 | fs = 100 59 | T = 10 60 | f = 2 61 | alpha = 0.5 62 | t, sig = signals.decaying_sine_wave(fs, T, f, alpha) 63 | assert len(t) == len(sig) 64 | 65 | 66 | def test_transient_sine_wave(): 67 | fs = 100 68 | T = 10 69 | f = 2 70 | T = 16 71 | begin = 2 72 | end = 6 73 | init = -4 74 | t, sig = signals.transient_sine_wave(fs, T, f, begin, end, initial_time=init) 75 | assert len(t) == len(sig) 76 | 77 | 78 | def test_gaussian(): 79 | fs = 1000 80 | T = 20 81 | a = 1 82 | b = T/2 83 | t, sig = signals.gaussian(fs, T, b, a=a) 84 | assert len(t) == len(sig) 85 | -------------------------------------------------------------------------------- /.github/workflows/sphinx.yaml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - 'docs/**' 9 | - 'src/**' 10 | - 'examples/**' 11 | pull_request: 12 | branches: 13 | - main 14 | 15 | 16 | jobs: 17 | build: 18 | name: Sphinx Build and Publish 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v1 22 | - name: Install system packages 23 | run: | 24 | sudo apt-get update -y 25 | sudo apt-get install pandoc libgl1-mesa-dev optipng 26 | - name: Install Python 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: 3.8 30 | - name: Cache pip 31 | uses: actions/cache@v2 32 | with: 33 | path: ~/.cache/pip 34 | key: ${{ runner.os }}-pip-${{ hashFiles('docs/requirements.txt') }} 35 | restore-keys: | 36 | ${{ runner.os }}-pip- 37 | ${{ runner.os }}- 38 | - name: Install dependencies 39 | run: | 40 | python -m pip install --upgrade pip 41 | pip install -r requirements/requirements.txt 42 | pip install -r docs/requirements.txt 43 | - name: Install the package itself in development mode 44 | run: | 45 | pip install -e . 46 | - name: Debugging information 47 | run: | 48 | pandoc --version 49 | echo "github.ref:" ${{github.ref}} 50 | echo "github.event_name:" ${{github.event_name}} 51 | echo "github.head_ref:" ${{github.head_ref}} 52 | echo "github.base_ref:" ${{github.base_ref}} 53 | set -x 54 | git rev-parse --abbrev-ref HEAD 55 | git branch 56 | git branch -a 57 | git remote -v 58 | python --version 59 | pip --version 60 | pip list --not-required 61 | pip list 62 | - uses: ammaraskar/sphinx-problem-matcher@master 63 | - name: Build Sphinx docs 64 | run: | 65 | cd docs 66 | make html 67 | - name: Deploy 68 | uses: peaceiris/actions-gh-pages@v3 69 | with: 70 | github_token: ${{ secrets.GITHUB_TOKEN }} 71 | publish_dir: ./docs/_build/html 72 | -------------------------------------------------------------------------------- /tests/core/test_norms.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from cr.nimble.test_setup import * 3 | from cr.nimble import * 4 | 5 | # Integers 6 | X = jnp.array([[1, 2], 7 | [3, 4]]) 8 | 9 | # Floats 10 | X2 = jnp.array([[1, 2], 11 | [3, 4]], dtype=jnp.float32) 12 | 13 | def check_equal(actual, expected): 14 | print(actual, expected) 15 | x = jnp.equal(actual, expected) 16 | print(x) 17 | x = jnp.all(x) 18 | print(x) 19 | assert x 20 | 21 | 22 | def check_approx_equal(actual, expected, abs_err=1e-6): 23 | result = jnp.allclose(actual, expected, atol=abs_err) 24 | print(result) 25 | assert result 26 | 27 | 28 | def test_l1_norm_cw(): 29 | check_equal(norms_l1_cw(X), jnp.array([4, 6])) 30 | 31 | def test_l1_norm_rw(): 32 | check_equal(norms_l1_rw(X), jnp.array([3, 7])) 33 | 34 | 35 | def test_l2_norm_cw(): 36 | check_approx_equal(norms_l2_cw(X2), jnp.array([sqrt(10), sqrt(20)])) 37 | 38 | def test_l2_norm_rw(): 39 | check_approx_equal(norms_l2_rw(X2), jnp.array([sqrt(5), sqrt(25)])) 40 | 41 | 42 | def test_linf_norm_cw(): 43 | check_equal(norms_linf_cw(X), jnp.array([3, 4])) 44 | 45 | def test_linf_norm_rw(): 46 | check_equal(norms_linf_rw(X), jnp.array([2, 4])) 47 | 48 | 49 | def test_normalize_l1_cw(): 50 | Y = normalize_l1_cw(X2) 51 | check_equal(norms_l1_cw(Y), jnp.array([1., 1.])) 52 | 53 | def test_normalize_l1_rw(): 54 | Y = normalize_l1_rw(X2) 55 | check_equal(norms_l1_rw(Y), jnp.array([1., 1.])) 56 | 57 | 58 | def test_normalize_l2_cw(): 59 | Y = normalize_l2_cw(X2) 60 | check_approx_equal(norms_l2_cw(Y), jnp.array([1., 1.])) 61 | 62 | def test_normalize_l2_rw(): 63 | Y = normalize_l2_rw(X2) 64 | check_approx_equal(norms_l2_rw(Y), jnp.array([1., 1.])) 65 | 66 | 67 | def test_norm_l1(): 68 | n = 32 69 | x = jnp.arange(n) 70 | assert jnp.sum(x) == norm_l1(x) 71 | 72 | def test_sqr_norm_l2(): 73 | n = 32 74 | x = jnp.arange(n) 75 | assert jnp.sum(x**2) == sqr_norm_l2(x) 76 | 77 | def test_norm_l2(): 78 | n = 32 79 | x = jnp.arange(n) 80 | assert jnp.sum(x**2) == norm_l2(x)**2 81 | 82 | def test_norm_inf(): 83 | n = 32 84 | x = jnp.arange(n) 85 | assert n-1 == norm_linf(x) -------------------------------------------------------------------------------- /docs/source/vector.rst: -------------------------------------------------------------------------------- 1 | Vectors 2 | ======================== 3 | 4 | .. contents:: 5 | :depth: 2 6 | :local: 7 | 8 | 9 | Predicates 10 | ------------------- 11 | 12 | .. currentmodule:: cr.nimble 13 | 14 | .. autosummary:: 15 | :toctree: _autosummary 16 | 17 | is_scalar 18 | is_vec 19 | is_line_vec 20 | is_row_vec 21 | is_col_vec 22 | is_increasing_vec 23 | is_decreasing_vec 24 | is_nonincreasing_vec 25 | is_nondecreasing_vec 26 | has_equal_values_vec 27 | 28 | Unary Operations 29 | ---------------------------------- 30 | 31 | .. autosummary:: 32 | :toctree: _autosummary 33 | 34 | to_row_vec 35 | to_col_vec 36 | vec_unit 37 | vec_shift_right 38 | vec_rotate_right 39 | vec_shift_left 40 | vec_rotate_left 41 | vec_shift_right_n 42 | vec_rotate_right_n 43 | vec_shift_left_n 44 | vec_rotate_left_n 45 | vec_repeat_at_end 46 | vec_repeat_at_start 47 | vec_centered 48 | vec_unit_jit 49 | vec_repeat_at_end_jit 50 | vec_repeat_at_start_jit 51 | vec_centered_jit 52 | vec_swap_entries 53 | 54 | Norm 55 | ---------------------------------- 56 | 57 | .. autosummary:: 58 | :toctree: _autosummary 59 | 60 | norm_l1 61 | norm_l2 62 | norm_linf 63 | sqr_norm_l2 64 | normalize_l1 65 | normalize_l2 66 | normalize_linf 67 | 68 | 69 | Circular Buffer 70 | ------------------------- 71 | 72 | 73 | .. autosummary:: 74 | :toctree: _autosummary 75 | 76 | cbuf_push_left 77 | cbuf_push_right 78 | 79 | 80 | Binary Heap 81 | ------------------------- 82 | 83 | 84 | .. autosummary:: 85 | :toctree: _autosummary 86 | 87 | is_min_heap 88 | is_max_heap 89 | left_child_idx 90 | right_child_idx 91 | parent_idx 92 | build_max_heap 93 | largest_plr 94 | heapify_subtree 95 | delete_top_from_max_heap 96 | 97 | 98 | Miscellaneous 99 | ------------------------- 100 | 101 | 102 | .. autosummary:: 103 | :toctree: _autosummary 104 | 105 | vec_mag_desc 106 | vec_to_pmf 107 | vec_to_cmf 108 | cmf_find_quantile_index 109 | num_largest_coeffs_for_energy_percent -------------------------------------------------------------------------------- /src/cr/nimble/_src/compression/bits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import struct 17 | from bitarray import bitarray 18 | from bitarray.util import int2ba, ba2int 19 | 20 | 21 | def float_to_int(value): 22 | """Constructs an integer representation of a floating point number 23 | """ 24 | s = struct.pack('>f', value) 25 | return struct.unpack('>l', s)[0] 26 | 27 | 28 | def int_to_float(rep): 29 | """Constructs a floating point value from an integer representation 30 | """ 31 | s = struct.pack('>l', rep) 32 | return struct.unpack('>f', s)[0] 33 | 34 | 35 | def int_to_bitarray(value, len_bits=5): 36 | n = value.bit_length() 37 | ba = int2ba(value, length=n+1, signed=True) 38 | # some-bits to encode the length of integer value 39 | output = int2ba(n+1, length=len_bits) 40 | # now add the integer bit array 41 | output.extend(ba) 42 | return output 43 | 44 | def read_int_from_bitarray(a: bitarray, pos:int, len_bits:int =5): 45 | # read the six bit prefix 46 | e = pos + len_bits 47 | prefix = a[pos:e] 48 | n = ba2int(prefix) 49 | pos = e 50 | e = e + n 51 | suffix = a[pos:e] 52 | value = ba2int(suffix, signed=True) 53 | return value, e 54 | 55 | 56 | def float_to_bitarray(value): 57 | s = struct.pack('>f', value) 58 | ba = bitarray() 59 | ba.frombytes(s) 60 | return ba 61 | 62 | def bitarray_to_float(a: bitarray): 63 | bytes = a.tobytes() 64 | value = struct.unpack('>f', bytes)[0] 65 | return value 66 | 67 | def read_float_from_bitarray(a: bitarray, pos: int): 68 | e = pos+32 69 | value = a[pos:e] 70 | return bitarray_to_float(value), e 71 | 72 | -------------------------------------------------------------------------------- /docs/source/matrix.rst: -------------------------------------------------------------------------------- 1 | Matrices 2 | =================== 3 | 4 | .. contents:: 5 | :depth: 2 6 | :local: 7 | 8 | 9 | .. currentmodule:: cr.nimble 10 | 11 | Predicates 12 | ---------------------------------------------------------------------- 13 | 14 | .. autosummary:: 15 | :toctree: _autosummary 16 | 17 | is_matrix 18 | is_square 19 | is_symmetric 20 | is_hermitian 21 | is_positive_definite 22 | has_orthogonal_columns 23 | has_orthogonal_rows 24 | has_unitary_columns 25 | has_unitary_rows 26 | 27 | Matrix Multiplication 28 | ---------------------------- 29 | 30 | .. autosummary:: 31 | :toctree: _autosummary 32 | 33 | AH_v 34 | mat_transpose 35 | mat_hermitian 36 | diag_premultiply 37 | diag_postmultiply 38 | 39 | Matrix Parts 40 | ------------------------ 41 | 42 | .. autosummary:: 43 | :toctree: _autosummary 44 | 45 | off_diagonal_elements 46 | off_diagonal_min 47 | off_diagonal_max 48 | off_diagonal_mean 49 | block_diag 50 | mat_column_blocks 51 | 52 | 53 | Matrix Operations 54 | -------------------------- 55 | 56 | .. autosummary:: 57 | :toctree: _autosummary 58 | 59 | set_diagonal 60 | add_to_diagonal 61 | 62 | 63 | 64 | 65 | Row wise and column wise norms 66 | ----------------------------------- 67 | 68 | .. autosummary:: 69 | :toctree: _autosummary 70 | 71 | norms_l1_cw 72 | norms_l1_rw 73 | norms_l2_cw 74 | norms_l2_rw 75 | norms_linf_cw 76 | norms_linf_rw 77 | sqr_norms_l2_cw 78 | sqr_norms_l2_rw 79 | normalize_l1_cw 80 | normalize_l1_rw 81 | normalize_l2_cw 82 | normalize_l2_rw 83 | 84 | 85 | Pairwise Distances 86 | ------------------------- 87 | 88 | .. autosummary:: 89 | :toctree: _autosummary 90 | 91 | pairwise_sqr_l2_distances_rw 92 | pairwise_sqr_l2_distances_cw 93 | pairwise_l2_distances_rw 94 | pairwise_l2_distances_cw 95 | pdist_sqr_l2_rw 96 | pdist_sqr_l2_cw 97 | pdist_l2_rw 98 | pdist_l2_cw 99 | pairwise_l1_distances_rw 100 | pairwise_l1_distances_cw 101 | pdist_l1_rw 102 | pdist_l1_cw 103 | pairwise_linf_distances_rw 104 | pairwise_linf_distances_cw 105 | pdist_linf_rw 106 | pdist_linf_cw 107 | 108 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/standard_matrices.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import numpy as np 5 | import scipy 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | from jax import random 10 | from jax import jit 11 | 12 | from .norm import normalize_l2_cw 13 | from .util import promote_arg_dtypes 14 | from .array import hermitian 15 | 16 | def gaussian_mtx(key, N, D, normalize_atoms=True): 17 | """A dictionary/sensing matrix where entries are drawn independently from normal distribution. 18 | 19 | Args: 20 | key: a PRNG key used as the random key. 21 | N (int): Number of rows of the sensing matrix 22 | D (int): Number of columns of the sensing matrix 23 | normalize_atoms (bool): Whether the columns of sensing matrix are normalized 24 | (default True) 25 | 26 | Returns: 27 | (jax.numpy.ndarray): A Gaussian sensing matrix of shape (N, D) 28 | 29 | Example: 30 | 31 | >>> from jax import random 32 | >>> import cr.nimble as cnb 33 | >>> m, n = 8, 16 34 | >>> Phi = cnb.gaussian_mtx(random.PRNGKey(0), m, n) 35 | >>> print(Phi.shape) 36 | (8, 16) 37 | >>> print(cnb.norms_l2_cw(Phi)) 38 | [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] 39 | """ 40 | shape = (N, D) 41 | dict = random.normal(key, shape) 42 | if normalize_atoms: 43 | dict = normalize_l2_cw(dict) 44 | else: 45 | sigma = math.sqrt(N) 46 | dict = dict / sigma 47 | return dict 48 | 49 | 50 | def _pascal_lower(n): 51 | A = jnp.empty((n, n), dtype=jnp.int32) 52 | A = A.at[0, :].set(0) 53 | A = A.at[:, 0].set(1) 54 | for i in range(1, n): 55 | for j in range(1, i+1): 56 | A = A.at[i, j].set(A[i-1, j] + A[i-1, j-1]) 57 | return A 58 | 59 | def _pascal_sym(n): 60 | A = jnp.empty((n, n), dtype=jnp.int32) 61 | A = A.at[0, :].set(1) 62 | A = A.at[:, 0].set(1) 63 | for i in range(1, n): 64 | for j in range(1, n): 65 | A = A.at[i, j].set(A[i-1, j] + A[i, j-1]) 66 | return A 67 | 68 | def pascal(n, symmetric=False): 69 | """Returns a pascal matrix of size n \times n 70 | """ 71 | if symmetric: 72 | return _pascal_sym(n) 73 | else: 74 | return _pascal_lower(n) 75 | 76 | pascal_jit = jax.jit(pascal, static_argnums=(0, 1)) 77 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/scaling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | from jax import jit 19 | 20 | 21 | def scale_0_mean_1_var(data, axis=-1): 22 | """Normalizes a data vector (data - mu) / sigma 23 | 24 | Args: 25 | data (jax.numpy.ndarray): A data vector or array 26 | axis (int): For nd arrays, the axis along which the data normalization will be done 27 | 28 | Returns: 29 | (jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray): A tuple comprising of: 30 | * Normalized data vector/array 31 | * Mean value(s) 32 | * Standard deviation value(s) 33 | """ 34 | mu = jnp.mean(data, axis) 35 | data = data - mu 36 | variance = jnp.var(data, axis) 37 | sigma = jnp.sqrt(variance) 38 | data = data / sigma 39 | return data, mu, sigma 40 | 41 | scale_0_mean_1_var_jit = jit(scale_0_mean_1_var, static_argnums=(1,)) 42 | 43 | 44 | def scale_to_0_1(x): 45 | """Scales a signal to the range of 0 and 1 46 | 47 | Args: 48 | x (jax.numpy.ndarray): A signal to be scaled 49 | 50 | Returns: 51 | (jax.numpy.ndarray, float, float): A tuple comprising of: 52 | * Scaled signal 53 | * The amount of shift 54 | * The scale factor 55 | """ 56 | shift = jnp.min(x) 57 | x = x - shift 58 | scale = jnp.max(x) 59 | x = x / scale 60 | return x, shift, scale 61 | 62 | def descale_from_0_1(x, shift, scale): 63 | """Reverses the scaling of a signal from the range of 0 and 1 64 | 65 | Args: 66 | x (jax.numpy.ndarray): A signal to be scaled 67 | shift (float): The amount of shift 68 | scale (float): The scale factor 69 | 70 | Returns: 71 | jax.numpy.ndarray: Descaled signal 72 | """ 73 | x = x * scale 74 | x = x + shift 75 | return x 76 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/rq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import jax.numpy as jnp 17 | from jax.numpy.linalg import norm 18 | from jax import jit 19 | 20 | @jit 21 | def factor_mgs(A): 22 | n, m = A.shape 23 | # n rows 24 | # m dimension 25 | if n > m: 26 | raise Exception("Number of rows is larger than dimension") 27 | Q = jnp.empty([n, m]) 28 | R = jnp.zeros([n, n]) 29 | for k in range(0, n): 30 | # fill the k-th diagonal entry in R 31 | atom = A[k] 32 | norm_a = norm(atom) 33 | R = R.at[k, k].set(norm_a) 34 | # Initialize the k-th vector in Q 35 | q = atom / norm_a 36 | Q = Q.at[k].set(q) 37 | # Compute the inner product of new q vector with each of the remaining rows in A 38 | products = A[k+1:n, :] @ q.T 39 | # Place in k-th column of R 40 | R = R.at[k+1:n, k].set(products) 41 | # Subtract the contribution of previous q vector from all remaining rows of A. 42 | rr = R[k+1:n, k:k+1] 43 | update = -rr @ jnp.expand_dims(q, 0) 44 | A = A.at[k+1:n].add(update) 45 | return R, Q 46 | 47 | 48 | def update(R, Q, a, k): 49 | if k > 0: 50 | # make it a column vector 51 | b = jnp.expand_dims(a, 1) 52 | # Compute the projection of a on each of the previous rows in Q 53 | h = Q[:k, :] @ b 54 | # Store in the k-th row of R 55 | R = R.at[k, :k].set(jnp.squeeze(h)) 56 | # subtract the projections 57 | proj = h.T @ Q[:k, :] 58 | a = a - jnp.squeeze(proj) 59 | # compute norm 60 | a_norm = norm(a) 61 | # save it in the diagonal entry of R 62 | R = R.at[k,k].set(a_norm) 63 | # place the new normalized vector 64 | a = jnp.squeeze(a) 65 | a = a /a_norm 66 | Q = Q.at[k, :].set(a) 67 | return R, Q 68 | 69 | update = jit(update, static_argnums=(3,)) 70 | 71 | 72 | def solve(R, Q, x): 73 | pass -------------------------------------------------------------------------------- /src/cr/nimble/_src/discrete/number.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import jax.numpy as jnp 17 | from sympy.ntheory import factorint 18 | 19 | 20 | def next_pow_of_2(n): 21 | """ 22 | Returns the smallest integer greater than or equal to n which is a power of 2 23 | """ 24 | return 2**int(math.ceil(math.log2(n))) 25 | 26 | def is_integer(x): 27 | return jnp.mod(x, 1) == 0 28 | 29 | def is_positive_integer(x): 30 | return jnp.logical_and(x > 0, jnp.mod(x, 1) == 0) 31 | 32 | def is_negative_integer(x): 33 | return jnp.logical_and(x < 0, jnp.mod(x, 1) == 0) 34 | 35 | 36 | def is_odd(x): 37 | return jnp.mod(x, 2) == 1 38 | 39 | def is_even(x): 40 | return jnp.mod(x, 2) == 0 41 | 42 | def is_odd_natural(x): 43 | return jnp.logical_and(x > 0, jnp.mod(x, 2) == 1) 44 | 45 | 46 | def is_even_natural(x): 47 | return jnp.logical_and(x > 0, jnp.mod(x, 2) == 0) 48 | 49 | def is_power_of_2(x): 50 | return jnp.logical_not(jnp.bitwise_and(x, x - 1)) 51 | 52 | def is_perfect_square(x): 53 | return is_integer(jnp.sqrt(x)) 54 | 55 | 56 | def integer_factors_close_to_sqr_root(n): 57 | assert isinstance(n, int) 58 | a_max = math.floor(math.sqrt(n)) 59 | if n % a_max == 0: 60 | a = a_max 61 | b = n // a 62 | return a,b 63 | # get the prime factors 64 | factors_map = factorint(n) 65 | factors = factors_map.keys() 66 | candidates = {1} 67 | #print(a_max) 68 | for key in factors_map: 69 | for count in range(factors_map[key]): 70 | new_candidates = {key*c for c in candidates} 71 | candidates = candidates.union(new_candidates) 72 | # filter out larger candidates 73 | candidates = {c for c in candidates if c <= a_max} 74 | #print(candidates) 75 | # a is the last candidate 76 | candidates = list(candidates) 77 | candidates.sort() 78 | a = candidates[-1] 79 | b = n // a 80 | return a, b -------------------------------------------------------------------------------- /tests/svd/test_subspaces.py: -------------------------------------------------------------------------------- 1 | from svd_setup import * 2 | 3 | # ambient dimension 4 | n = 10 5 | # subspace dimension 6 | d = 4 7 | # number of subspaces 8 | k = 2 9 | 10 | bases = cnb.data.random_subspaces_jit(cnb.KEYS[0], n, d, k) 11 | 12 | 13 | def test_orth_complement(): 14 | A = bases[0] 15 | B = bases[1] 16 | C = subspaces.orth_complement(A, B) 17 | G = A.T @ C 18 | # all vectors in C must be orthogonal to vectors in A 19 | assert_allclose(G, 0, atol=atol) 20 | 21 | 22 | def test_principal_angles(): 23 | A = bases[0] 24 | B = bases[1] 25 | C = subspaces.orth_complement(A, B) 26 | angles = subspaces.principal_angles_cos_jit(A, C) 27 | # all vectors in C must be orthogonal to vectors in A 28 | assert_allclose(angles, 0, atol=atol) 29 | angles = subspaces.principal_angles_rad_jit(A, C) 30 | assert_allclose(angles, jnp.pi/2, atol=atol) 31 | angles = subspaces.principal_angles_deg_jit(A, C) 32 | assert_allclose(angles, 90, atol=atol) 33 | 34 | def test_smallest_principal_angle(): 35 | A = bases[0] 36 | B = bases[1] 37 | C = subspaces.orth_complement(A, B) 38 | angle = subspaces.smallest_principal_angle_cos_jit(A, C) 39 | assert_allclose(angle, 0, atol=atol) 40 | angle = subspaces.smallest_principal_angle_rad_jit(A, C) 41 | assert_allclose(angle, jnp.pi/2, atol=atol) 42 | angle = subspaces.smallest_principal_angle_deg_jit(A, C) 43 | assert_allclose(angle, 90, atol=atol) 44 | 45 | 46 | def test_smallest_principal_angles(): 47 | A = bases[0] 48 | lst = jnp.stack((A, A, A)) 49 | angles = subspaces.smallest_principal_angles_cos_jit(lst) 50 | o = jnp.ones((3,3)) 51 | z = jnp.zeros((3,3)) 52 | assert_allclose(o, angles) 53 | angles = subspaces.smallest_principal_angles_rad_jit(lst) 54 | # allow for 32-bit floating point errors 55 | assert_allclose(z, angles, atol=1e-1) 56 | angles = subspaces.smallest_principal_angles_deg_jit(lst) 57 | # allow for 32-bit floating point errors 58 | assert_allclose(z, angles, atol=1e-1) 59 | 60 | 61 | def test_project_to_subspace(): 62 | A = jnp.eye(6)[:, :3] 63 | v = jnp.arange(6) + 0. 64 | u = subspaces.project_to_subspace(A, v) 65 | u0 = jnp.array([0., 1., 2., 0., 0., 0.]) 66 | assert_allclose(u, u0) 67 | 68 | def test_is_in_subspace(): 69 | A = jnp.eye(6)[:, :3] 70 | v = jnp.arange(6) + 0. 71 | u = subspaces.project_to_subspace(A, v) 72 | assert subspaces.is_in_subspace(A, u) 73 | assert not subspaces.is_in_subspace(A, v) 74 | 75 | def test_subspace_distance(): 76 | A = jnp.eye(6)[:, :3] 77 | d = subspaces.subspace_distance(A, A) 78 | assert_array_equal(d, 0) 79 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/svdpack/reorth.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax import lax, jit, vmap, random 16 | import jax.numpy as jnp 17 | from jax.numpy.linalg import norm 18 | 19 | from cr.nimble import promote_arg_dtypes 20 | 21 | def reorth_mgs(Q, r, r_norm, indices, alpha=0.5): 22 | """Reorthogonalizes r against subset of columns in Q indexed by indices 23 | 24 | If norm of r reduces significantly, then a second reorthogonalization is performed. 25 | If the norm of r reduces again significantly, then it is assumed that r is 26 | numerically in the column span of Q and a zero vector is returned. 27 | """ 28 | n, k = Q.shape 29 | k2 = len(indices) 30 | assert k2 <= k 31 | Q, r, r_norm = promote_arg_dtypes(Q, r, r_norm) 32 | 33 | def orthogonalize_with_col(i, r): 34 | # pick the corresponding column 35 | q = Q[:, i] 36 | # compute the dot product 37 | t = jnp.dot(q, r) 38 | # subtract the projection from r 39 | r = r - t * q 40 | return r 41 | 42 | def for_body(i, r): 43 | # orthogonalize against a column only if it is selected 44 | return lax.cond(indices[i], 45 | lambda r : orthogonalize_with_col(i, r), 46 | lambda r : r, 47 | r 48 | ) 49 | # orthogonalize r against Q 50 | r = lax.fori_loop(0, k2, for_body, r) 51 | old_norm = r_norm 52 | r_norm = norm(r) 53 | 54 | def while_cond(state): 55 | r, r_norm, old_norm, iterations = state 56 | return r_norm < alpha * old_norm 57 | 58 | def while_body(state): 59 | r, r_norm, old_norm, iterations = state 60 | # orthogonalize r against Q 61 | r = lax.fori_loop(0, k2, for_body, r) 62 | old_norm = r_norm 63 | r_norm = norm(r) 64 | return r, r_norm, old_norm, iterations+1 65 | 66 | state = r, r_norm, old_norm, 1 67 | state = lax.while_loop(while_cond, while_body, state) 68 | r, r_norm, old_norm, iterations = state 69 | return r, r_norm, iterations 70 | 71 | reorth_mgs_jit = jit(reorth_mgs) 72 | 73 | def reorth_noop(r, r_norm): 74 | return r, r_norm, 0 75 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/spd/jacobi.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Solve a symmatric positive definite system :math:`A x = b` 17 | using Jacobi iterations 18 | """ 19 | 20 | from typing import NamedTuple 21 | 22 | from jax import lax, jit 23 | import jax.numpy as jnp 24 | 25 | 26 | class State(NamedTuple): 27 | x: jnp.ndarray 28 | """The solution""" 29 | e_norm_sqr: jnp.ndarray 30 | """The norm squared of the error between successive x approximations""" 31 | iterations: int 32 | """The number of iterations it took to complete""" 33 | 34 | def jacobi_solve(A, b, max_iters=None, res_norm_rtol=1e-4): 35 | """Solves the problem :math:`Ax = b` for a symmetric positive definite :math:`A` via conjugate gradients iterations 36 | """ 37 | # Boyd Conjugate Gradients slide 22 38 | m, n = A.shape 39 | 40 | # Get the diagonal elements 41 | D = jnp.diag(A) 42 | # Invert them 43 | Dinv = 1 / D 44 | # Set the diagonal elements to 0 to get E 45 | E = A.at[jnp.diag_indices(m)].set(0) 46 | # Compute B = D^{-1} E 47 | B = -jnp.multiply(Dinv[:, None], E) 48 | # Compute z D^{-1} b 49 | z = jnp.multiply(Dinv, b) 50 | 51 | b_norm_sqr = b.T @ b 52 | 53 | max_e_norm_sqr = b_norm_sqr * (res_norm_rtol ** 2) 54 | if max_iters is None: 55 | max_iters = 500 56 | 57 | def init(): 58 | x = z 59 | e_norm_sqr = x.T @ x 60 | return State(x=x, 61 | e_norm_sqr=e_norm_sqr, 62 | iterations=1) 63 | 64 | def iteration(state): 65 | # update the solution x 66 | x = B @ state.x + z 67 | # update the residual r 68 | r = x - state.x 69 | e_norm_sqr = r.T @ r 70 | # update state 71 | return State(x=x, 72 | e_norm_sqr=e_norm_sqr, 73 | iterations=state.iterations+1) 74 | 75 | def cond(state): 76 | # limit on residual norm 77 | a = state.e_norm_sqr > max_e_norm_sqr 78 | # limit on number of iterations 79 | b = state.iterations < max_iters 80 | c = jnp.logical_and(a, b) 81 | return c 82 | 83 | state = lax.while_loop(cond, iteration, init()) 84 | return state 85 | 86 | jacobi_solve_jit = jit(jacobi_solve, 87 | static_argnames=("max_iters", "res_norm_rtol")) 88 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/wht.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-Present CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Fast Walsh Hadamard Transforms 17 | """ 18 | 19 | from jax import lax, jit 20 | import jax.numpy as jnp 21 | 22 | 23 | 24 | @jit 25 | def fwht(X): 26 | """Computes the Fast Walsh Hadamard Transform over columns 27 | 28 | Args: 29 | X (jax.numpy.ndarray): The 1D real signal or 2D matrix where each column is a signal whose transform is to be computed 30 | 31 | Returns: 32 | jax.numpy.ndarray: The Fast Walsh Hadamard Transform coefficients of (columns of) X 33 | """ 34 | n = X.shape[0] 35 | # number of stages 36 | s = (n-1).bit_length() 37 | 38 | def init1(): 39 | Y = jnp.empty(X.shape, dtype=X.dtype) 40 | A = X[0::2] 41 | B = X[1::2] 42 | Y = Y.at[0::2].set(A + B) 43 | Y = Y.at[1::2].set(A - B) 44 | return (Y, 1, 2, 4) 45 | 46 | def body1(state): 47 | # gap between x entries 48 | # number of x entries 49 | X, count, gap, step = state 50 | Y = jnp.empty(X.shape, dtype=X.dtype) 51 | J = 0 52 | k = 0 53 | def body2(state): 54 | Y, J, k = state 55 | def body3(state): 56 | Y, j, k = state 57 | # compute the four parts 58 | a = X[j] 59 | b = X[j+gap] 60 | c = X[j+1] 61 | d = X[j+1+gap] 62 | Y = Y.at[k].set(a+b) 63 | Y = Y.at[k+1].set(a-b) 64 | Y = Y.at[k+2].set(c-d) 65 | Y = Y.at[k+3].set(c+d) 66 | return (Y, j+2, k+4) 67 | def cond3(state): 68 | j = state[1] 69 | return j < J+gap-1 70 | # the loop 71 | init3 = (Y, J, k) 72 | Y, j, k = lax.while_loop(cond3, body3, init3) 73 | return (Y, J + step, k) 74 | 75 | def cond2(state): 76 | k = state[2] 77 | return k < n - 1 78 | 79 | init2 = Y, J, 0 80 | Y, J, k = lax.while_loop(cond2, body2, init2) 81 | 82 | return (Y, count+1, 2*gap, 2*step) 83 | 84 | def cond1(state): 85 | count = state[1] 86 | return count < s 87 | 88 | state = lax.while_loop(cond1, body1, init1()) 89 | return state[0] 90 | -------------------------------------------------------------------------------- /tests/core/test_vector.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | from cr.nimble import * 3 | 4 | def test_scalar(): 5 | x = jnp.array(2) 6 | assert is_scalar(x) 7 | 8 | def test_line_vec(): 9 | x = jnp.array([2,3]) 10 | assert is_line_vec(x) 11 | assert is_vec(x) 12 | 13 | 14 | def test_row_vec(): 15 | x = jnp.array([[2,3]]) 16 | assert is_row_vec(x) 17 | assert is_vec(x) 18 | 19 | def test_col_vec(): 20 | x = jnp.array([[2],[3]]) 21 | assert is_col_vec(x) 22 | assert is_vec(x) 23 | 24 | def test_to_row_vec(): 25 | x = jnp.array([2,3]) 26 | assert is_line_vec(x) 27 | x = to_row_vec(x) 28 | assert is_row_vec(x) 29 | 30 | def test_to_col_vec(): 31 | x = jnp.array([2,3]) 32 | assert is_line_vec(x) 33 | x = to_col_vec(x) 34 | assert is_col_vec(x) 35 | 36 | 37 | def test_vec_shift_right(): 38 | n = 10 39 | x = jnp.arange(n) 40 | y = vec_shift_right(x) 41 | assert_allclose(y[1:], x[:-1]) 42 | assert y[0] == 0 43 | 44 | def test_vec_rotate_right(): 45 | n = 10 46 | x = jnp.arange(n) 47 | y = vec_rotate_right(x) 48 | assert_allclose(y[1:], x[:-1]) 49 | assert y[0] == x[-1] 50 | 51 | def test_vec_shift_left(): 52 | n = 10 53 | x = jnp.arange(n) 54 | y = vec_shift_left(x) 55 | assert_allclose(y[:-1], x[1:]) 56 | assert y[-1] == 0 57 | 58 | def test_vec_rotate_left(): 59 | n = 10 60 | x = jnp.arange(n) 61 | y = vec_rotate_left(x) 62 | assert_allclose(y[:-1], x[1:]) 63 | assert y[-1] == x[0] 64 | 65 | def test_vec_shift_right_n(): 66 | n = 10 67 | x = jnp.arange(n) 68 | y = vec_shift_right_n(x, 1) 69 | assert_allclose(y[1:], x[:-1]) 70 | assert y[0] == 0 71 | 72 | def test_vec_rotate_right_n(): 73 | n = 10 74 | x = jnp.arange(n) 75 | y = vec_rotate_right_n(x, 1) 76 | assert_allclose(y[1:], x[:-1]) 77 | assert y[0] == x[-1] 78 | 79 | def test_vec_shift_left_n(): 80 | n = 10 81 | x = jnp.arange(n) 82 | y = vec_shift_left_n(x, 1) 83 | assert_allclose(y[:-1], x[1:]) 84 | assert y[-1] == 0 85 | 86 | def test_vec_rotate_left_n(): 87 | n = 10 88 | x = jnp.arange(n) 89 | y = vec_rotate_left_n(x, 1) 90 | assert_allclose(y[:-1], x[1:]) 91 | assert y[-1] == x[0] 92 | 93 | 94 | def test_vec_repeat_at_end(): 95 | n = 10 96 | x = jnp.arange(n) 97 | p = 4 98 | y = vec_repeat_at_end(x, p) 99 | z = jnp.arange(p) 100 | assert_array_equal(y, jnp.concatenate((x, z))) 101 | 102 | def test_vec_repeat_at_start(): 103 | n = 10 104 | x = jnp.arange(n) 105 | p = 4 106 | y = vec_repeat_at_start(x, p) 107 | z = jnp.arange(n-p, n) 108 | assert_array_equal(y, jnp.concatenate((z, x))) 109 | 110 | def test_vec_centered(): 111 | n = 10 112 | x = jnp.arange(n) 113 | y = vec_centered(x, 8) 114 | assert_array_equal(y, x[1:-1]) 115 | 116 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Blocked stuff 132 | .vscode/ 133 | data/videos/*.mp4 134 | *.avi 135 | *.mp3 136 | *.webm 137 | *.ogg 138 | *.mp4 139 | 140 | examples/junk/* 141 | workbook.ipynb 142 | test_*.ipynb 143 | demo_*.ipynb 144 | workbook_*.ipynb 145 | 146 | demo_*.py 147 | _autosummary 148 | docs/source/_autosummary/* 149 | 150 | # Data files 151 | record_*.csv 152 | 153 | *.mat 154 | 155 | tmp/* 156 | junk/* 157 | 158 | docs/gallery/ 159 | 160 | # airspeed velocity benchmark results 161 | .asv/* 162 | paper/paper.pdf 163 | 164 | *.html 165 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/svdpack/lansvd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import NamedTuple 16 | import math 17 | 18 | from jax import lax, jit, vmap, random 19 | import jax.numpy as jnp 20 | from jax.numpy.linalg import norm 21 | 22 | 23 | import cr.nimble as cnb 24 | 25 | from .lanbpro import (LanBDOptions, LanBProState, 26 | lanbpro_options_init, lanbpro_init, lanbpro_iteration, lanbpro) 27 | 28 | from .bdsqr import bdsqr 29 | from .lansvd_utils import refine_bounds 30 | 31 | class LanSVDState(NamedTuple): 32 | """State for Lan SVD iterations 33 | """ 34 | bpro_state: LanBProState 35 | "State for the LanBPro algorithm" 36 | n_converged : int = 0 37 | "Number of converged eigen values" 38 | 39 | def lansvd_simple(A, k, p0): 40 | """Returns the k largest singular values and corresponding vectors 41 | """ 42 | m, n = A.shape 43 | # maximum number of U/V columns 44 | lanmax = min(m, n) 45 | lanmax = min (lanmax, max(20, k*4)) 46 | assert k <= lanmax, "k must be less than lanmax" 47 | 48 | eps = jnp.finfo(float).eps 49 | tol = 16*eps 50 | n_converged = 0 51 | # number of iterations for LanBPro algorithm 52 | j = min(k + max(8, k), lanmax) 53 | options = lanbpro_options_init(lanmax) 54 | state = lanbpro_init(A, lanmax, p0, options) 55 | # carry out the lanbpro iterations 56 | state = lax.fori_loop(1, j, 57 | lambda i, state: lanbpro_iteration(A, state, options), 58 | state) 59 | # norm of the residual 60 | res_norm = norm(state.p) 61 | # compute SVD of the bidiagonal matrix ritz values and vectors 62 | P, S, Qh, bot = bdsqr(state.alpha, state.beta, j) 63 | # estimate of A norm 64 | anorm = S[0] 65 | # simple error bounds on singular values 66 | bnd = res_norm * jnp.abs(bot) 67 | # now refine the bounds 68 | bnd = refine_bounds(S**2, bnd, n*eps*anorm) 69 | # count the number of converged singular values 70 | converged = jnp.less(bnd, jnp.abs(S)) 71 | # make sure that all indices beyond min(j,k) are marked as non-converged 72 | converged = converged.at[min(j,k):].set(False) 73 | # find the index of first non-converged singular value 74 | n_converged = jnp.argmin(converged) 75 | n_converged = jnp.where(converged[n_converged], len(converged), n_converged) 76 | U = state.U[:, :j] 77 | V = state.V[:, :j] 78 | # keep only the first k ritz vectors 79 | P = P[:, :k] 80 | Q = Qh.T[:, :k] 81 | U = state.U[:, :j] @ P[:j, :] 82 | V = state.V[:, :j] @ Q 83 | return U, S[:k], V, bnd, n_converged, state 84 | 85 | 86 | 87 | lansvd_simple_jit = jit(lansvd_simple, static_argnums=(1,)) 88 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/toeplitz.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax import jit 16 | import jax.numpy as jnp 17 | import jax.numpy.fft as jfft 18 | 19 | def toeplitz_mat(c, r): 20 | """Constructs a Toeplitz matrix 21 | 22 | 23 | """ 24 | c = jnp.asarray(c) 25 | r = jnp.asarray(r) 26 | m = len(c) 27 | n = len(r) 28 | # assert c[0] == r[0] 29 | w = jnp.concatenate((c[::-1], r[1:])) 30 | # backwards indices 31 | a = -jnp.arange(m, dtype=int) 32 | # print(a) 33 | # forwards indices 34 | b = jnp.arange(m-1,m+n-1, dtype=int) 35 | # print(b) 36 | # combine indices for the toeplitz matrix 37 | indices = a[:, None] + b[None, :] 38 | # print(indices) 39 | # form the toeplitz matrix 40 | mat = w[indices] 41 | return mat 42 | 43 | 44 | def toeplitz_mult(w, x): 45 | """Multiplies a Toeplitz matrix with a vector 46 | 47 | Note: 48 | Only real matrices and vectors are supported 49 | """ 50 | c, r = w 51 | m = len(c) 52 | n = len(r) 53 | p = m + n - 1 54 | if x.ndim == 1: 55 | x = x.reshape(-1, 1) 56 | ww = jnp.concatenate((c, r[-1:0:-1])) 57 | wf = jfft.rfft(ww).reshape(-1, 1) 58 | xf = jfft.rfft(x, n=p, axis=0) 59 | yf = wf * xf 60 | y = jfft.irfft(yf, n=p, axis=0) 61 | # drop extra values 62 | y = y[:m, :] 63 | # drop extra dimension if required 64 | return jnp.squeeze(y) 65 | 66 | def circulant_mat(c): 67 | """Constructs a circulant matrix 68 | """ 69 | # make sure that the array is flattened 70 | c = jnp.asarray(c).ravel() 71 | m = len(c) 72 | # extend c for the toeplitz structure 73 | cc = jnp.concatenate((c[::-1], c[:0:-1])) 74 | # backwards indices 75 | a = -jnp.arange(m, dtype=int) 76 | # forwards indices 77 | b = jnp.arange(m-1,m+m-1, dtype=int) 78 | # combine indices for the toeplitz matrix 79 | indices = a[:, None] + b[None, :] 80 | # form the circulant matrix 81 | mat = cc[indices] 82 | return mat 83 | 84 | 85 | 86 | def circulant_mult(c, x): 87 | """Multiplies a circulant matrix with a vector 88 | 89 | Note: 90 | Only real matrices and vectors are supported 91 | """ 92 | if x.ndim == 1: 93 | x = x.reshape(-1, 1) 94 | # make sure that the array is flattened 95 | c = jnp.asarray(c).ravel() 96 | m = len(c) 97 | cf = jfft.rfft(c).reshape(-1, 1) 98 | xf = jfft.rfft(x, n=m, axis=0) 99 | yf = xf * cf 100 | y = jfft.irfft(yf, n=m, axis=0) 101 | # drop extra dimension if required 102 | return jnp.squeeze(y) 103 | 104 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/compression/fixed_length.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """ 17 | Fixed Length Encoding of arrays 18 | """ 19 | 20 | from bitarray import bitarray 21 | from bitarray.util import int2ba, ba2int 22 | import numpy as np 23 | 24 | 25 | def encode_uint_arr_fl(input_arr, bits_per_sample: int): 26 | """Encodes an array of unsigned integers to a bit array using a fixed number of bits per sample 27 | """ 28 | a = bitarray() 29 | max_val = (1 << bits_per_sample) - 1 30 | # make sure that the values are clipped 31 | input_arr = np.clip(input_arr, 0, max_val) 32 | for value in input_arr: 33 | value = int(value) 34 | a.extend(int2ba(value, bits_per_sample)) 35 | return a 36 | 37 | def decode_uint_arr_fl(input_bit_arr : bitarray, bits_per_sample: int): 38 | """Decodes an array of unsigned integers from a bit array using a fixed number of bits per sample 39 | """ 40 | a = input_bit_arr 41 | # number of bits 42 | nbits = len(a) 43 | # number of samples 44 | n = nbits // bits_per_sample 45 | output = np.empty(n, dtype=np.int) 46 | idx = 0 47 | for i in range(n): 48 | # read the value 49 | value = ba2int(a[idx:idx+bits_per_sample]) 50 | idx += bits_per_sample 51 | output[i] = value 52 | return output 53 | 54 | 55 | def encode_int_arr_sgn_mag_fl(input_arr, bits_per_sample: int): 56 | """Encodes an array of integers to a bit array using a sign bit and a fixed number of bits per sample for magnitude 57 | """ 58 | a = bitarray() 59 | max_val = (1 << bits_per_sample) - 1 60 | for value in input_arr: 61 | value = int(value) 62 | sign, value = (0, value) if value >= 0 else (1, -value) 63 | # make sure that the values are clipped 64 | value = max_val if value > max_val else value 65 | a.append(sign) 66 | a.extend(int2ba(value, bits_per_sample)) 67 | return a 68 | 69 | 70 | def decode_int_arr_sgn_mag_fl(input_bit_arr : bitarray, bits_per_sample: int): 71 | """Decodes an array of integers from a bit array using a sign bit and a fixed number of bits per sample for magnitude 72 | """ 73 | a = input_bit_arr 74 | # number of bits 75 | nbits = len(a) 76 | # number of samples 77 | n = nbits // (bits_per_sample + 1) 78 | output = np.empty(n, dtype=np.int) 79 | idx = 0 80 | for i in range(n): 81 | # read the sign bit 82 | s = a[idx] 83 | idx += 1 84 | # read the value 85 | value = ba2int(a[idx:idx+bits_per_sample]) 86 | idx += bits_per_sample 87 | # combine sign and value 88 | value = -value if s else value 89 | output[i] = value 90 | return output 91 | 92 | -------------------------------------------------------------------------------- /tests/orthogonalization/test_householder.py: -------------------------------------------------------------------------------- 1 | from cr.nimble.test_setup import * 2 | 3 | householder_vec = jit(cnb.householder_vec) 4 | householder_matrix = jit(cnb.householder_matrix) 5 | householder_premultiply = jit(cnb.householder_premultiply) 6 | householder_postmultiply = jit(cnb.householder_postmultiply) 7 | householder_ffm_backward_accum = jit(cnb.householder_ffm_backward_accum) 8 | householder_qr = jit(cnb.householder_qr) 9 | householder_qr_packed = jit(cnb.householder_qr_packed) 10 | householder_split_qf_r = jit(cnb.householder_split_qf_r) 11 | householder_ffm_premultiply = jit(cnb.householder_ffm_premultiply) 12 | householder_ffm_to_wy = jit(cnb.householder_ffm_to_wy) 13 | A = jnp.array([[12.0,-51, 4], [6, 167, -68], [-4, 24, -41]]) 14 | 15 | def test_vec(): 16 | x = A[:, 0] 17 | v, beta = householder_vec(x) 18 | v_expected = jnp.array([ 1., -3., 2.]) 19 | assert jnp.allclose(v, v_expected) 20 | assert jnp.isclose(beta, 0.14285715) 21 | 22 | def test_vec_(): 23 | x = A[:, 0] 24 | v, beta = cnb.householder_vec_(x) 25 | v_expected = jnp.array([ 1., -3., 2.]) 26 | assert jnp.allclose(v, v_expected) 27 | assert jnp.isclose(beta, 0.14285715) 28 | 29 | def test_vec2(): 30 | x = jnp.array([1.0, 0, 0]) 31 | v, beta = cnb.householder_vec_(x) 32 | 33 | def test_vec3(): 34 | x = jnp.array([0, 1., 0]) 35 | v, beta = cnb.householder_vec_(x) 36 | 37 | def test_vec4(): 38 | x = jnp.array([-2, 0., 0]) 39 | v, beta = cnb.householder_vec_(x) 40 | 41 | def test_vec5(): 42 | x = jnp.array([1, 1., 0]) 43 | v, beta = cnb.householder_vec_(x) 44 | 45 | 46 | def test_vec6(): 47 | x = jnp.array([-1, 1., 0]) 48 | v, beta = cnb.householder_vec_(x) 49 | 50 | def test_vec7(): 51 | x = jnp.array([-1]) 52 | v, beta = cnb.householder_vec_(x) 53 | 54 | def test_vec8(): 55 | x = jnp.array([-1]) 56 | v, beta = householder_vec(x) 57 | 58 | def test_matrix(): 59 | x = A[:, 0] 60 | Q = jnp.array([[ 6., 3., -2.], 61 | [ 3., -2., 6.], 62 | [-2., 6., 3.]]) / 7 63 | P = householder_matrix(x) 64 | assert jnp.allclose(Q, P) 65 | 66 | 67 | def test_premultiply(): 68 | x = A[:, 0] 69 | v, beta = householder_vec(x) 70 | A2 = householder_premultiply(v, beta, A) 71 | expected = jnp.array([[ 14., 21., -14.], 72 | [ 0., -49., -14.], 73 | [ 0., 168., -77.]]) 74 | assert jnp.allclose(expected, A2, atol=1e-3) 75 | 76 | def test_postmultiply(): 77 | x = A[:, 0] 78 | v, beta = householder_vec(x) 79 | A2 = householder_postmultiply(v, beta, A.T) 80 | expected = jnp.array([[ 14., 21., -14.], 81 | [ 0., -49., -14.], 82 | [ 0., 168., -77.]]) 83 | assert jnp.allclose(expected.T, A2, atol=1e-3) 84 | 85 | 86 | def test_qr(): 87 | Q, R = householder_qr(A) 88 | assert jnp.allclose(A, Q @ R, atol=1e-3) 89 | 90 | 91 | def test_fftm_to_wy(): 92 | Q, R = householder_qr(A) 93 | A2 = householder_qr_packed(A) 94 | W, Y = householder_ffm_to_wy(A2) 95 | I = jnp.eye(A.shape[0]) 96 | Qb = I - W @ Y.T 97 | assert jnp.allclose(Q, Qb) 98 | 99 | def test_ffm_premultiply(): 100 | Q, R = householder_qr(A) 101 | A2 = householder_qr_packed(A) 102 | QF, R = householder_split_qf_r(A2) 103 | I = jnp.eye(A.shape[0]) 104 | C = householder_ffm_premultiply(QF, I) 105 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Guidelines for contributing code 2 | 3 | ## What do I need to know to help? 4 | 5 | If you are looking to help to with a code contribution our project is primarily written in Python 6 | and uses Numpy, JAX and related technologies. 7 | If you don't feel ready to make a code contribution yet, no problem! 8 | You can also check out the documentation issues 9 | or the design issues that we have. 10 | 11 | If you are interested in making a code contribution 12 | and would like to learn more about the technologies that we use, 13 | check out the list below. 14 | 15 | * [JAX reference documentation](https://jax.readthedocs.io/en/latest/) 16 | * [SciPy Lectures](https://scipy-lectures.org/) 17 | * [CR-Nimble reference documentation](https://carnotresearch.github.io/cr-nimble/) 18 | 19 | ## How do I make a contribution? 20 | 21 | Never made an open source contribution before? Wondering how contributions work in the in our project? Here's a quick rundown! 22 | 23 | * Find an issue that you are interested in addressing or a feature that you would like to add. 24 | 25 | * Fork the repository associated with the issue to your local GitHub organization. 26 | This means that you will have a copy of the repository under your-GitHub-username/cr-nimble. 27 | 28 | * Clone the repository to your local machine using 29 | `git clone https://github.com/your-user-name/cr-nimble.git`. 30 | 31 | * Create a new branch for your fix using `git checkout -b branch-name-here`. 32 | 33 | * Make the appropriate changes for the issue you are trying to address or the feature that you want to add. 34 | 35 | * Use `git add insert-paths-of-changed-files-here` to add the file contents of the changed files to the 36 | "snapshot" git uses to manage the state of the project, also known as the index. 37 | 38 | * Use `git commit -m "Insert a short message of the changes made here"` to store the contents of the index with a descriptive message. 39 | 40 | * Push the changes to the remote repository using `git push origin branch-name-here`. 41 | 42 | * Submit a pull request to the upstream repository. 43 | 44 | * Title the pull request with a short description of the changes made and the issue or bug number associated with your change. 45 | For example, you can title an issue like so "Added more log outputting to resolve #4352". 46 | 47 | * In the description of the pull request, explain the changes that you made, any issues you think exist 48 | with the pull request you made, and any questions you have for the maintainer. 49 | It's OK if your pull request is not perfect (no pull request is), the reviewer will be able to help you 50 | fix any problems and improve it! 51 | 52 | * Wait for the pull request to be reviewed by a maintainer. 53 | 54 | * Make changes to the pull request if the reviewing maintainer recommends them. 55 | 56 | * Celebrate your success after your pull request is merged! 57 | 58 | ## Where can I go for help? 59 | 60 | If you need help, you can ask questions on our [Discussions](https://github.com/carnotresearch/cr-nimble/discussions) forum for this project. 61 | 62 | ## What does the Code of Conduct mean for me? 63 | 64 | Our Code of Conduct means that you are responsible for treating everyone 65 | on the project with respect and courtesy regardless of their identity. 66 | If you are the victim of any inappropriate behavior or 67 | comments as described in our Code of Conduct, 68 | we are here for you and will do the best to ensure that 69 | the abuser is reprimanded appropriately, per our code. 70 | -------------------------------------------------------------------------------- /src/cr/nimble/dsp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Signal Processing Utilities 17 | """ 18 | 19 | # pylint: disable=W0611 20 | 21 | 22 | from cr.nimble._src.dsp.util import ( 23 | sliding_windows_rw, 24 | sliding_windows_cw 25 | ) 26 | 27 | from cr.nimble._src.dsp.convolution import ( 28 | # convolution 29 | vec_convolve, 30 | vec_convolve_jit, 31 | ) 32 | 33 | # Energy 34 | from cr.nimble._src.dsp.energy import ( 35 | 36 | # energy of a signal 37 | energy, 38 | find_first_signal_with_energy_le_rw, 39 | find_first_signal_with_energy_le_cw, 40 | ) 41 | 42 | # Thresholding 43 | from cr.nimble._src.dsp.thresholding import ( 44 | 45 | hard_threshold, 46 | hard_threshold_sorted, 47 | hard_threshold_by, 48 | largest_indices_by, 49 | energy_threshold, 50 | ) 51 | 52 | # Scaling 53 | from cr.nimble._src.dsp.scaling import ( 54 | 55 | scale_to_0_1, 56 | descale_from_0_1, 57 | # statistical normalization of data 58 | scale_0_mean_1_var, 59 | scale_0_mean_1_var_jit, 60 | ) 61 | 62 | # Quantization 63 | from cr.nimble._src.dsp.quantization import ( 64 | quantize_1, 65 | inv_quantize_1, 66 | ) 67 | 68 | 69 | # Spectrum 70 | from cr.nimble._src.dsp.spectrum import ( 71 | norm_freq, 72 | frequency_spectrum, 73 | power_spectrum 74 | ) 75 | 76 | # Interpolation 77 | from cr.nimble._src.dsp.interpolation import ( 78 | # interpolate via fourier transform 79 | interpft, 80 | ) 81 | 82 | # Signal Features 83 | from cr.nimble._src.dsp.features import ( 84 | dynamic_range, 85 | nonzero_dynamic_range, 86 | ) 87 | 88 | # Sparse Signals 89 | from cr.nimble._src.dsp.sparse import ( 90 | nonzero_values, 91 | nonzero_indices, 92 | support, 93 | largest_indices, 94 | sparse_approximation, 95 | build_signal_from_indices_and_values, 96 | ) 97 | 98 | 99 | # Sparse Signal Matrices 100 | from cr.nimble._src.dsp.sparse import ( 101 | randomize_rows, 102 | randomize_cols, 103 | # row wise 104 | take_along_rows, 105 | largest_indices_rw, 106 | sparse_approximation_rw, 107 | # column wise 108 | take_along_cols, 109 | largest_indices_cw, 110 | sparse_approximation_cw, 111 | ) 112 | 113 | # Signal Comparison 114 | from cr.nimble._src.signalcomparison import ( 115 | SignalsComparison, 116 | snrs_cw, 117 | snrs_rw, 118 | snr 119 | ) 120 | 121 | # Noise 122 | from cr.nimble._src.noise import ( 123 | awgn_at_snr_ms, 124 | awgn_at_snr_std, 125 | awgn_at_snr 126 | ) 127 | 128 | # Discrete Cosine Transform 129 | from cr.nimble._src.dsp.dct import ( 130 | dct, 131 | idct, 132 | orthonormal_dct, 133 | orthonormal_idct 134 | ) 135 | 136 | 137 | # Walsh Hadamard 138 | from cr.nimble._src.dsp.wht import ( 139 | fwht, 140 | ) 141 | 142 | from cr.nimble._src.dsp.synthetic_signals import ( 143 | time_values, 144 | ) 145 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at contact@carnotresearch.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/ndarray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Utility functions for ND arrays 17 | """ 18 | 19 | from jax import jit 20 | import jax.numpy as jnp 21 | 22 | from cr.nimble import promote_arg_dtypes 23 | 24 | 25 | def arr_largest_index(x): 26 | """Returns the unraveled index of the largest entry (by magnitude) in an n-d array 27 | 28 | Args: 29 | x (jax.numpy.ndarray): An nd-array 30 | 31 | Returns: 32 | tuple: n-dim index of the largest entry in x 33 | """ 34 | x = jnp.asarray(x) 35 | return jnp.unravel_index(jnp.argmax(jnp.abs(x)), x.shape) 36 | 37 | def arr_l1norm(x): 38 | """Returns the l1-norm of an array by flattening it 39 | 40 | Args: 41 | x (jax.numpy.ndarray): An nd-array 42 | 43 | Returns: 44 | (float): l1 norm of x 45 | """ 46 | x = jnp.asarray(x) 47 | x = promote_arg_dtypes(x) 48 | return jnp.sum(jnp.abs(x)) 49 | 50 | 51 | def arr_l2norm(x): 52 | """Returns the l2-norm of an array by flattening it 53 | """ 54 | x = jnp.asarray(x) 55 | x = promote_arg_dtypes(x) 56 | return jnp.sqrt(jnp.abs(jnp.vdot(x, x))) 57 | 58 | def arr_l2norm_sqr(x): 59 | """Returns the squared l2-norm of an array by flattening it 60 | """ 61 | x = jnp.asarray(x) 62 | x = promote_arg_dtypes(x) 63 | return jnp.vdot(x, x) 64 | 65 | def arr_vdot(x, y): 66 | """Returns the inner product of two arrays by flattening it 67 | """ 68 | x = jnp.asarray(x) 69 | y = jnp.asarray(y) 70 | x, y = promote_arg_dtypes(x, y) 71 | return jnp.vdot(x, y) 72 | 73 | @jit 74 | def arr_rdot(x, y): 75 | """Returns the inner product Re(x^H, y) on two arrays by flattening them 76 | """ 77 | x = jnp.asarray(x) 78 | y = jnp.asarray(y) 79 | x = jnp.ravel(x) 80 | y = jnp.ravel(y) 81 | if jnp.isrealobj(x) and jnp.isrealobj(y): 82 | # we can fall back to real inner product 83 | return jnp.sum(x * y) 84 | if jnp.isrealobj(x) or jnp.isrealobj(y): 85 | # 86 | x = jnp.real(x) 87 | y = jnp.real(y) 88 | return jnp.sum(x * y) 89 | # both x and y are complex 90 | # compute x^H 91 | x = jnp.conjugate(x) 92 | # compute x^H y 93 | prod = jnp.sum(x * y) 94 | # take the real part 95 | return jnp.real(prod) 96 | 97 | @jit 98 | def arr_rnorm_sqr(x): 99 | """Returns the squared norm of x using the real inner product Re(x^H, x) 100 | """ 101 | return arr_rdot(x, x) 102 | 103 | @jit 104 | def arr_rnorm(x): 105 | """Returns the norm of x using the real inner product Re(x^H, x) 106 | """ 107 | return jnp.sqrt(arr_rdot(x, x)) 108 | 109 | @jit 110 | def arr2vec(x): 111 | """Converts an nd array to a vector 112 | """ 113 | x = jnp.asarray(x) 114 | return jnp.ravel(x) 115 | 116 | 117 | @jit 118 | def log_pos(x): 119 | """Computes log with the assumption that x values are positive. 120 | """ 121 | return jnp.log(jnp.maximum(x, jnp.finfo(float).eps)) 122 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/io/resource.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from pathlib import Path 17 | import re 18 | import urllib 19 | from urllib.parse import urlparse 20 | import requests 21 | from dataclasses import dataclass 22 | 23 | _INITIALIZED = False 24 | 25 | CACHE_DIR = '' 26 | 27 | 28 | def is_valid_url(url): 29 | regex = re.compile( 30 | r'^https?://' # http:// or https:// 31 | r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain... 32 | r'localhost|' # localhost... 33 | r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip 34 | r'(?::\d+)?' # optional port 35 | r'(?:/?|[/?]\S+)$', re.IGNORECASE) 36 | return url is not None and regex.search(url) 37 | 38 | def _initialize(): 39 | global _INITIALIZED 40 | if _INITIALIZED: 41 | return 42 | # print("Initializing CR-VISION") 43 | home_dir = Path.home() 44 | cache_dir = home_dir / '.cr-suite' 45 | # Make sure that lib directory exists 46 | cache_dir.mkdir(parents=True, exist_ok=True) 47 | global CACHE_DIR 48 | CACHE_DIR = cache_dir 49 | _INITIALIZED = True 50 | 51 | _initialize() 52 | 53 | def ensure_resource(name, uri=None): 54 | # technically this line is not required. But it helps in unit test coverage 55 | _initialize() 56 | if is_valid_url(name): 57 | # it seems uri has been passed first 58 | name, uri = uri, name 59 | if name is None: 60 | if uri is None: 61 | return None 62 | # let's construct name from uri 63 | p = urlparse(uri) 64 | name = p.path.split('/')[-1] 65 | 66 | path = CACHE_DIR / name 67 | if path.is_file(): 68 | # It's already downloaded, nothing to do. 69 | return path 70 | if uri is None: 71 | uri = get_uri(name) 72 | if uri is None: 73 | # We could not find the download URL 74 | return None 75 | r = requests.get(uri, stream=True) 76 | CHUNK_SIZE = 1024 77 | print(f"Downloading {name}") 78 | with path.open('wb') as o: 79 | for chunk in r.iter_content(chunk_size=CHUNK_SIZE): 80 | o.write(chunk) 81 | print("Download complete for {}".format(name)) 82 | return path 83 | 84 | 85 | def ensure_cr_suite_resource(path): 86 | uri = f'https://raw.githubusercontent.com/carnotresearch/cr-suite/main/data/{path}' 87 | name = path.replace('/', '_') 88 | return ensure_resource(name, uri) 89 | 90 | @dataclass 91 | class _Resource: 92 | name: str 93 | uri: str 94 | 95 | _KNOWN_RESOURCES = [ 96 | _Resource(name="haarcascade_frontalface_default.xml", 97 | uri="https://raw.githubusercontent.com/opencv/opencv/master/data/haarcascades/haarcascade_frontalface_default.xml"), 98 | _Resource(name="lbfmodel.yaml", 99 | uri="https://raw.githubusercontent.com/kurnianggoro/GSOC2017/master/data/lbfmodel.yaml") 100 | ] 101 | 102 | 103 | def get_uri(name): 104 | for res in _KNOWN_RESOURCES: 105 | if res.name == name: 106 | return res.uri 107 | return None -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/dct.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-Present CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Discrete Cosine Transforms 16 | 17 | Adapted from: 18 | 19 | * http://www-personal.umich.edu/~mejn/computational-physics/dcst.py 20 | * https://dsp.stackexchange.com/questions/2807/fast-cosine-transform-via-fft 21 | """ 22 | 23 | from jax import jit 24 | import jax.numpy as jnp 25 | import jax.numpy.fft as jfft 26 | 27 | def dct(y): 28 | """Computes the 1D Type-II DCT transform 29 | 30 | Args: 31 | y (jax.numpy.ndarray): The 1D real signal 32 | 33 | Returns: 34 | jax.numpy.ndarray: The Type-II Discrete Cosine Transform coefficients of y 35 | """ 36 | n = y.shape[0] 37 | y2 = jnp.concatenate( (y[:], y[::-1])) 38 | c = jfft.rfft(y2, axis=0)[:n] 39 | ks = jnp.arange(n) 40 | phi = jnp.exp(-1j*jnp.pi*ks/(2*n)) 41 | prod = (phi*c.T).T 42 | return jnp.real(prod) 43 | 44 | 45 | def idct(a): 46 | """Computes the 1D Type-II Inverse DCT transform 47 | 48 | Args: 49 | a (jax.numpy.ndarray): The Type-II DCT transform coefficients of a 1D real signal 50 | 51 | Returns: 52 | jax.numpy.ndarray: The 1D real signal y s.t. a = dct(y) 53 | """ 54 | n = a.shape[0] 55 | shape = (1,)+a.shape[1:] 56 | ks = jnp.arange(n) 57 | phi = jnp.exp(1j*jnp.pi*ks/(2*n)) 58 | upper = (phi*a.T).T 59 | lower = jnp.zeros(shape) 60 | c = jnp.concatenate((upper, lower)) 61 | return jfft.irfft(c, axis=0)[:n] 62 | 63 | 64 | def orthonormal_dct(y): 65 | """Computes the 1D Type-II DCT transform such that the transform is orthonormal 66 | 67 | Args: 68 | y (jax.numpy.ndarray): The 1D real signal 69 | 70 | Returns: 71 | jax.numpy.ndarray: The orthonormal Type-II Discrete Cosine Transform coefficients of y 72 | 73 | Orthonormality ensures that 74 | 75 | .. math:: 76 | 77 | \\langle a, a \\rangle = \\langle y, y \\rangle 78 | """ 79 | n = y.shape[0] 80 | factor = jnp.sqrt(1/(2*n)) 81 | ks = jnp.arange(n) 82 | phi = jnp.exp(-1j*jnp.pi*ks/(2*n)) 83 | # scaling to make the transform orthonormal 84 | phi = phi.at[0].set(phi[0]*1/jnp.sqrt(2)) 85 | phi = phi * factor 86 | 87 | y2 = jnp.concatenate( (y[:], y[::-1])) 88 | c = jfft.rfft(y2, axis=0)[:n] 89 | prod = jnp.real(phi*c.T).T 90 | # phi = phi*jnp.sqrt(2)/n 91 | return prod 92 | 93 | def orthonormal_idct(a): 94 | """Computes the 1D Type-II IDCT transform such that the transform is orthonormal 95 | 96 | Args: 97 | a (jax.numpy.ndarray): The orthonormal Type-II DCT transform coefficients of a 1D real signal 98 | 99 | Returns: 100 | jax.numpy.ndarray: The 1D real signal y s.t. a = orthonormal_dct(y) 101 | """ 102 | n = a.shape[0] 103 | factor = jnp.sqrt(2*n) 104 | ks = jnp.arange(n) 105 | 106 | phi = jnp.exp(1j*jnp.pi*ks/(2*n)) 107 | # scaling to make the transform orthonormal 108 | phi = phi*factor 109 | phi = phi.at[0].set(phi[0]*jnp.sqrt(2)) 110 | 111 | upper = (phi*a.T).T 112 | lower = jnp.zeros((1,)+a.shape[1:]) 113 | c = jnp.concatenate((upper, lower)) 114 | return jfft.irfft(c, axis=0)[:n] 115 | -------------------------------------------------------------------------------- /docs/source/la.rst: -------------------------------------------------------------------------------- 1 | Linear Algebra 2 | ============================ 3 | 4 | .. contents:: 5 | :depth: 2 6 | :local: 7 | 8 | .. currentmodule:: cr.nimble 9 | 10 | 11 | 12 | Linear Systems 13 | ------------------------ 14 | 15 | .. rubric:: Triangular Systems 16 | 17 | 18 | .. autosummary:: 19 | :toctree: _autosummary 20 | 21 | solve_Lx_b 22 | solve_LTx_b 23 | solve_Ux_b 24 | solve_UTx_b 25 | solve_spd_chol 26 | 27 | 28 | 29 | .. rubric:: Special Dense Linear Systems 30 | 31 | 32 | .. autosummary:: 33 | :toctree: _autosummary 34 | 35 | mult_with_submatrix 36 | solve_on_submatrix 37 | 38 | 39 | Singular Value Decomposition 40 | -------------------------------- 41 | 42 | .. rubric:: Fundamental Subspaces 43 | 44 | .. autosummary:: 45 | :toctree: _autosummary 46 | 47 | orth 48 | row_space 49 | null_space 50 | left_null_space 51 | effective_rank 52 | effective_rank_from_svd 53 | singular_values 54 | 55 | 56 | .. rubric:: SVD for Bidiagonal Matrices 57 | 58 | .. currentmodule:: cr.nimble.svd 59 | 60 | .. autosummary:: 61 | :toctree: _autosummary 62 | 63 | bdsqr 64 | bdsqr_jit 65 | 66 | 67 | .. rubric:: Truncated SVD 68 | 69 | 70 | .. autosummary:: 71 | :toctree: _autosummary 72 | 73 | lansvd_simple 74 | lansvd_simple_jit 75 | lanbpro_init 76 | lanbpro_iteration 77 | lanbpro_iteration_jit 78 | lanbpro 79 | lanbpro_jit 80 | 81 | 82 | 83 | Orthogonalization 84 | ------------------------ 85 | 86 | .. currentmodule:: cr.nimble 87 | 88 | .. rubric:: Householder Reflections 89 | 90 | .. autosummary:: 91 | :toctree: _autosummary 92 | 93 | householder_vec 94 | householder_matrix 95 | householder_premultiply 96 | householder_postmultiply 97 | householder_ffm_jth_v_beta 98 | householder_ffm_premultiply 99 | householder_ffm_backward_accum 100 | householder_ffm_to_wy 101 | householder_qr_packed 102 | householder_split_qf_r 103 | householder_qr 104 | 105 | 106 | 107 | Subspaces 108 | --------------------------- 109 | 110 | .. currentmodule:: cr.nimble.subspaces 111 | 112 | .. rubric:: Projection 113 | 114 | .. autosummary:: 115 | :toctree: _autosummary 116 | 117 | project_to_subspace 118 | is_in_subspace 119 | 120 | 121 | .. rubric:: Principal Angles 122 | 123 | .. autosummary:: 124 | :toctree: _autosummary 125 | 126 | principal_angles_cos 127 | principal_angles_rad 128 | principal_angles_deg 129 | smallest_principal_angle_cos 130 | smallest_principal_angle_rad 131 | smallest_principal_angle_deg 132 | smallest_principal_angles_cos 133 | smallest_principal_angles_rad 134 | smallest_principal_angles_deg 135 | subspace_distance 136 | 137 | 138 | Affine Spaces 139 | ------------------------------------ 140 | 141 | .. contents:: 142 | :depth: 2 143 | :local: 144 | 145 | .. currentmodule:: cr.nimble.affine 146 | 147 | 148 | .. rubric:: Homogeneous Coordinate System 149 | 150 | .. autosummary:: 151 | :toctree: _autosummary 152 | 153 | homogenize 154 | homogenize_vec 155 | homogenize_cols 156 | 157 | Standard Matrices 158 | ------------------------------------ 159 | 160 | .. currentmodule:: cr.nimble 161 | 162 | .. rubric:: Random matrices 163 | 164 | .. autosummary:: 165 | :toctree: _autosummary 166 | 167 | gaussian_mtx 168 | 169 | 170 | .. rubric:: Special Matrices 171 | 172 | 173 | .. autosummary:: 174 | :toctree: _autosummary 175 | 176 | pascal 177 | pascal_jit 178 | 179 | 180 | 181 | Toeplitz Matrices 182 | ------------------------- 183 | 184 | 185 | .. autosummary:: 186 | :toctree: _autosummary 187 | 188 | toeplitz_mat 189 | toeplitz_mult 190 | 191 | 192 | Circulant Matrices 193 | ------------------------- 194 | 195 | 196 | .. autosummary:: 197 | :toctree: _autosummary 198 | 199 | circulant_mat 200 | circulant_mult 201 | 202 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/compression/binary_arrs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Run Length Encoding of Binary Maps 17 | 18 | 19 | References: 20 | 21 | https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi 22 | """ 23 | 24 | from bitarray import bitarray 25 | from bitarray.util import int2ba, ba2int 26 | import numpy as np 27 | 28 | def count_binary_runs(input_arr): 29 | """Returns the runs of 0s and 1s in a binary map 30 | """ 31 | # make sure that it is a numpy array 32 | input_arr = np.asarray(input_arr) 33 | # the first bit 34 | b = input_arr[0] 35 | # the last bit 36 | e = input_arr[-1] 37 | # extend the array 38 | extended = np.hstack(([1-b], input_arr, [1 - e])) 39 | # locate the changes 40 | diffs = np.diff(extended) 41 | markers, = np.where(diffs) 42 | runs = np.diff(markers) 43 | return runs 44 | 45 | B00 = bitarray('00') 46 | B01 = bitarray('01') 47 | B10 = bitarray('10') 48 | B11 = bitarray('11') 49 | NUM_BITS_RUN_LEN = 4 50 | 51 | 52 | def encode_binary_arr(input_arr): 53 | """Encodes a binary array into a bit array via run length encoding 54 | """ 55 | # the first bit 56 | b = input_arr[0] 57 | # the runs 58 | runs = count_binary_runs(input_arr) 59 | # build the bit array 60 | a = bitarray() 61 | a.append(b) 62 | for run in runs: 63 | run = int(run) 64 | if run == 1: 65 | a.extend(B00) 66 | continue 67 | if run == 2: 68 | a.extend(B01) 69 | continue 70 | if run == 3: 71 | a.extend(B10) 72 | continue 73 | # run is 4 or more 74 | a.extend(B11) 75 | # now record number of bits for the run 76 | bl = run.bit_length() 77 | a.extend(int2ba(bl, NUM_BITS_RUN_LEN)) 78 | # now record the run 79 | a.extend(int2ba(run)) 80 | return a 81 | 82 | 83 | 84 | def decode_binary_arr(input_bit_arr : bitarray): 85 | """Decodes a binary array from a bit array via run length decoding 86 | """ 87 | a = input_bit_arr 88 | result = [] 89 | # The first bit 90 | b = a[0] 91 | idx = 1 92 | # number of bits in the encoded bit array 93 | n = len(a) 94 | while idx < n: 95 | # read the next 2 bits 96 | code = a[idx:idx+2] 97 | idx += 2 98 | code = ba2int(code) 99 | run = code + 1 100 | if code == 3: 101 | # we need to decode run from the stream 102 | bl = ba2int(a[idx:idx+NUM_BITS_RUN_LEN]) 103 | idx += NUM_BITS_RUN_LEN 104 | run = ba2int(a[idx:idx+bl]) 105 | idx += bl 106 | for i in range(run): 107 | result.append(b) 108 | b = 1 - b 109 | return np.array(result) 110 | 111 | 112 | def binary_compression_ratio(input_arr, output_arr, bits_per_sample=1): 113 | """Returns the compression ratio of binary array compression algorithm 114 | """ 115 | out_len = output_arr.nbytes * 8 116 | ratio = len(input_arr) * bits_per_sample / out_len 117 | return ratio 118 | 119 | def binary_space_saving_ratio(input_arr, output_arr, bits_per_sample=1): 120 | """Returns the space saving ratio of binary array compression algorithm 121 | """ 122 | out_len = output_arr.nbytes * 8 123 | return 1 - out_len / (len(input_arr) * bits_per_sample) 124 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | All notable changes to this project will be documented in this file. 3 | 4 | * This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 5 | * The format of this log is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | 8 | ## [Unreleased] 9 | 10 | [Documentation](https://cr-nimble.readthedocs.io/en/latest/) 11 | 12 | 13 | ## [0.3.2] - 2022-10-08 14 | 15 | [Documentation](https://cr-nimble.readthedocs.io/en/v0.3.2/) 16 | 17 | ### Added 18 | 19 | Vectors: 20 | 21 | - Circular buffers 22 | - Binary heap 23 | 24 | DSP 25 | 26 | - sliding_windows 27 | 28 | Synthetic Signals 29 | 30 | - picket_fence 31 | - heavi_sine 32 | - bumps 33 | - blocks 34 | - doppler 35 | - ramp 36 | - cusp 37 | - sing 38 | - hi_sine 39 | - lo_sine 40 | - lin_chirp 41 | - two_chirp 42 | - quad_chirp 43 | - mish_mash 44 | - werner_sorrows 45 | - leopold 46 | 47 | 48 | Miscellaneous 49 | 50 | - `cr.sparse.io` moved to `cr.nimble.io` 51 | 52 | 53 | ### Fixed 54 | 55 | - Handling of complex signals in `build_signal_from_indices_and_values` 56 | - Handling of 100 percentile in `num_largest_coeffs_for_energy_percent` 57 | 58 | ## [0.3.1] - 2022-09-10 59 | 60 | [Documentation](https://cr-nimble.readthedocs.io/en/v0.3.1/) 61 | 62 | ### Added 63 | 64 | Metrics 65 | 66 | - normalized_mse 67 | - percent_rms_diff 68 | - compression_ratio 69 | - cr_to_pss 70 | - pss_to_cr 71 | 72 | Noise 73 | 74 | - awgn_at_snr_std 75 | 76 | Matrices 77 | 78 | - mat_column_blocks 79 | - block_diag 80 | 81 | Vectors 82 | 83 | - has_equal_values_vec 84 | 85 | Special matrices 86 | 87 | - toeplitz_mat 88 | - toeplitz_mult 89 | - circulant_mat 90 | - circulant_mult 91 | 92 | 93 | Misc 94 | 95 | - to_tex_matrix 96 | 97 | 98 | 99 | ## [0.3.0] - 2022-08-27 100 | 101 | [Documentation](https://cr-nimble.readthedocs.io/en/v0.3.0/) 102 | 103 | ### Added 104 | 105 | Data Compression 106 | 107 | - Binary data encoding/decoding 108 | - Run length encoding/decoding 109 | - Fixed length encoding/decoding 110 | 111 | Digital Signal Processing 112 | 113 | - Scaling functions 114 | - Quantized 115 | - Energy fraction based thresholding 116 | 117 | Metrics 118 | 119 | - Percentage root mean square difference 120 | 121 | ### Removed 122 | 123 | ### Changed 124 | 125 | - Statistical normalization renamed with changes in return type 126 | - Digital signal processing related functions moved under 127 | `cr.nimble.dsp` 128 | 129 | 130 | ### Improved 131 | 132 | - Documentation improved 133 | - API organization improved 134 | 135 | 136 | ## [0.2.4] - 2022-08-17 137 | 138 | [Documentation](https://cr-nimble.readthedocs.io/en/v0.2.4/) 139 | 140 | 141 | ### Added 142 | 143 | - Digital signal processing utilities moved from cr-sparse to cr-nimble 144 | - Moved discrete number related functions from cr-sparse. 145 | - Some sparse vector and matrix processing functionality moved from cr-sparse. 146 | 147 | ### Removed 148 | 149 | - Unnecessary `__init__.py` files removed. 150 | 151 | ### Notes 152 | 153 | - Jax 0.3.14 compatibility 154 | - Aligning version numbering across sister projects. 155 | 156 | ## [0.1.1] - 2021-12-07 157 | 158 | ### Added 159 | 160 | - distance, matrix, ndarray, metrics, modules were moved from `cr-sparse` to `cr-nimble` 161 | - some more vector functions were moved from `cr-sparse` to `cr-nimble` 162 | 163 | ### Improved 164 | 165 | - All unit tests were moved to 64-bit floating point data. 166 | 167 | 168 | ## [0.1.0] - 2021-12-07 169 | 170 | Initial release by refactoring code from `cr-nimble`. 171 | 172 | 173 | [Unreleased]: https://github.com/carnotresearch/cr-nimble/compare/v0.3.2...HEAD 174 | [0.3.2]: https://github.com/carnotresearch/cr-nimble/compare/v0.3.1...v0.3.2 175 | [0.3.1]: https://github.com/carnotresearch/cr-nimble/compare/v0.3.0...v0.3.1 176 | [0.3.0]: https://github.com/carnotresearch/cr-nimble/compare/v0.2.4...v0.3.0 177 | [0.2.4]: https://github.com/carnotresearch/cr-nimble/compare/v0.1.1...v0.2.4 178 | [0.1.1]: https://github.com/carnotresearch/cr-nimble/compare/v0.1.0...v0.1.1 179 | [0.1.0]: https://github.com/carnotresearch/cr-nimble/releases/tag/v0.1.0 -------------------------------------------------------------------------------- /src/cr/nimble/_src/signalcomparison.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import jax 17 | import jax.numpy as jnp 18 | 19 | from cr.nimble import sqr_norms_l2_cw, sqr_norms_l2_rw, norms_l2_cw, sqr_norm_l2 20 | 21 | norm = jnp.linalg.norm 22 | 23 | 24 | class SignalsComparison: 25 | 26 | def __init__(self, references, estimates): 27 | if references.ndim == 1: 28 | references = jnp.expand_dims(references, 1) 29 | if estimates.ndim == 1: 30 | estimates = jnp.expand_dims(estimates, 1) 31 | self.references = references 32 | self.estimates = estimates 33 | self.differences = references - estimates 34 | 35 | @property 36 | def reference_norms(self): 37 | return norms_l2_cw(self.references) 38 | 39 | @property 40 | def estimate_norms(self): 41 | return norms_l2_cw(self.estimates) 42 | 43 | @property 44 | def difference_norms(self): 45 | return norms_l2_cw(self.differences) 46 | 47 | @property 48 | def reference_energies(self): 49 | return sqr_norms_l2_cw(self.references) 50 | 51 | @property 52 | def estimate_energies(self): 53 | return sqr_norms_l2_cw(self.estimates) 54 | 55 | @property 56 | def difference_energies(self): 57 | return sqr_norms_l2_cw(self.differences) 58 | 59 | @property 60 | def error_to_signal_norms(self): 61 | a_norms = self.reference_norms 62 | diff_norms = self.difference_norms 63 | return diff_norms / a_norms 64 | 65 | @property 66 | def signal_to_noise_ratios(self): 67 | a_energies = self.reference_energies 68 | diff_energies = self.difference_energies 69 | ratios = a_energies / diff_energies 70 | return 10*jnp.log10(ratios) 71 | 72 | @property 73 | def cum_reference_norm(self): 74 | return norm(self.references, 'fro') 75 | 76 | @property 77 | def cum_estimate_norm(self): 78 | return norm(self.estimates, 'fro') 79 | 80 | @property 81 | def cum_difference_norm(self): 82 | return norm(self.differences, 'fro') 83 | 84 | @property 85 | def cum_error_to_signal_norm(self): 86 | a_norm = self.cum_reference_norm 87 | err_norm = self.cum_difference_norm 88 | ratio = err_norm / a_norm 89 | return ratio 90 | 91 | @property 92 | def cum_signal_to_noise_ratio(self): 93 | a_norm = self.cum_reference_norm 94 | err_norm = self.cum_difference_norm 95 | ratio = a_norm / err_norm 96 | return 20 * jnp.log10(ratio) 97 | 98 | def summarize(self): 99 | # if self.references.ndim == 2: 100 | # 101 | # else: 102 | # n, s = self.references.shape, 1 103 | n, s = self.references.shape 104 | print(f'Dimensions: {n}') 105 | print(f'Signals: {s}') 106 | print(f'Combined reference norm: {self.cum_reference_norm:.3f}') 107 | print(f'Combined estimate norm: {self.cum_estimate_norm:.3f}') 108 | print(f'Combined difference norm: {self.cum_difference_norm:.3f}') 109 | print(f'Combined SNR: {self.cum_signal_to_noise_ratio:.2f} dB') 110 | 111 | 112 | 113 | def snrs_cw(A, B): 114 | ref_energies = sqr_norms_l2_cw(A) 115 | diff_energies = sqr_norms_l2_cw(A-B) 116 | ratios = ref_energies / diff_energies 117 | return 10*jnp.log10(ratios) 118 | 119 | def snrs_rw(A, B): 120 | ref_energies = sqr_norms_l2_rw(A) 121 | diff_energies = sqr_norms_l2_rw(A-B) 122 | ratios = ref_energies / diff_energies 123 | return 10*jnp.log10(ratios) 124 | 125 | def snr(signal, noise): 126 | s = sqr_norm_l2(signal) 127 | n = sqr_norm_l2(noise) 128 | ratio = s/n 129 | return 10*jnp.log10(ratio) 130 | 131 | -------------------------------------------------------------------------------- /docs/source/dsp.rst: -------------------------------------------------------------------------------- 1 | .. _api:dsp: 2 | 3 | Digital Signal Processing 4 | =============================== 5 | 6 | .. contents:: 7 | :depth: 2 8 | :local: 9 | 10 | The ``CR-Nimble`` library has some handy digital signal processing routines 11 | implemented in JAX. They are available as part of the ``cr.nimble.dsp`` 12 | package. 13 | 14 | 15 | .. currentmodule:: cr.nimble.dsp 16 | 17 | Signal Energy 18 | ------------------------------- 19 | 20 | .. autosummary:: 21 | :toctree: _autosummary 22 | 23 | energy 24 | 25 | Thresholding 26 | ------------------------------- 27 | 28 | .. autosummary:: 29 | :toctree: _autosummary 30 | 31 | hard_threshold 32 | hard_threshold_sorted 33 | hard_threshold_by 34 | largest_indices_by 35 | energy_threshold 36 | 37 | 38 | 39 | Scaling 40 | ------------------------------- 41 | 42 | .. autosummary:: 43 | :toctree: _autosummary 44 | 45 | scale_to_0_1 46 | scale_0_mean_1_var 47 | 48 | Quantization 49 | ------------------------------- 50 | 51 | .. autosummary:: 52 | :toctree: _autosummary 53 | 54 | quantize_1 55 | 56 | 57 | Spectrum Analysis 58 | ------------------------------- 59 | 60 | .. autosummary:: 61 | :toctree: _autosummary 62 | 63 | norm_freq 64 | frequency_spectrum 65 | power_spectrum 66 | 67 | Interpolation 68 | ------------------------------- 69 | 70 | .. autosummary:: 71 | :toctree: _autosummary 72 | 73 | interpft 74 | 75 | 76 | 77 | Windowing 78 | ------------------------------- 79 | 80 | .. autosummary:: 81 | :toctree: _autosummary 82 | 83 | sliding_windows_rw 84 | sliding_windows_cw 85 | 86 | 87 | 88 | Sparse Signals 89 | ------------------------------------ 90 | 91 | Following functions analyze, transform, or construct signals 92 | which are known to be sparse. 93 | 94 | .. autosummary:: 95 | :toctree: _autosummary 96 | 97 | nonzero_values 98 | nonzero_indices 99 | support 100 | largest_indices 101 | sparse_approximation 102 | build_signal_from_indices_and_values 103 | 104 | 105 | Matrices of Sparse Signals 106 | ------------------------------------ 107 | 108 | Following functions analyze, transform, or construct 109 | collections of sparse signals organized as matrices. 110 | 111 | .. autosummary:: 112 | :toctree: _autosummary 113 | 114 | randomize_rows 115 | randomize_cols 116 | 117 | .. rubric:: Sparse representation matrices (row-wise) 118 | 119 | .. autosummary:: 120 | :toctree: _autosummary 121 | 122 | take_along_rows 123 | largest_indices_rw 124 | sparse_approximation_rw 125 | 126 | .. rubric:: Sparse representation matrices (column-wise) 127 | 128 | .. autosummary:: 129 | :toctree: _autosummary 130 | 131 | take_along_cols 132 | largest_indices_cw 133 | sparse_approximation_cw 134 | 135 | 136 | 137 | 138 | 139 | Artificial Noise 140 | ----------------------------------- 141 | 142 | 143 | .. autosummary:: 144 | :toctree: _autosummary 145 | 146 | awgn_at_snr_ms 147 | awgn_at_snr_std 148 | 149 | 150 | Synthetic Signals 151 | ----------------------- 152 | 153 | .. currentmodule:: cr.nimble.dsp.signals 154 | 155 | .. autosummary:: 156 | :nosignatures: 157 | :toctree: _autosummary 158 | 159 | pulse 160 | transient_sine_wave 161 | decaying_sine_wave 162 | chirp 163 | chirp_centered 164 | gaussian_pulse 165 | picket_fence 166 | heavi_sine 167 | bumps 168 | blocks 169 | doppler 170 | ramp 171 | cusp 172 | sing 173 | hi_sine 174 | lo_sine 175 | lin_chirp 176 | two_chirp 177 | quad_chirp 178 | mish_mash 179 | werner_sorrows 180 | leopold 181 | 182 | 183 | .. currentmodule:: cr.nimble.dsp 184 | 185 | 186 | 187 | Discrete Cosine Transform 188 | ------------------------------------- 189 | 190 | .. autosummary:: 191 | :nosignatures: 192 | :toctree: _autosummary 193 | 194 | dct 195 | idct 196 | orthonormal_dct 197 | orthonormal_idct 198 | 199 | .. currentmodule:: cr.nimble.dsp 200 | 201 | Fast Walsh Hadamard Transform 202 | ------------------------------ 203 | 204 | There is no separate Inverse Fast Walsh Hadamard Transform as FWHT is the inverse of 205 | itself except for a normalization factor. 206 | In other words, ``x == fwht(fwht(x)) / n`` where n is the length of x. 207 | 208 | .. autosummary:: 209 | :nosignatures: 210 | :toctree: _autosummary 211 | 212 | fwht 213 | 214 | Utilities 215 | ----------------------- 216 | 217 | 218 | .. autosummary:: 219 | :nosignatures: 220 | :toctree: _autosummary 221 | 222 | time_values 223 | 224 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import jax.numpy as jnp 17 | 18 | import jax 19 | from jax import lax, random 20 | from jax._src import dtypes 21 | 22 | from jax.lib import xla_bridge 23 | platform = xla_bridge.get_backend().platform 24 | 25 | 26 | def is_cpu(): 27 | """Returns True if the code is running on a CPU platform 28 | """ 29 | return platform == 'cpu' 30 | 31 | def is_gpu(): 32 | """Returns True if the code is running on a GPU platform 33 | """ 34 | return platform == 'gpu' 35 | 36 | def is_tpu(): 37 | """Returns True if the code is running on a TPU platform 38 | """ 39 | return platform == 'tpu' 40 | 41 | KEY0 = random.PRNGKey(0) 42 | KEYS = random.split(KEY0, 64) 43 | 44 | def promote_arg_dtypes(*args): 45 | """Promotes `args` to a common inexact type. 46 | 47 | Args: 48 | *args: list of JAX ndarrays to be promoted to common inexact type 49 | 50 | Returns: 51 | The same list of arrays with their dtype promoted to a common inexact type 52 | 53 | Example: 54 | Promoting a single argument:: 55 | 56 | >>> cr.nimble.promote_arg_dtypes(jnp.arange(2)) 57 | DeviceArray([0., 1.], dtype=float32) 58 | >>> from jax.config import config 59 | >>> config.update("jax_enable_x64", True) 60 | >>> cr.nimble.promote_arg_dtypes(jnp.arange(2)) 61 | DeviceArray([0., 1.], dtype=float64) 62 | 63 | Promoting two arguments to common floating point type:: 64 | 65 | >>> a = jnp.arange(2) 66 | >>> b = jnp.arange(1.5, 3.5) 67 | >>> a, b = cr.nimble.promote_arg_dtypes(a, b) 68 | >>> print(a) 69 | >>> print(b) 70 | [0. 1.] 71 | [1.5 2.5] 72 | 73 | A mix of real and complex types:: 74 | 75 | >>> a = jnp.arange(2) + 0.j 76 | >>> b = jnp.arange(1.5, 3.5) 77 | >>> a, b = cr.nimble.promote_arg_dtypes(a, b) 78 | >>> print(a) 79 | >>> print(b) 80 | [0.+0.j 1.+0.j] 81 | [1.5+0.j 2.5+0.j] 82 | """ 83 | def _to_inexact_type(type): 84 | return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_ 85 | inexact_types = [_to_inexact_type(arg.dtype) for arg in args] 86 | dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types)) 87 | args = [lax.convert_element_type(arg, dtype) for arg in args] 88 | if len(args) == 1: 89 | return args[0] 90 | else: 91 | return args 92 | 93 | 94 | def canonicalize_dtype(dtype): 95 | """Wrapper function on dtypes.canonicalize_dtype with None handling 96 | """ 97 | if dtype is None: 98 | return dtype 99 | return dtypes.canonicalize_dtype(dtype) 100 | 101 | 102 | def promote_to_complex(arg): 103 | """Promotes an argument to complex type""" 104 | dtype = dtypes.result_type(arg, np.complex64) 105 | return lax.convert_element_type(arg, dtype) 106 | 107 | def promote_to_real(arg): 108 | """Promotes an argument to real type""" 109 | dtype = dtypes.result_type(arg, np.float32) 110 | return lax.convert_element_type(arg, dtype) 111 | 112 | 113 | # Integer types 114 | integer_types = ( 115 | jnp.uint8.dtype, 116 | jnp.uint16.dtype, 117 | jnp.uint32.dtype, 118 | jnp.uint64.dtype, 119 | jnp.int8.dtype, 120 | jnp.int16.dtype, 121 | jnp.int32.dtype, 122 | jnp.int64.dtype, 123 | ) 124 | 125 | # Ranges of values for integer types 126 | integer_ranges = {t: (jnp.iinfo(t).min, jnp.iinfo(t).max) 127 | for t in integer_types} 128 | 129 | # Ranges of values for floating point types 130 | dtype_ranges = { 131 | bool: (False, True), 132 | float: (-1, 1), 133 | jnp.bool_.dtype: (False, True), 134 | jnp.float_.dtype: (-1, 1), 135 | jnp.float16.dtype: (-1, 1), 136 | jnp.float32.dtype: (-1, 1), 137 | jnp.complex64.dtype: (-1, 1), 138 | jnp.complex128.dtype: (-1, 1), 139 | } 140 | 141 | dtype_ranges.update(integer_ranges) 142 | 143 | 144 | def nbytes_live_buffers(): 145 | """Returns the number of bytes consumed by the live buffers 146 | """ 147 | backend = jax.lib.xla_bridge.get_backend() 148 | nbytes = [buf.nbytes for buf in backend.live_buffers()] 149 | return np.sum(nbytes) 150 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/dsp/thresholding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR.Sparse Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | from jax import jit 18 | 19 | def hard_threshold(x, K): 20 | """Returns the indices and corresponding values of largest K non-zero entries in a vector x 21 | 22 | Args: 23 | x (jax.numpy.ndarray): A sparse/compressible signal 24 | K (int): The number of largest entries to be kept in x 25 | 26 | Returns: 27 | (jax.numpy.ndarray, jax.numpy.ndarray): A tuple comprising of: 28 | * The indices of K largest entries in x 29 | * Corresponding entries in x 30 | 31 | See Also: 32 | :func:`hard_threshold_sorted` 33 | :func:`hard_threshold_by` 34 | """ 35 | indices = jnp.argsort(jnp.abs(x)) 36 | I = indices[:-K-1:-1] 37 | x_I = x[I] 38 | return I, x_I 39 | 40 | def hard_threshold_sorted(x, K): 41 | """Returns the sorted indices and corresponding values of largest K non-zero entries in a vector x 42 | 43 | Args: 44 | x (jax.numpy.ndarray): A sparse/compressible signal 45 | K (int): The number of largest entries to be kept in x 46 | 47 | Returns: 48 | (jax.numpy.ndarray, jax.numpy.ndarray): A tuple comprising of: 49 | * The indices of K largest entries in x sorted in ascending order 50 | * Corresponding entries in x 51 | 52 | See Also: 53 | :func:`hard_threshold` 54 | """ 55 | # Sort entries in x by their magnitude 56 | indices = jnp.argsort(jnp.abs(x)) 57 | # Pick the indices of K-largest (magnitude) entries in x (from behind) 58 | I = indices[:-K-1:-1] 59 | # Make sure that indices are sorted in ascending order 60 | I = jnp.sort(I) 61 | # Pick corresponding values 62 | x_I = x[I] 63 | return I, x_I 64 | 65 | def hard_threshold_by(x, t): 66 | """ 67 | Sets all entries in x to be zero which are less than t in magnitude 68 | 69 | Args: 70 | x (jax.numpy.ndarray): A sparse/compressible signal 71 | t (float): The threshold value 72 | 73 | Returns: 74 | (jax.numpy.ndarray): x modified such that all values below t are set to 0 75 | 76 | Note: 77 | This function doesn't change the length of x and can be JIT compiled 78 | 79 | See Also: 80 | :func:`hard_threshold` 81 | """ 82 | valid = jnp.abs(x) >= t 83 | return x * valid 84 | 85 | def largest_indices_by(x, t): 86 | """ 87 | Returns the locations of all entries in x which are larger than t in magnitude 88 | 89 | Args: 90 | x (jax.numpy.ndarray): A sparse/compressible signal 91 | t (float): The threshold value 92 | 93 | Returns: 94 | (jax.numpy.ndarray): An index vector of all entries in x which are above the threshold 95 | 96 | Note: 97 | This function cannot be JIT compiled as the length of output is data dependent 98 | 99 | See Also: 100 | :func:`hard_threshold_by` 101 | """ 102 | return jnp.where(jnp.abs(x) >= t)[0] 103 | 104 | 105 | def energy_threshold(signal, fraction): 106 | """ 107 | Keeps only as much coefficients in signal so as to capture a fraction of signal energy 108 | 109 | Args: 110 | x (jax.numpy.ndarray): A signal 111 | fraction (float): The fraction of energy to be preserved 112 | 113 | Returns: 114 | (jax.numpy.ndarray, jax.numpy.ndarray): A tuple comprising of: 115 | * Signal after thresholding 116 | * A binary mask of the indices to be kept 117 | 118 | Note: 119 | This function doesn't change the length of signal and can be JIT compiled 120 | 121 | See Also: 122 | :func:`hard_threshold` 123 | """ 124 | # signal length 125 | n = signal.size 126 | # compute energies 127 | energies = signal ** 2 128 | # sort in descending order 129 | idx = jnp.argsort(energies)[::-1] 130 | energies = energies[idx] 131 | # total energy 132 | s = jnp.sum(energies) * 1. 133 | # normalize 134 | energies = energies / s 135 | # convert to a cmf 136 | cmf = jnp.cumsum(energies) 137 | # find the index 138 | index = jnp.argmax(cmf >= fraction) 139 | # build the mask 140 | idx2 = jnp.arange(n) 141 | mask = jnp.where(idx2 <= index, 1, 0) 142 | # reshuffle the mask 143 | mask = mask.at[idx].set(mask) 144 | signal = signal * mask 145 | return signal, mask 146 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | 17 | References 18 | 19 | - https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html 20 | - https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.norm.html 21 | 22 | """ 23 | 24 | import jax.numpy as jnp 25 | 26 | EPS = jnp.finfo(jnp.float32).eps 27 | 28 | norm = jnp.linalg.norm 29 | 30 | from .util import promote_arg_dtypes 31 | 32 | def norm_l1(x): 33 | """ 34 | Computes the l_1 norm of a vector 35 | """ 36 | return jnp.sum(jnp.abs(x)) 37 | 38 | def sqr_norm_l2(x): 39 | """ 40 | Computes the squared l_2 norm of a vector 41 | """ 42 | return x.T @ x 43 | 44 | def norm_l2(x): 45 | """ 46 | Computes the l_2 norm of a vector 47 | """ 48 | return jnp.sqrt(x.T @ x) 49 | 50 | def norm_linf(x): 51 | """ 52 | Computes the l_inf norm of a vector 53 | """ 54 | return jnp.max(jnp.abs(x)) 55 | 56 | def norms_l1_cw(X): 57 | """ 58 | Computes the l_1 norm of each column of a matrix 59 | """ 60 | return norm(X, ord=1, axis=0) 61 | 62 | def norms_l1_rw(X): 63 | """ 64 | Computes the l_1 norm of each row of a matrix 65 | """ 66 | return norm(X, ord=1, axis=1) 67 | 68 | def norms_l2_cw(X): 69 | """ 70 | Computes the l_2 norm of each column of a matrix 71 | """ 72 | return norm(X, ord=2, axis=0, keepdims=False) 73 | 74 | def norms_l2_rw(X): 75 | """ 76 | Computes the l_2 norm of each row of a matrix 77 | """ 78 | return norm(X, ord=2, axis=1, keepdims=False) 79 | 80 | 81 | def norms_linf_cw(X): 82 | """ 83 | Computes the l_inf norm of each column of a matrix 84 | """ 85 | return norm(X, ord=jnp.inf, axis=0) 86 | 87 | def norms_linf_rw(X): 88 | """ 89 | Computes the l_inf norm of each row of a matrix 90 | """ 91 | return norm(X, ord=jnp.inf, axis=1) 92 | 93 | 94 | def sqr_norms_l2_cw(X): 95 | """ 96 | Computes the squared l_2 norm of each column of a matrix 97 | """ 98 | return jnp.sum(X * X, axis=0) 99 | 100 | def sqr_norms_l2_rw(X): 101 | """ 102 | Computes the l_2 norm of each row of a matrix 103 | """ 104 | return jnp.sum(X * X, axis=1) 105 | 106 | 107 | ###################################### 108 | # Normalization of vectors 109 | ###################################### 110 | 111 | def normalize_l1(x): 112 | """Normalizes a vector by its l_1-norm 113 | """ 114 | x = promote_arg_dtypes(x) 115 | x2 = jnp.abs(x) 116 | s = jnp.sum(x) + EPS 117 | return jnp.divide(x, s) 118 | 119 | def normalize_l2(x): 120 | """Normalizes a vector by its l_2-norm 121 | """ 122 | x = promote_arg_dtypes(x) 123 | s = jnp.sqrt(jnp.sum(x ** 2)) + EPS 124 | return jnp.divide(x, s) 125 | 126 | def normalize_linf(x): 127 | """Normalizes a vector by its l_inf-norm 128 | """ 129 | x = promote_arg_dtypes(x) 130 | s = jnp.max(jnp.abs(x)) + EPS 131 | return jnp.divide(x, s) 132 | 133 | ###################################### 134 | # Normalization of rows and columns 135 | ###################################### 136 | 137 | 138 | def normalize_l1_cw(X): 139 | """ 140 | Normalize each column of X per l_1-norm 141 | """ 142 | X = promote_arg_dtypes(X) 143 | X2 = jnp.abs(X) 144 | sums = jnp.sum(X2, axis=0) + EPS 145 | return jnp.divide(X, sums) 146 | 147 | def normalize_l1_rw(X): 148 | """ 149 | Normalize each row of X per l_1-norm 150 | """ 151 | X = promote_arg_dtypes(X) 152 | X2 = jnp.abs(X) 153 | sums = jnp.sum(X2, axis=1) + EPS 154 | # row wise sum should be a column vector 155 | sums = jnp.expand_dims(sums, axis=-1) 156 | # now broadcasting works well 157 | return jnp.divide(X, sums) 158 | 159 | def normalize_l2_cw(X): 160 | """ 161 | Normalize each column of X per l_2-norm 162 | """ 163 | X = promote_arg_dtypes(X) 164 | X2 = jnp.square(X) 165 | sums = jnp.sum(X2, axis=0) 166 | sums = jnp.sqrt(sums) 167 | return jnp.divide(X, sums) 168 | 169 | def normalize_l2_rw(X): 170 | """ 171 | Normalize each row of X per l_2-norm 172 | """ 173 | X = promote_arg_dtypes(X) 174 | X2 = jnp.square(X) 175 | sums = jnp.sum(X2, axis=1) 176 | sums = jnp.sqrt(sums) 177 | # row wise sum should be a column vector 178 | sums = jnp.expand_dims(sums, axis=-1) 179 | # now broadcasting works well 180 | return jnp.divide(X, sums) 181 | 182 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """A setuptools based setup module. 2 | """ 3 | import io 4 | import re 5 | from glob import glob 6 | from os.path import basename 7 | from os.path import dirname 8 | from os.path import join 9 | from os.path import splitext 10 | 11 | # Always prefer setuptools over distutils 12 | from setuptools import setup, find_namespace_packages 13 | # To use a consistent encoding 14 | from codecs import open 15 | from os import path 16 | 17 | exec(open('src/cr/nimble/version.py').read()) 18 | here = path.abspath(path.dirname(__file__)) 19 | 20 | def read(*names, **kwargs): 21 | with io.open( 22 | join(dirname(__file__), *names), 23 | encoding=kwargs.get('encoding', 'utf8') 24 | ) as fh: 25 | return fh.read() 26 | 27 | # Get the long description from the README file 28 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 29 | long_description = f.read() 30 | 31 | def _parse_requirements(filename): 32 | with open(path.join(here, 'requirements', filename)) as f: 33 | return [ 34 | line.rstrip() 35 | for line in f 36 | if not (line.isspace() or line.startswith('#')) 37 | ] 38 | 39 | setup( 40 | name='cr-nimble', 41 | 42 | version=__version__, 43 | 44 | description='Iterative algorithms for numerical linear algebra with JAX', 45 | long_description=long_description, 46 | long_description_content_type="text/markdown", 47 | 48 | # The project's main homepage. 49 | url='https://carnotresearch.github.io/cr-nimble', 50 | download_url=f"https://github.com/carnotresearch/cr-nimble/archive/v{__version__}.tar.gz", 51 | 52 | # Author details 53 | author='CR-Suite Development Team', 54 | author_email='contact@carnotresearch.com', 55 | 56 | # Choose your license 57 | license='Apache 2.0: http://www.apache.org/licenses/LICENSE-2.0', 58 | 59 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 60 | classifiers=[ 61 | # How mature is this project? Common values are 62 | # 3 - Alpha 63 | # 4 - Beta 64 | # 5 - Production/Stable 65 | 'Development Status :: 3 - Alpha', 66 | 67 | # Indicate who your project is intended for 68 | 'Intended Audience :: Developers', 69 | 'Topic :: Multimedia', 70 | 'Topic :: Multimedia :: Video', 71 | 'Topic :: Scientific/Engineering :: Image Processing', 72 | 'Topic :: Scientific/Engineering :: Image Recognition', 73 | # License 74 | 'License :: OSI Approved :: Apache Software License', 75 | # OS Support 76 | 'Operating System :: Unix', 77 | 'Operating System :: POSIX', 78 | # 'Operating System :: Microsoft :: Windows', 79 | # Specify the Python versions you support here. In particular, ensure 80 | # that you indicate whether you support Python 2, Python 3 or both. 81 | # 'Programming Language :: Python :: 3.6', 82 | 'Programming Language :: Python :: 3.7', 83 | 'Programming Language :: Python :: 3.8', 84 | 'Programming Language :: Python :: 3.9', 85 | 'Programming Language :: Python :: Implementation :: CPython', 86 | ], 87 | project_urls={ 88 | 'Issue Tracker': "https://github.com/carnotresearch/cr-nimble/issues" 89 | }, 90 | # What does your project relate to? 91 | keywords='Linear Algebra', 92 | 93 | # You can just specify the packages manually here if your project is 94 | # simple. Or you can use find_packages(). 95 | packages=find_namespace_packages('src', include=['cr.*']), 96 | package_dir={'': 'src'}, 97 | py_modules=[splitext(basename(path))[0] for path in glob('src/*.py')], 98 | python_requires=">=3.7", 99 | # Alternatively, if you want to distribute just a my_module.py, uncomment 100 | # this: 101 | # py_modules=["my_module"], 102 | 103 | # List run-time dependencies here. These will be installed by pip when 104 | # your project is installed. For an analysis of "install_requires" vs pip's 105 | # requirements files see: 106 | # https://packaging.python.org/en/latest/requirements.html 107 | install_requires=_parse_requirements('requirements.txt'), 108 | tests_require=_parse_requirements('requirements-tests.txt'), 109 | # List additional groups of dependencies here (e.g. development 110 | # dependencies). You can install these using the following syntax, 111 | # for example: 112 | # $ pip install -e .[dev,test] 113 | extras_require={ 114 | 'dev': [ ], 115 | 'test': _parse_requirements('requirements-tests.txt'), 116 | }, 117 | include_package_data=True, 118 | zip_safe=False, 119 | 120 | # If there are data files included in your packages that need to be 121 | # installed, specify them here. If using Python 2.6 or less, then these 122 | # have to be included in MANIFEST.in as well. 123 | package_data={ 124 | }, 125 | 126 | # Although 'package_data' is the preferred approach, in some case you may 127 | # need to place data files outside of your packages. See: 128 | # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa 129 | # In this case, 'data_file' will be installed into '/my_data' 130 | data_files=[], 131 | 132 | # To provide executable scripts, use entry points in preference to the 133 | # "scripts" keyword. Entry points provide cross-platform support and allow 134 | # pip to create the appropriate form of executable for the target platform. 135 | entry_points={ 136 | 'console_scripts': [ 137 | ], 138 | }, 139 | ) 140 | -------------------------------------------------------------------------------- /src/cr/nimble/_src/householder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 CR-Suite Development Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Some basic linear transformations 17 | """ 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from cr.nimble import promote_arg_dtypes 23 | from cr.nimble import to_row_vec, to_col_vec 24 | 25 | 26 | def householder_vec(x): 27 | """Computes a Householder vector for :math:`x` 28 | 29 | GVL4: Algorithm 5.1.1 30 | """ 31 | x = promote_arg_dtypes(x) 32 | m = len(x) 33 | if m == 1: 34 | return jnp.array(0), jnp.array(0) 35 | x_1 = x[0] 36 | x_rest = x[1:] 37 | sigma = x_rest.T @ x_rest 38 | v = jnp.hstack((1, x_rest)) 39 | 40 | def non_zero_sigma(v): 41 | mu = jnp.sqrt(x_1*x_1 + sigma) 42 | v_1 = jax.lax.cond(x_1 >= 0, 43 | lambda _: x_1 - mu, 44 | lambda _: -sigma/(x_1 + mu), 45 | operand=None) 46 | v = v.at[0].set(v_1) 47 | beta = 2. * v_1 * v_1 / (sigma + v_1 * v_1) 48 | v = v / v_1 49 | return v, beta 50 | 51 | def zero_sigma(v): 52 | beta = jax.lax.cond(x_1 >= 0, lambda _: 0., lambda _: -2., operand=None) 53 | return v, beta 54 | 55 | v, beta = jax.lax.cond(sigma == 0, zero_sigma, non_zero_sigma, operand=v) 56 | return v , beta 57 | 58 | 59 | def householder_matrix(x): 60 | """Computes a Householder refection matrix for :math:`x` 61 | """ 62 | v, beta = householder_vec(x) 63 | return jnp.eye(len(x)) - beta * jnp.outer(v, v) 64 | 65 | 66 | def householder_premultiply(v, beta, A): 67 | """Pre-multiplies a Householder reflection defined by :math:`v, beta` to a matrix A, PA 68 | """ 69 | assert v.ndim == 1 70 | assert A.ndim == 2 71 | vt = to_row_vec(v) 72 | v = to_col_vec(v) 73 | return A - (beta * v) @(vt @ A) 74 | 75 | def householder_postmultiply(v, beta, A): 76 | """Post-multiplies a Householder reflection defined by :math:`v, beta` to a matrix A, AP 77 | """ 78 | assert v.ndim == 1 79 | assert A.ndim == 2 80 | vt = to_row_vec(v) 81 | v = to_col_vec(v) 82 | return A - (A @ v) @ (beta * vt) 83 | 84 | 85 | def householder_vec_(x): 86 | """Computes a Householder vector for :math:`x` 87 | """ 88 | x = promote_arg_dtypes(x) 89 | m = len(x) 90 | if m == 1: 91 | return jnp.array(0), jnp.array(0) 92 | x_1 = x[0] 93 | x_rest = x[1:] 94 | sigma = x_rest.T @ x_rest 95 | v = jnp.hstack((1, x_rest)) 96 | 97 | if sigma == 0: 98 | if x_1 >= 0: 99 | beta = 0 100 | else: 101 | beta = -2 102 | else: 103 | mu = jnp.sqrt(x_1*x_1 + sigma) 104 | if x_1 <= 0: 105 | v = v.at[0].set(x_1 - mu) 106 | else: 107 | v = v.at[0].set(-sigma/(x_1 + mu)) 108 | v_1 = v[0] 109 | beta = 2 * v_1 * v_1 / (sigma + v_1 * v_1) 110 | v = v / v_1 111 | return v , beta 112 | 113 | 114 | def householder_ffm_jth_v_beta(A,j): 115 | """GVL4 EQ 5.1.4 v, beta calculation 116 | """ 117 | v = A[j+1:, j] 118 | ms = v.T @ v 119 | beta = 2/(1 + ms) 120 | v = jnp.hstack((1, v)) 121 | return v, beta 122 | 123 | def householder_ffm_premultiply(A, C): 124 | """ 125 | Computes Q^T C where Q is stored in its factored form in A. 126 | 127 | Each column j, of A contains the essential part of the j-th 128 | Householder vector. 129 | 130 | GVL4 EQ 5.1.4 131 | """ 132 | m, n = A.shape 133 | for j in range(n): 134 | v, beta = householder_ffm_jth_v_beta(A, j) 135 | C2 = householder_premultiply(v, beta, C[j:,:]) 136 | C = C.at[j:, :].set(C2) 137 | return C 138 | 139 | def householder_ffm_backward_accum(A, k): 140 | """ 141 | Computes k columns of Q from the factored form representation of Q stored in A. 142 | 143 | GVL4 EQ 5.1.5 144 | """ 145 | m, n = A.shape 146 | Q = jnp.eye(m,k) 147 | for j in range(n-1, -1, -1): 148 | v, beta = householder_ffm_jth_v_beta(A, j) 149 | QQ = householder_premultiply(v, beta, Q[j:,j:]) 150 | Q = Q.at[j:, j:].set(QQ) 151 | return Q 152 | 153 | 154 | def householder_ffm_to_wy(A): 155 | """ 156 | Computes the WY representation of Q such that Q = I_m - W Y^T from the factored form representation 157 | 158 | GVL4 algorithm 5.1.2 159 | """ 160 | m, r = A.shape 161 | v, beta = householder_ffm_jth_v_beta(A, 0) 162 | v = to_col_vec(v) 163 | Y = v 164 | W = beta * v 165 | for j in range (1, r-1): 166 | v, beta = householder_ffm_jth_v_beta(A, j) 167 | v = to_col_vec(v) 168 | v2 = jnp.vstack((jnp.zeros(j), v)) 169 | z = beta * (v2 - (W @ Y[j:,:].T)@v) 170 | W = jnp.hstack((W, z)) 171 | Y = jnp.hstack((Y, v2)) 172 | return W, Y 173 | 174 | def householder_qr_packed(A): 175 | """Computes the QR = A factorization of A using Householder reflections. Returns packed factorization. 176 | 177 | Algorithm 5.2.1 178 | """ 179 | A = promote_arg_dtypes(A) 180 | m, n = A.shape 181 | assert m >= n 182 | for j in range(n-1): 183 | x = A[j:, j] 184 | v, beta = householder_vec(x) 185 | A2 = householder_premultiply(v, beta, A[j:, j:]) 186 | A = A.at[j:, j:].set(A2) 187 | # place the essential part of the Householder vector 188 | A = A.at[j+1:,j].set(v[1:]) 189 | return A 190 | 191 | def householder_split_qf_r(A): 192 | """Splits a packed QR factorization into QF and R 193 | """ 194 | # The upper triangular part is R 195 | R = jnp.triu(A) 196 | # The remaining lower triangular part of A is the factored form representation of Q 197 | QF = jnp.tril(A[:,:-1], -1) 198 | return QF, R 199 | 200 | def householder_qr(A): 201 | """Computes the QR = A factorization of A using Householder reflections 202 | 203 | Algorithm 5.2.1 204 | """ 205 | m, n = A.shape 206 | A = householder_qr_packed(A) 207 | QF , R = householder_split_qf_r(A) 208 | Q = householder_ffm_backward_accum(QF, n) 209 | return Q, R 210 | 211 | 212 | -------------------------------------------------------------------------------- /docs/_static/js/mathconf.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | macros: { 4 | AA: '{\\mathbb{A}}', 5 | BB: '{\\mathbb{B}}', // Complex space symbol 6 | CC: '{\\mathbb{C}}', // A dictionary 7 | DD: '{\\mathbb{D}}', // Expectation operator 8 | EE: '{\\mathbb{E}}', // A field 9 | FF: '{\\mathbb{F}}', // A group 10 | GG: '{\\mathbb{G}}', // A Hilbert space 11 | HH: '{\\mathbb{H}}', // Irrational numbers 12 | II: '{\\mathbb{I}}', 13 | JJ: '{\\mathbb{J}}', // Real or complex space symbol 14 | KK: '{\\mathbb{K}}', // Natural numbers 15 | NN: '{\\mathbb{N}}', 16 | Nat: '{\\mathbb{N}}', // Probability set symbol 17 | PP: '{\\mathbb{P}}', // Rational numbers 18 | QQ: '{\\mathbb{Q}}', // Real line symbol 19 | RR: '{\\mathbb{R}}', 20 | RRMN: '{\\mathbb{R}^{M \\times N} }', // A linear operator 21 | TT: '{\\mathbb{T}}', // Another linear operator 22 | UU: '{\\mathbb{U}}', // A vector space 23 | VV: '{\\mathbb{V}}', // A subspace 24 | WW: '{\\mathbb{W}}', // An inner product space 25 | XX: '{\\mathbb{X}}', // Integers 26 | ZZ: '{\\mathbb{Z}}', // All mathcal shortcuts 27 | AAA: '{\\mathcal{A}}', 28 | BBB: '{\\mathcal{B}}', 29 | CCC: '{\\mathcal{C}}', 30 | DDD: '{\\mathcal{D}}', 31 | EEE: '{\\mathcal{E}}', 32 | FFF: '{\\mathcal{F}}', 33 | GGG: '{\\mathcal{G}}', 34 | HHH: '{\\mathcal{H}}', 35 | III: '{\\mathcal{I}}', 36 | JJJ: '{\\mathcal{J}}', 37 | KKK: '{\\mathcal{K}}', 38 | LLL: '{\\mathcal{L}}', 39 | MMM: '{\\mathcal{M}}', 40 | NNN: '{\\mathcal{N}}', 41 | OOO: '{\\mathcal{O}}', 42 | PPP: '{\\mathcal{P}}', 43 | QQQ: '{\\mathcal{Q}}', 44 | RRR: '{\\mathcal{R}}', 45 | SSS: '{\\mathcal{S}}', 46 | TTT: '{\\mathcal{T}}', 47 | UUU: '{\\mathcal{U}}', 48 | VVV: '{\\mathcal{V}}', 49 | WWW: '{\\mathcal{W}}', 50 | XXX: '{\\mathcal{X}}', 51 | YYY: '{\\mathcal{Y}}', 52 | ZZZ: '{\\mathcal{Z}}', 53 | Tau: '{\\mathcal{T}}', 54 | Chi: '{\\mathcal{X}}', 55 | Eta: '{\\mathcal{H}}', // Real part of a complex number 56 | Re: '\\operatorname{Re}', 57 | Im: '\\operatorname{Im}', // Null space 58 | NullSpace: '{\\mathcal{N}}', // Column space 59 | ColSpace: '{\\mathcal{C}}', // Row space 60 | RowSpace: '{\\mathcal{R}}', // Power set 61 | Power: '{\\mathop{\\mathcal{P}}}', 62 | LinTSpace: '{\\mathcal{L}}', // Range 63 | Range: '{\\mathrm{R}}', // image 64 | Image: '{\\mathrm{im}}', // Kernel 65 | Kernel: '{\\mathrm{ker}}', // Span 66 | Span: '{\\mathrm{span}}', // Nullity of an operator 67 | Nullity: '{\\mathrm{nullity}}', // Dimension of a vector space 68 | Dim: '{\\mathrm{dim}}', // Rank of a matrix 69 | Rank: '{\\mathrm{rank}}', // Trace of a matrix 70 | Trace: '{\\mathrm{tr}}', // Diagonal of a matrix 71 | Diag: '{\\mathrm{diag}}', // Signum function 72 | sgn: '{\\mathrm{sgn}}', // Support function 73 | supp: '{\\mathrm{supp}}', // Row support 74 | rowsupp: '{\\mathop{\\mathrm{rowsupp}}}', // Entry wise absolute value function 75 | abs: '{\\mathop{\\mathrm{abs}}}', // error function 76 | erf: '{\\mathop{\\mathrm{erf}}}', // complementary error function 77 | erfc: '{\\mathop{\\mathrm{erfc}}}', // Sub Gaussian function 78 | Sub: '{\\mathop{\\mathrm{Sub}}}', // Strictly sub Gaussian function 79 | SSub: '{\\mathop{\\mathrm{SSub}}}', // Variance function 80 | Var: '{\\mathop{\\mathrm{Var}}}', // Covariance matrix 81 | Cov: '{\\mathop{\\mathrm{Cov}}}', // Affine hull of a set 82 | AffineHull: '{\\mathop{\\mathrm{aff}}}', // Convex hull of a set 83 | ConvexHull: '{\\mathop{\\mathrm{conv}}}', // Set theory related stuff 84 | Card: ['\\mathrm{card}\\,{#1}', 1], 85 | argmin: '\\mathrm{arg}\\,\\mathrm{min}', 86 | argmax: '\\mathrm{arg}\\,\\mathrm{max}', 87 | EmptySet: '\\varnothing', // Forall operator with some space 88 | Forall: '\\; \\forall \\;', // Topology related stuff 89 | Interior: ['\\mathring{#1}', 1], 90 | Closure: ['\\overline{#1}', 1], // Probability distributions 91 | Gaussian: '{\\mathcal{N}}', // Sparse representations related stuff 92 | spark: '{\\mathop{\\mathrm{spark}}}', // Exact Recovery Criterion 93 | ERC: '{\\mathop{\\mathrm{ERC}}}', // Maximum correlation 94 | Maxcor: '{\\mathop{\\mathrm{maxcor}}}', // pseudo-inverse 95 | dag: '\\dagger', // bracket operator 96 | Bracket: '\\left [ \\; \\right ]', 97 | bold: ['{\\bf #1}', 1], // OneVec 98 | OneVec: '\\mathbb{1}', 99 | ZeroVec: '0', 100 | OneMat: '\\mathbf{1}', 101 | bigO: ['\\mathop{}\\mathopen{}\\mathcal{O}\\mathopen{}\\left(#1\\right)', 1], 102 | smallO: ['\\scriptstyle\\mathcal{O}\\left(#1\\right)', 1] 103 | } 104 | } 105 | }; 106 | // A $( document ).ready() block. 107 | $(document).ready(function () { 108 | 109 | var on_proof_caption_click = function () { 110 | 111 | $header = $(this); 112 | //getting the next element 113 | $content = $header.next(); 114 | //open up the content needed - toggle the slide- if visible, slide up, if not slidedown. 115 | $content.slideToggle(500, function () { 116 | //execute this after slideToggle is done 117 | //change text of header based on visibility of content div 118 | $header.text(function () { 119 | //change text based on condition 120 | return $content.is(":visible") ? "Proof" : "Click to see proof"; 121 | }); 122 | }); 123 | 124 | } 125 | // Attach the on click handler to each proof element 126 | $(".proof_caption").click(on_proof_caption_click); 127 | 128 | // MathJax.Hub.Queue(function () { 129 | // // Collapse all proof elements in the beginning 130 | // $(".proof_caption").each(on_proof_caption_click); 131 | // // We want to disable the proofs initially 132 | // $(".proof_caption").each(function () { 133 | // $header = $(this); 134 | // $header.text("Click to see proof"); 135 | // }); 136 | // }); 137 | 138 | }); 139 | 140 | -------------------------------------------------------------------------------- /tests/signal/test_signal.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | from cr.nimble import * 5 | from cr.nimble.dsp import * 6 | import jax.numpy as jnp 7 | from jax import random 8 | 9 | from numpy.testing import (assert_almost_equal, assert_allclose, assert_, 10 | assert_equal, assert_raises, assert_raises_regex, 11 | assert_array_equal, assert_warns) 12 | 13 | 14 | def test_find_first_signal_with_energy_le_rw(): 15 | X = jnp.eye(10) 16 | X = X.at[5,5].set(.5) 17 | index = find_first_signal_with_energy_le_rw(X, 0.3) 18 | assert index == 5 19 | index = find_first_signal_with_energy_le_rw(X, 0.2) 20 | assert index == -1 21 | 22 | def test_find_first_signal_with_energy_le_cw(): 23 | X = jnp.eye(10) 24 | X = X.at[5,5].set(.5) 25 | index = find_first_signal_with_energy_le_cw(X, 0.3) 26 | assert index == 5 27 | index = find_first_signal_with_energy_le_cw(X, 0.2) 28 | assert index == -1 29 | 30 | def test_randomize_rows(): 31 | key = random.PRNGKey(0) 32 | X = jnp.ones((4, 4)) 33 | Y = randomize_rows(key, X) 34 | assert jnp.allclose(X, Y) 35 | 36 | def test_randomize_cols(): 37 | key = random.PRNGKey(0) 38 | X = jnp.ones((4, 4)) 39 | Y = randomize_cols(key, X) 40 | assert jnp.allclose(X, Y) 41 | 42 | 43 | def test_largest_indices(): 44 | x = jnp.array([5, 1, 3, 4, 2]) 45 | indices = largest_indices(x, 2) 46 | assert jnp.array_equal(indices, jnp.array([0, 3])) 47 | 48 | def test_largest_indices_rw(): 49 | x = jnp.array([[5, 1, 3, 4, 2]]) 50 | indices = largest_indices_rw(x, 2) 51 | assert jnp.array_equal(indices, jnp.array([[0, 3]])) 52 | 53 | def test_largest_indices_cw(): 54 | x = jnp.array([[5, 1, 3, 4, 2]]) 55 | indices = largest_indices_cw(x.T, 2) 56 | expected = jnp.array([[0, 3]]).T 57 | assert jnp.array_equal(indices, expected) 58 | 59 | def test_take_along_rows(): 60 | x = jnp.array([[5, 1, 3, 4, 2]]) 61 | indices = largest_indices_rw(x, 2) 62 | y = take_along_rows(x, indices) 63 | assert jnp.array_equal(y, jnp.array([[5, 4]])) 64 | 65 | def test_take_along_cols(): 66 | x = jnp.array([[5, 1, 3, 4, 2]]).T 67 | indices = largest_indices_cw(x, 2) 68 | y = take_along_cols(x, indices) 69 | expected = jnp.array([[5, 4]]).T 70 | assert jnp.array_equal(y, expected) 71 | 72 | 73 | def test_sparse_approximation(): 74 | x = jnp.array([5, 1, 3, 4, 2]) 75 | y = sparse_approximation(x, 2) 76 | expected = jnp.array([5, 0, 0, 4, 0]) 77 | assert jnp.array_equal(y, expected) 78 | 79 | def test_sparse_approximation_0(): 80 | x = jnp.array([5, 1, 3, 4, 2]) 81 | y = sparse_approximation(x, 0) 82 | expected = jnp.array([0, 0, 0, 0, 0]) 83 | assert jnp.array_equal(y, expected) 84 | 85 | def test_sparse_approximation_cw(): 86 | x = jnp.array([[5, 1, 3, 4, 2]]).T 87 | y = sparse_approximation_cw(x, 2) 88 | expected = jnp.array([[5, 0, 0, 4, 0]]).T 89 | assert jnp.array_equal(y, expected) 90 | y = sparse_approximation_cw(x, 0) 91 | expected = jnp.array([[0, 0, 0, 0, 0]]).T 92 | assert jnp.array_equal(y, expected) 93 | 94 | 95 | def test_sparse_approximation_rw(): 96 | x = jnp.array([[5, 1, 3, 4, 2]]) 97 | y = sparse_approximation_rw(x, 2) 98 | expected = jnp.array([[5, 0, 0, 4, 0]]) 99 | assert jnp.array_equal(y, expected) 100 | y = sparse_approximation_rw(x, 0) 101 | expected = jnp.array([[0, 0, 0, 0, 0]]) 102 | assert jnp.array_equal(y, expected) 103 | 104 | def test_build_signal_from_indices_and_values(): 105 | n = 4 106 | indices = jnp.array([1, 3]) 107 | values = jnp.array([9, 15]) 108 | x = build_signal_from_indices_and_values(n, indices, values) 109 | expected = jnp.array([0, 9, 0, 15]) 110 | assert jnp.array_equal(x, expected) 111 | 112 | def test_nonzero_values(): 113 | x = jnp.array([0, 9, 0, 15]) 114 | y = nonzero_values(x) 115 | expected = jnp.array([9, 15]) 116 | assert jnp.array_equal(y, expected) 117 | 118 | def test_nonzero_indices(): 119 | x = jnp.array([0, 9, 0, 15]) 120 | y = nonzero_indices(x) 121 | expected = jnp.array([1, 3]) 122 | assert jnp.array_equal(y, expected) 123 | 124 | def test_hard_threshold(): 125 | x = jnp.array([5, 1, 3, 6, 2]) 126 | I, x_I = hard_threshold(x, 2) 127 | assert I.size == 2 128 | assert x_I.size == 2 129 | assert jnp.array_equal(I, jnp.array([3, 0])) 130 | assert jnp.array_equal(x_I, jnp.array([6, 5])) 131 | 132 | def test_hard_threshold_sorted(): 133 | x = jnp.array([5, 1, 3, 6, 2]) 134 | I, x_I = hard_threshold_sorted(x, 2) 135 | assert I.size == 2 136 | assert x_I.size == 2 137 | assert jnp.array_equal(I, jnp.array([0, 3])) 138 | assert jnp.array_equal(x_I, jnp.array([5, 6])) 139 | 140 | def test_dynamic_range(): 141 | x = jnp.array([4, -2, 3, 3, 8]) 142 | dr = dynamic_range(x) 143 | assert dr >= 12 and dr <= 12.1 144 | 145 | def test_nonzero_dynamic_range(): 146 | x = jnp.array([4, -2, 3, 3, 8]) 147 | dr = nonzero_dynamic_range(x) 148 | assert dr >= 12 and dr <= 12.1 149 | x = jnp.array([4, -2, 0, 0, 8]) 150 | dr = nonzero_dynamic_range(x) 151 | assert dr >= 12 and dr <= 12.1 152 | 153 | 154 | def test_SignalsComparison(): 155 | n = 80 156 | s = 10 157 | X = jnp.ones((n,s)) 158 | Y = 1.1 * jnp.ones((n, s)) 159 | cmp = SignalsComparison(X, Y) 160 | assert len(cmp.reference_norms) == s 161 | assert len(cmp.estimate_norms) == s 162 | assert len(cmp.difference_norms) == s 163 | assert len(cmp.reference_energies) == s 164 | assert len(cmp.estimate_energies) == s 165 | assert len(cmp.difference_energies) == s 166 | assert len(cmp.error_to_signal_norms) == s 167 | assert len(cmp.signal_to_noise_ratios) == s 168 | assert cmp.cum_reference_norm 169 | assert cmp.cum_estimate_norm 170 | assert cmp.cum_difference_norm 171 | assert cmp.cum_error_to_signal_norm 172 | assert cmp.cum_signal_to_noise_ratio 173 | cmp.summarize() 174 | assert(len(snrs_cw(X, Y)) == s) 175 | assert(len(snrs_rw(X, Y)) == n) 176 | cmp = SignalsComparison(X[:,0], Y[:,0]) 177 | 178 | def test_support(): 179 | x = jnp.concatenate((jnp.zeros(5), jnp.ones(5))) 180 | i = support(x) 181 | assert len(i) == 5 182 | 183 | def test_hard_threshold_by(): 184 | x = jnp.arange(10) 185 | y = hard_threshold_by(x, 5) 186 | i = support(y) 187 | assert len(i) == 5 188 | 189 | def test_largest_indices_by(): 190 | x = jnp.arange(10) 191 | y = largest_indices_by(x, 5) 192 | assert len(y) == 5 193 | 194 | def test_normalize(): 195 | x = jnp.arange(10) * 1. 196 | y, mu, sigma = scale_0_mean_1_var(x) 197 | assert_almost_equal(jnp.mean(y), 0) 198 | assert_almost_equal(jnp.var(y), 1.) 199 | 200 | def test_power_spectrum(): 201 | t = jnp.linspace(0, 10, 1024) 202 | f = 1 203 | x = jnp.cos(2*jnp.pi*t) 204 | n = len(x) 205 | n2 = n // 2 206 | f, sxx = power_spectrum(x) 207 | assert len(f) == n2 208 | assert len(sxx) == n2 209 | assert jnp.all(sxx >= 0) 210 | --------------------------------------------------------------------------------