├── example_binary.py ├── example_multiclass.py ├── LICENSE ├── setup.py ├── .gitignore ├── keras_balanced_batch_generator.py └── README.md /example_binary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras 3 | from keras_balanced_batch_generator import make_generator 4 | 5 | def example_binary(): 6 | num_samples = 100 7 | num_classes = 2 8 | input_shape = (2,) 9 | batch_size = 16 10 | 11 | x = np.random.rand(num_samples, *input_shape) 12 | y = np.random.randint(low=0, high=num_classes, size=num_samples) 13 | y = keras.utils.to_categorical(y) 14 | 15 | generator = make_generator(x, y, batch_size, categorical=False) 16 | 17 | model = keras.models.Sequential() 18 | model.add(keras.layers.Dense(32, input_shape=input_shape, activation='relu')) 19 | model.add(keras.layers.Dense(1, activation='sigmoid')) 20 | model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy']) 21 | model.fit(generator, steps_per_epoch=10, epochs=5) 22 | 23 | if __name__ == '__main__': 24 | example_binary() 25 | -------------------------------------------------------------------------------- /example_multiclass.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras 3 | from keras_balanced_batch_generator import make_generator 4 | 5 | def example_multiclass(): 6 | num_samples = 100 7 | num_classes = 3 8 | input_shape = (2,) 9 | batch_size = 16 10 | 11 | x = np.random.rand(num_samples, *input_shape) 12 | y = np.random.randint(low=0, high=num_classes, size=num_samples) 13 | y = keras.utils.to_categorical(y) 14 | 15 | generator = make_generator(x, y, batch_size) 16 | 17 | model = keras.models.Sequential() 18 | model.add(keras.layers.Dense(32, input_shape=input_shape, activation='relu')) 19 | model.add(keras.layers.Dense(num_classes, activation='softmax')) 20 | model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 21 | model.fit(generator, steps_per_epoch=10, epochs=5) 22 | 23 | if __name__ == '__main__': 24 | example_multiclass() 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Soroush Javadi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open('README.md', 'r') as f: 4 | long_description = f.read() 5 | 6 | setuptools.setup( 7 | name='keras-balanced-batch-generator', 8 | version='0.0.3', 9 | url='https://github.com/soroushj/keras-balanced-batch-generator', 10 | author='Soroush Javadi', 11 | author_email='soroush.javadi@gmail.com', 12 | license='MIT', 13 | description='A Keras-compatible generator for creating balanced batches', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | keywords=[ 17 | 'keras', 18 | 'generator', 19 | ], 20 | classifiers=[ 21 | 'Development Status :: 4 - Beta', 22 | 'Programming Language :: Python :: 3', 23 | 'License :: OSI Approved :: MIT License', 24 | 'Operating System :: OS Independent', 25 | 'Topic :: Software Development :: Libraries', 26 | 'Topic :: Software Development :: Libraries :: Python Modules', 27 | 'Intended Audience :: Developers', 28 | 'Intended Audience :: Education', 29 | 'Intended Audience :: Science/Research', 30 | ], 31 | install_requires=[ 32 | 'numpy>=1.0.0', 33 | ], 34 | python_requires='>=3.0', 35 | py_modules=['keras_balanced_batch_generator'], 36 | ) 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | -------------------------------------------------------------------------------- /keras_balanced_batch_generator.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | def make_generator(x, y, batch_size, 5 | categorical=True, 6 | seed=None): 7 | """A Keras-compatible generator for creating balanced batches. 8 | 9 | This generator loops over its data indefinitely and yields balanced, 10 | shuffled batches. 11 | 12 | Args: 13 | x (numpy.ndarray): Input data. Must have the same length as `y`. 14 | y (numpy.ndarray): Target data. Must be a binary class matrix (i.e., 15 | shape `(num_samples, num_classes)`). You can use 16 | `tf.keras.utils.to_categorical` to convert a class vector to a binary 17 | class matrix. 18 | batch_size (int): Batch size. 19 | categorical (bool): If true, generates binary class matrices 20 | (i.e., shape `(num_samples, num_classes)`) for batch targets. 21 | Otherwise, generates class vectors (i.e., shape `(num_samples,)`). 22 | seed: Random seed. 23 | Returns a Keras-compatible generator yielding batches as `(x, y)` tuples. 24 | """ 25 | if type(x) is not np.ndarray: 26 | raise ValueError('Arg x must be of type numpy.ndarray.') 27 | if type(y) is not np.ndarray: 28 | raise ValueError('Arg y must be of type numpy.ndarray.') 29 | if x.shape[0] != y.shape[0]: 30 | raise ValueError('Args x and y must have the same length.') 31 | if x.shape[0] < 1: 32 | raise ValueError('Args x and y must not be empty.') 33 | if len(y.shape) != 2: 34 | raise ValueError( 35 | 'Arg y must have a shape of (num_samples, num_classes). ' + 36 | 'You can use tf.keras.utils.to_categorical to convert a class vector ' + 37 | 'to a binary class matrix.' 38 | ) 39 | if type(batch_size) is not int: 40 | raise ValueError('Arg batch_size must be of type int.') 41 | if batch_size < 1: 42 | raise ValueError('Arg batch_size must be positive.') 43 | num_samples = y.shape[0] 44 | num_classes = y.shape[1] 45 | batch_x_shape = (batch_size, *x.shape[1:]) 46 | batch_y_shape = (batch_size, num_classes) if categorical else (batch_size,) 47 | indexes = [0 for _ in range(num_classes)] 48 | samples = [[] for _ in range(num_classes)] 49 | for i in range(num_samples): 50 | samples[np.argmax(y[i])].append(x[i]) 51 | for c, s in enumerate(samples): 52 | if len(s) < 1: 53 | raise ValueError('Class {} has no samples.'.format(c)) 54 | rand = random.Random(seed) 55 | while True: 56 | batch_x = np.ndarray(shape=batch_x_shape, dtype=x.dtype) 57 | batch_y = np.zeros(shape=batch_y_shape, dtype=y.dtype) 58 | for i in range(batch_size): 59 | random_class = rand.randrange(num_classes) 60 | current_index = indexes[random_class] 61 | indexes[random_class] = (current_index + 1) % len(samples[random_class]) 62 | if current_index == 0: 63 | rand.shuffle(samples[random_class]) 64 | batch_x[i] = samples[random_class][current_index] 65 | if categorical: 66 | batch_y[i][random_class] = 1 67 | else: 68 | batch_y[i] = random_class 69 | yield (batch_x, batch_y) 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-balanced-batch-generator: A Keras-compatible generator for creating balanced batches 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/keras-balanced-batch-generator.svg)](https://pypi.org/project/keras-balanced-batch-generator/) 4 | [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) 5 | 6 | ## Installation 7 | 8 | ```bash 9 | pip install keras-balanced-batch-generator 10 | ``` 11 | 12 | ## Overview 13 | 14 | This module implements an over-sampling algorithm to address the issue of class imbalance. 15 | It generates *balanced batches*, i.e., batches in which the number of samples from each class is on average the same. 16 | Generated batches are also shuffled. 17 | 18 | The generator can be easily used with Keras models' 19 | [`fit`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit) method. 20 | 21 | Currently, only [NumPy arrays](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) for single-input, single-output models are supported. 22 | 23 | ## API 24 | 25 | ```python 26 | make_generator(x, y, batch_size, 27 | categorical=True, 28 | seed=None) 29 | ``` 30 | 31 | - **`x`** *(numpy.ndarray)* Input data. Must have the same length as `y`. 32 | - **`y`** *(numpy.ndarray)* Target data. Must be a binary class matrix (i.e., shape `(num_samples, num_classes)`). 33 | You can use [`keras.utils.to_categorical`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/to_categorical) to convert a class vector to a binary class matrix. 34 | - **`batch_size`** *(int)* Batch size. 35 | - **`categorical`** *(bool)* If true, generates binary class matrices (i.e., shape `(num_samples, num_classes)`) for batch targets. 36 | Otherwise, generates class vectors (i.e., shape `(num_samples,)`). 37 | - **`seed`** Random seed (see the [docs](https://docs.python.org/3/library/random.html#random.seed)). 38 | - Returns a Keras-compatible generator yielding batches as `(x, y)` tuples. 39 | 40 | ## Usage 41 | 42 | ```python 43 | import keras 44 | from keras_balanced_batch_generator import make_generator 45 | 46 | x = ... 47 | y = ... 48 | batch_size = ... 49 | steps_per_epoch = ... 50 | model = keras.models.Sequential(...) 51 | 52 | generator = make_generator(x, y, batch_size) 53 | model.fit(generator, steps_per_epoch=steps_per_epoch) 54 | ``` 55 | 56 | ## Example: Multiclass Classification 57 | 58 | ```python 59 | import numpy as np 60 | import keras 61 | from keras_balanced_batch_generator import make_generator 62 | 63 | num_samples = 100 64 | num_classes = 3 65 | input_shape = (2,) 66 | batch_size = 16 67 | 68 | x = np.random.rand(num_samples, *input_shape) 69 | y = np.random.randint(low=0, high=num_classes, size=num_samples) 70 | y = keras.utils.to_categorical(y) 71 | 72 | generator = make_generator(x, y, batch_size) 73 | 74 | model = keras.models.Sequential() 75 | model.add(keras.layers.Dense(32, input_shape=input_shape, activation='relu')) 76 | model.add(keras.layers.Dense(num_classes, activation='softmax')) 77 | model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 78 | model.fit(generator, steps_per_epoch=10, epochs=5) 79 | ``` 80 | 81 | ## Example: Binary Classification 82 | 83 | ```python 84 | import numpy as np 85 | import keras 86 | from keras_balanced_batch_generator import make_generator 87 | 88 | num_samples = 100 89 | num_classes = 2 90 | input_shape = (2,) 91 | batch_size = 16 92 | 93 | x = np.random.rand(num_samples, *input_shape) 94 | y = np.random.randint(low=0, high=num_classes, size=num_samples) 95 | y = keras.utils.to_categorical(y) 96 | 97 | generator = make_generator(x, y, batch_size, categorical=False) 98 | 99 | model = keras.models.Sequential() 100 | model.add(keras.layers.Dense(32, input_shape=input_shape, activation='relu')) 101 | model.add(keras.layers.Dense(1, activation='sigmoid')) 102 | model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy']) 103 | model.fit(generator, steps_per_epoch=10, epochs=5) 104 | ``` 105 | --------------------------------------------------------------------------------