├── .DS_Store ├── .github └── workflows │ └── unit_test.yml ├── LICENSE.txt ├── README.md ├── assets ├── logo_resized.png ├── scikit-jax-logo-1-blue_blackground.jpg ├── scikit-jax-logo-1-transparent.jpg └── scikit-jax-logo1.jpg ├── examples ├── data │ ├── Naive-Bayes-Classification-Data.csv │ ├── Student_performance_data _.csv │ └── multiple_linear_regression_dataset.csv └── notebooks │ ├── 2_gaussian_nb.ipynb │ ├── 2_k_means.ipynb │ ├── experimenting_clustering.ipynb │ ├── experimenting_decomposing.ipynb │ ├── experimenting_gaussiannb.ipynb │ ├── experimenting_multinomialnb.ipynb │ ├── linear_model.ipynb │ └── loadingimage.ipynb ├── mkdocs.yml ├── readthedocs.yaml ├── requirements.txt ├── requirements_dev.txt ├── setup.py ├── skjax ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── clustering.cpython-311.pyc │ ├── decomposition.cpython-311.pyc │ ├── linear_model.cpython-311.pyc │ └── naive_bayes.cpython-311.pyc ├── _utils │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── _helper_functions.cpython-311.pyc │ │ └── config.cpython-311.pyc │ └── helpers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── _clustering.cpython-311.pyc │ │ └── _helper_functions.cpython-311.pyc │ │ ├── _clustering.py │ │ ├── _data.py │ │ ├── _helper_functions.py │ │ ├── _linear_model.py │ │ └── _naive_bayes.py ├── clustering.py ├── decomposition.py ├── linear_model.py └── naive_bayes.py └── tests ├── .DS_Store ├── __pycache__ ├── test_PCA.cpython-311-pytest-8.3.2.pyc ├── test_PCA.cpython-311.pyc ├── test_clustering.cpython-311.pyc ├── test_decomposition.cpython-311.pyc ├── test_gaussian_naive_bayes.cpython-311-pytest-8.3.2.pyc ├── test_gaussian_naive_bayes.cpython-311.pyc ├── test_k_means.cpython-311-pytest-8.3.2.pyc ├── test_k_means.cpython-311.pyc ├── test_linear_model.cpython-311.pyc ├── test_linear_regression.cpython-311-pytest-8.3.2.pyc ├── test_linear_regression.cpython-311.pyc ├── test_linear_regression_sgd.cpython-311-pytest-8.3.2.pyc ├── test_multinomial_naive_bayes.cpython-311-pytest-8.3.2.pyc └── test_multinomial_naive_bayes.cpython-311.pyc ├── files ├── basic5.csv └── multiple_linear_regression_dataset.csv ├── test_PCA.py ├── test_gaussian_naive_bayes.py ├── test_k_means.py ├── test_linear_regression.py ├── test_linear_regression_sgd.py └── test_multinomial_naive_bayes.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/.DS_Store -------------------------------------------------------------------------------- /.github/workflows/unit_test.yml: -------------------------------------------------------------------------------- 1 | name: Python Unit Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 13 | uses: actions/setup-python@v3 14 | with: 15 | python-version: 3.11 16 | - name: Install dependencies 17 | run: | 18 | pip install --upgrade pip 19 | pip install -r requirements.txt 20 | pip install -r requirements_dev.txt 21 | python setup.py bdist_wheel sdist 22 | pip install . 23 | - name: Run tests 24 | run: | 25 | pytest -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Liiban Mohamud 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Alt text 3 |

