├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── .idea
├── .gitignore
├── dictionaries
│ └── umang.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── mypy.xml
└── vcs.xml
├── LICENSE
├── README.md
├── assets
├── logo.png
├── quickstart_code_acc_curves.png
└── quickstart_code_loss_curves.png
├── examples
├── mnist_classifier.py
└── tiny_mnist_binary_classifier.py
├── kerax
├── __init__.py
├── data
│ └── __init__.py
├── datasets
│ ├── __init__.py
│ ├── binary_tiny_mnist
│ │ ├── __init__.py
│ │ ├── test.csv
│ │ └── train.csv
│ └── mnist
│ │ └── __init__.py
├── layers
│ ├── __init__.py
│ ├── activations.py
│ └── layers.py
├── losses
│ └── __init__.py
├── metrics
│ └── __init__.py
├── models
│ ├── __init__.py
│ └── sequential.py
├── optimizers
│ └── __init__.py
└── utils
│ ├── __init__.py
│ ├── interpreter.py
│ ├── serialization.py
│ ├── tensor.py
│ └── trainer.py
├── requirements.txt
├── setup.py
└── tests
├── test_data.py
├── test_datasets.py
├── test_layers.py
├── test_losses.py
├── test_metrics.py
└── test_sequential.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /.mypy_cache/3.7/@plugins_snapshot.json
2 | /.mypy_cache/3.7/__future__.data.json
3 | /.mypy_cache/3.7/__future__.meta.json
4 | /.mypy_cache/3.7/_ast.data.json
5 | /.mypy_cache/3.7/_ast.meta.json
6 | /.mypy_cache/3.7/_importlib_modulespec.data.json
7 | /.mypy_cache/3.7/_importlib_modulespec.meta.json
8 | /.mypy_cache/3.7/_warnings.data.json
9 | /.mypy_cache/3.7/_warnings.meta.json
10 | /.mypy_cache/3.7/_weakref.data.json
11 | /.mypy_cache/3.7/_weakref.meta.json
12 | /.mypy_cache/3.7/_weakrefset.data.json
13 | /.mypy_cache/3.7/_weakrefset.meta.json
14 | /.mypy_cache/3.7/abc.data.json
15 | /.mypy_cache/3.7/abc.meta.json
16 | /.mypy_cache/3.7/array.data.json
17 | /.mypy_cache/3.7/array.meta.json
18 | /.mypy_cache/3.7/ast.data.json
19 | /.mypy_cache/3.7/ast.meta.json
20 | /.mypy_cache/3.7/builtins.data.json
21 | /.mypy_cache/3.7/builtins.meta.json
22 | /.mypy_cache/3.7/codecs.data.json
23 | /.mypy_cache/3.7/codecs.meta.json
24 | /.mypy_cache/3.7/contextlib.data.json
25 | /.mypy_cache/3.7/contextlib.meta.json
26 | /.mypy_cache/3.7/copy.data.json
27 | /.mypy_cache/3.7/copy.meta.json
28 | /.mypy_cache/3.7/difflib.data.json
29 | /.mypy_cache/3.7/difflib.meta.json
30 | /.mypy_cache/3.7/inspect.data.json
31 | /.mypy_cache/3.7/inspect.meta.json
32 | /.mypy_cache/3.7/io.data.json
33 | /.mypy_cache/3.7/io.meta.json
34 | /.mypy_cache/3.7/itertools.data.json
35 | /.mypy_cache/3.7/itertools.meta.json
36 | /.mypy_cache/3.7/math.data.json
37 | /.mypy_cache/3.7/math.meta.json
38 | /.mypy_cache/3.7/mmap.data.json
39 | /.mypy_cache/3.7/mmap.meta.json
40 | /.mypy_cache/3.7/pathlib.data.json
41 | /.mypy_cache/3.7/pathlib.meta.json
42 | /.mypy_cache/3.7/pickle.data.json
43 | /.mypy_cache/3.7/pickle.meta.json
44 | /.mypy_cache/3.7/posix.data.json
45 | /.mypy_cache/3.7/posix.meta.json
46 | /.mypy_cache/3.7/queue.data.json
47 | /.mypy_cache/3.7/queue.meta.json
48 | /.mypy_cache/3.7/shutil.data.json
49 | /.mypy_cache/3.7/shutil.meta.json
50 | /.mypy_cache/3.7/struct.data.json
51 | /.mypy_cache/3.7/struct.meta.json
52 | /.mypy_cache/3.7/sys.data.json
53 | /.mypy_cache/3.7/sys.meta.json
54 | /.mypy_cache/3.7/tarfile.data.json
55 | /.mypy_cache/3.7/tarfile.meta.json
56 | /.mypy_cache/3.7/tempfile.data.json
57 | /.mypy_cache/3.7/tempfile.meta.json
58 | /.mypy_cache/3.7/test.data.json
59 | /.mypy_cache/3.7/test.meta.json
60 | /.mypy_cache/3.7/threading.data.json
61 | /.mypy_cache/3.7/threading.meta.json
62 | /.mypy_cache/3.7/traceback.data.json
63 | /.mypy_cache/3.7/traceback.meta.json
64 | /.mypy_cache/3.7/types.data.json
65 | /.mypy_cache/3.7/types.meta.json
66 | /.mypy_cache/3.7/typing.data.json
67 | /.mypy_cache/3.7/typing.meta.json
68 | /.mypy_cache/3.7/typing_extensions.data.json
69 | /.mypy_cache/3.7/typing_extensions.meta.json
70 | /.mypy_cache/3.7/warnings.data.json
71 | /.mypy_cache/3.7/warnings.meta.json
72 | /.mypy_cache/3.7/weakref.data.json
73 | /.mypy_cache/3.7/weakref.meta.json
74 | /.mypy_cache/
75 | /.idea/
76 | /.idea/dictionaries/
77 | /datasets/mnist/
78 | .DS_Store
79 | /venv/
80 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
4 | /dictionaries/
5 |
--------------------------------------------------------------------------------
/.idea/dictionaries/umang.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | adagrad
5 | bfloat
6 | conv
7 | crossentropy
8 | dataset
9 | datasets
10 | dtype
11 | fastai
12 | glorot
13 | jaxlib
14 | matplotlib
15 | mish
16 | mnist
17 | nchw
18 | ndarray
19 | ndarrays
20 | nhwc
21 | npscalar
22 | pytree
23 | relu
24 | rmsprop
25 | softmax
26 | softplus
27 | subkey
28 | tanh
29 | tfds
30 | tqdm
31 |
32 |
33 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/mypy.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | # Kerax
6 |
7 | Keras-like APIs for the [JAX](https://github.com/google/jax) library.
8 |
9 | ## Features
10 |
11 | * Enables high-performance machine learning research.
12 | * Built-in support of popular optimization algorithms and activation functions.
13 | * Runs seamlessly on CPU, GPU and even TPU! without any manual configuration required.
14 |
15 | ## Quickstart
16 |
17 | ### Code
18 |
19 | ```python3
20 | from kerax.datasets import binary_tiny_mnist
21 | from kerax.layers import Dense, Relu, Sigmoid
22 | from kerax.losses import BCELoss
23 | from kerax.metrics import binary_accuracy
24 | from kerax.models import Sequential
25 | from kerax.optimizers import SGD
26 |
27 | data = binary_tiny_mnist.load_dataset(batch_size=200)
28 | model = Sequential([Dense(100), Relu, Dense(1), Sigmoid])
29 | model.compile(loss=BCELoss, optimizer=SGD(step_size=0.003), metrics=[binary_accuracy])
30 | model.fit(data=data, epochs=10)
31 | model.save(file_name="model")
32 |
33 | interp = model.get_interpretation()
34 | interp.plot_losses()
35 | interp.plot_accuracy()
36 | ```
37 | ### Output
38 |
39 | ```terminal
40 | WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
41 | Epoch 10: 100%|██████████| 10/10 [00:02<00:00, 3.82it/s, train_loss : 0.192 :: valid_loss : 0.202 :: train_binary_accuracy : 1.000 :: valid_binary_accuracy : 1.000]
42 |
43 | Process finished with exit code 0
44 | ```
45 |
46 | 
47 | 
48 |
49 | ## Documentation (Coming soon...)
50 |
51 | ## Developer's Notes
52 |
53 | This project is developed and maintained by [Umang Patel](https://github.com/umangjpatel)
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/umangjpatel/kerax/21e1798643a6382a1d7960db4c5f3a22fa19a28a/assets/logo.png
--------------------------------------------------------------------------------
/assets/quickstart_code_acc_curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/umangjpatel/kerax/21e1798643a6382a1d7960db4c5f3a22fa19a28a/assets/quickstart_code_acc_curves.png
--------------------------------------------------------------------------------
/assets/quickstart_code_loss_curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/umangjpatel/kerax/21e1798643a6382a1d7960db4c5f3a22fa19a28a/assets/quickstart_code_loss_curves.png
--------------------------------------------------------------------------------
/examples/mnist_classifier.py:
--------------------------------------------------------------------------------
1 | from kerax.datasets import mnist
2 | from kerax.layers import Flatten, Dense, Relu, LogSoftmax
3 | from kerax.losses import CCELoss
4 | from kerax.metrics import accuracy
5 | from kerax.models import Sequential
6 | from kerax.optimizers import RMSProp
7 |
8 | data = mnist.load_dataset(batch_size=1024)
9 | model = Sequential([Flatten, Dense(100), Relu, Dense(10), LogSoftmax])
10 | model.compile(loss=CCELoss, optimizer=RMSProp(step_size=0.001), metrics=[accuracy])
11 | model.fit(data, epochs=10)
12 | model.save("tfds_mnist_v1")
13 | interp = model.get_interpretation()
14 | interp.plot_losses()
15 | interp.plot_accuracy()
16 |
--------------------------------------------------------------------------------
/examples/tiny_mnist_binary_classifier.py:
--------------------------------------------------------------------------------
1 | from kerax.datasets import binary_tiny_mnist
2 | from kerax.layers import Dense, Relu, Sigmoid
3 | from kerax.losses import BCELoss
4 | from kerax.metrics import binary_accuracy
5 | from kerax.models import Sequential
6 | from kerax.optimizers import SGD
7 |
8 | data = binary_tiny_mnist.load_dataset(batch_size=200)
9 | model = Sequential([Dense(100), Relu, Dense(1), Sigmoid])
10 | model.compile(loss=BCELoss, optimizer=SGD(step_size=0.003), metrics=[binary_accuracy])
11 | model.fit(data=data, epochs=10)
12 | model.save(file_name="tiny_mnist_binary_classifier_v1")
13 | interp = model.get_interpretation()
14 | interp.plot_losses()
15 | interp.plot_accuracy()
16 |
17 | new_model = Sequential()
18 | new_model.load(file_name="tiny_mnist_binary_classifier_v1")
19 | # model already compiled when loaded from serialized file
20 | new_model.fit(data=data, epochs=50)
21 | new_model.save(file_name="tiny_mnist_binary_classifier_v2")
22 | interp = new_model.get_interpretation()
23 | interp.plot_losses()
24 | interp.plot_accuracy()
25 |
--------------------------------------------------------------------------------
/kerax/__init__.py:
--------------------------------------------------------------------------------
1 | from . import data
2 | from . import datasets
3 | from . import layers
4 | from . import losses
5 | from . import metrics
6 | from . import models
7 | from . import optimizers
8 | from . import utils
9 |
--------------------------------------------------------------------------------
/kerax/data/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, Tuple
2 |
3 |
4 | class Dataloader:
5 | import numpy as np
6 | """
7 | Dataloader class is a helper class.
8 | Assists in iterating batches of data during the training process.
9 | """
10 |
11 | def __init__(self, train_data: Iterator[Tuple[np.ndarray, np.ndarray]],
12 | val_data: Iterator[Tuple[np.ndarray, np.ndarray]],
13 | input_shape: Tuple[int, ...], batch_size: int,
14 | num_train_batches: int, num_val_batches: int):
15 | """
16 | Initializes the Dataloader class.
17 | :param train_data: Iterator containing training data in the form of (inputs, labels) tuples.
18 | :param val_data: Iterator containing validation data in the form of (inputs, labels) tuples.
19 | :param input_shape: Input shape to initialize the parameters. -1 to be used for expressing batch dimensions.
20 | :param batch_size: Number of examples to be included in a single batch.
21 | :param num_train_batches: Number of batches of training data
22 | :param num_val_batches: Number of batches of validation data
23 | """
24 | assert train_data is not None, "Training data is empty"
25 | assert val_data is not None, "Validation data is empty"
26 | assert input_shape is not None, "Input shape not passed"
27 | assert batch_size > 0, "Invalid batch size passed"
28 | assert num_train_batches is not None, "Number of training batches is not passed"
29 | assert num_val_batches is not None, "Number of validation batches is not passed"
30 |
31 | self.train_data = train_data
32 | self.val_data = val_data
33 | self.input_shape = input_shape
34 | self.batch_size = batch_size
35 | self.num_train_batches = num_train_batches
36 | self.num_val_batches = num_val_batches
37 |
38 |
39 | __all__ = [
40 | "Dataloader"
41 | ]
42 |
--------------------------------------------------------------------------------
/kerax/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import binary_tiny_mnist, mnist
2 |
3 | __all__ = [
4 | "binary_tiny_mnist",
5 | "mnist"
6 | ]
7 |
--------------------------------------------------------------------------------
/kerax/datasets/binary_tiny_mnist/__init__.py:
--------------------------------------------------------------------------------
1 | from ...data import Dataloader as __Dataloader__
2 | from typing import Tuple
3 |
4 |
5 | def load_dataset(batch_size: int) -> __Dataloader__:
6 | """
7 | Loads the reduced version of the MNIST dataset where it contains only data of digits 0s and 1s.
8 | :param batch_size: Number of examples to be included in a single batch.
9 | :return: Dataloader object consisting of training and validation data.
10 | """
11 | import numpy as np
12 | from ...data import Dataloader
13 | import pandas as pd
14 | from pathlib import Path
15 |
16 | def compute_dataset_info(x, y) -> Tuple[int, ...]:
17 | """
18 | Computes the number of examples and number of batches for a given dataset
19 | :param x: Inputs
20 | :param y: Targets
21 | :return: a tuple consisting of total number of examples and number of batches.
22 | """
23 | assert x.shape[0] == y.shape[0], "Number of examples do not match..."
24 | num_examples = x.shape[0]
25 | num_complete_batches, leftover = divmod(num_examples, batch_size)
26 | num_batches = num_complete_batches + bool(leftover)
27 | return num_examples, num_batches
28 |
29 | def data_stream(x, y, data_info):
30 | """
31 | Yields an iterator for streaming the dataset into respective batches efficiently.
32 | :param x: Inputs
33 | :param y: Targets
34 | :param data_info: a tuple consisting of total number of examples and number of batches.
35 | :return: a tuple consisting of a batch of inputs and targets for a dataset.
36 | """
37 | num_examples, num_batches = data_info
38 | rng = np.random.RandomState(0)
39 | while True:
40 | perm = rng.permutation(num_examples)
41 | for i in range(num_batches):
42 | batch_idx = perm[i * batch_size:(i + 1) * batch_size]
43 | yield x[batch_idx], y[batch_idx]
44 |
45 | def load() -> __Dataloader__:
46 | """
47 | Loads the dataset
48 | :return: a Dataloader object consisting of the training and validation data.
49 | """
50 | path: Path = Path(__file__).parent
51 | dataset: pd.DataFrame = pd.read_csv(path / "train.csv", header=None)
52 |
53 | # 80% training, 20% validation
54 | train_data: pd.DataFrame = dataset.sample(frac=0.8, random_state=0)
55 | val_data: pd.DataFrame = dataset.drop(train_data.index)
56 |
57 | train_labels: np.ndarray = np.expand_dims(train_data[0].values, axis=1)
58 | train_images: np.ndarray = train_data.iloc[:, 1:].values / 255.0
59 |
60 | val_labels: np.ndarray = np.expand_dims(val_data[0].values, axis=1)
61 | val_images: np.ndarray = val_data.iloc[:, 1:].values / 255.0
62 |
63 | train_data_info = compute_dataset_info(train_images, train_labels)
64 | val_data_info = compute_dataset_info(val_images, val_labels)
65 |
66 | train_gen = data_stream(train_images, train_labels, train_data_info)
67 | val_gen = data_stream(val_images, val_labels, val_data_info)
68 |
69 | input_shape = tuple([-1] + list(train_images.shape)[1:])
70 |
71 | return Dataloader(train_data=train_gen, val_data=val_gen,
72 | batch_size=batch_size, input_shape=input_shape,
73 | num_train_batches=train_data_info[1],
74 | num_val_batches=val_data_info[1])
75 |
76 | return load()
77 |
--------------------------------------------------------------------------------
/kerax/datasets/mnist/__init__.py:
--------------------------------------------------------------------------------
1 | def load_dataset(batch_size: int):
2 | """
3 | Loads the complete MNIST dataset (training + validation)
4 | Once the dataset is downloaded, it won't be downloaded again. So, relax...
5 | :param batch_size: Number of examples in a batch
6 | :return: a Dataloader object consisting of the dataset
7 | """
8 | import asyncio
9 | import tensorflow as tf
10 | from pathlib import Path
11 | import tensorflow_datasets as tfds
12 | import math
13 |
14 | from ...data import Dataloader
15 |
16 | async def tfds_load_data() -> Dataloader:
17 | """
18 | Loads the dataset using the TensorFlow Datasets API
19 | :return: a Dataloader object consisting of the dataset.
20 | """
21 | assert batch_size > 0, "Batch size must be greater than 0"
22 | current_path = Path(__file__).parent
23 | ds, info = tfds.load(name="mnist", split=["train", "test"], as_supervised=True, with_info=True,
24 | shuffle_files=True, data_dir=current_path, batch_size=batch_size)
25 | train_ds, valid_ds = ds
26 | train_ds = train_ds.map(lambda x, y: (tf.divide(tf.cast(x, dtype=tf.float32), 255.0), tf.one_hot(y, depth=10)))
27 | valid_ds = valid_ds.map(lambda x, y: (tf.divide(tf.cast(x, dtype=tf.float32), 255.0), tf.one_hot(y, depth=10)))
28 | train_ds, valid_ds = train_ds.cache().repeat(), valid_ds.cache().repeat()
29 | input_shape = tuple([-1] + list(info.features["image"].shape))
30 | num_train_batches = math.ceil(info.splits["train"].num_examples / batch_size)
31 | num_val_batches = math.ceil(info.splits["test"].num_examples / batch_size)
32 | return Dataloader(
33 | train_data=iter(tfds.as_numpy(train_ds)), val_data=iter(tfds.as_numpy(valid_ds)),
34 | input_shape=input_shape, batch_size=batch_size,
35 | num_train_batches=num_train_batches, num_val_batches=num_val_batches
36 | )
37 |
38 | return asyncio.run(tfds_load_data())
39 |
--------------------------------------------------------------------------------
/kerax/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .activations import Relu, Sigmoid, LogSoftmax
2 | from .layers import Dense, Dropout, Flatten
3 |
4 | __all__ = [
5 | "Relu",
6 | "Sigmoid",
7 | "LogSoftmax",
8 | "Dense",
9 | "Flatten",
10 | "Dropout"
11 | ]
12 |
--------------------------------------------------------------------------------
/kerax/layers/activations.py:
--------------------------------------------------------------------------------
1 | from jax.experimental.stax import Sigmoid, Relu, LogSoftmax
2 |
--------------------------------------------------------------------------------
/kerax/layers/layers.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from jax.experimental.stax import Dense, Flatten
4 | from ..utils import jnp, random
5 |
6 | from ..utils import Tensor
7 |
8 |
9 | def Dropout(rate: Union[Tensor, float]):
10 | """
11 | Implementation of the Dropout layer.
12 | When using the apply_fun, you can pass a mode kwarg.
13 | This is helpful when you're using the network for validation / prediction.
14 | :param rate: Probability / Rate at which we wish to 'knock off' the neurons.
15 | :return:
16 | """
17 |
18 | def init_fun(rng, input_shape):
19 | """
20 | Initializes the function. This layer doesn't perform any parameter initialization.
21 | :param rng: a PRNG key for randomization.
22 | :param input_shape: Shape of the inputs received from the previous layer.
23 | :return: a tuple of the output shape and initialized params. Since no params involved, sent an empty tuple.
24 | """
25 | return input_shape, ()
26 |
27 | def apply_fun(params, inputs, **kwargs):
28 | """
29 | Performs computation of the layer
30 | :param params: Parameters of the layer
31 | :param inputs: Inputs for the layer
32 | :param kwargs: Keyword arguments for additional info while computing the inputs
33 | :return: the computed outputs.
34 | """
35 | mode = kwargs.get('mode', 'train')
36 | rng = random.PRNGKey(seed=0)
37 | if mode == 'train':
38 | keep = random.bernoulli(rng, rate, inputs.shape)
39 | return jnp.where(keep, inputs / rate, 0)
40 | else:
41 | return inputs
42 |
43 | return init_fun, apply_fun
44 |
--------------------------------------------------------------------------------
/kerax/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ..utils import Tensor, jnp
2 |
3 |
4 | def BCELoss(predictions: Tensor, targets: Tensor) -> Tensor:
5 | """
6 | BCE or Binary Cross Entropy loss function.
7 | Useful for binary classification tasks.
8 | :param predictions: Outputs of the network.
9 | :param targets: Expected outputs of the network.
10 | :return: binary cross-entropy loss value
11 | """
12 | return -jnp.mean(a=(targets * jnp.log(predictions) + (1 - targets) * jnp.log(1 - predictions)))
13 |
14 |
15 | def CCELoss(predictions: Tensor, targets: Tensor) -> Tensor:
16 | """
17 | CCE or Categorical Cross Entropy loss function.
18 | Useful for multi-class classification task.
19 | :param predictions: Outputs of the network
20 | :param targets: Expected outputs of the network.
21 | :return: categorical cross-entopy loss value.
22 | """
23 | return -jnp.mean(jnp.sum(predictions * targets, axis=1))
24 |
25 |
26 | __all__ = [
27 | "BCELoss",
28 | "CCELoss"
29 | ]
30 |
--------------------------------------------------------------------------------
/kerax/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from ..utils import Tensor, jnp
2 |
3 |
4 | def binary_accuracy(predictions: Tensor, targets: Tensor, **kwargs) -> Tensor:
5 | """
6 | Computes the accuracy for a binary-classification task.
7 | :param predictions: Outputs of the network.
8 | :param targets: Expected outputs of the network.
9 | :param kwargs: Keyword arguments which contains the 'acc_threshold' value.
10 | :return: the mean accuracy
11 | """
12 | threshold = kwargs.get("acc_thresh", 0.50)
13 | assert 0 < threshold < 1, "Threshold should be between 0 and 1"
14 | predictions = jnp.where(predictions > threshold, 1.0, 0.0)
15 | return jnp.mean(predictions == targets)
16 |
17 |
18 | def accuracy(predictions: Tensor, targets: Tensor, **kwargs) -> Tensor:
19 | """
20 | Computes the accuracy for a multi-class classification task
21 | :param predictions: Outputs of the network.
22 | :param targets: Expected outputs of the network.
23 | :param kwargs: Keyword arguments (if any)
24 | :return: the mean accuracy
25 | """
26 | predicted_class = jnp.argmax(predictions, axis=1)
27 | target_class = jnp.argmax(targets, axis=1)
28 | return jnp.mean(predicted_class == target_class)
29 |
30 |
31 | __all__ = [
32 | "binary_accuracy",
33 | "accuracy"
34 | ]
35 |
--------------------------------------------------------------------------------
/kerax/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .sequential import Sequential
2 |
3 | __all__ = [
4 | "Sequential"
5 | ]
6 |
--------------------------------------------------------------------------------
/kerax/models/sequential.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Callable, Tuple, List, Dict, Optional
3 |
4 | from ..data import Dataloader
5 | from ..utils import Interpreter, Tensor, Trainer, convert_to_tensor, serialization, stax
6 |
7 |
8 | class Sequential:
9 | """
10 | Sequential / serial model.
11 | """
12 |
13 | def __init__(self, layers=None):
14 | """
15 | Initializes the model with layers
16 | :param layers: List of layers for the neural network
17 | """
18 | self._layers: List[Tuple[Callable, Callable]] = [] if layers is None else layers
19 | self._epochs: int = 1
20 | self._trained_params: List[Optional[Tuple[Tensor, Tensor]]] = []
21 | self._loss_fn: Optional[Callable[[Tensor, Tensor], Tensor]] = None
22 | self._optimizer: Optional[Optimizer] = None
23 | self._metrics: Dict[str, Dict[str, List[float]]] = {"loss": defaultdict(list),
24 | "loss_per_epoch": defaultdict(list)}
25 | self._metrics_fn: Optional[List[Callable]] = []
26 | self._seed: int = 0
27 |
28 | def __add__(self, other):
29 | """
30 | Overriden method to support addition of a model's layers with another ones.
31 | :param other: a Sequential model consisting of some layers.
32 | :return: a Sequential model with the combined layers.
33 | """
34 | assert type(other) == Sequential, "Type is not 'Sequential'"
35 | assert other._layers, "Layers not provided"
36 | assert len(other._layers) > 0, "No layers found"
37 | layers = self._layers + other._layers
38 | return Sequential(layers=layers)
39 |
40 | def add(self, other):
41 | """
42 | Another helper method for adding layers into the model.
43 | :param other: Either a Sequential model or list of layers or a layer
44 | :return: None if the instance is not as expected in the API.
45 | """
46 | if isinstance(other, Sequential) and len(other._layers) > 0:
47 | self._layers += other._layers
48 | elif isinstance(other, list) and len(other) > 0:
49 | self._layers += other
50 | else:
51 | return None
52 |
53 | def compile(self, loss: Callable, optimizer: Callable, metrics: List[Callable] = None):
54 | """
55 | Compiles the model.
56 | :param loss: the loss function to be used.
57 | :param optimizer: the optimizer to be used.
58 | :param metrics: the metrics to be used.
59 | """
60 | self._loss_fn = loss
61 | self._optimizer = optimizer
62 | self._metrics_fn = metrics
63 | for metric_fn in self._metrics_fn:
64 | self._metrics[metric_fn.__name__] = defaultdict(list)
65 | self._metrics[metric_fn.__name__ + "_per_epoch"] = defaultdict(list)
66 |
67 | def fit(self, data: Dataloader, epochs: int, seed: int = 0):
68 | """
69 | Trains the model
70 | :param data: Dataloader object containing the dataset.
71 | :param epochs: Number of times the entire dataset is used for training the model.
72 | :param seed: Seed for randomization.
73 | """
74 | assert self._optimizer, "Call .compile() before .fit()"
75 | assert self._loss_fn, "Call .compile() before .fit()"
76 | assert epochs > 0, "Number of epochs must be greater than 0"
77 | self._epochs = epochs
78 | self._seed = seed
79 | self.__dict__ = Trainer(self.__dict__).train(data)
80 |
81 | def predict(self, inputs: Tensor):
82 | """
83 | Uses the trained model for prediction
84 | :param inputs: Inputs to be used for prediction
85 | :return: the outputs computed by the trained model.
86 | """
87 | assert self._trained_params, "Module not yet trained / trained params not found"
88 | _, forward_pass = stax.serial(*self._layers)
89 | return forward_pass(self._trained_params, inputs, mode="predict")
90 |
91 | def get_interpretation(self) -> Interpreter:
92 | """
93 | Fetches the Interpreter object for graphical analysis of the training process.
94 | :return: the Interpreter object containing relevant information of the training results.
95 | """
96 | return Interpreter(epochs=self._epochs, metrics=self._metrics)
97 |
98 | def save(self, file_name: str):
99 | """
100 | Saves the model onto the disk.
101 | By default, it will be saved in the current directory.
102 | :param file_name: File name of the model to be saved (without the file extension)
103 | """
104 | assert self._layers, "Layers not provided"
105 | assert self._loss_fn, "Loss function not provided"
106 | assert self._metrics_fn, "Metric functions not provided"
107 | assert self._optimizer, "Optimizer not provided"
108 | assert self._trained_params, "Model not trained yet..."
109 | serialization.save_module(file_name, layers=self._layers,
110 | loss=self._loss_fn,
111 | metrics=self._metrics_fn,
112 | optimizer=self._optimizer,
113 | params=self._trained_params)
114 |
115 | def load(self, file_name: str):
116 | """
117 | Loads the model from the disk.
118 | By default, it will be loaded from the current directory.
119 | :param file_name: File name of the model to be loaded (without the file extension)
120 | """
121 | deserialized_config = serialization.load_module(file_name)
122 | self._layers = deserialized_config.get("layers")
123 | self.compile(loss=deserialized_config.get("loss"),
124 | optimizer=deserialized_config.get("optimizer"),
125 | metrics=deserialized_config.get("metrics"))
126 | self._trained_params = convert_to_tensor(deserialized_config.get("params"))
127 |
--------------------------------------------------------------------------------
/kerax/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from jax.experimental.optimizers import OptimizerState
2 | from jax.experimental.optimizers import adagrad as Adagrad
3 | from jax.experimental.optimizers import adam as Adam
4 | from jax.experimental.optimizers import adamax as Adamax
5 | from jax.experimental.optimizers import rmsprop as RMSProp
6 | from jax.experimental.optimizers import sgd as SGD
7 | from jax.experimental.optimizers import sm3 as SM3
8 |
9 | __all__ = [
10 | "Adam",
11 | "Adagrad",
12 | "Adamax",
13 | "OptimizerState",
14 | "RMSProp",
15 | "SGD",
16 | "SM3"
17 | ]
18 |
--------------------------------------------------------------------------------
/kerax/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import jit, grad, vmap, pmap, random, device_put
3 | from jax.experimental import stax
4 |
5 | from .interpreter import Interpreter
6 | from .serialization import load_module, save_module
7 | from .tensor import Tensor, convert_to_tensor
8 | from .trainer import Trainer
9 |
10 | __all__ = [
11 | "Interpreter",
12 | "Tensor",
13 | "convert_to_tensor",
14 | "load_module",
15 | "save_module",
16 | "Trainer",
17 | "jnp",
18 | "jit",
19 | "random",
20 | "vmap",
21 | "pmap",
22 | "grad",
23 | "device_put",
24 | "stax"
25 | ]
26 |
--------------------------------------------------------------------------------
/kerax/utils/interpreter.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, List
2 |
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | class Interpreter:
7 | """
8 | Interpreter class which is used for analysis of the training results.
9 | """
10 |
11 | def __init__(self, **config):
12 | """
13 | Initializes the class.
14 | :param config: a dictionary consisting of the training results.
15 | """
16 | self._config = config
17 |
18 | def plot_losses(self):
19 | """
20 | Plots the loss curves (both training and validation) in Matplotlib.
21 | :return: a Matplotlib chart.
22 | """
23 | epochs: Iterable[int] = range(1, self._config.get("epochs") + 1)
24 | train_losses: List[float] = self._config.get("metrics").get("loss_per_epoch").get("train")
25 | assert len(train_losses) == len(epochs), "Length of losses and number of epochs do not match"
26 | val_losses: List[float] = self._config.get("metrics").get("loss_per_epoch").get("valid")
27 | assert len(train_losses) == len(val_losses), "Unequal length of the losses"
28 | plt.plot(epochs, train_losses, color="red", label="Training")
29 | plt.plot(epochs, val_losses, color="green", label="Validation")
30 | plt.title("Loss Curve")
31 | plt.xlabel("Epochs")
32 | plt.ylabel("Loss")
33 | plt.legend()
34 | plt.show()
35 |
36 | def plot_accuracy(self):
37 | """
38 | Plots the accuracy curves (both training and validation) in Matplotlib.
39 | :return: a Matplotlib chart.
40 | """
41 | epochs: Iterable[int] = range(1, self._config.get("epochs") + 1)
42 | if "binary_accuracy" in self._config.get("metrics").keys():
43 | train_acc: List[float] = self._config.get("metrics").get("binary_accuracy_per_epoch").get("train")
44 | assert len(train_acc) == len(epochs), "Length of accuracy values and number of epochs do not match"
45 | val_acc: List[float] = self._config.get("metrics").get("binary_accuracy_per_epoch").get("valid")
46 | assert len(train_acc) == len(val_acc), "Unequal length of the accuracy values"
47 | elif "accuracy" in self._config.get("metrics").keys():
48 | train_acc: List[float] = self._config.get("metrics").get("accuracy_per_epoch").get("train")
49 | assert len(train_acc) == len(epochs), "Length of accuracy values and number of epochs do not match"
50 | val_acc: List[float] = self._config.get("metrics").get("accuracy_per_epoch").get("valid")
51 | assert len(train_acc) == len(val_acc), "Unequal length of the accuracy values"
52 | else:
53 | return None
54 | plt.plot(epochs, train_acc, color="red", label="Training")
55 | plt.plot(epochs, val_acc, color="green", label="Validation")
56 | plt.title("Accuracy Curve")
57 | plt.ylim([0.0, 1.05])
58 | plt.xlabel("Epochs")
59 | plt.ylabel("Accuracy")
60 | plt.legend()
61 | plt.show()
62 |
--------------------------------------------------------------------------------
/kerax/utils/serialization.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | import dill
4 | import msgpack
5 |
6 |
7 | def save_module(file_name: str, **config):
8 | """
9 | Serializes the module with the specified file name into MessagePack format.
10 | :param file_name: The name of the file (without the extension).
11 | :param config: Properties of the module to be serialized.
12 | """
13 | serialized_config: Dict[str, Any] = {}
14 | for k, v in config.items():
15 | item_dill: bytes = dill.dumps(v)
16 | item_msgpack: bytes = msgpack.packb(item_dill, use_bin_type=True)
17 | serialized_config[k] = item_msgpack
18 |
19 | with open(f"{file_name}.msgpack", "wb") as f:
20 | serialized_data: bytes = msgpack.packb(serialized_config)
21 | f.write(serialized_data)
22 |
23 |
24 | def load_module(file_name: str) -> Dict[str, Any]:
25 | """
26 | Deserializes the module with the specified file name from MessagePack format.
27 | :param file_name: The name of the file (without the extension).
28 | :return: properties of the deserialized module in a dictionary form.
29 | """
30 | with open(f"{file_name}.msgpack", "rb") as f:
31 | deserialized_data: bytes = f.read()
32 | deserialized_config: Dict[str, Any] = msgpack.unpackb(deserialized_data)
33 | for k in list(deserialized_config):
34 | item_dill: bytes = msgpack.unpackb(deserialized_config.pop(k))
35 | deserialized_config[k.decode("utf-8")] = dill.loads(item_dill)
36 | return deserialized_config
37 |
--------------------------------------------------------------------------------
/kerax/utils/tensor.py:
--------------------------------------------------------------------------------
1 | from jax.numpy import ndarray as Tensor
2 |
3 |
4 | def convert_to_tensor(data):
5 | """
6 | Converts the given data into Tensors
7 | :param data: Data to be converted
8 | :return: the data in the form of Tensors.
9 | """
10 | from jax.tree_util import tree_flatten, tree_unflatten
11 | from jax import device_put
12 | flat_data, data_tree_struct = tree_flatten(data)
13 | for i, item in enumerate(flat_data):
14 | flat_data[i] = device_put(item)
15 | return tree_unflatten(data_tree_struct, flat_data)
16 |
--------------------------------------------------------------------------------
/kerax/utils/trainer.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from functools import partial
3 | from typing import Tuple, List, Optional, Dict, Any
4 |
5 | from tqdm import tqdm
6 |
7 | from . import Tensor, jit, grad, device_put, jnp, random, stax
8 | from ..data import Dataloader
9 | from ..optimizers import OptimizerState
10 |
11 |
12 | class Trainer:
13 | """
14 | Trainer utility class for training the models.
15 | """
16 |
17 | def __init__(self, config: Dict[str, Any]):
18 | """
19 | Initializes the Trainer class with particular configuration.
20 | :param config: Configuration containing the optimizer function and the layers.
21 | """
22 | self.config: Dict[str, Any] = config
23 | self.mode = "train"
24 | self.opt_init, self.opt_update, self.fetch_params = config.get("_optimizer")
25 | self.setup_params, self.forward_pass = stax.serial(*config.get("_layers"))
26 |
27 | def initialize_params(self, input_shape: List[int]):
28 | """
29 | Initializes the network parameters.
30 | If already trained, then it will return the trained parameters.
31 | :param input_shape: Shape of the inputs for properly initializing the parms.
32 | :return: the network parameters.
33 | """
34 | trained_params: List[Optional[Tuple[Tensor, Tensor]]] = self.config.get("_trained_params")
35 | if len(trained_params) > 0:
36 | return trained_params
37 | else:
38 | rng = random.PRNGKey(self.config.get("_seed"))
39 | input_shape[0] = -1
40 | input_shape = tuple(input_shape)
41 | _, params = self.setup_params(rng=rng, input_shape=input_shape)
42 | return params
43 |
44 | def train(self, data: Dataloader):
45 | """
46 | Trains the network
47 | :param data: Dataloader object containing the dataset.
48 | :return: the configuration of the training network in the form of dictionary.
49 | """
50 | network_params = self.initialize_params(list(data.input_shape))
51 | opt_state: OptimizerState = self.opt_init(network_params)
52 | iter_count = itertools.count()
53 | progress_bar = tqdm(iterable=range(self.config.get("_epochs")),
54 | desc="Training model", leave=True)
55 | for epoch in progress_bar:
56 | progress_bar.set_description(desc=f"Epoch {epoch + 1}")
57 | self.mode = "train"
58 | for _ in range(data.num_train_batches):
59 | train_batch = device_put(next(data.train_data))
60 | opt_state = self.step(next(iter_count), opt_state, train_batch)
61 | network_params = self.fetch_params(opt_state)
62 | self.calculate_metrics(network_params, train_batch)
63 | network_params = self.fetch_params(opt_state)
64 | self.mode = "valid"
65 | for _ in range(data.num_val_batches):
66 | valid_batch = device_put(next(data.val_data))
67 | self.calculate_metrics(network_params, valid_batch)
68 | self.calculate_epoch_metrics(data)
69 | progress_bar.set_postfix_str(self.pretty_print_metrics())
70 | progress_bar.refresh()
71 | self.config["_trained_params"] = self.fetch_params(opt_state)
72 | return self.config
73 |
74 | @partial(jit, static_argnums=(0,))
75 | def step(self, i, opt_state, batch):
76 | """
77 | Training step for the optimization process.
78 | :param i: Iteration count
79 | :param opt_state: State of the optimizer
80 | :param batch: Batch of data for the optimizer to work with.
81 | :return: the updates state of the optimizer.
82 | """
83 | params = self.fetch_params(opt_state)
84 | grads = grad(self.compute_loss)(params, batch)
85 | return self.opt_update(i, grads, opt_state)
86 |
87 | @partial(jit, static_argnums=(0,))
88 | def compute_loss(self, params, batch):
89 | """
90 | Helper function to compute forward pass as well as the loss at every step in the training process.
91 | :param params: Network parameters
92 | :param batch: Batch of data to compute predictions and loss value
93 | :return: the computed loss value
94 | """
95 | inputs, targets = batch
96 | predictions = self.forward_pass(params, inputs, mode=self.mode)
97 | return jit(self.config.get("_loss_fn"))(predictions, targets)
98 |
99 | def calculate_metrics(self, params, batch):
100 | """
101 | Helper function that computes the metrics at every step in the training process.
102 | :param params: Network parameters
103 | :param batch: Batch of data to compute the metrics.
104 | """
105 | inputs, targets = batch
106 | predictions = self.forward_pass(params, inputs, mode=self.mode)
107 | self.config.get("_metrics")["loss"][self.mode].append(self.compute_loss(params, batch))
108 | for metric_fn in self.config.get("_metrics_fn"):
109 | self.config.get("_metrics")[metric_fn.__name__][self.mode].append(jit(metric_fn)(predictions, targets))
110 |
111 | def calculate_epoch_metrics(self, data: Dataloader):
112 | """
113 | Calculates the metrics values (both training and validation) after every epoch of the training process.
114 | :param data: Dataloader object (used to fetch the number of batches)
115 | """
116 | self.config.get("_metrics")["loss_per_epoch"]["train"].append(
117 | jnp.mean(jnp.array(self.config.get("_metrics")["loss"]["train"][-data.num_train_batches:]))
118 | )
119 | self.config.get("_metrics")["loss_per_epoch"]["valid"].append(
120 | jnp.mean(jnp.array(self.config.get("_metrics")["loss"]["valid"][-data.num_val_batches:]))
121 | )
122 | for metric_fn in self.config.get("_metrics_fn"):
123 | self.config.get("_metrics")[metric_fn.__name__ + "_per_epoch"]["train"]\
124 | .append(self.config.get("_metrics")[metric_fn.__name__]["train"][-1])
125 | self.config.get("_metrics")[metric_fn.__name__ + "_per_epoch"]["valid"] \
126 | .append(self.config.get("_metrics")[metric_fn.__name__]["valid"][-1])
127 |
128 | def pretty_print_metrics(self) -> str:
129 | """
130 | Helper function to display the results (loss + metrics) during the training process
131 | :return: a string containing the values of the results.
132 | """
133 | return " :: ".join([f"{metric_type}_{metric_name} : {metric.get(metric_type)[-1]:.3f}"
134 | for metric_name, metric in self.config.get("_metrics").items()
135 | for metric_type in metric.keys() if "epoch" not in metric_name])
136 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas>=1.2.0
2 | tensorflow==2.5.1
3 | tensorflow_datasets==4.2.0
4 | matplotlib>=3.3.3
5 | jax>=0.2.7
6 | jaxlib>=0.1.57
7 | numpy>=1.19.5
8 | dill>=0.3.3
9 | tqdm>=4.55.1
10 | msgpack>=1.0.2
11 | msgpack_python>=0.5.6
12 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | __version__ = None
4 |
5 | setup(
6 | name='kerax',
7 | version=__version__,
8 | packages=find_packages(exclude=["examples"]),
9 | python_requires="==3.7",
10 | install_requires=[
11 | "pandas>=1.2.0",
12 | "tensorflow==2.4.0",
13 | "tensorflow_datasets==4.2.0",
14 | "matplotlib>=3.3.3",
15 | "jax>=0.2.7",
16 | "jaxlib>=0.1.57",
17 | "numpy>=1.19.5",
18 | "dill>=0.3.3",
19 | "tqdm>=4.55.1",
20 | "msgpack_python>=0.5.6".
21 | "msgpack>=1.0.2"
22 | ],
23 | url='https://github.com/umangjpatel/kerax',
24 | license='',
25 | author='Umang Patel (umangjpatel)',
26 | author_email='umangpatel1947@gmail.com',
27 | description='Keras-like APIs powered with JAX library',
28 | classifiers=[
29 | "Programming Language :: Python :: 3.7"
30 | ]
31 | )
32 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import numpy as np
4 |
5 | from kerax.datasets import binary_tiny_mnist
6 |
7 |
8 | class TestDataloader(TestCase):
9 |
10 | def setUp(self) -> None:
11 | self.data_loader = binary_tiny_mnist.load_dataset(batch_size=200)
12 |
13 | def tearDown(self) -> None:
14 | del self.data_loader
15 |
16 | def test_dataloader_attrs(self):
17 | self.assertEqual(self.data_loader.num_train_batches, 17)
18 | self.assertEqual(self.data_loader.num_val_batches, 5)
19 | self.assertEqual(self.data_loader.input_shape, (-1, 784))
20 | self.assertEqual(self.data_loader.batch_size, 200)
21 |
22 | def test_data_shapes(self):
23 | self.assertEqual(next(self.data_loader.train_data)[0].shape, (200, 784))
24 | self.assertEqual(next(self.data_loader.train_data)[1].shape, (200, 1))
25 | self.assertEqual(next(self.data_loader.val_data)[0].shape, (200, 784))
26 | self.assertEqual(next(self.data_loader.val_data)[1].shape, (200, 1))
27 |
28 | def test_data_items(self):
29 | train_item = next(self.data_loader.train_data)
30 | val_item = next(self.data_loader.train_data)
31 | train_inputs, train_labels = train_item
32 | val_inputs, val_labels = val_item
33 | self.assertIsInstance(train_item, tuple)
34 | self.assertIsInstance(val_item, tuple)
35 | self.assertIsInstance(train_inputs, np.ndarray)
36 | self.assertIsInstance(train_labels, np.ndarray)
37 | self.assertIsInstance(val_inputs, np.ndarray)
38 | self.assertIsInstance(val_labels, np.ndarray)
39 |
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | from kerax.datasets import mnist, binary_tiny_mnist
2 | from unittest import TestCase
3 |
4 |
5 | class TestDatasets(TestCase):
6 |
7 | def setUp(self) -> None:
8 | pass
9 |
10 | def tearDown(self) -> None:
11 | pass
12 |
13 | def test_binary_mnist(self) -> None:
14 | data = binary_tiny_mnist.load_dataset(batch_size=200)
15 | self.assertEqual(data.batch_size, 200)
16 | self.assertEqual(data.num_train_batches, 17)
17 | self.assertEqual(data.num_val_batches, 5)
18 | inputs, targets = next(data.train_data)
19 | self.assertEqual(inputs.shape, (200, 784))
20 | self.assertEqual(targets.shape, (200, 1))
21 | inputs, targets = next(data.val_data)
22 | self.assertEqual(inputs.shape, (200, 784))
23 | self.assertEqual(targets.shape, (200, 1))
24 | self.assertEqual(data.input_shape, (-1, 784))
25 |
26 | def test_mnist(self) -> None:
27 | data = mnist.load_dataset(batch_size=1000)
28 | self.assertEqual(data.batch_size, 1000)
29 | self.assertEqual(data.num_train_batches, 60)
30 | self.assertEqual(data.num_val_batches, 10)
31 | inputs, targets = next(data.train_data)
32 | self.assertEqual(inputs.shape, (1000, 28, 28, 1))
33 | self.assertEqual(targets.shape, (1000, 10))
34 | inputs, targets = next(data.val_data)
35 | self.assertEqual(inputs.shape, (1000, 28, 28, 1))
36 | self.assertEqual(targets.shape, (1000, 10))
37 | self.assertEqual(data.input_shape, (-1, 28, 28, 1))
38 |
--------------------------------------------------------------------------------
/tests/test_layers.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from kerax.layers import Dense, Relu, Sigmoid, LogSoftmax, Flatten, Dropout
4 | from kerax.utils import stax, random, jnp
5 |
6 |
7 | class TestLayers(TestCase):
8 |
9 | def setUp(self) -> None:
10 | pass
11 |
12 | def tearDown(self) -> None:
13 | pass
14 |
15 | @staticmethod
16 | def get_dummy_network(layers, inputs):
17 | init_params, forward_pass = stax.serial(*layers)
18 | input_shape = tuple([-1] + list(inputs.shape)[1:])
19 | _, params = init_params(random.PRNGKey(10), input_shape)
20 | return params, forward_pass
21 |
22 | @staticmethod
23 | def get_dummy_inputs(input_shape):
24 | return random.normal(key=random.PRNGKey(42), shape=input_shape)
25 |
26 | def test_dense_layer(self):
27 | inputs = self.get_dummy_inputs(input_shape=(200, 784))
28 | params, forward_pass = self.get_dummy_network([Dense(10)], inputs)
29 | outputs = forward_pass(params, inputs)
30 | self.assertEqual(outputs.shape, (200, 10))
31 |
32 | def test_relu_layer(self):
33 | inputs = self.get_dummy_inputs(input_shape=(200, 784))
34 | params, forward_pass = self.get_dummy_network([Relu], inputs)
35 | outputs = forward_pass(params, inputs)
36 | self.assertEqual(inputs.shape, outputs.shape)
37 |
38 | def test_sigmoid_layer(self):
39 | inputs = self.get_dummy_inputs(input_shape=(200, 784))
40 | params, forward_pass = self.get_dummy_network([Sigmoid], inputs)
41 | outputs = forward_pass(params, inputs)
42 | self.assertEqual(inputs.shape, outputs.shape)
43 |
44 | def test_flatten_layer_already_flat(self):
45 | inputs = self.get_dummy_inputs(input_shape=(200, 784))
46 | params, forward_pass = self.get_dummy_network([Flatten], inputs)
47 | outputs = forward_pass(params, inputs)
48 | self.assertEqual(inputs.shape, outputs.shape)
49 |
50 | def test_flatten_layer_not_already_flat(self):
51 | inputs = self.get_dummy_inputs(input_shape=(200, 28, 28, 1))
52 | params, forward_pass = self.get_dummy_network([Flatten], inputs)
53 | outputs = forward_pass(params, inputs)
54 | self.assertEqual(outputs.shape, (200, 784))
55 |
56 | def test_log_softmax_layer(self):
57 | inputs = self.get_dummy_inputs(input_shape=(200, 784))
58 | params, forward_pass = self.get_dummy_network([LogSoftmax], inputs)
59 | outputs = forward_pass(params, inputs)
60 | self.assertEqual(outputs.shape, (200, 784))
61 |
62 | def test_dropout_layer_mode_train(self):
63 | inputs = self.get_dummy_inputs(input_shape=(1, 5))
64 | params, forward_pass = self.get_dummy_network([Dropout(rate=0.0)], inputs)
65 | outputs = forward_pass(params, inputs)
66 | self.assertEqual(len(outputs == 0.0), len(inputs))
67 |
68 | def test_dropout_layer_mode_predict(self):
69 | inputs = self.get_dummy_inputs(input_shape=(1, 5))
70 | params, forward_pass = self.get_dummy_network([Dropout(rate=0.0)], inputs)
71 | outputs = forward_pass(params, inputs, mode="predict")
72 | self.assertEqual(len(outputs != 0.0), len(inputs))
73 |
74 |
--------------------------------------------------------------------------------
/tests/test_losses.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from jax.interpreters.xla import _DeviceArray
4 |
5 | from kerax.losses import BCELoss, CCELoss
6 | from kerax.utils import jnp, random
7 |
8 |
9 | class TestLossFunctions(TestCase):
10 |
11 | def setUp(self) -> None:
12 | self.keys = random.split(random.PRNGKey(42), 4)
13 |
14 | def testBCELoss(self):
15 | k1, k2 = self.keys[0], self.keys[1]
16 | binary_predictions = random.uniform(key=k1, shape=(100, 1))
17 | binary_labels = random.permutation(key=k2,
18 | x=jnp.concatenate((jnp.zeros(shape=(50,)), jnp.ones(shape=(50,)))))
19 | loss = BCELoss(predictions=binary_predictions, targets=binary_labels)
20 | self.assertIsInstance(loss, _DeviceArray)
21 |
22 | def testCCELoss(self):
23 | k3, k4 = self.keys[2], self.keys[3]
24 | softmax_predictions = random.randint(key=k3, shape=(100, 1), minval=0, maxval=9)
25 | softmax_labels = random.randint(key=k4, shape=(100, 1), minval=0, maxval=9)
26 | loss = CCELoss(predictions=softmax_predictions, targets=softmax_labels)
27 | self.assertIsInstance(loss, _DeviceArray)
28 |
29 | def tearDown(self) -> None:
30 | del self.keys
31 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from jax.interpreters.xla import _DeviceArray
4 | from kerax.utils import jnp, random
5 | from kerax.metrics import binary_accuracy, accuracy
6 |
7 |
8 | class TestMetricsFunctions(TestCase):
9 |
10 | def setUp(self) -> None:
11 | self.keys = random.split(random.PRNGKey(42), 4)
12 |
13 | def tearDown(self) -> None:
14 | del self.keys
15 |
16 | def test_binary_accuracy(self) -> None:
17 | k1, k2 = self.keys[0], self.keys[1]
18 | predictions = random.uniform(key=k1, shape=(100, 1), minval=0, maxval=1)
19 | labels = random.randint(key=k2, shape=(100, 1), minval=0, maxval=1)
20 | acc = binary_accuracy(predictions=predictions, targets=labels, acc_thresh=0.5)
21 | self.assertIsInstance(acc, _DeviceArray)
22 |
23 | def test_accuracy(self) -> None:
24 | k1, k2 = self.keys[2], self.keys[3]
25 | predictions = random.randint(key=k1, shape=(100, 1), minval=0, maxval=9)
26 | labels = random.randint(key=k2, shape=(100, 1), minval=0, maxval=9)
27 | acc = accuracy(predictions=predictions, targets=labels)
28 | self.assertIsInstance(acc, _DeviceArray)
29 |
30 |
--------------------------------------------------------------------------------
/tests/test_sequential.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from kerax.datasets import binary_tiny_mnist
4 | from kerax.layers import Dense, Sigmoid
5 | from kerax.losses import BCELoss
6 | from kerax.metrics import binary_accuracy
7 | from kerax.models import Sequential
8 | from kerax.optimizers import SGD
9 | from kerax.utils import Interpreter, device_put
10 |
11 |
12 | class TestSequential(TestCase):
13 |
14 | def setUp(self) -> None:
15 | self.data = binary_tiny_mnist.load_dataset(batch_size=200)
16 | self.binary_model = Sequential(layers=[Dense(1), Sigmoid])
17 |
18 | def tearDown(self) -> None:
19 | del self.data
20 | del self.binary_model
21 |
22 | def test_attrs_after_init(self) -> None:
23 | self.assertNotEqual(self.binary_model._layers, None)
24 | self.assertEqual(self.binary_model._epochs, 1)
25 | self.assertEqual(self.binary_model._trained_params, [])
26 | self.assertEqual(self.binary_model._loss_fn, None)
27 | self.assertEqual(self.binary_model._optimizer, None)
28 | self.assertEqual(", ".join(self.binary_model._metrics.keys()), "loss, loss_per_epoch")
29 | self.assertEqual(self.binary_model._metrics_fn, [])
30 | self.assertEqual(self.binary_model._seed, 0)
31 |
32 | def test_add_layers(self) -> None:
33 | new_binary_model = Sequential([Dense(100)]) + self.binary_model
34 | self.assertEqual(len(new_binary_model._layers), 3)
35 | another_binary_model = Sequential([Dense(100)])
36 | another_binary_model.add(self.binary_model)
37 | self.assertEqual(len(another_binary_model._layers), 3)
38 | one_more_binary_model = Sequential([Dense([100])])
39 | one_more_binary_model.add([Dense(10), Sigmoid])
40 | self.assertEqual(len(one_more_binary_model._layers), 3)
41 | self.assertEqual(one_more_binary_model.add(1), None)
42 |
43 | def test_compile(self) -> None:
44 | loss_fn, opt_fn, metrics_fn = BCELoss, SGD(step_size=0.001), [binary_accuracy]
45 | self.binary_model.compile(loss=loss_fn, optimizer=opt_fn, metrics=metrics_fn)
46 | self.assertEqual(self.binary_model._loss_fn, loss_fn)
47 | self.assertEqual(self.binary_model._optimizer, opt_fn)
48 | self.assertEqual(self.binary_model._metrics_fn, metrics_fn)
49 | self.assertEqual(", ".join(self.binary_model._metrics.keys()),
50 | "loss, loss_per_epoch, binary_accuracy, binary_accuracy_per_epoch")
51 |
52 | def test_fit(self) -> None:
53 | self.binary_model.compile(loss=BCELoss, optimizer=SGD(step_size=0.01), metrics=[binary_accuracy])
54 | self.binary_model.fit(data=self.data, epochs=1)
55 | self.assertNotEqual(self.binary_model._trained_params, None)
56 | self.assertEqual(len(self.binary_model._metrics.keys()), 4)
57 |
58 | def test_predict(self) -> None:
59 | self.binary_model.compile(loss=BCELoss, optimizer=SGD(step_size=0.01), metrics=[binary_accuracy])
60 | self.binary_model.fit(data=self.data, epochs=1)
61 | inputs = device_put(next(self.data.train_data)[0])
62 | self.assertEqual(inputs.shape, (200, 784))
63 | outputs = self.binary_model.predict(inputs)
64 | self.assertEqual(outputs.shape, (200, 1))
65 |
66 | def test_save_and_load_and_train(self) -> None:
67 | from pathlib import Path
68 | self.binary_model.compile(loss=BCELoss, optimizer=SGD(step_size=0.01), metrics=[binary_accuracy])
69 | self.binary_model.fit(data=self.data, epochs=1)
70 | self.binary_model.save("dummy_binary_model")
71 | path = Path(__file__).parent / "dummy_binary_model.msgpack"
72 | self.assertTrue(path.exists())
73 | loaded_model = Sequential()
74 | loaded_model.load("dummy_binary_model")
75 | self.assertEqual(len(loaded_model._layers), 2)
76 | self.assertNotEqual(loaded_model._optimizer, None)
77 | self.assertEqual(len(loaded_model._trained_params), 2)
78 | self.assertEqual(len(loaded_model._metrics_fn), 1)
79 | self.assertNotEqual(loaded_model._loss_fn, None)
80 | model = Sequential()
81 | model.load("dummy_binary_model")
82 | model.fit(data=self.data, epochs=1)
83 |
84 | def test_get_interpreter(self) -> None:
85 | self.binary_model.compile(loss=BCELoss, optimizer=SGD(step_size=0.01), metrics=[binary_accuracy])
86 | self.binary_model.fit(data=self.data, epochs=1)
87 | interp = self.binary_model.get_interpretation()
88 | self.assertIsInstance(interp, Interpreter)
89 | self.assertEqual(interp._config.get("epochs"), 1)
90 | self.assertEqual(len(interp._config.get("metrics").keys()), 4)
91 |
--------------------------------------------------------------------------------