├── .gitignore ├── .travis.yml ├── README.md ├── pyproject.toml ├── scripts ├── benchmark_RBF.py └── benchmark_string_kernel.py ├── sklearn_jax_kernels ├── __init__.py ├── base_kernels.py ├── config.py ├── gpc.py └── structured │ ├── string_utils.py │ └── strings.py └── tests ├── __init__.py ├── test_base_kernels.py ├── test_gpc.py ├── test_kernels.py └── test_structured_strings.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | sklearn_jax_kernels.egg-info/ 3 | pip-wheel-metadata/ 4 | poetry.lock 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 3.6 4 | - 3.7 5 | - 3.8 6 | install: 7 | - pip install poetry && poetry install 8 | script: 9 | - poetry run pytest tests 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sklearn-jax-kernels 2 | 3 | [![PyPI version](https://badge.fury.io/py/sklearn-jax-kernels.svg)](https://badge.fury.io/py/sklearn-jax-kernels) [![Build Status](https://travis-ci.com/ExpectationMax/sklearn-jax-kernels.svg?token=3sUUnmMzs9wxN3Qapssj&branch=master)](https://travis-ci.com/ExpectationMax/sklearn-jax-kernels) [![Downloads](https://pepy.tech/badge/sklearn-jax-kernels)](https://pepy.tech/project/sklearn-jax-kernels) 4 | 5 | **Warning: This project is still in an early stage it could be that the API 6 | will change in the future. Further, functionality is still quite limited to the 7 | use cases which defined the creation of the project (application to DNA 8 | sequences present in Biology).** 9 | 10 | ## Why? 11 | Ever wanted to run a kernel-based model from 12 | [scikit-learn](https://scikit-learn.org/) on a relatively large dataset? If so 13 | you will have noticed, that this can take extraordinarily long and require huge 14 | amounts of memory, especially if you are using compositions of kernels (such as 15 | for example `k1 * k2 + k3`). This is due to the way Kernels are computed in 16 | scikit-learn: For each kernel, the complete kernel matrix is computed, and the 17 | compositions are then computed from the kernel matrices. Further, 18 | `scikit-learn` does not rely on an automatic differentiation framework for the 19 | computation of gradients though kernel operations. 20 | 21 | ## Introduction 22 | 23 | `sklearn-jax-kernels` was designed to circumvent these issues: 24 | 25 | - The utilization of [JAX](https://github.com/google/jax) allows accelerating 26 | kernel computations through [XLA](https://www.tensorflow.org/xla) 27 | optimizations, computation on GPUs and simplifies the computation of 28 | gradients though kernels 29 | - The composition of kernels takes place on a per-element basis, such that 30 | unnecessary copies can be optimized away by JAX compilation 31 | 32 | The goal of `sklearn-jax-kernels` is to provide the same flexibility and ease 33 | of use as known from `scikit-learn` kernels while improving speed and allowing 34 | the faster design of new kernels through Automatic Differentiation. 35 | 36 | The kernels in this package follow the [scikit-learn kernel 37 | API](https://scikit-learn.org/stable/modules/gaussian_process.html#gaussian-process-kernel-api). 38 | 39 | ## Installation 40 | 41 | `sklearn-jax-kernels` can simply be installed via `pip`: 42 | 43 | ```bash 44 | pip install sklearn-jax-kernels 45 | ``` 46 | 47 | ## Quickstart 48 | 49 | A short demonstration of how the kernels can be used, inspired by the 50 | [scikit-learn 51 | documentation](https://scikit-learn.org/stable/auto_examples/gaussian_process/plot_gpc_iris.html). 52 | 53 | ```python 54 | from sklearn import datasets 55 | import jax.numpy as jnp 56 | from sklearn_jax_kernels import RBF, GaussianProcessClassifier 57 | 58 | iris = datasets.load_iris() 59 | X = jnp.asarray(iris.data) 60 | y = jnp.array(iris.target, dtype=int) 61 | 62 | kernel = 1. + RBF(length_scale=1.0) 63 | gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y) 64 | ``` 65 | 66 | Here a further example demonstrating how kernels can be combined: 67 | 68 | ```python 69 | from sklearn_jax_kernels.base_kernels import RBF, NormalizedKernel 70 | from sklearn_jax_kernels.structured.strings import SpectrumKernel 71 | 72 | my_kernel = RBF(1.) * SpectrumKernel(n_gram_length=3) 73 | my_kernel_2 = RBF(1.) + RBF(2.) 74 | my_kernel_2 = NormalizedKernel(my_kernel_2) 75 | ``` 76 | 77 | Some further inspiration can be taken from the tests in the subfolder `tests`. 78 | 79 | ## Implemented Kernels 80 | 81 | - Kernel compositions ($+,-,*,/$, exponentiation) 82 | - Kernels for real valued data: 83 | - RBF kernel 84 | - Kernels for same length strings: 85 | - SpectrumKernel 86 | - DistanceSpectrumKernel, SpectrumKernel with distance weight between 87 | matching substrings 88 | - ReverseComplement Spectrum kernel (relevant for applications in Biology 89 | when working with DNA sequences) 90 | 91 | ## TODOs 92 | 93 | - Implement more fundamental Kernels 94 | - Implement jax compatible version of GaussianProcessRegressor 95 | - Optimize GaussianProcessClassifier for performance 96 | - Run benchmarks to show benefits in speed 97 | - Add fake "split" kernel which allows to apply different kernels to different 98 | parts of the input 99 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "sklearn-jax-kernels" 3 | version = "0.0.2" 4 | description = "Composable kernels for scikit-learn implemented in JAX." 5 | authors = ["Max Horn "] 6 | readme = "README.md" 7 | repository = "https://github.com/ExpectationMax/sklearn-jax-kernels" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.6" 11 | jax = "^0.1.59" 12 | jaxlib = "^0.1.40" 13 | scikit-learn = "^0.23.0" 14 | 15 | [tool.poetry.dev-dependencies] 16 | pytest = "^4.6" 17 | ipython = "^7.13.0" 18 | 19 | [build-system] 20 | requires = ["poetry>=0.12"] 21 | build-backend = "poetry.masonry.api" 22 | -------------------------------------------------------------------------------- /scripts/benchmark_RBF.py: -------------------------------------------------------------------------------- 1 | """Compare performance of jax kernel with default sklearn kernel.""" 2 | import timeit 3 | 4 | import numpy as np 5 | import jax.numpy as jnp 6 | from sklearn import datasets 7 | from sklearn.gaussian_process import GaussianProcessClassifier as sk_GPC 8 | from sklearn.gaussian_process.kernels import RBF as sklearn_RBF 9 | from sklearn_jax_kernels import RBF as jax_RBF 10 | from sklearn_jax_kernels import GaussianProcessClassifier as jax_GPC 11 | 12 | # import some data to play with 13 | digits = datasets.load_digits() 14 | X = digits.data 15 | y = np.array(digits.target, dtype=int) 16 | 17 | X_jax = jnp.asarray(X) 18 | y_jax = jnp.asarray(y) 19 | 20 | sk_kernel = 1.0 * sklearn_RBF([1.0]) 21 | jax_kernel = 1.0 * jax_RBF([1.0]) 22 | 23 | sk_clf = sk_GPC(kernel=sk_kernel, copy_X_train=False) 24 | jax_clf = jax_GPC(kernel=jax_kernel, copy_X_train=False) 25 | 26 | sk_clf.fit(X, y) 27 | jax_clf.fit(X_jax, y_jax) 28 | 29 | def fit_with_sklearn_kernel(): 30 | sk_clf.fit(X, y) 31 | 32 | def fit_with_jax_kernel(): 33 | jax_clf.fit(X_jax, y_jax) 34 | 35 | time_sk = timeit.timeit(fit_with_sklearn_kernel, number=1) 36 | print(time_sk) 37 | time_jax = timeit.timeit(fit_with_jax_kernel, number=1) 38 | print(time_jax) 39 | -------------------------------------------------------------------------------- /scripts/benchmark_string_kernel.py: -------------------------------------------------------------------------------- 1 | import random 2 | import timeit 3 | from sklearn_jax_kernels.structured.string_utils import ( 4 | AsciiBytesTransformer, NGramTransformer) 5 | from sklearn_jax_kernels.structured.strings import SpectrumKernel 6 | from sklearn.pipeline import Pipeline 7 | 8 | strings = [ 9 | "".join(random.choices(['A', 'T', 'G', 'C'], k=115)) for i in range(1000)] 10 | pipeline = Pipeline([ 11 | ('bytes', AsciiBytesTransformer()), 12 | ('ngrams', NGramTransformer(2)) 13 | ]) 14 | transformed = pipeline.transform(strings).block_until_ready() 15 | kernel = SpectrumKernel(n_gram_length=None) 16 | 17 | kernel(transformed).block_until_ready() 18 | print('Compiled') 19 | time = timeit.timeit(lambda: kernel(transformed).block_until_ready(), number=5) 20 | print(time) 21 | 22 | -------------------------------------------------------------------------------- /sklearn_jax_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | """Jax implementation of sklearn kernels.""" 2 | from .base_kernels import ( 3 | ConstantKernel, 4 | Exponentiation, 5 | Kernel, 6 | KernelOperator, 7 | NormalizedKernel, 8 | Product, 9 | RBF, 10 | Sum 11 | ) 12 | from .gpc import GaussianProcessClassifier 13 | 14 | __all__ = [ 15 | 'ConstantKernel', 'Exponentiation', 'GaussianProcessClassifier', 'Kernel', 16 | 'KernelOperator', 'NormalizedKernel', 'Product', 'RBF', 'Sum' 17 | ] 18 | __version__ = '0.1.0' 19 | -------------------------------------------------------------------------------- /sklearn_jax_kernels/base_kernels.py: -------------------------------------------------------------------------------- 1 | """Base classes of Kernel implementation compatible with JAX.""" 2 | import abc 3 | import numpy 4 | from functools import partial 5 | from sklearn.gaussian_process.kernels import Kernel as sklearn_kernel 6 | from sklearn.gaussian_process.kernels import ( 7 | Hyperparameter, 8 | StationaryKernelMixin, 9 | NormalizedKernelMixin 10 | ) 11 | from jax import jit, vmap, value_and_grad 12 | import jax.numpy as np 13 | import jax.ops as ops 14 | from jax.experimental import loops 15 | 16 | from .config import config_value 17 | 18 | 19 | class Kernel(sklearn_kernel, metaclass=abc.ABCMeta): 20 | """Kernel object similar to sklearn implementation but supporting JAX. 21 | 22 | Contains additional methods: 23 | - pure_kernel_function 24 | 25 | """ 26 | 27 | @property 28 | @abc.abstractmethod 29 | def pure_kernel_fn(self): 30 | """Get a pure function which applies the kernel. 31 | 32 | Returned function should have the signature: 33 | fn(theta, x1, x2). 34 | """ 35 | 36 | @staticmethod 37 | def _kernel_matrix_without_gradients(kernel_fn, theta, X, Y): 38 | kernel_fn = partial(kernel_fn, theta) 39 | if Y is None or (Y is X): 40 | if config_value('KERNEL_MATRIX_USE_LOOP'): 41 | n = len(X) 42 | with loops.Scope() as s: 43 | # s.scattered_values = np.empty((n, n)) 44 | s.index1, s.index2 = np.tril_indices(n, k=0) 45 | s.output = np.empty(len(s.index1)) 46 | for i in s.range(s.index1.shape[0]): 47 | i1, i2 = s.index1[i], s.index2[i] 48 | s.output = ops.index_update( 49 | s.output, 50 | i, 51 | kernel_fn(X[i1], X[i2]) 52 | ) 53 | first_update = ops.index_update( 54 | np.empty((n, n)), (s.index1, s.index2), s.output) 55 | second_update = ops.index_update( 56 | first_update, (s.index2, s.index1), s.output) 57 | return second_update 58 | else: 59 | n = len(X) 60 | values_scattered = np.empty((n, n)) 61 | index1, index2 = np.tril_indices(n, k=-1) 62 | inst1, inst2 = X[index1], X[index2] 63 | values = vmap(kernel_fn)(inst1, inst2) 64 | values_scattered = ops.index_update( 65 | values_scattered, (index1, index2), values) 66 | values_scattered = ops.index_update( 67 | values_scattered, (index2, index1), values) 68 | values_scattered = ops.index_update( 69 | values_scattered, 70 | np.diag_indices(n), 71 | vmap(lambda x: kernel_fn(x, x))(X) 72 | ) 73 | return values_scattered 74 | else: 75 | if config_value('KERNEL_MATRIX_USE_LOOP'): 76 | with loops.Scope() as s: 77 | s.output = np.empty((X.shape[0], Y.shape[0])) 78 | for i in s.range(X.shape[0]): 79 | x = X[i] 80 | s.output = ops.index_update( 81 | s.output, 82 | i, 83 | vmap(lambda y: kernel_fn(x, y))(Y) 84 | ) 85 | return s.output 86 | else: 87 | return vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(Y))(X) 88 | 89 | @staticmethod 90 | def _kernel_matrix_with_gradients(kernel_fn, theta, X, Y): 91 | kernel_fn = value_and_grad(kernel_fn) 92 | kernel_fn = partial(kernel_fn, theta) 93 | if Y is None or (Y is X): 94 | if config_value('KERNEL_MATRIX_USE_LOOP'): 95 | n = len(X) 96 | with loops.Scope() as s: 97 | s.scattered_values = np.empty((n, n)) 98 | s.scattered_grads = np.empty((n, n, len(theta))) 99 | index1, index2 = np.tril_indices(n, k=0) 100 | for i in s.range(index1.shape[0]): 101 | i1, i2 = index1[i], index2[i] 102 | value, grads = kernel_fn(X[i1], X[i2]) 103 | indexes = (np.stack([i1, i2]), np.stack([i2, i1])) 104 | s.scattered_values = ops.index_update( 105 | s.scattered_values, 106 | indexes, 107 | value 108 | ) 109 | s.scattered_grads = ops.index_update( 110 | s.scattered_grads, indexes, grads) 111 | return s.scattered_values, s.scattered_grads 112 | else: 113 | n = len(X) 114 | values_scattered = np.empty((n, n)) 115 | grads_scattered = np.empty((n, n, len(theta))) 116 | index1, index2 = np.tril_indices(n, k=-1) 117 | inst1, inst2 = X[index1], X[index2] 118 | values, grads = vmap(kernel_fn)(inst1, inst2) 119 | # Scatter computed values into matrix 120 | values_scattered = ops.index_update( 121 | values_scattered, (index1, index2), values) 122 | values_scattered = ops.index_update( 123 | values_scattered, (index2, index1), values) 124 | grads_scattered = ops.index_update( 125 | grads_scattered, (index1, index2), grads) 126 | grads_scattered = ops.index_update( 127 | grads_scattered, (index2, index1), grads) 128 | diag_values, diag_grads = vmap( 129 | lambda x: kernel_fn(x, x))(X) 130 | diag_indices = np.diag_indices(n) 131 | values_scattered = ops.index_update( 132 | values_scattered, diag_indices, diag_values) 133 | grads_scattered = ops.index_update( 134 | grads_scattered, diag_indices, diag_grads) 135 | return values_scattered, grads_scattered 136 | else: 137 | return vmap( 138 | lambda x: vmap(lambda y: kernel_fn(x, y))(Y))(X) 139 | 140 | def get_kernel_matrix_fn(self, eval_gradient): 141 | """Return pure function for computing kernel matrix and gradients. 142 | 143 | We do some internal caching in order to avoid recompiling the resulting 144 | function. 145 | 146 | Returned function has the signature: `f(theta, X, Y)` 147 | """ 148 | cache_name = ( 149 | '_cached_kernel_matrix_fn' + '_grad' if eval_gradient else '') 150 | if not hasattr(self, cache_name): 151 | pure_kernel_fn = self.pure_kernel_fn 152 | 153 | if eval_gradient: 154 | kernel_matrix_fn = jit(partial( 155 | self._kernel_matrix_with_gradients, 156 | pure_kernel_fn 157 | )) 158 | else: 159 | kernel_matrix_fn = jit(partial( 160 | self._kernel_matrix_without_gradients, 161 | pure_kernel_fn 162 | )) 163 | setattr(self, cache_name, kernel_matrix_fn) 164 | 165 | return getattr(self, cache_name) 166 | 167 | def __call__(self, X, Y=None, eval_gradient=False): 168 | """Build kernel matrix from input data X and Y. 169 | 170 | Evtl. also compute the gradient with respect to the parameters. 171 | """ 172 | X = np.asarray(X) 173 | if Y is not None: 174 | Y = np.asarray(Y) 175 | return self.get_kernel_matrix_fn(eval_gradient)(self.theta, X, Y) 176 | 177 | def diag(self, X): 178 | """Get diagonal of kernel matrix.""" 179 | return vmap(lambda x: self.pure_kernel_fn(self.theta, x, x))(X) 180 | 181 | def __add__(self, b): 182 | """Add kernel to constant or other kernel.""" 183 | if not isinstance(b, Kernel): 184 | return Sum(self, ConstantKernel(b)) 185 | return Sum(self, b) 186 | 187 | def __radd__(self, b): 188 | """Add kernel to constant or other kernel.""" 189 | if not isinstance(b, Kernel): 190 | return Sum(ConstantKernel(b), self) 191 | return Sum(b, self) 192 | 193 | def __mul__(self, b): 194 | """Mulitply kernel with constant or other kernel.""" 195 | if not isinstance(b, Kernel): 196 | return Product(self, ConstantKernel(b)) 197 | return Product(self, b) 198 | 199 | def __rmul__(self, b): 200 | """Mulitply kernel with constant or other kernel.""" 201 | if not isinstance(b, Kernel): 202 | return Product(ConstantKernel(b), self) 203 | return Product(b, self) 204 | 205 | def __pow__(self, b): 206 | """Exponentiate kernel.""" 207 | return Exponentiation(self, b) 208 | 209 | 210 | class NormalizedKernel(NormalizedKernelMixin, Kernel): 211 | """Kernel wrapper which computes a normalized version of the kernel.""" 212 | 213 | def __init__(self, kernel): 214 | self.kernel = kernel 215 | 216 | @property 217 | def pure_kernel_fn(self): 218 | """Not really needed in this particular case.""" 219 | kernel_fn = self.kernel.pure_kernel_fn 220 | 221 | def wrapped(theta, x, y): 222 | K_xy = kernel_fn(theta, x, y) 223 | K_xx = kernel_fn(theta, x, x) 224 | K_yy = kernel_fn(theta, y, y) 225 | return K_xy / np.sqrt(K_xx * K_yy) 226 | return wrapped 227 | 228 | def get_kernel_matrix_fn(self, eval_gradient): 229 | """Return pure function for computing kernel matrix and gradients. 230 | 231 | We do some internal caching in order to avoid recompiling the resulting 232 | function. Further, we compute the output of the normalized kernel 233 | matrix at this stage in order to avoid recomputing the self similarity 234 | on each kernel evaluation. 235 | 236 | Returned function has the signature: `f(theta, X, Y)` 237 | """ 238 | if config_value('NORMALIZED_KERNEL_PUSH_DOWN'): 239 | # In this case compute the normalization for each instance 240 | # inside the kernel fn. This recomputes the self similarities many 241 | # times, but does not require keeping multiple tensors of the size 242 | # of the kernel matrix in memory for computing normalization. This 243 | # is particularly the case when computing gradients with respect to 244 | # kernel parameters. 245 | return super().get_kernel_matrix_fn(eval_gradient) 246 | 247 | cache_name = '_kernel_matrix_fn' + '_grad' if eval_gradient else '' 248 | if not hasattr(self, cache_name): 249 | pure_kernel_fn = self.kernel.pure_kernel_fn 250 | if eval_gradient: 251 | kernel_matrix_with_grad = \ 252 | self._kernel_matrix_with_gradients 253 | 254 | def wrapped(theta, X, Y): 255 | """Compute normalized kernel matrix and do chain rule.""" 256 | kmatrix, grads = kernel_matrix_with_grad( 257 | pure_kernel_fn, theta, X, Y) 258 | if Y is None: 259 | diag = np.diag(kmatrix) 260 | grad_diag_indices = np.diag_indices(kmatrix.shape[0]) 261 | diag_grad = grads[grad_diag_indices] 262 | normalizer = np.sqrt(diag[:, None] * diag[None, :]) 263 | # Add dimensions for broadcasting 264 | K_xx = diag[:, None, None] 265 | K_yy = diag[None, :, None] 266 | K_xx_grad = diag_grad[:, None, :] 267 | K_yy_grad = diag_grad[None, :, :] 268 | 269 | # Do the chain rule 270 | grads = ( 271 | ( 272 | 2 * K_xx * K_yy * grads - 273 | kmatrix * (K_xx_grad * K_yy + K_xx * K_yy_grad) 274 | ) / (2 * (K_xx * K_yy) ** (3/2)) 275 | ) 276 | return kmatrix / normalizer, grads 277 | else: 278 | # If y is not defined we need to compute the self 279 | # similarity of each instance 280 | kernel_fn_with_grad = partial( 281 | value_and_grad(pure_kernel_fn), theta) 282 | K_xx, K_xx_grad = vmap( 283 | lambda x: kernel_fn_with_grad(x, x))(X) 284 | K_yy, K_yy_grad = vmap( 285 | lambda y: kernel_fn_with_grad(y, y))(Y) 286 | # Add dimensions for broadcasting 287 | K_xx = K_xx[:, None, None] 288 | K_yy = K_yy[None, :, None] 289 | K_xx_grad = K_xx_grad[:, None, :] 290 | K_yy_grad = K_yy_grad[None, :, :] 291 | 292 | normalizer = np.sqrt(K_xx[:, None] * K_yy[None, :]) 293 | # d/dw(k(x, y, w)/sqrt(k(x, x, w) k(y, y, w))) = (2 294 | # k(x, x, w) k(y, y, w) k^(0, 0, 1)(x, y, w) - k(x, y, 295 | # w) (k^(0, 0, 1)(x, x, w) k(y, y, w) + k(x, x, w) 296 | # k^(0, 0, 1)(y, y, w)))/(2 (k(x, x, w) k(y, y, 297 | # w))^(3/2)) 298 | grads = ( 299 | ( 300 | 2 * K_xx * K_yy * grads - 301 | kmatrix * (K_xx_grad * K_yy + K_xx * K_yy_grad) 302 | ) / 303 | (2 * (K_xx * K_yy) ** (3/2)) 304 | ) 305 | 306 | return kmatrix / normalizer, grads 307 | 308 | kernel_matrix_fn = jit(wrapped) 309 | else: 310 | kernel_matrix = self._kernel_matrix_without_gradients 311 | 312 | def wrapped(theta, X, Y): 313 | """Compute normalized kernel matrix.""" 314 | kmatrix = kernel_matrix(pure_kernel_fn, theta, X, Y) 315 | if Y is None: 316 | diag = np.diag(kmatrix) 317 | normalizer = np.sqrt(diag[:, None] * diag[None, :]) 318 | return kmatrix / normalizer 319 | else: 320 | # If y is not defined we need to compute the self 321 | # similarity of each instance 322 | K_xx = vmap(lambda x: pure_kernel_fn(theta, x, x))(X) 323 | K_yy = vmap(lambda y: pure_kernel_fn(theta, y, y))(Y) 324 | normalizer = np.sqrt(K_xx[:, None] * K_yy[None, :]) 325 | return kmatrix / normalizer 326 | 327 | kernel_matrix_fn = jit(wrapped) 328 | setattr(self, cache_name, kernel_matrix_fn) 329 | return getattr(self, cache_name) 330 | 331 | def is_stationary(self): 332 | """Whether kernel is stationary.""" 333 | return self.kernel.is_stationary() 334 | 335 | def get_params(self, deep=True): 336 | """Get parameters of this kernel. 337 | 338 | Parameters: 339 | deep : boolean, optional 340 | If True, will return the parameters for this estimator and 341 | contained subobjects that are estimators. 342 | 343 | Returns: 344 | params : mapping of string to any 345 | Parameter names mapped to their values. 346 | 347 | """ 348 | params = dict(kernel=self.kernel) 349 | if deep: 350 | deep_items = self.kernel.get_params().items() 351 | params.update(('kernel__' + k, val) for k, val in deep_items) 352 | return params 353 | 354 | @property 355 | def hyperparameters(self): 356 | """Return a list of all hyperparameter.""" 357 | r = [] 358 | for hyperparameter in self.kernel.hyperparameters: 359 | r.append(Hyperparameter("kernel__" + hyperparameter.name, 360 | hyperparameter.value_type, 361 | hyperparameter.bounds, 362 | hyperparameter.n_elements)) 363 | return r 364 | 365 | @property 366 | def theta(self): 367 | """Return the (flattened, log-transformed) non-fixed hyperparameters. 368 | 369 | Note that theta are typically the log-transformed values of the 370 | kernel's hyperparameters as this representation of the search space 371 | is more amenable for hyperparameter search, as hyperparameters like 372 | length-scales naturally live on a log-scale. 373 | 374 | Returns: 375 | theta : array, shape (n_dims,) 376 | The non-fixed, log-transformed hyperparameters of the kernel 377 | 378 | """ 379 | return self.kernel.theta 380 | 381 | @theta.setter 382 | def theta(self, theta): 383 | """Set the (flattened, log-transformed) non-fixed hyperparameters. 384 | 385 | Parameters: 386 | theta : array, shape (n_dims,) 387 | The non-fixed, log-transformed hyperparameters of the kernel 388 | 389 | """ 390 | self.kernel.theta = theta 391 | 392 | @property 393 | def bounds(self): 394 | """Return the log-transformed bounds on the theta. 395 | 396 | Returns: 397 | bounds : array, shape (n_dims, 2) 398 | The log-transformed bounds on the kernel's hyperparameters 399 | theta 400 | 401 | """ 402 | return self.kernel.bounds 403 | 404 | def __eq__(self, b): 405 | """Whether two instances are considered equal.""" 406 | if type(self) != type(b): 407 | return False 408 | return self.kernel == b.kernel 409 | 410 | 411 | class KernelOperator(Kernel): 412 | """Base class for all kernel operators.""" 413 | 414 | def __init__(self, k1, k2): 415 | self.k1 = k1 416 | self.k2 = k2 417 | 418 | def get_params(self, deep=True): 419 | """Get parameters of this kernel. 420 | 421 | Parameters: 422 | deep : boolean, optional 423 | If True, will return the parameters for this estimator and 424 | contained subobjects that are estimators. 425 | 426 | Returns: 427 | params : mapping of string to any 428 | Parameter names mapped to their values. 429 | 430 | """ 431 | params = dict(k1=self.k1, k2=self.k2) 432 | if deep: 433 | deep_items = self.k1.get_params().items() 434 | params.update(('k1__' + k, val) for k, val in deep_items) 435 | deep_items = self.k2.get_params().items() 436 | params.update(('k2__' + k, val) for k, val in deep_items) 437 | 438 | return params 439 | 440 | @property 441 | def hyperparameters(self): 442 | """Return a list of all hyperparameter.""" 443 | r = [Hyperparameter("k1__" + hyperparameter.name, 444 | hyperparameter.value_type, 445 | hyperparameter.bounds, hyperparameter.n_elements) 446 | for hyperparameter in self.k1.hyperparameters] 447 | 448 | for hyperparameter in self.k2.hyperparameters: 449 | r.append(Hyperparameter("k2__" + hyperparameter.name, 450 | hyperparameter.value_type, 451 | hyperparameter.bounds, 452 | hyperparameter.n_elements)) 453 | return r 454 | 455 | @property 456 | def theta(self): 457 | """Return the (flattened, log-transformed) non-fixed hyperparameters. 458 | 459 | Note that theta are typically the log-transformed values of the 460 | kernel's hyperparameters as this representation of the search space 461 | is more amenable for hyperparameter search, as hyperparameters like 462 | length-scales naturally live on a log-scale. 463 | 464 | Returns: 465 | theta : array, shape (n_dims,) 466 | The non-fixed, log-transformed hyperparameters of the kernel 467 | 468 | """ 469 | return np.append(self.k1.theta, self.k2.theta) 470 | 471 | @theta.setter 472 | def theta(self, theta): 473 | """Set the (flattened, log-transformed) non-fixed hyperparameters. 474 | 475 | Parameters: 476 | theta : array, shape (n_dims,) 477 | The non-fixed, log-transformed hyperparameters of the kernel 478 | 479 | """ 480 | k1_dims = self.k1.n_dims 481 | self.k1.theta = theta[:k1_dims] 482 | self.k2.theta = theta[k1_dims:] 483 | 484 | @property 485 | def bounds(self): 486 | """Return the log-transformed bounds on the theta. 487 | 488 | Returns: 489 | bounds : array, shape (n_dims, 2) 490 | The log-transformed bounds on the kernel's hyperparameters 491 | theta 492 | 493 | """ 494 | if self.k1.bounds.size == 0: 495 | return self.k2.bounds 496 | if self.k2.bounds.size == 0: 497 | return self.k1.bounds 498 | return np.vstack((self.k1.bounds, self.k2.bounds)) 499 | 500 | def __eq__(self, b): 501 | """Check for equality between kernels.""" 502 | if type(self) != type(b): 503 | return False 504 | return (self.k1 == b.k1 and self.k2 == b.k2) \ 505 | or (self.k1 == b.k2 and self.k2 == b.k1) 506 | 507 | def is_stationary(self): 508 | """Return whether the kernel is stationary.""" 509 | return self.k1.is_stationary() and self.k2.is_stationary() 510 | 511 | @property 512 | def requires_vector_input(self): 513 | """Return whether the kernel is stationary. """ 514 | return (self.k1.requires_vector_input or 515 | self.k2.requires_vector_input) 516 | 517 | 518 | class Sum(KernelOperator): 519 | """Sum of two kernels.""" 520 | 521 | @property 522 | def pure_kernel_fn(self): 523 | """Kernel function of the two added kernels.""" 524 | k1_fn = self.k1.pure_kernel_fn 525 | k2_fn = self.k2.pure_kernel_fn 526 | k1_dims = self.k1.n_dims 527 | 528 | def kernel_fn(theta, x, y): 529 | return k1_fn(theta[:k1_dims], x, y) + k2_fn(theta[k1_dims:], x, y) 530 | 531 | return kernel_fn 532 | 533 | 534 | class Product(KernelOperator): 535 | """Product of two kernels.""" 536 | 537 | @property 538 | def pure_kernel_fn(self): 539 | """Kernel function of the two added kernels.""" 540 | k1_fn = self.k1.pure_kernel_fn 541 | k2_fn = self.k2.pure_kernel_fn 542 | k1_dims = self.k1.n_dims 543 | 544 | def kernel_fn(theta, x, y): 545 | return k1_fn(theta[:k1_dims], x, y) * k2_fn(theta[k1_dims:], x, y) 546 | 547 | return kernel_fn 548 | 549 | def __repr__(self): 550 | """Return representation of kernel.""" 551 | return "{0} * {1}".format(self.k1, self.k2) 552 | 553 | 554 | class ConstantKernel(StationaryKernelMixin, Kernel): 555 | """Kernel which always returns a constant.""" 556 | 557 | def __init__(self, constant_value=1.0, constant_value_bounds=(1e-5, 1e5)): 558 | """Init kernel with constant_value.""" 559 | self.constant_value = constant_value 560 | self.constant_value_bounds = constant_value_bounds 561 | 562 | @property 563 | def hyperparameter_constant_value(self): 564 | return Hyperparameter( 565 | "constant_value", "numeric", self.constant_value_bounds) 566 | 567 | @property 568 | def pure_kernel_fn(self): 569 | """Return the kernel fn.""" 570 | if self.hyperparameter_constant_value.fixed: 571 | value = self.constant_value 572 | 573 | def kernel_fn(theta, x, y): 574 | return value 575 | else: 576 | def kernel_fn(theta, x, y): 577 | return np.exp(theta[0]) # Theta is in log domain and array 578 | 579 | return kernel_fn 580 | 581 | 582 | class Exponentiation(Kernel): 583 | """Exponentiation of a kernel.""" 584 | 585 | def __init__(self, kernel: Kernel, exponent): 586 | """Init kernel exponentiation of kernel with exponent.""" 587 | self.kernel = kernel 588 | self.exponent = exponent 589 | 590 | def is_stationary(self): 591 | """Whether kernel is stationary.""" 592 | return self.kernel.is_stationary() 593 | 594 | def get_params(self, deep=True): 595 | """Get parameters of this kernel. 596 | 597 | Parameters: 598 | deep : boolean, optional 599 | If True, will return the parameters for this estimator and 600 | contained subobjects that are estimators. 601 | 602 | Returns: 603 | params : mapping of string to any 604 | Parameter names mapped to their values. 605 | 606 | """ 607 | params = dict(kernel=self.kernel, exponent=self.exponent) 608 | if deep: 609 | deep_items = self.kernel.get_params().items() 610 | params.update(('kernel__' + k, val) for k, val in deep_items) 611 | return params 612 | 613 | @property 614 | def hyperparameters(self): 615 | """Return a list of all hyperparameter.""" 616 | r = [] 617 | for hyperparameter in self.kernel.hyperparameters: 618 | r.append(Hyperparameter("kernel__" + hyperparameter.name, 619 | hyperparameter.value_type, 620 | hyperparameter.bounds, 621 | hyperparameter.n_elements)) 622 | return r 623 | 624 | @property 625 | def theta(self): 626 | """Return the (flattened, log-transformed) non-fixed hyperparameters. 627 | 628 | Note that theta are typically the log-transformed values of the 629 | kernel's hyperparameters as this representation of the search space 630 | is more amenable for hyperparameter search, as hyperparameters like 631 | length-scales naturally live on a log-scale. 632 | 633 | Returns: 634 | theta : array, shape (n_dims,) 635 | The non-fixed, log-transformed hyperparameters of the kernel 636 | 637 | """ 638 | return self.kernel.theta 639 | 640 | @theta.setter 641 | def theta(self, theta): 642 | """Set the (flattened, log-transformed) non-fixed hyperparameters. 643 | 644 | Parameters: 645 | theta : array, shape (n_dims,) 646 | The non-fixed, log-transformed hyperparameters of the kernel 647 | 648 | """ 649 | self.kernel.theta = theta 650 | 651 | @property 652 | def bounds(self): 653 | """Return the log-transformed bounds on the theta. 654 | 655 | Returns: 656 | bounds : array, shape (n_dims, 2) 657 | The log-transformed bounds on the kernel's hyperparameters 658 | theta 659 | 660 | """ 661 | return self.kernel.bounds 662 | 663 | def __eq__(self, b): 664 | """Whether two instances are considered equal.""" 665 | if type(self) != type(b): 666 | return False 667 | return (self.kernel == b.kernel and self.exponent == b.exponent) 668 | 669 | @property 670 | def pure_kernel_fn(self): 671 | """Pure kernel fn of exponentiated kernel.""" 672 | exponent = self.exponent 673 | kernel_fn = self.kernel.pure_kernel_fn 674 | 675 | def exp_kernel_fn(theta, x, y): 676 | return np.power(kernel_fn(theta, x, y), exponent) 677 | return exp_kernel_fn 678 | 679 | 680 | class RBF(StationaryKernelMixin, NormalizedKernelMixin, Kernel): 681 | """RBF Kernel.""" 682 | 683 | def __init__(self, length_scale=1.0, length_scale_bounds=(1e-5, 1e5)): 684 | """Initialize RBF kernel with length_scale and bounds.""" 685 | self.length_scale = length_scale 686 | self.length_scale_bounds = length_scale_bounds 687 | 688 | @property 689 | def anisotropic(self): 690 | return np.iterable(self.length_scale) and len(self.length_scale) > 1 691 | 692 | @property 693 | def hyperparameter_length_scale(self): 694 | if self.anisotropic: 695 | return Hyperparameter("length_scale", "numeric", 696 | self.length_scale_bounds, 697 | len(self.length_scale)) 698 | return Hyperparameter( 699 | "length_scale", "numeric", self.length_scale_bounds) 700 | 701 | @property 702 | def pure_kernel_fn(self): 703 | """Pure kernel fn of RBF kernel.""" 704 | if self.hyperparameter_length_scale.fixed: 705 | length_scale = self.length_scale 706 | if np.iterable(length_scale): 707 | # handle case when length scale is fixed and provided as list 708 | length_scale = np.asarray(length_scale) 709 | 710 | def kernel_fn(theta, x, y): 711 | # as we get a log-transformed theta as input, we need to transform 712 | # it back. 713 | diff = (x - y) / length_scale 714 | d = np.sum(diff ** 2, axis=-1) 715 | return np.exp(-0.5 * d) 716 | else: 717 | def kernel_fn(theta, x, y): 718 | # as we get a log-transformed theta as input, we need to transform 719 | # it back. 720 | diff = (x - y) / np.exp(theta) 721 | d = np.sum(diff ** 2, axis=-1) 722 | return np.exp(-0.5 * d) 723 | return kernel_fn 724 | 725 | def __repr__(self): 726 | if self.anisotropic: 727 | return "{0}(length_scale=[{1}])".format( 728 | self.__class__.__name__, ", ".join(map("{0:.3g}".format, 729 | self.length_scale))) 730 | else: # isotropic 731 | return "{0}(length_scale={1:.3g})".format( 732 | self.__class__.__name__, np.ravel(self.length_scale)[0]) 733 | -------------------------------------------------------------------------------- /sklearn_jax_kernels/config.py: -------------------------------------------------------------------------------- 1 | """Configuration values for jax kernels.""" 2 | 3 | MEMORY_SAVING_CONFIGS = [ 4 | 'KERNEL_MATRIX_USE_LOOP', 5 | 'NORMALIZED_KERNEL_PUSH_DOWN' 6 | ] 7 | 8 | SAVE_MEMORY = False 9 | """bool: Apply multiple tricks to reduce memory overhead. 10 | 11 | This could slow down computation significantly but may allow the direct 12 | application to larger problems when running into GPU memory issues. 13 | 14 | This variable activates all configs in `MEMORY_SAVING_CONFIGS`. 15 | """ 16 | 17 | KERNEL_MATRIX_USE_LOOP = True 18 | """bool: Use loop when constructing kernel matrix 19 | 20 | Compute kernel matrix using a dynamically unrolled loop instead of 21 | vectorization. Reduces memory overhead as not all intermediate tensors have to 22 | be allocated at once. 23 | """ 24 | 25 | NORMALIZED_KERNEL_PUSH_DOWN = False 26 | """bool: Push down the computation of normalization factor 27 | 28 | Moves the computation of normalization to each individual instance. This can 29 | reduce memory overhead, especially when computing gradients with respect to 30 | kernel parameters at the cost of higher computation time (approx. 3 fold). In 31 | this case k(x,x) and k(y,y) needed to be recomputed for each instance. 32 | """ 33 | 34 | 35 | def config_value(config_var_name): 36 | """Get value of a config variable. 37 | 38 | Also accounts for config variables that activate other variables such as 39 | `SAVE_MEMORY`. 40 | """ 41 | if config_var_name in MEMORY_SAVING_CONFIGS and SAVE_MEMORY: 42 | return True 43 | return globals()[config_var_name] 44 | -------------------------------------------------------------------------------- /sklearn_jax_kernels/gpc.py: -------------------------------------------------------------------------------- 1 | """Subclass of sklearn Gaussian process classifier using JAX.""" 2 | from functools import partial 3 | from sklearn.gaussian_process import GaussianProcessClassifier as GPC 4 | from sklearn.gaussian_process._gpc import ( 5 | _BinaryGaussianProcessClassifierLaplace) 6 | from sklearn.multiclass import OneVsRestClassifier, OneVsOneClassifier 7 | from sklearn.utils.validation import check_X_y 8 | 9 | import numpy 10 | 11 | import jax.numpy as np 12 | from jax.scipy.linalg import cholesky, cho_solve, solve 13 | from jax.scipy.special import expit 14 | from jax import ops 15 | from jax import jit 16 | 17 | 18 | @partial(jit, static_argnums=0) 19 | def _newton_iteration(y_train, K, f): 20 | pi = expit(f) 21 | W = pi * (1 - pi) 22 | # Line 5 23 | W_sr = np.sqrt(W) 24 | W_sr_K = W_sr[:, np.newaxis] * K 25 | B = np.eye(W.shape[0]) + W_sr_K * W_sr 26 | L = cholesky(B, lower=True) 27 | # Line 6 28 | b = W * f + (y_train - pi) 29 | # Line 7 30 | a = b - W_sr * cho_solve((L, True), W_sr_K.dot(b)) 31 | # Line 8 32 | f = K.dot(a) 33 | 34 | # Line 10: Compute log marginal likelihood in loop and use as 35 | # convergence criterion 36 | lml = -0.5 * a.T.dot(f) \ 37 | - np.log1p(np.exp(-(y_train * 2 - 1) * f)).sum() \ 38 | - np.log(np.diag(L)).sum() 39 | return lml, f, (pi, W_sr, L, b, a) 40 | 41 | 42 | class BinaryGaussianProcessClassifier(_BinaryGaussianProcessClassifierLaplace): 43 | def _posterior_mode(self, K, return_temporaries=False): 44 | """Mode-finding for binary Laplace GPC and fixed kernel. 45 | This approximates the posterior of the latent function values for given 46 | inputs and target observations with a Gaussian approximation and uses 47 | Newton's iteration to find the mode of this approximation. 48 | """ 49 | # Based on Algorithm 3.1 of GPML 50 | 51 | # If warm_start are enabled, we reuse the last solution for the 52 | # posterior mode as initialization; otherwise, we initialize with 0 53 | if self.warm_start and hasattr(self, "f_cached") \ 54 | and self.f_cached.shape == self.y_train_.shape: 55 | f = self.f_cached 56 | else: 57 | f = np.zeros_like(self.y_train_, dtype=np.float32) 58 | 59 | # Use Newton's iteration method to find mode of Laplace approximation 60 | log_marginal_likelihood = -np.inf 61 | newton_iteration = partial(_newton_iteration, self.y_train_, K) 62 | 63 | for _ in range(self.max_iter_predict): 64 | lml, f, (pi, W_sr, L, b, a) = newton_iteration(f) 65 | # Check if we have converged (log marginal likelihood does 66 | # not decrease) 67 | # XXX: more complex convergence criterion 68 | if lml - log_marginal_likelihood < 1e-10: 69 | break 70 | log_marginal_likelihood = lml 71 | 72 | self.f_cached = f # Remember solution for later warm-starts 73 | if return_temporaries: 74 | return log_marginal_likelihood, (pi, W_sr, L, b, a) 75 | else: 76 | return log_marginal_likelihood 77 | 78 | def log_marginal_likelihood(self, theta=None, eval_gradient=False, 79 | clone_kernel=False): 80 | """Returns log-marginal likelihood of theta for training data. 81 | 82 | Parameters 83 | ---------- 84 | theta : array-like of shape (n_kernel_params,) or None 85 | Kernel hyperparameters for which the log-marginal likelihood is 86 | evaluated. If None, the precomputed log_marginal_likelihood 87 | of ``self.kernel_.theta`` is returned. 88 | eval_gradient : bool, default: False 89 | If True, the gradient of the log-marginal likelihood with respect 90 | to the kernel hyperparameters at position theta is returned 91 | additionally. If True, theta must not be None. 92 | clone_kernel : bool, default=True 93 | If True, the kernel attribute is copied. If False, the kernel 94 | attribute is modified, but may result in a performance improvement. 95 | Returns 96 | ------- 97 | log_likelihood : float 98 | Log-marginal likelihood of theta for training data. 99 | log_likelihood_gradient : array, shape = (n_kernel_params,), optional 100 | Gradient of the log-marginal likelihood with respect to the kernel 101 | hyperparameters at position theta. 102 | Only returned when eval_gradient is True. 103 | """ 104 | 105 | if theta is None: 106 | if eval_gradient: 107 | raise ValueError( 108 | "Gradient can only be evaluated for theta!=None") 109 | return self.log_marginal_likelihood_value_ 110 | 111 | kernel_matrix_fn = self.kernel_.get_kernel_matrix_fn(eval_gradient) 112 | 113 | if eval_gradient: 114 | K, K_gradient = kernel_matrix_fn(theta, self.X_train_, None) 115 | else: 116 | K = kernel_matrix_fn(theta, self.X_train_, None) 117 | 118 | # Compute log-marginal-likelihood Z and also store some temporaries 119 | # which can be reused for computing Z's gradient 120 | Z, (pi, W_sr, L, b, a) = \ 121 | self._posterior_mode(K, return_temporaries=True) 122 | 123 | if not eval_gradient: 124 | return Z 125 | 126 | # Compute gradient based on Algorithm 5.1 of GPML 127 | 128 | d_Z = np.empty(theta.shape[0]) 129 | # XXX: Get rid of the np.diag() in the next line 130 | R = W_sr[:, np.newaxis] * cho_solve((L, True), np.diag(W_sr)) # Line 7 131 | C = solve(L, W_sr[:, np.newaxis] * K) # Line 8 132 | # Line 9: (use einsum to compute np.diag(C.T.dot(C)))) 133 | s_2 = -0.5 * (np.diag(K) - np.einsum('ij, ij -> j', C, C)) \ 134 | * (pi * (1 - pi) * (1 - 2 * pi)) # third derivative 135 | 136 | for j in range(d_Z.shape[0]): 137 | C = K_gradient[:, :, j] # Line 11 138 | # Line 12: (R.T.ravel().dot(C.ravel()) = np.trace(R.dot(C))) 139 | s_1 = .5 * a.T.dot(C).dot(a) - .5 * R.T.ravel().dot(C.ravel()) 140 | 141 | b = C.dot(self.y_train_ - pi) # Line 13 142 | s_3 = b - K.dot(R.dot(b)) # Line 14 143 | 144 | d_Z = ops.index_update(d_Z, j, s_1 + s_2.T.dot(s_3)) # Line 15 145 | 146 | return ( 147 | numpy.asarray(Z, dtype=numpy.float64), 148 | numpy.asarray(d_Z, dtype=numpy.float64) 149 | ) 150 | 151 | 152 | class GaussianProcessClassifier(GPC): 153 | def fit(self, X, y): 154 | """Fit Gaussian process classification model 155 | Parameters 156 | ---------- 157 | X : sequence of length n_samples 158 | Feature vectors or other representations of training data. 159 | Could either be array-like with shape = (n_samples, n_features) 160 | or a list of objects. 161 | y : array-like of shape (n_samples,) 162 | Target values, must be binary 163 | Returns 164 | ------- 165 | self : returns an instance of self. 166 | """ 167 | if self.kernel is None or self.kernel.requires_vector_input: 168 | X, y = check_X_y(X, y, multi_output=False, 169 | ensure_2d=True, dtype="numeric") 170 | else: 171 | X, y = check_X_y(X, y, multi_output=False, 172 | ensure_2d=False, dtype=None) 173 | 174 | self.base_estimator_ = BinaryGaussianProcessClassifier( 175 | self.kernel, self.optimizer, self.n_restarts_optimizer, 176 | self.max_iter_predict, self.warm_start, self.copy_X_train, 177 | self.random_state) 178 | 179 | self.classes_ = numpy.unique(y) 180 | self.n_classes_ = self.classes_.size 181 | if self.n_classes_ == 1: 182 | raise ValueError("GaussianProcessClassifier requires 2 or more " 183 | "distinct classes; got %d class (only class %s " 184 | "is present)" 185 | % (self.n_classes_, self.classes_[0])) 186 | if self.n_classes_ > 2: 187 | if self.multi_class == "one_vs_rest": 188 | self.base_estimator_ = \ 189 | OneVsRestClassifier(self.base_estimator_, 190 | n_jobs=self.n_jobs) 191 | elif self.multi_class == "one_vs_one": 192 | self.base_estimator_ = \ 193 | OneVsOneClassifier(self.base_estimator_, 194 | n_jobs=self.n_jobs) 195 | else: 196 | raise ValueError("Unknown multi-class mode %s" 197 | % self.multi_class) 198 | 199 | self.base_estimator_.fit(X, y) 200 | 201 | if self.n_classes_ > 2: 202 | self.log_marginal_likelihood_value_ = numpy.mean( 203 | [estimator.log_marginal_likelihood() 204 | for estimator in self.base_estimator_.estimators_]) 205 | else: 206 | self.log_marginal_likelihood_value_ = \ 207 | self.base_estimator_.log_marginal_likelihood() 208 | 209 | return self 210 | -------------------------------------------------------------------------------- /sklearn_jax_kernels/structured/string_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for usage with strings.""" 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.base import BaseEstimator, TransformerMixin 6 | 7 | 8 | class AsciiBytesTransformer(TransformerMixin): 9 | """Convert python strings into ascii byte arrays. 10 | 11 | This allows them to be used with jax. 12 | """ 13 | 14 | def fit(self, X, y=None): 15 | return self 16 | 17 | def transform(self, X): 18 | if len(np.unique([len(x) for x in X])) == 1: 19 | # All strings are of same length 20 | return np.asarray([bytearray(x, 'ascii', 'strict') for x in X]) 21 | 22 | # Varying length 23 | return [np.asarray(bytearray(a, 'ascii', 'strict')) for a in X] 24 | 25 | def inverse_transform(self, X): 26 | return [str(bytes(a), 'ascii', 'strict') for a in X] 27 | 28 | 29 | class CompressAlphabetTransformer(TransformerMixin, BaseEstimator): 30 | """Determines which chars are used and maps them to values.""" 31 | 32 | def __init__(self, output_dtype=np.uint8): 33 | self.output_dtype = output_dtype 34 | self._unique_chars = None 35 | self._mapping = None 36 | 37 | def fit(self, X, y=None): 38 | single_str = ''.join(X) 39 | self._unique_chars = list(set(single_str)) 40 | self._unique_chars.sort() 41 | n_alphabet = len(self._unique_chars) 42 | self._mapping = defaultdict( 43 | lambda: n_alphabet, 44 | [(char, i) for i, char in enumerate(self._unique_chars)] 45 | ) 46 | return self 47 | 48 | def transform(self, X, y=None): 49 | X_transf = [ 50 | [self._mapping[char] for char in x] 51 | for x in X 52 | ] 53 | return np.asarray(X_transf, dtype=self.output_dtype) 54 | 55 | 56 | def get_translation_table(input_symbols, output_symbols, mapping, 57 | dtype=np.uint8): 58 | """Build a translation table for characters using mapping. 59 | 60 | Returns a array, where index i corresponding to input_symbol contains the 61 | matching value from output_symbols. 62 | 63 | Example: 64 | >>> get_translation_table( 65 | input_symbols=['a', 'b', 'c'], 66 | output_symbols=['c', 'b', 'a'], 67 | mapping = {'a': 0, 'b': 1, 'c': 2} 68 | ) 69 | array([2, 1, 0]) 70 | 71 | """ 72 | input_transf = np.array([mapping[symb] for symb in input_symbols], dtype=dtype) 73 | assert len(np.unique(input_transf)) == len(input_symbols) 74 | output_transf = np.array([mapping[symb] for symb in output_symbols], dtype=dtype) 75 | input_order = np.argsort(input_transf) 76 | return output_transf[input_order] 77 | 78 | 79 | 80 | class NGramTransformer(TransformerMixin, BaseEstimator): 81 | def __init__(self, ngram_length): 82 | self.ngram_length = ngram_length 83 | 84 | def fit(self, X, y=None): 85 | return self 86 | 87 | def transform(self, X): 88 | n = X.shape[1] 89 | ngram_slices = [ 90 | X[:, i:n+1-self.ngram_length+i] 91 | for i in range(0, self.ngram_length) 92 | ] 93 | return np.stack(ngram_slices, axis=2) 94 | -------------------------------------------------------------------------------- /sklearn_jax_kernels/structured/strings.py: -------------------------------------------------------------------------------- 1 | """Implementation of string kernels.""" 2 | from functools import partial 3 | from jax import device_put, vmap 4 | import jax.numpy as np 5 | from jax.lax import dynamic_slice_in_dim 6 | from jax.experimental import loops 7 | 8 | from sklearn_jax_kernels import Kernel 9 | from sklearn.gaussian_process.kernels import GenericKernelMixin, Hyperparameter 10 | 11 | 12 | class SpectrumKernel(GenericKernelMixin, Kernel): 13 | """Spectrum string kernel. 14 | 15 | As described in: 16 | ``` 17 | @incollection{leslie2001spectrum, 18 | title={ 19 | The spectrum kernel: A string kernel for SVM protein classification}, 20 | author={ 21 | Leslie, Christina and Eskin, Eleazar and Noble, William Stafford}, 22 | booktitle={Biocomputing 2002}, 23 | pages={564--575}, 24 | year={2001}, 25 | publisher={World Scientific} 26 | } 27 | ``` 28 | """ 29 | 30 | def __init__(self, n_gram_length): 31 | """Spectrum kernel on strings. 32 | 33 | Assumes input was transformed via `AcsiiBytesTransformer` or similar 34 | tranformation into a jax compatible datatype. 35 | 36 | Parameters: 37 | n_gram_length: Length of ngrams to compare. If `None` it is assumed 38 | that the input is 2d where the final axis is the n_grams. 39 | 40 | """ 41 | self.n_gram_length = n_gram_length 42 | 43 | @property 44 | def pure_kernel_fn(self): 45 | """Return the pure fuction for computing the kernel.""" 46 | n_gram_length = self.n_gram_length 47 | 48 | def kmer_kernel_fn(theta, kmers1, kmers2): 49 | same_kmer = np.all(kmers1[None, :, :] == kmers2[:, None, :], axis=2) 50 | return np.sum(same_kmer) 51 | 52 | if n_gram_length is None: 53 | # Assume input is kmer transformed 54 | kernel_fn = kmer_kernel_fn 55 | else: 56 | def kernel_fn(theta, string1, string2): 57 | def make_to_kmers(string): 58 | ngram_slices = [ 59 | string[i:len(string)+1-n_gram_length+i] 60 | for i in range(0, n_gram_length) 61 | ] 62 | return np.stack(ngram_slices, axis=1) 63 | kmers1 = make_to_kmers(string1) 64 | kmers2 = make_to_kmers(string2) 65 | return kmer_kernel_fn(theta, kmers1, kmers2) 66 | return kernel_fn 67 | 68 | @property 69 | def hyperparameters(self): 70 | """Return a list of all hyperparameter.""" 71 | return [] 72 | 73 | @property 74 | def theta(self): 75 | """Return the (flattened, log-transformed) non-fixed hyperparameters. 76 | 77 | Note that theta are typically the log-transformed values of the 78 | kernel's hyperparameters as this representation of the search space 79 | is more amenable for hyperparameter search, as hyperparameters like 80 | length-scales naturally live on a log-scale. 81 | 82 | Returns: 83 | theta : array, shape (n_dims,) 84 | The non-fixed, log-transformed hyperparameters of the kernel 85 | 86 | """ 87 | return np.empty((0,)) 88 | 89 | @theta.setter 90 | def theta(self, theta): 91 | """Set the (flattened, log-transformed) non-fixed hyperparameters. 92 | 93 | Parameters: 94 | theta : array, shape (n_dims,) 95 | The non-fixed, log-transformed hyperparameters of the kernel 96 | 97 | """ 98 | 99 | @property 100 | def bounds(self): 101 | """Return the log-transformed bounds on the theta. 102 | 103 | Returns: 104 | bounds : array, shape (n_dims, 2) 105 | The log-transformed bounds on the kernel's hyperparameters 106 | theta 107 | 108 | """ 109 | return np.empty((0, 2)) 110 | 111 | def is_stationary(self): 112 | """Whether this kernel is stationary.""" 113 | return False 114 | 115 | 116 | class DistanceSpectrumKernel(Kernel): 117 | """Spectrum kernel weighting ngrams by their distance in the sequence.""" 118 | 119 | def __init__(self, distance_kernel: Kernel, n_gram_length): 120 | """Initialize DistanceSpectrumKernel using distance_kernel. 121 | 122 | Args: 123 | distance_kernel (Kernel): Kernel used to quantify distance of kmers 124 | in the string. 125 | 126 | """ 127 | self.distance_kernel = distance_kernel 128 | self.n_gram_length = n_gram_length 129 | 130 | @property 131 | def pure_kernel_fn(self): 132 | """Return the pure function for computing the kernel.""" 133 | n_gram_length = self.n_gram_length 134 | 135 | distance_kernel = self.distance_kernel.pure_kernel_fn 136 | 137 | def kmer_kernel_fn(theta, kmers1, kmers2): 138 | pos_kernel = partial(distance_kernel, theta) 139 | same_kmer = np.all( 140 | kmers1[:, None, :] == kmers2[None, :, :], 141 | axis=2 142 | ) 143 | offsets1 = np.arange(kmers1.shape[0]) 144 | offsets2 = np.arange(kmers2.shape[0]) 145 | distance_weight = vmap( 146 | lambda i: vmap(lambda j: pos_kernel(i, j))(offsets2))(offsets1) 147 | return np.sum(same_kmer * distance_weight) 148 | 149 | if n_gram_length is None: 150 | # Assume input is kmer transformed 151 | kernel_fn = kmer_kernel_fn 152 | else: 153 | def kernel_fn(theta, string1, string2): 154 | def make_to_kmers(string): 155 | ngram_slices = [ 156 | string[i:len(string)+1-n_gram_length+i] 157 | for i in range(0, n_gram_length) 158 | ] 159 | return np.stack(ngram_slices, axis=1) 160 | kmers1 = make_to_kmers(string1) 161 | kmers2 = make_to_kmers(string2) 162 | return kmer_kernel_fn(theta, kmers1, kmers2) 163 | 164 | return kernel_fn 165 | 166 | @property 167 | def hyperparameters(self): 168 | """Return a list of all hyperparameter.""" 169 | r = [] 170 | for hyperparameter in self.distance_kernel.hyperparameters: 171 | r.append(Hyperparameter("distance_kernel__" + hyperparameter.name, 172 | hyperparameter.value_type, 173 | hyperparameter.bounds, 174 | hyperparameter.n_elements)) 175 | return r 176 | 177 | @property 178 | def theta(self): 179 | """Return the (flattened, log-transformed) non-fixed hyperparameters. 180 | 181 | Note that theta are typically the log-transformed values of the 182 | kernel's hyperparameters as this representation of the search space 183 | is more amenable for hyperparameter search, as hyperparameters like 184 | length-scales naturally live on a log-scale. 185 | 186 | Returns: 187 | theta : array, shape (n_dims,) 188 | The non-fixed, log-transformed hyperparameters of the kernel 189 | 190 | """ 191 | return self.distance_kernel.theta 192 | 193 | @theta.setter 194 | def theta(self, theta): 195 | """Set the (flattened, log-transformed) non-fixed hyperparameters. 196 | 197 | Parameters: 198 | theta : array, shape (n_dims,) 199 | The non-fixed, log-transformed hyperparameters of the kernel 200 | 201 | """ 202 | self.distance_kernel.theta = theta 203 | 204 | @property 205 | def bounds(self): 206 | """Return the log-transformed bounds on the theta. 207 | 208 | Returns: 209 | bounds : array, shape (n_dims, 2) 210 | The log-transformed bounds on the kernel's hyperparameters 211 | theta 212 | 213 | """ 214 | return self.distance_kernel.bounds 215 | 216 | def get_params(self, deep=True): 217 | """Get parameters of this kernel. 218 | 219 | Parameters: 220 | deep : boolean, optional 221 | If True, will return the parameters for this estimator and 222 | contained subobjects that are estimators. 223 | 224 | Returns: 225 | params : mapping of string to any 226 | Parameter names mapped to their values. 227 | 228 | """ 229 | params = dict( 230 | distance_kernel=self.distance_kernel, 231 | n_gram_length=self.n_gram_length 232 | ) 233 | if deep: 234 | deep_items = self.distance_kernel.get_params().items() 235 | params.update( 236 | ('distance_kernel__' + k, val) for k, val in deep_items) 237 | return params 238 | 239 | def __eq__(self, b): 240 | """Whether two instances are considered equal.""" 241 | if type(self) != type(b): 242 | return False 243 | return ( 244 | self.distance_kernel == b.distance_kernel and 245 | self.n_gram_length == b.n_gram_length 246 | ) 247 | 248 | def is_stationary(self): 249 | """Whether this kernel is stationary.""" 250 | return False 251 | 252 | 253 | class DistanceFromEndSpectrumKernel(Kernel): 254 | """Spectrum kernel weighting ngrams by their distance in the sequence.""" 255 | 256 | def __init__(self, distance_kernel: Kernel, n_gram_length): 257 | """Initialize DistanceSpectrumKernel using distance_kernel. 258 | 259 | Args: 260 | distance_kernel (Kernel): Kernel used to quantify distance of kmers 261 | in the string. 262 | 263 | """ 264 | self.distance_kernel = distance_kernel 265 | self.n_gram_length = n_gram_length 266 | 267 | @property 268 | def pure_kernel_fn(self): 269 | """Return the pure fuction for computing the kernel.""" 270 | n_gram_length = self.n_gram_length 271 | 272 | distance_kernel = self.distance_kernel.pure_kernel_fn 273 | 274 | def kmer_kernel_fn(theta, kmers1, kmers2): 275 | pos_kernel = partial(distance_kernel, theta) 276 | kmer2_offsets = np.arange(kmers1.shape[0], dtype=np.uint32) 277 | distances_from_end = np.min( 278 | np.stack([kmer2_offsets, kmer2_offsets[::-1]], axis=0), axis=0) 279 | 280 | with loops.Scope() as s: 281 | s.out = 0. 282 | for i in s.range(kmers1.shape[0]): 283 | distance_from_end = distances_from_end[i] 284 | kmer = kmers1[i] 285 | distances = vmap( 286 | lambda j: pos_kernel(distance_from_end, j))(distances_from_end) 287 | is_same = np.all(kmer[None, :] == kmers2, axis=1) 288 | n_matches = np.sum(is_same * distances) 289 | s.out += n_matches 290 | return s.out 291 | if n_gram_length is None: 292 | # Assume input is kmer transformed 293 | kernel_fn = kmer_kernel_fn 294 | else: 295 | def kernel_fn(theta, string1, string2): 296 | def make_to_kmers(string): 297 | ngram_slices = [ 298 | string[i:len(string)+1-n_gram_length+i] 299 | for i in range(0, n_gram_length) 300 | ] 301 | return np.stack(ngram_slices, axis=1) 302 | kmers1 = make_to_kmers(string1) 303 | kmers2 = make_to_kmers(string2) 304 | return kmer_kernel_fn(theta, kmers1, kmers2) 305 | 306 | return kernel_fn 307 | 308 | @property 309 | def hyperparameters(self): 310 | """Return a list of all hyperparameter.""" 311 | r = [] 312 | for hyperparameter in self.distance_kernel.hyperparameters: 313 | r.append(Hyperparameter("distance_kernel__" + hyperparameter.name, 314 | hyperparameter.value_type, 315 | hyperparameter.bounds, 316 | hyperparameter.n_elements)) 317 | return r 318 | 319 | @property 320 | def theta(self): 321 | """Return the (flattened, log-transformed) non-fixed hyperparameters. 322 | 323 | Note that theta are typically the log-transformed values of the 324 | kernel's hyperparameters as this representation of the search space 325 | is more amenable for hyperparameter search, as hyperparameters like 326 | length-scales naturally live on a log-scale. 327 | 328 | Returns: 329 | theta : array, shape (n_dims,) 330 | The non-fixed, log-transformed hyperparameters of the kernel 331 | 332 | """ 333 | return self.distance_kernel.theta 334 | 335 | @theta.setter 336 | def theta(self, theta): 337 | """Set the (flattened, log-transformed) non-fixed hyperparameters. 338 | 339 | Parameters: 340 | theta : array, shape (n_dims,) 341 | The non-fixed, log-transformed hyperparameters of the kernel 342 | 343 | """ 344 | self.distance_kernel.theta = theta 345 | 346 | @property 347 | def bounds(self): 348 | """Return the log-transformed bounds on the theta. 349 | 350 | Returns: 351 | bounds : array, shape (n_dims, 2) 352 | The log-transformed bounds on the kernel's hyperparameters 353 | theta 354 | 355 | """ 356 | return self.distance_kernel.bounds 357 | 358 | def get_params(self, deep=True): 359 | """Get parameters of this kernel. 360 | 361 | Parameters: 362 | deep : boolean, optional 363 | If True, will return the parameters for this estimator and 364 | contained subobjects that are estimators. 365 | 366 | Returns: 367 | params : mapping of string to any 368 | Parameter names mapped to their values. 369 | 370 | """ 371 | params = dict( 372 | distance_kernel=self.distance_kernel, 373 | n_gram_length=self.n_gram_length 374 | ) 375 | if deep: 376 | deep_items = self.distance_kernel.get_params().items() 377 | params.update( 378 | ('distance_kernel__' + k, val) for k, val in deep_items) 379 | return params 380 | 381 | def __eq__(self, b): 382 | """Whether two instances are considered equal.""" 383 | if type(self) != type(b): 384 | return False 385 | return ( 386 | self.distance_kernel == b.distance_kernel and 387 | self.n_gram_length == b.n_gram_length 388 | ) 389 | 390 | def is_stationary(self): 391 | """Whether this kernel is stationary.""" 392 | return False 393 | 394 | 395 | class RevComplementSpectrumKernel(GenericKernelMixin, Kernel): 396 | """Spectrum string kernel which also count reverse complement matches.""" 397 | 398 | def __init__(self, n_gram_length, mapping): 399 | """Spectrum kernel on strings. 400 | 401 | Assumes input was transformed via `AcsiiBytesTransformer` or similar 402 | tranformation into a jax compatible datatype. 403 | 404 | Parameters: 405 | n_gram_length: Length of ngrams to compare. If `None` it is assumed 406 | that the input is 2d where the final axis is the n_grams. 407 | mapping: Array of length of alphabet which defines what is 408 | considered the complement to a particular character 409 | 410 | """ 411 | self.n_gram_length = n_gram_length 412 | self.mapping = mapping 413 | 414 | @property 415 | def pure_kernel_fn(self): 416 | """Return the pure fuction for computing the kernel.""" 417 | n_gram_length = self.n_gram_length 418 | mapping = device_put(np.array(self.mapping)) 419 | 420 | def kmer_kernel_fn(theta, kmers1, kmers2): 421 | with loops.Scope() as s: 422 | s.out = 0. 423 | for i in s.range(kmers1.shape[0]): 424 | kmer = kmers1[i] 425 | rev_comp = mapping[kmer][::-1] 426 | is_same_fw = np.all(kmer[None, :] == kmers2, axis=1) 427 | is_same_rev_comp = np.all( 428 | rev_comp[None, :] == kmers2, axis=1) 429 | n_matches = np.sum(is_same_fw) + np.sum(is_same_rev_comp) 430 | s.out += n_matches 431 | return s.out 432 | 433 | if n_gram_length is None: 434 | # Assume input is kmer transformed 435 | kernel_fn = kmer_kernel_fn 436 | else: 437 | def kernel_fn(theta, string1, string2): 438 | def make_to_kmers(string): 439 | ngram_slices = [ 440 | string[i:len(string)+1-n_gram_length+i] 441 | for i in range(0, n_gram_length) 442 | ] 443 | return np.stack(ngram_slices, axis=1) 444 | kmers1 = make_to_kmers(string1) 445 | kmers2 = make_to_kmers(string2) 446 | return kmer_kernel_fn(theta, kmers1, kmers2) 447 | return kernel_fn 448 | 449 | @property 450 | def hyperparameters(self): 451 | """Return a list of all hyperparameter.""" 452 | return [] 453 | 454 | @property 455 | def theta(self): 456 | """Return the (flattened, log-transformed) non-fixed hyperparameters. 457 | 458 | Note that theta are typically the log-transformed values of the 459 | kernel's hyperparameters as this representation of the search space 460 | is more amenable for hyperparameter search, as hyperparameters like 461 | length-scales naturally live on a log-scale. 462 | 463 | Returns: 464 | theta : array, shape (n_dims,) 465 | The non-fixed, log-transformed hyperparameters of the kernel 466 | 467 | """ 468 | return np.empty((0,)) 469 | 470 | @theta.setter 471 | def theta(self, theta): 472 | """Set the (flattened, log-transformed) non-fixed hyperparameters. 473 | 474 | Parameters: 475 | theta : array, shape (n_dims,) 476 | The non-fixed, log-transformed hyperparameters of the kernel 477 | 478 | """ 479 | 480 | @property 481 | def bounds(self): 482 | """Return the log-transformed bounds on the theta. 483 | 484 | Returns: 485 | bounds : array, shape (n_dims, 2) 486 | The log-transformed bounds on the kernel's hyperparameters 487 | theta 488 | 489 | """ 490 | return np.empty((0, 2)) 491 | 492 | def is_stationary(self): 493 | """Whether this kernel is stationary.""" 494 | return False 495 | 496 | 497 | class DistanceRevComplementSpectrumKernel(Kernel): 498 | """Spectrum kernel weighting ngrams by their distance in the sequence.""" 499 | 500 | def __init__(self, distance_kernel: Kernel, n_gram_length, mapping): 501 | """Initialize DistanceSpectrumKernel using distance_kernel. 502 | 503 | Args: 504 | distance_kernel (Kernel): Kernel used to quantify distance of kmers 505 | in the string. 506 | 507 | """ 508 | self.distance_kernel = distance_kernel 509 | self.n_gram_length = n_gram_length 510 | self.mapping = mapping 511 | 512 | @property 513 | def pure_kernel_fn(self): 514 | """Return the pure function for computing the kernel.""" 515 | n_gram_length = self.n_gram_length 516 | 517 | distance_kernel = self.distance_kernel.pure_kernel_fn 518 | 519 | def kmer_kernel_fn(theta, kmers1, kmers2): 520 | mapping = device_put(self.mapping) 521 | pos_kernel = partial(distance_kernel, theta) 522 | 523 | rev_complement = np.reshape( 524 | mapping[np.ravel(kmers2)], kmers2.shape)[:, ::-1] 525 | 526 | same_kmer = np.all( 527 | kmers1[:, None, :] == kmers2[None, :, :], 528 | axis=2 529 | ) 530 | same_rev_comp = np.all( 531 | kmers1[:, None, :] == rev_complement[None, :, :], 532 | axis=2 533 | ) 534 | offsets1 = np.arange(kmers1.shape[0]) 535 | offsets2 = np.arange(kmers2.shape[0]) 536 | weight = vmap( 537 | lambda i: vmap(lambda j: pos_kernel(i, j))(offsets2))(offsets1) 538 | weight_rev_comp = vmap( 539 | lambda i: vmap(lambda j: pos_kernel(i, j))(offsets2[::-1]))(offsets1) 540 | 541 | return ( 542 | np.sum(same_kmer * weight) + 543 | np.sum(same_rev_comp * weight_rev_comp) 544 | ) 545 | 546 | if n_gram_length is None: 547 | # Assume input is kmer transformed 548 | kernel_fn = kmer_kernel_fn 549 | else: 550 | def kernel_fn(theta, string1, string2): 551 | def make_to_kmers(string): 552 | ngram_slices = [ 553 | string[i:len(string)+1-n_gram_length+i] 554 | for i in range(0, n_gram_length) 555 | ] 556 | return np.stack(ngram_slices, axis=1) 557 | kmers1 = make_to_kmers(string1) 558 | kmers2 = make_to_kmers(string2) 559 | return kmer_kernel_fn(theta, kmers1, kmers2) 560 | 561 | return kernel_fn 562 | 563 | @property 564 | def hyperparameters(self): 565 | """Return a list of all hyperparameter.""" 566 | r = [] 567 | for hyperparameter in self.distance_kernel.hyperparameters: 568 | r.append(Hyperparameter("distance_kernel__" + hyperparameter.name, 569 | hyperparameter.value_type, 570 | hyperparameter.bounds, 571 | hyperparameter.n_elements)) 572 | return r 573 | 574 | @property 575 | def theta(self): 576 | """Return the (flattened, log-transformed) non-fixed hyperparameters. 577 | 578 | Note that theta are typically the log-transformed values of the 579 | kernel's hyperparameters as this representation of the search space 580 | is more amenable for hyperparameter search, as hyperparameters like 581 | length-scales naturally live on a log-scale. 582 | 583 | Returns: 584 | theta : array, shape (n_dims,) 585 | The non-fixed, log-transformed hyperparameters of the kernel 586 | 587 | """ 588 | return self.distance_kernel.theta 589 | 590 | @theta.setter 591 | def theta(self, theta): 592 | """Set the (flattened, log-transformed) non-fixed hyperparameters. 593 | 594 | Parameters: 595 | theta : array, shape (n_dims,) 596 | The non-fixed, log-transformed hyperparameters of the kernel 597 | 598 | """ 599 | self.distance_kernel.theta = theta 600 | 601 | @property 602 | def bounds(self): 603 | """Return the log-transformed bounds on the theta. 604 | 605 | Returns: 606 | bounds : array, shape (n_dims, 2) 607 | The log-transformed bounds on the kernel's hyperparameters 608 | theta 609 | 610 | """ 611 | return self.distance_kernel.bounds 612 | 613 | def get_params(self, deep=True): 614 | """Get parameters of this kernel. 615 | 616 | Parameters: 617 | deep : boolean, optional 618 | If True, will return the parameters for this estimator and 619 | contained subobjects that are estimators. 620 | 621 | Returns: 622 | params : mapping of string to any 623 | Parameter names mapped to their values. 624 | 625 | """ 626 | params = dict( 627 | distance_kernel=self.distance_kernel, 628 | n_gram_length=self.n_gram_length, 629 | mapping=self.mapping 630 | ) 631 | if deep: 632 | deep_items = self.distance_kernel.get_params().items() 633 | params.update( 634 | ('distance_kernel__' + k, val) for k, val in deep_items) 635 | return params 636 | 637 | def __eq__(self, b): 638 | """Whether two instances are considered equal.""" 639 | if type(self) != type(b): 640 | return False 641 | return ( 642 | self.distance_kernel == b.distance_kernel and 643 | self.n_gram_length == b.n_gram_length 644 | ) 645 | 646 | def is_stationary(self): 647 | """Whether this kernel is stationary.""" 648 | return False 649 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExpectationMax/sklearn-jax-kernels/6a351b7d9406de26439918af77b330259127e254/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_base_kernels.py: -------------------------------------------------------------------------------- 1 | """Test base kernels and compositions.""" 2 | import pytest 3 | from functools import partial 4 | import numpy as np 5 | from sklearn.gaussian_process.kernels import RBF as sklearn_RBF 6 | from sklearn.gaussian_process.kernels import ConstantKernel as sklearn_C 7 | 8 | from sklearn_jax_kernels import RBF, ConstantKernel, NormalizedKernel 9 | from sklearn_jax_kernels import config 10 | 11 | 12 | class TestRBF: 13 | @pytest.mark.parametrize("save_memory", [True, False]) 14 | def test_value(self, save_memory): 15 | config.SAVE_MEMORY = save_memory 16 | 17 | lengthscale = 15. 18 | X = np.random.normal(size=(10, 20)) 19 | 20 | sk_rbf = sklearn_RBF(lengthscale) 21 | rbf = RBF(lengthscale) 22 | assert np.allclose(sk_rbf(X), rbf(X)) 23 | 24 | @pytest.mark.parametrize("save_memory", [True, False]) 25 | def test_gradient(self, save_memory): 26 | config.SAVE_MEMORY = save_memory 27 | 28 | lengthscale = 1. 29 | X = np.random.normal(size=(5, 2)) 30 | 31 | sk_rbf = sklearn_RBF(lengthscale) 32 | _, sk_grad = sk_rbf(X, eval_gradient=True) 33 | rbf = RBF(lengthscale) 34 | _, grad = rbf(X, eval_gradient=True) 35 | assert np.allclose(sk_grad, grad) 36 | 37 | 38 | class TestNormalizedKernel: 39 | def test_RBF_value_same(self): 40 | X = np.random.normal(size=(10, 20)) 41 | kernel = NormalizedKernel(RBF(1.)) 42 | K = kernel(X) 43 | 44 | # Compute the kernel using instance wise formulation 45 | from jax import vmap 46 | kernel_fn = partial(kernel.pure_kernel_fn, kernel.theta) 47 | K_instance_wise = \ 48 | vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(X))(X) 49 | 50 | assert np.allclose(K, K_instance_wise) 51 | 52 | def test_RBF_grad_same_XX(self): 53 | X = np.random.normal(size=(3, 20)) 54 | kernel = NormalizedKernel(RBF(1.)) 55 | K, K_grad = kernel(X, eval_gradient=True) 56 | 57 | # Compute the kernel using instance wise formulation 58 | from jax import vmap, grad 59 | kernel_fn = partial(grad(kernel.pure_kernel_fn), kernel.theta) 60 | K_grad_instance_wise = \ 61 | vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(X))(X) 62 | 63 | assert np.allclose(K_grad, K_grad_instance_wise) 64 | 65 | def test_RBF_grad_same_XY(self): 66 | X = np.random.normal(size=(3, 20)) 67 | kernel = NormalizedKernel(RBF(1.)) 68 | K, K_grad = kernel(X, X, eval_gradient=True) 69 | 70 | # Compute the kernel using instance wise formulation 71 | from jax import vmap, grad 72 | kernel_fn = partial(grad(kernel.pure_kernel_fn), kernel.theta) 73 | K_grad_instance_wise = \ 74 | vmap(lambda x: vmap(lambda y: kernel_fn(x, y))(X))(X) 75 | 76 | assert np.allclose(K_grad, K_grad_instance_wise) 77 | 78 | 79 | class TestConstant: 80 | def test_value(self): 81 | val = 5. 82 | X = np.random.normal(size=(10, 20)) 83 | 84 | k = ConstantKernel(val) 85 | assert np.allclose(k(X), np.full((10, 10), val)) 86 | 87 | def test_gradient(self): 88 | val = 5. 89 | X = np.random.normal(size=(10, 20)) 90 | 91 | sk_c = sklearn_C(val) 92 | c = ConstantKernel(val) 93 | _, sk_grad = sk_c(X, eval_gradient=True) 94 | _, grad = c(X, eval_gradient=True) 95 | 96 | assert np.allclose(sk_grad, grad) 97 | 98 | 99 | class TestCompositions: 100 | def test_sum(self): 101 | lengthscale1, lengthscale2 = 5., 10. 102 | X = np.random.normal(size=(10, 20)) 103 | 104 | sk_sum = sklearn_RBF(lengthscale1) + sklearn_RBF(lengthscale2) 105 | ours_sum = RBF(lengthscale1) + RBF(lengthscale2) 106 | assert np.allclose(sk_sum(X), ours_sum(X)) 107 | 108 | def test_product(self): 109 | lengthscale1, lengthscale2 = 5., 10. 110 | X = np.random.normal(size=(10, 20)) 111 | 112 | sk_prod = sklearn_RBF(lengthscale1) + sklearn_RBF(lengthscale2) 113 | ours_prod = RBF(lengthscale1) + RBF(lengthscale2) 114 | assert np.allclose(sk_prod(X), ours_prod(X)) 115 | 116 | def test_exponentiation(self): 117 | lengthscale = 5. 118 | exponent = 2. 119 | X = np.random.normal(size=(10, 20)) 120 | 121 | sk = sklearn_RBF(lengthscale) ** exponent 122 | ours = RBF(lengthscale) ** exponent 123 | assert np.allclose(sk(X), ours(X)) 124 | -------------------------------------------------------------------------------- /tests/test_gpc.py: -------------------------------------------------------------------------------- 1 | from sklearn import datasets 2 | from sklearn_jax_kernels import RBF, GaussianProcessClassifier 3 | import jax.numpy as jnp 4 | import numpy as np 5 | 6 | from scipy.optimize import approx_fprime 7 | 8 | import pytest 9 | 10 | from sklearn_jax_kernels import GaussianProcessClassifier 11 | from sklearn_jax_kernels import RBF, ConstantKernel as C 12 | 13 | from sklearn.utils._testing import assert_almost_equal, assert_array_equal 14 | 15 | from jax.config import config 16 | # Required for numerical gradients checks 17 | config.update("jax_enable_x64", True) 18 | 19 | 20 | def f(x): 21 | return np.sin(x) 22 | 23 | 24 | X = np.atleast_2d(np.linspace(0, 10, 30)).T 25 | X2 = np.atleast_2d([2., 4., 5.5, 6.5, 7.5]).T 26 | y = np.array(f(X).ravel() > 0, dtype=int) 27 | fX = f(X).ravel() 28 | y_mc = np.empty(y.shape, dtype=int) # multi-class 29 | y_mc[fX < -0.35] = 0 30 | y_mc[(fX >= -0.35) & (fX < 0.35)] = 1 31 | y_mc[fX > 0.35] = 2 32 | 33 | 34 | fixed_kernel = RBF(length_scale=1.0, length_scale_bounds="fixed") 35 | kernels = [RBF(length_scale=0.1), fixed_kernel, 36 | RBF(length_scale=1.0, length_scale_bounds=(1e-3, 1e3)), 37 | C(1.0, (1e-2, 1e2)) * 38 | RBF(length_scale=1.0, length_scale_bounds=(1e-3, 1e3))] 39 | non_fixed_kernels = [kernel for kernel in kernels 40 | if kernel != fixed_kernel] 41 | 42 | 43 | @pytest.mark.parametrize('kernel', kernels) 44 | def test_predict_consistent(kernel): 45 | # Check binary predict decision has also predicted probability above 0.5. 46 | gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y) 47 | assert_array_equal(gpc.predict(X), 48 | gpc.predict_proba(X)[:, 1] >= 0.5) 49 | 50 | 51 | @pytest.mark.parametrize('kernel', non_fixed_kernels) 52 | def test_lml_improving(kernel): 53 | # Test that hyperparameter-tuning improves log-marginal likelihood. 54 | gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y) 55 | assert (gpc.log_marginal_likelihood(gpc.kernel_.theta) > 56 | gpc.log_marginal_likelihood(kernel.theta)) 57 | 58 | 59 | @pytest.mark.parametrize('kernel', kernels) 60 | def test_lml_precomputed(kernel): 61 | # Test that lml of optimized kernel is stored correctly. 62 | gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y) 63 | assert_almost_equal(gpc.log_marginal_likelihood(gpc.kernel_.theta), 64 | gpc.log_marginal_likelihood(), 5) 65 | 66 | 67 | # @pytest.mark.parametrize('kernel', kernels) 68 | # def test_lml_without_cloning_kernel(kernel): 69 | # # Test that clone_kernel=False has side-effects of kernel.theta. 70 | # gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y) 71 | # input_theta = np.ones(gpc.kernel_.theta.shape, dtype=np.float64) 72 | # 73 | # gpc.log_marginal_likelihood(input_theta, clone_kernel=False) 74 | # assert_almost_equal(gpc.kernel_.theta, input_theta, 7) 75 | 76 | 77 | @pytest.mark.parametrize('kernel', non_fixed_kernels) 78 | def test_converged_to_local_maximum(kernel): 79 | # Test that we are in local maximum after hyperparameter-optimization. 80 | gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y) 81 | 82 | lml, lml_gradient = \ 83 | gpc.log_marginal_likelihood(gpc.kernel_.theta, True) 84 | 85 | assert np.all((np.abs(lml_gradient) < 1e-4) | 86 | (gpc.kernel_.theta == gpc.kernel_.bounds[:, 0]) | 87 | (gpc.kernel_.theta == gpc.kernel_.bounds[:, 1])) 88 | 89 | 90 | @pytest.mark.parametrize('kernel', kernels) 91 | def test_lml_gradient(kernel): 92 | # Compare analytic and numeric gradient of log marginal likelihood. 93 | gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y) 94 | 95 | lml, lml_gradient = gpc.log_marginal_likelihood(kernel.theta, True) 96 | lml_gradient_approx = \ 97 | approx_fprime(kernel.theta, 98 | lambda theta: gpc.log_marginal_likelihood(theta, 99 | False), 100 | 1e-10) 101 | 102 | assert_almost_equal(lml_gradient, lml_gradient_approx, 3) 103 | 104 | 105 | def test_random_starts(): 106 | # Test that an increasing number of random-starts of GP fitting only 107 | # increases the log marginal likelihood of the chosen theta. 108 | n_samples, n_features = 25, 2 109 | rng = np.random.RandomState(0) 110 | X = rng.randn(n_samples, n_features) * 2 - 1 111 | y = (np.sin(X).sum(axis=1) + np.sin(3 * X).sum(axis=1)) > 0 112 | 113 | kernel = C(1.0, (1e-2, 1e2)) \ 114 | * RBF(length_scale=[1e-3] * n_features, 115 | length_scale_bounds=[(1e-4, 1e+2)] * n_features) 116 | last_lml = -np.inf 117 | for n_restarts_optimizer in range(5): 118 | gp = GaussianProcessClassifier( 119 | kernel=kernel, n_restarts_optimizer=n_restarts_optimizer, 120 | random_state=0).fit(X, y) 121 | lml = gp.log_marginal_likelihood(gp.kernel_.theta) 122 | assert lml > last_lml - np.finfo(np.float32).eps 123 | last_lml = lml 124 | 125 | 126 | @pytest.mark.parametrize('kernel', non_fixed_kernels) 127 | def test_custom_optimizer(kernel): 128 | # Test that GPC can use externally defined optimizers. 129 | # Define a dummy optimizer that simply tests 10 random hyperparameters 130 | def optimizer(obj_func, initial_theta, bounds): 131 | rng = np.random.RandomState(0) 132 | theta_opt, func_min = \ 133 | initial_theta, obj_func(initial_theta, eval_gradient=False) 134 | for _ in range(10): 135 | theta = np.atleast_1d(rng.uniform(np.maximum(-2, bounds[:, 0]), 136 | np.minimum(1, bounds[:, 1]))) 137 | f = obj_func(theta, eval_gradient=False) 138 | if f < func_min: 139 | theta_opt, func_min = theta, f 140 | return theta_opt, func_min 141 | 142 | gpc = GaussianProcessClassifier(kernel=kernel, optimizer=optimizer) 143 | gpc.fit(X, y_mc) 144 | # Checks that optimizer improved marginal likelihood 145 | assert (gpc.log_marginal_likelihood(gpc.kernel_.theta) > 146 | gpc.log_marginal_likelihood(kernel.theta)) 147 | 148 | 149 | @pytest.mark.parametrize('kernel', kernels) 150 | def test_multi_class(kernel): 151 | # Test GPC for multi-class classification problems. 152 | gpc = GaussianProcessClassifier(kernel=kernel) 153 | gpc.fit(X, y_mc) 154 | 155 | y_prob = gpc.predict_proba(X2) 156 | assert_almost_equal(y_prob.sum(1), 1) 157 | 158 | y_pred = gpc.predict(X2) 159 | assert_array_equal(np.argmax(y_prob, 1), y_pred) 160 | 161 | 162 | def test_iris_example(): 163 | iris = datasets.load_iris() 164 | X = jnp.asarray(iris.data) 165 | y = jnp.array(iris.target, dtype=int) 166 | 167 | kernel = 1. + RBF(length_scale=1.0) 168 | gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y) 169 | -------------------------------------------------------------------------------- /tests/test_kernels.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import numpy as np 4 | from inspect import signature 5 | 6 | from sklearn.gaussian_process.kernels import _approx_fprime 7 | from sklearn.base import clone 8 | from sklearn.utils._testing import (assert_almost_equal, assert_array_equal, 9 | assert_array_almost_equal) 10 | 11 | from sklearn_jax_kernels import ( 12 | RBF, Kernel, KernelOperator, ConstantKernel, Exponentiation) 13 | 14 | from jax.config import config 15 | config.update("jax_enable_x64", True) # Required for numerical gradients checks 16 | 17 | X = np.random.RandomState(0).normal(0, 1, (5, 2)) 18 | Y = np.random.RandomState(0).normal(0, 1, (6, 2)) 19 | 20 | kernels = [RBF(length_scale=2.0), RBF(length_scale_bounds=(0.5, 2.0)), 21 | ConstantKernel(constant_value=10.0), 22 | 2.0 * RBF(length_scale=0.33, length_scale_bounds="fixed"), 23 | 2.0 * RBF(length_scale=0.5), 24 | 2.0 * RBF(length_scale=[0.5, 2.0]), 25 | RBF(length_scale=[2.0])] 26 | 27 | 28 | @pytest.mark.parametrize('kernel', kernels) 29 | def test_kernel_gradient(kernel): 30 | # Compare analytic and numeric gradient of kernels. 31 | K, K_gradient = kernel(X, eval_gradient=True) 32 | 33 | assert K_gradient.shape[0] == X.shape[0] 34 | assert K_gradient.shape[1] == X.shape[0] 35 | assert K_gradient.shape[2] == kernel.theta.shape[0] 36 | 37 | def eval_kernel_for_theta(theta): 38 | kernel_clone = kernel.clone_with_theta(theta) 39 | K = kernel_clone(X, eval_gradient=False) 40 | return K 41 | 42 | K_gradient_approx = \ 43 | _approx_fprime(kernel.theta, eval_kernel_for_theta, 1e-10) 44 | 45 | assert_almost_equal(K_gradient, K_gradient_approx, 5) 46 | 47 | 48 | @pytest.mark.parametrize( 49 | 'kernel', 50 | [kernel for kernel in kernels 51 | # skip non-basic kernels 52 | if not (isinstance(kernel, KernelOperator) 53 | or isinstance(kernel, Exponentiation))]) 54 | def test_kernel_theta(kernel): 55 | # Check that parameter vector theta of kernel is set correctly. 56 | theta = kernel.theta 57 | _, K_gradient = kernel(X, eval_gradient=True) 58 | 59 | # Determine kernel parameters that contribute to theta 60 | init_sign = signature(kernel.__class__.__init__).parameters.values() 61 | args = [p.name for p in init_sign if p.name != 'self'] 62 | theta_vars = map(lambda s: s[0:-len("_bounds")], 63 | filter(lambda s: s.endswith("_bounds"), args)) 64 | assert ( 65 | set(hyperparameter.name 66 | for hyperparameter in kernel.hyperparameters) == 67 | set(theta_vars)) 68 | 69 | # Check that values returned in theta are consistent with 70 | # hyperparameter values (being their logarithms) 71 | for i, hyperparameter in enumerate(kernel.hyperparameters): 72 | assert (theta[i] == np.log(getattr(kernel, hyperparameter.name))) 73 | 74 | # Fixed kernel parameters must be excluded from theta and gradient. 75 | for i, hyperparameter in enumerate(kernel.hyperparameters): 76 | # create copy with certain hyperparameter fixed 77 | params = kernel.get_params() 78 | params[hyperparameter.name + "_bounds"] = "fixed" 79 | kernel_class = kernel.__class__ 80 | new_kernel = kernel_class(**params) 81 | # Check that theta and K_gradient are identical with the fixed 82 | # dimension left out 83 | _, K_gradient_new = new_kernel(X, eval_gradient=True) 84 | assert theta.shape[0] == new_kernel.theta.shape[0] + 1 85 | assert K_gradient.shape[2] == K_gradient_new.shape[2] + 1 86 | if i > 0: 87 | assert theta[:i] == new_kernel.theta[:i] 88 | assert_array_equal(K_gradient[..., :i], 89 | K_gradient_new[..., :i]) 90 | if i + 1 < len(kernel.hyperparameters): 91 | assert theta[i + 1:] == new_kernel.theta[i:] 92 | assert_array_equal(K_gradient[..., i + 1:], 93 | K_gradient_new[..., i:]) 94 | 95 | # Check that values of theta are modified correctly 96 | for i, hyperparameter in enumerate(kernel.hyperparameters): 97 | theta[i] = np.log(42) 98 | kernel.theta = theta 99 | assert_almost_equal(getattr(kernel, hyperparameter.name), 42) 100 | 101 | setattr(kernel, hyperparameter.name, 43) 102 | assert_almost_equal(kernel.theta[i], np.log(43)) 103 | 104 | 105 | @pytest.mark.parametrize('kernel', kernels) 106 | def test_auto_vs_cross(kernel): 107 | # Auto-correlation and cross-correlation should be consistent. 108 | K_auto = kernel(X) 109 | K_cross = kernel(X, X) 110 | print(K_auto, K_cross) 111 | assert_array_almost_equal(K_auto, K_cross, 5) 112 | 113 | 114 | @pytest.mark.parametrize('kernel', kernels) 115 | def test_kernel_diag(kernel): 116 | # Test that diag method of kernel returns consistent results. 117 | K_call_diag = np.diag(kernel(X)) 118 | K_diag = kernel.diag(X) 119 | assert_array_almost_equal(K_call_diag, K_diag, 5) 120 | 121 | 122 | def test_kernel_operator_commutative(): 123 | # Adding kernels and multiplying kernels should be commutative. 124 | # Check addition 125 | assert_array_almost_equal((RBF(2.0) + 1.0)(X), (1.0 + RBF(2.0))(X)) 126 | 127 | # Check multiplication 128 | assert_array_almost_equal((3.0 * RBF(2.0))(X), (RBF(2.0) * 3.0)(X)) 129 | 130 | 131 | def test_kernel_anisotropic(): 132 | # Anisotropic kernel should be consistent with isotropic kernels. 133 | kernel = 3.0 * RBF([0.5, 2.0]) 134 | 135 | K = kernel(X) 136 | X1 = np.array(X) 137 | X1[:, 0] *= 4 138 | K1 = 3.0 * RBF(2.0)(X1) 139 | assert_array_almost_equal(K, K1) 140 | 141 | X2 = np.array(X) 142 | X2[:, 1] /= 4 143 | K2 = 3.0 * RBF(0.5)(X2) 144 | assert_array_almost_equal(K, K2) 145 | 146 | # Check getting and setting via theta 147 | kernel.theta = kernel.theta + np.log(2) 148 | assert_array_equal(kernel.theta, np.log([6.0, 1.0, 4.0])) 149 | assert_array_equal(kernel.k2.length_scale, [1.0, 4.0]) 150 | 151 | 152 | @pytest.mark.parametrize('kernel', 153 | [kernel for kernel in kernels 154 | if kernel.is_stationary()]) 155 | def test_kernel_stationary(kernel): 156 | # Test stationarity of kernels. 157 | K = kernel(X, X + 1) 158 | assert_almost_equal(K[0, 0], np.diag(K)) 159 | 160 | 161 | @pytest.mark.parametrize('kernel', kernels) 162 | def test_kernel_input_type(kernel): 163 | # Test whether kernels is for vectors or structured data 164 | if isinstance(kernel, Exponentiation): 165 | assert(kernel.requires_vector_input == 166 | kernel.kernel.requires_vector_input) 167 | if isinstance(kernel, KernelOperator): 168 | assert(kernel.requires_vector_input == 169 | (kernel.k1.requires_vector_input or 170 | kernel.k2.requires_vector_input)) 171 | 172 | 173 | def check_hyperparameters_equal(kernel1, kernel2): 174 | # Check that hyperparameters of two kernels are equal 175 | for attr in set(dir(kernel1) + dir(kernel2)): 176 | if attr.startswith("hyperparameter_"): 177 | attr_value1 = getattr(kernel1, attr) 178 | attr_value2 = getattr(kernel2, attr) 179 | assert attr_value1 == attr_value2 180 | 181 | 182 | @pytest.mark.parametrize("kernel", kernels) 183 | def test_kernel_clone(kernel): 184 | # Test that sklearn's clone works correctly on kernels. 185 | kernel_cloned = clone(kernel) 186 | 187 | # XXX: Should this be fixed? 188 | # This differs from the sklearn's estimators equality check. 189 | assert kernel == kernel_cloned 190 | assert id(kernel) != id(kernel_cloned) 191 | 192 | # Check that all constructor parameters are equal. 193 | assert kernel.get_params() == kernel_cloned.get_params() 194 | 195 | # Check that all hyperparameters are equal. 196 | check_hyperparameters_equal(kernel, kernel_cloned) 197 | 198 | 199 | @pytest.mark.parametrize('kernel', kernels) 200 | def test_kernel_clone_after_set_params(kernel): 201 | # This test is to verify that using set_params does not 202 | # break clone on kernels. 203 | # This used to break because in kernels such as the RBF, non-trivial 204 | # logic that modified the length scale used to be in the constructor 205 | # See https://github.com/scikit-learn/scikit-learn/issues/6961 206 | # for more details. 207 | bounds = (1e-5, 1e5) 208 | kernel_cloned = clone(kernel) 209 | params = kernel.get_params() 210 | # RationalQuadratic kernel is isotropic. 211 | isotropic_kernels = () 212 | if 'length_scale' in params and not isinstance(kernel, 213 | isotropic_kernels): 214 | length_scale = params['length_scale'] 215 | if np.iterable(length_scale): 216 | # XXX unreached code as of v0.22 217 | params['length_scale'] = length_scale[0] 218 | params['length_scale_bounds'] = bounds 219 | else: 220 | params['length_scale'] = [length_scale] * 2 221 | params['length_scale_bounds'] = bounds * 2 222 | kernel_cloned.set_params(**params) 223 | kernel_cloned_clone = clone(kernel_cloned) 224 | assert (kernel_cloned_clone.get_params() == kernel_cloned.get_params()) 225 | assert id(kernel_cloned_clone) != id(kernel_cloned) 226 | check_hyperparameters_equal(kernel_cloned, kernel_cloned_clone) 227 | 228 | 229 | @pytest.mark.parametrize("kernel", kernels) 230 | def test_set_get_params(kernel): 231 | # Check that set_params()/get_params() is consistent with kernel.theta. 232 | 233 | # Test get_params() 234 | index = 0 235 | params = kernel.get_params() 236 | for hyperparameter in kernel.hyperparameters: 237 | if isinstance("string", type(hyperparameter.bounds)): 238 | if hyperparameter.bounds == "fixed": 239 | continue 240 | size = hyperparameter.n_elements 241 | if size > 1: # anisotropic kernels 242 | assert_almost_equal(np.exp(kernel.theta[index:index + size]), 243 | params[hyperparameter.name]) 244 | index += size 245 | else: 246 | assert_almost_equal(np.exp(kernel.theta[index]), 247 | params[hyperparameter.name]) 248 | index += 1 249 | # Test set_params() 250 | index = 0 251 | value = 10 # arbitrary value 252 | for hyperparameter in kernel.hyperparameters: 253 | if isinstance("string", type(hyperparameter.bounds)): 254 | if hyperparameter.bounds == "fixed": 255 | continue 256 | size = hyperparameter.n_elements 257 | if size > 1: # anisotropic kernels 258 | kernel.set_params(**{hyperparameter.name: [value] * size}) 259 | assert_almost_equal(np.exp(kernel.theta[index:index + size]), 260 | [value] * size) 261 | index += size 262 | else: 263 | kernel.set_params(**{hyperparameter.name: value}) 264 | assert_almost_equal(np.exp(kernel.theta[index]), value) 265 | index += 1 266 | 267 | 268 | @pytest.mark.parametrize("kernel", kernels) 269 | def test_repr_kernels(kernel): 270 | # Smoke-test for repr in kernels. 271 | 272 | repr(kernel) 273 | 274 | 275 | def test_warns_on_get_params_non_attribute(): 276 | class MyKernel(Kernel): 277 | def __init__(self, param=5): 278 | pass 279 | 280 | def __call__(self, X, Y=None, eval_gradient=False): 281 | return X 282 | 283 | def diag(self, X): 284 | return np.ones(X.shape[0]) 285 | 286 | def is_stationary(self): 287 | return False 288 | 289 | @property 290 | def pure_kernel_fn(self): 291 | pass 292 | 293 | est = MyKernel() 294 | with pytest.warns(FutureWarning, match='AttributeError'): 295 | params = est.get_params() 296 | 297 | assert params['param'] is None 298 | -------------------------------------------------------------------------------- /tests/test_structured_strings.py: -------------------------------------------------------------------------------- 1 | """Test string kernels and utilities associated with them.""" 2 | import numpy as np 3 | from sklearn_jax_kernels.structured.string_utils import ( 4 | AsciiBytesTransformer, CompressAlphabetTransformer, NGramTransformer, 5 | get_translation_table 6 | ) 7 | from sklearn_jax_kernels import RBF 8 | from sklearn_jax_kernels.structured.strings import ( 9 | DistanceSpectrumKernel, 10 | DistanceFromEndSpectrumKernel, 11 | DistanceRevComplementSpectrumKernel, 12 | RevComplementSpectrumKernel, 13 | SpectrumKernel 14 | ) 15 | # from jax import config 16 | # config.update('jax_disable_jit', True) 17 | 18 | 19 | class TestUtils: 20 | def test_ascii_bytes_transformer(self): 21 | strings = ['abc', 'def'] 22 | transformer = AsciiBytesTransformer() 23 | trans = transformer.transform(strings) 24 | inverse = transformer.inverse_transform(trans) 25 | assert all([s1 == s2 for s1, s2 in zip(strings, inverse)]) 26 | 27 | def test_ngram_transformer(self): 28 | strings = np.asarray([list('abcde')]) 29 | ngrams = np.asarray([[ 30 | list('abc'), 31 | list('bcd'), 32 | list('cde') 33 | ]]) 34 | transformer = NGramTransformer(3) 35 | transformed = transformer.transform(strings) 36 | assert np.all(np.ravel(ngrams) == np.ravel(transformed)) 37 | 38 | def test_compress_alphabet_transformer(self): 39 | strings = np.asarray(['abc']) 40 | transf = CompressAlphabetTransformer() 41 | transf.fit(strings) 42 | out = transf.transform(np.asarray(['cbad'])) 43 | assert np.all(np.array([[2, 1, 0, 3]]) == out) 44 | 45 | 46 | class TestKernels: 47 | def test_spectrum_kernel_example(self): 48 | strings = ['aabbcc', 'aaabac'] 49 | strings_transformed = AsciiBytesTransformer().transform(strings) 50 | kernel = SpectrumKernel(n_gram_length=2) 51 | K = kernel(strings_transformed) 52 | assert np.allclose(K, np.array([[5., 3.], [3., 7.]])) 53 | 54 | def test_spectrum_kernel_ngram_transform(self): 55 | n_gram_length = 2 56 | strings = ['aabbcc', 'aaabac'] 57 | strings_transformed = AsciiBytesTransformer().transform(strings) 58 | ngrams = NGramTransformer(n_gram_length).transform(strings_transformed) 59 | 60 | kernel_strings = SpectrumKernel(n_gram_length=n_gram_length) 61 | kernel_ngrams = SpectrumKernel(n_gram_length=None) 62 | K_strings = kernel_strings(strings_transformed) 63 | K_ngrams = kernel_ngrams(ngrams) 64 | assert np.allclose(K_strings, K_ngrams) 65 | 66 | def test_distance_spectrum_kernel_ngram_transform(self): 67 | n_gram_length = 2 68 | distance_kernel = RBF(1.0) 69 | strings = ['aabbcc', 'aaabac'] 70 | strings_transformed = AsciiBytesTransformer().transform(strings) 71 | ngrams = NGramTransformer(n_gram_length).transform(strings_transformed) 72 | 73 | kernel_strings = DistanceSpectrumKernel(distance_kernel, n_gram_length) 74 | kernel_ngrams = DistanceSpectrumKernel(distance_kernel, None) 75 | K_strings = kernel_strings(strings_transformed) 76 | K_ngrams = kernel_ngrams(ngrams) 77 | assert np.allclose(K_strings, K_ngrams) 78 | 79 | def test_distance_spectrum_kernel(self): 80 | distance_kernel = RBF(1.0) 81 | strings = ['aabbcc', 'aaabac'] 82 | strings_transformed = AsciiBytesTransformer().transform(strings) 83 | kernel = DistanceSpectrumKernel(distance_kernel, 2) 84 | K = kernel(strings_transformed) 85 | K_gt = np.array([ 86 | [5., 2.2130613], 87 | [2.2130613, 6.2130613] 88 | ]) 89 | assert np.allclose(K, K_gt) 90 | 91 | def test_distance_from_end_spectrum_kernel(self): 92 | distance_kernel = RBF(1.0) 93 | strings = ['abc', 'cba'] 94 | strings_transformed = AsciiBytesTransformer().transform(strings) 95 | kernel = DistanceFromEndSpectrumKernel(distance_kernel, 1) 96 | K = kernel(strings_transformed) 97 | K_gt = np.array([ 98 | [3., 3.], 99 | [3., 3.] 100 | ]) 101 | assert np.allclose(K, K_gt) 102 | 103 | def test_rev_comp_spectrum_kernel(self): 104 | mapping = np.array([1, 0, 3, 2], np.uint8) 105 | strings = np.array([[0, 1, 2, 3], [2, 3, 0, 1]], np.uint8) 106 | kernel = RevComplementSpectrumKernel(2, mapping) 107 | K = kernel(strings) 108 | print(K) 109 | K_gt = np.array([ 110 | [5., 5.], 111 | [5., 5.] 112 | ]) 113 | assert np.allclose(K, K_gt) 114 | 115 | def test_get_translation_table(self): 116 | table = get_translation_table( 117 | input_symbols=['a', 'b', 'c'], 118 | output_symbols=['c', 'b', 'a'], 119 | mapping={'a': 0, 'b': 1, 'c': 2} 120 | ) 121 | assert np.all(np.array([2, 1, 0]) == table) 122 | 123 | def test_rev_comp_spectrum_kernel_string(self): 124 | strings = ['ATGCCG', 'CGGCAT'] 125 | transformer = CompressAlphabetTransformer() 126 | strings_transf = transformer.fit_transform(strings) 127 | table = get_translation_table( 128 | ['A', 'T', 'G', 'C'], 129 | ['T', 'A', 'C', 'G'], 130 | transformer._mapping 131 | ) 132 | kernel = RevComplementSpectrumKernel(2, table) 133 | K = kernel(strings_transf) 134 | K_gt = np.array([ 135 | [8., 8.], 136 | [8., 8.] 137 | ]) 138 | assert np.allclose(K, K_gt) 139 | 140 | def test_rev_comp_distance_spectrum_kernel_string(self): 141 | distance_kernel = RBF(1.) 142 | strings = ['ATGC', 'GCAT'] 143 | transformer = CompressAlphabetTransformer() 144 | strings_transf = transformer.fit_transform(strings) 145 | table = get_translation_table( 146 | ['A', 'T', 'G', 'C'], 147 | ['T', 'A', 'C', 'G'], 148 | transformer._mapping 149 | ) 150 | kernel = DistanceRevComplementSpectrumKernel(distance_kernel, 2, table) 151 | K = kernel(strings_transf) 152 | K_gt = np.array([ 153 | [3.2706707, 3.2706707], 154 | [3.2706707, 3.2706707] 155 | ]) 156 | assert np.allclose(K, K_gt) 157 | 158 | def test_rev_comp_distance_spectrum_kernel_string_mismatch(self): 159 | distance_kernel = RBF(1.) 160 | strings = ['ATGC', 'CGAT'] 161 | transformer = CompressAlphabetTransformer() 162 | strings_transf = transformer.fit_transform(strings) 163 | table = get_translation_table( 164 | ['A', 'T', 'G', 'C'], 165 | ['T', 'A', 'C', 'G'], 166 | transformer._mapping 167 | ) 168 | kernel = DistanceRevComplementSpectrumKernel(distance_kernel, 2, table) 169 | K = kernel(strings_transf) 170 | K_gt = np.array([ 171 | [3.2706707, 1.1353353], 172 | [1.1353353, 3.2706707] 173 | ]) 174 | assert np.allclose(K, K_gt) 175 | --------------------------------------------------------------------------------