4 | 5 | # Scikit-JAX: Classical Machine Learning on the GPU 6 | 7 | Welcome to **Scikit-JAX**, a machine learning library designed to leverage the power of GPUs through JAX for efficient and scalable classical machine learning algorithms. The library provides implementations for a variety of classical machine learning techniques, optimized for performance and ease of use. 8 | 9 | ## Features 10 | 11 | - **Linear Regression**: Implemented with options for different weight initialization methods and dropout regularization. 12 | - **KMeans**: Clustering algorithm to group data points into clusters. 13 | - **Principal Component Analysis (PCA)**: Dimensionality reduction technique to simplify data while preserving essential features. 14 | - **Multinomial Naive Bayes**: Classifier suitable for discrete data, such as text classification tasks. 15 | - **Gaussian Naive Bayes**: Classifier for continuous data with a normal distribution assumption. 16 | 17 | ## Installation 18 | 19 | To install Scikit-JAX, you can use pip. The package is available on PyPI: 20 | 21 | ```terminal 22 | pip install scikit-jax==0.0.3dev1 23 | ``` 24 | 25 | ## Usage 26 | 27 | Here is a quick guide on how to use the key components of Scikit-JAX. 28 | 29 | ### Linear Regression 30 | ```py 31 | from skjax.linear_model import LinearRegression 32 | 33 | # Initialize the model 34 | model = LinearRegression(weights_init='xavier', epochs=100, learning_rate=0.01) 35 | 36 | # Fit the model 37 | model.fit(X_train, y_train) 38 | 39 | # Make predictions 40 | predictions = model.predict(X_test) 41 | 42 | # Plot losses 43 | model.plot_losses() 44 | ``` 45 | 46 | ### K-Means 47 | ```python 48 | from skjax.clustering import KMeans 49 | 50 | # Initialize the model 51 | kmeans = KMeans(num_clusters=3) 52 | 53 | # Fit the model 54 | kmeans.fit(X_train) 55 | ``` 56 | 57 | ### Gaussian Naive Bayes 58 | ```python 59 | from skjax.naive_bayes import GaussianNaiveBayes 60 | 61 | # Initialize the model 62 | nb = GaussianNaiveBayes() 63 | 64 | # Fit the model 65 | nb.fit(X_train, y_train) 66 | 67 | # Make predictions 68 | predictions = nb.predict(X_test) 69 | ``` 70 | 71 | ### License 72 | 73 | Scikit-JAX is licensed under the [MIT License](LICENSE.txt). 74 | -------------------------------------------------------------------------------- /assets/logo_resized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/assets/logo_resized.png -------------------------------------------------------------------------------- /assets/scikit-jax-logo-1-blue_blackground.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/assets/scikit-jax-logo-1-blue_blackground.jpg -------------------------------------------------------------------------------- /assets/scikit-jax-logo-1-transparent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/assets/scikit-jax-logo-1-transparent.jpg -------------------------------------------------------------------------------- /assets/scikit-jax-logo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/assets/scikit-jax-logo1.jpg -------------------------------------------------------------------------------- /examples/data/Naive-Bayes-Classification-Data.csv: -------------------------------------------------------------------------------- 1 | glucose,bloodpressure,diabetes 2 | 40,85,0 3 | 40,92,0 4 | 45,63,1 5 | 45,80,0 6 | 40,73,1 7 | 45,82,0 8 | 40,85,0 9 | 30,63,1 10 | 65,65,1 11 | 45,82,0 12 | 35,73,1 13 | 45,90,0 14 | 50,68,1 15 | 40,93,0 16 | 35,80,1 17 | 50,70,1 18 | 40,73,1 19 | 40,67,1 20 | 40,75,1 21 | 40,80,1 22 | 40,72,1 23 | 40,88,0 24 | 40,78,1 25 | 45,98,0 26 | 40,88,0 27 | 60,67,1 28 | 40,85,0 29 | 40,88,0 30 | 45,78,0 31 | 55,73,1 32 | 45,77,1 33 | 50,68,1 34 | 45,77,1 35 | 40,85,0 36 | 45,70,1 37 | 45,72,1 38 | 45,90,0 39 | 40,65,1 40 | 45,88,0 41 | 45,88,0 42 | 40,68,1 43 | 40,73,1 44 | 45,88,0 45 | 45,78,0 46 | 45,85,0 47 | 40,83,0 48 | 40,63,1 49 | 45,73,1 50 | 45,90,0 51 | 45,87,0 52 | 40,90,0 53 | 45,93,0 54 | 50,73,1 55 | 40,68,1 56 | 50,68,1 57 | 50,90,0 58 | 50,75,1 59 | 50,85,0 60 | 45,83,0 61 | 50,65,1 62 | 45,80,0 63 | 40,75,1 64 | 35,77,1 65 | 55,68,1 66 | 45,85,0 67 | 45,87,0 68 | 25,82,1 69 | 40,90,0 70 | 45,82,0 71 | 45,80,0 72 | 45,88,0 73 | 40,87,0 74 | 45,70,1 75 | 45,88,0 76 | 45,88,0 77 | 50,67,1 78 | 45,93,0 79 | 45,90,0 80 | 55,77,1 81 | 35,62,1 82 | 30,78,1 83 | 50,82,0 84 | 45,83,0 85 | 40,83,0 86 | 40,82,0 87 | 50,70,1 88 | 45,92,0 89 | 40,90,0 90 | 40,85,0 91 | 40,73,1 92 | 50,88,0 93 | 30,80,1 94 | 40,85,1 95 | 30,75,1 96 | 45,85,0 97 | 40,85,0 98 | 60,80,1 99 | 40,90,0 100 | 35,65,1 101 | 40,87,0 102 | 45,87,0 103 | 45,92,0 104 | 45,95,0 105 | 50,67,1 106 | 40,90,0 107 | 45,85,0 108 | 45,75,1 109 | 50,73,1 110 | 40,70,1 111 | 40,92,0 112 | 30,72,1 113 | 50,87,0 114 | 40,93,0 115 | 55,60,1 116 | 40,82,1 117 | 40,95,0 118 | 40,72,1 119 | 60,68,1 120 | 45,92,0 121 | 40,70,1 122 | 45,97,0 123 | 50,83,0 124 | 60,63,1 125 | 35,65,1 126 | 45,90,0 127 | 40,65,1 128 | 40,95,0 129 | 30,68,1 130 | 45,83,0 131 | 45,92,0 132 | 45,87,0 133 | 45,72,1 134 | 45,83,0 135 | 45,85,0 136 | 45,88,0 137 | 55,68,1 138 | 60,65,1 139 | 40,73,1 140 | 35,73,1 141 | 40,83,0 142 | 50,67,1 143 | 45,80,0 144 | 55,68,1 145 | 45,80,0 146 | 50,80,0 147 | 40,92,0 148 | 50,77,1 149 | 50,92,0 150 | 45,85,0 151 | 65,70,1 152 | 45,77,1 153 | 45,82,0 154 | 40,95,0 155 | 45,75,1 156 | 45,78,1 157 | 45,87,0 158 | 50,83,0 159 | 45,92,0 160 | 40,73,1 161 | 45,88,0 162 | 45,72,1 163 | 40,67,1 164 | 45,78,0 165 | 40,78,0 166 | 45,87,0 167 | 50,78,1 168 | 45,75,1 169 | 55,73,1 170 | 45,87,0 171 | 45,80,0 172 | 45,73,1 173 | 45,93,0 174 | 45,73,1 175 | 40,87,0 176 | 40,87,0 177 | 45,87,0 178 | 25,72,1 179 | 45,88,0 180 | 55,68,1 181 | 40,90,0 182 | 40,93,0 183 | 45,82,0 184 | 35,77,1 185 | 50,72,1 186 | 40,100,0 187 | 25,83,1 188 | 55,72,1 189 | 45,82,0 190 | 40,75,1 191 | 35,80,1 192 | 40,90,0 193 | 45,90,0 194 | 45,85,0 195 | 55,63,1 196 | 45,92,0 197 | 40,87,0 198 | 45,93,0 199 | 45,85,0 200 | 35,70,1 201 | 55,73,1 202 | 50,67,1 203 | 50,65,1 204 | 55,75,1 205 | 45,85,0 206 | 35,68,1 207 | 40,80,0 208 | 40,63,1 209 | 40,90,0 210 | 50,90,0 211 | 25,67,1 212 | 55,67,1 213 | 60,67,1 214 | 50,92,0 215 | 45,80,0 216 | 50,77,1 217 | 40,88,0 218 | 45,93,0 219 | 40,93,0 220 | 45,77,0 221 | 40,77,0 222 | 55,75,1 223 | 45,87,0 224 | 60,67,1 225 | 45,90,0 226 | 30,73,1 227 | 45,87,0 228 | 40,88,0 229 | 45,95,0 230 | 45,77,1 231 | 55,68,1 232 | 45,83,0 233 | 40,90,0 234 | 40,83,0 235 | 45,97,0 236 | 45,85,0 237 | 45,83,0 238 | 45,72,1 239 | 45,68,1 240 | 40,93,0 241 | 40,87,0 242 | 40,87,0 243 | 55,70,1 244 | 60,68,1 245 | 50,90,0 246 | 40,90,0 247 | 40,78,1 248 | 50,80,1 249 | 55,75,1 250 | 40,72,1 251 | 50,73,1 252 | 45,58,1 253 | 55,68,1 254 | 40,90,0 255 | 55,72,1 256 | 35,82,0 257 | 40,70,1 258 | 55,57,1 259 | 50,80,0 260 | 45,83,0 261 | 45,85,0 262 | 45,72,1 263 | 40,75,1 264 | 40,85,0 265 | 40,83,0 266 | 40,72,1 267 | 35,63,1 268 | 20,70,1 269 | 40,92,0 270 | 45,87,0 271 | 45,83,0 272 | 55,67,1 273 | 45,80,0 274 | 45,75,1 275 | 40,70,1 276 | 40,88,0 277 | 35,78,1 278 | 55,63,1 279 | 40,82,0 280 | 40,65,1 281 | 45,90,0 282 | 40,72,1 283 | 55,62,1 284 | 50,83,0 285 | 50,58,1 286 | 45,72,1 287 | 50,68,1 288 | 65,60,1 289 | 25,73,1 290 | 35,68,1 291 | 45,58,1 292 | 45,92,0 293 | 45,67,1 294 | 50,72,1 295 | 40,87,0 296 | 35,77,1 297 | 50,65,1 298 | 60,77,1 299 | 40,68,1 300 | 45,88,1 301 | 50,77,1 302 | 45,82,0 303 | 50,73,1 304 | 35,68,1 305 | 40,92,0 306 | 55,65,1 307 | 45,83,0 308 | 50,67,1 309 | 40,68,1 310 | 45,83,0 311 | 45,90,0 312 | 45,83,0 313 | 40,72,1 314 | 45,78,1 315 | 55,68,1 316 | 35,82,1 317 | 50,87,0 318 | 50,83,0 319 | 45,73,0 320 | 45,83,0 321 | 30,73,1 322 | 45,83,0 323 | 40,68,1 324 | 35,77,1 325 | 45,85,0 326 | 45,78,1 327 | 25,73,1 328 | 40,88,0 329 | 45,82,1 330 | 60,68,1 331 | 70,65,1 332 | 40,87,0 333 | 35,70,1 334 | 55,68,1 335 | 35,90,0 336 | 40,65,1 337 | 40,65,1 338 | 55,60,1 339 | 50,83,0 340 | 40,87,0 341 | 40,82,0 342 | 45,85,0 343 | 40,85,0 344 | 55,68,1 345 | 40,83,0 346 | 50,88,0 347 | 40,88,0 348 | 50,85,0 349 | 35,62,1 350 | 40,75,1 351 | 40,75,1 352 | 45,90,0 353 | 60,85,1 354 | 50,85,0 355 | 40,82,0 356 | 40,63,1 357 | 40,88,0 358 | 30,82,1 359 | 45,83,0 360 | 50,77,1 361 | 45,97,0 362 | 45,93,0 363 | 50,68,1 364 | 40,87,0 365 | 45,87,0 366 | 40,67,1 367 | 50,85,0 368 | 50,90,0 369 | 35,70,1 370 | 45,92,0 371 | 30,78,1 372 | 45,88,0 373 | 55,70,1 374 | 45,88,0 375 | 50,78,1 376 | 40,70,1 377 | 45,73,1 378 | 40,88,0 379 | 35,75,1 380 | 45,82,0 381 | 50,68,1 382 | 35,77,1 383 | 40,73,1 384 | 45,75,1 385 | 45,82,0 386 | 45,78,1 387 | 40,70,1 388 | 45,88,0 389 | 35,77,1 390 | 50,65,1 391 | 40,90,0 392 | 45,83,0 393 | 50,67,1 394 | 45,78,1 395 | 45,82,0 396 | 45,85,0 397 | 40,70,1 398 | 45,68,1 399 | 45,73,1 400 | 40,82,0 401 | 50,78,1 402 | 50,92,0 403 | 45,82,0 404 | 45,92,0 405 | 55,65,1 406 | 40,72,1 407 | 50,85,0 408 | 50,62,1 409 | 45,92,0 410 | 55,72,1 411 | 40,83,0 412 | 45,67,1 413 | 55,65,1 414 | 45,73,1 415 | 50,85,0 416 | 45,90,0 417 | 40,72,1 418 | 50,92,0 419 | 45,87,0 420 | 45,75,1 421 | 45,78,0 422 | 55,73,1 423 | 35,90,0 424 | 40,70,1 425 | 40,88,0 426 | 45,95,0 427 | 40,77,1 428 | 25,88,1 429 | 40,88,0 430 | 65,62,1 431 | 40,85,0 432 | 30,83,1 433 | 50,52,1 434 | 50,75,1 435 | 35,78,1 436 | 45,87,0 437 | 40,93,0 438 | 45,82,0 439 | 45,67,1 440 | 55,70,1 441 | 40,82,0 442 | 40,90,0 443 | 45,67,1 444 | 40,80,1 445 | 40,60,1 446 | 40,83,0 447 | 40,88,0 448 | 50,90,0 449 | 50,83,0 450 | 50,68,1 451 | 45,82,0 452 | 55,70,1 453 | 35,72,1 454 | 50,87,0 455 | 45,90,0 456 | 45,90,0 457 | 45,92,0 458 | 45,68,1 459 | 45,90,0 460 | 50,88,0 461 | 45,92,0 462 | 45,88,0 463 | 45,80,0 464 | 55,72,1 465 | 35,83,0 466 | 50,85,0 467 | 50,70,1 468 | 40,83,0 469 | 40,92,0 470 | 50,88,0 471 | 40,100,0 472 | 40,77,1 473 | 35,70,1 474 | 35,85,1 475 | 45,88,0 476 | 40,73,1 477 | 40,65,1 478 | 40,97,0 479 | 35,87,1 480 | 40,83,0 481 | 50,75,1 482 | 45,78,1 483 | 50,95,0 484 | 50,90,0 485 | 40,78,1 486 | 30,75,1 487 | 45,67,1 488 | 50,83,0 489 | 45,80,0 490 | 45,85,0 491 | 60,68,1 492 | 55,67,1 493 | 30,82,1 494 | 45,92,0 495 | 45,62,1 496 | 40,88,0 497 | 35,78,1 498 | 40,75,1 499 | 30,70,1 500 | 30,78,1 501 | 30,78,1 502 | 45,85,1 503 | 50,60,1 504 | 40,92,0 505 | 45,73,1 506 | 40,78,0 507 | 50,72,1 508 | 45,73,1 509 | 40,88,0 510 | 45,90,0 511 | 40,83,0 512 | 45,73,1 513 | 45,68,1 514 | 55,65,1 515 | 45,85,0 516 | 50,63,1 517 | 40,70,1 518 | 50,65,1 519 | 50,75,1 520 | 40,88,0 521 | 45,77,1 522 | 40,93,0 523 | 45,87,0 524 | 45,77,0 525 | 40,87,0 526 | 35,73,1 527 | 40,75,1 528 | 45,87,0 529 | 30,77,1 530 | 40,72,1 531 | 45,77,1 532 | 40,93,0 533 | 35,68,1 534 | 40,75,1 535 | 25,70,1 536 | 40,85,0 537 | 50,77,1 538 | 45,88,0 539 | 45,78,1 540 | 50,68,1 541 | 40,65,1 542 | 50,78,0 543 | 40,60,1 544 | 40,82,0 545 | 40,82,1 546 | 50,80,1 547 | 50,83,0 548 | 35,87,1 549 | 40,92,0 550 | 45,88,0 551 | 55,68,1 552 | 50,80,0 553 | 50,72,1 554 | 65,72,1 555 | 40,85,0 556 | 50,63,1 557 | 45,92,0 558 | 30,78,1 559 | 50,88,0 560 | 40,85,0 561 | 50,90,0 562 | 45,73,1 563 | 50,60,1 564 | 45,85,0 565 | 55,70,1 566 | 35,70,1 567 | 50,80,0 568 | 45,87,0 569 | 45,65,1 570 | 45,70,1 571 | 45,85,0 572 | 40,63,1 573 | 40,87,1 574 | 45,83,0 575 | 50,87,0 576 | 45,82,1 577 | 50,90,0 578 | 50,80,0 579 | 35,88,0 580 | 40,87,0 581 | 45,83,0 582 | 45,80,0 583 | 40,83,1 584 | 45,87,0 585 | 40,95,0 586 | 40,88,0 587 | 45,88,0 588 | 45,83,0 589 | 45,78,1 590 | 45,57,1 591 | 50,83,0 592 | 45,82,0 593 | 45,57,1 594 | 45,83,0 595 | 45,77,1 596 | 30,83,1 597 | 50,75,0 598 | 40,87,0 599 | 35,62,1 600 | 45,78,0 601 | 55,75,1 602 | 45,88,0 603 | 45,68,1 604 | 30,73,1 605 | 45,73,1 606 | 50,73,1 607 | 40,68,1 608 | 35,80,1 609 | 40,85,0 610 | 45,73,1 611 | 40,92,0 612 | 35,77,1 613 | 40,93,0 614 | 55,78,1 615 | 45,73,1 616 | 55,67,1 617 | 45,68,1 618 | 45,93,0 619 | 40,92,0 620 | 50,92,0 621 | 65,63,1 622 | 20,80,1 623 | 45,70,1 624 | 50,88,0 625 | 45,72,1 626 | 65,73,1 627 | 50,85,0 628 | 30,80,1 629 | 45,63,1 630 | 45,87,0 631 | 40,83,0 632 | 45,80,0 633 | 55,75,1 634 | 50,78,0 635 | 45,88,0 636 | 45,77,1 637 | 45,90,0 638 | 50,88,0 639 | 40,87,0 640 | 50,83,0 641 | 45,83,0 642 | 40,92,0 643 | 45,67,1 644 | 50,85,1 645 | 70,62,1 646 | 40,87,0 647 | 45,88,0 648 | 35,77,1 649 | 45,85,0 650 | 45,90,0 651 | 45,88,0 652 | 55,60,1 653 | 45,70,1 654 | 45,75,1 655 | 45,92,0 656 | 50,82,0 657 | 50,80,1 658 | 40,67,1 659 | 45,93,0 660 | 45,92,0 661 | 40,87,0 662 | 45,87,1 663 | 50,85,0 664 | 55,73,1 665 | 45,73,1 666 | 35,73,1 667 | 45,88,0 668 | 50,90,0 669 | 30,80,1 670 | 25,73,1 671 | 35,80,1 672 | 40,65,1 673 | 45,73,0 674 | 35,75,1 675 | 45,70,1 676 | 60,70,1 677 | 55,75,1 678 | 45,68,1 679 | 35,83,1 680 | 45,67,1 681 | 45,85,0 682 | 55,63,1 683 | 45,88,1 684 | 45,57,1 685 | 50,58,1 686 | 50,75,1 687 | 45,92,0 688 | 40,85,0 689 | 45,80,0 690 | 40,88,0 691 | 50,83,0 692 | 45,90,0 693 | 45,82,1 694 | 40,82,1 695 | 45,75,1 696 | 50,93,0 697 | 45,75,1 698 | 45,93,0 699 | 50,72,1 700 | 50,73,1 701 | 40,82,0 702 | 40,90,0 703 | 35,67,1 704 | 40,93,0 705 | 50,70,1 706 | 50,85,0 707 | 40,90,0 708 | 40,80,1 709 | 40,87,0 710 | 50,85,0 711 | 40,88,0 712 | 40,82,1 713 | 50,85,0 714 | 40,85,1 715 | 45,68,1 716 | 60,60,1 717 | 50,83,0 718 | 50,67,1 719 | 50,78,0 720 | 40,70,1 721 | 40,70,1 722 | 45,77,0 723 | 45,88,0 724 | 45,87,0 725 | 50,62,1 726 | 50,63,1 727 | 40,72,1 728 | 40,90,0 729 | 65,73,1 730 | 55,73,1 731 | 40,67,1 732 | 45,80,0 733 | 45,90,0 734 | 45,82,0 735 | 40,87,0 736 | 45,88,0 737 | 40,67,1 738 | 45,85,0 739 | 50,88,0 740 | 60,75,1 741 | 45,60,1 742 | 35,72,1 743 | 50,77,0 744 | 40,75,1 745 | 55,73,1 746 | 40,63,1 747 | 45,90,0 748 | 45,92,0 749 | 40,98,0 750 | 40,83,0 751 | 60,67,1 752 | 45,88,0 753 | 40,72,1 754 | 45,82,1 755 | 40,82,1 756 | 45,87,0 757 | 45,88,0 758 | 55,67,1 759 | 40,67,1 760 | 45,85,0 761 | 60,75,1 762 | 40,80,1 763 | 45,68,1 764 | 40,93,0 765 | 45,83,0 766 | 45,70,1 767 | 25,73,1 768 | 55,93,0 769 | 55,67,1 770 | 55,62,1 771 | 60,68,1 772 | 50,78,1 773 | 55,78,1 774 | 45,88,0 775 | 50,85,0 776 | 35,83,1 777 | 45,83,0 778 | 45,75,1 779 | 50,70,1 780 | 45,85,0 781 | 50,87,0 782 | 45,78,1 783 | 40,93,0 784 | 30,78,1 785 | 50,70,1 786 | 35,90,0 787 | 45,83,0 788 | 60,62,1 789 | 45,92,0 790 | 40,62,1 791 | 50,75,1 792 | 40,65,1 793 | 50,90,0 794 | 30,75,1 795 | 35,67,1 796 | 40,70,1 797 | 40,78,0 798 | 50,93,0 799 | 45,87,0 800 | 45,90,0 801 | 45,85,0 802 | 40,77,1 803 | 50,95,0 804 | 45,90,0 805 | 35,80,1 806 | 45,83,0 807 | 45,90,0 808 | 45,95,0 809 | 35,73,1 810 | 60,70,1 811 | 45,92,0 812 | 45,82,0 813 | 45,70,1 814 | 45,77,1 815 | 30,70,1 816 | 40,85,0 817 | 45,67,1 818 | 55,68,1 819 | 45,80,1 820 | 55,72,1 821 | 35,67,1 822 | 50,78,1 823 | 35,82,0 824 | 50,77,1 825 | 45,92,0 826 | 45,85,0 827 | 45,75,1 828 | 50,88,0 829 | 40,87,0 830 | 40,73,1 831 | 45,63,1 832 | 50,67,1 833 | 55,73,1 834 | 35,82,0 835 | 45,85,0 836 | 45,85,0 837 | 40,65,1 838 | 40,85,0 839 | 45,80,0 840 | 40,87,0 841 | 55,77,1 842 | 40,67,1 843 | 45,82,0 844 | 50,78,1 845 | 50,83,0 846 | 50,65,1 847 | 40,87,0 848 | 45,93,0 849 | 50,88,0 850 | 45,85,0 851 | 45,92,0 852 | 45,68,1 853 | 55,72,1 854 | 40,77,1 855 | 50,65,1 856 | 40,75,1 857 | 40,80,0 858 | 40,92,0 859 | 40,75,1 860 | 45,83,1 861 | 45,87,0 862 | 35,78,1 863 | 50,85,0 864 | 50,65,1 865 | 40,88,0 866 | 45,73,1 867 | 45,87,0 868 | 40,87,0 869 | 50,92,0 870 | 40,87,0 871 | 50,85,0 872 | 45,70,1 873 | 35,83,0 874 | 40,88,0 875 | 20,73,1 876 | 45,60,1 877 | 45,88,0 878 | 55,77,1 879 | 40,87,0 880 | 40,87,0 881 | 45,82,0 882 | 50,80,1 883 | 50,95,0 884 | 40,67,1 885 | 45,67,1 886 | 45,85,0 887 | 45,78,0 888 | 40,88,0 889 | 35,72,1 890 | 45,80,0 891 | 45,85,0 892 | 45,88,0 893 | 40,87,0 894 | 35,70,1 895 | 50,82,0 896 | 45,87,0 897 | 45,80,0 898 | 40,78,0 899 | 45,80,0 900 | 45,72,1 901 | 45,77,1 902 | 40,88,0 903 | 45,87,0 904 | 45,90,0 905 | 45,83,0 906 | 45,88,0 907 | 35,88,0 908 | 60,63,1 909 | 40,80,1 910 | 45,87,0 911 | 45,90,0 912 | 35,82,0 913 | 40,72,1 914 | 45,72,1 915 | 40,90,0 916 | 55,87,0 917 | 30,77,1 918 | 45,85,0 919 | 45,72,1 920 | 45,87,0 921 | 40,87,0 922 | 40,70,1 923 | 50,83,0 924 | 40,80,0 925 | 45,83,0 926 | 50,82,0 927 | 45,78,0 928 | 45,85,0 929 | 45,90,0 930 | 45,88,0 931 | 45,67,1 932 | 30,75,1 933 | 35,77,1 934 | 40,88,0 935 | 40,75,1 936 | 45,95,0 937 | 40,72,1 938 | 40,65,1 939 | 45,90,0 940 | 40,77,1 941 | 45,90,0 942 | 40,88,0 943 | 45,75,1 944 | 45,73,1 945 | 35,75,1 946 | 55,65,1 947 | 40,95,0 948 | 45,87,0 949 | 45,67,1 950 | 50,80,1 951 | 50,67,1 952 | 45,77,0 953 | 45,92,0 954 | 45,73,1 955 | 40,75,0 956 | 35,70,1 957 | 45,50,1 958 | 45,87,0 959 | 35,73,1 960 | 40,83,1 961 | 45,88,0 962 | 45,73,1 963 | 40,72,1 964 | 45,88,0 965 | 50,88,0 966 | 40,75,1 967 | 50,95,0 968 | 35,68,1 969 | 45,75,1 970 | 50,90,0 971 | 40,88,0 972 | 40,78,1 973 | 55,68,1 974 | 45,85,0 975 | 60,63,1 976 | 40,78,0 977 | 40,92,0 978 | 50,65,1 979 | 45,82,0 980 | 45,97,0 981 | 50,83,1 982 | 45,73,1 983 | 45,75,1 984 | 50,67,1 985 | 45,65,1 986 | 50,73,1 987 | 40,73,1 988 | 45,97,0 989 | 45,90,0 990 | 45,90,0 991 | 45,68,1 992 | 45,87,0 993 | 40,83,0 994 | 40,83,0 995 | 40,60,1 996 | 45,82,0 997 | -------------------------------------------------------------------------------- /examples/data/multiple_linear_regression_dataset.csv: -------------------------------------------------------------------------------- 1 | age,experience,income 2 | 25,1,30450 3 | 30,3,35670 4 | 47,2,31580 5 | 32,5,40130 6 | 43,10,47830 7 | 51,7,41630 8 | 28,5,41340 9 | 33,4,37650 10 | 37,5,40250 11 | 39,8,45150 12 | 29,1,27840 13 | 47,9,46110 14 | 54,5,36720 15 | 51,4,34800 16 | 44,12,51300 17 | 41,6,38900 18 | 58,17,63600 19 | 23,1,30870 20 | 44,9,44190 21 | 37,10,48700 22 | -------------------------------------------------------------------------------- /examples/notebooks/2_gaussian_nb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 109, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "\n", 15 | "from functools import partial\n", 16 | "from jax import vmap, jit, tree\n", 17 | "\n", 18 | "from sklearn.model_selection import train_test_split" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 110, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "\n", 31 | "Case 3: Dataset with Redundant Features\n", 32 | " Feature1 Feature2 Feature3 Target\n", 33 | "0 A W L 0\n", 34 | "1 B W M 0\n", 35 | "2 C W N 1\n", 36 | "3 A W L 0\n", 37 | "4 B W M 1\n", 38 | "5 C W N 1\n", 39 | "6 A W L 2\n", 40 | "7 B W M 2\n", 41 | "8 C W N 2\n", 42 | "9 A W L 0\n", 43 | "10 B W M 0\n", 44 | "11 C W N 1\n", 45 | "12 A W L 0\n", 46 | "13 B W M 1\n", 47 | "14 C W N 1\n", 48 | "15 A W L 2\n", 49 | "16 B W M 2\n", 50 | "17 C W N 2\n", 51 | "18 A W L 0\n", 52 | "19 B W M 0\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "# Case 3: Dataset with Redundant Features\n", 58 | "data3 = {\n", 59 | " 'Feature1': ['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A',\n", 60 | " 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A', 'B'],\n", 61 | " 'Feature2': ['W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n", 62 | " 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W'], # Redundant feature\n", 63 | " 'Feature3': ['L', 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'N', 'L',\n", 64 | " 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'N', 'L', 'M'],\n", 65 | " 'Target': [0, 0, 1, 0, 1, 1, 2, 2, 2, 0,\n", 66 | " 0, 1, 0, 1, 1, 2, 2, 2, 0, 0]\n", 67 | "}\n", 68 | "\n", 69 | "data = pd.DataFrame(data3)\n", 70 | "print(\"\\nCase 3: Dataset with Redundant Features\")\n", 71 | "print(data)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 111, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | " Feature1 Feature2 Feature3 Target\n", 84 | "0 0 0 0 0\n", 85 | "1 1 0 1 0\n", 86 | "2 2 0 2 1\n", 87 | "3 0 0 0 0\n", 88 | "4 1 0 1 1\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "from sklearn.preprocessing import LabelEncoder\n", 94 | "\n", 95 | "label_encoders = {}\n", 96 | "for column in ['Feature1', 'Feature2', 'Feature3']:\n", 97 | " le = LabelEncoder()\n", 98 | " data[column] = le.fit_transform(data[column])\n", 99 | " label_encoders[column] = le\n", 100 | "\n", 101 | "# Display the updated DataFrame\n", 102 | "print(data.head())" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 112, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "X = jnp.asarray(data.drop(columns=['Target']).to_numpy(dtype=jnp.int32))\n", 112 | "y = jnp.asarray(data['Target'].to_numpy(dtype=jnp.int32))" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 113, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=12)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 114, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "((16, 3), (16,), (20, 3), (20,))" 133 | ] 134 | }, 135 | "execution_count": 114, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "X_train.shape, y_train.shape, X.shape, y.shape" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 115, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stdout", 151 | "output_type": "stream", 152 | "text": [ 153 | "Classes: [0, 1, 2]\n", 154 | "Categories in each feature column: [Array([0, 1, 2], dtype=int32), Array([0], dtype=int32), Array([0, 1, 2], dtype=int32)]\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "unique_classes = jnp.unique(y) \n", 160 | "unique_categories = list(map(jnp.unique, X.T))\n", 161 | "\n", 162 | "print(f'Classes: {unique_classes.tolist()}')\n", 163 | "print(f'Categories in each feature column: {unique_categories}')" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 116, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "Array([0.375 , 0.3125, 0.3125], dtype=float32)" 175 | ] 176 | }, 177 | "execution_count": 116, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | } 181 | ], 182 | "source": [ 183 | "@jit\n", 184 | "def compute_priors(y):\n", 185 | " return jnp.unique(y, return_counts=True, size=len(unique_classes))[1] / jnp.sum(jnp.unique(y, return_counts=True, size=len(unique_classes))[1])\n", 186 | "\n", 187 | "priors = compute_priors(y_train)\n", 188 | "priors" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 117, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "indices_for_the_classes = [jnp.where(y_train == class_) for class_ in unique_classes]\n", 198 | "\n", 199 | "def restructure_matrix_into_blocks(X:jax.Array):\n", 200 | " @jit\n", 201 | " def restructure_by_indices(indices:jax.Array):\n", 202 | " return X[indices]\n", 203 | " return restructure_by_indices\n", 204 | "\n", 205 | "X_train_restructured = tree.flatten(tree.map(restructure_matrix_into_blocks(X_train), indices_for_the_classes))[0]" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 118, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "text/plain": [ 216 | "Array([[1, 0, 1],\n", 217 | " [1, 0, 1],\n", 218 | " [0, 0, 0],\n", 219 | " [0, 0, 0],\n", 220 | " [0, 0, 0],\n", 221 | " [0, 0, 0]], dtype=int32)" 222 | ] 223 | }, 224 | "execution_count": 118, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "X_train_restructured[0]" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 119, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "[Array([0.6666667 , 0.33333334], dtype=float32), Array([1.], dtype=float32), Array([0.6666667 , 0.33333334], dtype=float32)]\n" 243 | ] 244 | } 245 | ], 246 | "source": [ 247 | "def return_likelihoods_for_feature_column(column:jax.Array):\n", 248 | " counts_of_feature_in_block = jnp.unique(column, return_counts=True)[1]\n", 249 | " return counts_of_feature_in_block / jnp.sum(counts_of_feature_in_block)\n", 250 | "\n", 251 | "likelihoods_for_block_0 = list(map(return_likelihoods_for_feature_column, X_train_restructured[0].T))\n", 252 | "print(likelihoods_for_block_0)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 120, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "data": { 262 | "text/plain": [ 263 | "[[Array([0.6666667 , 0.33333334], dtype=float32),\n", 264 | " Array([1.], dtype=float32),\n", 265 | " Array([0.6666667 , 0.33333334], dtype=float32)],\n", 266 | " [Array([0.2, 0.8], dtype=float32),\n", 267 | " Array([1.], dtype=float32),\n", 268 | " Array([0.2, 0.8], dtype=float32)],\n", 269 | " [Array([0.4, 0.2, 0.4], dtype=float32),\n", 270 | " Array([1.], dtype=float32),\n", 271 | " Array([0.4, 0.2, 0.4], dtype=float32)]]" 272 | ] 273 | }, 274 | "execution_count": 120, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "def compute_likelihoods_for_blocks(block:jax.Array):\n", 281 | " return list(map(return_likelihoods_for_feature_column, block.T))\n", 282 | "\n", 283 | "likelihoods = tree.map(compute_likelihoods_for_blocks, X_train_restructured)\n", 284 | "likelihoods" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 121, 290 | "metadata": {}, 291 | "outputs": [ 292 | { 293 | "data": { 294 | "text/plain": [ 295 | "Array([[1, 0, 1],\n", 296 | " [1, 0, 1],\n", 297 | " [0, 0, 0],\n", 298 | " [1, 0, 1]], dtype=int32)" 299 | ] 300 | }, 301 | "execution_count": 121, 302 | "metadata": {}, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "X_test" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 122, 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "data": { 317 | "text/plain": [ 318 | "(Array([1., 1., 1., 1.], dtype=float32), Array([0, 0, 0, 0], dtype=int32))" 319 | ] 320 | }, 321 | "execution_count": 122, 322 | "metadata": {}, 323 | "output_type": "execute_result" 324 | } 325 | ], 326 | "source": [ 327 | "block = 0\n", 328 | "j = 1\n", 329 | "x_j = X_test.T[j]\n", 330 | "likelihoods[block][j][x_j], x_j" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 123, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "data": { 340 | "text/plain": [ 341 | "Array(0.25, dtype=float32)" 342 | ] 343 | }, 344 | "execution_count": 123, 345 | "metadata": {}, 346 | "output_type": "execute_result" 347 | } 348 | ], 349 | "source": [ 350 | "def retrieve_likelihood_for_block_i_feature_j_xij(x_ij:jax.Array, i:int, j:int):\n", 351 | " return likelihoods[i][j][x_ij]\n", 352 | "\n", 353 | "def retrieve_likelihood_for_block_i_feature_j(feature_column:jax.Array, i:int, j:int):\n", 354 | " return vmap(retrieve_likelihood_for_block_i_feature_j_xij, in_axes=(0, None, None))(feature_column, i, j)\n", 355 | "\n", 356 | "def retrieve_likelihood_for_block_i(X:jax.Array, i:int, j:int):\n", 357 | " return vmap(retrieve_likelihood_for_block_i_feature_j, in_axes=(0, None, None))(X, i, j)\n", 358 | "\n", 359 | "block_of_likelihoods = [[] for _ in range(len(unique_classes))]\n", 360 | "posteriors_array = []\n", 361 | "for i in range(unique_classes.shape[0]):\n", 362 | " for j in range(X_test.shape[1]):\n", 363 | " v_array = retrieve_likelihood_for_block_i_feature_j(X_test.T[j], i, j)\n", 364 | " block_of_likelihoods[i].append(v_array)\n", 365 | " posteriors = jnp.prod(jnp.vstack(block_of_likelihoods[i]), axis=0)*priors[i]\n", 366 | " posteriors_array.append(posteriors)\n", 367 | "\n", 368 | "y_pred = jnp.vstack(posteriors_array).argmin(axis=0)\n", 369 | "jnp.where(y_pred == y_test, 1, 0).mean()\n", 370 | "#jnp.prod(jnp.vstack(block_of_likelihoods[0]), axis=0)*priors[0], block_of_likelihoods[0]" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 124, 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "data": { 380 | "text/plain": [ 381 | "Array(0.5, dtype=float32)" 382 | ] 383 | }, 384 | "execution_count": 124, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | } 388 | ], 389 | "source": [ 390 | "from sklearn.naive_bayes import MultinomialNB\n", 391 | "\n", 392 | "model_sk = MultinomialNB()\n", 393 | "model_sk_fitted = model_sk.fit(X_train, y_train)\n", 394 | "y_pred_sk = model_sk_fitted.predict(X_test)\n", 395 | "jnp.where(y_pred_sk == y_test, 1, 0).mean()" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 125, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "data": { 405 | "text/plain": [ 406 | "(Array([2, 1, 2], dtype=int32),\n", 407 | " Array([[1, 0, 1],\n", 408 | " [1, 0, 1],\n", 409 | " [0, 0, 0],\n", 410 | " [1, 0, 1]], dtype=int32))" 411 | ] 412 | }, 413 | "execution_count": 125, 414 | "metadata": {}, 415 | "output_type": "execute_result" 416 | } 417 | ], 418 | "source": [ 419 | "class MultinomialNaiveBayes():\n", 420 | " def fit(self, X:jax.Array, y:jax.Array):\n", 421 | " # Computing priors\n", 422 | " self.priors = compute_priors(y)\n", 423 | "\n", 424 | " # Computing likelihoods\n", 425 | " self.unique_classes = jnp.unique(y)\n", 426 | " self.num_classes = len(self.unique_classes)\n", 427 | "\n", 428 | " indices_for_the_classes = [jnp.where(y == class_) for class_ in self.unique_classes]\n", 429 | "\n", 430 | " self.X_restructured = tree.flatten(tree.map(restructure_matrix_into_blocks(X), indices_for_the_classes))[0]\n", 431 | " \n", 432 | " self.blocks_of_likelihoods = tree.map(compute_likelihoods_for_blocks, self.X_restructured)\n", 433 | "\n", 434 | " return self\n", 435 | " \n", 436 | " def predict(self, X:jax.Array):\n", 437 | " self.log_posteriors = jnp.zeros(shape=(X.shape[1], self.num_classes))\n", 438 | " for i in range(self.unique_classes.shape[0]):\n", 439 | " for j in range(X.shape[1]):\n", 440 | " array_of_prior_and_likelihoods = jnp.hstack((priors[i], retrieve_likelihood_for_block_i_feature_j(X.T[j], i, j)))\n", 441 | " log_posterior = jnp.sum(jnp.log(array_of_prior_and_likelihoods))\n", 442 | " self.log_posteriors = self.log_posteriors.at[j, i].set(log_posterior)\n", 443 | "\n", 444 | " return self.log_posteriors.argmin(axis=1)\n", 445 | " \n", 446 | "model = MultinomialNaiveBayes()\n", 447 | "model_fitted = model.fit(X_train, y_train)\n", 448 | "model_fitted.predict(X_test), X_test" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [] 457 | } 458 | ], 459 | "metadata": { 460 | "kernelspec": { 461 | "display_name": ".venv", 462 | "language": "python", 463 | "name": "python3" 464 | }, 465 | "language_info": { 466 | "codemirror_mode": { 467 | "name": "ipython", 468 | "version": 3 469 | }, 470 | "file_extension": ".py", 471 | "mimetype": "text/x-python", 472 | "name": "python", 473 | "nbconvert_exporter": "python", 474 | "pygments_lexer": "ipython3", 475 | "version": "3.11.9" 476 | } 477 | }, 478 | "nbformat": 4, 479 | "nbformat_minor": 2 480 | } 481 | -------------------------------------------------------------------------------- /examples/notebooks/experimenting_decomposing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 59, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "from jax.numpy.linalg import norm, svd\n", 12 | "\n", 13 | "jax.config.update('jax_enable_x64', False)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 60, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "def normalize_vector(v:jax.Array):\n", 23 | " return v / norm(v)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 61, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def householder_reflection(x:jax.Array):\n", 33 | " w_0 = x[0] + jnp.sign(x[0]) * norm(x) if x[0] != 0 else norm(x)\n", 34 | " w = x.at[0].set(w_0)\n", 35 | " w = normalize_vector(w)\n", 36 | " return jnp.identity(n=len(w)) - 2*jnp.linalg.outer(w, w) " 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 62, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def apply_householder_reflection(H, x):\n", 46 | " return jnp.dot(H, x)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 63, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "Array([3., 5.], dtype=float32)" 58 | ] 59 | }, 60 | "execution_count": 63, 61 | "metadata": {}, 62 | "output_type": "execute_result" 63 | } 64 | ], 65 | "source": [ 66 | "x = jnp.array([1.,2.])\n", 67 | "y = jnp.array([2.,3.])\n", 68 | "\n", 69 | "x + y" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 64, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "H = householder_reflection(x) " 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 65, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/plain": [ 89 | "Array([-2.2360678e+00, 1.7881393e-07], dtype=float32)" 90 | ] 91 | }, 92 | "execution_count": 65, 93 | "metadata": {}, 94 | "output_type": "execute_result" 95 | } 96 | ], 97 | "source": [ 98 | "apply_householder_reflection(H, x)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 66, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "Array([[1., 2., 3.],\n", 110 | " [4., 5., 6.],\n", 111 | " [7., 8., 9.]], dtype=float32)" 112 | ] 113 | }, 114 | "execution_count": 66, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "A = jnp.array([[1,2,3], [4,5,6], [7,8,9]], dtype=jnp.float32)\n", 121 | "A" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 67, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "Array([[-8.1240387e+00, -9.6011372e+00, -1.1078235e+01],\n", 133 | " [-1.7881393e-07, -8.5965633e-02, -1.7193133e-01],\n", 134 | " [ 1.7881393e-07, -9.0043950e-01, -1.8008795e+00]], dtype=float32)" 135 | ] 136 | }, 137 | "execution_count": 67, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "H1 = householder_reflection(A.T[0])\n", 144 | "A2 = jnp.dot(H1, A)\n", 145 | "A2" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 68, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "data": { 155 | "text/plain": [ 156 | "Array([[-8.1240387e+00, -9.6011372e+00, -1.1078235e+01],\n", 157 | " [-1.7881393e-07, -8.5965633e-02, -1.7193133e-01],\n", 158 | " [ 1.7881393e-07, -9.0043950e-01, -1.8008795e+00]], dtype=float32)" 159 | ] 160 | }, 161 | "execution_count": 68, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "H1 @ A" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 69, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "def qr_decomposition(A:jax.Array):\n", 177 | " n, m = A.shape\n", 178 | "\n", 179 | " R = A.copy()\n", 180 | " Q = jnp.identity(n)\n", 181 | " for i in range(m-1):\n", 182 | " H_i = householder_reflection(R[i:, i:].T[0])\n", 183 | " H_i = jax.scipy.linalg.block_diag(jnp.eye(i), H_i) if i != 0 else H_i\n", 184 | " R = jnp.dot(H_i, R)\n", 185 | " Q = jnp.dot(Q, H_i.T)\n", 186 | " \n", 187 | " return Q, R" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 70, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "Array([[1.2999997 , 1.9999998 ],\n", 199 | " [1.9999998 , 0.99999976]], dtype=float32)" 200 | ] 201 | }, 202 | "execution_count": 70, 203 | "metadata": {}, 204 | "output_type": "execute_result" 205 | } 206 | ], 207 | "source": [ 208 | "Q, R = qr_decomposition(jnp.array([[1.3,2.],[2.,1.]]))\n", 209 | "jnp.dot(Q, R)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 71, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "A = jnp.array([[1,2,3], [4,5,6], [7,8,9]], dtype=jnp.float32)\n", 219 | "\n", 220 | "def bidiagonalisation_decomposition(A:jax.Array):\n", 221 | " n, m = A.shape\n", 222 | " Q_2 = jnp.identity(m)\n", 223 | " Q_1 = jnp.identity(n)\n", 224 | " B = A.copy()\n", 225 | "\n", 226 | " for i in range(min(n, m)):\n", 227 | " if i <= n-1:\n", 228 | " H_1 = householder_reflection(B[i:, i:].T[0])\n", 229 | " H_1 = jax.scipy.linalg.block_diag(jnp.eye(i), H_1)\n", 230 | " B = jnp.dot(H_1, B)\n", 231 | " Q_1 = jnp.dot(Q_1, H_1.T)\n", 232 | "\n", 233 | " if i < m-1:\n", 234 | " H_2 = householder_reflection(B[i:, i+1:][0])\n", 235 | " H_2 = jax.scipy.linalg.block_diag(jnp.eye(i+1), H_2)\n", 236 | " B = jnp.dot(B, H_2.T)\n", 237 | " Q_2 = jnp.dot(H_2, Q_2)\n", 238 | "\n", 239 | " return Q_1, B, Q_2" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 72, 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "data": { 249 | "text/plain": [ 250 | "Array([[1.0000021, 2.0000029, 2.9999998, 4.000003 , 5.0000033, 6.0000043],\n", 251 | " [6.000002 , 5.000002 , 4.0000024, 3.0000021, 2.000002 , 1.0000017]], dtype=float32)" 252 | ] 253 | }, 254 | "execution_count": 72, 255 | "metadata": {}, 256 | "output_type": "execute_result" 257 | } 258 | ], 259 | "source": [ 260 | "Q_1, B, Q_2 = bidiagonalisation_decomposition(jnp.array([[1,2,3,4,5,6], [6,5,4,3,2,1]], dtype=jnp.float32))\n", 261 | "Q_1 @ B @ Q_2" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 73, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "B = jnp.array([[1, 2, 0, 0],\n", 271 | " [3, 4, 5, 0],\n", 272 | " [0, 6, 7, 8],\n", 273 | " [0, 0, 9, 10]])" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 74, 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | "Block 11:\n", 286 | "[[1 2]\n", 287 | " [3 4]]\n", 288 | "\n", 289 | "Block 12:\n", 290 | "[[0 0]\n", 291 | " [5 0]]\n", 292 | "\n", 293 | "Block 21:\n", 294 | "[[0 6]\n", 295 | " [0 0]]\n", 296 | "\n", 297 | "Block 22:\n", 298 | "[[ 7 8]\n", 299 | " [ 9 10]]\n" 300 | ] 301 | } 302 | ], 303 | "source": [ 304 | "# Define block size (example: 2x2)\n", 305 | "block_size = 2\n", 306 | "\n", 307 | "# Extract blocks\n", 308 | "block_11 = B[:block_size, :block_size] # Top-left block\n", 309 | "block_12 = B[:block_size, block_size:] # Top-right block\n", 310 | "block_21 = B[block_size:, :block_size] # Bottom-left block\n", 311 | "block_22 = B[block_size:, block_size:] # Bottom-right block\n", 312 | "\n", 313 | "print(\"Block 11:\")\n", 314 | "print(block_11)\n", 315 | "print(\"\\nBlock 12:\")\n", 316 | "print(block_12)\n", 317 | "print(\"\\nBlock 21:\")\n", 318 | "print(block_21)\n", 319 | "print(\"\\nBlock 22:\")\n", 320 | "print(block_22)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 75, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "def split_matrix_into_blocks(B, block_size:int=2):\n", 330 | "\n", 331 | " assert block_size > 0, f'The block size should be greater than 0. Instead {block_size}'\n", 332 | " assert block_size <= min(B.shape), f'The block size should be less than or equal to the size of the matrix. Instead {block_size} > {min(B.shape)}'\n", 333 | "\n", 334 | " block_11 = B[:block_size, :block_size] # Top-left block\n", 335 | " block_12 = B[:block_size, block_size:] # Top-right block\n", 336 | " block_21 = B[block_size:, :block_size] # Bottom-left block\n", 337 | " block_22 = B[block_size:, block_size:] # Bottom-right block\n", 338 | "\n", 339 | " return [block_11, block_12, block_21, block_22]" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 76, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "data": { 349 | "text/plain": [ 350 | "[Array([[1, 2],\n", 351 | " [3, 4]], dtype=int32),\n", 352 | " Array([[0, 0],\n", 353 | " [5, 0]], dtype=int32),\n", 354 | " Array([[0, 6],\n", 355 | " [0, 0]], dtype=int32),\n", 356 | " Array([[ 7, 8],\n", 357 | " [ 9, 10]], dtype=int32)]" 358 | ] 359 | }, 360 | "execution_count": 76, 361 | "metadata": {}, 362 | "output_type": "execute_result" 363 | } 364 | ], 365 | "source": [ 366 | "split_matrix_into_blocks(B)" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 77, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "def perform_svd_on_blocks(blocks: list):\n", 376 | " U_list, S_list, Vh_list = [], [], []\n", 377 | " singular_values = []\n", 378 | "\n", 379 | " # Perform SVD on each block\n", 380 | " for block in blocks:\n", 381 | " U_block, S_block, Vh_block = svd(block, full_matrices=False)\n", 382 | " U_list.append(U_block)\n", 383 | " singular_values.append(S_block)\n", 384 | " Vh_list.append(Vh_block)\n", 385 | "\n", 386 | " # Convert lists to arrays\n", 387 | " U = jnp.hstack(U_list)\n", 388 | " Vh = jnp.vstack(Vh_list)\n", 389 | "\n", 390 | " # Concatenate singular values and sort them if needed\n", 391 | " S = jnp.concatenate(singular_values)\n", 392 | " \n", 393 | " # Sorting singular values in descending order\n", 394 | " sorted_indices = jnp.argsort(S)[::-1]\n", 395 | " S_sorted = jnp.sort(S)[::-1]\n", 396 | " Vh_sorted = Vh[sorted_indices]\n", 397 | "\n", 398 | " # Construct diagonal matrix for singular values\n", 399 | " S_diag = jnp.diag(S_sorted)\n", 400 | "\n", 401 | " return U, S_diag, Vh_sorted" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 78, 407 | "metadata": {}, 408 | "outputs": [ 409 | { 410 | "data": { 411 | "text/plain": [ 412 | "Array([[17.146032 , 0. , 0. , 0. , 0. ,\n", 413 | " 0. , 0. , 0. ],\n", 414 | " [ 0. , 6. , 0. , 0. , 0. ,\n", 415 | " 0. , 0. , 0. ],\n", 416 | " [ 0. , 0. , 5.4649854 , 0. , 0. ,\n", 417 | " 0. , 0. , 0. ],\n", 418 | " [ 0. , 0. , 0. , 5. , 0. ,\n", 419 | " 0. , 0. , 0. ],\n", 420 | " [ 0. , 0. , 0. , 0. , 0.36596614,\n", 421 | " 0. , 0. , 0. ],\n", 422 | " [ 0. , 0. , 0. , 0. , 0. ,\n", 423 | " 0.11664554, 0. , 0. ],\n", 424 | " [ 0. , 0. , 0. , 0. , 0. ,\n", 425 | " 0. , 0. , 0. ],\n", 426 | " [ 0. , 0. , 0. , 0. , 0. ,\n", 427 | " 0. , 0. , 0. ]], dtype=float32)" 428 | ] 429 | }, 430 | "execution_count": 78, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "B = jnp.array([[1, 2, 0, 0],\n", 437 | " [3, 4, 5, 0],\n", 438 | " [0, 6, 7, 8],\n", 439 | " [0, 0, 9, 10]])\n", 440 | "\n", 441 | "blocks = split_matrix_into_blocks(B)\n", 442 | "U,S,V = perform_svd_on_blocks(blocks)\n", 443 | "S" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 79, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "B_scaled = (B - jnp.mean(B))/jnp.std(B)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 80, 458 | "metadata": {}, 459 | "outputs": [ 460 | { 461 | "data": { 462 | "text/plain": [ 463 | "(4, 4)" 464 | ] 465 | }, 466 | "execution_count": 80, 467 | "metadata": {}, 468 | "output_type": "execute_result" 469 | } 470 | ], 471 | "source": [ 472 | "U, S, Vt = svd(B)\n", 473 | "Vt.T.shape" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 81, 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [ 482 | "explained_variance = S**2/(len(B) - 1)" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": 82, 488 | "metadata": {}, 489 | "outputs": [], 490 | "source": [ 491 | "explained_variance_ratio = explained_variance / jnp.sum(explained_variance)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 83, 497 | "metadata": {}, 498 | "outputs": [ 499 | { 500 | "data": { 501 | "text/plain": [ 502 | "Array([0.8529526 , 0.11453851, 0.0312005 , 0.00130828], dtype=float32)" 503 | ] 504 | }, 505 | "execution_count": 83, 506 | "metadata": {}, 507 | "output_type": "execute_result" 508 | } 509 | ], 510 | "source": [ 511 | "explained_variance_ratio" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 84, 517 | "metadata": {}, 518 | "outputs": [], 519 | "source": [ 520 | "B_pca = jnp.dot(B, Vt.T[:, :3])" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 85, 526 | "metadata": {}, 527 | "outputs": [ 528 | { 529 | "data": { 530 | "text/plain": [ 531 | "(Array([[-0.03292128, 0.294175 , -0.17891642, 0.9382783 ],\n", 532 | " [-0.25488862, 0.7409277 , 0.6085924 , -0.12519355],\n", 533 | " [-0.65354645, 0.27927098, -0.66242605, -0.23680505],\n", 534 | " [-0.7119168 , -0.5352525 , 0.39849195, 0.21882348]], dtype=float32),\n", 535 | " Array([[-0.03292139, -0.29417548, -0.93827826, -0.17891629],\n", 536 | " [-0.25488853, -0.74092764, 0.1251936 , 0.6085925 ],\n", 537 | " [-0.6535463 , -0.27927062, 0.23680519, -0.6624261 ],\n", 538 | " [-0.7119166 , 0.53525215, -0.21882358, 0.39849186]], dtype=float32))" 539 | ] 540 | }, 541 | "execution_count": 85, 542 | "metadata": {}, 543 | "output_type": "execute_result" 544 | } 545 | ], 546 | "source": [ 547 | "U2 = jnp.linalg.eig(jnp.dot(B, B.T))[1].astype(jnp.float32)\n", 548 | "S2 = jnp.sqrt(jnp.linalg.eig(jnp.dot(B.T, B))[0].astype(jnp.float32))\n", 549 | "sorted_indices = jnp.argsort(S2, descending=True)\n", 550 | "U, U2" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 86, 556 | "metadata": {}, 557 | "outputs": [ 558 | { 559 | "data": { 560 | "text/plain": [ 561 | "((Array([18.121445 , 6.6405816, 3.465861 , 0.7097109], dtype=float32),\n", 562 | " Array([[-0.04401344, -0.27628452, -0.6763542 , -0.681377 ],\n", 563 | " [ 0.37902662, 0.7872333 , 0.12683554, -0.46959096],\n", 564 | " [ 0.47516632, -0.547633 , 0.5748661 , -0.37926766],\n", 565 | " [ 0.79285467, -0.06347415, -0.44270378, 0.41396353]], dtype=float32)),\n", 566 | " (Array([ 0.7097096, 3.4658616, 6.6405826, 18.121445 ], dtype=float32),\n", 567 | " Array([[-0.04401337, -0.27628452, -0.6763542 , -0.68137705],\n", 568 | " [-0.3790269 , -0.78723294, -0.12683566, 0.4695908 ],\n", 569 | " [-0.7928546 , 0.06347422, 0.4427038 , -0.41396356],\n", 570 | " [-0.47516638, 0.5476332 , -0.57486606, 0.3792676 ]], dtype=float32)))" 571 | ] 572 | }, 573 | "execution_count": 86, 574 | "metadata": {}, 575 | "output_type": "execute_result" 576 | } 577 | ], 578 | "source": [ 579 | "S2, Vt2 = jnp.linalg.eig(jnp.dot(B.T, B))\n", 580 | "S2 = jnp.sqrt(S2).astype(jnp.float32).sort()\n", 581 | "Vt2 = Vt2.astype(jnp.float32).T\n", 582 | "(S, Vt), (S2, Vt2)" 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": 87, 588 | "metadata": {}, 589 | "outputs": [], 590 | "source": [ 591 | "def my_svd(B:jax.Array):\n", 592 | " U = jnp.linalg.eig(jnp.dot(B, B.T))[1].astype(jnp.float32)\n", 593 | " S_u = jnp.sqrt(jnp.linalg.eig(jnp.dot(B, B.T))[0].astype(jnp.float32))\n", 594 | " Vt = jnp.linalg.eig(jnp.dot(B.T, B))[1].T.astype(jnp.float32)\n", 595 | " S_vt = jnp.sqrt(jnp.linalg.eig(jnp.dot(B.T, B))[0].astype(jnp.float32))\n", 596 | "\n", 597 | " sorted_indices_u = jnp.argsort(S_u, descending=True)\n", 598 | " sorted_indices_vt = jnp.argsort(S_vt, descending=True)\n", 599 | "\n", 600 | " return U[:, sorted_indices_u], jnp.diag(S_u[sorted_indices_u]), Vt[:, sorted_indices_vt]" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 88, 606 | "metadata": {}, 607 | "outputs": [ 608 | { 609 | "data": { 610 | "text/plain": [ 611 | "Array([[ 1.5747491e+00, 1.2986512e+00, -5.0670344e-01, 7.5956029e-01],\n", 612 | " [ 3.5359454e-01, 5.3320231e+00, -2.7045100e-03, 4.6308179e+00],\n", 613 | " [ 2.9646111e+00, 4.6783400e+00, 8.2129717e+00, 7.1324124e+00],\n", 614 | " [-1.8006245e+00, 7.6882309e-01, 9.8288898e+00, 8.9754982e+00]], dtype=float32)" 615 | ] 616 | }, 617 | "execution_count": 88, 618 | "metadata": {}, 619 | "output_type": "execute_result" 620 | } 621 | ], 622 | "source": [ 623 | "U2, S2, Vt2 = my_svd(B)\n", 624 | "U, S, Vt = svd(B)\n", 625 | "\n", 626 | "U2 @ S2 @ Vt2" 627 | ] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "execution_count": 89, 632 | "metadata": {}, 633 | "outputs": [ 634 | { 635 | "name": "stdout", 636 | "output_type": "stream", 637 | "text": [ 638 | "(7, 7) (5,) (8, 2)\n" 639 | ] 640 | }, 641 | { 642 | "data": { 643 | "text/plain": [ 644 | "Array([[ 0.79758686, 0.3347571 , -0.29629523],\n", 645 | " [ 0.8001173 , -1.0575135 , 0.9375937 ],\n", 646 | " [-1.51121 , -2.5334764 , -0.49882543],\n", 647 | " [-4.263436 , 0.9500487 , 0.13105209],\n", 648 | " [ 0.79758686, 0.3347571 , -0.29629523],\n", 649 | " [ 0.79758686, 0.3347571 , -0.29629523],\n", 650 | " [ 0.79758686, 0.3347571 , -0.29629523]], dtype=float32)" 651 | ] 652 | }, 653 | "execution_count": 89, 654 | "metadata": {}, 655 | "output_type": "execute_result" 656 | } 657 | ], 658 | "source": [ 659 | "B = jnp.array([[1, 2, 0, 0, 0],\n", 660 | " [0, 4, 5, 0, 0],\n", 661 | " [0, 0, 7, 8, 0],\n", 662 | " [0, 0, 0, 10, 11],\n", 663 | " [1, 2, 0, 0, 0],\n", 664 | " [1, 2, 0, 0, 0],\n", 665 | " [1, 2, 0, 0, 0]])\n", 666 | "\n", 667 | "B = (B - B.mean())/B.std()\n", 668 | "n, m = B.shape\n", 669 | "\n", 670 | "num_components = 3\n", 671 | "\n", 672 | "U, S, Vt = jax.scipy.linalg.svd(B, full_matrices=True)\n", 673 | "\n", 674 | "print(U.shape, S.shape, V.shape)\n", 675 | "\n", 676 | "if n < m:\n", 677 | " S = jnp.concatenate((jnp.diag(S), jnp.zeros((n, m-n))), axis=1)\n", 678 | "elif n > m:\n", 679 | " S = jnp.concatenate((jnp.diag(S), jnp.zeros((n-m, m))), axis=0)\n", 680 | "\n", 681 | "\n", 682 | "B @ Vt[:num_components].T" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 90, 688 | "metadata": {}, 689 | "outputs": [ 690 | { 691 | "data": { 692 | "text/plain": [ 693 | "(Array([[-0.03292139, -0.29417548, -0.17891629, -0.93827826],\n", 694 | " [-0.25488853, -0.74092764, 0.6085925 , 0.1251936 ],\n", 695 | " [-0.6535463 , -0.27927062, -0.6624261 , 0.23680519],\n", 696 | " [-0.7119166 , 0.53525215, 0.39849186, -0.21882358]], dtype=float32),\n", 697 | " Array([[ 1.64022967e-01, 1.12288654e-01, -2.42227316e-01,\n", 698 | " 3.89632732e-01, 8.66025329e-01, 0.00000000e+00,\n", 699 | " -8.96705643e-09],\n", 700 | " [ 1.64543360e-01, -3.54724795e-01, 7.66501904e-01,\n", 701 | " 5.09480536e-01, 1.56532224e-07, 1.80300059e-08,\n", 702 | " 9.59345225e-09],\n", 703 | " [-3.10778737e-01, -8.49811375e-01, -4.07799214e-01,\n", 704 | " 1.22215040e-01, -2.22817107e-08, -7.23325755e-09,\n", 705 | " -2.47567788e-09],\n", 706 | " [-8.76771212e-01, 3.18677604e-01, 1.07137300e-01,\n", 707 | " 3.43857735e-01, -1.11688678e-07, -1.03446371e-08,\n", 708 | " -6.83913148e-09],\n", 709 | " [ 1.64022967e-01, 1.12288617e-01, -2.42227197e-01,\n", 710 | " 3.89632732e-01, -2.88675159e-01, -5.77350259e-01,\n", 711 | " -5.77350259e-01],\n", 712 | " [ 1.64022967e-01, 1.12288617e-01, -2.42227197e-01,\n", 713 | " 3.89632732e-01, -2.88675159e-01, 7.88675129e-01,\n", 714 | " -2.11324871e-01],\n", 715 | " [ 1.64022967e-01, 1.12288617e-01, -2.42227197e-01,\n", 716 | " 3.89632732e-01, -2.88675159e-01, -2.11324871e-01,\n", 717 | " 7.88675129e-01]], dtype=float32))" 718 | ] 719 | }, 720 | "execution_count": 90, 721 | "metadata": {}, 722 | "output_type": "execute_result" 723 | } 724 | ], 725 | "source": [ 726 | "U2, U" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 91, 732 | "metadata": {}, 733 | "outputs": [ 734 | { 735 | "data": { 736 | "text/plain": [ 737 | "(Array([[18.121445 , 0. , 0. , 0. ],\n", 738 | " [ 0. , 6.6405807 , 0. , 0. ],\n", 739 | " [ 0. , 0. , 3.4658618 , 0. ],\n", 740 | " [ 0. , 0. , 0. , 0.70970935]], dtype=float32),\n", 741 | " Array([[4.8626552e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n", 742 | " 0.0000000e+00],\n", 743 | " [0.0000000e+00, 2.9812212e+00, 0.0000000e+00, 0.0000000e+00,\n", 744 | " 0.0000000e+00],\n", 745 | " [0.0000000e+00, 0.0000000e+00, 1.2232124e+00, 0.0000000e+00,\n", 746 | " 0.0000000e+00],\n", 747 | " [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.8522401e-01,\n", 748 | " 0.0000000e+00],\n", 749 | " [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n", 750 | " 1.1885693e-07],\n", 751 | " [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n", 752 | " 0.0000000e+00],\n", 753 | " [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n", 754 | " 0.0000000e+00]], dtype=float32))" 755 | ] 756 | }, 757 | "execution_count": 91, 758 | "metadata": {}, 759 | "output_type": "execute_result" 760 | } 761 | ], 762 | "source": [ 763 | "S2, S" 764 | ] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "execution_count": 92, 769 | "metadata": {}, 770 | "outputs": [ 771 | { 772 | "data": { 773 | "text/plain": [ 774 | "(Array([[-0.04401337, -0.27628452, -0.68137705, -0.6763542 ],\n", 775 | " [-0.3790269 , -0.78723294, 0.4695908 , -0.12683566],\n", 776 | " [-0.7928546 , 0.06347422, -0.41396356, 0.4427038 ],\n", 777 | " [-0.47516638, 0.5476332 , 0.3792676 , -0.57486606]], dtype=float32),\n", 778 | " Array([[ 0.08774175, 0.17972684, -0.05285533, -0.74583775, -0.6331922 ],\n", 779 | " [ 0.13246639, 0.02176011, -0.80038613, -0.33112183, 0.4813729 ],\n", 780 | " [-0.04163036, 0.5418476 , 0.5000364 , -0.38165998, 0.5558481 ],\n", 781 | " [-0.88722956, 0.35512844, -0.2501087 , 0.09995639, -0.1190044 ],\n", 782 | " [-0.4310973 , -0.73993903, 0.20975612, -0.4224085 , 0.21028274]], dtype=float32))" 783 | ] 784 | }, 785 | "execution_count": 92, 786 | "metadata": {}, 787 | "output_type": "execute_result" 788 | } 789 | ], 790 | "source": [ 791 | "Vt2, Vt" 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": 93, 797 | "metadata": {}, 798 | "outputs": [ 799 | { 800 | "data": { 801 | "text/plain": [ 802 | "Array([[ 1.5747491e+00, 1.2986512e+00, -5.0670344e-01, 7.5956029e-01],\n", 803 | " [ 3.5359454e-01, 5.3320231e+00, -2.7045100e-03, 4.6308179e+00],\n", 804 | " [ 2.9646111e+00, 4.6783400e+00, 8.2129717e+00, 7.1324124e+00],\n", 805 | " [-1.8006245e+00, 7.6882309e-01, 9.8288898e+00, 8.9754982e+00]], dtype=float32)" 806 | ] 807 | }, 808 | "execution_count": 93, 809 | "metadata": {}, 810 | "output_type": "execute_result" 811 | } 812 | ], 813 | "source": [ 814 | "U2 @ S2 @ Vt2" 815 | ] 816 | }, 817 | { 818 | "cell_type": "code", 819 | "execution_count": 97, 820 | "metadata": {}, 821 | "outputs": [], 822 | "source": [ 823 | "class myPCA():\n", 824 | " def __init__(self, num_components:int):\n", 825 | " self.num_components = num_components\n", 826 | " self.mean = None\n", 827 | " self.principal_components = None\n", 828 | " self.explained_variance = None\n", 829 | "\n", 830 | " def fit(self, X:jax.Array):\n", 831 | " n, m = X.shape\n", 832 | " \n", 833 | " self.mean = X.mean(axis=0)\n", 834 | " X_centred = X - self.mean\n", 835 | " S, self.principal_components = svd(X_centred, full_matrices=True)[1:]\n", 836 | "\n", 837 | " if n < m:\n", 838 | " S = jnp.concatenate((jnp.diag(S), jnp.zeros((n, m-n))), axis=1)\n", 839 | " elif n > m:\n", 840 | " S = jnp.concatenate((jnp.diag(S), jnp.zeros((n-m, m))), axis=0)\n", 841 | "\n", 842 | " self.explained_variance = S**2 / jnp.sum(S**2)\n", 843 | "\n", 844 | " def transform(self, X:jax.Array):\n", 845 | " if self.principal_components is None:\n", 846 | " raise RuntimeError('Must fit before transforming.')\n", 847 | " \n", 848 | " X_centred = X - X.mean(axis=0)\n", 849 | " return jnp.dot(X_centred, self.principal_components[:self.num_components].T)\n", 850 | " \n", 851 | " def fit_transform(self, X:jax.Array):\n", 852 | " if self.mean is None:\n", 853 | " self.mean = X.mean(axis=0)\n", 854 | "\n", 855 | " X_centred = X - self.mean\n", 856 | "\n", 857 | " self.principal_components = svd(X_centred, full_matrices=True)[2]\n", 858 | "\n", 859 | " return jnp.dot(X_centred, self.principal_components[:self.num_components].T)\n", 860 | "\n", 861 | " def inverse_transform(self, X_transformed:jax.Array):\n", 862 | " if self.principal_components is None:\n", 863 | " raise RuntimeError('Must fit before transforming.')\n", 864 | " \n", 865 | " return jnp.dot(X_transformed, self.principal_components[:self.num_components]) + self.mean" 866 | ] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "execution_count": 104, 871 | "metadata": {}, 872 | "outputs": [ 873 | { 874 | "data": { 875 | "text/plain": [ 876 | "Array([ True, True, True, True, True, True, True], dtype=bool)" 877 | ] 878 | }, 879 | "execution_count": 104, 880 | "metadata": {}, 881 | "output_type": "execute_result" 882 | } 883 | ], 884 | "source": [ 885 | "mymodel= myPCA(num_components=3)\n", 886 | "B_transformed = mymodel.fit_transform(B)\n", 887 | "jnp.isclose(B[:,0], B_transformed[:, 0], rtol=10)" 888 | ] 889 | } 890 | ], 891 | "metadata": { 892 | "kernelspec": { 893 | "display_name": ".venv", 894 | "language": "python", 895 | "name": "python3" 896 | }, 897 | "language_info": { 898 | "codemirror_mode": { 899 | "name": "ipython", 900 | "version": 3 901 | }, 902 | "file_extension": ".py", 903 | "mimetype": "text/x-python", 904 | "name": "python", 905 | "nbconvert_exporter": "python", 906 | "pygments_lexer": "ipython3", 907 | "version": "3.11.9" 908 | } 909 | }, 910 | "nbformat": 4, 911 | "nbformat_minor": 2 912 | } 913 | -------------------------------------------------------------------------------- /examples/notebooks/experimenting_gaussiannb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 19, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "\n", 15 | "from jax import jit, vmap" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 20, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | " Feature1 Feature2 Feature3 Target\n", 28 | "0 A X L 0\n", 29 | "1 B Y M 1\n", 30 | "2 A X N 0\n", 31 | "3 C Z L 1\n", 32 | "4 B X M 0\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "# Define the data\n", 38 | "data = {\n", 39 | " 'Feature1': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C',\n", 40 | " 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C', 'B',\n", 41 | " 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A',\n", 42 | " 'C', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C',\n", 43 | " 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A', 'C', 'B'],\n", 44 | " 'Feature2': ['X', 'Y', 'X', 'Z', 'X', 'Y', 'X', 'Z', 'Y', 'X',\n", 45 | " 'X', 'Z', 'Y', 'X', 'Y', 'Z', 'Y', 'X', 'Y', 'X',\n", 46 | " 'X', 'Y', 'X', 'Z', 'Y', 'X', 'Z', 'Y', 'X', 'Y',\n", 47 | " 'Z', 'X', 'Y', 'Z', 'X', 'Y', 'X', 'Z', 'X', 'Y',\n", 48 | " 'X', 'Y', 'X', 'Z', 'X', 'Y', 'X', 'Z', 'Y', 'X'],\n", 49 | " 'Feature3': ['L', 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'N', 'L',\n", 50 | " 'L', 'M', 'N', 'L', 'M', 'L', 'N', 'M', 'L', 'N',\n", 51 | " 'L', 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'L', 'N',\n", 52 | " 'L', 'M', 'L', 'M', 'N', 'L', 'M', 'L', 'N', 'L',\n", 53 | " 'M', 'N', 'L', 'M', 'N', 'L', 'M', 'N', 'L', 'M'],\n", 54 | " 'Target': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n", 55 | " 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n", 56 | " 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n", 57 | " 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n", 58 | " 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]\n", 59 | "}\n", 60 | "\n", 61 | "# Create DataFrame\n", 62 | "data = pd.DataFrame(data)\n", 63 | "\n", 64 | "# Display the DataFrame\n", 65 | "print(data.head())" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 21, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | " Feature1 Feature2 Feature3 Target\n", 78 | "0 0 0 0 0\n", 79 | "1 1 1 1 1\n", 80 | "2 0 0 2 0\n", 81 | "3 2 2 0 1\n", 82 | "4 1 0 1 0\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "from sklearn.preprocessing import LabelEncoder\n", 88 | "\n", 89 | "# Initialize LabelEncoder\n", 90 | "label_encoders = {}\n", 91 | "for column in ['Feature1', 'Feature2', 'Feature3']:\n", 92 | " le = LabelEncoder()\n", 93 | " data[column] = le.fit_transform(data[column])\n", 94 | " label_encoders[column] = le\n", 95 | "\n", 96 | "# Display the updated DataFrame\n", 97 | "print(data.head())" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 22, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "X = data.drop(columns=['Target']).to_numpy()\n", 107 | "y = data['Target'].to_numpy()\n", 108 | "\n", 109 | "X, y = map(jnp.array, (\n", 110 | " X, y\n", 111 | "))\n" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 23, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "def split_data(data, val_size=0.1, test_size=0.2):\n", 121 | " \"\"\" \n", 122 | " Splits data.\n", 123 | " \"\"\"\n", 124 | " split_index_test = int(len(data) * (1-test_size))\n", 125 | "\n", 126 | " data_non_test = data[:split_index_test]\n", 127 | " data_test = data[split_index_test:]\n", 128 | "\n", 129 | " split_index_val = int(len(data_non_test) * (1-val_size))\n", 130 | "\n", 131 | " data_train = data_non_test[:split_index_val]\n", 132 | " data_val = data_non_test[split_index_val:]\n", 133 | "\n", 134 | " return data_train, data_val, data_test" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 24, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "(X_train, X_val, X_test), (y_train, y_val, y_test) = map(\n", 144 | " split_data,\n", 145 | " (X, y)\n", 146 | ")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 25, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "[0.8089011311531067, 0.6781419515609741, 0.8975274562835693]" 158 | ] 159 | }, 160 | "execution_count": 25, 161 | "metadata": {}, 162 | "output_type": "execute_result" 163 | } 164 | ], 165 | "source": [ 166 | "unique_classes = jnp.unique(y).tolist()\n", 167 | "indices_for_each_class = [jnp.where(y_train==class_) for class_ in unique_classes]\n", 168 | "\n", 169 | "dictionary_of_stds = dict(zip(unique_classes,\n", 170 | " [[jnp.std(X_train[collection_of_indices][:,j]).item() for j in range(X_train.shape[1])] for collection_of_indices in indices_for_each_class]))\n", 171 | "\n", 172 | "dictionary_of_stds[0]" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 26, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "text/plain": [ 183 | "{'a': 0, 'b': 1, 'c': 2}" 184 | ] 185 | }, 186 | "execution_count": 26, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "val = [1,2,3]\n", 193 | "keys = ['a', 'b', 'c']\n", 194 | "\n", 195 | "dict(zip(keys, [i for i in range(len(keys))]))" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 27, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "def compute_priors(y:jax.Array):\n", 205 | " \"\"\"\n", 206 | " Obtain prior probabilities.\n", 207 | "\n", 208 | " Args:\n", 209 | " y (jax.Array): Label vector.\n", 210 | " \n", 211 | " Returns:\n", 212 | " prior_probabilities (jax.Array): Vector of prior probabilities.\n", 213 | " \"\"\"\n", 214 | " unique_classes = jnp.unique_values(y)\n", 215 | " prior_probabilities = []\n", 216 | " \n", 217 | " for index, class_ in enumerate(unique_classes.tolist()):\n", 218 | " prior_probabilities.append(jnp.mean(jnp.where(y==class_, 1, 0)))\n", 219 | "\n", 220 | " return jnp.array(prior_probabilities)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 28, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "data": { 230 | "text/plain": [ 231 | "Array([0.5, 0.5], dtype=float32)" 232 | ] 233 | }, 234 | "execution_count": 28, 235 | "metadata": {}, 236 | "output_type": "execute_result" 237 | } 238 | ], 239 | "source": [ 240 | "compute_priors(y)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 29, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "def gaussian_pdf(x, mean, std):\n", 250 | " return jnp.exp(-0.5 * ((x-mean)/std)**2 )/(std*jnp.sqrt(2*jnp.pi))" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 30, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "def compute_means(X:jax.Array, y:jax.Array, random_state=12):\n", 260 | " \"\"\" \n", 261 | " Computes means.\n", 262 | " \"\"\"\n", 263 | " np.random.seed(random_state)\n", 264 | "\n", 265 | " unique_classes = jnp.unique(y).tolist()\n", 266 | " indices_for_each_class = [jnp.where(y==class_) for class_ in unique_classes]\n", 267 | "\n", 268 | " dictionary_of_means = dict(zip(unique_classes, \n", 269 | " [[jnp.mean(X[collection_of_indices][:,j]).item() for j in range(X.shape[1])] for collection_of_indices in indices_for_each_class]))\n", 270 | " \n", 271 | " return dictionary_of_means" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 31, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "def compute_stds(X:jax.Array, y:jax.Array, random_state=12):\n", 281 | " \"\"\" \n", 282 | " Compute stds.\n", 283 | " \"\"\"\n", 284 | " np.random.seed(random_state)\n", 285 | "\n", 286 | " unique_classes = jnp.unique(y).tolist()\n", 287 | " indices_for_each_class = [jnp.where(y==class_) for class_ in unique_classes]\n", 288 | "\n", 289 | " dictionary_of_stds = dict(zip(unique_classes, \n", 290 | " [[jnp.std(X[collection_of_indices][:,j]).item() for j in range(X.shape[1])] for collection_of_indices in indices_for_each_class]))\n", 291 | " \n", 292 | " return dictionary_of_stds" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 32, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "def compute_posterior(X:jax.Array, y:jax.Array):\n", 302 | " \"\"\" \n", 303 | " Computes posteriors to compute predictions.\n", 304 | " \"\"\"\n", 305 | " posteriors = []\n", 306 | " \n", 307 | " dictionary_of_means = compute_means(X, y)\n", 308 | " dictionary_of_stds = compute_stds(X, y)\n", 309 | " \n", 310 | " prior_probabilites = compute_priors(y)\n", 311 | "\n", 312 | " for x in X:\n", 313 | " likelihoods = jnp.array([gaussian_pdf(x, jnp.array(means), jnp.array(stds)) for means, stds in zip(dictionary_of_means.values(), dictionary_of_stds.values())])\n", 314 | " vector_of_posteriors = jnp.log(jnp.dot(likelihoods, prior_probabilites))\n", 315 | " posteriors.append(vector_of_posteriors)\n", 316 | "\n", 317 | " return posteriors" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 33, 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "ename": "TypeError", 327 | "evalue": "dot_general requires contracting dimensions to have the same shape, got (3,) and (2,).", 328 | "output_type": "error", 329 | "traceback": [ 330 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 331 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 332 | "Cell \u001b[0;32mIn[33], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m posteriors \u001b[38;5;241m=\u001b[39m \u001b[43mcompute_posterior\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_test\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39margmax(jnp\u001b[38;5;241m.\u001b[39marray(posteriors), axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", 333 | "Cell \u001b[0;32mIn[32], line 14\u001b[0m, in \u001b[0;36mcompute_posterior\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m X:\n\u001b[1;32m 13\u001b[0m likelihoods \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray([gaussian_pdf(x, jnp\u001b[38;5;241m.\u001b[39marray(means), jnp\u001b[38;5;241m.\u001b[39marray(stds)) \u001b[38;5;28;01mfor\u001b[39;00m means, stds \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(dictionary_of_means\u001b[38;5;241m.\u001b[39mvalues(), dictionary_of_stds\u001b[38;5;241m.\u001b[39mvalues())])\n\u001b[0;32m---> 14\u001b[0m vector_of_posteriors \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mlog(\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdot\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlikelihoods\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprior_probabilites\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 15\u001b[0m posteriors\u001b[38;5;241m.\u001b[39mappend(vector_of_posteriors)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m posteriors\n", 334 | " \u001b[0;31m[... skipping hidden 11 frame]\u001b[0m\n", 335 | "File \u001b[0;32m~/Desktop/scikit-jax/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:5761\u001b[0m, in \u001b[0;36mdot\u001b[0;34m(a, b, precision, preferred_element_type)\u001b[0m\n\u001b[1;32m 5759\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 5760\u001b[0m contract_dims \u001b[38;5;241m=\u001b[39m ((a_ndim \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m,), (b_ndim \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m2\u001b[39m,))\n\u001b[0;32m-> 5761\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mlax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdot_general\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdimension_numbers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcontract_dims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_dims\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5762\u001b[0m \u001b[43m \u001b[49m\u001b[43mprecision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprecision\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreferred_element_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreferred_element_type\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5763\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lax_internal\u001b[38;5;241m.\u001b[39m_convert_element_type(result, preferred_element_type, output_weak_type)\n", 336 | " \u001b[0;31m[... skipping hidden 7 frame]\u001b[0m\n", 337 | "File \u001b[0;32m~/Desktop/scikit-jax/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py:2723\u001b[0m, in \u001b[0;36m_dot_general_shape_rule\u001b[0;34m(lhs, rhs, dimension_numbers, precision, preferred_element_type)\u001b[0m\n\u001b[1;32m 2720\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m core\u001b[38;5;241m.\u001b[39mdefinitely_equal_shape(lhs_contracting_shape, rhs_contracting_shape):\n\u001b[1;32m 2721\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdot_general requires contracting dimensions to have the same \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2722\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshape, got \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m and \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 2723\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg\u001b[38;5;241m.\u001b[39mformat(lhs_contracting_shape, rhs_contracting_shape))\n\u001b[1;32m 2725\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _dot_general_shape_computation(lhs\u001b[38;5;241m.\u001b[39mshape, rhs\u001b[38;5;241m.\u001b[39mshape, dimension_numbers)\n", 338 | "\u001b[0;31mTypeError\u001b[0m: dot_general requires contracting dimensions to have the same shape, got (3,) and (2,)." 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "posteriors = compute_posterior(X_test, y_test)\n", 344 | "y_pred = jnp.argmax(jnp.array(posteriors), axis=1)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": {}, 351 | "outputs": [ 352 | { 353 | "data": { 354 | "text/plain": [ 355 | "Array(87.43718, dtype=float32)" 356 | ] 357 | }, 358 | "execution_count": 106, 359 | "metadata": {}, 360 | "output_type": "execute_result" 361 | } 362 | ], 363 | "source": [ 364 | "jnp.mean(y_pred == y_test) * 100" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [ 372 | { 373 | "data": { 374 | "text/plain": [ 375 | "Array(92.46231, dtype=float32)" 376 | ] 377 | }, 378 | "execution_count": 107, 379 | "metadata": {}, 380 | "output_type": "execute_result" 381 | } 382 | ], 383 | "source": [ 384 | "from sklearn.naive_bayes import GaussianNB\n", 385 | "\n", 386 | "model = GaussianNB()\n", 387 | "model_fitted = model.fit(X_train, y_train)\n", 388 | "y_pred_2 = model_fitted.predict(X_test)\n", 389 | "\n", 390 | "jnp.mean(y_pred_2 == y_test) * 100" 391 | ] 392 | } 393 | ], 394 | "metadata": { 395 | "kernelspec": { 396 | "display_name": ".venv", 397 | "language": "python", 398 | "name": "python3" 399 | }, 400 | "language_info": { 401 | "codemirror_mode": { 402 | "name": "ipython", 403 | "version": 3 404 | }, 405 | "file_extension": ".py", 406 | "mimetype": "text/x-python", 407 | "name": "python", 408 | "nbconvert_exporter": "python", 409 | "pygments_lexer": "ipython3", 410 | "version": "3.11.9" 411 | } 412 | }, 413 | "nbformat": 4, 414 | "nbformat_minor": 2 415 | } 416 | -------------------------------------------------------------------------------- /examples/notebooks/linear_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 45, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax\n", 10 | "import jax.numpy as jnp\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "import pandas as pd\n", 14 | "import matplotlib.pyplot as plt" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 46, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "class MyLinearRegression():\n", 24 | " def __init__(self, random_state:int=12):\n", 25 | " self.key = jax.random.key(random_state)\n", 26 | "\n", 27 | " def fit(self, X:jax.Array, y:jax.Array):\n", 28 | " self.variance = jnp.var(y) # variance of the target variable y and the residuals\n", 29 | "\n", 30 | " self.design_matrix = jnp.vstack((jnp.ones(X.shape[0]), X.T)).T\n", 31 | "\n", 32 | " lhs = jnp.dot(self.design_matrix.T, self.design_matrix) # Gram Matrix\n", 33 | " rhs = jnp.dot(self.design_matrix.T, y)\n", 34 | "\n", 35 | " inverse_gram_matrix = jnp.linalg.inv(lhs)\n", 36 | "\n", 37 | " self.coeff = jnp.dot(inverse_gram_matrix, rhs) # Solving Normal Equation\n", 38 | " return self\n", 39 | "\n", 40 | " def predict(self, X:jax.Array):\n", 41 | " self.design_matrix = jnp.vstack((jnp.ones(X.shape[0]), X.T)).T\n", 42 | " return jnp.dot(self.design_matrix, self.coeff) " 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### DATA 1: Standard Linear Relationship" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 47, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "(100,)" 61 | ] 62 | }, 63 | "execution_count": 47, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "x = jnp.linspace(0, 10, 100)\n", 70 | "y = 2 * x + 1 + jax.random.normal(jax.random.PRNGKey(12), shape=(x.shape[0],))\n", 71 | "df1 = pd.DataFrame({'x': x, 'y': y})\n", 72 | "\n", 73 | "y.shape" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 48, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "Array([0.76906997, 2.043794 ], dtype=float32)" 85 | ] 86 | }, 87 | "execution_count": 48, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "model_fitted = MyLinearRegression().fit(x, y)\n", 94 | "y_pred = model_fitted.predict(x)\n", 95 | "\n", 96 | "model_fitted.coeff" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 49, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "Array(True, dtype=bool)" 108 | ] 109 | }, 110 | "execution_count": 49, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "X = model_fitted.design_matrix\n", 117 | "\n", 118 | "jnp.isclose(jnp.trace(X @ jnp.linalg.inv(X.T@X) @ X.T), X.shape[1])" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 50, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbIAAAGsCAYAAAC4ryL3AAAAP3RFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMS5wb3N0MSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8kixA/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABoF0lEQVR4nO3dd3hUZdrH8e/09ElvkITeexVlFRUFXnVVbCgWFHVXcRVZdcWCXda6lnUtq2JHcBXEBgIKiAJSpHcIISG9Tvq08/5xZiYZkkACSSaT3J/rynUxM2dOngmYn/dz7vM8GkVRFIQQQgg/pfX1AIQQQojTIUEmhBDCr0mQCSGE8GsSZEIIIfyaBJkQQgi/JkEmhBDCr0mQCSGE8Gt6Xw/geE6nk8zMTEJDQ9FoNL4ejhBCCB9RFIXS0lISExPRahuuu9pckGVmZpKUlOTrYQghhGgj0tPT6dy5c4Ovt7kgCw0NBdSBh4WF+Xg0QgghfMVisZCUlOTJhYa0uSBzTyeGhYVJkAkhhDjpZSZp9hBCCOHXJMiEEEL4NQkyIYQQfq3NXSNrLIfDgc1m8/UwRCswGAzodDpfD0MI0Ub5XZApikJ2djbFxcW+HopoReHh4cTHx8u9hUKIOvwuyNwhFhsbS1BQkPxia+cURaGiooLc3FwAEhISfDwiIURb41dB5nA4PCEWFRXl6+GIVhIYGAhAbm4usbGxMs0ohPDiV80e7mtiQUFBPh6JaG3uv3O5LiqEOJ5fBZmbTCd2PPJ3LoRoiF8GmRBCCOEmQSaEEMKvSZC1M6tWrUKj0cjtCUKIDkOCrBVoNJoTfj3++OOndN5x48Yxc+bMZh2rEEL4G79qv/dXWVlZnj8vWLCAOXPmsG/fPs9zISEhnj8rioLD4UCvl78aIYT/OJhbRueIQAIMrX97jN9XZIqiUGG1++RLUZRGjTE+Pt7zZTab0Wg0nsd79+4lNDSUH374geHDh2MymVi7di3Tpk3jsssu8zrPzJkzGTduHADTpk1j9erVvPrqq57K7siRI55jN2/ezIgRIwgKCuLMM8/0Ck4hhGhOvx3KZ/zLq3ns610++f5+/7/9lTYH/eYs88n33v3kBIKMzfMjfPDBB3nxxRfp1q0bERERJz3+1VdfZf/+/QwYMIAnn3wSgJiYGE+YPfzww7z00kvExMTw17/+lVtuuYVff/21WcYqhBC17c60ALD5aJFPvr/fB1l78eSTT3LBBRc0+niz2YzRaCQoKIj4+Pg6rz/zzDOcc845gBqSF110EVVVVQQEBDTbmIUQAiCvrBqAtIJyHE4FnbZ17/v0+yALNOjY/eQEn33v5jJixIhmOxfAoEGDPH92r0+Ym5tLcnJys34fIYTIL7UCYHMoHCuqJDmqdVdf8vsg02g0zTa950vBwcFej7VabZ1rcE1ZnslgMHj+7F4Vw+l0nsYIhRCifu6KDOBwflmrB5nfN3u0VzExMV7djgBbt271emw0GnE4HK04KiGEqCu/tCbIUvPLW/37S5C1Ueeddx6bNm3io48+4sCBAzz22GPs3LnT65guXbqwYcMGjhw5Qn5+vlRcQgifqF2RSZAJjwkTJvDoo4/ywAMPMHLkSEpLS7nxxhu9jrnvvvvQ6XT069ePmJgYjh496qPRCiE6KodToaD21GJe6weZRmnszVCtxGKxYDabKSkpISwszOu1qqoqUlNT6dq1q3TfdTDydy9E25RfVs2Ip1d4HncKD+TXB89rlnOfKA9qk4pMCCHEKctzXR8z6tU4OVZcSZWtda/dS5AJIYQ4Ze4g6xoVTFiA2kF+pKB1pxclyIQQQpyyfNf1sZhQE11j1HVjU1v5OpkEmRBCiFPmrshiQk10i1bvhz3cyp2L/n8nsRBCCJ9xV2TRIUbCAtSFGFq7BV+CTAghxCmrXZElhgcCrR9kMrUohBDilOXVvkbmmlqUIBNCCOE33AsGR4eY6BKlBllhuZXiCmurjUGCrJ05fkPOcePGMXPmzNM6Z3OcQwjRPtWuyIJNeuLD1AULWrPhQ4KslUybNs2zk7PRaKRHjx48+eST2O32Fv2+X331FU899VSjjl21ahUajYbi4uJTPocQouOwOZwUltdUZEDN9GIrtuBLkLWiiRMnkpWVxYEDB/j73//O448/zgsvvFDnOKu1+UryyMhIQkNDfX4OIUT74w4xnVZDRJARgK4xrX+dTIKsFZlMJuLj40lJSeGOO+5g/PjxLFmyxDMd+Mwzz5CYmEjv3r0BSE9P5+qrryY8PJzIyEguvfRSjhw54jmfw+Fg1qxZhIeHExUVxQMPPFBnD7PjpwWrq6v5xz/+QVJSEiaTiR49evDee+9x5MgRzj33XAAiIiLQaDRMmzat3nMUFRVx4403EhERQVBQEJMmTeLAgQOe1z/44APCw8NZtmwZffv2JSQkxBPibqtWrWLUqFEEBwcTHh7OWWedRVpaWjP9pIUQrcHdsRgVbPTsCt3NBw0f/h9kigLWct98neZ6y4GBgZ7qa+XKlezbt4/ly5fz7bffYrPZmDBhAqGhofzyyy/8+uuvnkBwv+ell17igw8+4P3332ft2rUUFhayaNGiE37PG2+8kfnz5/Paa6+xZ88e3n77bUJCQkhKSuLLL78EYN++fWRlZfHqq6/We45p06axadMmlixZwrp161AUhf/7v//z2vizoqKCF198kY8//pg1a9Zw9OhR7rvvPgDsdjuXXXYZ55xzDtu3b2fdunXcfvvtng1AhRD+Ic9zD5nJ81xXH9wU7f/3kdkq4NlE33zvhzLBGHzy446jKAorV65k2bJl/O1vfyMvL4/g4GDeffddjEa1PP/kk09wOp28++67nl/w8+bNIzw8nFWrVnHhhRfyyiuvMHv2bCZPngzAW2+9xbJlyxr8vvv372fhwoUsX76c8ePHA9CtWzfP65GRkQDExsYSHh5e7zkOHDjAkiVL+PXXXznzzDMB+PTTT0lKSmLx4sVcddVVgLqb9VtvvUX37t0BuOuuu3jyyScBdUXrkpISLr74Ys/rffv2bfLPUQjhW7XvIXNzB9mR/HKcTgWttuX/B9X/KzI/8u233xISEkJAQACTJk3immuu4fHHHwdg4MCBnhAD2LZtGwcPHiQ0NJSQkBBCQkKIjIykqqqKQ4cOUVJSQlZWFqNHj/a8R6/XM2LEiAa//9atW9HpdJxzzjmn/Bn27NmDXq/3+r5RUVH07t2bPXv2eJ4LCgryhBRAQkICubm5gBqY06ZNY8KECVxyySW8+uqrdXbDFkK0ffUFWVJkEHqthkqbg5zSqlYZh/9XZIYgtTLy1fdugnPPPZc333wTo9FIYmIien3Njz842LuyKysrY/jw4Xz66ad1zhMTE3NKww0MDDyl950Kg8Hg9Vij0Xhdv5s3bx533303S5cuZcGCBTzyyCMsX76cM844o9XGKIQ4Pfn1TC0adFqSI4M4nF9Oal45CeaW/73j/0Gm0ZzS9J4vBAcH06NHj0YdO2zYMBYsWEBsbGyDG8olJCSwYcMGzj77bEC99rR582aGDRtW7/EDBw7E6XSyevVqz9Ribe6K0OFoeC+hvn37Yrfb2bBhg2dqsaCggH379tGvX79GfTa3oUOHMnToUGbPns2YMWP47LPPJMiE8CP1VWQAZ/WIpltMCCaDrlXGIVOLbdTUqVOJjo7m0ksv5ZdffiE1NZVVq1Zx9913k5GRAcA999zDP//5TxYvXszevXu5884769wDVluXLl246aabuOWWW1i8eLHnnAsXLgQgJSUFjUbDt99+S15eHmVlZXXO0bNnTy699FJuu+021q5dy7Zt27j++uvp1KkTl156aaM+W2pqKrNnz2bdunWkpaXx448/cuDAAblOJoSfcQdZdIjR6/mnLhvAuzeNYHhKRKuMQ4KsjQoKCmLNmjUkJyczefJk+vbty/Tp06mqqvJUaH//+9+54YYbuOmmmxgzZgyhoaFcfvnlJzzvm2++yZVXXsmdd95Jnz59uO222ygvV7uLOnXqxBNPPMGDDz5IXFwcd911V73nmDdvHsOHD+fiiy9mzJgxKIrC999/X2c68USfbe/evVxxxRX06tWL22+/nRkzZvCXv/ylCT8hIYSv1d6LzJc0yvE3HvmYxWLBbDZTUlJSZ0qtqqqK1NRUunbtSkBAgI9GKHxB/u6FaHsGPb4MS5WdFbPOpkds8y+acKI8qE0qMiGEEE1WbXdgqVKX2Kvd7OELEmRCCNGBlVfbsTmcTX5ffpm6MINBp8EceNxlhfTfYcnfwNlw41hzkiATQogOqqTCxpn//Imp725o8ns9HYshJu9VeXZ/DR9eAls+gvVvNtdQT0iCTAghOqhdmSWUVNrYklaE09m0dol8d8eiu9FDUeC3f8PCm8BeBb0mwvBpzTzi+vn/fWRCCCFOSVphBQB2p4KlykZ4kPEk76jh2YcsxAQOOyx9EDb+V31x5G0w6TnQts59ZH4ZZE5n0+dzhX+Tv3Mhmt+RgpqFffPLqpsWZK6KLDHYCQumwv6lgAYufBrGzFAXq2glfhVkRqMRrVZLZmYmMTExGI1GWTG9nVMUBavVSl5eHlqt1ms9SiFE/X7em8uG1EL+fmEvDLqGryAdLajw/Dm/zEqP2MZ/j/yyamIoYsaRJ6F8L+gDYPI70K9xCyM0J78KMq1WS9euXcnKyiIz00frKwqfCAoKIjk5Ga1WLusKcTLPfL+Hg7ll9E0I5dIhnRo87ohXkFU36XvoCvaxyPQY8eX5EBQF134OSaNOecynw6+CDNSqLDk5GbvdfsI1AUX7odPp0Ov1Un0L0UgFrlBasSe3wSBTFIWjtaYWC8qasDN96hruS/8bwZpyykJSCLl5EUR1P/n7WojfBRmoK6kbDIZGL4kkhBAdhdOpUFKpbnK7al8uNoez3unF/DIr5VZHrcfeFZmiKNzxyRYA3rx+WM3/SG6dD0v+RrBiY6OzF7qL5jPMhyEG0n4vhBDtSmmVHXcnfWmVnY2phfUed7TQewfn/OMqsryyapbuymbprmxyLNVqe/2q52DxX8Fp43vnGVxvfYjImIQW+RxNIUEmhBDtSHGldyCt2JNb73FH8iu8Hh9fkeWU1Dw+mlsEX8+AVc8CUD7yLmZY78KmMZIY3nr7HDZEgkwIIdqR4gqb1+MVe3Kob234NNf1schgtRO44Lggy7aouzuHUkHKsmmw9VPQ6ODif7Gr399R0NIpIhCj3vcx4vsRCCGEaDbFrutjKVFBGPVajhZWcDC37t6C7puhhyWre4YdP7WYXVJJIvn8z/g4cfnrwRAM1y2AEbd4QjAlsm1saixBJoQQ7UhxhRpIieZAzuweBcDyPTl1jnO33rs3vzy+InNmbmWRaQ69tRkU66Phlh+g5wUAHHWFYHJUUMt8iCaSIBNCiHbEPbUYHmRgfN84AFbsrhtk7tZ7d5CVWx1UursY9y9jys6/EqcpZq8zib+HvQQJgz3vdYdgSqQEmRBCiGZWE2RGzu+rLtXxR3qxVzNHSaWNItdx/RPDPNe58suqYeO7MH8KJqWSNY6BXGV9jC3F3oHlDsGUKD+cWpw7dy4jR44kNDSU2NhYLrvsMvbt2+d1TFVVFTNmzCAqKoqQkBCuuOIKcnLq/t+AEEKI5ufuWgwPMpBgDmRApzAUBX7aW9O96F6aKjrERLBJr27FghPjz4/Dd38HxckPhvHcYrufUoIoqrBhqappInFfX0vxx6nF1atXM2PGDNavX8/y5cux2WxceOGFlJfX3I9w77338s033/DFF1+wevVqMjMzmTx5crMPXAghRF0l7orMtdllfdOL7sWCu7iCKCFY4XXD68TteFs94NxHuK/6VuzoPWv/usOvpMLmqfqS28jUYpNW9li6dKnX4w8++IDY2Fg2b97M2WefTUlJCe+99x6fffYZ5513HgDz5s2jb9++rF+/njPOOKP5Ri6EEKIOd9dieJAaZBf2i+eVFQdYtS+P4gor4UFG72aN8gLmlj1KT90uHBo9usv+Q2nvyZT/8CMAveNC2ZtdSlpBBQM6mUlz3UjtrubagtO6RlZSUgJAZGQkAJs3b8ZmszF+/HjPMX369CE5OZl169bVe47q6mosFovXlxBCiFNT5OpaNAeq94f1Swyjb0IYVoeTb7api60fyVfDaHBQAbw3np7VuyhRglgy6A0YfA3ZJeo9ZGEBevomhAF4AizNVZl1aSPTinAaQeZ0Opk5cyZnnXUWAwYMACA7Oxuj0Uh4eLjXsXFxcWRnZ9d7nrlz52I2mz1fSUlJpzokIYTo8EoqvCsygKuGdwbgi80ZgHqNa5hmP1O23wKFhyk2JTDZ+gTb9QOBmpuh480BnulD99RiW2u9h9MIshkzZrBz504+//zz0xrA7NmzKSkp8Xylp6ef1vmEEKIjc08tRtTaJPPSIYnotRq2Z5SwL7uU7nnLmW98BpO1GBKH8t2oTzikdPLcFO2uyOLNgZ6GDncl1tZuhoZTDLK77rqLb7/9lp9//pnOnTt7no+Pj8dqtVJcXOx1fE5ODvHx8fWey2QyERYW5vUlhBCi6ZxOxXNDdO2KLCrE5GrFV0hd8ixz7S9h0tiw9ZgI074jJEpd+Nd9U7QnyMJMniBzV2Kee8j8tSJTFIW77rqLRYsW8dNPP9G1a1ev14cPH47BYGDlypWe5/bt28fRo0cZM2ZM84xYCCFEvcqsNSvfmwO9t7m6amgCT+vfZ2LmfwD4lIkYrvsMjMFEh5iAmoWDPVOLYQEkuyqvzJJKqu0OzxRjWwqyJrWczJgxg88++4yvv/6a0NBQz3Uvs9lMYGAgZrOZ6dOnM2vWLCIjIwkLC+Nvf/sbY8aMkY5FIYRoYcXl6rRigEFLgEFX80J1GedtvQetfiVORcPT9uvZFH8NU7XqMVEh7oWD604tRocYCTLqqLA6OJRb7gm5tnIzNDQxyN58800Axo0b5/X8vHnzmDZtGgD/+te/0Gq1XHHFFVRXVzNhwgT+85//NMtghRBCNMxzM3RgzfUxSrPhs6vRZm3DpjFyl/VOljlHcXGtIHJXZIUVVuwOZ61mDxMajYbkyCD2Zpfy68F8AEJNeiKC2s7Gxk0Ksvq2AjheQEAAb7zxBm+88cYpD0oIIcSJ/WfVQdLyK/jnFQM9uzcXH9+xmLMbPr0KLBkQFM2xCe+xbL77ZuiaIIsIMqLVgFOBogobOZ6pRXWvsZQoNcjWHMgD1I5Fz47RbYCstSiEEH7Ganfy0o/7WbApnQO1tmjxuhn68Cp4f4IaYlE94NbldBk8jkGdzQD0jAvxvE+n1Xj2JcsqqfR0L8abA4CaacQNrt2mu7ShaUWQIBNCCL9zpKAch6urI6OoZqdnd8fiRNvP8MkVUG2B5DEwfTlEdgPgtSlDefTiflw0MMHrnFHB6vTi7kx1UQqjXuuZPnTfS2a1O9XHbajRAyTIhBDC7xyqVYUdK6r0/Lm43MpM/f+YlvccOO0w4Aq4YTEERXqO6RIdzPSxXdHrvH/9R4eqFdnOTHXFpviwAM/04fEdim1l+xa3trFQlhBCiEarveNzhjvI7FbO2T2Hwfof1MdjZ8F5j4K2cfWKuyLbeUytyOLDAjyvHX/zs1RkQgghTsvBvFpBVlwJlcXwyWQGF/6AXdGyuvfDMP6xRocY1HQu7s12BZm5JsgSwwPQa2uaO+QamRBCiNNSuyKz5h9RmzqO/EKVJpDptvvJ6j6lyed030tWZVOvg9UOMr1OS6cItYPRqNd6VWttgQSZEEK0Ue+sOcQHv6Z6Ped0KhzOU1voB2oOM7fwXsjbC6EJPBT+PKudg72Wp2qsGFdF5hZ3XFi5Gz6SIgLRattO6z1IkAkhRJuUY6ni2e/38vg3u8krrfY8n1lSSaXNwfm6LSwwPkU0xThj+8OtK9lqTwZqtnBpCndF5pZgrj/I2tKKHm4SZEII0QYdqnUdbHNaoefPB3PLuF63nHcMLxOkqWaNYyBpl34J5k71buHSWNEnqchGdVU7H4enRDT53C1NuhaFEKINOpJfc3/YxiNFTByQAE4nUb8+xdOGjwBYahjPXVU3Mq9CTxdFqXcLl8Y6viKLP64i+/PgRIanRNApPLDJ525pEmRCCNEGHXHt+wWwKa0IbJXw1e0MPLoEgLXJf2UBk7GX5pNRVElZtd1zk/TpVmQaDcSGeldoGo2GzhFtq+3eTaYWhRCiDUrNrwmyrGPpOD64GPYswYaBe6x3kj/0bjq5rlsdK6r0rLNo0h+38n0jBRh0hJjU2iY6xIRB5z/x4D8jFUKIDuSIK8i6arJYqJ+D7tgmCDBzh/YRvnaOpUdsiKdCOlZcWXfB4FMQ7ZpePL7Ro62TIBNCiDbG6VRIK6xghGYvXwc8ThdtDiWmRIqv+54VFT0B6BYT7Lledayosv4tXJooyjW9eHyjR1snQSaEEK2o2u4gs7jyhMdkllRyofNXPjXOJUwpZauzGw9Hv8IBZyIAncIDCTLqPTcpZxRVNGtF1tZueD4ZCTIhhGhFD321k7Oe+4lt6cX1H6Ao2Fb/i38bX8eksWFJuZAp1kdZlQH7sksB6B6rbsHS2RVk2ZYq8svUe81OJ8i6RKv3iPWIDTnJkW2LBJkQQrSSsmo732zPRFHg99TCugc47PDtvXTd+jwAy0MnE3zDfPSmYMqq7Xy/IwuAHjFq0EQHmzDqtTgV2JulhtzpTC3edW4P3rlhONeMTDrlc/iCBJkQQrSS1fvyPHt6pdZqrweguhTmXwOb56Gg4QnbDazvfT86vZ6hyeEA/HaoAKipmLRajec6mXv7ldOpyEIDDFzYP/6Uuh59SYJMCCFayY+7sz1/TqsdZJZMmDcJDq4AfSD/jnmceY5Jnqm+ESmRXuepPfXnDrL9OWpFZj6NIPNXEmRCCNEKrHYnP+3N9Tz2rNyRswveHQ/ZOyA4BqZ9x6LKIQB0cwXZyC7ey0J1j6lZ79B9nczmUG+GPpVVPfydBJkQQjTCseJKHl60g/TCipMfXI91hwsorbIT6rrpOLOkEuu+FfDeBLAcg6ieMH059oShHHV9D3dFNiQ5HJ1rxfmIIIOnTR6os2RUeKBUZEIIIerx1qpDfLrhKE99u/uU3r9slzqteMmQREJMeq7UrsKw4BqwlkLKWTD9R4jsyrHiSuxOBZNeS4KrDT7IqKd/YhhQt6PQ3YLvJlOLQggh6rU9oxiAn/bmklta1aT3Op0Ky3fnADChXxyPBH3JC4Z30DjtMPAquGERBKnXwdxLU6VEBXnt+zWqi/p67/hQr3Mfv/7h6XQt+itZNFgIIU7Caneyx9XebncqfLXlGH89p3uj3/9HehF5pdVEBiiM3fEwusqF6vNdbmXo5BfVVXpd3EtTdTlu36+7zuuByaDlhjO6eD1/fEV2Ol2L/koqMiGEOIn9OaVYHU7P44Ub01EUpdHv/3FXDmGUsSDoBXQ7F+JExz9st/Fl+DSvEAM4UqBeH+sa7R1k4UFG7p/Qp872KnGhJs/1M5BmDyGEEPXYnqHeozU0OZwgo47D+eVsPFLUqPcqisK2Hdv40vgEPSu2gjGUtaP/wwLHuV57jrm5pxa7RDduJ2a9TutZUsqo1xJg6Hi/1jveJxZCiCbacawYgDHdorhkkLre4ecbjzbqvWk7fuX1igfoqT2GMzQBbllKYN8LAe89x9zczx0/tXgi7hb88EADmuMqvI5AgkwIIU7CXZEN6mzmatfyTd/vyMJSZTvxG/d+T6fFVxCjKeGooTva236C+AGkRKkNGpnFlVTbHZ7DbQ4nGUXqgsLHTy2eiPs6WUe8PgYSZEIIcUJVNodnsd6BncMZlhxOz9gQqmxOlmzNbPiNG96Gz6/D4KxilWMwa8/+GMLUai4mxESwUYdTgfTCmpXw0wsrcDgVAg064sJMDZ25js7h7oqs410fAwkyIYQ4ob3ZpdidClHBRhLNAWg0Gs+iuu//msp7a1P5aksGa/bnUWVzgNMBSx+CHx4AFOY7zme67T7G9uvqOadGoyHFNXVYe6kq97RiSlRQk6YIe8er95glRQad5Mj2SdrvhRDiBHa47h8b2NnsCZfLh3bi+aX7OJxX7nWD9NRh0TzjfA32fgvAvgF/Z/amYXSLDiE5yjtkukQHsTvL4mnuADic59oVugnTigATB8Tz7o0jGJ4ScfKD2yEJMiGEOIFt7utjncye56JCTLx1wzB+2ptLcYWNgjIr+w8f5qpdc0BzEHRGuOxNPjjQC0jn7F4xdc7bxVOR1XQuHsorU19rYpDptBrG94tr6kdrNyTIhBDiBHa4gmxg53Cv58/rE8d5fdTwUPL2k/3mTSQ4s6nSmwm44XOU5DGs+vYnAMb1bjjI3NOJVTYHS3eqy1gdv0iwODG5RiaEEA2osNo5kKs2egzqbK7/oLTf0Lx3AQnObNKcsfzF9CxK8hgO5JaRVVKFSa/ljG5Rdd7mrrrcQbZsVzZFFTYSzAGc0yu2ZT5QOyUVmRBCNGB3pgWnAnFhJuLCAuoesON/sPgOcFhxJA7n+oy/kF4QwvrDhZ57z87oFlXvRpVdXNfMjhVVYrU7+fz3dACuHpHktVKHODmpyIQQwsXucJJRVOFZfsp9/9jATuHeByoK/PISfDkdHFboewm6m79j7JB+AHz2+1FW7csD4Jx6ro8BxISaCHK14P9yII91hwvQaPDcpyYaT4JMCCFcXl15gLHP/czVb69jR0YJO47V3Ajt4bDBN3fDyifVx2Pugqs+BEMgU0cnA7B0ZxYbjxQC9V8fA+8W/BeW7VOP7RVTZ38xcXIytSiE8EuPLt5JcaWN16YMabZlmb7bngXAxiNF/PmNtZj06v/rD3QHWZUFvpgGh1aCRgsT/wmj/+J5/4BOZgZ3Nns6HZMiA0/YSt8lKog9WRb2um64vnZUcrN8jo5GKjIhhN+xVNn4eH0a32zL9CzpdLpyS6s4nF+ORgMXDUpAUaDKpq54P7CTGSyZMG+SGmL6QLjmU68Qc7tudE0YjesVe8KQTam1nmJsqInz+kiTx6mQIBNC+J20WqvGZ1uatsllQza5VrPvHRfKG9cN46s7z+RPPaO5dlQS0WUH4L/nQ85OCI6Fm7+DPv9X73kuGZxIqEmd7GpoWtGta3TNTdJXj0hCr5NfyadCphaFEH4nrbBmNYzskuYJst9T1Wtao7uqOzEPS47g4+mj4eAKeH8aWEshujdM/QIiUho8T5BRz+vXDWVXpoVze5+4wqpdkV0jTR6nTIJMCOF3aq+G0VxBtsEVZCNdQQbAlo/gm5mgOKDLn+CajyHw5Dcrj+sdy7iThBioYXl+n1j6JoR12HUSm4MEmRDC7xyptT5hc0wtllTa2JttAWBUl0i1vf6np+GXF9UDBl0Df34d9I1fkb4xjHot700b2azn7IgkyIQQfserImuGINucVoiiqIv1xgZp4KvbYMcX6otn3w/nPgwdcMNKfyFBJoTwO7V3Vm6OqUX3tOLZSXr4eDKkrQWtHi55FYZef9rnFy1LgkwI4VcqrHZyS6s9j5sjyDamFtJZk8vf0x+BssNgCoOrP4Lu5572uUXLk15PIYRfOVqoTiu6lyPMsVThdCoNHl9aZWPGp1v4fkdWva9XWh1wbDOLjHMIKzsMYZ3hlqUSYn5EgkwI4VeOuO4h65sQhlYDdqdCQbm1weO/3prJdzuy+Mf/tlNcUfe4I78t5FP9U8RoLCjxg+DWFRDXv8XGL5qfBJkQwq+kua6P9YgNITpE7SI80fTi5jT1RufSajv//eWw94vr36LPqjsI1FjZFXwGmpt/gLCElhm4aDESZEIIv3LE1bGYEhlEglndWuVEnYvuIAOY9+sRCsqqwemAHx6Epf9Ag8Kn9vPZetZ/wBTSsoMXLUKCTAjhV9wVWUpUsGePsOyS+tdbzC2t4mhhBRqNuvRUhdXBez/vhoU3woY3AXjReR0P229hZHdZ59BfSZAJIfyK+x6yLtEnr8i2pBUDaog9OKkP0ZQwYdN02Pstis7Ev6Me4d/Wi4kMNtEjRqoxfyXt90IIv1Ftd5Dpqr5SooKJcwVZVgPXyLYcVacVh6VEMC6qiG+DniDemU2ZNpR7+QfLj3UjwKDl6csGoJVdmf2WBJkQwm+kF1aiKBBi0hMVbCTeNbWY00BFtsm1ueWEoINo3ruHeGcJR5xx3Fz9AKlKAj1iQ3jjumH0jg9ttc8gmp8EmRDCb7ivjyVHBqHRaIg/QUVWZXOw85iFP2t/5ewN/wWnFaXzSJ533k/qYTtXDe/ME5f2J8govwb9nfwNCiH8xpFa18cAT0WWXVKFoihem1juOlbMbXzF/caF4AT6/hnN5Hd4GSP3FVfSTa6JtRvS7CGE8Bu1OxYBT0VWYXVQWm2vOdBhI3jZLO43LFQfj7kLrvoQDIEEGHQSYu2MBJkQwm94Ohaj1IosyKgnLECdWMpxTy9WWeCzq+mTtRiHomFtz3/AhGdAK7/u2iv5mxVC+I3jKzKABHMg4LpOVnIM5k2CQz9RiYnbbbMIOOuvPhmraD0SZEIIv2BzOMkocrfe1+ym7G7Br0rfCu+eDzk7sQfFclX1o/yiGcmATmZfDFe0IgkyIYRfyCyuxO5UMOm1xIUGeJ5PCAvgbO02xv12I5RmQUwflp/5KTuVbgzoFEaAQefDUYvWIEEmhPALnjUWo4K8bl4eX/kD7xtewOiogC5/gluWsTZPnW4cnhLhk7GK1iVBJoTwC3WujzmdsOIJLjj0LHqNk1+DL4Drv0IJMHt2fJYg6xjkPjIhhF9w70OWEhkE9mpYfCfs/B8Ar9gn86PhZr7XG9mSVsjB3DJMei1ndIvy5ZBFK2lyRbZmzRouueQSEhMT0Wg0LF682Ov1adOmodFovL4mTpzYXOMVQnRQR1wVWW+zDT66TA0xrZ7McS/ziv1KckqrAfjwtzQALh2SSHiQ0VfDFa2oyUFWXl7O4MGDeeONNxo8ZuLEiWRlZXm+5s+ff1qDFEKIw3llJGtyuHjjTXD0NzCFwfVfEjjyBgAKyq2kF1bw/Y4sAG4c08WHoxWtqclTi5MmTWLSpEknPMZkMhEfH3/KgxJCiNqsdidRxdt52/gigRYLmJPguoUQ149wRe1krLY7eWXFAexOhREpEdJ234G0SLPHqlWriI2NpXfv3txxxx0UFBQ0eGx1dTUWi8XrSwghaivc9D8+1T9FtMaCkjAYbl0Bcf0AvBYP/uqPDABuPLOLr4YqfKDZg2zixIl89NFHrFy5kueee47Vq1czadIkHA5HvcfPnTsXs9ns+UpKSmruIQkh2pCVe3J4aNEOqmz1/07woiiw7g3ilt5OgMbG74aRaKZ9D6HeMz7uxYMVBWJCTUzsLzNCHUmzdy1OmTLF8+eBAwcyaNAgunfvzqpVqzj//PPrHD979mxmzZrleWyxWCTMhGjHXli2j73ZpQzubOaakckNH+h0wNLZ8PvbaICP7eP5veeDjDLVXfDXXZEBTB2djFEvdxZ1JC3+t92tWzeio6M5ePBgva+bTCbCwsK8voQQ7dcx1zJTaw82fMkBazksuB5+fxuA7xPu5FH7zXSJqf/3gzvI9FoN1406QTiKdqnFgywjI4OCggISEhJa+lsJIdqI9YcLWLgxvc7zpVU2z3Yrvx3Mx+lU6r65NAc+uAj2fQ86E1z1AR9p/gxo6BodXPd4oF+CGnCXDulEbFhAvceI9qvJU4tlZWVe1VVqaipbt24lMjKSyMhInnjiCa644gri4+M5dOgQDzzwAD169GDChAnNOnAhRNt174KtZJVUMSwlnB6xoZ7na+/kXFBuZU+2hf6JtboL8/bBJ1dCyVEIjIRrP4fk0aR+vQKgwSC7ZFAisaEBDE0Ob5HPI9q2JldkmzZtYujQoQwdOhSAWbNmMXToUObMmYNOp2P79u38+c9/plevXkyfPp3hw4fzyy+/YDKZmn3wQoi2x2p3egLrUF6512uZxZVej389mF/zIHUNvHeBGmKR3dTOxOTRlFfbybGoNzs3FGRarYYx3aNkgeAOqskV2bhx41CUeqYDXJYtW3ZaAxJCtB0Op4JWo7a4N1Z+WbXnz+5tV9xqV2SgXie7/ezusG0BfD0DnDZIGg1T5kOwurxUar4ahpHBRlmpQ9RLWnuEEPWqsjk498VV3Prhpia9L7e0dpBVeL2W5arIRrgW8/09NR/bT/+ERberIdbvMrhxiSfEoCbIGqrGhJBFg4UQ9TqYW8bRwgoyiipwOBV02sZVZbmWmqorvdC7Ist0VWTn9Ioho8DCrKr/YFizWn3xzLth/BOg9f7/awkycTISZEKIehWUWwFwKlBQVt3obsC8shNUZCVqsCUH2/nQ9Dy97ZtxokV70Qsw8tZ6zydBJk5GphaF6GCySir55UDeSY/LrzVFWHu68GRyLd7XyGpfU88qriKBAsavv4ne5ZspV0w8HTanwRADOOwKsm4SZKIBEmRCdDAzP9/KDe/9zrb04hMeV1BeO8iqTnCkt9qhV1Ztp6TSBoCiKJhLdrPY9CjBxftwBMdxtXUO8/J6UVJhq/dciqKQmlcGQNcYCTJRPwkyITqYg7lqMGw/VnLC4wrKrJ4/166yTibvuNBzdy6W7fyeT7SPE6cpxhnTB91tK6mOGagup3g4v54zQWG5FUuVegN1lygJMlE/CTIhOpBqu8Nz7euwq9JpSH6tIMtpQpAdPw2ZXlgBG98j5KvrCdZUs0EzEO30HyE8ibE9ogH4fkc25a4VP2pzXx/rFB4o94iJBkmQCdGB1K6sjr9Z+Xi17wdr0tSi63t0iQpCg5PEjf+E72ahUZx8YT+bf0Y8BQHqah7uIFuyLZPBT/zIlW/+xusrD1BhVUPtsDR6iEaQIBOiA8mu1Rp/KPfEFZn3NbLGVWROp+IJwFGdg3jd8G8GH/0AgK3d7+B++1+ICa9Zsmpc7ximndmFzhGB2J0Km9KKeGn5fu74ZAs2h1M6FkWjSPu9EB1Idq2VNTJLKqm0Ogg01j9l53WNrJFBVlhhxe5UiNRYuC/neWJ1W7GjR3/5GyzNHAIcIjE80HO8Xqfl8T/35/E/9ye9sILV+/N4+rvdrN6fxyOLdlJcqY5BgkyciFRkQnQgObUqMkWpuQZ1PEVRjmv2aNzUYq6lmhRNNotMTxBbvBWLEsRDwY/B4Cmee8gSzPXfj5YUGcT1Z6Tw72uHodXAgk3prNiTC0jHojgxCTIhOpDs49Y6PNRAw4elyo7V4fQ8ziutrn/LleNUpa7jK+NjpJCFPbQTk62P842lF4qikFWsfu+EWhVZfcb3i+OpywYA6lqPIPeQiROTIBOiA8lyVVbu1aYON9DwUeC6zhXo6hS0OxWKKqz1Huux+2sGrbyBKE0pqcaeOKav4BCdqbQ5KCy3kumqyBIbqMhqmzo6hRnndveModNJwk90bHKNTIgOJMdVkQ3oZGZ7RkmDFZm7RT8uzISlyk5huZXc0mqiQurZjklRYN0b8OMj6FFY7hjGz93n8mx4InGhe8i2VJFWWOGZ1jxZReZ234W9iTcHEh8WgF4n/88tGib/OoToQNxdi2d2V9veD+c3EGSuiiwqxERsqBpe9TZ8OB3w/f3w48OAwoboyfzFNotwczgAnSPU0Np6tBibQ90SJi60cXsTajQabjgjhQv6xTX244kOSoJMiA7C6VQ893id2V3dJuVQbnm9177yXI0eUcFGz2LBOcc3fFjL4fOpsPG/gAYmPMsH5hk40XrCLykyCICNRwoBiJPqSrQA+RclRAdRWGHF6nCi0cDILpHotRoqbQ6ve8vc6qvI8mpVZE5LNodfOAf2/4CiD4CrP4QxM8h1BaA7/NwVmTvIGupYFOJ0SJAJ0UG4Oxajgk0EGnUkR6nVUn0NH+7W+5gQY83Uojvwcvfg/O/5dLMdoEAJJfWi+dDvUvUl1wog7ve4g8y93FVjr48J0RQSZEJ0EO6pwXizGjLdY0KA+lvw3at61LlGlroG3puAvjSDw854Lrc+yQZbD0C998w9dRkbqlZeSRFBXudtTMeiEE0lQSZEB+GeQox3Tft1c91kXF+QuSuoqJCaa2R9cr6DjydDdQmZYYOZbH2Co0qcZzsYS5Wdart671mMpyLzDrIEs1RkovlJ+70QHYS79T7eVRW5K7L6phbd6yVGBZsw6uBu3VfcU/o/9cX+l/Nc1V8pzi0AYFuGuh2Me/uWUJPes+xVQngAWo26yzRAYrhUZKL5SUUmRAeRVeJdkXU/QUXmuUYWBH02zGaWQQ0x5ayZcMX77MytaRDZn1NKpdXhmVaMCatprzfotF5VmFRkoiVIkAnRQbinFuPcU4vRakWWVVLltReY1e6kpNJGGOWk/DCN4D0LcCgaHrJNx3LWI1Q5FM8ajUFGHQ6nwq7MEvLK3NfHvO8T6xRRK8ikIhMtQIJMiA6iptlDDZOIYCORwUbAe/HgogorieTzP+MTGI6uAUMwf9P8g88c55NbWsXB3DKcCkQEGTw3Vm9NL67T6OHmbvgw6DREBzfuZmghmkKCTIgOIvu4qUWof3qxNHUTi01z6KXNgJB4uOUHDoSdCag7Re/NLgWgd3woQ5LUDTK3Z5TUab13c7fgx5sD0LoXeRSiGUmQCeGncixVjd5epcJqx1KlTh/Gm2sHmbsF31WR7V9G12+uIlZTTKo2BW5bCQmDiQ1zt+BXsS/bAkCf+DAGdQ4HYFtGsWcJq9gw7yDrEq1WZMe34gvRXKRrUQg/ZLU7mfTqL2g1sG72+RhOsuyTuxoLNuoIDTB4nvdqwd/4Lnx/PzrFyS+OAXzc6SneMXcGaqYLc0trKrI+8aEM6qxWZGkFFQQZ9V7Huk3sn8Des0uZOCD+dD+2EPWSIBPCD2WVVFLoWqE+r7Taa9fl+ngaPY67Ibl7TAganJxz5DXYvwiAfQmXcnPqFVwcGuE5rmZ1D++pxfAgI12igjhSUMGeLIvXsW6BRh2z/6/vqX5UIU5KphaF8EO1N8isd1X64+RY6l4fAxgcb+J1w+tcbVVDjHMf4avOD2JH77Vli/sG573ZFs+ai73iQtVzJIV7nfP4qUUhWpoEmRB+qPZCv3VWpa/v+BI1fGpfH6O8gOivruZi3Qasio5tI5+Hc+4nv9wGqKt6uLlb9jcdKQIgJSqIYJM6oeO+TuYWEyot9qJ1ydSiEH6odng1piLLdu3O7KnICg7Bp1dC4WEqtSHcUnUPfZxjGUzNOovRtSoy93Sh1aEuQdXbVY0Bns5FAJNeS1iA/FoRrUv+xQnhh9wVFtCozsXs2veQHd0A86dAZSGYk9kw8g3WfVtK8WF1qxX3qh7RtSqy2OOmJPvE1wRZvwQzOq0Gh1MhNsyERiMt9qJ1ydSiEH4o21Lp+bP7RuQTH68eM6jkZ/jwEjXEEofCrSvoP3g0oF7/Kq6weq2z6HZ8A0fv+DDPnwONOk+FdnzHohCtQYJMCD9Uu9kjp/TkFVlOcSW3675hyPqZ4KiG3v8H076D0DhiQk10jwlGUWBDaqGnIqt9jSzYpCfEVDOB0ych1Ov8g13Ti8cHnhCtQYJMCD+UY6k9tXjiisxus3JX1Zs8ZJivPjH6r3DNJ2AM9hxzRrcoAFbszvFcB6t9jQxqQsqk19IlKtjrtYsHJRJk1DGud8ypfSAhToMEmRB+xulUjmv2OEFFVl2G/bNruV63AqeiwXnhszDpOdDqvA5zB9mPu3MACDHpCTB4H+Nuwe8ZF4LuuKWmzuoRzc7HJ3DNyORT/lxCnCoJMiH8TEG5Fbt7gy/3Y1cV5cWSBfMmEZC6girFwEOG+9GeOaPec47uFglASWXd1ns3zwabta6P1SbrKApfkSATws+4q7HoECN6rQZFqdnRueag3fDueMjeTrUxkinWR9kbMa7Bc8aGBniWqwKICq4bZKO6qmEn04eirZH2eyH8jHuDzARzIAadlqySKnIsVTU3Ox/6GRbeCNUWbBHduapkFtuVCKYdtwLH8c7oFuXZLToqpG7Txg1npHDpkETCaq3VKERbIBWZEH6m9gaZnjUQ3TdF//GpeqNztQV75zO42v4k2ysi6J8Yxv0Tep/wvKNdFRfUbfRwkxATbZFUZEL4mRxPRRaAeu9xCTkllfDzs7D6OQAc/a/glqJp/JFXSnxYAO/dNNKzpFRD3A0f4H0ztBBtnQSZEH6m9iodTkXBgJ1hW2ZD3vfqAWNn8WTZ5aw5nE6wUcf700Z6r7HYgLiwALpFB3M4v7zea2RCtFUytSiEn8mpNbXYOdDGh4Z/0i/ve9Do4JJXsZ37KJ9tzADglSlD6ZdYf5dhfaaekUJMqIkx3aNbZOxCtASpyITwM+5VPVK0+Vyw81bMukNUaQIJuO4T6Dmew9ml2BwKISY94/vGNunc08d2ZfrYri0xbCFajFRkQviZbEsVAzWHGfrjlZjLDpGlRHJf6HPQczygrpkI0CsuRBbwFR2CVGRC+JHyajujrBt43fhv9JXVVEX24fLMGTgqEz3H7M9x7+Dc+ClFIfyZVGRCtCEVVjtpBeUNv/7rW7xjeJkgTTV0P4/S674lmygKyqo9q3vsy1aDrPZWK0K0ZxJkQrQBDqfC/N+PcvbzP3POC6tYeyDf+wCnE5Y9TMyah9FpFL4zXADXLSQyMhqdVoNTUZeqAtib7a7IJMhExyBTi0L42LpDBTzxzS5PAAF8tSWDsT1dnYO2SvjqdtizBIDnbVezrdMtXKQzoEO95yvHUk2upZpgk56MInWvstq7OAvRnklFJoQPHcgp5Yb3NrA3u5SwAD3XjVZXj/9pX646VVier26EuWcJ6Iws7/s0/3FcRpw50HOOONdivjmWKs+0YmyoiQi5F0x0EFKRCeFDG1ILsTsVBnYy8+EtowgL0PPDjiyKKmzs2L6ZoWtug6JUCAiHKZ+xdpsZSCOh1g3OtZepci9VJdOKoiORikwIHzqYWwbAGd0iiQw2otdpOa9PHCM0e+n97WQ1xMJTYPpy6HKWZ8Hg+LBaQeZVkamt99LoIToSqciE8KFDeWqQ9YgN8Tx3Y9hm+hifxeSwo3QajubaBRCibp1Se1UPt9oV2WHX+aT1XnQkUpEJ4UPubVO6x4SAosDafzF4/b2YNHaWOUawf+J8T4iB9zqLbu5Qy7VUsS9HWu9FxyNBJoSPVFjtHCtWOwy7RwXAtzNhxeMALA+bzB22mSzbb/Ecb3c4yXNdA4uv5xrZzswSiitsaDXeFZ4Q7Z0EmRA+4q7GOgc5iPj6Btj8AaCBic9R+KcncKJl+e4cz/H5ZVacCui1GqKDa/YLq+laVEOuS3QwAQZdq30OIXxNrpEJ4SOH8sqIpYiPtS/BwcOgD4Qr34M+F3FeaTUazQ52HCshq6SSBHOgZ1oxNtSEVluzhqK7InOTaUXR0UhFJoSPlKRuZbHpUbraD0NwDEz7DvpcBEBMqInhyREArHBVZe5V7+OO21ssKsRErVyjd5w0eoiORSoyIXzh0E9cteNWAjUVFAd1IfzWryGii9chF/SLY1NaER+tS8McZORIvjoVWbv1HkCn1RAdYpJ7yESHJRWZEK3tj0/g06sIdFawwdmHHRO+qBNiABMHxKPVwIHcMu6e/wcvL98PUO9uz7FhNdOLMrUoOhoJMiFai6LAT0/D1zPAaWeJ8yxusM4mpXPneg9PiQrmqzvP4vazu9ErrqYLsX+iuc6xcaFquAUYtCRHBrXM+IVoo2RqUYgWsGxXNrmWKq4/I0Xd3NJuhSV3wfYFAJSMuJt71o7CqNfTKSKwwfMMSQpnSFI4D/1fX44VV5JZXMkw17Wz2twVWa+4UK9GECE6AgkyIZqZoijMWrCVcquD7rEhnJmogwU3wJFfQKODi//F5uCJKGs30TU6GF0jg6dTeCCdwusPvc4RahXWP1EaPUTHI0EmRDMrLLdSbnUA8N2aDZxZ/gTk7wNjKFz9IfQ4n0NrDgPQvZluXL52VDJOp8IVw+ufphSiPWvyNbI1a9ZwySWXkJiYiEajYfHixV6vK4rCnDlzSEhIIDAwkPHjx3PgwIHmGq8QbZ57Yd+BmsPMPHKHGmKhiXDLUuhxPlCzWHD3mOYJsshgI387vyeJDVRsQrRnTQ6y8vJyBg8ezBtvvFHv688//zyvvfYab731Fhs2bCA4OJgJEyZQVVV12oMVwh/kWKoYr93MAuNTxGhKyA3qCbethPgBnmPciwV3jwn21TCFaDeaPLU4adIkJk2aVO9riqLwyiuv8Mgjj3DppZcC8NFHHxEXF8fixYuZMmVKnfdUV1dTXV3teWyxWOocI4Q/Cd0+j7cNL6PTKKxyDGZO9f2sCIqn9jaX9a16L4Q4Nc3afp+amkp2djbjx4/3PGc2mxk9ejTr1q2r9z1z587FbDZ7vpKSkppzSEK0HqcDlj7EqD1z0WkUfo+8hAdND3G0XMuyXdmewwrLrRRV2ADoFi1BJsTpatYgy85W/2ONi4vzej4uLs7z2vFmz55NSUmJ5ys9Pb05hyRE67BWwMIbYb065f6cbQobBzzG1aO7AfDxujTPoe5qrFN4IIFGWdxXiNPl8xuiTSYTYWFhXl9C+JWyPPjwEtj7LeiM/DtyNm86/kxCeCBTRyej12r4/Ughe127N3saPWRaUYhm0axBFh8fD0BOTo7X8zk5OZ7XhGhX8g/Au+fDsU0QEA43fs0i2xmAuiZiXFgAE/qr//b/+vFmPl53hO0ZJQD0aKaORSE6umYNsq5duxIfH8/KlSs9z1ksFjZs2MCYMWOa81sJ4Xtpv8G746E4TV0r8dYVkHKmZ18w95qIfzu/B+ZAA0cKKnj0613M//0oAN1jpWNRiObQ5K7FsrIyDh486HmcmprK1q1biYyMJDk5mZkzZ/L000/Ts2dPunbtyqOPPkpiYiKXXXZZc45biFZ3rLiS77dncfWIJMyHvobFd4DDCp1GwLWfQ0gMpVU2yqrtQE2Q9YkPY+0/zuWrLcf48LcjHHatYj+gnjUThRBN1+Qg27RpE+eee67n8axZswC46aab+OCDD3jggQcoLy/n9ttvp7i4mLFjx7J06VICAuqu2C2EP3ll+X6+2JxO0O+vMrXsA/XJPhfD5P+CUV0iyr1nWFiAniBjzX9eoQEGbjqzCzeckcLag/lU2hwMTgpv5U8gRPvU5CAbN24ciqI0+LpGo+HJJ5/kySefPK2BCdHWpOWV8Kz+Xa4r+xkA5Yw70Vz4NGhrOg/duzgnmOtfYUOr1XB2r5iWH6wQHYistShEY1SXcm/eHMbo/8ChaHjSfiPhupu5V+vdPu9enqq+PcOEEC3D5+33QrQmRVF4dPFO/vzvtRRXWBv3ppJjON+fyBjlDyoVI8sGvMiHjgm8uvIAi/7I8DrUPbV4/C7OQoiWI0EmOpQl2zL5eH0a2zNK+N/mjJO/IXsnvDsebc5O8hQzNzgfY9KV0/nL2eqNzo8s2km13VFzuEUqMiFamwSZ6DDySqt5fMkuz+Mvtxw78RsOroT3J0JpJhXmHlxufZLiiIFoNBr+MbEP4UEGyq0O9meXed6SLVOLQrQ6CTLRIbinFIsqbPSKC8Go07Iny8KuzJI6x1VY7bDlI/j0KrCWQpc/sXT0B2QoMZ6NLbVaDQM7qe3zO47VnEOukQnR+iTIRIfw3Y4slu7KRq/V8K9rhjC+XywAX272rspmLdjKvKemw5K/geKAQVPg+q9ILVPXru8cUdONOMATZMWe53I8XYsSZEK0Fgky0e4VlFUz52t1SnHGuT3on2jmimHqTspfbz2GzeEEYPn2o5y962Fm6Barbzz7Abj8LdAbySiqBKBzRJDnvIOOq8iqbA4Ky9UGEmn2EKL1SJCJdu+bbZkUllvpHRfKjHN7AHB2rxiiQ4wUlFtZvS+P8uI8ohddw+W6X7EpOhYkPgjnPQwaDQDHPEFWtyLbl11Ktd3hqcYCDFrMgYbW/IhCdGgSZKLdSyusAGBcnxiMevWfvEGn5bIhnQBYtf53Kt86j6HKbkqVQKbZHmC+7Wyvc2QUqefoVCvIOkcEEhFkwOZQ2Jdd6mn0SDAHonEFoBCi5UmQiXavvmlBgCuGd2aw5iD3pt1JdNVRjilRLB39Ab86B3Iwt8yzgo3N4fS01deuyDQajacq255R4jkmLszU4p9JCFFDgky0e54gC/deNqpv8WoWmJ4hSmNhlzOF17u+yWUTLkSv1VBWbSfTVWFll1ThVMCk1xIT4h1S7s7FncdKPB2LDS1PJYRoGRJkot075poWrF1Nse4/sOAGAqjmJ8cQbtY8yczLz8Gg09I1Wt1eZX9OKQDp7mnF8LpThoM61zR8uKcW46TRQ4hWJWstinatpNKGpUrdVqVTRCA4HbDsIdjwFgD2odNYq7mFF/okeu796hUXyoHcMg7mlHFu71hPRVf7+phb7YaP2FC1WpPWeyFalwSZaNfc3YZRwUaCsMKCW2Hfd+qL459Af9Y9zDmuyuoZFwI7aiqyYw1cYwO1SosMNlJYbmX94UJAboYWorXJ1KJo19zdhv3NVfDBRWqI6Uxw5TwYO9PTXl9br7hQAPbnlrnOUbf13q12w0elTV1zUe4hE6J1SZCJdi2jqJLummO8bLkfMrdAYATc+DUMmNzge3rGhgBwMKcURVE8YVhfkEHNjdFuMrUoROuSIBN+o8rmYOexkpMfWIvu6K98ZXyMaHsWRHSFW1dCypgTvqdLdDAGnYZyq4PMkqoTVmRQc50MQKfVEBUi7fdCtCYJMtFsHM6Gdw5vrEqrg4O5pfW+9tzSvVz8+lo+23C0cSfb/gXX7Z+JWVNBbvhguHUFRHU/6dtqdy7uybTUuoes7jUygIGda4IsLtSETis3QwvRmiTIRLP4ZlsmfR9dyvc7sk75HFkllVz0+i+Mf3kNfxwtqvP62gP5ALzx80HsrvUR3f44WsR1/13PjowSUBRY8wJ8dSsGbHznGMXu8R9DcHSjx9LTdZ3slwN5OJwKRl3de8jcEs0BRAWriwpLo4cQrU+CTDSLXw/mY3U4eWXFfs+KGE2RXljB1W+v43BeOVATWm4VVjuH8tTmi2PFlXxXKzCtdid/X7iN3w4V8NGvB9WV6396GoB5/Jm7bHeTGBPZpPG4r5P9vC8PgMTwALQNVFq1Gz4kyIRofRJkolmUVNoA2J9Txu+phU167+G8Mq5+ex3phZWeabltGcVex+zJslB75vLt1Yc9gfnJ+jQO55cTQgXX7J8Ff3wMGi2VFz7PE1VTUNB69hFrLHfn4tFCd6NH/dOKbqO6qkHZIyakSd9HCHH6JMhEs7BU2Tx//qSx17CA3NIqrn57PVklVfSIDeH1a4cCsDW9xKuy23nMAsCw5HACDTp2Z1lYezCfonIrr6zYTzwFfGF8ghGOrSiGIJgynyNdrwUgIshAsKlpt0z2ivMOpJMF4fSxXXnr+uH85ZyTX4MTQjQvCTLRLNwVGcDSnVnklVY36n3fb88iv6yabjHBfH77GZzXJxadVkN+WbVn7UKo2fNrbM8YrhmZBKhV2Ssr9tO5+iDfBj5GX206uUo428d/Br0nNrhYcGOkRKmdi24NdSy6BRh0TBwQ3+TAFEKcPgky0SwsleoyUKEmPTaHwsJN6Y163x/pxQBcNqQT0SEmAgw6+sSr03rbXK8Bnrb7AYlhTB/bFZ1Ww9qD+aRtWMJC45NEK4VkGrtwefUT/FqpBt3J7v86kdqdiwCdI2UhYCHaKgky0SzcFdn1Y1IA+GzD0Ua14/9xtBiAocnhnucGdVb/vC2jZuflA65VNgZ2NpMUGcTFgxKYovuJdw0vEKKpgq5ns3z0RxwjxhOA9W2G2RTuzkWATuFNr+qEEK1DgkycNqdTodR1jey6UcmEBxk4VlzJz3tzT/i+/LJqTzOFO7wAhiSpHYDuQNqTZcHhVIgOMarLPzmdPBKwkH8a3kWvcVLW+yqY+iV9uyW53qcGoGex3yY2erj1iq0JslMNQyFEy5MgE6etzGr3dBTGhJq4eoQaKB+vTzvh+7a6qrEesSGYAw2e5wcnhQPqdTGnU2Fnptro0T/RjMZeDV/dSszW/wCQMXgmIVP+C3ojAzqFodVAtqWKHEsVGcWN6zhsiLvhQ6/VyNYsQrRhEmTitFlc04pGvZYAg46po5MBWHMgj5IKW4Pv+yNdvel5qCu43HrEhBBo0FFWbedwfhk7XVOMI2MV+Pgy2PklaPVw2Zt0vvwJz8K/QUa9p21+W3pxTbPHKV7fGpIcjlGnZUAns6zWIUQbJkEmTpv7+pi7qkqJCiYm1ISiQFpheYPvq7k+FuH1vF6n9ey8vDW9hJ2ZJSRrcrh57+1wdB2YzHD9lzDkujrnHOyaovz1YD7FrhA91anFBHMgP913Dp/cOvqU3i+EaB0SZOK0HR9kAEmua0ruquh4DqfCdlelVbvRw22w6zrZxtRCgnK38JXxMYLLjoA5CaYvg27j6j2ve1ryh53ZAIQHGQgNMNR7bGN0jggiRFrqhWjTJMjEaXO33ocF1PzCd1+XSnc1cxzvYG4ZZdV2gow6z3Rgbe7mj6odX/Ox7imiNRaUBNfCv7F9GxyLOwBzXfexnWo1JoTwHxJk4rRZ6qvIIk9ckbkXBR7cObze609DOpuZrvuef/ESARobWwNGoZn2PYTGn3AsveJCCTDU/LOWbkMh2j8JMnHa3MtThdUKMk9FVlR/RVbf/WMeTgedNzzOo4ZP0GoUPrGfz4+D/wWmk69jaNBpGZBYs63KqXYsCiH8hwSZOG31XyM78dSip2PxuEYPrOXw+VQ0v78DwLO2a3nEfgv9O0U1ejyDa3VBSkUmRPsnV7HFaXNPLYYF1K7IaqYWFUVBo6mZPrRU2TwrdQyp3XpfmgPzr4HMP0Bn4tsej/PONnWlkIG1dmE+Ge8gk4pMiPZOKjJx2uqryBLDA9FooNruJK/MewHh7eklKIoadjGhrs0qc/fCu+PVEAuMhJu+IXjolYDaRJLUhHvBhtRaJUSaPYRo/6QiE6etviAz6rXEhwWQVVJFRlElsaE1K2O4Gz0804qpa+Dz66G6BCK7wdT/QVR3xjqcXDc6mWHJEV4V3ckkRQbSNyGMonIr3WKCT/4GIYRfkyATp81S5Wq/D/T+55QUEURWSRXphRUMq3UtzL3i/dCkcNj2OXx9FzhtkDQapsyHYPV6mEGn5dnLBzZ5PBqNhm/uOgu7UyHAoDu1DyWE8BsytShOm7siq921CDVLQ9VuwVcUha3pxYDCxMKPYNFf1BDrdxncuMQTYqdLr9NKiAnRQUiQCS87j5Xw4W9HvHZnPpn6mj2gptEio1YLfkZRJZbyCl40vkPilpfVJ8+8G66cBwZZmFcI0XQytSi8zP5qBzuOlZBgDuDC/t43H+/JsrBkWyZ3jOvuFVr1XSOD+pep2pWazjzD8/xJuxM0Wvi/F2DkrS31cYQQHYBUZMLD4VTYl1MKwOa0ojqvP/v9Ht5cdYhvt2V5nquyOai2O4F6phaPv5esOJ3hK6bwJ91OqrWBcO3nEmJCiNMmQSY8jhVVYnWFkrshw83ucHrCrfZUoXtVD40GQo9bXNfdMn+suBLHsa3w7nhiKg+Tq4Sz+qwPodeEFvokQoiORIJMeBzKK/P8eXtGMXaH0/N4T1YpFVYHANklVZ7naxYMNqA9bs3E+LAAdFoNZylb0HwwCcqyOaB05rLqJ0npf2ZLfhQhRAciQSY8agdZlc3J3uxSz+NNaYWeP2fVCrKajsW6l1v1Oi13BK/iPcOLaG0VlHcay+Tqxyk2xtEj9uTrJgohRGNIkAmP2kEG3tOLm47UXDPLttSuyOpv9MDphOVzuM/2NjqNQlrSZfww+HVKCZIdl4UQzUqCTHgcylV3c+4SpTZpuFfgUBTluIqs0tOe71n5vnbrva0KvrwFfn0VgJdsV7I4+WG2ZqrnH9y58esmCiHEyUiQCQ93RXbl8M4AbHVttZJRVEmOpRq9q4qqsjk9U4p1Wu8rCuGjS2HXItAa+LHXE7zumExGcaVnR+hBtdZCFEKI0yVBJgAoKrdSUG4F4PJhapAdzi+nqNzqqcYGdDITGWwEaq6Ted0MXXBIXfg3fT2YzHD9l5T3vdJzrj1ZFuC4Fe+FEOI0SZAJAA7nq9VYojmATuGBdItWF9vdmlHMRtf1sZFdIogPU1ffcHcuuiuyfo698N4FUHgIzMkw/Ufodo7nXrI/jhZhcyhEBBlkjzAhRLOSIBNAzfWx7q5uwiGunZv/OFrMZleQDU+JJMGsBllNRWZnknYDU/fOgIoCSBgCt66A2D5AzQabTteKV4M6hzdpJXshhDgZCTIB1Fwf6x6jBpl7i5XV+3LZn6u24Y/oEkG82V2RVYKiMDLrU94wvIZesUKvSXDz9xAa5zlvbKgJo67mn9lgmVYUQjQzCTIB1A4ydUpxqCtwtmWom2B2jQ4mOsTkqchyisvg+/u4suAttBqFw12vgymfgtF7/y+tVkOnWlOJ0rEohGhuEmQCgEN5rqlFV0XWJz6UAEPNP4/hKWqFFm8OJIgqphyeDRvfxYmGp2zXkzb6cdDWv21K7Wti0rEohGhuEmSCaruDo66Ffd0rbuh1Wq/QGdlFDbJkg4UFxicZWrUB9AE8aryf9xz/hznI2OD53Q0fncIDiQk1tdCnEEJ0VBJkgqMFFTicCqEmvVfQDHU1fIDa6EHuHob+eCUDtUcoVELhpm/5xjoCqLsXWW1do9Ugk7Z7IURLkP3IBAdz1etj3WJDvDoKhyapVVhEkIHuZZtgwY0Yqks47Ixnmu0ffBM9hNLqH4F6lqiq5eoRSVgq7Uwe1qkFP4UQoqOSIBN1Gj3czusTy9TRyVyhW4PmkzngtEPyGKYdvY2jSgAHc0txbyRd36LBbuFBRu6b0LvFxi+E6NhkalHUafRwM+o0PBPxHcO2PKSGWP/JcMNigsJjADyr4wcYtJj09Td6CCFES5MgE3XuIQPAboXFd8KquerjsffCFe+BIcBzL9l+V5CdaFpRCCFamkwtdnCKonDIdY2sR6xrarGyGBbeAKlrQKODi16EEbd43uO+l8xdkZ2o0UMIIVqaBFkHl2OpptzqQKfVkBwZDMXp8OlVkLcHDMFw9YfQ8wKv98SHqfeF7cuRikwI4XsSZB2ce1oxJTIIY+52+OxqKMuB0AS4bgEkDK7zHndFVlzh3h1agkwI4TtyjayD25ZRDMAVITth3v+pIRbbT134t54QAzzXyNykIhNC+FKzB9njjz+ORqPx+urTp09zfxvRTNYdKuB63XLuyH4UbOXQ7Vy4ZSmYOzf4noTjgiwsQAp7IYTvtMhvoP79+7NixYqab6KXX3RtkdVmZ1za60w3fKM+MeR6uOQV0J24wpKKTAjRlrRIwuj1euLj4xt1bHV1NdXV1Z7HFoulJYYkjmerpOyTm5mu/QEA57iH0Z5zPzRir7DQAAMhJj1l1XZArpEJIXyrRa6RHThwgMTERLp168bUqVM5evRog8fOnTsXs9ns+UpKSmqJIYnaygvgo0uJTPsBq6Ljg7jZaMc90KgQc6tdlUmQCSF8qdmDbPTo0XzwwQcsXbqUN998k9TUVP70pz9RWlpa7/GzZ8+mpKTE85Went7cQxK1FRyC98ZD+gbKNSHcaJuNbsiUJp+m9nUymVoUQvhSs08tTpo0yfPnQYMGMXr0aFJSUli4cCHTp0+vc7zJZMJkkq09WsXRDTB/ClQWopiTuKrgHnY7E3m6e1STTxUfVqsikxuihRA+1OLt9+Hh4fTq1YuDBw+29LcSJ7JrEXx4CVQWQuJQtlz4P3bbE4kJNdVZY7ExpCITQrQVLR5kZWVlHDp0iISEhJb+VqI+igK/vgpfTANHNfT+P5j2Hasz1b/6M7pFeW3d0ljx5ppdn0+08r0QQrS0Zg+y++67j9WrV3PkyBF+++03Lr/8cnQ6Hddee21zfytxMg47fDcLls9RH4/6C1zzCRiDWX+oAIAx3Zo+rQhSkQkh2o5m/1/pjIwMrr32WgoKCoiJiWHs2LGsX7+emJiY5v5WopbCcitfbcng2lHJBJv0UF0G/7sZDvwIaGDCszDmTgAqrQ7+SC8CYMwpXB+Dmq5FrQZCTFKRCSF8p9l/A33++efNfUrRCM98t4cvt2RgqbIza3SoumZi9nbQB8AV70LfSzzHbk4rwuZQiA8LoEtU0Cl9v+4xIQzubCY5KviUpiaFEKK5yP9KtwMOp8LP+3IByNy/GXY8DZYMCIpWF/7tPMLr+HWH8wG1GjvVEDLqtXx919jTG7gQQjQDCbJ2YHtGMYXlVs7S7uCx3FdAUwlRPWDqFxDZrc7xvx48vetjQgjRlkiQtQM/783lSt1q5urfxaBxUB4/iuAbF0BQZJ1jf9qbw9b0YnRaDWN7RvtgtEII0bxkGxd/pygk/PEvXjS8jUHjYIljDIsG/LveECuvtvPo4l0ATB/blcTwwDrHCCGEv5Ega8NyLVW8vHw/P+7Krv8Au5XKL27j2sr5AKyOu5F7bDPYmFFR7+H/Wr6fY8WVdAoPZOb4ni01bCGEaFUytdgGlVbZeGfNYd79JZVKmwOAiwYm8OSl/YkKcS3nVVkMC64n8Mgv2BUtb4XOYPD4e1De+50tR4vqnHNHRgnv/5oKwNOXDyDIKH/1Qoj2QX6b+ZiiKGxILeRQXhk5lmpyLVUs351DQbkVgN5xoRzMK+O7HVmsO1zAU5cO4KIkq9pen7eXKk0gt1vvZsjgKxmSFI5GA+mFleSWVhEbqt7rZXc4mb1oO04FLh6UwLm9Y335kYUQollJkPnYT3tzmf7hpjrPd4sO5oGJfZjQP45dmRbu+2Ibe7NLeWv+/zg39F8EWQtQQhOZapnJZmdn7u0dQ2iAgd5xoezNLmVLWjETB6h7ws3fmM7OYxbCAvTMuaRfa39EIYRoURJkPrY9owSArtHBnNk9itjQALrHBjOhfzwGnXoJc0AnM0vuGsvXC9/jon1PEWStpjisF4fGz2PzZ2lEBRsZ3DkcgKHJEWqQHS1i4oB4bA4nb606BMCsC3p5qjQhhGgvJMh8LK2gHICrRyRxx7juDR5n3PIeVx14ADROVjsGcU/+PfRap+7xdk6vGLRa9cbm4SkRzP/9KFvS1OtkS7Zmcqy4kugQI1NGJbfwpxFCiNYnQeZjaYVqh2GDS0U5nbD8UVj3bwCUoTfydcWNFG/L5ffUQgDG9am55jUsORyA7cdKqLI5+M8qdfuc6WO7EWDQtdCnEEII35Eg87GjBWqQJdcXZLZK+Op22LNEfXz+HDRjZ/FPh0JexUZ+OZCPVgNn17qxuWt0MBFBBooqbPxr+X4O5ZUTGqDn+jOkGhNCtE8SZD5UWmXzdCemRAV7v1ier+7mnLERdEa49D8w6CoAjHoNb14/nDmLd9IjLoTwIKPnbRqNhmHJEazcm8s7vxwGYNqZXQiVXZyFEO2UBJkPpbmqsegQo/dWKPkH4dMroOgIBITDlM+gy1le7w0x6Xn5miH1nndYihpkigKBBh03n9W1ZT6AEEK0ARJkPuQOsuTIWtOKab/B59dBZRGEp8D1X0J001bhGJYc4fnztaOSiQw2nuBoIYTwbxJkPpRWqHYsdnFPK+78Ehb9FRxW6DQcrl0AIU3fkHRwkpnQAD02h5PbzpZqTAjRvkmQ+ZCn0SMyENb+C1Y8rr7Q52KY/F8wntqml0FGPYvuPBNFgQSzLAwshGjfJMh86EhBOTocXHrsRTiyUH3yjDvhwqdBe3qt8j1iQ5thhEII0fZJkPlQfn4B7xlepOuRbYAGJv4Tzvirr4clhBB+RYLMR6oK0nm16iH669JQ9IForngX+l7s62EJIYTfkSDzhZxd6D66gv7aLAoUM5HTvoLOI3w9KiGE8EuysWZrO/QTvDcBQ3kWB52J3B/+EhoJMSGEOGUSZK1py8fw6VVgLSUrfDiTrY8TGNvwQsFCCCFOTqYWW4OiwM/PwJoX1McDr+a/2juwZGfVv8aiEEKIRpMgawZl1Xb+tXw//9ucQZXNgVNRcCpwZvcoPrpxMJolf4Mdrvb6s++Hcx/m0LyNAKRESpAJIcTpkCA7TT/uyuaxJbvIKqmq89q2A0eo/uBxAo79BhodXPIKDLsRgKOu7VvqLBYshBCiSSTITpHTqXDvwq18vTUTgKTIQB67uD+940PRaTXMfv87Hil+goBjx8AYCld/CD3OB8DucJLuCTKpyIQQ4nRIkJ2i348U8vXWTPRaDbef3Y2/ndeTQKNrNY5jW/h3xf2EaguxGGMJu2URxA/wvDerpAq7U8Go1xIfFuCjTyCEEO2DdC2eok1H1N2ZJwyI54GJfWpCbO/38MFFhNoL2e1M4bGYV71CDLxXvddqNa06biGEaG+kIjtFvx8pAmBkSs2WKWx4G374B6BQ0ukcrjp0I8G5dbdQOVLgXvVephWFEOJ0SUV2ChxOhS1priDrGglOByx9CH54AFBg+DT01y+kQhNIbmk1uaXejSDuRo/kSGn0EEKI0yVBdgr2Zlsoq7YTYtLTJ0oPC2+E9W+oL45/HC5+heDAALpFq0G1K9Pi9f4j+WpFJo0eQghx+iTITsEm17TiOZ1B99GfYe+3oDPCFe/B2HtBo1736p9oBmD3cUF2VDoWhRCi2UiQnYKNRwrppsnk2YJ74dgmCIyAG7+GgVd6Hdc/MQyAXZklnucsVTYO56kVWbfokNYbtBBCtFPS7NFEiqJgO7yWr4z/xFxVDhFdYOr/ILpnnWPdFdnOYzUV2bKd2VgdTnrGhpAUKbs3CyHE6eoQFVlGUQVzvt5Jmqtb8HQUbJjPa7YnCNeU40wcDtNX1BtiUFORHS2swFJlA2DJNvUG6kuHJKLRSOu9EEKcrg4RZJ+sP8pH69J44+eDp34SRYFfXiZ66R2YNHbWmc5Ee/N3EBLT4Fsigo10Clerrt2ZFnJLq/j1YD4Afx7c6dTHIoQQwqNDTC0eK64EYLOrZb7JHHb4bhZs+RCA/9r/j7yRjzDGcPKpwX6JYRwrrmRXpoU9WRacCgxNDpdV74UQopl0iIosu0QNskN55RSVW5v25upSmH+NGmIaLa8F/IVn7NczoktUo97uafg4VuJZl/HSwYlNG4MQQogGdYggyyyuuSF5y9EmVGWWTHh/EhxcAYYgSi/7iJeLzwFgRJfIRp3C3fCxen8eW9OL0WrgokESZEII0VzafZA5nQo5lpogO9n04tb0Yub/fhQlewf893zI2QHBsTDtO9YbRgHQIzaEyOC6S0/VZ0AntSIrcFWCZ/WIJibUdCofRQghRD3a/TWy/PJq7E7F8/hEQVZtdzD9g430q9zEVUGvo7eXQ3RvmPoFRKSwadseAEZ2iWjwHMeLDwsgMthIoSvILh0iTR5CCNGc2n1Flu3a8FLvWmV+W0YxNoez3mN/3JXD+VXLmGd4Xg2xLn+C6csgIgWAta6OwxEpjZtWBNBoNJ7rZEa9lgn94075swghhKir3QeZe+fmfolhhAcZqLI56ywZBYCiYF/xJM8b/ote42SpbhzK9V+qq3YA6YUV7Mq0oNXAuN4Nt9zXZ1Bn9TrZ+L6xhAYYTu8DCSGE8NLupxbdFVmiOZDoEBM/7c1lc1oRg5PCaw6yV1O28C9cXroIgH87JvNi1RUsL7DSM069nvXj7hwARnaJJCqkade4bv9Td3QaDVPPSDn9DySEEMJLh6nIEsIDGO7aO2xz7c7FikL4+HJC9i/Cpuh4P+o+fu96B6Dhp725nsOW7cwGYEL/+CaPwRxkYNaFvYmT3aCFEKLZtfsgc99DlmAOYFiyGmTuvcQoTIX3LoS0XykliGm2B0g6/3bOc00d/rxPDbL8smo2pqk7Ql8o17iEEKJNafdB5q7I4s2BDE4yo9NqyCqpIm/vr/DeBVBwgMrAeK6snsOhkJGc2zuG8/qoYbXpSBGWKhsrduegKDCwk5nOEbIihxBCtCUdJsgSzAEEGfX0SwjjQu1GIr6YDOV5ED+Qv4e9zD4lmatHJqHXaUmOCqJ7TDB2p8Iv+/NZuss9rSjVmBBCtDXtOsgURfE0e8S7rk/dGfgjbxleQe+oQulxAWvGfsT3aaDVwDUjkzzvPa9PLABLth3jt4MFAEwc0PTrY0IIIVpWuw6ywnIrVtc9Y3EhBvjhH0zKeBWtRmGx7kKuKL6bGz9Rb3I+r0+cZ6V6gHNdQbZsVw5Wh5NuMcH0iA1t/Q8hhBDihNp1+717WrFTMBi/mgZ7vwVgru1a3q66GMpLMem1XD0iiVkX9PJ678gukYSa9JRW24FT61YUQgjR8tp1kGWXVBFNCfP4F+zdDzoTXP4Wuzd0IvxYCVNHJzPtzK71rn1o0Gn5U69ovt+hXh+bKEEmhBBtUrsOsvLMPXxlnEOyI09doePazyH5DD4e0Lj3n9s7lu93ZJNgDvCsziGEEKJtab9BdmQtE9ZdT4C2lAJjJ6Ju/QaiujfpFJcO6cShvHLG9ohGo9G00ECFEEKcjvYZZGm/wceXE+CwstnZk+0j3uLmJoYYqIv8PjipTwsMUAghRHNpn12LnYZDpxFsCPgT11kfJiImwdcjEkII0ULaZ5DpTTB1IbN1s6jGSLxZ1jgUQoj2qn0GGaAYQ8i0VAPqqh5CCCHap3YbZCWVNqpsrpuhZdV5IYRot9ptkLlvho4MNhJg0Pl4NEIIIVpKuw2y49dYFEII0T612yBzV2SJ4RJkQgjRnrXbIHNvqCkdi0II0b61WJC98cYbdOnShYCAAEaPHs3vv//eUt+qXjX7kAWe5EghhBD+rEWCbMGCBcyaNYvHHnuMLVu2MHjwYCZMmEBubm5LfLt6Zck1MiGE6BBaJMhefvllbrvtNm6++Wb69evHW2+9RVBQEO+//36dY6urq7FYLF5fzSHLNbUo95AJIUT71uxBZrVa2bx5M+PHj6/5Jlot48ePZ926dXWOnzt3Lmaz2fOVlJRU55imUhSlpiKTIBNCiHat2YMsPz8fh8NBXFyc1/NxcXFkZ2fXOX727NmUlJR4vtLT0097DKXVdiqsDkCCTAgh2jufr35vMpkwmepubHk6yqvtDEsOp7zaQZDR5x9RCCFEC2r23/LR0dHodDpycnK8ns/JySE+vnV2WU4wB/LVnWe1yvcSQgjhW80+tWg0Ghk+fDgrV670POd0Olm5ciVjxoxp7m8nhBCig2uRebdZs2Zx0003MWLECEaNGsUrr7xCeXk5N998c0t8OyGEEB1YiwTZNddcQ15eHnPmzCE7O5shQ4awdOnSOg0gQgghxOnSKIqi+HoQtVksFsxmMyUlJYSFhfl6OEIIIXyksXnQbtdaFEII0TFIkAkhhPBrEmRCCCH8mgSZEEIIvyZBJoQQwq9JkAkhhPBrEmRCCCH8mgSZEEIIvyZBJoQQwq9JkAkhhPBrbW6zLveKWRaLxccjEUII4UvuHDjZSoptLshKS0sBSEpK8vFIhBBCtAWlpaWYzeYGX29ziwY7nU4yMzMJDQ1Fo9Gc8nksFgtJSUmkp6fL4sO1yM+lYfKzqZ/8XBomP5v6NdfPRVEUSktLSUxMRKtt+EpYm6vItFotnTt3brbzhYWFyT+wesjPpWHys6mf/FwaJj+b+jXHz+VElZibNHsIIYTwaxJkQggh/Fq7DTKTycRjjz2GyWTy9VDaFPm5NEx+NvWTn0vD5GdTv9b+ubS5Zg8hhBCiKdptRSaEEKJjkCATQgjh1yTIhBBC+DUJMiGEEH5NgkwIIYRfa5dB9sYbb9ClSxcCAgIYPXo0v//+u6+H5HNz585l5MiRhIaGEhsby2WXXca+fft8Paw255///CcajYaZM2f6eihtwrFjx7j++uuJiooiMDCQgQMHsmnTJl8Py6ccDgePPvooXbt2JTAwkO7du/PUU0+ddGHb9mjNmjVccsklJCYmotFoWLx4sdfriqIwZ84cEhISCAwMZPz48Rw4cKDZx9HugmzBggXMmjWLxx57jC1btjB48GAmTJhAbm6ur4fmU6tXr2bGjBmsX7+e5cuXY7PZuPDCCykvL/f10NqMjRs38vbbbzNo0CBfD6VNKCoq4qyzzsJgMPDDDz+we/duXnrpJSIiInw9NJ967rnnePPNN/n3v//Nnj17eO6553j++ed5/fXXfT20VldeXs7gwYN544036n39+eef57XXXuOtt95iw4YNBAcHM2HCBKqqqpp3IEo7M2rUKGXGjBmexw6HQ0lMTFTmzp3rw1G1Pbm5uQqgrF692tdDaRNKS0uVnj17KsuXL1fOOecc5Z577vH1kHzuH//4hzJ27FhfD6PNueiii5RbbrnF67nJkycrU6dO9dGI2gZAWbRokeex0+lU4uPjlRdeeMHzXHFxsWIymZT58+c36/duVxWZ1Wpl8+bNjB8/3vOcVqtl/PjxrFu3zocja3tKSkoAiIyM9PFI2oYZM2Zw0UUXef3b6eiWLFnCiBEjuOqqq4iNjWXo0KH897//9fWwfO7MM89k5cqV7N+/H4Bt27axdu1aJk2a5OORtS2pqalkZ2d7/TdlNpsZPXp0s/8+bnOr35+O/Px8HA4HcXFxXs/HxcWxd+9eH42q7XE6ncycOZOzzjqLAQMG+Ho4Pvf555+zZcsWNm7c6OuhtCmHDx/mzTffZNasWTz00ENs3LiRu+++G6PRyE033eTr4fnMgw8+iMVioU+fPuh0OhwOB8888wxTp0719dDalOzsbIB6fx+7X2su7SrIROPMmDGDnTt3snbtWl8PxefS09O55557WL58OQEBAb4eTpvidDoZMWIEzz77LABDhw5l586dvPXWWx06yBYuXMinn37KZ599Rv/+/dm6dSszZ84kMTGxQ/9cfKldTS1GR0ej0+nIycnxej4nJ4f4+Hgfjaptueuuu/j222/5+eefm3XfN3+1efNmcnNzGTZsGHq9Hr1ez+rVq3nttdfQ6/U4HA5fD9FnEhIS6Nevn9dzffv25ejRoz4aUdtw//338+CDDzJlyhQGDhzIDTfcwL333svcuXN9PbQ2xf07tzV+H7erIDMajQwfPpyVK1d6nnM6naxcuZIxY8b4cGS+pygKd911F4sWLeKnn36ia9euvh5Sm3D++eezY8cOtm7d6vkaMWIEU6dOZevWreh0Ol8P0WfOOuusOrdo7N+/n5SUFB+NqG2oqKios1uxTqfD6XT6aERtU9euXYmPj/f6fWyxWNiwYUOz/z5ud1OLs2bN4qabbmLEiBGMGjWKV155hfLycm6++WZfD82nZsyYwWeffcbXX39NaGioZ47abDYTGBjo49H5TmhoaJ3rhMHBwURFRXX464f33nsvZ555Js8++yxXX301v//+O++88w7vvPOOr4fmU5dccgnPPPMMycnJ9O/fnz/++IOXX36ZW265xddDa3VlZWUcPHjQ8zg1NZWtW7cSGRlJcnIyM2fO5Omnn6Znz5507dqVRx99lMTERC677LLmHUiz9kC2Ea+//rqSnJysGI1GZdSoUcr69et9PSSfA+r9mjdvnq+H1uZI+32Nb775RhkwYIBiMpmUPn36KO+8846vh+RzFotFueeee5Tk5GQlICBA6datm/Lwww8r1dXVvh5aq/v555/r/b1y0003KYqituA/+uijSlxcnGIymZTzzz9f2bdvX7OPQ/YjE0II4dfa1TUyIYQQHY8EmRBCCL8mQSaEEMKvSZAJIYTwaxJkQggh/JoEmRBCCL8mQSaEEMKvSZAJIYTwaxJkQggh/JoEmRBCCL8mQSaEEMKv/T96bJh7TO9IZgAAAABJRU5ErkJggg==", 129 | "text/plain": [ 130 | "
" 131 | ] 132 | }, 133 | "metadata": {}, 134 | "output_type": "display_data" 135 | } 136 | ], 137 | "source": [ 138 | "plt.figure(figsize=(5,5))\n", 139 | "plt.plot(x, y, label='Truth')\n", 140 | "plt.plot(x, y_pred, label='Predictions')\n", 141 | "plt.legend()\n", 142 | "plt.show()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### DATA 2: Higher Dimensions" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 54, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "Array([130.9917 , -41.129025, 28.312298], dtype=float32)" 161 | ] 162 | }, 163 | "execution_count": 54, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "x1 = jnp.linspace(0,10,100)\n", 170 | "x2 = jnp.linspace(-10,10,100)\n", 171 | "X = jnp.vstack([x1, x2]).T\n", 172 | "\n", 173 | "design_matrix = jnp.vstack((jnp.ones(X.shape[0]), X.T)).T \n", 174 | "\n", 175 | "y = design_matrix @ jnp.array([5,3,2]) + jax.random.normal(jax.random.PRNGKey(12), shape=(len(x1),))\n", 176 | "\n", 177 | "design_matrix = jnp.vstack((jnp.ones(X.shape[0]), X.T)).T \n", 178 | "_lambda = jnp.linalg.det(design_matrix.T @ design_matrix)\n", 179 | "\n", 180 | "jnp.linalg.inv(design_matrix.T @ design_matrix) @ design_matrix.T @ y" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 52, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "text/plain": [ 191 | "(Array(-180.12497, dtype=float32),\n", 192 | " Array(-13.431677, dtype=float32),\n", 193 | " Array([102.53715 , -42.046654, 28.41555 ], dtype=float32))" 194 | ] 195 | }, 196 | "execution_count": 52, 197 | "metadata": {}, 198 | "output_type": "execute_result" 199 | } 200 | ], 201 | "source": [ 202 | "model = MyLinearRegression().fit(X, y)\n", 203 | "\n", 204 | "model.predict(X)[1], y[1], model.coeff" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 53, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "text/plain": [ 215 | "(np.float32(-13.807932),\n", 216 | " Array(-178.63159, dtype=float32),\n", 217 | " Array(-13.971927, dtype=float32),\n", 218 | " array([1.4087583, 2.8175173], dtype=float32))" 219 | ] 220 | }, 221 | "execution_count": 53, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "from sklearn.linear_model import LinearRegression\n", 228 | "\n", 229 | "model_sk_fitted = LinearRegression().fit(X, y)\n", 230 | "\n", 231 | "model_sk_fitted.predict(X)[2], model.predict(X)[2], y[2], model_sk_fitted.coef_" 232 | ] 233 | } 234 | ], 235 | "metadata": { 236 | "kernelspec": { 237 | "display_name": ".venv", 238 | "language": "python", 239 | "name": "python3" 240 | }, 241 | "language_info": { 242 | "codemirror_mode": { 243 | "name": "ipython", 244 | "version": 3 245 | }, 246 | "file_extension": ".py", 247 | "mimetype": "text/x-python", 248 | "name": "python", 249 | "nbconvert_exporter": "python", 250 | "pygments_lexer": "ipython3", 251 | "version": "3.11.9" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 2 256 | } 257 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: scikit-jax 2 | theme: 3 | name: material 4 | font: 5 | text: kanit 6 | logo: assets/scikit_jax_logo.png 7 | palette: 8 | # Palette toggle for automatic mode 9 | - media: "(prefers-color-scheme)" 10 | toggle: 11 | icon: material/brightness-auto 12 | name: Switch to light mode 13 | 14 | # Palette toggle for light mode 15 | - media: "(prefers-color-scheme: light)" 16 | scheme: default 17 | primary: black 18 | toggle: 19 | icon: material/brightness-7 20 | name: Switch to dark mode 21 | 22 | # Palette toggle for dark mode 23 | - media: "(prefers-color-scheme: dark)" 24 | scheme: slate 25 | primary: black 26 | toggle: 27 | icon: material/brightness-4 28 | name: Switch to system preference 29 | 30 | extra: 31 | social: 32 | - type: github 33 | icon: assets/github-brands-solid.svg 34 | link: https://github.com/LiibanMo 35 | - type: linkedin 36 | link: https://www.linkedin.com/in/lmohamud12/ 37 | 38 | nav: 39 | - Home: "index.md" 40 | - API: 41 | - skjax.clustering: 42 | - KMeans: "api/k_means.md" 43 | - skjax.decomposition: 44 | - PCA: "api/pca.md" 45 | - skjax.linear_model: 46 | - LinearRegression: "api/linear_regression.md" 47 | - skjax.naive_bayes: 48 | - MultinomialNaiveBayes: "api/multinomial_naive_bayes.md" 49 | - GaussianNaiveBayes: "api/gaussian_naive_bayes.md" 50 | 51 | -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.9" 7 | 8 | mkdocs: 9 | configuration: mkdocs.yml 10 | 11 | python: 12 | install: 13 | - requirements: requirements.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax == 0.4.31 2 | pandas == 2.2.2 3 | numpy == 2.0.1 4 | matplotlib == 3.9.0 -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | wheel == 0.44.0 2 | pytest == 8.3.2 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="scikit-jax", 5 | version="0.0.3dev1", 6 | author="Liiban Mohamud", 7 | author_email="liibanmohamud12@gmail.com", 8 | description="Classical machine learning algorithms on the GPU/TPU.", 9 | long_description=open("README.md").read(), 10 | long_description_content_type="text/markdown", 11 | url="https://github.com/LiibanMo/scikit-jax", 12 | packages=find_packages(), 13 | install_requires=[ 14 | "jax", 15 | "pandas", 16 | "numpy", 17 | "matplotlib", 18 | "seaborn", 19 | ], 20 | extras_require={ 21 | "dev": ["pytest>=6.0"], 22 | "test": ["pytest>=6.0"], 23 | }, 24 | classifiers=[ 25 | "Development Status :: 3 - Alpha", 26 | "Intended Audience :: Developers", 27 | "Intended Audience :: Science/Research", 28 | "Intended Audience :: Education", 29 | "Topic :: Software Development :: Build Tools", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | "License :: OSI Approved :: MIT License", 32 | "Programming Language :: Python :: 3", 33 | ], 34 | keywords="jax classical machine learning", 35 | python_requires=">=3.9", 36 | ) 37 | -------------------------------------------------------------------------------- /skjax/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/.DS_Store -------------------------------------------------------------------------------- /skjax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/__init__.py -------------------------------------------------------------------------------- /skjax/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/__pycache__/clustering.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/__pycache__/clustering.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/__pycache__/decomposition.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/__pycache__/decomposition.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/__pycache__/linear_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/__pycache__/linear_model.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/__pycache__/naive_bayes.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/__pycache__/naive_bayes.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/_utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/.DS_Store -------------------------------------------------------------------------------- /skjax/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/__init__.py -------------------------------------------------------------------------------- /skjax/_utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/_utils/__pycache__/_helper_functions.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/__pycache__/_helper_functions.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/_utils/__pycache__/config.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/__pycache__/config.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/_utils/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/helpers/__init__.py -------------------------------------------------------------------------------- /skjax/_utils/helpers/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/helpers/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/_utils/helpers/__pycache__/_clustering.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/helpers/__pycache__/_clustering.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/_utils/helpers/__pycache__/_helper_functions.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/skjax/_utils/helpers/__pycache__/_helper_functions.cpython-311.pyc -------------------------------------------------------------------------------- /skjax/_utils/helpers/_clustering.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.numpy.linalg import norm 4 | 5 | # ------------------------------------------------------------------------------------------ # 6 | ### KMeans 7 | # ------------------------------------------------------------------------------------------ # 8 | 9 | 10 | def initialize_centroids(X: jax.Array, num_clusters: int, random_state: int = 12): 11 | # k-means++ initialisation 12 | 13 | assert ( 14 | num_clusters > 0 15 | ), f"num_clusters should be a natural number greater than 0. Instead got {num_clusters}" 16 | 17 | X_without_centroids = X.copy() 18 | 19 | key = jax.random.PRNGKey(random_state) 20 | init_index = jax.random.choice(key, X.shape[0]) 21 | 22 | initialised_centroids = {cluster: None for cluster in range(num_clusters)} 23 | 24 | init_centroid = X[init_index] 25 | initialised_centroids[0] = init_centroid 26 | 27 | if num_clusters == 1: 28 | return init_centroid 29 | else: 30 | for cluster in range(1, num_clusters): 31 | X_without_centroids = jnp.vstack( 32 | [ 33 | X_without_centroids[:init_index], 34 | X_without_centroids[init_index + 1 :], 35 | ] 36 | ) 37 | squared_distances_from_init_centroid = ( 38 | norm(X_without_centroids - init_centroid, axis=1) ** 2 39 | ) 40 | prob_dist_of_centroid_chosen = ( 41 | squared_distances_from_init_centroid 42 | / squared_distances_from_init_centroid.sum() 43 | ) 44 | key, subkey = jax.random.split(key) 45 | init_index = jax.random.choice( 46 | subkey, X_without_centroids.shape[0], p=prob_dist_of_centroid_chosen 47 | ) 48 | init_centroid = X_without_centroids[init_index] 49 | initialised_centroids[cluster] = init_centroid 50 | 51 | return initialised_centroids 52 | 53 | 54 | # ------------------------------------------------------------------------------------------ # 55 | 56 | 57 | def assign_clusters_to_data(X: jax.Array, centroids: dict): 58 | 59 | distances_matrix = jnp.zeros(shape=(len(centroids), X.shape[0])) 60 | 61 | for cluster, centroid in centroids.items(): 62 | distances_from_cluster = norm(X - centroid, axis=1) 63 | distances_matrix = distances_matrix.at[cluster].set(distances_from_cluster) 64 | 65 | assigned_clusters = distances_matrix.argmin(axis=0) 66 | 67 | return assigned_clusters 68 | 69 | 70 | # ------------------------------------------------------------------------------------------ # 71 | 72 | 73 | def calculate_new_centroids( 74 | X: jax.Array, assigned_centroids: jax.Array, num_clusters: int 75 | ): 76 | 77 | centroids = {cluster: None for cluster in range(num_clusters)} 78 | new_distances = jnp.zeros(shape=(num_clusters, X.shape[0])) 79 | 80 | for cluster in range(num_clusters): 81 | indices_for_cluster = jnp.where(assigned_centroids == cluster)[0] 82 | X_at_cluster = X[indices_for_cluster] 83 | new_centroid = jnp.mean(X_at_cluster, axis=0) 84 | centroids[cluster] = new_centroid 85 | new_distances_for_cluster = norm(X - new_centroid, axis=1) 86 | new_distances = new_distances.at[cluster].set(new_distances_for_cluster) 87 | 88 | updated_centroids_for_X = new_distances.argmin(axis=0) 89 | return centroids, updated_centroids_for_X 90 | -------------------------------------------------------------------------------- /skjax/_utils/helpers/_data.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def split_data(data, val_size=0.1, test_size=0.2): 6 | """ 7 | Splits the data into training, validation, and test sets. 8 | 9 | Args: 10 | data (jax.Array): The dataset to split. 11 | val_size (float, optional): Proportion of data to use for validation. Defaults to 0.1. 12 | test_size (float, optional): Proportion of data to use for testing. Defaults to 0.2. 13 | 14 | Returns: 15 | tuple: A tuple containing the training data, validation data, and test data as jax.Array. 16 | """ 17 | split_index_test = int(len(data) * (1 - test_size)) 18 | 19 | data_non_test = data[:split_index_test] 20 | data_test = data[split_index_test:] 21 | 22 | split_index_val = int(len(data_non_test) * (1 - val_size)) 23 | 24 | data_train = data_non_test[:split_index_val] 25 | data_val = data_non_test[split_index_val:] 26 | 27 | return jnp.asarray(data_train), jnp.asarray(data_val), jnp.asarray(data_test) 28 | -------------------------------------------------------------------------------- /skjax/_utils/helpers/_helper_functions.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from jax import jit 5 | 6 | # ------------------------------------------------------------------------------------------ # 7 | 8 | 9 | def split_data(data, val_size=0.1, test_size=0.2): 10 | """ 11 | Splits the data into training, validation, and test sets. 12 | 13 | Args: 14 | data (jax.Array): The dataset to split. 15 | val_size (float, optional): Proportion of data to use for validation. Defaults to 0.1. 16 | test_size (float, optional): Proportion of data to use for testing. Defaults to 0.2. 17 | 18 | Returns: 19 | tuple: A tuple containing the training data, validation data, and test data as jax.Array. 20 | """ 21 | split_index_test = int(len(data) * (1 - test_size)) 22 | 23 | data_non_test = data[:split_index_test] 24 | data_test = data[split_index_test:] 25 | 26 | split_index_val = int(len(data_non_test) * (1 - val_size)) 27 | 28 | data_train = data_non_test[:split_index_val] 29 | data_val = data_non_test[split_index_val:] 30 | 31 | return jnp.asarray(data_train), jnp.asarray(data_val), jnp.asarray(data_test) 32 | 33 | 34 | # ------------------------------------------------------------------------------------------ # 35 | ### LINEAR REGRESSION 36 | # ------------------------------------------------------------------------------------------ # 37 | 38 | 39 | @jit 40 | def compute_mse(y_true, y_pred): 41 | return jnp.mean((y_pred - y_true) ** 2) 42 | 43 | 44 | # ------------------------------------------------------------------------------------------ # 45 | 46 | 47 | @jit 48 | def calculate_loss_gradients( 49 | beta: jax.Array, X: jax.Array, y: jax.Array, p: int = 2, lambda_: float = 0.01 50 | ): 51 | """ 52 | Forward pass on data X. 53 | 54 | Args: 55 | beta (jax.Array): weights and bias. 56 | X (jax.Array): data 57 | """ 58 | return (2 / len(y)) * X.T @ (jnp.dot(X, beta) - y) + lambda_ * p * jnp.sign( 59 | beta 60 | ) * jnp.insert(beta[1:] ** (p - 1), 0, 0) 61 | 62 | 63 | # ------------------------------------------------------------------------------------------ # 64 | ### K-MEANS 65 | # ------------------------------------------------------------------------------------------ # 66 | 67 | 68 | def compute_euclidean_distance(x1, x2): 69 | """ 70 | Computes the Euclidean distance between two vectors. 71 | 72 | Args: 73 | x1 (jax.Array): First vector. 74 | x2 (jax.Array): Second vector. 75 | 76 | Returns: 77 | jax.Array: Euclidean distance between the two vectors. 78 | """ 79 | return jnp.sqrt(jnp.sum((x1 - x2) ** 2)) 80 | 81 | 82 | # ------------------------------------------------------------------------------------------ # 83 | 84 | 85 | def kmeans_plus_initialization(key: jax.Array, num_clusters: int, X: jax.Array): 86 | """ 87 | Initializes centroids using the k-means++ algorithm. 88 | 89 | Args: 90 | key (jax.Array): Random key for initialization. 91 | num_clusters (int): Number of clusters (centroids) to initialize. 92 | X (jax.Array): Data points to initialize centroids from. 93 | 94 | Returns: 95 | jax.Array: Initialized centroids as jax.Array. 96 | """ 97 | indices_of_available_centroids = np.arange( 98 | len(X) 99 | ) # Used to keep track which centroid has been chosen. 100 | 101 | key, subkey = jax.random.split(key) 102 | 103 | init_index = jax.random.choice(key, indices_of_available_centroids) 104 | init_centroids = [X[init_index]] 105 | indices_of_available_centroids = np.delete( 106 | indices_of_available_centroids, init_index 107 | ) 108 | 109 | index = 0 110 | while len(init_centroids) < num_clusters: 111 | distance_between_points_and_init_centroid = jnp.asarray( 112 | [compute_euclidean_distance(init_centroids[index], x) for x in X] 113 | ) 114 | # Computing probabilities of choosing centroids proportional to distance 115 | probabilities = distance_between_points_and_init_centroid[ 116 | indices_of_available_centroids 117 | ] / jnp.sum( 118 | distance_between_points_and_init_centroid[indices_of_available_centroids] 119 | ) 120 | index_of_centroid_chosen = jax.random.choice( 121 | subkey, indices_of_available_centroids, p=probabilities, replace=False 122 | ) 123 | init_centroids.append(X[index_of_centroid_chosen]) 124 | indices_of_available_centroids = jnp.delete( 125 | indices_of_available_centroids, index_of_centroid_chosen 126 | ) 127 | index += 1 128 | 129 | return jnp.asarray(init_centroids) 130 | 131 | 132 | # ------------------------------------------------------------------------------------------ # 133 | 134 | 135 | def initialize_k_centroids(num_clusters, X, init: str = "random", seed: int = 12): 136 | """ 137 | Initializes centroids for k-means clustering. 138 | 139 | Args: 140 | num_clusters (int): Number of clusters (centroids) to initialize. 141 | X (jax.Array): Data points to initialize centroids from. 142 | init (str, optional): Initialization method ('random' or 'k-means++'). Defaults to 'random'. 143 | seed (int, optional): Random seed for initialization. Defaults to 12. 144 | 145 | Returns: 146 | jax.Array: Initialized centroids as jax.Array. 147 | """ 148 | key = jax.random.key(seed) 149 | initialization = { 150 | "random": jax.random.choice(key, X, shape=(num_clusters,), replace=False), 151 | "k-means++": kmeans_plus_initialization(key, num_clusters, X), 152 | } 153 | return initialization[init] 154 | 155 | 156 | # ------------------------------------------------------------------------------------------ # 157 | 158 | 159 | def calculating_distances_between_centroids_and_points(centroids, X): 160 | """ 161 | Calculates the distance between each centroid and each data point. 162 | 163 | Args: 164 | centroids (jax.Array): Centroids for the k-means algorithm. 165 | X (jax.Array): Data points. 166 | 167 | Returns: 168 | jax.Array: Distance matrix where each entry (i, j) represents the distance between the i-th centroid and the j-th data point. 169 | """ 170 | return jnp.asarray( 171 | [[compute_euclidean_distance(centroid, x) for x in X] for centroid in centroids] 172 | ) 173 | 174 | 175 | # ------------------------------------------------------------------------------------------ # 176 | 177 | 178 | def calculate_new_centroids(centroids, X): 179 | """ 180 | Computes the new centroids based on the current centroids and data points. 181 | 182 | Args: 183 | centroids (jax.Array): Current centroids. 184 | X (jax.Array): Data points. 185 | 186 | Returns: 187 | jax.Array: New centroids computed as the mean of the data points assigned to each centroid. 188 | """ 189 | distances_between_centroids_and_points = ( 190 | calculating_distances_between_centroids_and_points(centroids, X) 191 | ) 192 | labels_of_each_point = jnp.argmin(distances_between_centroids_and_points.T, axis=1) 193 | indices_of_each_cluster = [ 194 | jnp.where(labels_of_each_point == label) 195 | for label in jnp.unique(labels_of_each_point) 196 | ] 197 | new_centroids = jnp.asarray( 198 | [ 199 | jnp.mean(X[collection_of_indices].T, axis=1) 200 | for collection_of_indices in indices_of_each_cluster 201 | ] 202 | ) 203 | return new_centroids 204 | 205 | 206 | # ------------------------------------------------------------------------------------------ # 207 | 208 | 209 | def calculate_stds_in_each_cluster(centroids, X): 210 | """ 211 | Calculates the standard deviation of data points in each cluster. 212 | 213 | Args: 214 | centroids (jax.Array): Centroids for the k-means algorithm. 215 | X (jax.Array): Data points. 216 | 217 | Returns: 218 | jax.Array: Sum of the standard deviations of data points in each cluster. 219 | """ 220 | distances_between_centroids_and_points = ( 221 | calculating_distances_between_centroids_and_points(centroids, X) 222 | ) 223 | labels_of_each_point = jnp.argmin(distances_between_centroids_and_points.T, axis=1) 224 | indices_of_each_cluster = [ 225 | jnp.where(labels_of_each_point == label) 226 | for label in jnp.unique(labels_of_each_point) 227 | ] 228 | return jnp.sum( 229 | jnp.asarray( 230 | [ 231 | jnp.std(X[collection_of_indices]) 232 | for collection_of_indices in indices_of_each_cluster 233 | ] 234 | ) 235 | ) 236 | 237 | 238 | # ------------------------------------------------------------------------------------------ # 239 | ### NAIVE BAYES 240 | # ------------------------------------------------------------------------------------------ # 241 | 242 | 243 | def compute_priors(y: jax.Array) -> jax.Array: 244 | """ 245 | Computes the prior probabilities of each class. 246 | 247 | Args: 248 | y (jax.Array): Array of class labels. 249 | 250 | Returns: 251 | jax.Array: Array of prior probabilities for each class. 252 | """ 253 | unique_classes = jnp.unique(y) 254 | prior_probabilities = [] 255 | 256 | for class_ in unique_classes.tolist(): 257 | prior_probabilities.append(jnp.mean(jnp.where(y == class_, 1, 0))) 258 | 259 | return jnp.asarray(prior_probabilities) 260 | 261 | 262 | # ------------------------------------------------------------------------------------------ # 263 | 264 | 265 | def compute_likelihoods(X: jax.Array, y: jax.Array, alpha: int = 0) -> dict: 266 | """ 267 | Computes the likelihoods of each feature given each class. 268 | 269 | Args: 270 | X (jax.Array): Feature matrix. 271 | y (jax.Array): Array of class labels. 272 | alpha (int, optional): Laplace smoothing parameter. Defaults to 0. 273 | 274 | Returns: 275 | dict: Dictionary of likelihoods where each key is a class label and each value is a list of dictionaries, 276 | each representing the probability of each category given the class. 277 | """ 278 | unique_classes = jnp.unique_values(y) 279 | unique_categories_in_every_feature = [jnp.unique(x).tolist() for x in X.T] 280 | collection_of_indices_of_each_class = [ 281 | jnp.where(y == class_) for class_ in unique_classes.tolist() 282 | ] 283 | likelihoods_of_each_class_per_category = { 284 | class_: [] for class_ in unique_classes.tolist() 285 | } 286 | 287 | for class_, collection_of_indices in zip( 288 | unique_classes.tolist(), collection_of_indices_of_each_class 289 | ): 290 | for j, categories in enumerate(unique_categories_in_every_feature): 291 | likelihoods_per_feature = [ 292 | ( 293 | jnp.sum(jnp.where(X[collection_of_indices][:, j] == category, 1, 0)) 294 | + alpha 295 | ) 296 | / (len(X[collection_of_indices][:, j]) + alpha * X.shape[1]) 297 | for category in categories 298 | ] 299 | likelihoods_of_each_class_per_category[class_].append( 300 | { 301 | category: likelihoods_per_feature[ith_category].item() 302 | for ith_category, category in enumerate(categories) 303 | } 304 | ) 305 | 306 | return likelihoods_of_each_class_per_category 307 | 308 | 309 | # ------------------------------------------------------------------------------------------ # 310 | 311 | 312 | def compute_posteriors(X: jax.Array, priors: jax.Array, likelihoods: dict) -> jax.Array: 313 | """ 314 | Computes the posterior probabilities for each class given the feature matrix. 315 | 316 | Args: 317 | X (jax.Array): Feature matrix. 318 | priors (jax.Array): Array of prior probabilities for each class. 319 | likelihoods (dict): Dictionary of likelihoods for each class and feature. 320 | 321 | Returns: 322 | jax.Array: Matrix of posterior probabilities where each row corresponds to a data point 323 | """ 324 | vector_of_posteriors_for_data_point_i = jnp.zeros(len(likelihoods)) 325 | matrix_of_posteriors = [] 326 | 327 | for x in X: 328 | for i in range(len(likelihoods)): 329 | posterior = jnp.log(priors[i]) + jnp.sum( 330 | jnp.log( 331 | jnp.asarray( 332 | [likelihoods[i][j][x_ij.item()] for j, x_ij in enumerate(x)] 333 | ) 334 | ) 335 | ) 336 | vector_of_posteriors_for_data_point_i = ( 337 | vector_of_posteriors_for_data_point_i.at[i].set(posterior) 338 | ) 339 | matrix_of_posteriors.append(vector_of_posteriors_for_data_point_i) 340 | 341 | return jnp.asarray(matrix_of_posteriors) 342 | 343 | 344 | # ------------------------------------------------------------------------------------------ # 345 | 346 | 347 | def gaussian_pdf(x, mean, std) -> jax.Array: 348 | """ 349 | Computes the probability density function of a Gaussian distribution. 350 | 351 | Args: 352 | x (jax.Array): Data points for which to compute the probability density. 353 | mean (jax.Array): Mean of the Gaussian distribution. 354 | std (jax.Array): Standard deviation of the Gaussian distribution. 355 | 356 | Returns: 357 | jax.Array: Probability density values for the given data points. 358 | """ 359 | return jnp.exp(-0.5 * ((x - mean) / std) ** 2) / (std * jnp.sqrt(2 * jnp.pi)) 360 | 361 | 362 | # ------------------------------------------------------------------------------------------ # 363 | 364 | 365 | def compute_means(X: jax.Array, y: jax.Array, random_state=12) -> dict: 366 | """ 367 | Computes the mean of each feature for each class. 368 | 369 | Args: 370 | X (jax.Array): Feature matrix where rows represent samples and columns represent features. 371 | y (jax.Array): Array of class labels corresponding to each sample in X. 372 | random_state (int, optional): Seed for the random number generator. Defaults to 12. 373 | 374 | Returns: 375 | dict: A dictionary where keys are class labels and values are lists of means of features for each class. 376 | """ 377 | np.random.seed(random_state) 378 | 379 | unique_classes = jnp.unique(y).tolist() 380 | indices_for_each_class = [jnp.where(y == class_) for class_ in unique_classes] 381 | 382 | dictionary_of_means = dict( 383 | zip( 384 | unique_classes, 385 | [ 386 | [ 387 | jnp.mean(X[collection_of_indices][:, j]).item() 388 | for j in range(X.shape[1]) 389 | ] 390 | for collection_of_indices in indices_for_each_class 391 | ], 392 | ) 393 | ) 394 | 395 | return dictionary_of_means 396 | 397 | 398 | # ------------------------------------------------------------------------------------------ # 399 | 400 | 401 | def compute_stds(X: jax.Array, y: jax.Array, random_state=12) -> dict: 402 | """ 403 | Computes the standard deviation of each feature for each class. 404 | 405 | Args: 406 | X (jax.Array): Feature matrix where rows represent samples and columns represent features. 407 | y (jax.Array): Array of class labels corresponding to each sample in X. 408 | random_state (int, optional): Seed for the random number generator. Defaults to 12. 409 | 410 | Returns: 411 | dict: A dictionary where keys are class labels and values are lists of standard deviations of features for each class. 412 | """ 413 | np.random.seed(random_state) 414 | 415 | unique_classes = jnp.unique(y).tolist() 416 | indices_for_each_class = [jnp.where(y == class_) for class_ in unique_classes] 417 | 418 | dictionary_of_stds = dict( 419 | zip( 420 | unique_classes, 421 | [ 422 | [ 423 | jnp.std(X[collection_of_indices][:, j]).item() 424 | for j in range(X.shape[1]) 425 | ] 426 | for collection_of_indices in indices_for_each_class 427 | ], 428 | ) 429 | ) 430 | 431 | return dictionary_of_stds 432 | -------------------------------------------------------------------------------- /skjax/_utils/helpers/_linear_model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import jit 4 | 5 | # ------------------------------------------------------------------------------------------ # 6 | ### LinearRegression 7 | # ------------------------------------------------------------------------------------------ # 8 | 9 | 10 | @jit 11 | def compute_mse(y_true, y_pred): 12 | return jnp.mean((y_pred - y_true) ** 2) 13 | 14 | 15 | # ------------------------------------------------------------------------------------------ # 16 | 17 | 18 | @jit 19 | def calculate_loss_gradients( 20 | beta: jax.Array, X: jax.Array, y: jax.Array, p: int = 2, lambda_: float = 0.01 21 | ): 22 | """ 23 | Forward pass on data X. 24 | 25 | Args: 26 | beta (jax.Array): weights and bias. 27 | X (jax.Array): data 28 | """ 29 | return (2 / len(y)) * X.T @ (jnp.dot(X, beta) - y) + lambda_ * p * jnp.sign( 30 | beta 31 | ) * jnp.insert(beta[1:] ** (p - 1), 0, 0) 32 | -------------------------------------------------------------------------------- /skjax/_utils/helpers/_naive_bayes.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | # ------------------------------------------------------------------------------------------ # 6 | ### MultinomialNB 7 | # ------------------------------------------------------------------------------------------ # 8 | 9 | 10 | def compute_priors(y: jax.Array) -> jax.Array: 11 | """ 12 | Computes the prior probabilities of each class. 13 | 14 | Args: 15 | y (jax.Array): Array of class labels. 16 | 17 | Returns: 18 | jax.Array: Array of prior probabilities for each class. 19 | """ 20 | unique_classes = jnp.unique(y) 21 | prior_probabilities = [] 22 | 23 | for class_ in unique_classes.tolist(): 24 | prior_probabilities.append(jnp.mean(jnp.where(y == class_, 1, 0))) 25 | 26 | return jnp.asarray(prior_probabilities) 27 | 28 | 29 | # ------------------------------------------------------------------------------------------ # 30 | 31 | 32 | def compute_likelihoods(X: jax.Array, y: jax.Array, alpha: int = 0) -> dict: 33 | """ 34 | Computes the likelihoods of each feature given each class. 35 | 36 | Args: 37 | X (jax.Array): Feature matrix. 38 | y (jax.Array): Array of class labels. 39 | alpha (int, optional): Laplace smoothing parameter. Defaults to 0. 40 | 41 | Returns: 42 | dict: Dictionary of likelihoods where each key is a class label and each value is a list of dictionaries, 43 | each representing the probability of each category given the class. 44 | """ 45 | unique_classes = jnp.unique_values(y) 46 | unique_categories_in_every_feature = [jnp.unique(x).tolist() for x in X.T] 47 | collection_of_indices_of_each_class = [ 48 | jnp.where(y == class_) for class_ in unique_classes.tolist() 49 | ] 50 | likelihoods_of_each_class_per_category = { 51 | class_: [] for class_ in unique_classes.tolist() 52 | } 53 | 54 | for class_, collection_of_indices in zip( 55 | unique_classes.tolist(), collection_of_indices_of_each_class 56 | ): 57 | for j, categories in enumerate(unique_categories_in_every_feature): 58 | likelihoods_per_feature = [ 59 | ( 60 | jnp.sum(jnp.where(X[collection_of_indices][:, j] == category, 1, 0)) 61 | + alpha 62 | ) 63 | / (len(X[collection_of_indices][:, j]) + alpha * X.shape[1]) 64 | for category in categories 65 | ] 66 | likelihoods_of_each_class_per_category[class_].append( 67 | { 68 | category: likelihoods_per_feature[ith_category].item() 69 | for ith_category, category in enumerate(categories) 70 | } 71 | ) 72 | 73 | return likelihoods_of_each_class_per_category 74 | 75 | 76 | # ------------------------------------------------------------------------------------------ # 77 | 78 | 79 | def compute_posteriors(X: jax.Array, priors: jax.Array, likelihoods: dict) -> jax.Array: 80 | """ 81 | Computes the posterior probabilities for each class given the feature matrix. 82 | 83 | Args: 84 | X (jax.Array): Feature matrix. 85 | priors (jax.Array): Array of prior probabilities for each class. 86 | likelihoods (dict): Dictionary of likelihoods for each class and feature. 87 | 88 | Returns: 89 | jax.Array: Matrix of posterior probabilities where each row corresponds to a data point 90 | """ 91 | vector_of_posteriors_for_data_point_i = jnp.zeros(len(likelihoods)) 92 | matrix_of_posteriors = [] 93 | 94 | for x in X: 95 | for i in range(len(likelihoods)): 96 | posterior = jnp.log(priors[i]) + jnp.sum( 97 | jnp.log( 98 | jnp.asarray( 99 | [likelihoods[i][j][x_ij.item()] for j, x_ij in enumerate(x)] 100 | ) 101 | ) 102 | ) 103 | vector_of_posteriors_for_data_point_i = ( 104 | vector_of_posteriors_for_data_point_i.at[i].set(posterior) 105 | ) 106 | matrix_of_posteriors.append(vector_of_posteriors_for_data_point_i) 107 | 108 | return jnp.asarray(matrix_of_posteriors) 109 | 110 | 111 | # ------------------------------------------------------------------------------------------ # 112 | ### GaussianNB 113 | # ------------------------------------------------------------------------------------------ # 114 | 115 | 116 | def gaussian_pdf(x, mean, std) -> jax.Array: 117 | """ 118 | Computes the probability density function of a Gaussian distribution. 119 | 120 | Args: 121 | x (jax.Array): Data points for which to compute the probability density. 122 | mean (jax.Array): Mean of the Gaussian distribution. 123 | std (jax.Array): Standard deviation of the Gaussian distribution. 124 | 125 | Returns: 126 | jax.Array: Probability density values for the given data points. 127 | """ 128 | return jnp.exp(-0.5 * ((x - mean) / std) ** 2) / (std * jnp.sqrt(2 * jnp.pi)) 129 | 130 | 131 | # ------------------------------------------------------------------------------------------ # 132 | 133 | 134 | def compute_means(X: jax.Array, y: jax.Array, random_state=12) -> dict: 135 | """ 136 | Computes the mean of each feature for each class. 137 | 138 | Args: 139 | X (jax.Array): Feature matrix where rows represent samples and columns represent features. 140 | y (jax.Array): Array of class labels corresponding to each sample in X. 141 | random_state (int, optional): Seed for the random number generator. Defaults to 12. 142 | 143 | Returns: 144 | dict: A dictionary where keys are class labels and values are lists of means of features for each class. 145 | """ 146 | np.random.seed(random_state) 147 | 148 | unique_classes = jnp.unique(y).tolist() 149 | indices_for_each_class = [jnp.where(y == class_) for class_ in unique_classes] 150 | 151 | dictionary_of_means = dict( 152 | zip( 153 | unique_classes, 154 | [ 155 | [ 156 | jnp.mean(X[collection_of_indices][:, j]).item() 157 | for j in range(X.shape[1]) 158 | ] 159 | for collection_of_indices in indices_for_each_class 160 | ], 161 | ) 162 | ) 163 | 164 | return dictionary_of_means 165 | 166 | 167 | # ------------------------------------------------------------------------------------------ # 168 | 169 | 170 | def compute_stds(X: jax.Array, y: jax.Array, random_state=12) -> dict: 171 | """ 172 | Computes the standard deviation of each feature for each class. 173 | 174 | Args: 175 | X (jax.Array): Feature matrix where rows represent samples and columns represent features. 176 | y (jax.Array): Array of class labels corresponding to each sample in X. 177 | random_state (int, optional): Seed for the random number generator. Defaults to 12. 178 | 179 | Returns: 180 | dict: A dictionary where keys are class labels and values are lists of standard deviations of features for each class. 181 | """ 182 | np.random.seed(random_state) 183 | 184 | unique_classes = jnp.unique(y).tolist() 185 | indices_for_each_class = [jnp.where(y == class_) for class_ in unique_classes] 186 | 187 | dictionary_of_stds = dict( 188 | zip( 189 | unique_classes, 190 | [ 191 | [ 192 | jnp.std(X[collection_of_indices][:, j]).item() 193 | for j in range(X.shape[1]) 194 | ] 195 | for collection_of_indices in indices_for_each_class 196 | ], 197 | ) 198 | ) 199 | 200 | return dictionary_of_stds 201 | -------------------------------------------------------------------------------- /skjax/clustering.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from skjax._utils.helpers._clustering import (assign_clusters_to_data, 7 | calculate_new_centroids, 8 | initialize_centroids) 9 | 10 | # ------------------------------------------------------------------------------------------ # 11 | 12 | 13 | class KMeans: 14 | def __init__( 15 | self, 16 | num_clusters: int, 17 | epochs: int = 5, 18 | random_state: int = 5, 19 | ): 20 | """ 21 | Initialize the KMeans clustering algorithm. 22 | 23 | Args: 24 | num_clusters (int): The number of clusters to form. 25 | epochs (int, optional): The number of iterations to run. Default is 25. 26 | init (str, optional): Method for initializing centroids ('random' or other methods). Default is 'random'. 27 | max_patience (int, optional): The maximum number of epochs to wait for improvement before stopping early. Default is 5. 28 | seed (int, optional): Random seed for reproducibility. Default is 12. 29 | """ 30 | self.num_clusters: int = num_clusters 31 | self.epochs: int = epochs 32 | self.random_state: int = random_state 33 | 34 | def fit(self, X: jax.Array) -> None: 35 | """ 36 | Compute the KMeans clustering. 37 | 38 | Args: 39 | X (jax.Array): Input data, where each row is a data point. 40 | 41 | Returns: 42 | self: The instance of the KMeans object with fitted centroids. 43 | """ 44 | self.init_centroids = initialize_centroids(X, num_clusters=self.num_clusters) 45 | 46 | centroids_for_each_data_point = assign_clusters_to_data(X, self.init_centroids) 47 | print(centroids_for_each_data_point) 48 | 49 | for epoch in range(self.epochs): 50 | centroids, centroids_for_each_data_point = calculate_new_centroids( 51 | X, centroids_for_each_data_point, self.num_clusters 52 | ) 53 | 54 | self.centroids = jnp.asarray(list(centroids.values())) 55 | 56 | return self 57 | -------------------------------------------------------------------------------- /skjax/decomposition.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.scipy.linalg import svd 4 | 5 | # ------------------------------------------------------------------------------------------ # 6 | 7 | 8 | class PCA: 9 | """ 10 | Principal Component Analysis (PCA) for dimensionality reduction. 11 | 12 | Attributes: 13 | num_components (int): Number of principal components to keep. 14 | mean (jax.Array, optional): Mean of each feature in the training data. 15 | principal_components (jax.Array, optional): Principal components (eigenvectors) of the training data. 16 | explained_variance (jax.Array, optional): Explained variance of each principal component. 17 | """ 18 | 19 | def __init__(self, num_components: int): 20 | """ 21 | Initialize PCA with the number of components to keep. 22 | 23 | Args: 24 | num_components (int): Number of principal components to retain. 25 | """ 26 | self.num_components = num_components 27 | self.mean = None 28 | self.principal_components = None 29 | self.explained_variance = None 30 | 31 | def fit(self, X: jax.Array): 32 | n, m = X.shape 33 | 34 | if self.mean is None: 35 | self.mean = X.mean(axis=0) 36 | 37 | X_centred = X - self.mean 38 | S, self.principal_components = svd(X_centred, full_matrices=True)[1:] 39 | 40 | self.explained_variance = S**2 / jnp.sum(S**2) 41 | 42 | def transform(self, X: jax.Array): 43 | if self.principal_components is None: 44 | raise RuntimeError("Must fit before transforming.") 45 | 46 | X_centred = X - X.mean(axis=0) 47 | return jnp.dot(X_centred, self.principal_components[: self.num_components].T) 48 | 49 | def fit_transform(self, X: jax.Array): 50 | if self.mean is None: 51 | self.mean = X.mean(axis=0) 52 | 53 | X_centred = X - self.mean 54 | 55 | self.principal_components = svd(X_centred, full_matrices=True)[2] 56 | 57 | return jnp.dot(X_centred, self.principal_components[: self.num_components].T) 58 | 59 | def inverse_transform(self, X_transformed: jax.Array): 60 | if self.principal_components is None: 61 | raise RuntimeError("Must fit before transforming.") 62 | 63 | return ( 64 | jnp.dot(X_transformed, self.principal_components[: self.num_components]) 65 | + self.mean 66 | ) 67 | -------------------------------------------------------------------------------- /skjax/linear_model.py: -------------------------------------------------------------------------------- 1 | """ Model(s) """ 2 | 3 | from typing import Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | from skjax._utils.helpers._helper_functions import (calculate_loss_gradients, 11 | compute_mse) 12 | 13 | # ================================================================================================================ # 14 | 15 | 16 | class LinearRegression: 17 | def fit(self, X: jax.Array, y: jax.Array): 18 | self.design_matrix = jnp.vstack((jnp.ones(X.shape[0]), X)).T 19 | self.hat_matrix = X @ jnp.linalg.inv(X.T @ X) @ X.T 20 | 21 | max_rank_of_design_matrix = jnp.min(X.shape) 22 | rank_of_hat_matrix = jnp.trace( 23 | self.hat_matrix 24 | ) # Rank of hat_matrix = Trace of hat_matrix 25 | 26 | assert jnp.isclose( 27 | rank_of_hat_matrix == max_rank_of_design_matrix 28 | ), f"Rank of the design matrix is not full. Should have {max_rank_of_design_matrix}. Instead got {rank_of_hat_matrix}." 29 | 30 | lhs = jnp.dot(self.design_matrix.T, self.design_matrix) # Gram Matrix 31 | rhs = jnp.dot(self.design_matrix.T, y) 32 | self.coeff = jnp.linalg.solve(lhs, rhs) # Solving Normal Equation 33 | return self 34 | 35 | def predict(self, X: jax.Array): 36 | return jnp.dot(self.design_matrix, self.coeff) 37 | 38 | 39 | # ================================================================================================================ # 40 | class LinearRegressionSGD: 41 | """ 42 | Linear Regression model with options for various weight initialization methods and dropout regularization. 43 | 44 | Attributes: 45 | weights (str): Initialization method for weights ('zero', 'random', 'lecun', 'xavier', 'he'). 46 | epochs (int): Number of epochs for training. 47 | learning_rate (float): Learning rate for the optimizer. 48 | p (int): Number of features in the input data. 49 | lambda_ (float): Regularization parameter for L2 regularization. 50 | max_patience (int): Number of epochs to wait for improvement before early stopping. 51 | dropout (float): Dropout rate to prevent overfitting. 52 | random_state (int): Seed for random number generation. 53 | losses_in_training_data (np.ndarray): Array to store training losses for each epoch. 54 | losses_in_validation_data (np.ndarray): Array to store validation losses for each epoch. 55 | stopped_at (int): Epoch at which training stopped, either due to completion or early stopping. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | weights_init: str = "zero", 61 | epochs: int = 2000, 62 | learning_rate: float = 5e-3, 63 | p: int = 2, 64 | lambda_: float = 0.0, 65 | max_patience: int = 200, 66 | dropout: float = 0.0, 67 | random_state: int = 41, 68 | ): 69 | """ 70 | Initialize the LinearRegressionModel. 71 | 72 | Args: 73 | weights_init (str): Method to initialize weights ('zero', 'random', 'lecun', 'xavier', 'he'). Default is 'zero'. 74 | epochs (int): Number of epochs for training. Default is 2000. 75 | learning_rate (float): Learning rate for optimization. Default is 0.0005. 76 | p (int): Number of features in the dataset. Default is 2. 77 | lambda_ (float): Regularization parameter for L2 regularization. Default is 0. 78 | max_patience (int): Maximum number of epochs to wait for improvement before early stopping. Default is 200. 79 | dropout (float): Dropout rate to prevent overfitting. Default is 0. 80 | random_state (int): Seed for random number generation. Default is 41. 81 | """ 82 | self.weights = weights_init 83 | self.epochs = epochs 84 | self.learning_rate = learning_rate 85 | self.p = p 86 | self.lambda_ = lambda_ 87 | self.max_patience = max_patience 88 | self.dropout = dropout 89 | self.random_state = random_state 90 | self.losses_in_training_data = np.zeros(epochs) 91 | self.losses_in_validation_data = np.zeros(epochs) 92 | self.stopped_at = epochs # Records epoch stopped at. First assumes training using every epoch. 93 | 94 | def fit( 95 | self, 96 | X_train: jax.Array, 97 | y_train: jax.Array, 98 | X_val: Optional[jax.Array] = None, 99 | y_val: Optional[jax.Array] = None, 100 | ): 101 | """ 102 | Fit the model to the training data. 103 | 104 | Args: 105 | X_train (jax.Array): Training data features. 106 | y_train (jax.Array): Training data labels. 107 | X_val (jax.Array, optional): Validation data features. Default is None. 108 | y_val (jax.Array, optional): Validation data labels. Default is None. 109 | 110 | Returns: 111 | self: The instance of the LinearRegression object with fitted weights. 112 | """ 113 | # Initializing parameters 114 | best_beta: Optional[jax.Array] = None 115 | 116 | n = len(X_train) 117 | 118 | best_mse = jnp.inf 119 | 120 | patience_counter = 0 121 | 122 | key = jax.random.key(self.random_state) 123 | 124 | # Defining the weights initializers 125 | weights_init_dict = { 126 | "zero": jnp.zeros(X_train.shape[1]), 127 | "random": jax.random.normal(key, shape=(X_train.shape[1],)), 128 | "lecun": jax.random.normal(key, shape=(X_train.shape[1],)) 129 | * jnp.sqrt(1 / X_train.shape[0]), 130 | "xavier": jax.random.normal(key, shape=(X_train.shape[1],)) 131 | * jnp.sqrt(2 / (X_train.shape[0] + y_train.shape[0])), 132 | "he": jax.random.normal(key, shape=(X_train.shape[1],)) 133 | * jnp.sqrt(2 / X_train.shape[0]), 134 | } 135 | 136 | self.weights = weights_init_dict[self.weights] 137 | 138 | # Training Loop 139 | for epoch in range(self.epochs): 140 | # Dropout 141 | key, subkey = jax.random.split(key) 142 | dropout_mask = jax.random.bernoulli( 143 | subkey, p=(1 - self.dropout), shape=(n, 1) 144 | ) 145 | X_train_dropout = X_train * dropout_mask 146 | 147 | # Calculating loss on training data 148 | mse_train = compute_mse( 149 | y_pred=jnp.dot(X_train_dropout, self.weights), y_true=y_train 150 | ) 151 | self.losses_in_training_data[epoch] = mse_train 152 | 153 | # Calculate loss gradients 154 | loss_gradient_wrt_beta = calculate_loss_gradients( 155 | self.weights, X_train_dropout, y_train, self.p, self.lambda_ 156 | ) 157 | 158 | # Optimiser step 159 | self.weights -= self.learning_rate * loss_gradient_wrt_beta 160 | 161 | if X_val is not None and y_val is not None: 162 | # Validation step 163 | mse_val = compute_mse(y_val, jnp.dot(X_val, self.weights)) 164 | self.losses_in_validation_data[epoch] = mse_val 165 | 166 | # Potential early stopping 167 | if mse_val < best_mse: 168 | best_mse = mse_val 169 | patience_counter = 0 170 | best_beta = self.weights 171 | else: 172 | patience_counter += 1 173 | 174 | if patience_counter >= self.max_patience: 175 | print(f"Stopped at epoch {epoch+1}.") 176 | self.stopped_at = epoch + 1 177 | break 178 | 179 | if X_val is not None and y_val is not None: 180 | self.weights = best_beta 181 | 182 | return self 183 | 184 | def predict(self, X_test: jax.Array): 185 | """ 186 | Predict the labels for the given test data. 187 | 188 | Args: 189 | X_test (jax.Array): Test data features. 190 | 191 | Returns: 192 | jax.Array: Predicted labels for the test data. 193 | """ 194 | return jnp.dot(X_test, self.weights) 195 | 196 | def plot_losses(self): 197 | """ 198 | Plot training and validation losses over epochs. 199 | 200 | Displays a plot of Mean Squared Error (MSE) for both training and validation data across epochs. 201 | """ 202 | plt.figure(figsize=(10, 5)) 203 | # Plotting training losses 204 | plt.title("MSE vs Epochs") 205 | plt.plot( 206 | range(self.stopped_at), 207 | self.losses_in_training_data[: self.stopped_at], 208 | c="blue", 209 | label="Training", 210 | ) 211 | # Plotting validation losses 212 | plt.plot( 213 | range(self.stopped_at), 214 | self.losses_in_validation_data[: self.stopped_at], 215 | c="orange", 216 | label="Valdation", 217 | ) 218 | plt.legend() 219 | plt.show() 220 | -------------------------------------------------------------------------------- /skjax/naive_bayes.py: -------------------------------------------------------------------------------- 1 | """ Implementing Multinomial Naive Bayes """ 2 | 3 | from typing import Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jax import tree 8 | 9 | from skjax._utils.helpers._helper_functions import (compute_likelihoods, 10 | compute_means, 11 | compute_posteriors, 12 | compute_priors, 13 | compute_stds, gaussian_pdf) 14 | 15 | # ------------------------------------------------------------------------------------------ # 16 | 17 | 18 | class MultinomialNaiveBayes: 19 | """ 20 | A class implementing the Multinomial Naive Bayes classifier. 21 | 22 | This classifier is suitable for discrete data, such as text classification tasks where features represent word counts or frequencies. It calculates prior probabilities and feature likelihoods from the training data and uses these to predict the class of new data points. 23 | 24 | Attributes: 25 | alpha (Optional[jax.Array]): The smoothing parameter for the likelihood estimates. If None, no smoothing is applied. Smoothing is used to handle cases where some feature values might not appear in the training data for a given class. 26 | priors (jax.Array): The prior probabilities of each class, computed from the training labels. 27 | likelihoods (dict): A dictionary where each key is a class label and each value is a JAX array of feature likelihoods for that class. 28 | 29 | Methods: 30 | fit(X: jax.Array, y: jax.Array): 31 | Trains the Multinomial Naive Bayes model on the provided feature matrix and labels. 32 | 33 | predict(X: jax.Array): 34 | Predicts class labels for the provided feature matrix using the trained model. 35 | 36 | """ 37 | 38 | def __init__(self, alpha: int = 0): 39 | """ 40 | Initializes the MultinomialNaiveBayes classifier. 41 | 42 | Args: 43 | alpha (Optional[jax.Array]): The smoothing parameter for likelihood estimation. If None, no smoothing is applied. Defaults to None. 44 | """ 45 | self.alpha = alpha 46 | self.priors = None 47 | self.likelihoods = None 48 | 49 | def fit(self, X: jax.Array, y: jax.Array): 50 | """ 51 | Fits the model to the training data. 52 | 53 | Args: 54 | X (jax.Array): The feature matrix for the training data, where each feature represents discrete counts or frequencies. 55 | y (jax.Array): The vector of class labels corresponding to the training data. 56 | 57 | Returns: 58 | self: The fitted MultinomialNaiveBayes instance. 59 | """ 60 | ### CALCULATING PRIORS 61 | self.priors = compute_priors(y) 62 | 63 | ### CALCULATING LIKELIHOODS 64 | self.likelihoods = compute_likelihoods(X, y, self.alpha) 65 | 66 | return self 67 | 68 | def predict(self, X: jax.Array): 69 | """ 70 | Predicts class labels for the given data using the trained model. 71 | 72 | Args: 73 | X (jax.Array): The feature matrix for which predictions are to be made. Each feature represents discrete counts or frequencies. 74 | 75 | Returns: 76 | jax.Array: An array of predicted class labels for each sample in X. 77 | """ 78 | return jnp.argmax(compute_posteriors(X, self.priors, self.likelihoods), axis=1) 79 | 80 | 81 | # ------------------------------------------------------------------------------------------ # 82 | 83 | 84 | class GaussianNaiveBayes: 85 | """ 86 | A class implementing the Gaussian Naive Bayes classifier. 87 | 88 | This classifier assumes that the features follow a Gaussian distribution and is used for classification tasks. It calculates the priors, means, and standard deviations of each class from the training data and uses these statistics to predict the class of new data points. 89 | 90 | Attributes: 91 | priors (jax.Array): The prior probabilities of each class, computed from the training labels. 92 | means (dict): A dictionary where each key is a class label and each value is a JAX array of feature means for that class. 93 | stds (dict): A dictionary where each key is a class label and each value is a JAX array of feature standard deviations for that class. 94 | seed (int): The random seed for reproducibility. 95 | 96 | Methods: 97 | fit(X: jax.Array, y: jax.Array): 98 | Trains the Gaussian Naive Bayes model on the provided feature matrix and labels. 99 | 100 | predict(X: jax.Array): 101 | Predicts class labels for the provided feature matrix using the trained model. 102 | 103 | """ 104 | 105 | def __init__(self, seed: int = 12): 106 | """ 107 | Initializes the GaussianNaiveBayes classifier. 108 | 109 | Args: 110 | seed (int): The random seed for initializing random number generators used in the computations. Defaults to 12. 111 | """ 112 | self.priors: jax.Array = None 113 | self.means: dict = None 114 | self.stds: dict = None 115 | self.seed: int = seed 116 | 117 | def fit(self, X: jax.Array, y: jax.Array): 118 | """ 119 | Fits the model to the training data. 120 | 121 | Args: 122 | X (jax.Array): The feature matrix for the training data. 123 | y (jax.Array): The vector of class labels corresponding to the training data. 124 | 125 | Returns: 126 | self: The fitted GaussianNaiveBayes instance. 127 | """ 128 | self.priors = compute_priors(y) 129 | self.means = compute_means(X, y, self.seed) 130 | self.stds = compute_stds(X, y, self.seed) 131 | 132 | return self 133 | 134 | def predict(self, X: jax.Array): 135 | """ 136 | Predicts class labels for the given data using the trained model. 137 | 138 | Args: 139 | X (jax.Array): The feature matrix for which predictions are to be made. 140 | 141 | Returns: 142 | jax.Array: An array of predicted class labels for each sample in X. 143 | """ 144 | posteriors = [] 145 | for x in X: 146 | likelihoods = jnp.array( 147 | [ 148 | gaussian_pdf(x, jnp.array(means), jnp.array(stds)) 149 | for means, stds in zip(self.means.values(), self.stds.values()) 150 | ] 151 | ) 152 | vector_of_posteriors = jnp.log(jnp.dot(likelihoods, self.priors)) 153 | posteriors.append(vector_of_posteriors) 154 | 155 | return jnp.argmax(jnp.array(posteriors), axis=1) 156 | -------------------------------------------------------------------------------- /tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/.DS_Store -------------------------------------------------------------------------------- /tests/__pycache__/test_PCA.cpython-311-pytest-8.3.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_PCA.cpython-311-pytest-8.3.2.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_PCA.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_PCA.cpython-311.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_clustering.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_clustering.cpython-311.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_decomposition.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_decomposition.cpython-311.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_gaussian_naive_bayes.cpython-311-pytest-8.3.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_gaussian_naive_bayes.cpython-311-pytest-8.3.2.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_gaussian_naive_bayes.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_gaussian_naive_bayes.cpython-311.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_k_means.cpython-311-pytest-8.3.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_k_means.cpython-311-pytest-8.3.2.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_k_means.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_k_means.cpython-311.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_linear_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_linear_model.cpython-311.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_linear_regression.cpython-311-pytest-8.3.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_linear_regression.cpython-311-pytest-8.3.2.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_linear_regression.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_linear_regression.cpython-311.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_linear_regression_sgd.cpython-311-pytest-8.3.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_linear_regression_sgd.cpython-311-pytest-8.3.2.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_multinomial_naive_bayes.cpython-311-pytest-8.3.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_multinomial_naive_bayes.cpython-311-pytest-8.3.2.pyc -------------------------------------------------------------------------------- /tests/__pycache__/test_multinomial_naive_bayes.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiibanMo/scikit-jax/03a6b57134f335c67f255083b850b4ef76367e13/tests/__pycache__/test_multinomial_naive_bayes.cpython-311.pyc -------------------------------------------------------------------------------- /tests/files/multiple_linear_regression_dataset.csv: -------------------------------------------------------------------------------- 1 | age,experience,income 2 | 25,1,30450 3 | 30,3,35670 4 | 47,2,31580 5 | 32,5,40130 6 | 43,10,47830 7 | 51,7,41630 8 | 28,5,41340 9 | 33,4,37650 10 | 37,5,40250 11 | 39,8,45150 12 | 29,1,27840 13 | 47,9,46110 14 | 54,5,36720 15 | 51,4,34800 16 | 44,12,51300 17 | 41,6,38900 18 | 58,17,63600 19 | 23,1,30870 20 | 44,9,44190 21 | 37,10,48700 22 | -------------------------------------------------------------------------------- /tests/test_PCA.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | 4 | from skjax.decomposition import PCA 5 | 6 | num_components = 2 7 | 8 | 9 | @pytest.fixture 10 | def setup(): 11 | """Sets up test data and PCA instance.""" 12 | X = jnp.array( 13 | [ 14 | [2.5, 2.4, 3.5], 15 | [0.5, 0.7, 2.1], 16 | [2.2, 2.9, 3.3], 17 | [1.9, 2.2, 2.9], 18 | [3.1, 3.0, 3.8], 19 | ] 20 | ) 21 | 22 | pca = PCA(num_components=num_components) 23 | 24 | return X, pca 25 | 26 | 27 | def test_initialization(setup): 28 | """Tests initialization of PCA instance.""" 29 | 30 | _, pca = setup 31 | 32 | assert pca.mean is None 33 | assert pca.principal_components is None 34 | assert pca.explained_variance is None 35 | 36 | 37 | def test_fit(setup): 38 | """Tests the fit method of PCA.""" 39 | X, pca = setup 40 | pca.fit(X) 41 | 42 | assert pca.mean is not None 43 | assert pca.principal_components is not None 44 | assert pca.explained_variance is not None 45 | assert pca.principal_components.shape[0] == X.shape[1] 46 | assert pca.explained_variance.shape[0] == X.shape[1] 47 | 48 | 49 | def test_transform(setup): 50 | """Tests the transform method of PCA.""" 51 | X, pca = setup 52 | pca.fit(X) 53 | 54 | X_transformed = pca.transform(X) 55 | 56 | assert X_transformed.shape[1] == num_components 57 | assert jnp.all(jnp.isfinite(X_transformed)) 58 | 59 | 60 | def test_fit_transform(setup): 61 | """Tests the fit_transform method of PCA.""" 62 | X, pca = setup 63 | X_transformed = pca.fit_transform(X) 64 | 65 | assert X_transformed.shape[1] == num_components 66 | assert jnp.all(jnp.isfinite(X_transformed)) 67 | assert pca.mean is not None 68 | assert pca.principal_components is not None 69 | 70 | 71 | def test_inverse_transform(setup): 72 | """Tests the inverse_transform method of PCA.""" 73 | X, pca = setup 74 | pca.fit(X) 75 | 76 | X_transformed = pca.transform(X) 77 | X_reconstructed = pca.inverse_transform(X_transformed) 78 | assert X_reconstructed.shape == X.shape 79 | assert jnp.all(jnp.isfinite(X_reconstructed)) 80 | assert jnp.allclose(X[:, 0], X_reconstructed[:, 0], atol=10) 81 | 82 | 83 | def test_exceptions(setup): 84 | """Tests that exceptions are raised when fitting is not done before transforming.""" 85 | X, pca = setup 86 | 87 | with pytest.raises(RuntimeError): 88 | pca.transform(X) 89 | 90 | with pytest.raises(RuntimeError): 91 | pca.inverse_transform(X) 92 | -------------------------------------------------------------------------------- /tests/test_gaussian_naive_bayes.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | 4 | from skjax.naive_bayes import GaussianNaiveBayes 5 | 6 | 7 | @pytest.fixture 8 | def setup(): 9 | model = GaussianNaiveBayes(seed=42) 10 | X = jnp.array([[1.0, 2.0], [1.5, 1.8], [2.0, 2.2], [3.0, 3.0], [3.5, 3.5]]) 11 | y = jnp.array([0, 0, 0, 1, 1]) 12 | return X, y, model 13 | 14 | 15 | def test_initialization(setup): 16 | """Test initialization of GaussianNaiveBayes instance.""" 17 | _, _, model = setup 18 | 19 | assert model.priors is None 20 | assert model.means is None 21 | assert model.stds is None 22 | 23 | 24 | def test_fit(setup): 25 | """Test the fit method of GaussianNaiveBayes.""" 26 | X, y, model = setup 27 | model.fit(X, y) 28 | assert model.priors is not None 29 | assert model.means is not None 30 | assert model.stds is not None 31 | 32 | # Check that priors, means, and stds are not empty 33 | assert len(model.priors) > 0 34 | assert len(model.means) > 0 35 | assert len(model.stds) > 0 36 | 37 | 38 | def test_predict(setup): 39 | """Test the predict method of GaussianNaiveBayes.""" 40 | X, y, model = setup 41 | model.fit(X, y) 42 | predictions = model.predict(X) 43 | assert predictions.shape == y.shape 44 | assert jnp.all(predictions >= 0) 45 | assert jnp.all(predictions < len(jnp.unique(y))) 46 | -------------------------------------------------------------------------------- /tests/test_k_means.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import jax.numpy as jnp 4 | import pandas as pd 5 | import pytest 6 | 7 | from skjax.clustering import KMeans 8 | 9 | X1 = pd.read_csv("tests/files/basic5.csv").drop(columns="color").to_numpy(jnp.float32) 10 | X2 = jnp.array( 11 | [ 12 | [1, 2], 13 | [1, 3], 14 | [2, 2], 15 | [2, 3], # Cluster 1 16 | [8, 9], 17 | [8, 8], 18 | [9, 9], 19 | [9, 8], # Cluster 2 20 | ] 21 | ) 22 | X3 = jnp.array( 23 | [ 24 | [1, 1], 25 | [1, 2], 26 | [2, 1], 27 | [2, 2], # Cluster 1 28 | [5, 5], 29 | [5, 6], 30 | [6, 5], 31 | [6, 6], # Cluster 2 32 | [8, 1], 33 | [8, 2], 34 | [9, 1], 35 | [9, 2], # Cluster 3 36 | ] 37 | ) 38 | 39 | 40 | def test_01_fit_time(): 41 | model = KMeans(num_clusters=3) 42 | start_time = time.time() 43 | model.fit(X1) 44 | end_time = time.time() 45 | 46 | time_taken = end_time - start_time 47 | # Assert that time taken is within a reasonable limit, e.g., 1 second 48 | assert time_taken < 5.0, f"Model fitting took too long: {time_taken:.4f}s" 49 | 50 | 51 | def test_02_fit(): 52 | """Test the fit method of KMeans.""" 53 | model = KMeans(num_clusters=2).fit(X2) 54 | true_centroids = jnp.array([[1.5, 2.5], [8.5, 8.5]]).sort(axis=0) 55 | pred_centroids = model.centroids.sort(axis=0) 56 | assert jnp.allclose(true_centroids, pred_centroids) 57 | 58 | 59 | def test_03_fit(): 60 | """Test the fit method of KMeans.""" 61 | model = KMeans(num_clusters=3).fit(X3) 62 | true_centroids = jnp.array([[8.5, 1.5], [1.5, 1.5], [5.5, 5.5]]).sort(axis=0) 63 | pred_centroids = model.centroids.sort(axis=0) 64 | assert jnp.allclose(true_centroids, pred_centroids) 65 | -------------------------------------------------------------------------------- /tests/test_linear_regression.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pytest 5 | 6 | from skjax.linear_model import LinearRegression 7 | 8 | np.random.seed(42) 9 | 10 | 11 | @pytest.fixture 12 | def generate_basic_linear(): 13 | X = np.linspace(0, 10, 100) 14 | y = 2 * X + 1 + np.random.normal(0, 0.5, X.shape) 15 | return X, y 16 | 17 | 18 | def data_for_higher_dimensional_fitting(): 19 | pass 20 | -------------------------------------------------------------------------------- /tests/test_linear_regression_sgd.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pytest 4 | 5 | from skjax.linear_model import LinearRegressionSGD 6 | 7 | 8 | @pytest.fixture 9 | def setup_data(): 10 | """Fixture to set up test data and LinearRegression instance.""" 11 | X_train = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 12 | y_train = jnp.array([1.0, 2.0, 3.0]) 13 | X_val = jnp.array([[7.0, 8.0], [9.0, 10.0]]) 14 | y_val = jnp.array([4.0, 5.0]) 15 | 16 | model = LinearRegressionSGD( 17 | weights_init="random", 18 | epochs=10, 19 | learning_rate=0.01, 20 | p=2, 21 | lambda_=0.1, 22 | max_patience=2, 23 | dropout=0.1, 24 | random_state=42, 25 | ) 26 | return X_train, y_train, X_val, y_val, model 27 | 28 | 29 | def test_initialization(setup_data): 30 | """Test initialization of LinearRegression instance.""" 31 | _, _, _, _, model = setup_data 32 | 33 | assert model.weights == "random" 34 | assert model.epochs == 10 35 | assert model.learning_rate == 0.01 36 | assert model.p == 2 37 | assert model.lambda_ == 0.1 38 | assert model.max_patience == 2 39 | assert model.dropout == 0.1 40 | assert model.random_state == 42 41 | 42 | assert isinstance(model.weights, str) 43 | assert isinstance(model.epochs, int) 44 | assert isinstance(model.learning_rate, float) 45 | assert isinstance(model.p, int) 46 | assert isinstance(model.dropout, float) 47 | assert isinstance(model.random_state, int) 48 | assert isinstance(model.losses_in_training_data, np.ndarray) 49 | assert isinstance(model.losses_in_validation_data, np.ndarray) 50 | assert isinstance(model.stopped_at, int) 51 | 52 | assert model.epochs >= 1 53 | assert model.dropout <= 1 54 | assert model.lambda_ >= 0 55 | assert model.stopped_at >= 1 56 | 57 | 58 | def test_fit(setup_data): 59 | """Test the fit method.""" 60 | X_train, y_train, _, _, model = setup_data 61 | 62 | model = model.fit(X_train, y_train) 63 | 64 | assert isinstance(model, LinearRegressionSGD) 65 | assert len(model.losses_in_training_data) == model.epochs 66 | assert len(model.losses_in_validation_data) == model.epochs 67 | assert model.stopped_at >= model.epochs 68 | 69 | 70 | def test_predict(setup_data): 71 | """Test the predict method.""" 72 | X_train, y_train, X_val, y_val, model = setup_data 73 | 74 | model.fit(X_train, y_train) 75 | predictions = model.predict(X_val) 76 | 77 | assert predictions.shape == y_val.shape 78 | assert jnp.all(jnp.isfinite(predictions)) 79 | 80 | 81 | def test_losses_plot(setup_data): 82 | """Test if plotting method runs without error.""" 83 | X_train, y_train, X_val, y_val, model = setup_data 84 | 85 | model.fit(X_train, y_train, X_val, y_val) 86 | try: 87 | model.plot_losses() 88 | except Exception as e: 89 | pytest.fail(f"plot_losses() raised an exception: {e}") 90 | -------------------------------------------------------------------------------- /tests/test_multinomial_naive_bayes.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | 4 | from skjax.naive_bayes import MultinomialNaiveBayes 5 | 6 | 7 | @pytest.fixture 8 | def setup(): 9 | """Set up test data and MultinomialNaiveBayes instance.""" 10 | X = jnp.array([[1, 2, 0], [2, 1, 1], [0, 1, 3], [1, 0, 2], [2, 2, 1]]) 11 | y = jnp.array([0, 1, 0, 1, 0]) 12 | 13 | model = MultinomialNaiveBayes(alpha=1.0) 14 | 15 | return X, y, model 16 | 17 | 18 | def test_initialization(setup): 19 | """Test initialization of MultinomialNaiveBayes instance.""" 20 | _, _, model = setup 21 | 22 | assert model.alpha == 1.0 23 | assert model.priors is None 24 | assert model.likelihoods is None 25 | 26 | 27 | def test_fit(setup): 28 | """Test the fit method of MultinomialNaiveBayes.""" 29 | X, y, model = setup 30 | 31 | model.fit(X, y) 32 | assert model.priors is not None 33 | assert model.likelihoods is not None 34 | assert model.priors.shape[0] == len(jnp.unique(y)) 35 | for class_label in model.likelihoods.keys(): 36 | assert len(model.likelihoods[class_label]) == X.shape[1] 37 | 38 | 39 | def test_predict(setup): 40 | """Test the predict method of MultinomialNaiveBayes.""" 41 | X, y, model = setup 42 | 43 | model.fit(X, y) 44 | predictions = model.predict(X) 45 | 46 | assert predictions.shape == y.shape 47 | assert jnp.all(predictions >= 0) 48 | assert jnp.all(predictions < len(jnp.unique(y))) 49 | 50 | 51 | def test_smoothing(setup): 52 | """Test the effect of smoothing parameter on fit.""" 53 | X, y, model = setup 54 | 55 | model = MultinomialNaiveBayes(alpha=0.5) 56 | model.fit(X, y) 57 | 58 | assert model.likelihoods is not None 59 | 60 | 61 | def test_no_smoothing(setup): 62 | """Test fitting without smoothing.""" 63 | X, y, model = setup 64 | 65 | model = MultinomialNaiveBayes(alpha=0) 66 | model.fit(X, y) 67 | assert model.likelihoods is not None 68 | --------------------------------------------------------------------------------