├── requirements.txt
├── .gitignore
├── pcax
├── __init__.py
├── pca_test.py
└── pca.py
├── .github
├── dependabot.yml
└── workflows
│ ├── pytest.yml
│ └── python-publish.yml
├── pyproject.toml
├── LICENSE
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | jax
2 | jaxlib
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore python files
2 | **/__pycache__/**
3 |
4 |
--------------------------------------------------------------------------------
/pcax/__init__.py:
--------------------------------------------------------------------------------
1 | from .pca import fit, transform, recover
2 |
3 | __version__ = '0.1.0'
4 |
5 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: pip
4 | directory: "/"
5 | schedule:
6 | interval: weekly
7 | timezone: CET
8 | open-pull-requests-limit: 10
9 |
10 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | build:
11 | name: "Python ${{ matrix.python-version }}"
12 | runs-on: ubuntu-latest
13 |
14 | strategy:
15 | matrix:
16 | python-version: ["3.11", "3.12", "3.13"]
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 |
21 | - uses: actions/setup-python@v4
22 | with:
23 | python-version: "${{ matrix.python-version }}"
24 |
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | python -m pip install -e . pytest
29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
30 |
31 | - name: Test with pytest
32 | run: |
33 | pytest
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel", "jax"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "pcax"
7 | dynamic = ["version"]
8 | description = "Minimal Principal Component Analysis (PCA) implementation using JAX."
9 | readme = "README.md"
10 | authors = [
11 | {name = "Albert Alonso", email = "alonfnt@pm.me"},
12 | ]
13 | license = {file = "LICENSE"}
14 | keywords = ["jax", "pca", "machine-learning"]
15 | classifiers = [
16 | "Development Status :: 3 - Alpha",
17 | "Intended Audience :: Developers",
18 | "License :: OSI Approved :: MIT License",
19 | "Programming Language :: Python :: 3",
20 | ]
21 | dependencies = ["jax",]
22 |
23 | [tool.setuptools.dynamic]
24 | version = {attr = "pcax.__version__"}
25 |
26 | [project.urls]
27 | "Homepage" = "https://github.com/alonfnt/pcax"
28 | "Documentation" = "https://github.com/alonfnt/pcax"
29 | "Source" = "https://github.com/alonfnt/pcax"
30 | "Bug Tracker" = "https://github.com/alonfnt/pcax/issues"
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Albert Alonso
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [](https://github.com/alonfnt/pcax/actions/workflows/pytest.yml)
6 | [](https://pypi.org/project/pcax/)
7 |
8 | `pcax` is a minimal PCA implementation in [JAX](https://github.com/jax-ml/jax) that’s both GPU/TPU/CPU‑native and fully differentiable.
9 | It keeps data and computation on-device with zero-copy transfers, lets you backpropagate through your dimensionality reduction step, and plugs directly your JAX workflows for seamless, efficient model integration.
10 |
11 | ## Usage
12 | ```python
13 | import pcax
14 |
15 | # Fit the PCA model with 3 components on your data X
16 | state = pcax.fit(X, n_components=3)
17 |
18 | # Transform X to its principal components
19 | X_pca = pcax.transform(state, X)
20 |
21 | # Recover the original X from its principal components
22 | X_recover = pcax.recover(state, X_pca)
23 | ```
24 |
25 | ## Installation
26 | `pcax` can be installed from PyPI via `pip`
27 | ```
28 | pip install pcax
29 | ```
30 |
31 | ## Citation
32 | If you use `pcax` in your research and need to reference it, please cite it as follows:
33 | ```
34 | @software{alonso_pcax,
35 | author = {Alonso, Albert},
36 | title = {pcax: Minimal Principal Component Analysis (PCA) Implementation in JAX},
37 | url = {https://github.com/alonfnt/pcax},
38 | version = {0.1.0}
39 | }
40 | ```
41 |
--------------------------------------------------------------------------------
/pcax/pca_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import jax
3 | import jax.numpy as jnp
4 |
5 | from pcax import fit, transform, recover
6 |
7 | KEY = jax.random.PRNGKey(42)
8 |
9 |
10 | def test_fit_invalid_solver():
11 | with pytest.raises(ValueError):
12 | fit(jnp.zeros((10, 5)), n_components=2, solver="invalid_solver")
13 |
14 |
15 | @pytest.mark.parametrize("n_components", [1, 2, 5, 10])
16 | @pytest.mark.parametrize("n_entries", [100, 200, 300])
17 | @pytest.mark.parametrize("solver", ["full", "randomized"])
18 | def test_fit_output_shapes(n_entries, n_components, solver):
19 | x = jax.random.normal(KEY, shape=(n_entries, 50))
20 | rng, _ = jax.random.split(KEY)
21 |
22 | state = fit(x, n_components=n_components, solver=solver, rng=rng)
23 |
24 | assert state.components.shape == (n_components, x.shape[1])
25 | assert state.means.shape == (1, x.shape[1])
26 | assert state.explained_variance.shape == (n_components,)
27 |
28 |
29 | def test_fit_zero_mean():
30 | x = jax.random.normal(KEY, shape=(100, 50))
31 | n_components = 5
32 | state = fit(x, n_components=n_components, solver="full")
33 | x_zero_mean = x - state.means
34 | x_pca = jnp.dot(x_zero_mean, state.components.T)
35 | x_pca2 = transform(state, x)
36 | assert jnp.allclose(x_pca.mean(axis=0), jnp.zeros(n_components), atol=1e-5)
37 | assert jnp.allclose(x_pca, x_pca2)
38 |
39 |
40 | @pytest.mark.parametrize("n_components", [50])
41 | @pytest.mark.parametrize("n_entries", [300, 500])
42 | @pytest.mark.parametrize("solver", ['full', 'randomized'])
43 | def test_reconstruction(n_entries, n_components, solver):
44 | x = jax.random.normal(KEY, shape=(n_entries, 50))
45 | state = fit(x, n_components=n_components, solver=solver)
46 | x_pca = transform(state, x)
47 | x_recovered = recover(state, x_pca)
48 | assert x_recovered.shape == x.shape
49 | assert jnp.allclose(x, x_recovered, atol=1e-1)
50 |
--------------------------------------------------------------------------------
/pcax/pca.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import NamedTuple
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 |
8 | class PCAState(NamedTuple):
9 | """Stores the state of a fitted PCA model.
10 |
11 | Attributes:
12 | components: Principal components (right singular vectors).
13 | means: Mean of each feature in the original data.
14 | explained_variance: Variance explained by each component.
15 | """
16 |
17 | components: jax.Array
18 | means: jax.Array
19 | explained_variance: jax.Array
20 |
21 |
22 | def transform(state: PCAState, x: jax.Array) -> jax.Array:
23 | """
24 | Projects data into the PCA space defined by the fitted components.
25 |
26 | Args:
27 | state: The fitted PCA state.
28 | x: Input data of shape (n_samples, n_features).
29 |
30 | Returns:
31 | Transformed data in the reduced PCA space.
32 | """
33 | x = x - state.means
34 | return jnp.dot(x, jnp.transpose(state.components))
35 |
36 |
37 | def recover(state: PCAState, x: jax.Array) -> jax.Array:
38 | """
39 | Reconstructs data from its PCA-projected representation.
40 |
41 | Args:
42 | state: The fitted PCA state.
43 | x: Transformed data of shape (n_samples, n_components).
44 |
45 | Returns:
46 | Approximate reconstruction of the original input.
47 | """
48 | return jnp.dot(x, state.components) + state.means
49 |
50 |
51 | def fit(
52 | x: jax.Array, n_components: int, solver: str = "full", rng: jax.Array | None = None
53 | ) -> PCAState:
54 | """
55 | Fits PCA on the input data using either full or randomized solver.
56 |
57 | Args:
58 | x: Input data of shape (n_samples, n_features).
59 | n_components: Number of principal components to retain.
60 | solver: Either "full" or "randomized".
61 | rng: PRNG key for randomized solver.
62 |
63 | Returns:
64 | The learned PCA transformation state.
65 |
66 | Raises:
67 | ValueError: If an invalid solver name is provided.
68 | """
69 | if solver == "full":
70 | return _fit_full(x, n_components)
71 | elif solver == "randomized":
72 | if rng is None:
73 | rng = jax.random.PRNGKey(n_components)
74 | return _fit_randomized(x, n_components, rng)
75 | else:
76 | raise ValueError("Invalid solver: must be 'full' or 'randomized'")
77 |
78 |
79 | @partial(jax.jit, static_argnames="n_components")
80 | def _fit_full(x: jax.Array, n_components: int) -> PCAState:
81 | """
82 | Performs exact PCA using full SVD on centered input.
83 | Used internally when `solver='full'`.
84 | """
85 |
86 | n_samples, n_features = x.shape
87 |
88 | # Subtract the mean of the input data
89 | means = x.mean(axis=0, keepdims=True)
90 | x = x - means
91 |
92 | # Factorize the data matrix with singular value decomposition.
93 | U, S, Vt = jax.scipy.linalg.svd(x, full_matrices=False)
94 |
95 | # Compute the explained variance
96 | explained_variance = (S[:n_components] ** 2) / (n_samples - 1)
97 |
98 | # Return the transformation matrix
99 | A = Vt[:n_components]
100 | return PCAState(components=A, means=means, explained_variance=explained_variance)
101 |
102 |
103 | def _fit_randomized(
104 | x: jax.Array, n_components: int, rng: jax.Array, n_iter: int = 5
105 | ) -> PCAState:
106 | """
107 | Randomized PCA approximation using power iterations.
108 | Based on Halko et al., [https://doi.org/10.48550/arXiv.1007.5510].
109 | Used internally when `solver='randomized'`.
110 | """
111 | n_samples, n_features = x.shape
112 | means = jnp.mean(x, axis=0, keepdims=True)
113 | x = x - means
114 |
115 | # Generate n_features normal vectors of the given size
116 | size = jnp.minimum(2 * n_components, n_features)
117 | Q = jax.random.normal(rng, shape=(n_features, size))
118 |
119 | def step_fn(q, _):
120 | q, _ = jax.scipy.linalg.lu(x @ q, permute_l=True)
121 | q, _ = jax.scipy.linalg.lu(x.T @ q, permute_l=True)
122 | return q, None
123 |
124 | Q, _ = jax.lax.scan(step_fn, init=Q, xs=None, length=n_iter)
125 | Q, _ = jax.scipy.linalg.qr(x @ Q, mode="economic")
126 | B = Q.T @ x
127 |
128 | _, S, Vt = jax.scipy.linalg.svd(B, full_matrices=False)
129 |
130 | explained_variance = (S[:n_components] ** 2) / (n_samples - 1)
131 | A = Vt[:n_components]
132 | return PCAState(components=A, means=means, explained_variance=explained_variance)
133 |
--------------------------------------------------------------------------------