├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── .idea ├── .gitignore ├── dictionaries │ └── umang.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── mypy.xml └── vcs.xml ├── LICENSE ├── README.md ├── assets ├── logo.png ├── quickstart_code_acc_curves.png └── quickstart_code_loss_curves.png ├── examples ├── mnist_classifier.py └── tiny_mnist_binary_classifier.py ├── kerax ├── __init__.py ├── data │ └── __init__.py ├── datasets │ ├── __init__.py │ ├── binary_tiny_mnist │ │ ├── __init__.py │ │ ├── test.csv │ │ └── train.csv │ └── mnist │ │ └── __init__.py ├── layers │ ├── __init__.py │ ├── activations.py │ └── layers.py ├── losses │ └── __init__.py ├── metrics │ └── __init__.py ├── models │ ├── __init__.py │ └── sequential.py ├── optimizers │ └── __init__.py └── utils │ ├── __init__.py │ ├── interpreter.py │ ├── serialization.py │ ├── tensor.py │ └── trainer.py ├── requirements.txt ├── setup.py └── tests ├── test_data.py ├── test_datasets.py ├── test_layers.py ├── test_losses.py ├── test_metrics.py └── test_sequential.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.mypy_cache/3.7/@plugins_snapshot.json 2 | /.mypy_cache/3.7/__future__.data.json 3 | /.mypy_cache/3.7/__future__.meta.json 4 | /.mypy_cache/3.7/_ast.data.json 5 | /.mypy_cache/3.7/_ast.meta.json 6 | /.mypy_cache/3.7/_importlib_modulespec.data.json 7 | /.mypy_cache/3.7/_importlib_modulespec.meta.json 8 | /.mypy_cache/3.7/_warnings.data.json 9 | /.mypy_cache/3.7/_warnings.meta.json 10 | /.mypy_cache/3.7/_weakref.data.json 11 | /.mypy_cache/3.7/_weakref.meta.json 12 | /.mypy_cache/3.7/_weakrefset.data.json 13 | /.mypy_cache/3.7/_weakrefset.meta.json 14 | /.mypy_cache/3.7/abc.data.json 15 | /.mypy_cache/3.7/abc.meta.json 16 | /.mypy_cache/3.7/array.data.json 17 | /.mypy_cache/3.7/array.meta.json 18 | /.mypy_cache/3.7/ast.data.json 19 | /.mypy_cache/3.7/ast.meta.json 20 | /.mypy_cache/3.7/builtins.data.json 21 | /.mypy_cache/3.7/builtins.meta.json 22 | /.mypy_cache/3.7/codecs.data.json 23 | /.mypy_cache/3.7/codecs.meta.json 24 | /.mypy_cache/3.7/contextlib.data.json 25 | /.mypy_cache/3.7/contextlib.meta.json 26 | /.mypy_cache/3.7/copy.data.json 27 | /.mypy_cache/3.7/copy.meta.json 28 | /.mypy_cache/3.7/difflib.data.json 29 | /.mypy_cache/3.7/difflib.meta.json 30 | /.mypy_cache/3.7/inspect.data.json 31 | /.mypy_cache/3.7/inspect.meta.json 32 | /.mypy_cache/3.7/io.data.json 33 | /.mypy_cache/3.7/io.meta.json 34 | /.mypy_cache/3.7/itertools.data.json 35 | /.mypy_cache/3.7/itertools.meta.json 36 | /.mypy_cache/3.7/math.data.json 37 | /.mypy_cache/3.7/math.meta.json 38 | /.mypy_cache/3.7/mmap.data.json 39 | /.mypy_cache/3.7/mmap.meta.json 40 | /.mypy_cache/3.7/pathlib.data.json 41 | /.mypy_cache/3.7/pathlib.meta.json 42 | /.mypy_cache/3.7/pickle.data.json 43 | /.mypy_cache/3.7/pickle.meta.json 44 | /.mypy_cache/3.7/posix.data.json 45 | /.mypy_cache/3.7/posix.meta.json 46 | /.mypy_cache/3.7/queue.data.json 47 | /.mypy_cache/3.7/queue.meta.json 48 | /.mypy_cache/3.7/shutil.data.json 49 | /.mypy_cache/3.7/shutil.meta.json 50 | /.mypy_cache/3.7/struct.data.json 51 | /.mypy_cache/3.7/struct.meta.json 52 | /.mypy_cache/3.7/sys.data.json 53 | /.mypy_cache/3.7/sys.meta.json 54 | /.mypy_cache/3.7/tarfile.data.json 55 | /.mypy_cache/3.7/tarfile.meta.json 56 | /.mypy_cache/3.7/tempfile.data.json 57 | /.mypy_cache/3.7/tempfile.meta.json 58 | /.mypy_cache/3.7/test.data.json 59 | /.mypy_cache/3.7/test.meta.json 60 | /.mypy_cache/3.7/threading.data.json 61 | /.mypy_cache/3.7/threading.meta.json 62 | /.mypy_cache/3.7/traceback.data.json 63 | /.mypy_cache/3.7/traceback.meta.json 64 | /.mypy_cache/3.7/types.data.json 65 | /.mypy_cache/3.7/types.meta.json 66 | /.mypy_cache/3.7/typing.data.json 67 | /.mypy_cache/3.7/typing.meta.json 68 | /.mypy_cache/3.7/typing_extensions.data.json 69 | /.mypy_cache/3.7/typing_extensions.meta.json 70 | /.mypy_cache/3.7/warnings.data.json 71 | /.mypy_cache/3.7/warnings.meta.json 72 | /.mypy_cache/3.7/weakref.data.json 73 | /.mypy_cache/3.7/weakref.meta.json 74 | /.mypy_cache/ 75 | /.idea/ 76 | /.idea/dictionaries/ 77 | /datasets/mnist/ 78 | .DS_Store 79 | /venv/ 80 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml 4 | /dictionaries/ 5 | -------------------------------------------------------------------------------- /.idea/dictionaries/umang.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | adagrad 5 | bfloat 6 | conv 7 | crossentropy 8 | dataset 9 | datasets 10 | dtype 11 | fastai 12 | glorot 13 | jaxlib 14 | matplotlib 15 | mish 16 | mnist 17 | nchw 18 | ndarray 19 | ndarrays 20 | nhwc 21 | npscalar 22 | pytree 23 | relu 24 | rmsprop 25 | softmax 26 | softplus 27 | subkey 28 | tanh 29 | tfds 30 | tqdm 31 | 32 | 33 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/mypy.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | logo 3 |
4 | 5 | # Kerax 6 | 7 | Keras-like APIs for the [JAX](https://github.com/google/jax) library. 8 | 9 | ## Features 10 | 11 | * Enables high-performance machine learning research. 12 | * Built-in support of popular optimization algorithms and activation functions. 13 | * Runs seamlessly on CPU, GPU and even TPU! without any manual configuration required. 14 | 15 | ## Quickstart 16 | 17 | ### Code 18 | 19 | ```python3 20 | from kerax.datasets import binary_tiny_mnist 21 | from kerax.layers import Dense, Relu, Sigmoid 22 | from kerax.losses import BCELoss 23 | from kerax.metrics import binary_accuracy 24 | from kerax.models import Sequential 25 | from kerax.optimizers import SGD 26 | 27 | data = binary_tiny_mnist.load_dataset(batch_size=200) 28 | model = Sequential([Dense(100), Relu, Dense(1), Sigmoid]) 29 | model.compile(loss=BCELoss, optimizer=SGD(step_size=0.003), metrics=[binary_accuracy]) 30 | model.fit(data=data, epochs=10) 31 | model.save(file_name="model") 32 | 33 | interp = model.get_interpretation() 34 | interp.plot_losses() 35 | interp.plot_accuracy() 36 | ``` 37 | ### Output 38 | 39 | ```terminal 40 | WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) 41 | Epoch 10: 100%|██████████| 10/10 [00:02<00:00, 3.82it/s, train_loss : 0.192 :: valid_loss : 0.202 :: train_binary_accuracy : 1.000 :: valid_binary_accuracy : 1.000] 42 | 43 | Process finished with exit code 0 44 | ``` 45 | 46 | ![Quickstart code Loss Curves](assets/quickstart_code_loss_curves.png "Loss Curves") 47 | ![Quickstart code Accuracy Curves](assets/quickstart_code_acc_curves.png "Accuracy Curves") 48 | 49 | ## Documentation (Coming soon...) 50 | 51 | ## Developer's Notes 52 | 53 | This project is developed and maintained by [Umang Patel](https://github.com/umangjpatel) -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umangjpatel/kerax/21e1798643a6382a1d7960db4c5f3a22fa19a28a/assets/logo.png -------------------------------------------------------------------------------- /assets/quickstart_code_acc_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umangjpatel/kerax/21e1798643a6382a1d7960db4c5f3a22fa19a28a/assets/quickstart_code_acc_curves.png -------------------------------------------------------------------------------- /assets/quickstart_code_loss_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umangjpatel/kerax/21e1798643a6382a1d7960db4c5f3a22fa19a28a/assets/quickstart_code_loss_curves.png -------------------------------------------------------------------------------- /examples/mnist_classifier.py: -------------------------------------------------------------------------------- 1 | from kerax.datasets import mnist 2 | from kerax.layers import Flatten, Dense, Relu, LogSoftmax 3 | from kerax.losses import CCELoss 4 | from kerax.metrics import accuracy 5 | from kerax.models import Sequential 6 | from kerax.optimizers import RMSProp 7 | 8 | data = mnist.load_dataset(batch_size=1024) 9 | model = Sequential([Flatten, Dense(100), Relu, Dense(10), LogSoftmax]) 10 | model.compile(loss=CCELoss, optimizer=RMSProp(step_size=0.001), metrics=[accuracy]) 11 | model.fit(data, epochs=10) 12 | model.save("tfds_mnist_v1") 13 | interp = model.get_interpretation() 14 | interp.plot_losses() 15 | interp.plot_accuracy() 16 | -------------------------------------------------------------------------------- /examples/tiny_mnist_binary_classifier.py: -------------------------------------------------------------------------------- 1 | from kerax.datasets import binary_tiny_mnist 2 | from kerax.layers import Dense, Relu, Sigmoid 3 | from kerax.losses import BCELoss 4 | from kerax.metrics import binary_accuracy 5 | from kerax.models import Sequential 6 | from kerax.optimizers import SGD 7 | 8 | data = binary_tiny_mnist.load_dataset(batch_size=200) 9 | model = Sequential([Dense(100), Relu, Dense(1), Sigmoid]) 10 | model.compile(loss=BCELoss, optimizer=SGD(step_size=0.003), metrics=[binary_accuracy]) 11 | model.fit(data=data, epochs=10) 12 | model.save(file_name="tiny_mnist_binary_classifier_v1") 13 | interp = model.get_interpretation() 14 | interp.plot_losses() 15 | interp.plot_accuracy() 16 | 17 | new_model = Sequential() 18 | new_model.load(file_name="tiny_mnist_binary_classifier_v1") 19 | # model already compiled when loaded from serialized file 20 | new_model.fit(data=data, epochs=50) 21 | new_model.save(file_name="tiny_mnist_binary_classifier_v2") 22 | interp = new_model.get_interpretation() 23 | interp.plot_losses() 24 | interp.plot_accuracy() 25 | -------------------------------------------------------------------------------- /kerax/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from . import datasets 3 | from . import layers 4 | from . import losses 5 | from . import metrics 6 | from . import models 7 | from . import optimizers 8 | from . import utils 9 | -------------------------------------------------------------------------------- /kerax/data/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Tuple 2 | 3 | 4 | class Dataloader: 5 | import numpy as np 6 | """ 7 | Dataloader class is a helper class. 8 | Assists in iterating batches of data during the training process. 9 | """ 10 | 11 | def __init__(self, train_data: Iterator[Tuple[np.ndarray, np.ndarray]], 12 | val_data: Iterator[Tuple[np.ndarray, np.ndarray]], 13 | input_shape: Tuple[int, ...], batch_size: int, 14 | num_train_batches: int, num_val_batches: int): 15 | """ 16 | Initializes the Dataloader class. 17 | :param train_data: Iterator containing training data in the form of (inputs, labels) tuples. 18 | :param val_data: Iterator containing validation data in the form of (inputs, labels) tuples. 19 | :param input_shape: Input shape to initialize the parameters. -1 to be used for expressing batch dimensions. 20 | :param batch_size: Number of examples to be included in a single batch. 21 | :param num_train_batches: Number of batches of training data 22 | :param num_val_batches: Number of batches of validation data 23 | """ 24 | assert train_data is not None, "Training data is empty" 25 | assert val_data is not None, "Validation data is empty" 26 | assert input_shape is not None, "Input shape not passed" 27 | assert batch_size > 0, "Invalid batch size passed" 28 | assert num_train_batches is not None, "Number of training batches is not passed" 29 | assert num_val_batches is not None, "Number of validation batches is not passed" 30 | 31 | self.train_data = train_data 32 | self.val_data = val_data 33 | self.input_shape = input_shape 34 | self.batch_size = batch_size 35 | self.num_train_batches = num_train_batches 36 | self.num_val_batches = num_val_batches 37 | 38 | 39 | __all__ = [ 40 | "Dataloader" 41 | ] 42 | -------------------------------------------------------------------------------- /kerax/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import binary_tiny_mnist, mnist 2 | 3 | __all__ = [ 4 | "binary_tiny_mnist", 5 | "mnist" 6 | ] 7 | -------------------------------------------------------------------------------- /kerax/datasets/binary_tiny_mnist/__init__.py: -------------------------------------------------------------------------------- 1 | from ...data import Dataloader as __Dataloader__ 2 | from typing import Tuple 3 | 4 | 5 | def load_dataset(batch_size: int) -> __Dataloader__: 6 | """ 7 | Loads the reduced version of the MNIST dataset where it contains only data of digits 0s and 1s. 8 | :param batch_size: Number of examples to be included in a single batch. 9 | :return: Dataloader object consisting of training and validation data. 10 | """ 11 | import numpy as np 12 | from ...data import Dataloader 13 | import pandas as pd 14 | from pathlib import Path 15 | 16 | def compute_dataset_info(x, y) -> Tuple[int, ...]: 17 | """ 18 | Computes the number of examples and number of batches for a given dataset 19 | :param x: Inputs 20 | :param y: Targets 21 | :return: a tuple consisting of total number of examples and number of batches. 22 | """ 23 | assert x.shape[0] == y.shape[0], "Number of examples do not match..." 24 | num_examples = x.shape[0] 25 | num_complete_batches, leftover = divmod(num_examples, batch_size) 26 | num_batches = num_complete_batches + bool(leftover) 27 | return num_examples, num_batches 28 | 29 | def data_stream(x, y, data_info): 30 | """ 31 | Yields an iterator for streaming the dataset into respective batches efficiently. 32 | :param x: Inputs 33 | :param y: Targets 34 | :param data_info: a tuple consisting of total number of examples and number of batches. 35 | :return: a tuple consisting of a batch of inputs and targets for a dataset. 36 | """ 37 | num_examples, num_batches = data_info 38 | rng = np.random.RandomState(0) 39 | while True: 40 | perm = rng.permutation(num_examples) 41 | for i in range(num_batches): 42 | batch_idx = perm[i * batch_size:(i + 1) * batch_size] 43 | yield x[batch_idx], y[batch_idx] 44 | 45 | def load() -> __Dataloader__: 46 | """ 47 | Loads the dataset 48 | :return: a Dataloader object consisting of the training and validation data. 49 | """ 50 | path: Path = Path(__file__).parent 51 | dataset: pd.DataFrame = pd.read_csv(path / "train.csv", header=None) 52 | 53 | # 80% training, 20% validation 54 | train_data: pd.DataFrame = dataset.sample(frac=0.8, random_state=0) 55 | val_data: pd.DataFrame = dataset.drop(train_data.index) 56 | 57 | train_labels: np.ndarray = np.expand_dims(train_data[0].values, axis=1) 58 | train_images: np.ndarray = train_data.iloc[:, 1:].values / 255.0 59 | 60 | val_labels: np.ndarray = np.expand_dims(val_data[0].values, axis=1) 61 | val_images: np.ndarray = val_data.iloc[:, 1:].values / 255.0 62 | 63 | train_data_info = compute_dataset_info(train_images, train_labels) 64 | val_data_info = compute_dataset_info(val_images, val_labels) 65 | 66 | train_gen = data_stream(train_images, train_labels, train_data_info) 67 | val_gen = data_stream(val_images, val_labels, val_data_info) 68 | 69 | input_shape = tuple([-1] + list(train_images.shape)[1:]) 70 | 71 | return Dataloader(train_data=train_gen, val_data=val_gen, 72 | batch_size=batch_size, input_shape=input_shape, 73 | num_train_batches=train_data_info[1], 74 | num_val_batches=val_data_info[1]) 75 | 76 | return load() 77 | -------------------------------------------------------------------------------- /kerax/datasets/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | def load_dataset(batch_size: int): 2 | """ 3 | Loads the complete MNIST dataset (training + validation) 4 | Once the dataset is downloaded, it won't be downloaded again. So, relax... 5 | :param batch_size: Number of examples in a batch 6 | :return: a Dataloader object consisting of the dataset 7 | """ 8 | import asyncio 9 | import tensorflow as tf 10 | from pathlib import Path 11 | import tensorflow_datasets as tfds 12 | import math 13 | 14 | from ...data import Dataloader 15 | 16 | async def tfds_load_data() -> Dataloader: 17 | """ 18 | Loads the dataset using the TensorFlow Datasets API 19 | :return: a Dataloader object consisting of the dataset. 20 | """ 21 | assert batch_size > 0, "Batch size must be greater than 0" 22 | current_path = Path(__file__).parent 23 | ds, info = tfds.load(name="mnist", split=["train", "test"], as_supervised=True, with_info=True, 24 | shuffle_files=True, data_dir=current_path, batch_size=batch_size) 25 | train_ds, valid_ds = ds 26 | train_ds = train_ds.map(lambda x, y: (tf.divide(tf.cast(x, dtype=tf.float32), 255.0), tf.one_hot(y, depth=10))) 27 | valid_ds = valid_ds.map(lambda x, y: (tf.divide(tf.cast(x, dtype=tf.float32), 255.0), tf.one_hot(y, depth=10))) 28 | train_ds, valid_ds = train_ds.cache().repeat(), valid_ds.cache().repeat() 29 | input_shape = tuple([-1] + list(info.features["image"].shape)) 30 | num_train_batches = math.ceil(info.splits["train"].num_examples / batch_size) 31 | num_val_batches = math.ceil(info.splits["test"].num_examples / batch_size) 32 | return Dataloader( 33 | train_data=iter(tfds.as_numpy(train_ds)), val_data=iter(tfds.as_numpy(valid_ds)), 34 | input_shape=input_shape, batch_size=batch_size, 35 | num_train_batches=num_train_batches, num_val_batches=num_val_batches 36 | ) 37 | 38 | return asyncio.run(tfds_load_data()) 39 | -------------------------------------------------------------------------------- /kerax/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import Relu, Sigmoid, LogSoftmax 2 | from .layers import Dense, Dropout, Flatten 3 | 4 | __all__ = [ 5 | "Relu", 6 | "Sigmoid", 7 | "LogSoftmax", 8 | "Dense", 9 | "Flatten", 10 | "Dropout" 11 | ] 12 | -------------------------------------------------------------------------------- /kerax/layers/activations.py: -------------------------------------------------------------------------------- 1 | from jax.experimental.stax import Sigmoid, Relu, LogSoftmax 2 | -------------------------------------------------------------------------------- /kerax/layers/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from jax.experimental.stax import Dense, Flatten 4 | from ..utils import jnp, random 5 | 6 | from ..utils import Tensor 7 | 8 | 9 | def Dropout(rate: Union[Tensor, float]): 10 | """ 11 | Implementation of the Dropout layer. 12 | When using the apply_fun, you can pass a mode kwarg. 13 | This is helpful when you're using the network for validation / prediction. 14 | :param rate: Probability / Rate at which we wish to 'knock off' the neurons. 15 | :return: 16 | """ 17 | 18 | def init_fun(rng, input_shape): 19 | """ 20 | Initializes the function. This layer doesn't perform any parameter initialization. 21 | :param rng: a PRNG key for randomization. 22 | :param input_shape: Shape of the inputs received from the previous layer. 23 | :return: a tuple of the output shape and initialized params. Since no params involved, sent an empty tuple. 24 | """ 25 | return input_shape, () 26 | 27 | def apply_fun(params, inputs, **kwargs): 28 | """ 29 | Performs computation of the layer 30 | :param params: Parameters of the layer 31 | :param inputs: Inputs for the layer 32 | :param kwargs: Keyword arguments for additional info while computing the inputs 33 | :return: the computed outputs. 34 | """ 35 | mode = kwargs.get('mode', 'train') 36 | rng = random.PRNGKey(seed=0) 37 | if mode == 'train': 38 | keep = random.bernoulli(rng, rate, inputs.shape) 39 | return jnp.where(keep, inputs / rate, 0) 40 | else: 41 | return inputs 42 | 43 | return init_fun, apply_fun 44 | -------------------------------------------------------------------------------- /kerax/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import Tensor, jnp 2 | 3 | 4 | def BCELoss(predictions: Tensor, targets: Tensor) -> Tensor: 5 | """ 6 | BCE or Binary Cross Entropy loss function. 7 | Useful for binary classification tasks. 8 | :param predictions: Outputs of the network. 9 | :param targets: Expected outputs of the network. 10 | :return: binary cross-entropy loss value 11 | """ 12 | return -jnp.mean(a=(targets * jnp.log(predictions) + (1 - targets) * jnp.log(1 - predictions))) 13 | 14 | 15 | def CCELoss(predictions: Tensor, targets: Tensor) -> Tensor: 16 | """ 17 | CCE or Categorical Cross Entropy loss function. 18 | Useful for multi-class classification task. 19 | :param predictions: Outputs of the network 20 | :param targets: Expected outputs of the network. 21 | :return: categorical cross-entopy loss value. 22 | """ 23 | return -jnp.mean(jnp.sum(predictions * targets, axis=1)) 24 | 25 | 26 | __all__ = [ 27 | "BCELoss", 28 | "CCELoss" 29 | ] 30 | -------------------------------------------------------------------------------- /kerax/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import Tensor, jnp 2 | 3 | 4 | def binary_accuracy(predictions: Tensor, targets: Tensor, **kwargs) -> Tensor: 5 | """ 6 | Computes the accuracy for a binary-classification task. 7 | :param predictions: Outputs of the network. 8 | :param targets: Expected outputs of the network. 9 | :param kwargs: Keyword arguments which contains the 'acc_threshold' value. 10 | :return: the mean accuracy 11 | """ 12 | threshold = kwargs.get("acc_thresh", 0.50) 13 | assert 0 < threshold < 1, "Threshold should be between 0 and 1" 14 | predictions = jnp.where(predictions > threshold, 1.0, 0.0) 15 | return jnp.mean(predictions == targets) 16 | 17 | 18 | def accuracy(predictions: Tensor, targets: Tensor, **kwargs) -> Tensor: 19 | """ 20 | Computes the accuracy for a multi-class classification task 21 | :param predictions: Outputs of the network. 22 | :param targets: Expected outputs of the network. 23 | :param kwargs: Keyword arguments (if any) 24 | :return: the mean accuracy 25 | """ 26 | predicted_class = jnp.argmax(predictions, axis=1) 27 | target_class = jnp.argmax(targets, axis=1) 28 | return jnp.mean(predicted_class == target_class) 29 | 30 | 31 | __all__ = [ 32 | "binary_accuracy", 33 | "accuracy" 34 | ] 35 | -------------------------------------------------------------------------------- /kerax/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequential import Sequential 2 | 3 | __all__ = [ 4 | "Sequential" 5 | ] 6 | -------------------------------------------------------------------------------- /kerax/models/sequential.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Callable, Tuple, List, Dict, Optional 3 | 4 | from ..data import Dataloader 5 | from ..utils import Interpreter, Tensor, Trainer, convert_to_tensor, serialization, stax 6 | 7 | 8 | class Sequential: 9 | """ 10 | Sequential / serial model. 11 | """ 12 | 13 | def __init__(self, layers=None): 14 | """ 15 | Initializes the model with layers 16 | :param layers: List of layers for the neural network 17 | """ 18 | self._layers: List[Tuple[Callable, Callable]] = [] if layers is None else layers 19 | self._epochs: int = 1 20 | self._trained_params: List[Optional[Tuple[Tensor, Tensor]]] = [] 21 | self._loss_fn: Optional[Callable[[Tensor, Tensor], Tensor]] = None 22 | self._optimizer: Optional[Optimizer] = None 23 | self._metrics: Dict[str, Dict[str, List[float]]] = {"loss": defaultdict(list), 24 | "loss_per_epoch": defaultdict(list)} 25 | self._metrics_fn: Optional[List[Callable]] = [] 26 | self._seed: int = 0 27 | 28 | def __add__(self, other): 29 | """ 30 | Overriden method to support addition of a model's layers with another ones. 31 | :param other: a Sequential model consisting of some layers. 32 | :return: a Sequential model with the combined layers. 33 | """ 34 | assert type(other) == Sequential, "Type is not 'Sequential'" 35 | assert other._layers, "Layers not provided" 36 | assert len(other._layers) > 0, "No layers found" 37 | layers = self._layers + other._layers 38 | return Sequential(layers=layers) 39 | 40 | def add(self, other): 41 | """ 42 | Another helper method for adding layers into the model. 43 | :param other: Either a Sequential model or list of layers or a layer 44 | :return: None if the instance is not as expected in the API. 45 | """ 46 | if isinstance(other, Sequential) and len(other._layers) > 0: 47 | self._layers += other._layers 48 | elif isinstance(other, list) and len(other) > 0: 49 | self._layers += other 50 | else: 51 | return None 52 | 53 | def compile(self, loss: Callable, optimizer: Callable, metrics: List[Callable] = None): 54 | """ 55 | Compiles the model. 56 | :param loss: the loss function to be used. 57 | :param optimizer: the optimizer to be used. 58 | :param metrics: the metrics to be used. 59 | """ 60 | self._loss_fn = loss 61 | self._optimizer = optimizer 62 | self._metrics_fn = metrics 63 | for metric_fn in self._metrics_fn: 64 | self._metrics[metric_fn.__name__] = defaultdict(list) 65 | self._metrics[metric_fn.__name__ + "_per_epoch"] = defaultdict(list) 66 | 67 | def fit(self, data: Dataloader, epochs: int, seed: int = 0): 68 | """ 69 | Trains the model 70 | :param data: Dataloader object containing the dataset. 71 | :param epochs: Number of times the entire dataset is used for training the model. 72 | :param seed: Seed for randomization. 73 | """ 74 | assert self._optimizer, "Call .compile() before .fit()" 75 | assert self._loss_fn, "Call .compile() before .fit()" 76 | assert epochs > 0, "Number of epochs must be greater than 0" 77 | self._epochs = epochs 78 | self._seed = seed 79 | self.__dict__ = Trainer(self.__dict__).train(data) 80 | 81 | def predict(self, inputs: Tensor): 82 | """ 83 | Uses the trained model for prediction 84 | :param inputs: Inputs to be used for prediction 85 | :return: the outputs computed by the trained model. 86 | """ 87 | assert self._trained_params, "Module not yet trained / trained params not found" 88 | _, forward_pass = stax.serial(*self._layers) 89 | return forward_pass(self._trained_params, inputs, mode="predict") 90 | 91 | def get_interpretation(self) -> Interpreter: 92 | """ 93 | Fetches the Interpreter object for graphical analysis of the training process. 94 | :return: the Interpreter object containing relevant information of the training results. 95 | """ 96 | return Interpreter(epochs=self._epochs, metrics=self._metrics) 97 | 98 | def save(self, file_name: str): 99 | """ 100 | Saves the model onto the disk. 101 | By default, it will be saved in the current directory. 102 | :param file_name: File name of the model to be saved (without the file extension) 103 | """ 104 | assert self._layers, "Layers not provided" 105 | assert self._loss_fn, "Loss function not provided" 106 | assert self._metrics_fn, "Metric functions not provided" 107 | assert self._optimizer, "Optimizer not provided" 108 | assert self._trained_params, "Model not trained yet..." 109 | serialization.save_module(file_name, layers=self._layers, 110 | loss=self._loss_fn, 111 | metrics=self._metrics_fn, 112 | optimizer=self._optimizer, 113 | params=self._trained_params) 114 | 115 | def load(self, file_name: str): 116 | """ 117 | Loads the model from the disk. 118 | By default, it will be loaded from the current directory. 119 | :param file_name: File name of the model to be loaded (without the file extension) 120 | """ 121 | deserialized_config = serialization.load_module(file_name) 122 | self._layers = deserialized_config.get("layers") 123 | self.compile(loss=deserialized_config.get("loss"), 124 | optimizer=deserialized_config.get("optimizer"), 125 | metrics=deserialized_config.get("metrics")) 126 | self._trained_params = convert_to_tensor(deserialized_config.get("params")) 127 | -------------------------------------------------------------------------------- /kerax/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from jax.experimental.optimizers import OptimizerState 2 | from jax.experimental.optimizers import adagrad as Adagrad 3 | from jax.experimental.optimizers import adam as Adam 4 | from jax.experimental.optimizers import adamax as Adamax 5 | from jax.experimental.optimizers import rmsprop as RMSProp 6 | from jax.experimental.optimizers import sgd as SGD 7 | from jax.experimental.optimizers import sm3 as SM3 8 | 9 | __all__ = [ 10 | "Adam", 11 | "Adagrad", 12 | "Adamax", 13 | "OptimizerState", 14 | "RMSProp", 15 | "SGD", 16 | "SM3" 17 | ] 18 | -------------------------------------------------------------------------------- /kerax/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import jit, grad, vmap, pmap, random, device_put 3 | from jax.experimental import stax 4 | 5 | from .interpreter import Interpreter 6 | from .serialization import load_module, save_module 7 | from .tensor import Tensor, convert_to_tensor 8 | from .trainer import Trainer 9 | 10 | __all__ = [ 11 | "Interpreter", 12 | "Tensor", 13 | "convert_to_tensor", 14 | "load_module", 15 | "save_module", 16 | "Trainer", 17 | "jnp", 18 | "jit", 19 | "random", 20 | "vmap", 21 | "pmap", 22 | "grad", 23 | "device_put", 24 | "stax" 25 | ] 26 | -------------------------------------------------------------------------------- /kerax/utils/interpreter.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | class Interpreter: 7 | """ 8 | Interpreter class which is used for analysis of the training results. 9 | """ 10 | 11 | def __init__(self, **config): 12 | """ 13 | Initializes the class. 14 | :param config: a dictionary consisting of the training results. 15 | """ 16 | self._config = config 17 | 18 | def plot_losses(self): 19 | """ 20 | Plots the loss curves (both training and validation) in Matplotlib. 21 | :return: a Matplotlib chart. 22 | """ 23 | epochs: Iterable[int] = range(1, self._config.get("epochs") + 1) 24 | train_losses: List[float] = self._config.get("metrics").get("loss_per_epoch").get("train") 25 | assert len(train_losses) == len(epochs), "Length of losses and number of epochs do not match" 26 | val_losses: List[float] = self._config.get("metrics").get("loss_per_epoch").get("valid") 27 | assert len(train_losses) == len(val_losses), "Unequal length of the losses" 28 | plt.plot(epochs, train_losses, color="red", label="Training") 29 | plt.plot(epochs, val_losses, color="green", label="Validation") 30 | plt.title("Loss Curve") 31 | plt.xlabel("Epochs") 32 | plt.ylabel("Loss") 33 | plt.legend() 34 | plt.show() 35 | 36 | def plot_accuracy(self): 37 | """ 38 | Plots the accuracy curves (both training and validation) in Matplotlib. 39 | :return: a Matplotlib chart. 40 | """ 41 | epochs: Iterable[int] = range(1, self._config.get("epochs") + 1) 42 | if "binary_accuracy" in self._config.get("metrics").keys(): 43 | train_acc: List[float] = self._config.get("metrics").get("binary_accuracy_per_epoch").get("train") 44 | assert len(train_acc) == len(epochs), "Length of accuracy values and number of epochs do not match" 45 | val_acc: List[float] = self._config.get("metrics").get("binary_accuracy_per_epoch").get("valid") 46 | assert len(train_acc) == len(val_acc), "Unequal length of the accuracy values" 47 | elif "accuracy" in self._config.get("metrics").keys(): 48 | train_acc: List[float] = self._config.get("metrics").get("accuracy_per_epoch").get("train") 49 | assert len(train_acc) == len(epochs), "Length of accuracy values and number of epochs do not match" 50 | val_acc: List[float] = self._config.get("metrics").get("accuracy_per_epoch").get("valid") 51 | assert len(train_acc) == len(val_acc), "Unequal length of the accuracy values" 52 | else: 53 | return None 54 | plt.plot(epochs, train_acc, color="red", label="Training") 55 | plt.plot(epochs, val_acc, color="green", label="Validation") 56 | plt.title("Accuracy Curve") 57 | plt.ylim([0.0, 1.05]) 58 | plt.xlabel("Epochs") 59 | plt.ylabel("Accuracy") 60 | plt.legend() 61 | plt.show() 62 | -------------------------------------------------------------------------------- /kerax/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import dill 4 | import msgpack 5 | 6 | 7 | def save_module(file_name: str, **config): 8 | """ 9 | Serializes the module with the specified file name into MessagePack format. 10 | :param file_name: The name of the file (without the extension). 11 | :param config: Properties of the module to be serialized. 12 | """ 13 | serialized_config: Dict[str, Any] = {} 14 | for k, v in config.items(): 15 | item_dill: bytes = dill.dumps(v) 16 | item_msgpack: bytes = msgpack.packb(item_dill, use_bin_type=True) 17 | serialized_config[k] = item_msgpack 18 | 19 | with open(f"{file_name}.msgpack", "wb") as f: 20 | serialized_data: bytes = msgpack.packb(serialized_config) 21 | f.write(serialized_data) 22 | 23 | 24 | def load_module(file_name: str) -> Dict[str, Any]: 25 | """ 26 | Deserializes the module with the specified file name from MessagePack format. 27 | :param file_name: The name of the file (without the extension). 28 | :return: properties of the deserialized module in a dictionary form. 29 | """ 30 | with open(f"{file_name}.msgpack", "rb") as f: 31 | deserialized_data: bytes = f.read() 32 | deserialized_config: Dict[str, Any] = msgpack.unpackb(deserialized_data) 33 | for k in list(deserialized_config): 34 | item_dill: bytes = msgpack.unpackb(deserialized_config.pop(k)) 35 | deserialized_config[k.decode("utf-8")] = dill.loads(item_dill) 36 | return deserialized_config 37 | -------------------------------------------------------------------------------- /kerax/utils/tensor.py: -------------------------------------------------------------------------------- 1 | from jax.numpy import ndarray as Tensor 2 | 3 | 4 | def convert_to_tensor(data): 5 | """ 6 | Converts the given data into Tensors 7 | :param data: Data to be converted 8 | :return: the data in the form of Tensors. 9 | """ 10 | from jax.tree_util import tree_flatten, tree_unflatten 11 | from jax import device_put 12 | flat_data, data_tree_struct = tree_flatten(data) 13 | for i, item in enumerate(flat_data): 14 | flat_data[i] = device_put(item) 15 | return tree_unflatten(data_tree_struct, flat_data) 16 | -------------------------------------------------------------------------------- /kerax/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from functools import partial 3 | from typing import Tuple, List, Optional, Dict, Any 4 | 5 | from tqdm import tqdm 6 | 7 | from . import Tensor, jit, grad, device_put, jnp, random, stax 8 | from ..data import Dataloader 9 | from ..optimizers import OptimizerState 10 | 11 | 12 | class Trainer: 13 | """ 14 | Trainer utility class for training the models. 15 | """ 16 | 17 | def __init__(self, config: Dict[str, Any]): 18 | """ 19 | Initializes the Trainer class with particular configuration. 20 | :param config: Configuration containing the optimizer function and the layers. 21 | """ 22 | self.config: Dict[str, Any] = config 23 | self.mode = "train" 24 | self.opt_init, self.opt_update, self.fetch_params = config.get("_optimizer") 25 | self.setup_params, self.forward_pass = stax.serial(*config.get("_layers")) 26 | 27 | def initialize_params(self, input_shape: List[int]): 28 | """ 29 | Initializes the network parameters. 30 | If already trained, then it will return the trained parameters. 31 | :param input_shape: Shape of the inputs for properly initializing the parms. 32 | :return: the network parameters. 33 | """ 34 | trained_params: List[Optional[Tuple[Tensor, Tensor]]] = self.config.get("_trained_params") 35 | if len(trained_params) > 0: 36 | return trained_params 37 | else: 38 | rng = random.PRNGKey(self.config.get("_seed")) 39 | input_shape[0] = -1 40 | input_shape = tuple(input_shape) 41 | _, params = self.setup_params(rng=rng, input_shape=input_shape) 42 | return params 43 | 44 | def train(self, data: Dataloader): 45 | """ 46 | Trains the network 47 | :param data: Dataloader object containing the dataset. 48 | :return: the configuration of the training network in the form of dictionary. 49 | """ 50 | network_params = self.initialize_params(list(data.input_shape)) 51 | opt_state: OptimizerState = self.opt_init(network_params) 52 | iter_count = itertools.count() 53 | progress_bar = tqdm(iterable=range(self.config.get("_epochs")), 54 | desc="Training model", leave=True) 55 | for epoch in progress_bar: 56 | progress_bar.set_description(desc=f"Epoch {epoch + 1}") 57 | self.mode = "train" 58 | for _ in range(data.num_train_batches): 59 | train_batch = device_put(next(data.train_data)) 60 | opt_state = self.step(next(iter_count), opt_state, train_batch) 61 | network_params = self.fetch_params(opt_state) 62 | self.calculate_metrics(network_params, train_batch) 63 | network_params = self.fetch_params(opt_state) 64 | self.mode = "valid" 65 | for _ in range(data.num_val_batches): 66 | valid_batch = device_put(next(data.val_data)) 67 | self.calculate_metrics(network_params, valid_batch) 68 | self.calculate_epoch_metrics(data) 69 | progress_bar.set_postfix_str(self.pretty_print_metrics()) 70 | progress_bar.refresh() 71 | self.config["_trained_params"] = self.fetch_params(opt_state) 72 | return self.config 73 | 74 | @partial(jit, static_argnums=(0,)) 75 | def step(self, i, opt_state, batch): 76 | """ 77 | Training step for the optimization process. 78 | :param i: Iteration count 79 | :param opt_state: State of the optimizer 80 | :param batch: Batch of data for the optimizer to work with. 81 | :return: the updates state of the optimizer. 82 | """ 83 | params = self.fetch_params(opt_state) 84 | grads = grad(self.compute_loss)(params, batch) 85 | return self.opt_update(i, grads, opt_state) 86 | 87 | @partial(jit, static_argnums=(0,)) 88 | def compute_loss(self, params, batch): 89 | """ 90 | Helper function to compute forward pass as well as the loss at every step in the training process. 91 | :param params: Network parameters 92 | :param batch: Batch of data to compute predictions and loss value 93 | :return: the computed loss value 94 | """ 95 | inputs, targets = batch 96 | predictions = self.forward_pass(params, inputs, mode=self.mode) 97 | return jit(self.config.get("_loss_fn"))(predictions, targets) 98 | 99 | def calculate_metrics(self, params, batch): 100 | """ 101 | Helper function that computes the metrics at every step in the training process. 102 | :param params: Network parameters 103 | :param batch: Batch of data to compute the metrics. 104 | """ 105 | inputs, targets = batch 106 | predictions = self.forward_pass(params, inputs, mode=self.mode) 107 | self.config.get("_metrics")["loss"][self.mode].append(self.compute_loss(params, batch)) 108 | for metric_fn in self.config.get("_metrics_fn"): 109 | self.config.get("_metrics")[metric_fn.__name__][self.mode].append(jit(metric_fn)(predictions, targets)) 110 | 111 | def calculate_epoch_metrics(self, data: Dataloader): 112 | """ 113 | Calculates the metrics values (both training and validation) after every epoch of the training process. 114 | :param data: Dataloader object (used to fetch the number of batches) 115 | """ 116 | self.config.get("_metrics")["loss_per_epoch"]["train"].append( 117 | jnp.mean(jnp.array(self.config.get("_metrics")["loss"]["train"][-data.num_train_batches:])) 118 | ) 119 | self.config.get("_metrics")["loss_per_epoch"]["valid"].append( 120 | jnp.mean(jnp.array(self.config.get("_metrics")["loss"]["valid"][-data.num_val_batches:])) 121 | ) 122 | for metric_fn in self.config.get("_metrics_fn"): 123 | self.config.get("_metrics")[metric_fn.__name__ + "_per_epoch"]["train"]\ 124 | .append(self.config.get("_metrics")[metric_fn.__name__]["train"][-1]) 125 | self.config.get("_metrics")[metric_fn.__name__ + "_per_epoch"]["valid"] \ 126 | .append(self.config.get("_metrics")[metric_fn.__name__]["valid"][-1]) 127 | 128 | def pretty_print_metrics(self) -> str: 129 | """ 130 | Helper function to display the results (loss + metrics) during the training process 131 | :return: a string containing the values of the results. 132 | """ 133 | return " :: ".join([f"{metric_type}_{metric_name} : {metric.get(metric_type)[-1]:.3f}" 134 | for metric_name, metric in self.config.get("_metrics").items() 135 | for metric_type in metric.keys() if "epoch" not in metric_name]) 136 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=1.2.0 2 | tensorflow==2.5.1 3 | tensorflow_datasets==4.2.0 4 | matplotlib>=3.3.3 5 | jax>=0.2.7 6 | jaxlib>=0.1.57 7 | numpy>=1.19.5 8 | dill>=0.3.3 9 | tqdm>=4.55.1 10 | msgpack>=1.0.2 11 | msgpack_python>=0.5.6 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = None 4 | 5 | setup( 6 | name='kerax', 7 | version=__version__, 8 | packages=find_packages(exclude=["examples"]), 9 | python_requires="==3.7", 10 | install_requires=[ 11 | "pandas>=1.2.0", 12 | "tensorflow==2.4.0", 13 | "tensorflow_datasets==4.2.0", 14 | "matplotlib>=3.3.3", 15 | "jax>=0.2.7", 16 | "jaxlib>=0.1.57", 17 | "numpy>=1.19.5", 18 | "dill>=0.3.3", 19 | "tqdm>=4.55.1", 20 | "msgpack_python>=0.5.6". 21 | "msgpack>=1.0.2" 22 | ], 23 | url='https://github.com/umangjpatel/kerax', 24 | license='', 25 | author='Umang Patel (umangjpatel)', 26 | author_email='umangpatel1947@gmail.com', 27 | description='Keras-like APIs powered with JAX library', 28 | classifiers=[ 29 | "Programming Language :: Python :: 3.7" 30 | ] 31 | ) 32 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | 5 | from kerax.datasets import binary_tiny_mnist 6 | 7 | 8 | class TestDataloader(TestCase): 9 | 10 | def setUp(self) -> None: 11 | self.data_loader = binary_tiny_mnist.load_dataset(batch_size=200) 12 | 13 | def tearDown(self) -> None: 14 | del self.data_loader 15 | 16 | def test_dataloader_attrs(self): 17 | self.assertEqual(self.data_loader.num_train_batches, 17) 18 | self.assertEqual(self.data_loader.num_val_batches, 5) 19 | self.assertEqual(self.data_loader.input_shape, (-1, 784)) 20 | self.assertEqual(self.data_loader.batch_size, 200) 21 | 22 | def test_data_shapes(self): 23 | self.assertEqual(next(self.data_loader.train_data)[0].shape, (200, 784)) 24 | self.assertEqual(next(self.data_loader.train_data)[1].shape, (200, 1)) 25 | self.assertEqual(next(self.data_loader.val_data)[0].shape, (200, 784)) 26 | self.assertEqual(next(self.data_loader.val_data)[1].shape, (200, 1)) 27 | 28 | def test_data_items(self): 29 | train_item = next(self.data_loader.train_data) 30 | val_item = next(self.data_loader.train_data) 31 | train_inputs, train_labels = train_item 32 | val_inputs, val_labels = val_item 33 | self.assertIsInstance(train_item, tuple) 34 | self.assertIsInstance(val_item, tuple) 35 | self.assertIsInstance(train_inputs, np.ndarray) 36 | self.assertIsInstance(train_labels, np.ndarray) 37 | self.assertIsInstance(val_inputs, np.ndarray) 38 | self.assertIsInstance(val_labels, np.ndarray) 39 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | from kerax.datasets import mnist, binary_tiny_mnist 2 | from unittest import TestCase 3 | 4 | 5 | class TestDatasets(TestCase): 6 | 7 | def setUp(self) -> None: 8 | pass 9 | 10 | def tearDown(self) -> None: 11 | pass 12 | 13 | def test_binary_mnist(self) -> None: 14 | data = binary_tiny_mnist.load_dataset(batch_size=200) 15 | self.assertEqual(data.batch_size, 200) 16 | self.assertEqual(data.num_train_batches, 17) 17 | self.assertEqual(data.num_val_batches, 5) 18 | inputs, targets = next(data.train_data) 19 | self.assertEqual(inputs.shape, (200, 784)) 20 | self.assertEqual(targets.shape, (200, 1)) 21 | inputs, targets = next(data.val_data) 22 | self.assertEqual(inputs.shape, (200, 784)) 23 | self.assertEqual(targets.shape, (200, 1)) 24 | self.assertEqual(data.input_shape, (-1, 784)) 25 | 26 | def test_mnist(self) -> None: 27 | data = mnist.load_dataset(batch_size=1000) 28 | self.assertEqual(data.batch_size, 1000) 29 | self.assertEqual(data.num_train_batches, 60) 30 | self.assertEqual(data.num_val_batches, 10) 31 | inputs, targets = next(data.train_data) 32 | self.assertEqual(inputs.shape, (1000, 28, 28, 1)) 33 | self.assertEqual(targets.shape, (1000, 10)) 34 | inputs, targets = next(data.val_data) 35 | self.assertEqual(inputs.shape, (1000, 28, 28, 1)) 36 | self.assertEqual(targets.shape, (1000, 10)) 37 | self.assertEqual(data.input_shape, (-1, 28, 28, 1)) 38 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from kerax.layers import Dense, Relu, Sigmoid, LogSoftmax, Flatten, Dropout 4 | from kerax.utils import stax, random, jnp 5 | 6 | 7 | class TestLayers(TestCase): 8 | 9 | def setUp(self) -> None: 10 | pass 11 | 12 | def tearDown(self) -> None: 13 | pass 14 | 15 | @staticmethod 16 | def get_dummy_network(layers, inputs): 17 | init_params, forward_pass = stax.serial(*layers) 18 | input_shape = tuple([-1] + list(inputs.shape)[1:]) 19 | _, params = init_params(random.PRNGKey(10), input_shape) 20 | return params, forward_pass 21 | 22 | @staticmethod 23 | def get_dummy_inputs(input_shape): 24 | return random.normal(key=random.PRNGKey(42), shape=input_shape) 25 | 26 | def test_dense_layer(self): 27 | inputs = self.get_dummy_inputs(input_shape=(200, 784)) 28 | params, forward_pass = self.get_dummy_network([Dense(10)], inputs) 29 | outputs = forward_pass(params, inputs) 30 | self.assertEqual(outputs.shape, (200, 10)) 31 | 32 | def test_relu_layer(self): 33 | inputs = self.get_dummy_inputs(input_shape=(200, 784)) 34 | params, forward_pass = self.get_dummy_network([Relu], inputs) 35 | outputs = forward_pass(params, inputs) 36 | self.assertEqual(inputs.shape, outputs.shape) 37 | 38 | def test_sigmoid_layer(self): 39 | inputs = self.get_dummy_inputs(input_shape=(200, 784)) 40 | params, forward_pass = self.get_dummy_network([Sigmoid], inputs) 41 | outputs = forward_pass(params, inputs) 42 | self.assertEqual(inputs.shape, outputs.shape) 43 | 44 | def test_flatten_layer_already_flat(self): 45 | inputs = self.get_dummy_inputs(input_shape=(200, 784)) 46 | params, forward_pass = self.get_dummy_network([Flatten], inputs) 47 | outputs = forward_pass(params, inputs) 48 | self.assertEqual(inputs.shape, outputs.shape) 49 | 50 | def test_flatten_layer_not_already_flat(self): 51 | inputs = self.get_dummy_inputs(input_shape=(200, 28, 28, 1)) 52 | params, forward_pass = self.get_dummy_network([Flatten], inputs) 53 | outputs = forward_pass(params, inputs) 54 | self.assertEqual(outputs.shape, (200, 784)) 55 | 56 | def test_log_softmax_layer(self): 57 | inputs = self.get_dummy_inputs(input_shape=(200, 784)) 58 | params, forward_pass = self.get_dummy_network([LogSoftmax], inputs) 59 | outputs = forward_pass(params, inputs) 60 | self.assertEqual(outputs.shape, (200, 784)) 61 | 62 | def test_dropout_layer_mode_train(self): 63 | inputs = self.get_dummy_inputs(input_shape=(1, 5)) 64 | params, forward_pass = self.get_dummy_network([Dropout(rate=0.0)], inputs) 65 | outputs = forward_pass(params, inputs) 66 | self.assertEqual(len(outputs == 0.0), len(inputs)) 67 | 68 | def test_dropout_layer_mode_predict(self): 69 | inputs = self.get_dummy_inputs(input_shape=(1, 5)) 70 | params, forward_pass = self.get_dummy_network([Dropout(rate=0.0)], inputs) 71 | outputs = forward_pass(params, inputs, mode="predict") 72 | self.assertEqual(len(outputs != 0.0), len(inputs)) 73 | 74 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from jax.interpreters.xla import _DeviceArray 4 | 5 | from kerax.losses import BCELoss, CCELoss 6 | from kerax.utils import jnp, random 7 | 8 | 9 | class TestLossFunctions(TestCase): 10 | 11 | def setUp(self) -> None: 12 | self.keys = random.split(random.PRNGKey(42), 4) 13 | 14 | def testBCELoss(self): 15 | k1, k2 = self.keys[0], self.keys[1] 16 | binary_predictions = random.uniform(key=k1, shape=(100, 1)) 17 | binary_labels = random.permutation(key=k2, 18 | x=jnp.concatenate((jnp.zeros(shape=(50,)), jnp.ones(shape=(50,))))) 19 | loss = BCELoss(predictions=binary_predictions, targets=binary_labels) 20 | self.assertIsInstance(loss, _DeviceArray) 21 | 22 | def testCCELoss(self): 23 | k3, k4 = self.keys[2], self.keys[3] 24 | softmax_predictions = random.randint(key=k3, shape=(100, 1), minval=0, maxval=9) 25 | softmax_labels = random.randint(key=k4, shape=(100, 1), minval=0, maxval=9) 26 | loss = CCELoss(predictions=softmax_predictions, targets=softmax_labels) 27 | self.assertIsInstance(loss, _DeviceArray) 28 | 29 | def tearDown(self) -> None: 30 | del self.keys 31 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from jax.interpreters.xla import _DeviceArray 4 | from kerax.utils import jnp, random 5 | from kerax.metrics import binary_accuracy, accuracy 6 | 7 | 8 | class TestMetricsFunctions(TestCase): 9 | 10 | def setUp(self) -> None: 11 | self.keys = random.split(random.PRNGKey(42), 4) 12 | 13 | def tearDown(self) -> None: 14 | del self.keys 15 | 16 | def test_binary_accuracy(self) -> None: 17 | k1, k2 = self.keys[0], self.keys[1] 18 | predictions = random.uniform(key=k1, shape=(100, 1), minval=0, maxval=1) 19 | labels = random.randint(key=k2, shape=(100, 1), minval=0, maxval=1) 20 | acc = binary_accuracy(predictions=predictions, targets=labels, acc_thresh=0.5) 21 | self.assertIsInstance(acc, _DeviceArray) 22 | 23 | def test_accuracy(self) -> None: 24 | k1, k2 = self.keys[2], self.keys[3] 25 | predictions = random.randint(key=k1, shape=(100, 1), minval=0, maxval=9) 26 | labels = random.randint(key=k2, shape=(100, 1), minval=0, maxval=9) 27 | acc = accuracy(predictions=predictions, targets=labels) 28 | self.assertIsInstance(acc, _DeviceArray) 29 | 30 | -------------------------------------------------------------------------------- /tests/test_sequential.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from kerax.datasets import binary_tiny_mnist 4 | from kerax.layers import Dense, Sigmoid 5 | from kerax.losses import BCELoss 6 | from kerax.metrics import binary_accuracy 7 | from kerax.models import Sequential 8 | from kerax.optimizers import SGD 9 | from kerax.utils import Interpreter, device_put 10 | 11 | 12 | class TestSequential(TestCase): 13 | 14 | def setUp(self) -> None: 15 | self.data = binary_tiny_mnist.load_dataset(batch_size=200) 16 | self.binary_model = Sequential(layers=[Dense(1), Sigmoid]) 17 | 18 | def tearDown(self) -> None: 19 | del self.data 20 | del self.binary_model 21 | 22 | def test_attrs_after_init(self) -> None: 23 | self.assertNotEqual(self.binary_model._layers, None) 24 | self.assertEqual(self.binary_model._epochs, 1) 25 | self.assertEqual(self.binary_model._trained_params, []) 26 | self.assertEqual(self.binary_model._loss_fn, None) 27 | self.assertEqual(self.binary_model._optimizer, None) 28 | self.assertEqual(", ".join(self.binary_model._metrics.keys()), "loss, loss_per_epoch") 29 | self.assertEqual(self.binary_model._metrics_fn, []) 30 | self.assertEqual(self.binary_model._seed, 0) 31 | 32 | def test_add_layers(self) -> None: 33 | new_binary_model = Sequential([Dense(100)]) + self.binary_model 34 | self.assertEqual(len(new_binary_model._layers), 3) 35 | another_binary_model = Sequential([Dense(100)]) 36 | another_binary_model.add(self.binary_model) 37 | self.assertEqual(len(another_binary_model._layers), 3) 38 | one_more_binary_model = Sequential([Dense([100])]) 39 | one_more_binary_model.add([Dense(10), Sigmoid]) 40 | self.assertEqual(len(one_more_binary_model._layers), 3) 41 | self.assertEqual(one_more_binary_model.add(1), None) 42 | 43 | def test_compile(self) -> None: 44 | loss_fn, opt_fn, metrics_fn = BCELoss, SGD(step_size=0.001), [binary_accuracy] 45 | self.binary_model.compile(loss=loss_fn, optimizer=opt_fn, metrics=metrics_fn) 46 | self.assertEqual(self.binary_model._loss_fn, loss_fn) 47 | self.assertEqual(self.binary_model._optimizer, opt_fn) 48 | self.assertEqual(self.binary_model._metrics_fn, metrics_fn) 49 | self.assertEqual(", ".join(self.binary_model._metrics.keys()), 50 | "loss, loss_per_epoch, binary_accuracy, binary_accuracy_per_epoch") 51 | 52 | def test_fit(self) -> None: 53 | self.binary_model.compile(loss=BCELoss, optimizer=SGD(step_size=0.01), metrics=[binary_accuracy]) 54 | self.binary_model.fit(data=self.data, epochs=1) 55 | self.assertNotEqual(self.binary_model._trained_params, None) 56 | self.assertEqual(len(self.binary_model._metrics.keys()), 4) 57 | 58 | def test_predict(self) -> None: 59 | self.binary_model.compile(loss=BCELoss, optimizer=SGD(step_size=0.01), metrics=[binary_accuracy]) 60 | self.binary_model.fit(data=self.data, epochs=1) 61 | inputs = device_put(next(self.data.train_data)[0]) 62 | self.assertEqual(inputs.shape, (200, 784)) 63 | outputs = self.binary_model.predict(inputs) 64 | self.assertEqual(outputs.shape, (200, 1)) 65 | 66 | def test_save_and_load_and_train(self) -> None: 67 | from pathlib import Path 68 | self.binary_model.compile(loss=BCELoss, optimizer=SGD(step_size=0.01), metrics=[binary_accuracy]) 69 | self.binary_model.fit(data=self.data, epochs=1) 70 | self.binary_model.save("dummy_binary_model") 71 | path = Path(__file__).parent / "dummy_binary_model.msgpack" 72 | self.assertTrue(path.exists()) 73 | loaded_model = Sequential() 74 | loaded_model.load("dummy_binary_model") 75 | self.assertEqual(len(loaded_model._layers), 2) 76 | self.assertNotEqual(loaded_model._optimizer, None) 77 | self.assertEqual(len(loaded_model._trained_params), 2) 78 | self.assertEqual(len(loaded_model._metrics_fn), 1) 79 | self.assertNotEqual(loaded_model._loss_fn, None) 80 | model = Sequential() 81 | model.load("dummy_binary_model") 82 | model.fit(data=self.data, epochs=1) 83 | 84 | def test_get_interpreter(self) -> None: 85 | self.binary_model.compile(loss=BCELoss, optimizer=SGD(step_size=0.01), metrics=[binary_accuracy]) 86 | self.binary_model.fit(data=self.data, epochs=1) 87 | interp = self.binary_model.get_interpretation() 88 | self.assertIsInstance(interp, Interpreter) 89 | self.assertEqual(interp._config.get("epochs"), 1) 90 | self.assertEqual(len(interp._config.get("metrics").keys()), 4) 91 | --------------------------------------------------------------------------------