├── requirements.txt ├── .gitignore ├── pcax ├── __init__.py ├── pca_test.py └── pca.py ├── .github ├── dependabot.yml └── workflows │ ├── pytest.yml │ └── python-publish.yml ├── pyproject.toml ├── LICENSE └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore python files 2 | **/__pycache__/** 3 | 4 | -------------------------------------------------------------------------------- /pcax/__init__.py: -------------------------------------------------------------------------------- 1 | from .pca import fit, transform, recover 2 | 3 | __version__ = '0.1.0' 4 | 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: weekly 7 | timezone: CET 8 | open-pull-requests-limit: 10 9 | 10 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | name: "Python ${{ matrix.python-version }}" 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.11", "3.12", "3.13"] 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: actions/setup-python@v4 22 | with: 23 | python-version: "${{ matrix.python-version }}" 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install -e . pytest 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | 31 | - name: Test with pytest 32 | run: | 33 | pytest -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "jax"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pcax" 7 | dynamic = ["version"] 8 | description = "Minimal Principal Component Analysis (PCA) implementation using JAX." 9 | readme = "README.md" 10 | authors = [ 11 | {name = "Albert Alonso", email = "alonfnt@pm.me"}, 12 | ] 13 | license = {file = "LICENSE"} 14 | keywords = ["jax", "pca", "machine-learning"] 15 | classifiers = [ 16 | "Development Status :: 3 - Alpha", 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3", 20 | ] 21 | dependencies = ["jax",] 22 | 23 | [tool.setuptools.dynamic] 24 | version = {attr = "pcax.__version__"} 25 | 26 | [project.urls] 27 | "Homepage" = "https://github.com/alonfnt/pcax" 28 | "Documentation" = "https://github.com/alonfnt/pcax" 29 | "Source" = "https://github.com/alonfnt/pcax" 30 | "Bug Tracker" = "https://github.com/alonfnt/pcax/issues" 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Albert Alonso 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | [![tests](https://github.com/alonfnt/pcax/actions/workflows/pytest.yml/badge.svg)](https://github.com/alonfnt/pcax/actions/workflows/pytest.yml) 6 | [![PyPI](https://img.shields.io/pypi/v/pcax.svg)](https://pypi.org/project/pcax/) 7 | 8 | `pcax` is a minimal PCA implementation in [JAX](https://github.com/jax-ml/jax) that’s both GPU/TPU/CPU‑native and fully differentiable. 9 | It keeps data and computation on-device with zero-copy transfers, lets you backpropagate through your dimensionality reduction step, and plugs directly your JAX workflows for seamless, efficient model integration. 10 | 11 | ## Usage 12 | ```python 13 | import pcax 14 | 15 | # Fit the PCA model with 3 components on your data X 16 | state = pcax.fit(X, n_components=3) 17 | 18 | # Transform X to its principal components 19 | X_pca = pcax.transform(state, X) 20 | 21 | # Recover the original X from its principal components 22 | X_recover = pcax.recover(state, X_pca) 23 | ``` 24 | 25 | ## Installation 26 | `pcax` can be installed from PyPI via `pip` 27 | ``` 28 | pip install pcax 29 | ``` 30 | 31 | ## Citation 32 | If you use `pcax` in your research and need to reference it, please cite it as follows: 33 | ``` 34 | @software{alonso_pcax, 35 | author = {Alonso, Albert}, 36 | title = {pcax: Minimal Principal Component Analysis (PCA) Implementation in JAX}, 37 | url = {https://github.com/alonfnt/pcax}, 38 | version = {0.1.0} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /pcax/pca_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | from pcax import fit, transform, recover 6 | 7 | KEY = jax.random.PRNGKey(42) 8 | 9 | 10 | def test_fit_invalid_solver(): 11 | with pytest.raises(ValueError): 12 | fit(jnp.zeros((10, 5)), n_components=2, solver="invalid_solver") 13 | 14 | 15 | @pytest.mark.parametrize("n_components", [1, 2, 5, 10]) 16 | @pytest.mark.parametrize("n_entries", [100, 200, 300]) 17 | @pytest.mark.parametrize("solver", ["full", "randomized"]) 18 | def test_fit_output_shapes(n_entries, n_components, solver): 19 | x = jax.random.normal(KEY, shape=(n_entries, 50)) 20 | rng, _ = jax.random.split(KEY) 21 | 22 | state = fit(x, n_components=n_components, solver=solver, rng=rng) 23 | 24 | assert state.components.shape == (n_components, x.shape[1]) 25 | assert state.means.shape == (1, x.shape[1]) 26 | assert state.explained_variance.shape == (n_components,) 27 | 28 | 29 | def test_fit_zero_mean(): 30 | x = jax.random.normal(KEY, shape=(100, 50)) 31 | n_components = 5 32 | state = fit(x, n_components=n_components, solver="full") 33 | x_zero_mean = x - state.means 34 | x_pca = jnp.dot(x_zero_mean, state.components.T) 35 | x_pca2 = transform(state, x) 36 | assert jnp.allclose(x_pca.mean(axis=0), jnp.zeros(n_components), atol=1e-5) 37 | assert jnp.allclose(x_pca, x_pca2) 38 | 39 | 40 | @pytest.mark.parametrize("n_components", [50]) 41 | @pytest.mark.parametrize("n_entries", [300, 500]) 42 | @pytest.mark.parametrize("solver", ['full', 'randomized']) 43 | def test_reconstruction(n_entries, n_components, solver): 44 | x = jax.random.normal(KEY, shape=(n_entries, 50)) 45 | state = fit(x, n_components=n_components, solver=solver) 46 | x_pca = transform(state, x) 47 | x_recovered = recover(state, x_pca) 48 | assert x_recovered.shape == x.shape 49 | assert jnp.allclose(x, x_recovered, atol=1e-1) 50 | -------------------------------------------------------------------------------- /pcax/pca.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import NamedTuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | 8 | class PCAState(NamedTuple): 9 | """Stores the state of a fitted PCA model. 10 | 11 | Attributes: 12 | components: Principal components (right singular vectors). 13 | means: Mean of each feature in the original data. 14 | explained_variance: Variance explained by each component. 15 | """ 16 | 17 | components: jax.Array 18 | means: jax.Array 19 | explained_variance: jax.Array 20 | 21 | 22 | def transform(state: PCAState, x: jax.Array) -> jax.Array: 23 | """ 24 | Projects data into the PCA space defined by the fitted components. 25 | 26 | Args: 27 | state: The fitted PCA state. 28 | x: Input data of shape (n_samples, n_features). 29 | 30 | Returns: 31 | Transformed data in the reduced PCA space. 32 | """ 33 | x = x - state.means 34 | return jnp.dot(x, jnp.transpose(state.components)) 35 | 36 | 37 | def recover(state: PCAState, x: jax.Array) -> jax.Array: 38 | """ 39 | Reconstructs data from its PCA-projected representation. 40 | 41 | Args: 42 | state: The fitted PCA state. 43 | x: Transformed data of shape (n_samples, n_components). 44 | 45 | Returns: 46 | Approximate reconstruction of the original input. 47 | """ 48 | return jnp.dot(x, state.components) + state.means 49 | 50 | 51 | def fit( 52 | x: jax.Array, n_components: int, solver: str = "full", rng: jax.Array | None = None 53 | ) -> PCAState: 54 | """ 55 | Fits PCA on the input data using either full or randomized solver. 56 | 57 | Args: 58 | x: Input data of shape (n_samples, n_features). 59 | n_components: Number of principal components to retain. 60 | solver: Either "full" or "randomized". 61 | rng: PRNG key for randomized solver. 62 | 63 | Returns: 64 | The learned PCA transformation state. 65 | 66 | Raises: 67 | ValueError: If an invalid solver name is provided. 68 | """ 69 | if solver == "full": 70 | return _fit_full(x, n_components) 71 | elif solver == "randomized": 72 | if rng is None: 73 | rng = jax.random.PRNGKey(n_components) 74 | return _fit_randomized(x, n_components, rng) 75 | else: 76 | raise ValueError("Invalid solver: must be 'full' or 'randomized'") 77 | 78 | 79 | @partial(jax.jit, static_argnames="n_components") 80 | def _fit_full(x: jax.Array, n_components: int) -> PCAState: 81 | """ 82 | Performs exact PCA using full SVD on centered input. 83 | Used internally when `solver='full'`. 84 | """ 85 | 86 | n_samples, n_features = x.shape 87 | 88 | # Subtract the mean of the input data 89 | means = x.mean(axis=0, keepdims=True) 90 | x = x - means 91 | 92 | # Factorize the data matrix with singular value decomposition. 93 | U, S, Vt = jax.scipy.linalg.svd(x, full_matrices=False) 94 | 95 | # Compute the explained variance 96 | explained_variance = (S[:n_components] ** 2) / (n_samples - 1) 97 | 98 | # Return the transformation matrix 99 | A = Vt[:n_components] 100 | return PCAState(components=A, means=means, explained_variance=explained_variance) 101 | 102 | 103 | def _fit_randomized( 104 | x: jax.Array, n_components: int, rng: jax.Array, n_iter: int = 5 105 | ) -> PCAState: 106 | """ 107 | Randomized PCA approximation using power iterations. 108 | Based on Halko et al., [https://doi.org/10.48550/arXiv.1007.5510]. 109 | Used internally when `solver='randomized'`. 110 | """ 111 | n_samples, n_features = x.shape 112 | means = jnp.mean(x, axis=0, keepdims=True) 113 | x = x - means 114 | 115 | # Generate n_features normal vectors of the given size 116 | size = jnp.minimum(2 * n_components, n_features) 117 | Q = jax.random.normal(rng, shape=(n_features, size)) 118 | 119 | def step_fn(q, _): 120 | q, _ = jax.scipy.linalg.lu(x @ q, permute_l=True) 121 | q, _ = jax.scipy.linalg.lu(x.T @ q, permute_l=True) 122 | return q, None 123 | 124 | Q, _ = jax.lax.scan(step_fn, init=Q, xs=None, length=n_iter) 125 | Q, _ = jax.scipy.linalg.qr(x @ Q, mode="economic") 126 | B = Q.T @ x 127 | 128 | _, S, Vt = jax.scipy.linalg.svd(B, full_matrices=False) 129 | 130 | explained_variance = (S[:n_components] ** 2) / (n_samples - 1) 131 | A = Vt[:n_components] 132 | return PCAState(components=A, means=means, explained_variance=explained_variance) 133 | --------------------------------------------------------------------------------