├── .coveragerc ├── .github └── ISSUE_TEMPLATE │ ├── a--image.md │ └── b--text.md ├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── PULL_REQUEST_TEMPLATE.md ├── README.md ├── keras_preprocessing ├── __init__.py ├── image │ ├── __init__.py │ ├── affine_transformations.py │ ├── dataframe_iterator.py │ ├── directory_iterator.py │ ├── image_data_generator.py │ ├── iterator.py │ ├── numpy_array_iterator.py │ └── utils.py ├── sequence.py └── text.py ├── setup.cfg ├── setup.py └── tests ├── image ├── affine_transformations_test.py ├── dataframe_iterator_test.py ├── directory_iterator_test.py ├── image_data_generator_test.py ├── iterator_test.py ├── numpy_array_iterator_test.py ├── test_image_api.py └── utils_test.py ├── sequence_test.py ├── test_api.py ├── test_documentation.py └── text_test.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | fail_under = 85 3 | show_missing = True 4 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/a--image.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: a) Issues related to images 3 | about: Select this if your issue is related to images (default). 4 | labels: image 5 | --- 6 | 7 | Please make sure that the boxes below are checked before you submit your issue. 8 | If your issue is an **implementation question**, please ask your question on [StackOverflow](http://stackoverflow.com/questions/tagged/keras) or [on the Keras Slack channel](https://keras-slack-autojoin.herokuapp.com/) instead of opening a GitHub issue. 9 | 10 | Thank you! 11 | 12 | - [ ] Check that you are up-to-date with the master branch of keras-preprocessing. You can update with: 13 | `pip install git+git://github.com/keras-team/keras-preprocessing.git --upgrade --no-deps` 14 | 15 | - [ ] Provide a link to a GitHub Gist of a Python script that can reproduce your issue (or just copy the script here if it is short). 16 | 17 | 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/b--text.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: a) Issues related to NLP 3 | about: Select this if your issue is related to natural language processing (NLP). 4 | labels: text 5 | --- 6 | 7 | Please make sure that the boxes below are checked before you submit your issue. 8 | If your issue is an **implementation question**, please ask your question on [StackOverflow](http://stackoverflow.com/questions/tagged/keras) or [on the Keras Slack channel](https://keras-slack-autojoin.herokuapp.com/) instead of opening a GitHub issue. 9 | 10 | Thank you! 11 | 12 | - [ ] Check that you are up-to-date with the master branch of keras-preprocessing. You can update with: 13 | `pip install git+git://github.com/keras-team/keras-preprocessing.git --upgrade --no-deps` 14 | 15 | - [ ] Provide a link to a GitHub Gist of a Python script that can reproduce your issue (or just copy the script here if it is short). 16 | 17 | 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.pyc 3 | dist/* 4 | build/* 5 | tags 6 | Keras_Preprocessing.egg-info 7 | 8 | # test-related 9 | .coverage 10 | .cache 11 | .pytest_cache 12 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | language: python 3 | matrix: 4 | include: 5 | # check code style and Python 3.6 6 | - python: 3.6 7 | env: TEST_MODE=PEP8 8 | # run tests with keras from source and Python 3.6 9 | - python: 3.6 10 | env: KERAS_HEAD=true 11 | env: TEST_MODE=TESTS 12 | # run tests with keras from PyPI and Python 3.6 13 | - python: 3.6 14 | env: TEST_MODE=TESTS 15 | # run import test and Python 3.6 16 | - python: 3.6 17 | env: TEST_MODE=IMPORTS 18 | 19 | 20 | before_install: 21 | - sudo apt-get update 22 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 23 | - bash miniconda.sh -b -p $HOME/miniconda 24 | - export PATH="$HOME/miniconda/bin:$PATH" 25 | - hash -r 26 | - conda config --set always_yes yes --set changeps1 no 27 | - conda update -q conda 28 | # Useful for debugging any issues with conda 29 | - conda info -a 30 | - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION 31 | - source activate test-environment 32 | 33 | install: 34 | - if [[ $KERAS_HEAD == "true" ]]; then 35 | pip install --no-deps git+https://github.com/keras-team/keras.git --upgrade; 36 | fi 37 | - if [[ "$TEST_MODE" == "PEP8" ]]; then 38 | pip install -e .[pep8]; 39 | elif [[ "$TEST_MODE" == "TESTS" ]]; then 40 | pip install -e .[tests]; 41 | elif [[ "$TEST_MODE" == "IMPORTS" ]]; then 42 | pip install .; 43 | fi 44 | 45 | script: 46 | - if [[ "$TEST_MODE" == "PEP8" ]]; then 47 | flake8 -v --count; 48 | elif [[ "$TEST_MODE" == "TESTS" ]]; then 49 | py.test tests --cov-config .coveragerc --cov=keras_preprocessing tests; 50 | elif [[ "$TEST_MODE" == "IMPORTS" ]]; then 51 | python -c "import keras_preprocessing; from keras_preprocessing import image; from keras_preprocessing import sequence; from keras_preprocessing import text"; 52 | fi 53 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # On Github Issues and Pull Requests 2 | 3 | Found a bug? Want to contribute changes to the codebase? Make sure to read this first. 4 | 5 | ## Update Your Environment 6 | 7 | To easily update Keras: `pip install git+https://www.github.com/keras-team/keras.git --upgrade` 8 | 9 | To easily update Keras-Preprocessing: `pip install git+https://www.github.com/keras-team/keras-preprocessing.git --upgrade` 10 | 11 | To easily update Theano: `pip install git+git://github.com/Theano/Theano.git --upgrade` 12 | 13 | To update TensorFlow: See [TensorFlow Installation instructions](https://github.com/tensorflow/tensorflow#installation) 14 | 15 | ## Bug reporting 16 | 17 | Your code doesn't work, **and you have determined that the issue lies with Keras-Preprocessing**? Follow these steps to report a bug. 18 | 19 | 1. Your bug may already be fixed. Make sure to update to the current Keras master branch and Keras-Preprocessing master branch, as well as the latest Theano/TensorFlow master branch. 20 | 21 | 2. [Search for similar issues](https://github.com/keras-team/keras-preprocessing/issues?utf8=%E2%9C%93&q=is%3Aissue). It's possible somebody has encountered this bug already. Still having a problem? Open an issue on Github to let us know. 22 | 23 | 3. Make sure you provide us with useful information about your configuration: what OS are you using? What Keras backend are you using? Are you running on GPU? If so, what is your version of Cuda, of cuDNN? What is your GPU? 24 | 25 | 4. Provide us with a script to reproduce the issue. This script should be runnable as-is and should not require external data download (use randomly generated data if you need to run a model on some test data). We recommend that you use Github Gists to post your code. Any issue that cannot be reproduced is likely to be closed. 26 | 27 | 5. If possible, take a stab at fixing the bug yourself --if you can! 28 | 29 | The more information you provide, the easier it is for us to validate that there is a bug and the faster we'll be able to take action. If you want your issue to be resolved quickly, following the steps above is crucial. 30 | 31 | ## Pull Requests 32 | 33 | We love pull requests. Here's a quick guide: 34 | 35 | 1. If your PR introduces a change in functionality, make sure you start by opening an issue to discuss whether the change should be made, and how to handle it. This will save you from having your PR closed down the road! Of course, if your PR is a simple bug fix, you don't need to do that. 36 | 37 | 2. Ensure that your environment (Keras, Keras-Preprocessing, and your backend) are up to date. See "Update Your Environment". Create a new branch for your changes. 38 | 39 | 3. Write the code (or get others to write it). This is the hard part! 40 | 41 | 4. Make sure any new function or class you introduce has proper docstrings. Make sure any code you touch still has up-to-date docstrings and documentation. **Docstring style should be respected.** In particular, they should be formatted in MarkDown, and there should be sections for `Arguments`, `Returns`, `Raises` (if applicable). Look at other docstrings in the codebase for examples. 42 | 43 | 5. Write tests. Your code should have full unit test coverage. If you want to see your PR merged promptly, this is crucial. If your PR is a bug fix, it is advisable to add a new test, which, without your fix in this PR, would have failed. 44 | 45 | 6. Run our test suite locally. It's easy: from the Keras folder, simply run: `py.test tests/`. 46 | - You will need to install the test requirements as well: `pip install -e .[tests]`. 47 | 48 | 7. Make sure all tests are passing: 49 | - with the Theano backend, on Python 2.7 and Python 3.6. Make sure you have the development version of Theano. 50 | - with the TensorFlow backend, on Python 2.7 and Python 3.6. Make sure you have the development version of TensorFlow. 51 | - with the CNTK backend, on Python 2.7 and Python 3.6. Make sure you have the development version of CNTK. 52 | - **Please Note:** all tests run on top of the very latest Keras master branch. 53 | 54 | 8. We use PEP8 syntax conventions, but we aren't dogmatic when it comes to line length. Make sure your lines stay reasonably sized, though. To make your life easier, we recommend running a PEP8 linter: 55 | - Install PEP8 packages: `pip install pep8 pytest-pep8 autopep8` 56 | - Run a standalone PEP8 check: `py.test --pep8 -m pep8` 57 | - You can automatically fix some PEP8 error by running: `autopep8 -i --select ` for example: `autopep8 -i --select E128 tests/keras/backend/test_backends.py` 58 | 59 | 9. When committing, use appropriate, descriptive commit messages. Make sure that your branch history is not a string of "bug fix", "fix", "oops", etc. When submitting your PR, squash your commits into a single commit with an appropriate commit message, to make sure the project history stays clean and readable. See ['rebase and squash'](http://rebaseandsqua.sh/) for technical help on how to squash your commits. 60 | 61 | 10. Update the documentation. If introducing new functionality, make sure you include code snippets demonstrating the usage of your new feature. 62 | 63 | 11. Submit your PR. If your changes have been approved in a previous discussion, and if you have complete (and passing) unit tests, your PR is likely to be merged promptly. 64 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | Copyright (c) 2015 - 2018, the respective contributors. 4 | All rights reserved. 5 | 6 | Each contributor holds copyright over their respective contributions. 7 | The project versioning (Git) records all such contribution source information. 8 | The initial code of this repository came from https://github.com/keras-team/keras 9 | (the Keras repository), hence, for author information regarding commits 10 | that occured earlier than the first commit in the present repository, 11 | please see the original Keras repository. 12 | 13 | LICENSE 14 | 15 | The MIT License (MIT) 16 | 17 | Permission is hereby granted, free of charge, to any person obtaining a copy 18 | of this software and associated documentation files (the "Software"), to deal 19 | in the Software without restriction, including without limitation the rights 20 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 21 | copies of the Software, and to permit persons to whom the Software is 22 | furnished to do so, subject to the following conditions: 23 | 24 | The above copyright notice and this permission notice shall be included in all 25 | copies or substantial portions of the Software. 26 | 27 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 28 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 29 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 30 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 31 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 32 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 33 | SOFTWARE. 34 | 35 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include CONTRIBUTING.md 4 | graft tests -------------------------------------------------------------------------------- /PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Summary 2 | 3 | ### Related Issues 4 | 5 | ### PR Overview 6 | 7 | - [ ] This PR requires new unit tests [y/n] (make sure tests are included) 8 | - [ ] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date) 9 | - [ ] This PR is backwards compatible [y/n] 10 | - [ ] This PR changes the current API [y/n] (all API changes need to be approved by fchollet) 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras Preprocessing 2 | 3 | ⚠️ This GitHub repository is now deprecated -- all Keras Preprocessing symbols have 4 | moved into the core Keras [repository](https://github.com/keras-team/keras) 5 | and the TensorFlow [`pip` package](https://www.tensorflow.org/install). All code 6 | changes and discussion should move to the Keras repository. 7 | 8 | For users looking for a place to start preprocessing data, consult the 9 | [preprocessing layers guide](https://keras.io/guides/preprocessing_layers/) 10 | and refer to the [data loading utilities API](https://keras.io/api/data_loading/). 11 | -------------------------------------------------------------------------------- /keras_preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | """Enables dynamic setting of underlying Keras module. 2 | """ 3 | 4 | _KERAS_BACKEND = None 5 | _KERAS_UTILS = None 6 | 7 | 8 | def set_keras_submodules(backend, utils): 9 | # Deprecated, will be removed in the future. 10 | global _KERAS_BACKEND 11 | global _KERAS_UTILS 12 | _KERAS_BACKEND = backend 13 | _KERAS_UTILS = utils 14 | 15 | 16 | def get_keras_submodule(name): 17 | # Deprecated, will be removed in the future. 18 | if name not in {'backend', 'utils'}: 19 | raise ImportError( 20 | 'Can only retrieve "backend" and "utils". ' 21 | 'Requested: %s' % name) 22 | if _KERAS_BACKEND is None: 23 | raise ImportError('You need to first `import keras` ' 24 | 'in order to use `keras_preprocessing`. ' 25 | 'For instance, you can do:\n\n' 26 | '```\n' 27 | 'import keras\n' 28 | 'from keras_preprocessing import image\n' 29 | '```\n\n' 30 | 'Or, preferably, this equivalent formulation:\n\n' 31 | '```\n' 32 | 'from keras import preprocessing\n' 33 | '```\n') 34 | if name == 'backend': 35 | return _KERAS_BACKEND 36 | elif name == 'utils': 37 | return _KERAS_UTILS 38 | 39 | 40 | __version__ = '1.1.2' 41 | -------------------------------------------------------------------------------- /keras_preprocessing/image/__init__.py: -------------------------------------------------------------------------------- 1 | """Enables dynamic setting of underlying Keras module. 2 | """ 3 | # flake8: noqa:F401 4 | from .affine_transformations import * 5 | from .dataframe_iterator import DataFrameIterator 6 | from .directory_iterator import DirectoryIterator 7 | from .image_data_generator import ImageDataGenerator 8 | from .iterator import Iterator 9 | from .numpy_array_iterator import NumpyArrayIterator 10 | from .utils import * 11 | -------------------------------------------------------------------------------- /keras_preprocessing/image/affine_transformations.py: -------------------------------------------------------------------------------- 1 | """Utilities for performing affine transformations on image data. 2 | """ 3 | import numpy as np 4 | 5 | from .utils import array_to_img, img_to_array 6 | 7 | try: 8 | import scipy 9 | # scipy.ndimage cannot be accessed until explicitly imported 10 | from scipy import ndimage 11 | except ImportError: 12 | scipy = None 13 | 14 | try: 15 | from PIL import Image as pil_image 16 | from PIL import ImageEnhance 17 | except ImportError: 18 | pil_image = None 19 | ImageEnhance = None 20 | 21 | 22 | def flip_axis(x, axis): 23 | x = np.asarray(x).swapaxes(axis, 0) 24 | x = x[::-1, ...] 25 | x = x.swapaxes(0, axis) 26 | return x 27 | 28 | 29 | def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0, 30 | fill_mode='nearest', cval=0., interpolation_order=1): 31 | """Performs a random rotation of a Numpy image tensor. 32 | 33 | # Arguments 34 | x: Input tensor. Must be 3D. 35 | rg: Rotation range, in degrees. 36 | row_axis: Index of axis for rows in the input tensor. 37 | col_axis: Index of axis for columns in the input tensor. 38 | channel_axis: Index of axis for channels in the input tensor. 39 | fill_mode: Points outside the boundaries of the input 40 | are filled according to the given mode 41 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 42 | cval: Value used for points outside the boundaries 43 | of the input if `mode='constant'`. 44 | interpolation_order: int, order of spline interpolation. 45 | see `ndimage.interpolation.affine_transform` 46 | 47 | # Returns 48 | Rotated Numpy image tensor. 49 | """ 50 | theta = np.random.uniform(-rg, rg) 51 | x = apply_affine_transform(x, 52 | theta=theta, 53 | row_axis=row_axis, 54 | col_axis=col_axis, 55 | channel_axis=channel_axis, 56 | fill_mode=fill_mode, 57 | cval=cval, 58 | order=interpolation_order) 59 | return x 60 | 61 | 62 | def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0, 63 | fill_mode='nearest', cval=0., interpolation_order=1): 64 | """Performs a random spatial shift of a Numpy image tensor. 65 | 66 | # Arguments 67 | x: Input tensor. Must be 3D. 68 | wrg: Width shift range, as a float fraction of the width. 69 | hrg: Height shift range, as a float fraction of the height. 70 | row_axis: Index of axis for rows in the input tensor. 71 | col_axis: Index of axis for columns in the input tensor. 72 | channel_axis: Index of axis for channels in the input tensor. 73 | fill_mode: Points outside the boundaries of the input 74 | are filled according to the given mode 75 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 76 | cval: Value used for points outside the boundaries 77 | of the input if `mode='constant'`. 78 | interpolation_order: int, order of spline interpolation. 79 | see `ndimage.interpolation.affine_transform` 80 | 81 | # Returns 82 | Shifted Numpy image tensor. 83 | """ 84 | h, w = x.shape[row_axis], x.shape[col_axis] 85 | tx = np.random.uniform(-hrg, hrg) * h 86 | ty = np.random.uniform(-wrg, wrg) * w 87 | x = apply_affine_transform(x, 88 | tx=tx, 89 | ty=ty, 90 | row_axis=row_axis, 91 | col_axis=col_axis, 92 | channel_axis=channel_axis, 93 | fill_mode=fill_mode, 94 | cval=cval, 95 | order=interpolation_order) 96 | return x 97 | 98 | 99 | def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0, 100 | fill_mode='nearest', cval=0., interpolation_order=1): 101 | """Performs a random spatial shear of a Numpy image tensor. 102 | 103 | # Arguments 104 | x: Input tensor. Must be 3D. 105 | intensity: Transformation intensity in degrees. 106 | row_axis: Index of axis for rows in the input tensor. 107 | col_axis: Index of axis for columns in the input tensor. 108 | channel_axis: Index of axis for channels in the input tensor. 109 | fill_mode: Points outside the boundaries of the input 110 | are filled according to the given mode 111 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 112 | cval: Value used for points outside the boundaries 113 | of the input if `mode='constant'`. 114 | interpolation_order: int, order of spline interpolation. 115 | see `ndimage.interpolation.affine_transform` 116 | 117 | # Returns 118 | Sheared Numpy image tensor. 119 | """ 120 | shear = np.random.uniform(-intensity, intensity) 121 | x = apply_affine_transform(x, 122 | shear=shear, 123 | row_axis=row_axis, 124 | col_axis=col_axis, 125 | channel_axis=channel_axis, 126 | fill_mode=fill_mode, 127 | cval=cval, 128 | order=interpolation_order) 129 | return x 130 | 131 | 132 | def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0, 133 | fill_mode='nearest', cval=0., interpolation_order=1): 134 | """Performs a random spatial zoom of a Numpy image tensor. 135 | 136 | # Arguments 137 | x: Input tensor. Must be 3D. 138 | zoom_range: Tuple of floats; zoom range for width and height. 139 | row_axis: Index of axis for rows in the input tensor. 140 | col_axis: Index of axis for columns in the input tensor. 141 | channel_axis: Index of axis for channels in the input tensor. 142 | fill_mode: Points outside the boundaries of the input 143 | are filled according to the given mode 144 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 145 | cval: Value used for points outside the boundaries 146 | of the input if `mode='constant'`. 147 | interpolation_order: int, order of spline interpolation. 148 | see `ndimage.interpolation.affine_transform` 149 | 150 | # Returns 151 | Zoomed Numpy image tensor. 152 | 153 | # Raises 154 | ValueError: if `zoom_range` isn't a tuple. 155 | """ 156 | if len(zoom_range) != 2: 157 | raise ValueError('`zoom_range` should be a tuple or list of two' 158 | ' floats. Received: %s' % (zoom_range,)) 159 | 160 | if zoom_range[0] == 1 and zoom_range[1] == 1: 161 | zx, zy = 1, 1 162 | else: 163 | zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) 164 | x = apply_affine_transform(x, 165 | zx=zx, 166 | zy=zy, 167 | row_axis=row_axis, 168 | col_axis=col_axis, 169 | channel_axis=channel_axis, 170 | fill_mode=fill_mode, 171 | cval=cval, 172 | order=interpolation_order) 173 | return x 174 | 175 | 176 | def apply_channel_shift(x, intensity, channel_axis=0): 177 | """Performs a channel shift. 178 | 179 | # Arguments 180 | x: Input tensor. Must be 3D. 181 | intensity: Transformation intensity. 182 | channel_axis: Index of axis for channels in the input tensor. 183 | 184 | # Returns 185 | Numpy image tensor. 186 | 187 | """ 188 | x = np.rollaxis(x, channel_axis, 0) 189 | min_x, max_x = np.min(x), np.max(x) 190 | channel_images = [ 191 | np.clip(x_channel + intensity, 192 | min_x, 193 | max_x) 194 | for x_channel in x] 195 | x = np.stack(channel_images, axis=0) 196 | x = np.rollaxis(x, 0, channel_axis + 1) 197 | return x 198 | 199 | 200 | def random_channel_shift(x, intensity_range, channel_axis=0): 201 | """Performs a random channel shift. 202 | 203 | # Arguments 204 | x: Input tensor. Must be 3D. 205 | intensity_range: Transformation intensity. 206 | channel_axis: Index of axis for channels in the input tensor. 207 | 208 | # Returns 209 | Numpy image tensor. 210 | """ 211 | intensity = np.random.uniform(-intensity_range, intensity_range) 212 | return apply_channel_shift(x, intensity, channel_axis=channel_axis) 213 | 214 | 215 | def apply_brightness_shift(x, brightness, scale=True): 216 | """Performs a brightness shift. 217 | 218 | # Arguments 219 | x: Input tensor. Must be 3D. 220 | brightness: Float. The new brightness value. 221 | scale: Whether to rescale the image such that minimum and maximum values 222 | are 0 and 255 respectively. 223 | Default: True. 224 | 225 | # Returns 226 | Numpy image tensor. 227 | 228 | # Raises 229 | ImportError: if PIL is not available. 230 | """ 231 | if ImageEnhance is None: 232 | raise ImportError('Using brightness shifts requires PIL. ' 233 | 'Install PIL or Pillow.') 234 | x_min, x_max = np.min(x), np.max(x) 235 | local_scale = (x_min < 0) or (x_max > 255) 236 | x = array_to_img(x, scale=local_scale or scale) 237 | x = imgenhancer_Brightness = ImageEnhance.Brightness(x) 238 | x = imgenhancer_Brightness.enhance(brightness) 239 | x = img_to_array(x) 240 | if not scale and local_scale: 241 | x = x / 255 * (x_max - x_min) + x_min 242 | return x 243 | 244 | 245 | def random_brightness(x, brightness_range, scale=True): 246 | """Performs a random brightness shift. 247 | 248 | # Arguments 249 | x: Input tensor. Must be 3D. 250 | brightness_range: Tuple of floats; brightness range. 251 | scale: Whether to rescale the image such that minimum and maximum values 252 | are 0 and 255 respectively. 253 | Default: True. 254 | 255 | # Returns 256 | Numpy image tensor. 257 | 258 | # Raises 259 | ValueError if `brightness_range` isn't a tuple. 260 | """ 261 | if len(brightness_range) != 2: 262 | raise ValueError( 263 | '`brightness_range should be tuple or list of two floats. ' 264 | 'Received: %s' % (brightness_range,)) 265 | 266 | u = np.random.uniform(brightness_range[0], brightness_range[1]) 267 | return apply_brightness_shift(x, u, scale) 268 | 269 | 270 | def transform_matrix_offset_center(matrix, x, y): 271 | o_x = float(x) / 2 - 0.5 272 | o_y = float(y) / 2 - 0.5 273 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 274 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 275 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) 276 | return transform_matrix 277 | 278 | 279 | def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1, 280 | row_axis=1, col_axis=2, channel_axis=0, 281 | fill_mode='nearest', cval=0., order=1): 282 | """Applies an affine transformation specified by the parameters given. 283 | 284 | # Arguments 285 | x: 3D numpy array - a 2D image with one or more channels. 286 | theta: Rotation angle in degrees. 287 | tx: Width shift. 288 | ty: Heigh shift. 289 | shear: Shear angle in degrees. 290 | zx: Zoom in x direction. 291 | zy: Zoom in y direction 292 | row_axis: Index of axis for rows (aka Y axis) in the input image. 293 | Direction: left to right. 294 | col_axis: Index of axis for columns (aka X axis) in the input image. 295 | Direction: top to bottom. 296 | channel_axis: Index of axis for channels in the input image. 297 | fill_mode: Points outside the boundaries of the input 298 | are filled according to the given mode 299 | (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). 300 | cval: Value used for points outside the boundaries 301 | of the input if `mode='constant'`. 302 | order: int, order of interpolation 303 | 304 | # Returns 305 | The transformed version of the input. 306 | """ 307 | if scipy is None: 308 | raise ImportError('Image transformations require SciPy. ' 309 | 'Install SciPy.') 310 | 311 | # Input sanity checks: 312 | # 1. x must 2D image with one or more channels (i.e., a 3D tensor) 313 | # 2. channels must be either first or last dimension 314 | if np.unique([row_axis, col_axis, channel_axis]).size != 3: 315 | raise ValueError("'row_axis', 'col_axis', and 'channel_axis'" 316 | " must be distinct") 317 | 318 | # TODO: shall we support negative indices? 319 | valid_indices = set([0, 1, 2]) 320 | actual_indices = set([row_axis, col_axis, channel_axis]) 321 | if actual_indices != valid_indices: 322 | raise ValueError( 323 | f"Invalid axis' indices: {actual_indices - valid_indices}") 324 | 325 | if x.ndim != 3: 326 | raise ValueError("Input arrays must be multi-channel 2D images.") 327 | if channel_axis not in [0, 2]: 328 | raise ValueError("Channels are allowed and the first and last dimensions.") 329 | 330 | transform_matrix = None 331 | if theta != 0: 332 | theta = np.deg2rad(theta) 333 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 334 | [np.sin(theta), np.cos(theta), 0], 335 | [0, 0, 1]]) 336 | transform_matrix = rotation_matrix 337 | 338 | if tx != 0 or ty != 0: 339 | shift_matrix = np.array([[1, 0, tx], 340 | [0, 1, ty], 341 | [0, 0, 1]]) 342 | if transform_matrix is None: 343 | transform_matrix = shift_matrix 344 | else: 345 | transform_matrix = np.dot(transform_matrix, shift_matrix) 346 | 347 | if shear != 0: 348 | shear = np.deg2rad(shear) 349 | shear_matrix = np.array([[1, -np.sin(shear), 0], 350 | [0, np.cos(shear), 0], 351 | [0, 0, 1]]) 352 | if transform_matrix is None: 353 | transform_matrix = shear_matrix 354 | else: 355 | transform_matrix = np.dot(transform_matrix, shear_matrix) 356 | 357 | if zx != 1 or zy != 1: 358 | zoom_matrix = np.array([[zx, 0, 0], 359 | [0, zy, 0], 360 | [0, 0, 1]]) 361 | if transform_matrix is None: 362 | transform_matrix = zoom_matrix 363 | else: 364 | transform_matrix = np.dot(transform_matrix, zoom_matrix) 365 | 366 | if transform_matrix is not None: 367 | h, w = x.shape[row_axis], x.shape[col_axis] 368 | transform_matrix = transform_matrix_offset_center( 369 | transform_matrix, h, w) 370 | x = np.rollaxis(x, channel_axis, 0) 371 | 372 | # Matrix construction assumes that coordinates are x, y (in that order). 373 | # However, regular numpy arrays use y,x (aka i,j) indexing. 374 | # Possible solution is: 375 | # 1. Swap the x and y axes. 376 | # 2. Apply transform. 377 | # 3. Swap the x and y axes again to restore image-like data ordering. 378 | # Mathematically, it is equivalent to the following transformation: 379 | # M' = PMP, where P is the permutation matrix, M is the original 380 | # transformation matrix. 381 | if col_axis > row_axis: 382 | transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]] 383 | transform_matrix[[0, 1]] = transform_matrix[[1, 0]] 384 | final_affine_matrix = transform_matrix[:2, :2] 385 | final_offset = transform_matrix[:2, 2] 386 | 387 | channel_images = [ndimage.interpolation.affine_transform( 388 | x_channel, 389 | final_affine_matrix, 390 | final_offset, 391 | order=order, 392 | mode=fill_mode, 393 | cval=cval) for x_channel in x] 394 | x = np.stack(channel_images, axis=0) 395 | x = np.rollaxis(x, 0, channel_axis + 1) 396 | return x 397 | -------------------------------------------------------------------------------- /keras_preprocessing/image/dataframe_iterator.py: -------------------------------------------------------------------------------- 1 | """Utilities for real-time data augmentation on image data. 2 | """ 3 | import os 4 | import warnings 5 | from collections import OrderedDict 6 | 7 | import numpy as np 8 | 9 | from .iterator import BatchFromFilesMixin, Iterator 10 | from .utils import validate_filename 11 | 12 | 13 | class DataFrameIterator(BatchFromFilesMixin, Iterator): 14 | """Iterator capable of reading images from a directory on disk 15 | through a dataframe. 16 | 17 | # Arguments 18 | dataframe: Pandas dataframe containing the filepaths relative to 19 | `directory` (or absolute paths if `directory` is None) of the 20 | images in a string column. It should include other column/s 21 | depending on the `class_mode`: 22 | - if `class_mode` is `"categorical"` (default value) it must 23 | include the `y_col` column with the class/es of each image. 24 | Values in column can be string/list/tuple if a single class 25 | or list/tuple if multiple classes. 26 | - if `class_mode` is `"binary"` or `"sparse"` it must include 27 | the given `y_col` column with class values as strings. 28 | - if `class_mode` is `"raw"` or `"multi_output"` it should contain 29 | the columns specified in `y_col`. 30 | - if `class_mode` is `"input"` or `None` no extra column is needed. 31 | directory: string, path to the directory to read images from. If `None`, 32 | data in `x_col` column should be absolute paths. 33 | image_data_generator: Instance of `ImageDataGenerator` to use for 34 | random transformations and normalization. If None, no transformations 35 | and normalizations are made. 36 | x_col: string, column in `dataframe` that contains the filenames (or 37 | absolute paths if `directory` is `None`). 38 | y_col: string or list, column/s in `dataframe` that has the target data. 39 | weight_col: string, column in `dataframe` that contains the sample 40 | weights. Default: `None`. 41 | target_size: tuple of integers, dimensions to resize input images to. 42 | color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. 43 | Color mode to read images. 44 | classes: Optional list of strings, classes to use (e.g. `["dogs", "cats"]`). 45 | If None, all classes in `y_col` will be used. 46 | class_mode: one of "binary", "categorical", "input", "multi_output", 47 | "raw", "sparse" or None. Default: "categorical". 48 | Mode for yielding the targets: 49 | - `"binary"`: 1D numpy array of binary labels, 50 | - `"categorical"`: 2D numpy array of one-hot encoded labels. 51 | Supports multi-label output. 52 | - `"input"`: images identical to input images (mainly used to 53 | work with autoencoders), 54 | - `"multi_output"`: list with the values of the different columns, 55 | - `"raw"`: numpy array of values in `y_col` column(s), 56 | - `"sparse"`: 1D numpy array of integer labels, 57 | - `None`, no targets are returned (the generator will only yield 58 | batches of image data, which is useful to use in 59 | `model.predict_generator()`). 60 | batch_size: Integer, size of a batch. 61 | shuffle: Boolean, whether to shuffle the data between epochs. 62 | seed: Random seed for data shuffling. 63 | data_format: String, one of `channels_first`, `channels_last`. 64 | save_to_dir: Optional directory where to save the pictures 65 | being yielded, in a viewable format. This is useful 66 | for visualizing the random transformations being 67 | applied, for debugging purposes. 68 | save_prefix: String prefix to use for saving sample 69 | images (if `save_to_dir` is set). 70 | save_format: Format to use for saving sample images 71 | (if `save_to_dir` is set). 72 | subset: Subset of data (`"training"` or `"validation"`) if 73 | validation_split is set in ImageDataGenerator. 74 | interpolation: Interpolation method used to resample the image if the 75 | target size is different from that of the loaded image. 76 | Supported methods are "nearest", "bilinear", and "bicubic". 77 | If PIL version 1.1.3 or newer is installed, "lanczos" is also 78 | supported. If PIL version 3.4.0 or newer is installed, "box" and 79 | "hamming" are also supported. By default, "nearest" is used. 80 | keep_aspect_ratio: Boolean, whether to resize images to a target size 81 | without aspect ratio distortion. The image is cropped in the center 82 | with target aspect ratio before resizing. 83 | dtype: Dtype to use for the generated arrays. 84 | validate_filenames: Boolean, whether to validate image filenames in 85 | `x_col`. If `True`, invalid images will be ignored. Disabling this option 86 | can lead to speed-up in the instantiation of this class. Default: `True`. 87 | """ 88 | allowed_class_modes = { 89 | 'binary', 'categorical', 'input', 'multi_output', 'raw', 'sparse', None 90 | } 91 | 92 | def __new__(cls, *args, **kwargs): 93 | try: 94 | from tensorflow.keras.utils import Sequence as TFSequence 95 | if TFSequence not in cls.__bases__: 96 | cls.__bases__ = cls.__bases__ + (TFSequence,) 97 | except ImportError: 98 | pass 99 | return super(DataFrameIterator, cls).__new__(cls) 100 | 101 | def __init__(self, 102 | dataframe, 103 | directory=None, 104 | image_data_generator=None, 105 | x_col="filename", 106 | y_col="class", 107 | weight_col=None, 108 | target_size=(256, 256), 109 | color_mode='rgb', 110 | classes=None, 111 | class_mode='categorical', 112 | batch_size=32, 113 | shuffle=True, 114 | seed=None, 115 | data_format='channels_last', 116 | save_to_dir=None, 117 | save_prefix='', 118 | save_format='png', 119 | subset=None, 120 | interpolation='nearest', 121 | keep_aspect_ratio=False, 122 | dtype='float32', 123 | validate_filenames=True): 124 | 125 | super(DataFrameIterator, self).set_processing_attrs(image_data_generator, 126 | target_size, 127 | color_mode, 128 | data_format, 129 | save_to_dir, 130 | save_prefix, 131 | save_format, 132 | subset, 133 | interpolation, 134 | keep_aspect_ratio) 135 | df = dataframe.copy() 136 | self.directory = directory or '' 137 | self.class_mode = class_mode 138 | self.dtype = dtype 139 | # check that inputs match the required class_mode 140 | self._check_params(df, x_col, y_col, weight_col, classes) 141 | if validate_filenames: # check which image files are valid and keep them 142 | df = self._filter_valid_filepaths(df, x_col) 143 | if class_mode not in ["input", "multi_output", "raw", None]: 144 | df, classes = self._filter_classes(df, y_col, classes) 145 | num_classes = len(classes) 146 | # build an index of all the unique classes 147 | self.class_indices = dict(zip(classes, range(len(classes)))) 148 | # retrieve only training or validation set 149 | if self.split: 150 | num_files = len(df) 151 | start = int(self.split[0] * num_files) 152 | stop = int(self.split[1] * num_files) 153 | df = df.iloc[start: stop, :] 154 | # get labels for each observation 155 | if class_mode not in ["input", "multi_output", "raw", None]: 156 | self.classes = self.get_classes(df, y_col) 157 | self.filenames = df[x_col].tolist() 158 | self._sample_weight = df[weight_col].values if weight_col else None 159 | 160 | if class_mode == "multi_output": 161 | self._targets = [np.array(df[col].tolist()) for col in y_col] 162 | if class_mode == "raw": 163 | self._targets = df[y_col].values 164 | self.samples = len(self.filenames) 165 | validated_string = 'validated' if validate_filenames else 'non-validated' 166 | if class_mode in ["input", "multi_output", "raw", None]: 167 | print('Found {} {} image filenames.' 168 | .format(self.samples, validated_string)) 169 | else: 170 | print('Found {} {} image filenames belonging to {} classes.' 171 | .format(self.samples, validated_string, num_classes)) 172 | self._filepaths = [ 173 | os.path.join(self.directory, fname) for fname in self.filenames 174 | ] 175 | super(DataFrameIterator, self).__init__(self.samples, 176 | batch_size, 177 | shuffle, 178 | seed) 179 | 180 | def _check_params(self, df, x_col, y_col, weight_col, classes): 181 | # check class mode is one of the currently supported 182 | if self.class_mode not in self.allowed_class_modes: 183 | raise ValueError('Invalid class_mode: {}; expected one of: {}' 184 | .format(self.class_mode, self.allowed_class_modes)) 185 | # check that y_col has several column names if class_mode is multi_output 186 | if (self.class_mode == 'multi_output') and not isinstance(y_col, list): 187 | raise TypeError( 188 | 'If class_mode="{}", y_col must be a list. Received {}.' 189 | .format(self.class_mode, type(y_col).__name__) 190 | ) 191 | # check that filenames/filepaths column values are all strings 192 | if not all(df[x_col].apply(lambda x: isinstance(x, str))): 193 | raise TypeError('All values in column x_col={} must be strings.' 194 | .format(x_col)) 195 | # check labels are string if class_mode is binary or sparse 196 | if self.class_mode in {'binary', 'sparse'}: 197 | if not all(df[y_col].apply(lambda x: isinstance(x, str))): 198 | raise TypeError('If class_mode="{}", y_col="{}" column ' 199 | 'values must be strings.' 200 | .format(self.class_mode, y_col)) 201 | # check that if binary there are only 2 different classes 202 | if self.class_mode == 'binary': 203 | if classes: 204 | classes = set(classes) 205 | if len(classes) != 2: 206 | raise ValueError('If class_mode="binary" there must be 2 ' 207 | 'classes. {} class/es were given.' 208 | .format(len(classes))) 209 | elif df[y_col].nunique() != 2: 210 | raise ValueError('If class_mode="binary" there must be 2 classes. ' 211 | 'Found {} classes.'.format(df[y_col].nunique())) 212 | # check values are string, list or tuple if class_mode is categorical 213 | if self.class_mode == 'categorical': 214 | types = (str, list, tuple) 215 | if not all(df[y_col].apply(lambda x: isinstance(x, types))): 216 | raise TypeError('If class_mode="{}", y_col="{}" column ' 217 | 'values must be type string, list or tuple.' 218 | .format(self.class_mode, y_col)) 219 | # raise warning if classes are given but will be unused 220 | if classes and self.class_mode in {"input", "multi_output", "raw", None}: 221 | warnings.warn('`classes` will be ignored given the class_mode="{}"' 222 | .format(self.class_mode)) 223 | # check that if weight column that the values are numerical 224 | if weight_col and not issubclass(df[weight_col].dtype.type, np.number): 225 | raise TypeError('Column weight_col={} must be numeric.' 226 | .format(weight_col)) 227 | 228 | def get_classes(self, df, y_col): 229 | labels = [] 230 | for label in df[y_col]: 231 | if isinstance(label, (list, tuple)): 232 | labels.append([self.class_indices[lbl] for lbl in label]) 233 | else: 234 | labels.append(self.class_indices[label]) 235 | return labels 236 | 237 | @staticmethod 238 | def _filter_classes(df, y_col, classes): 239 | df = df.copy() 240 | 241 | def remove_classes(labels, classes): 242 | if isinstance(labels, (list, tuple)): 243 | labels = [cls for cls in labels if cls in classes] 244 | return labels or None 245 | elif isinstance(labels, str): 246 | return labels if labels in classes else None 247 | else: 248 | raise TypeError( 249 | "Expect string, list or tuple but found {} in {} column " 250 | .format(type(labels), y_col) 251 | ) 252 | 253 | if classes: 254 | # prepare for membership lookup 255 | classes = list(OrderedDict.fromkeys(classes).keys()) 256 | df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes)) 257 | else: 258 | classes = set() 259 | for v in df[y_col]: 260 | if isinstance(v, (list, tuple)): 261 | classes.update(v) 262 | else: 263 | classes.add(v) 264 | classes = sorted(classes) 265 | return df.dropna(subset=[y_col]), classes 266 | 267 | def _filter_valid_filepaths(self, df, x_col): 268 | """Keep only dataframe rows with valid filenames 269 | 270 | # Arguments 271 | df: Pandas dataframe containing filenames in a column 272 | x_col: string, column in `df` that contains the filenames or filepaths 273 | 274 | # Returns 275 | absolute paths to image files 276 | """ 277 | filepaths = df[x_col].map( 278 | lambda fname: os.path.join(self.directory, fname) 279 | ) 280 | mask = filepaths.apply(validate_filename, args=(self.white_list_formats,)) 281 | n_invalid = (~mask).sum() 282 | if n_invalid: 283 | warnings.warn( 284 | 'Found {} invalid image filename(s) in x_col="{}". ' 285 | 'These filename(s) will be ignored.' 286 | .format(n_invalid, x_col) 287 | ) 288 | return df[mask] 289 | 290 | @property 291 | def filepaths(self): 292 | return self._filepaths 293 | 294 | @property 295 | def labels(self): 296 | if self.class_mode in {"multi_output", "raw"}: 297 | return self._targets 298 | else: 299 | return self.classes 300 | 301 | @property 302 | def sample_weight(self): 303 | return self._sample_weight 304 | -------------------------------------------------------------------------------- /keras_preprocessing/image/directory_iterator.py: -------------------------------------------------------------------------------- 1 | """Utilities for real-time data augmentation on image data. 2 | """ 3 | import multiprocessing.pool 4 | import os 5 | 6 | import numpy as np 7 | 8 | from .iterator import BatchFromFilesMixin, Iterator 9 | from .utils import _list_valid_filenames_in_directory 10 | 11 | 12 | class DirectoryIterator(BatchFromFilesMixin, Iterator): 13 | """Iterator capable of reading images from a directory on disk. 14 | 15 | # Arguments 16 | directory: string, path to the directory to read images from. 17 | Each subdirectory in this directory will be 18 | considered to contain images from one class, 19 | or alternatively you could specify class subdirectories 20 | via the `classes` argument. 21 | image_data_generator: Instance of `ImageDataGenerator` 22 | to use for random transformations and normalization. 23 | target_size: tuple of integers, dimensions to resize input images to. 24 | color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. 25 | Color mode to read images. 26 | classes: Optional list of strings, names of subdirectories 27 | containing images from each class (e.g. `["dogs", "cats"]`). 28 | It will be computed automatically if not set. 29 | class_mode: Mode for yielding the targets: 30 | `"binary"`: binary targets (if there are only two classes), 31 | `"categorical"`: categorical targets, 32 | `"sparse"`: integer targets, 33 | `"input"`: targets are images identical to input images (mainly 34 | used to work with autoencoders), 35 | `None`: no targets get yielded (only input images are yielded). 36 | batch_size: Integer, size of a batch. 37 | shuffle: Boolean, whether to shuffle the data between epochs. 38 | If set to False, sorts the data in alphanumeric order. 39 | seed: Random seed for data shuffling. 40 | data_format: String, one of `channels_first`, `channels_last`. 41 | save_to_dir: Optional directory where to save the pictures 42 | being yielded, in a viewable format. This is useful 43 | for visualizing the random transformations being 44 | applied, for debugging purposes. 45 | save_prefix: String prefix to use for saving sample 46 | images (if `save_to_dir` is set). 47 | save_format: Format to use for saving sample images 48 | (if `save_to_dir` is set). 49 | follow_links: Boolean, follow symbolic links to subdirectories 50 | subset: Subset of data (`"training"` or `"validation"`) if 51 | validation_split is set in ImageDataGenerator. 52 | interpolation: Interpolation method used to resample the image if the 53 | target size is different from that of the loaded image. 54 | Supported methods are "nearest", "bilinear", and "bicubic". 55 | If PIL version 1.1.3 or newer is installed, "lanczos" is also 56 | supported. If PIL version 3.4.0 or newer is installed, "box" and 57 | "hamming" are also supported. By default, "nearest" is used. 58 | keep_aspect_ratio: Boolean, whether to resize images to a target size 59 | without aspect ratio distortion. The image is cropped in the center 60 | with target aspect ratio before resizing. 61 | dtype: Dtype to use for generated arrays. 62 | """ 63 | allowed_class_modes = {'categorical', 'binary', 'sparse', 'input', None} 64 | 65 | def __new__(cls, *args, **kwargs): 66 | try: 67 | from tensorflow.keras.utils import Sequence as TFSequence 68 | if TFSequence not in cls.__bases__: 69 | cls.__bases__ = cls.__bases__ + (TFSequence,) 70 | except ImportError: 71 | pass 72 | return super(DirectoryIterator, cls).__new__(cls) 73 | 74 | def __init__(self, 75 | directory, 76 | image_data_generator, 77 | target_size=(256, 256), 78 | color_mode='rgb', 79 | classes=None, 80 | class_mode='categorical', 81 | batch_size=32, 82 | shuffle=True, 83 | seed=None, 84 | data_format='channels_last', 85 | save_to_dir=None, 86 | save_prefix='', 87 | save_format='png', 88 | follow_links=False, 89 | subset=None, 90 | interpolation='nearest', 91 | keep_aspect_ratio=False, 92 | dtype='float32'): 93 | super(DirectoryIterator, self).set_processing_attrs(image_data_generator, 94 | target_size, 95 | color_mode, 96 | data_format, 97 | save_to_dir, 98 | save_prefix, 99 | save_format, 100 | subset, 101 | interpolation, 102 | keep_aspect_ratio) 103 | self.directory = directory 104 | self.classes = classes 105 | if class_mode not in self.allowed_class_modes: 106 | raise ValueError('Invalid class_mode: {}; expected one of: {}' 107 | .format(class_mode, self.allowed_class_modes)) 108 | self.class_mode = class_mode 109 | self.dtype = dtype 110 | # First, count the number of samples and classes. 111 | self.samples = 0 112 | 113 | if not classes: 114 | classes = [] 115 | for subdir in sorted(os.listdir(directory)): 116 | if os.path.isdir(os.path.join(directory, subdir)): 117 | classes.append(subdir) 118 | self.num_classes = len(classes) 119 | self.class_indices = dict(zip(classes, range(len(classes)))) 120 | 121 | pool = multiprocessing.pool.ThreadPool() 122 | 123 | # Second, build an index of the images 124 | # in the different class subfolders. 125 | results = [] 126 | self.filenames = [] 127 | i = 0 128 | for dirpath in (os.path.join(directory, subdir) for subdir in classes): 129 | results.append( 130 | pool.apply_async(_list_valid_filenames_in_directory, 131 | (dirpath, self.white_list_formats, self.split, 132 | self.class_indices, follow_links))) 133 | classes_list = [] 134 | for res in results: 135 | classes, filenames = res.get() 136 | classes_list.append(classes) 137 | self.filenames += filenames 138 | self.samples = len(self.filenames) 139 | self.classes = np.zeros((self.samples,), dtype='int32') 140 | for classes in classes_list: 141 | self.classes[i:i + len(classes)] = classes 142 | i += len(classes) 143 | 144 | print('Found %d images belonging to %d classes.' % 145 | (self.samples, self.num_classes)) 146 | pool.close() 147 | pool.join() 148 | self._filepaths = [ 149 | os.path.join(self.directory, fname) for fname in self.filenames 150 | ] 151 | super(DirectoryIterator, self).__init__(self.samples, 152 | batch_size, 153 | shuffle, 154 | seed) 155 | 156 | @property 157 | def filepaths(self): 158 | return self._filepaths 159 | 160 | @property 161 | def labels(self): 162 | return self.classes 163 | 164 | @property # mixin needs this property to work 165 | def sample_weight(self): 166 | # no sample weights will be returned 167 | return None 168 | -------------------------------------------------------------------------------- /keras_preprocessing/image/iterator.py: -------------------------------------------------------------------------------- 1 | """Utilities for real-time data augmentation on image data. 2 | """ 3 | import os 4 | import threading 5 | 6 | import numpy as np 7 | 8 | from keras_preprocessing import get_keras_submodule 9 | 10 | try: 11 | IteratorType = get_keras_submodule('utils').Sequence 12 | except ImportError: 13 | IteratorType = object 14 | 15 | from .utils import array_to_img, img_to_array, load_img 16 | 17 | 18 | class Iterator(IteratorType): 19 | """Base class for image data iterators. 20 | 21 | Every `Iterator` must implement the `_get_batches_of_transformed_samples` 22 | method. 23 | 24 | # Arguments 25 | n: Integer, total number of samples in the dataset to loop over. 26 | batch_size: Integer, size of a batch. 27 | shuffle: Boolean, whether to shuffle the data between epochs. 28 | seed: Random seeding for data shuffling. 29 | """ 30 | white_list_formats = ('png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff') 31 | 32 | def __init__(self, n, batch_size, shuffle, seed): 33 | self.n = n 34 | self.batch_size = batch_size 35 | self.seed = seed 36 | self.shuffle = shuffle 37 | self.batch_index = 0 38 | self.total_batches_seen = 0 39 | self.lock = threading.Lock() 40 | self.index_array = None 41 | self.index_generator = self._flow_index() 42 | 43 | def _set_index_array(self): 44 | self.index_array = np.arange(self.n) 45 | if self.shuffle: 46 | self.index_array = np.random.permutation(self.n) 47 | 48 | def __getitem__(self, idx): 49 | if idx >= len(self): 50 | raise ValueError('Asked to retrieve element {idx}, ' 51 | 'but the Sequence ' 52 | 'has length {length}'.format(idx=idx, 53 | length=len(self))) 54 | if self.seed is not None: 55 | np.random.seed(self.seed + self.total_batches_seen) 56 | self.total_batches_seen += 1 57 | if self.index_array is None: 58 | self._set_index_array() 59 | index_array = self.index_array[self.batch_size * idx: 60 | self.batch_size * (idx + 1)] 61 | return self._get_batches_of_transformed_samples(index_array) 62 | 63 | def __len__(self): 64 | return (self.n + self.batch_size - 1) // self.batch_size # round up 65 | 66 | def on_epoch_end(self): 67 | self._set_index_array() 68 | 69 | def reset(self): 70 | self.batch_index = 0 71 | 72 | def _flow_index(self): 73 | # Ensure self.batch_index is 0. 74 | self.reset() 75 | while 1: 76 | if self.seed is not None: 77 | np.random.seed(self.seed + self.total_batches_seen) 78 | if self.batch_index == 0: 79 | self._set_index_array() 80 | 81 | if self.n == 0: 82 | # Avoiding modulo by zero error 83 | current_index = 0 84 | else: 85 | current_index = (self.batch_index * self.batch_size) % self.n 86 | if self.n > current_index + self.batch_size: 87 | self.batch_index += 1 88 | else: 89 | self.batch_index = 0 90 | self.total_batches_seen += 1 91 | yield self.index_array[current_index: 92 | current_index + self.batch_size] 93 | 94 | def __iter__(self): 95 | # Needed if we want to do something like: 96 | # for x, y in data_gen.flow(...): 97 | return self 98 | 99 | def __next__(self, *args, **kwargs): 100 | return self.next(*args, **kwargs) 101 | 102 | def next(self): 103 | """For python 2.x. 104 | 105 | # Returns 106 | The next batch. 107 | """ 108 | with self.lock: 109 | index_array = next(self.index_generator) 110 | # The transformation of images is not under thread lock 111 | # so it can be done in parallel 112 | return self._get_batches_of_transformed_samples(index_array) 113 | 114 | def _get_batches_of_transformed_samples(self, index_array): 115 | """Gets a batch of transformed samples. 116 | 117 | # Arguments 118 | index_array: Array of sample indices to include in batch. 119 | 120 | # Returns 121 | A batch of transformed samples. 122 | """ 123 | raise NotImplementedError 124 | 125 | 126 | class BatchFromFilesMixin(): 127 | """Adds methods related to getting batches from filenames 128 | 129 | It includes the logic to transform image files to batches. 130 | """ 131 | 132 | def set_processing_attrs(self, 133 | image_data_generator, 134 | target_size, 135 | color_mode, 136 | data_format, 137 | save_to_dir, 138 | save_prefix, 139 | save_format, 140 | subset, 141 | interpolation, 142 | keep_aspect_ratio): 143 | """Sets attributes to use later for processing files into a batch. 144 | 145 | # Arguments 146 | image_data_generator: Instance of `ImageDataGenerator` 147 | to use for random transformations and normalization. 148 | target_size: tuple of integers, dimensions to resize input images to. 149 | color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. 150 | Color mode to read images. 151 | data_format: String, one of `channels_first`, `channels_last`. 152 | save_to_dir: Optional directory where to save the pictures 153 | being yielded, in a viewable format. This is useful 154 | for visualizing the random transformations being 155 | applied, for debugging purposes. 156 | save_prefix: String prefix to use for saving sample 157 | images (if `save_to_dir` is set). 158 | save_format: Format to use for saving sample images 159 | (if `save_to_dir` is set). 160 | subset: Subset of data (`"training"` or `"validation"`) if 161 | validation_split is set in ImageDataGenerator. 162 | interpolation: Interpolation method used to resample the image if the 163 | target size is different from that of the loaded image. 164 | Supported methods are "nearest", "bilinear", and "bicubic". 165 | If PIL version 1.1.3 or newer is installed, "lanczos" is also 166 | supported. If PIL version 3.4.0 or newer is installed, "box" and 167 | "hamming" are also supported. By default, "nearest" is used. 168 | """ 169 | self.image_data_generator = image_data_generator 170 | self.target_size = tuple(target_size) 171 | self.keep_aspect_ratio = keep_aspect_ratio 172 | if color_mode not in {'rgb', 'rgba', 'grayscale'}: 173 | raise ValueError('Invalid color mode:', color_mode, 174 | '; expected "rgb", "rgba", or "grayscale".') 175 | self.color_mode = color_mode 176 | self.data_format = data_format 177 | if self.color_mode == 'rgba': 178 | if self.data_format == 'channels_last': 179 | self.image_shape = self.target_size + (4,) 180 | else: 181 | self.image_shape = (4,) + self.target_size 182 | elif self.color_mode == 'rgb': 183 | if self.data_format == 'channels_last': 184 | self.image_shape = self.target_size + (3,) 185 | else: 186 | self.image_shape = (3,) + self.target_size 187 | else: 188 | if self.data_format == 'channels_last': 189 | self.image_shape = self.target_size + (1,) 190 | else: 191 | self.image_shape = (1,) + self.target_size 192 | self.save_to_dir = save_to_dir 193 | self.save_prefix = save_prefix 194 | self.save_format = save_format 195 | self.interpolation = interpolation 196 | if subset is not None: 197 | validation_split = self.image_data_generator._validation_split 198 | if subset == 'validation': 199 | split = (0, validation_split) 200 | elif subset == 'training': 201 | split = (validation_split, 1) 202 | else: 203 | raise ValueError( 204 | 'Invalid subset name: %s;' 205 | 'expected "training" or "validation"' % (subset,)) 206 | else: 207 | split = None 208 | self.split = split 209 | self.subset = subset 210 | 211 | def _get_batches_of_transformed_samples(self, index_array): 212 | """Gets a batch of transformed samples. 213 | 214 | # Arguments 215 | index_array: Array of sample indices to include in batch. 216 | 217 | # Returns 218 | A batch of transformed samples. 219 | """ 220 | batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=self.dtype) 221 | # build batch of image data 222 | # self.filepaths is dynamic, is better to call it once outside the loop 223 | filepaths = self.filepaths 224 | for i, j in enumerate(index_array): 225 | img = load_img(filepaths[j], 226 | color_mode=self.color_mode, 227 | target_size=self.target_size, 228 | interpolation=self.interpolation, 229 | keep_aspect_ratio=self.keep_aspect_ratio) 230 | x = img_to_array(img, data_format=self.data_format) 231 | # Pillow images should be closed after `load_img`, 232 | # but not PIL images. 233 | if hasattr(img, 'close'): 234 | img.close() 235 | if self.image_data_generator: 236 | params = self.image_data_generator.get_random_transform(x.shape) 237 | x = self.image_data_generator.apply_transform(x, params) 238 | x = self.image_data_generator.standardize(x) 239 | batch_x[i] = x 240 | # optionally save augmented images to disk for debugging purposes 241 | if self.save_to_dir: 242 | for i, j in enumerate(index_array): 243 | img = array_to_img(batch_x[i], self.data_format, scale=True) 244 | fname = '{prefix}_{index}_{hash}.{format}'.format( 245 | prefix=self.save_prefix, 246 | index=j, 247 | hash=np.random.randint(1e7), 248 | format=self.save_format) 249 | img.save(os.path.join(self.save_to_dir, fname)) 250 | # build batch of labels 251 | if self.class_mode == 'input': 252 | batch_y = batch_x.copy() 253 | elif self.class_mode in {'binary', 'sparse'}: 254 | batch_y = np.empty(len(batch_x), dtype=self.dtype) 255 | for i, n_observation in enumerate(index_array): 256 | batch_y[i] = self.classes[n_observation] 257 | elif self.class_mode == 'categorical': 258 | batch_y = np.zeros((len(batch_x), len(self.class_indices)), 259 | dtype=self.dtype) 260 | for i, n_observation in enumerate(index_array): 261 | batch_y[i, self.classes[n_observation]] = 1. 262 | elif self.class_mode == 'multi_output': 263 | batch_y = [output[index_array] for output in self.labels] 264 | elif self.class_mode == 'raw': 265 | batch_y = self.labels[index_array] 266 | else: 267 | return batch_x 268 | if self.sample_weight is None: 269 | return batch_x, batch_y 270 | else: 271 | return batch_x, batch_y, self.sample_weight[index_array] 272 | 273 | @property 274 | def filepaths(self): 275 | """List of absolute paths to image files""" 276 | raise NotImplementedError( 277 | '`filepaths` property method has not been implemented in {}.' 278 | .format(type(self).__name__) 279 | ) 280 | 281 | @property 282 | def labels(self): 283 | """Class labels of every observation""" 284 | raise NotImplementedError( 285 | '`labels` property method has not been implemented in {}.' 286 | .format(type(self).__name__) 287 | ) 288 | 289 | @property 290 | def sample_weight(self): 291 | raise NotImplementedError( 292 | '`sample_weight` property method has not been implemented in {}.' 293 | .format(type(self).__name__) 294 | ) 295 | -------------------------------------------------------------------------------- /keras_preprocessing/image/numpy_array_iterator.py: -------------------------------------------------------------------------------- 1 | """Utilities for real-time data augmentation on image data. 2 | """ 3 | import os 4 | import warnings 5 | 6 | import numpy as np 7 | 8 | from .iterator import Iterator 9 | from .utils import array_to_img 10 | 11 | 12 | class NumpyArrayIterator(Iterator): 13 | """Iterator yielding data from a Numpy array. 14 | 15 | # Arguments 16 | x: Numpy array of input data or tuple. 17 | If tuple, the second elements is either 18 | another numpy array or a list of numpy arrays, 19 | each of which gets passed 20 | through as an output without any modifications. 21 | y: Numpy array of targets data. 22 | image_data_generator: Instance of `ImageDataGenerator` 23 | to use for random transformations and normalization. 24 | batch_size: Integer, size of a batch. 25 | shuffle: Boolean, whether to shuffle the data between epochs. 26 | sample_weight: Numpy array of sample weights. 27 | seed: Random seed for data shuffling. 28 | data_format: String, one of `channels_first`, `channels_last`. 29 | save_to_dir: Optional directory where to save the pictures 30 | being yielded, in a viewable format. This is useful 31 | for visualizing the random transformations being 32 | applied, for debugging purposes. 33 | save_prefix: String prefix to use for saving sample 34 | images (if `save_to_dir` is set). 35 | save_format: Format to use for saving sample images 36 | (if `save_to_dir` is set). 37 | subset: Subset of data (`"training"` or `"validation"`) if 38 | validation_split is set in ImageDataGenerator. 39 | ignore_class_split: Boolean (default: False), ignore difference 40 | in number of classes in labels across train and validation 41 | split (useful for non-classification tasks) 42 | dtype: Dtype to use for the generated arrays. 43 | """ 44 | 45 | def __new__(cls, *args, **kwargs): 46 | try: 47 | from tensorflow.keras.utils import Sequence as TFSequence 48 | if TFSequence not in cls.__bases__: 49 | cls.__bases__ = cls.__bases__ + (TFSequence,) 50 | except ImportError: 51 | pass 52 | return super(NumpyArrayIterator, cls).__new__(cls) 53 | 54 | def __init__(self, 55 | x, 56 | y, 57 | image_data_generator, 58 | batch_size=32, 59 | shuffle=False, 60 | sample_weight=None, 61 | seed=None, 62 | data_format='channels_last', 63 | save_to_dir=None, 64 | save_prefix='', 65 | save_format='png', 66 | subset=None, 67 | ignore_class_split=False, 68 | dtype='float32'): 69 | self.dtype = dtype 70 | if (type(x) is tuple) or (type(x) is list): 71 | if type(x[1]) is not list: 72 | x_misc = [np.asarray(x[1])] 73 | else: 74 | x_misc = [np.asarray(xx) for xx in x[1]] 75 | x = x[0] 76 | for xx in x_misc: 77 | if len(x) != len(xx): 78 | raise ValueError( 79 | 'All of the arrays in `x` ' 80 | 'should have the same length. ' 81 | 'Found a pair with: len(x[0]) = %s, len(x[?]) = %s' % 82 | (len(x), len(xx))) 83 | else: 84 | x_misc = [] 85 | 86 | if y is not None and len(x) != len(y): 87 | raise ValueError('`x` (images tensor) and `y` (labels) ' 88 | 'should have the same length. ' 89 | 'Found: x.shape = %s, y.shape = %s' % 90 | (np.asarray(x).shape, np.asarray(y).shape)) 91 | if sample_weight is not None and len(x) != len(sample_weight): 92 | raise ValueError('`x` (images tensor) and `sample_weight` ' 93 | 'should have the same length. ' 94 | 'Found: x.shape = %s, sample_weight.shape = %s' % 95 | (np.asarray(x).shape, np.asarray(sample_weight).shape)) 96 | if subset is not None: 97 | if subset not in {'training', 'validation'}: 98 | raise ValueError('Invalid subset name:', subset, 99 | '; expected "training" or "validation".') 100 | split_idx = int(len(x) * image_data_generator._validation_split) 101 | 102 | if (y is not None and not ignore_class_split and not 103 | np.array_equal(np.unique(y[:split_idx]), 104 | np.unique(y[split_idx:]))): 105 | raise ValueError('Training and validation subsets ' 106 | 'have different number of classes after ' 107 | 'the split. If your numpy arrays are ' 108 | 'sorted by the label, you might want ' 109 | 'to shuffle them.') 110 | 111 | if subset == 'validation': 112 | x = x[:split_idx] 113 | x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc] 114 | if y is not None: 115 | y = y[:split_idx] 116 | else: 117 | x = x[split_idx:] 118 | x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc] 119 | if y is not None: 120 | y = y[split_idx:] 121 | 122 | self.x = np.asarray(x, dtype=self.dtype) 123 | self.x_misc = x_misc 124 | if self.x.ndim != 4: 125 | raise ValueError('Input data in `NumpyArrayIterator` ' 126 | 'should have rank 4. You passed an array ' 127 | 'with shape', self.x.shape) 128 | channels_axis = 3 if data_format == 'channels_last' else 1 129 | if self.x.shape[channels_axis] not in {1, 3, 4}: 130 | warnings.warn('NumpyArrayIterator is set to use the ' 131 | 'data format convention "' + data_format + '" ' 132 | '(channels on axis ' + str(channels_axis) + 133 | '), i.e. expected either 1, 3, or 4 ' 134 | 'channels on axis ' + str(channels_axis) + '. ' 135 | 'However, it was passed an array with shape ' + 136 | str(self.x.shape) + ' (' + 137 | str(self.x.shape[channels_axis]) + ' channels).') 138 | if y is not None: 139 | self.y = np.asarray(y) 140 | else: 141 | self.y = None 142 | if sample_weight is not None: 143 | self.sample_weight = np.asarray(sample_weight) 144 | else: 145 | self.sample_weight = None 146 | self.image_data_generator = image_data_generator 147 | self.data_format = data_format 148 | self.save_to_dir = save_to_dir 149 | self.save_prefix = save_prefix 150 | self.save_format = save_format 151 | super(NumpyArrayIterator, self).__init__(x.shape[0], 152 | batch_size, 153 | shuffle, 154 | seed) 155 | 156 | def _get_batches_of_transformed_samples(self, index_array): 157 | batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]), 158 | dtype=self.dtype) 159 | for i, j in enumerate(index_array): 160 | x = self.x[j] 161 | params = self.image_data_generator.get_random_transform(x.shape) 162 | x = self.image_data_generator.apply_transform( 163 | x.astype(self.dtype), params) 164 | x = self.image_data_generator.standardize(x) 165 | batch_x[i] = x 166 | 167 | if self.save_to_dir: 168 | for i, j in enumerate(index_array): 169 | img = array_to_img(batch_x[i], self.data_format, scale=True) 170 | fname = '{prefix}_{index}_{hash}.{format}'.format( 171 | prefix=self.save_prefix, 172 | index=j, 173 | hash=np.random.randint(1e4), 174 | format=self.save_format) 175 | img.save(os.path.join(self.save_to_dir, fname)) 176 | batch_x_miscs = [xx[index_array] for xx in self.x_misc] 177 | output = (batch_x if batch_x_miscs == [] 178 | else [batch_x] + batch_x_miscs,) 179 | if self.y is None: 180 | return output[0] 181 | output += (self.y[index_array],) 182 | if self.sample_weight is not None: 183 | output += (self.sample_weight[index_array],) 184 | return output 185 | -------------------------------------------------------------------------------- /keras_preprocessing/image/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for real-time data augmentation on image data. 2 | """ 3 | import io 4 | import os 5 | import warnings 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | 10 | try: 11 | from PIL import Image as pil_image 12 | from PIL import ImageEnhance 13 | except ImportError: 14 | pil_image = None 15 | ImageEnhance = None 16 | 17 | 18 | if pil_image is not None: 19 | _PIL_INTERPOLATION_METHODS = { 20 | 'nearest': pil_image.NEAREST, 21 | 'bilinear': pil_image.BILINEAR, 22 | 'bicubic': pil_image.BICUBIC, 23 | } 24 | # These methods were only introduced in version 3.4.0 (2016). 25 | if hasattr(pil_image, 'HAMMING'): 26 | _PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING 27 | if hasattr(pil_image, 'BOX'): 28 | _PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX 29 | # This method is new in version 1.1.3 (2013). 30 | if hasattr(pil_image, 'LANCZOS'): 31 | _PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS 32 | 33 | 34 | def validate_filename(filename, white_list_formats): 35 | """Check if a filename refers to a valid file. 36 | 37 | # Arguments 38 | filename: String, absolute path to a file 39 | white_list_formats: Set, allowed file extensions 40 | 41 | # Returns 42 | A boolean value indicating if the filename is valid or not 43 | """ 44 | return (filename.lower().endswith(white_list_formats) and 45 | os.path.isfile(filename)) 46 | 47 | 48 | def save_img(path, 49 | x, 50 | data_format='channels_last', 51 | file_format=None, 52 | scale=True, 53 | **kwargs): 54 | """Saves an image stored as a Numpy array to a path or file object. 55 | 56 | # Arguments 57 | path: Path or file object. 58 | x: Numpy array. 59 | data_format: Image data format, 60 | either "channels_first" or "channels_last". 61 | file_format: Optional file format override. If omitted, the 62 | format to use is determined from the filename extension. 63 | If a file object was used instead of a filename, this 64 | parameter should always be used. 65 | scale: Whether to rescale image values to be within `[0, 255]`. 66 | **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. 67 | """ 68 | img = array_to_img(x, data_format=data_format, scale=scale) 69 | if img.mode == 'RGBA' and (file_format == 'jpg' or file_format == 'jpeg'): 70 | warnings.warn('The JPG format does not support ' 71 | 'RGBA images, converting to RGB.') 72 | img = img.convert('RGB') 73 | img.save(path, format=file_format, **kwargs) 74 | 75 | 76 | def load_img(path, grayscale=False, color_mode='rgb', target_size=None, 77 | interpolation='nearest', keep_aspect_ratio=False): 78 | """Loads an image into PIL format. 79 | 80 | # Arguments 81 | path: Path (string), pathlib.Path object, or io.BytesIO stream to image file. 82 | grayscale: DEPRECATED use `color_mode="grayscale"`. 83 | color_mode: The desired image format. One of "grayscale", "rgb", "rgba". 84 | "grayscale" supports 8-bit images and 32-bit signed integer images. 85 | Default: "rgb". 86 | target_size: Either `None` (default to original size) 87 | or tuple of ints `(img_height, img_width)`. 88 | interpolation: Interpolation method used to resample the image if the 89 | target size is different from that of the loaded image. 90 | Supported methods are "nearest", "bilinear", and "bicubic". 91 | If PIL version 1.1.3 or newer is installed, "lanczos" is also 92 | supported. If PIL version 3.4.0 or newer is installed, "box" and 93 | "hamming" are also supported. 94 | Default: "nearest". 95 | keep_aspect_ratio: Boolean, whether to resize images to a target 96 | size without aspect ratio distortion. The image is cropped in 97 | the center with target aspect ratio before resizing. 98 | 99 | # Returns 100 | A PIL Image instance. 101 | 102 | # Raises 103 | ImportError: if PIL is not available. 104 | ValueError: if interpolation method is not supported. 105 | TypeError: type of 'path' should be path-like or io.Byteio. 106 | """ 107 | if grayscale is True: 108 | warnings.warn('grayscale is deprecated. Please use ' 109 | 'color_mode = "grayscale"') 110 | color_mode = 'grayscale' 111 | if pil_image is None: 112 | raise ImportError('Could not import PIL.Image. ' 113 | 'The use of `load_img` requires PIL.') 114 | if isinstance(path, io.BytesIO): 115 | img = pil_image.open(path) 116 | elif isinstance(path, (Path, bytes, str)): 117 | if isinstance(path, Path): 118 | path = str(path.resolve()) 119 | with open(path, 'rb') as f: 120 | img = pil_image.open(io.BytesIO(f.read())) 121 | else: 122 | raise TypeError('path should be path-like or io.BytesIO' 123 | ', not {}'.format(type(path))) 124 | 125 | if color_mode == 'grayscale': 126 | # if image is not already an 8-bit, 16-bit or 32-bit grayscale image 127 | # convert it to an 8-bit grayscale image. 128 | if img.mode not in ('L', 'I;16', 'I'): 129 | img = img.convert('L') 130 | elif color_mode == 'rgba': 131 | if img.mode != 'RGBA': 132 | img = img.convert('RGBA') 133 | elif color_mode == 'rgb': 134 | if img.mode != 'RGB': 135 | img = img.convert('RGB') 136 | else: 137 | raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"') 138 | if target_size is not None: 139 | width_height_tuple = (target_size[1], target_size[0]) 140 | if img.size != width_height_tuple: 141 | if interpolation not in _PIL_INTERPOLATION_METHODS: 142 | raise ValueError( 143 | 'Invalid interpolation method {} specified. Supported ' 144 | 'methods are {}'.format( 145 | interpolation, 146 | ", ".join(_PIL_INTERPOLATION_METHODS.keys()))) 147 | resample = _PIL_INTERPOLATION_METHODS[interpolation] 148 | 149 | if keep_aspect_ratio: 150 | width, height = img.size 151 | target_width, target_height = width_height_tuple 152 | 153 | crop_height = (width * target_height) // target_width 154 | crop_width = (height * target_width) // target_height 155 | 156 | # Set back to input height / width 157 | # if crop_height / crop_width is not smaller. 158 | crop_height = min(height, crop_height) 159 | crop_width = min(width, crop_width) 160 | 161 | crop_box_hstart = (height - crop_height) // 2 162 | crop_box_wstart = (width - crop_width) // 2 163 | crop_box_wend = crop_box_wstart + crop_width 164 | crop_box_hend = crop_box_hstart + crop_height 165 | crop_box = [ 166 | crop_box_wstart, crop_box_hstart, crop_box_wend, 167 | crop_box_hend 168 | ] 169 | img = img.resize(width_height_tuple, resample, box=crop_box) 170 | else: 171 | img = img.resize(width_height_tuple, resample) 172 | return img 173 | 174 | 175 | def list_pictures(directory, ext=('jpg', 'jpeg', 'bmp', 'png', 'ppm', 'tif', 176 | 'tiff')): 177 | """Lists all pictures in a directory, including all subdirectories. 178 | 179 | # Arguments 180 | directory: string, absolute path to the directory 181 | ext: tuple of strings or single string, extensions of the pictures 182 | 183 | # Returns 184 | a list of paths 185 | """ 186 | ext = tuple('.%s' % e for e in ((ext,) if isinstance(ext, str) else ext)) 187 | return [os.path.join(root, f) 188 | for root, _, files in os.walk(directory) for f in files 189 | if f.lower().endswith(ext)] 190 | 191 | 192 | def _iter_valid_files(directory, white_list_formats, follow_links): 193 | """Iterates on files with extension in `white_list_formats` contained in `directory`. 194 | 195 | # Arguments 196 | directory: Absolute path to the directory 197 | containing files to be counted 198 | white_list_formats: Set of strings containing allowed extensions for 199 | the files to be counted. 200 | follow_links: Boolean, follow symbolic links to subdirectories. 201 | 202 | # Yields 203 | Tuple of (root, filename) with extension in `white_list_formats`. 204 | """ 205 | def _recursive_list(subpath): 206 | return sorted(os.walk(subpath, followlinks=follow_links), 207 | key=lambda x: x[0]) 208 | 209 | for root, _, files in _recursive_list(directory): 210 | for fname in sorted(files): 211 | if fname.lower().endswith('.tiff'): 212 | warnings.warn('Using ".tiff" files with multiple bands ' 213 | 'will cause distortion. Please verify your output.') 214 | if fname.lower().endswith(white_list_formats): 215 | yield root, fname 216 | 217 | 218 | def _list_valid_filenames_in_directory(directory, white_list_formats, split, 219 | class_indices, follow_links): 220 | """Lists paths of files in `subdir` with extensions in `white_list_formats`. 221 | 222 | # Arguments 223 | directory: absolute path to a directory containing the files to list. 224 | The directory name is used as class label 225 | and must be a key of `class_indices`. 226 | white_list_formats: set of strings containing allowed extensions for 227 | the files to be counted. 228 | split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into 229 | account a certain fraction of files in each directory. 230 | E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent 231 | of images in each directory. 232 | class_indices: dictionary mapping a class name to its index. 233 | follow_links: boolean, follow symbolic links to subdirectories. 234 | 235 | # Returns 236 | classes: a list of class indices 237 | filenames: the path of valid files in `directory`, relative from 238 | `directory`'s parent (e.g., if `directory` is "dataset/class1", 239 | the filenames will be 240 | `["class1/file1.jpg", "class1/file2.jpg", ...]`). 241 | """ 242 | dirname = os.path.basename(directory) 243 | if split: 244 | all_files = list(_iter_valid_files(directory, white_list_formats, 245 | follow_links)) 246 | num_files = len(all_files) 247 | start, stop = int(split[0] * num_files), int(split[1] * num_files) 248 | valid_files = all_files[start: stop] 249 | else: 250 | valid_files = _iter_valid_files( 251 | directory, white_list_formats, follow_links) 252 | classes = [] 253 | filenames = [] 254 | for root, fname in valid_files: 255 | classes.append(class_indices[dirname]) 256 | absolute_path = os.path.join(root, fname) 257 | relative_path = os.path.join( 258 | dirname, os.path.relpath(absolute_path, directory)) 259 | filenames.append(relative_path) 260 | 261 | return classes, filenames 262 | 263 | 264 | def array_to_img(x, data_format='channels_last', scale=True, dtype='float32'): 265 | """Converts a 3D Numpy array to a PIL Image instance. 266 | 267 | # Arguments 268 | x: Input Numpy array. 269 | data_format: Image data format, either "channels_first" or "channels_last". 270 | Default: "channels_last". 271 | scale: Whether to rescale the image such that minimum and maximum values 272 | are 0 and 255 respectively. 273 | Default: True. 274 | dtype: Dtype to use. 275 | Default: "float32". 276 | 277 | # Returns 278 | A PIL Image instance. 279 | 280 | # Raises 281 | ImportError: if PIL is not available. 282 | ValueError: if invalid `x` or `data_format` is passed. 283 | """ 284 | if pil_image is None: 285 | raise ImportError('Could not import PIL.Image. ' 286 | 'The use of `array_to_img` requires PIL.') 287 | x = np.asarray(x, dtype=dtype) 288 | if x.ndim != 3: 289 | raise ValueError('Expected image array to have rank 3 (single image). ' 290 | 'Got array with shape: %s' % (x.shape,)) 291 | 292 | if data_format not in {'channels_first', 'channels_last'}: 293 | raise ValueError('Invalid data_format: %s' % data_format) 294 | 295 | # Original Numpy array x has format (height, width, channel) 296 | # or (channel, height, width) 297 | # but target PIL image has format (width, height, channel) 298 | if data_format == 'channels_first': 299 | x = x.transpose(1, 2, 0) 300 | if scale: 301 | x = x - np.min(x) 302 | x_max = np.max(x) 303 | if x_max != 0: 304 | x /= x_max 305 | x *= 255 306 | if x.shape[2] == 4: 307 | # RGBA 308 | return pil_image.fromarray(x.astype('uint8'), 'RGBA') 309 | elif x.shape[2] == 3: 310 | # RGB 311 | return pil_image.fromarray(x.astype('uint8'), 'RGB') 312 | elif x.shape[2] == 1: 313 | # grayscale 314 | if np.max(x) > 255: 315 | # 32-bit signed integer grayscale image. PIL mode "I" 316 | return pil_image.fromarray(x[:, :, 0].astype('int32'), 'I') 317 | return pil_image.fromarray(x[:, :, 0].astype('uint8'), 'L') 318 | else: 319 | raise ValueError('Unsupported channel number: %s' % (x.shape[2],)) 320 | 321 | 322 | def img_to_array(img, data_format='channels_last', dtype='float32'): 323 | """Converts a PIL Image instance to a Numpy array. 324 | 325 | # Arguments 326 | img: PIL Image instance. 327 | data_format: Image data format, 328 | either "channels_first" or "channels_last". 329 | dtype: Dtype to use for the returned array. 330 | 331 | # Returns 332 | A 3D Numpy array. 333 | 334 | # Raises 335 | ValueError: if invalid `img` or `data_format` is passed. 336 | """ 337 | if data_format not in {'channels_first', 'channels_last'}: 338 | raise ValueError('Unknown data_format: %s' % data_format) 339 | # Numpy array x has format (height, width, channel) 340 | # or (channel, height, width) 341 | # but original PIL image has format (width, height, channel) 342 | x = np.asarray(img, dtype=dtype) 343 | if len(x.shape) == 3: 344 | if data_format == 'channels_first': 345 | x = x.transpose(2, 0, 1) 346 | elif len(x.shape) == 2: 347 | if data_format == 'channels_first': 348 | x = x.reshape((1, x.shape[0], x.shape[1])) 349 | else: 350 | x = x.reshape((x.shape[0], x.shape[1], 1)) 351 | else: 352 | raise ValueError('Unsupported image shape: %s' % (x.shape,)) 353 | return x 354 | -------------------------------------------------------------------------------- /keras_preprocessing/sequence.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Utilities for preprocessing sequence data. 3 | """ 4 | import json 5 | import random 6 | 7 | import numpy as np 8 | 9 | 10 | def pad_sequences(sequences, maxlen=None, dtype='int32', 11 | padding='pre', truncating='pre', value=0.): 12 | """Pads sequences to the same length. 13 | 14 | This function transforms a list of 15 | `num_samples` sequences (lists of integers) 16 | into a 2D Numpy array of shape `(num_samples, num_timesteps)`. 17 | `num_timesteps` is either the `maxlen` argument if provided, 18 | or the length of the longest sequence otherwise. 19 | 20 | Sequences that are shorter than `num_timesteps` 21 | are padded with `value` at the beginning or the end 22 | if padding='post. 23 | 24 | Sequences longer than `num_timesteps` are truncated 25 | so that they fit the desired length. 26 | The position where padding or truncation happens is determined by 27 | the arguments `padding` and `truncating`, respectively. 28 | 29 | Pre-padding is the default. 30 | 31 | # Arguments 32 | sequences: List of lists, where each element is a sequence. 33 | maxlen: Int, maximum length of all sequences. 34 | dtype: Type of the output sequences. 35 | To pad sequences with variable length strings, you can use `object`. 36 | padding: String, 'pre' or 'post': 37 | pad either before or after each sequence. 38 | truncating: String, 'pre' or 'post': 39 | remove values from sequences larger than 40 | `maxlen`, either at the beginning or at the end of the sequences. 41 | value: Float or String, padding value. 42 | 43 | # Returns 44 | x: Numpy array with shape `(len(sequences), maxlen)` 45 | 46 | # Raises 47 | ValueError: In case of invalid values for `truncating` or `padding`, 48 | or in case of invalid shape for a `sequences` entry. 49 | """ 50 | if not hasattr(sequences, '__len__'): 51 | raise ValueError('`sequences` must be iterable.') 52 | num_samples = len(sequences) 53 | 54 | lengths = [] 55 | sample_shape = () 56 | flag = True 57 | 58 | # take the sample shape from the first non empty sequence 59 | # checking for consistency in the main loop below. 60 | 61 | for x in sequences: 62 | try: 63 | lengths.append(len(x)) 64 | if flag and len(x): 65 | sample_shape = np.asarray(x).shape[1:] 66 | flag = False 67 | except TypeError: 68 | raise ValueError('`sequences` must be a list of iterables. ' 69 | 'Found non-iterable: ' + str(x)) 70 | 71 | if maxlen is None: 72 | maxlen = np.max(lengths) 73 | 74 | is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_) 75 | if isinstance(value, str) and dtype != object and not is_dtype_str: 76 | raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n" 77 | "You should set `dtype=object` for variable length strings." 78 | .format(dtype, type(value))) 79 | 80 | x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) 81 | for idx, s in enumerate(sequences): 82 | if not len(s): 83 | continue # empty list/array was found 84 | if truncating == 'pre': 85 | trunc = s[-maxlen:] 86 | elif truncating == 'post': 87 | trunc = s[:maxlen] 88 | else: 89 | raise ValueError('Truncating type "%s" ' 90 | 'not understood' % truncating) 91 | 92 | # check `trunc` has expected shape 93 | trunc = np.asarray(trunc, dtype=dtype) 94 | if trunc.shape[1:] != sample_shape: 95 | raise ValueError('Shape of sample %s of sequence at position %s ' 96 | 'is different from expected shape %s' % 97 | (trunc.shape[1:], idx, sample_shape)) 98 | 99 | if padding == 'post': 100 | x[idx, :len(trunc)] = trunc 101 | elif padding == 'pre': 102 | x[idx, -len(trunc):] = trunc 103 | else: 104 | raise ValueError('Padding type "%s" not understood' % padding) 105 | return x 106 | 107 | 108 | def make_sampling_table(size, sampling_factor=1e-5): 109 | """Generates a word rank-based probabilistic sampling table. 110 | 111 | Used for generating the `sampling_table` argument for `skipgrams`. 112 | `sampling_table[i]` is the probability of sampling 113 | the word i-th most common word in a dataset 114 | (more common words should be sampled less frequently, for balance). 115 | 116 | The sampling probabilities are generated according 117 | to the sampling distribution used in word2vec: 118 | 119 | ``` 120 | p(word) = (min(1, sqrt(word_frequency / sampling_factor) / 121 | (word_frequency / sampling_factor))) 122 | ``` 123 | 124 | We assume that the word frequencies follow Zipf's law (s=1) to derive 125 | a numerical approximation of frequency(rank): 126 | 127 | `frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank))` 128 | where `gamma` is the Euler-Mascheroni constant. 129 | 130 | # Arguments 131 | size: Int, number of possible words to sample. 132 | sampling_factor: The sampling factor in the word2vec formula. 133 | 134 | # Returns 135 | A 1D Numpy array of length `size` where the ith entry 136 | is the probability that a word of rank i should be sampled. 137 | """ 138 | gamma = 0.577 139 | rank = np.arange(size) 140 | rank[0] = 1 141 | inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1. / (12. * rank) 142 | f = sampling_factor * inv_fq 143 | 144 | return np.minimum(1., f / np.sqrt(f)) 145 | 146 | 147 | def skipgrams(sequence, vocabulary_size, 148 | window_size=4, negative_samples=1., shuffle=True, 149 | categorical=False, sampling_table=None, seed=None): 150 | """Generates skipgram word pairs. 151 | 152 | This function transforms a sequence of word indexes (list of integers) 153 | into tuples of words of the form: 154 | 155 | - (word, word in the same window), with label 1 (positive samples). 156 | - (word, random word from the vocabulary), with label 0 (negative samples). 157 | 158 | Read more about Skipgram in this gnomic paper by Mikolov et al.: 159 | [Efficient Estimation of Word Representations in 160 | Vector Space](http://arxiv.org/pdf/1301.3781v3.pdf) 161 | 162 | # Arguments 163 | sequence: A word sequence (sentence), encoded as a list 164 | of word indices (integers). If using a `sampling_table`, 165 | word indices are expected to match the rank 166 | of the words in a reference dataset (e.g. 10 would encode 167 | the 10-th most frequently occurring token). 168 | Note that index 0 is expected to be a non-word and will be skipped. 169 | vocabulary_size: Int, maximum possible word index + 1 170 | window_size: Int, size of sampling windows (technically half-window). 171 | The window of a word `w_i` will be 172 | `[i - window_size, i + window_size+1]`. 173 | negative_samples: Float >= 0. 0 for no negative (i.e. random) samples. 174 | 1 for same number as positive samples. 175 | shuffle: Whether to shuffle the word couples before returning them. 176 | categorical: bool. if False, labels will be 177 | integers (eg. `[0, 1, 1 .. ]`), 178 | if `True`, labels will be categorical, e.g. 179 | `[[1,0],[0,1],[0,1] .. ]`. 180 | sampling_table: 1D array of size `vocabulary_size` where the entry i 181 | encodes the probability to sample a word of rank i. 182 | seed: Random seed. 183 | 184 | # Returns 185 | couples, labels: where `couples` are int pairs and 186 | `labels` are either 0 or 1. 187 | 188 | # Note 189 | By convention, index 0 in the vocabulary is 190 | a non-word and will be skipped. 191 | """ 192 | couples = [] 193 | labels = [] 194 | for i, wi in enumerate(sequence): 195 | if not wi: 196 | continue 197 | if sampling_table is not None: 198 | if sampling_table[wi] < random.random(): 199 | continue 200 | 201 | window_start = max(0, i - window_size) 202 | window_end = min(len(sequence), i + window_size + 1) 203 | for j in range(window_start, window_end): 204 | if j != i: 205 | wj = sequence[j] 206 | if not wj: 207 | continue 208 | couples.append([wi, wj]) 209 | if categorical: 210 | labels.append([0, 1]) 211 | else: 212 | labels.append(1) 213 | 214 | if negative_samples > 0: 215 | num_negative_samples = int(len(labels) * negative_samples) 216 | words = [c[0] for c in couples] 217 | random.shuffle(words) 218 | 219 | couples += [[words[i % len(words)], 220 | random.randint(1, vocabulary_size - 1)] 221 | for i in range(num_negative_samples)] 222 | if categorical: 223 | labels += [[1, 0]] * num_negative_samples 224 | else: 225 | labels += [0] * num_negative_samples 226 | 227 | if shuffle: 228 | if seed is None: 229 | seed = random.randint(0, 10e6) 230 | random.seed(seed) 231 | random.shuffle(couples) 232 | random.seed(seed) 233 | random.shuffle(labels) 234 | 235 | return couples, labels 236 | 237 | 238 | def _remove_long_seq(maxlen, seq, label): 239 | """Removes sequences that exceed the maximum length. 240 | 241 | # Arguments 242 | maxlen: Int, maximum length of the output sequences. 243 | seq: List of lists, where each sublist is a sequence. 244 | label: List where each element is an integer. 245 | 246 | # Returns 247 | new_seq, new_label: shortened lists for `seq` and `label`. 248 | """ 249 | new_seq, new_label = [], [] 250 | for x, y in zip(seq, label): 251 | if len(x) < maxlen: 252 | new_seq.append(x) 253 | new_label.append(y) 254 | return new_seq, new_label 255 | 256 | 257 | class TimeseriesGenerator(object): 258 | """Utility class for generating batches of temporal data. 259 | 260 | This class takes in a sequence of data-points gathered at 261 | equal intervals, along with time series parameters such as 262 | stride, length of history, etc., to produce batches for 263 | training/validation. 264 | 265 | # Arguments 266 | data: Indexable generator (such as list or Numpy array) 267 | containing consecutive data points (timesteps). 268 | The data should be at 2D, and axis 0 is expected 269 | to be the time dimension. 270 | targets: Targets corresponding to timesteps in `data`. 271 | It should have same length as `data`. 272 | length: Length of the output sequences (in number of timesteps). 273 | sampling_rate: Period between successive individual timesteps 274 | within sequences. For rate `r`, timesteps 275 | `data[i]`, `data[i-r]`, ... `data[i - length]` 276 | are used for create a sample sequence. 277 | stride: Period between successive output sequences. 278 | For stride `s`, consecutive output samples would 279 | be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. 280 | start_index: Data points earlier than `start_index` will not be used 281 | in the output sequences. This is useful to reserve part of the 282 | data for test or validation. 283 | end_index: Data points later than `end_index` will not be used 284 | in the output sequences. This is useful to reserve part of the 285 | data for test or validation. 286 | shuffle: Whether to shuffle output samples, 287 | or instead draw them in chronological order. 288 | reverse: Boolean: if `true`, timesteps in each output sample will be 289 | in reverse chronological order. 290 | batch_size: Number of timeseries samples in each batch 291 | (except maybe the last one). 292 | 293 | # Returns 294 | A [Sequence](/utils/#sequence) instance. 295 | 296 | # Examples 297 | 298 | ```python 299 | from keras.preprocessing.sequence import TimeseriesGenerator 300 | import numpy as np 301 | 302 | data = np.array([[i] for i in range(50)]) 303 | targets = np.array([[i] for i in range(50)]) 304 | 305 | data_gen = TimeseriesGenerator(data, targets, 306 | length=10, sampling_rate=2, 307 | batch_size=2) 308 | assert len(data_gen) == 20 309 | 310 | batch_0 = data_gen[0] 311 | x, y = batch_0 312 | assert np.array_equal(x, 313 | np.array([[[0], [2], [4], [6], [8]], 314 | [[1], [3], [5], [7], [9]]])) 315 | assert np.array_equal(y, 316 | np.array([[10], [11]])) 317 | ``` 318 | """ 319 | 320 | def __init__(self, data, targets, length, 321 | sampling_rate=1, 322 | stride=1, 323 | start_index=0, 324 | end_index=None, 325 | shuffle=False, 326 | reverse=False, 327 | batch_size=128): 328 | 329 | if len(data) != len(targets): 330 | raise ValueError('Data and targets have to be' + 331 | ' of same length. ' 332 | 'Data length is {}'.format(len(data)) + 333 | ' while target length is {}'.format(len(targets))) 334 | 335 | self.data = data 336 | self.targets = targets 337 | self.length = length 338 | self.sampling_rate = sampling_rate 339 | self.stride = stride 340 | self.start_index = start_index + length 341 | if end_index is None: 342 | end_index = len(data) - 1 343 | self.end_index = end_index 344 | self.shuffle = shuffle 345 | self.reverse = reverse 346 | self.batch_size = batch_size 347 | 348 | if self.start_index > self.end_index: 349 | raise ValueError('`start_index+length=%i > end_index=%i` ' 350 | 'is disallowed, as no part of the sequence ' 351 | 'would be left to be used as current step.' 352 | % (self.start_index, self.end_index)) 353 | 354 | def __len__(self): 355 | return (self.end_index - self.start_index + 356 | self.batch_size * self.stride) // (self.batch_size * self.stride) 357 | 358 | def __getitem__(self, index): 359 | if self.shuffle: 360 | rows = np.random.randint( 361 | self.start_index, self.end_index + 1, size=self.batch_size) 362 | else: 363 | i = self.start_index + self.batch_size * self.stride * index 364 | rows = np.arange(i, min(i + self.batch_size * 365 | self.stride, self.end_index + 1), self.stride) 366 | 367 | samples = np.array([self.data[row - self.length:row:self.sampling_rate] 368 | for row in rows]) 369 | targets = np.array([self.targets[row] for row in rows]) 370 | 371 | if self.reverse: 372 | return samples[:, ::-1, ...], targets 373 | return samples, targets 374 | 375 | def get_config(self): 376 | '''Returns the TimeseriesGenerator configuration as Python dictionary. 377 | 378 | # Returns 379 | A Python dictionary with the TimeseriesGenerator configuration. 380 | ''' 381 | data = self.data 382 | if type(self.data).__module__ == np.__name__: 383 | data = self.data.tolist() 384 | try: 385 | json_data = json.dumps(data) 386 | except TypeError: 387 | raise TypeError('Data not JSON Serializable:', data) 388 | 389 | targets = self.targets 390 | if type(self.targets).__module__ == np.__name__: 391 | targets = self.targets.tolist() 392 | try: 393 | json_targets = json.dumps(targets) 394 | except TypeError: 395 | raise TypeError('Targets not JSON Serializable:', targets) 396 | 397 | return { 398 | 'data': json_data, 399 | 'targets': json_targets, 400 | 'length': self.length, 401 | 'sampling_rate': self.sampling_rate, 402 | 'stride': self.stride, 403 | 'start_index': self.start_index, 404 | 'end_index': self.end_index, 405 | 'shuffle': self.shuffle, 406 | 'reverse': self.reverse, 407 | 'batch_size': self.batch_size 408 | } 409 | 410 | def to_json(self, **kwargs): 411 | """Returns a JSON string containing the timeseries generator 412 | configuration. To load a generator from a JSON string, use 413 | `keras.preprocessing.sequence.timeseries_generator_from_json(json_string)`. 414 | 415 | # Arguments 416 | **kwargs: Additional keyword arguments 417 | to be passed to `json.dumps()`. 418 | 419 | # Returns 420 | A JSON string containing the tokenizer configuration. 421 | """ 422 | config = self.get_config() 423 | timeseries_generator_config = { 424 | 'class_name': self.__class__.__name__, 425 | 'config': config 426 | } 427 | return json.dumps(timeseries_generator_config, **kwargs) 428 | 429 | 430 | def timeseries_generator_from_json(json_string): 431 | """Parses a JSON timeseries generator configuration file and 432 | returns a timeseries generator instance. 433 | 434 | # Arguments 435 | json_string: JSON string encoding a timeseries 436 | generator configuration. 437 | 438 | # Returns 439 | A Keras TimeseriesGenerator instance 440 | """ 441 | full_config = json.loads(json_string) 442 | config = full_config.get('config') 443 | 444 | data = json.loads(config.pop('data')) 445 | config['data'] = data 446 | targets = json.loads(config.pop('targets')) 447 | config['targets'] = targets 448 | 449 | return TimeseriesGenerator(**config) 450 | -------------------------------------------------------------------------------- /keras_preprocessing/text.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Utilities for text input preprocessing. 3 | """ 4 | import json 5 | import warnings 6 | from collections import OrderedDict, defaultdict 7 | from hashlib import md5 8 | 9 | import numpy as np 10 | 11 | maketrans = str.maketrans 12 | 13 | 14 | def text_to_word_sequence(text, 15 | filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', 16 | lower=True, split=" "): 17 | """Converts a text to a sequence of words (or tokens). 18 | 19 | # Arguments 20 | text: Input text (string). 21 | filters: list (or concatenation) of characters to filter out, such as 22 | punctuation. Default: ``!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n``, 23 | includes basic punctuation, tabs, and newlines. 24 | lower: boolean. Whether to convert the input to lowercase. 25 | split: str. Separator for word splitting. 26 | 27 | # Returns 28 | A list of words (or tokens). 29 | """ 30 | if lower: 31 | text = text.lower() 32 | 33 | translate_dict = {c: split for c in filters} 34 | translate_map = maketrans(translate_dict) 35 | text = text.translate(translate_map) 36 | 37 | seq = text.split(split) 38 | return [i for i in seq if i] 39 | 40 | 41 | def one_hot(text, n, 42 | filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', 43 | lower=True, 44 | split=' ', 45 | analyzer=None): 46 | """One-hot encodes a text into a list of word indexes of size n. 47 | 48 | This is a wrapper to the `hashing_trick` function using `hash` as the 49 | hashing function; unicity of word to index mapping non-guaranteed. 50 | 51 | # Arguments 52 | text: Input text (string). 53 | n: int. Size of vocabulary. 54 | filters: list (or concatenation) of characters to filter out, such as 55 | punctuation. Default: ``!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n``, 56 | includes basic punctuation, tabs, and newlines. 57 | lower: boolean. Whether to set the text to lowercase. 58 | split: str. Separator for word splitting. 59 | analyzer: function. Custom analyzer to split the text 60 | 61 | # Returns 62 | List of integers in [1, n]. Each integer encodes a word 63 | (unicity non-guaranteed). 64 | """ 65 | return hashing_trick(text, n, 66 | hash_function=hash, 67 | filters=filters, 68 | lower=lower, 69 | split=split, 70 | analyzer=analyzer) 71 | 72 | 73 | def hashing_trick(text, n, 74 | hash_function=None, 75 | filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', 76 | lower=True, 77 | split=' ', 78 | analyzer=None): 79 | """Converts a text to a sequence of indexes in a fixed-size hashing space. 80 | 81 | # Arguments 82 | text: Input text (string). 83 | n: Dimension of the hashing space. 84 | hash_function: defaults to python `hash` function, can be 'md5' or 85 | any function that takes in input a string and returns a int. 86 | Note that 'hash' is not a stable hashing function, so 87 | it is not consistent across different runs, while 'md5' 88 | is a stable hashing function. 89 | filters: list (or concatenation) of characters to filter out, such as 90 | punctuation. Default: ``!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\\t\\n``, 91 | includes basic punctuation, tabs, and newlines. 92 | lower: boolean. Whether to set the text to lowercase. 93 | split: str. Separator for word splitting. 94 | analyzer: function. Custom analyzer to split the text 95 | 96 | # Returns 97 | A list of integer word indices (unicity non-guaranteed). 98 | 99 | `0` is a reserved index that won't be assigned to any word. 100 | 101 | Two or more words may be assigned to the same index, due to possible 102 | collisions by the hashing function. 103 | The [probability]( 104 | https://en.wikipedia.org/wiki/Birthday_problem#Probability_table) 105 | of a collision is in relation to the dimension of the hashing space and 106 | the number of distinct objects. 107 | """ 108 | if hash_function is None: 109 | hash_function = hash 110 | elif hash_function == 'md5': 111 | def hash_function(w): 112 | return int(md5(w.encode()).hexdigest(), 16) 113 | 114 | if analyzer is None: 115 | seq = text_to_word_sequence(text, 116 | filters=filters, 117 | lower=lower, 118 | split=split) 119 | else: 120 | seq = analyzer(text) 121 | 122 | return [(hash_function(w) % (n - 1) + 1) for w in seq] 123 | 124 | 125 | class Tokenizer(object): 126 | """Text tokenization utility class. 127 | 128 | This class allows to vectorize a text corpus, by turning each 129 | text into either a sequence of integers (each integer being the index 130 | of a token in a dictionary) or into a vector where the coefficient 131 | for each token could be binary, based on word count, based on tf-idf... 132 | 133 | # Arguments 134 | num_words: the maximum number of words to keep, based 135 | on word frequency. Only the most common `num_words-1` words will 136 | be kept. 137 | filters: a string where each element is a character that will be 138 | filtered from the texts. The default is all punctuation, plus 139 | tabs and line breaks, minus the `'` character. 140 | lower: boolean. Whether to convert the texts to lowercase. 141 | split: str. Separator for word splitting. 142 | char_level: if True, every character will be treated as a token. 143 | oov_token: if given, it will be added to word_index and used to 144 | replace out-of-vocabulary words during text_to_sequence calls 145 | analyzer: function. Custom analyzer to split the text. 146 | The default analyzer is text_to_word_sequence 147 | 148 | By default, all punctuation is removed, turning the texts into 149 | space-separated sequences of words 150 | (words maybe include the `'` character). These sequences are then 151 | split into lists of tokens. They will then be indexed or vectorized. 152 | 153 | `0` is a reserved index that won't be assigned to any word. 154 | """ 155 | 156 | def __init__(self, num_words=None, 157 | filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', 158 | lower=True, 159 | split=' ', 160 | char_level=False, 161 | oov_token=None, 162 | analyzer=None, 163 | **kwargs): 164 | # Legacy support 165 | if 'nb_words' in kwargs: 166 | warnings.warn('The `nb_words` argument in `Tokenizer` ' 167 | 'has been renamed `num_words`.') 168 | num_words = kwargs.pop('nb_words') 169 | document_count = kwargs.pop('document_count', 0) 170 | if kwargs: 171 | raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) 172 | 173 | self.word_counts = OrderedDict() 174 | self.word_docs = defaultdict(int) 175 | self.filters = filters 176 | self.split = split 177 | self.lower = lower 178 | self.num_words = num_words 179 | self.document_count = document_count 180 | self.char_level = char_level 181 | self.oov_token = oov_token 182 | self.index_docs = defaultdict(int) 183 | self.word_index = {} 184 | self.index_word = {} 185 | self.analyzer = analyzer 186 | 187 | def fit_on_texts(self, texts): 188 | """Updates internal vocabulary based on a list of texts. 189 | 190 | In the case where texts contains lists, 191 | we assume each entry of the lists to be a token. 192 | 193 | Required before using `texts_to_sequences` or `texts_to_matrix`. 194 | 195 | # Arguments 196 | texts: can be a list of strings, 197 | a generator of strings (for memory-efficiency), 198 | or a list of list of strings. 199 | """ 200 | for text in texts: 201 | self.document_count += 1 202 | if self.char_level or isinstance(text, list): 203 | if self.lower: 204 | if isinstance(text, list): 205 | text = [text_elem.lower() for text_elem in text] 206 | else: 207 | text = text.lower() 208 | seq = text 209 | else: 210 | if self.analyzer is None: 211 | seq = text_to_word_sequence(text, 212 | filters=self.filters, 213 | lower=self.lower, 214 | split=self.split) 215 | else: 216 | seq = self.analyzer(text) 217 | for w in seq: 218 | if w in self.word_counts: 219 | self.word_counts[w] += 1 220 | else: 221 | self.word_counts[w] = 1 222 | for w in set(seq): 223 | # In how many documents each word occurs 224 | self.word_docs[w] += 1 225 | 226 | wcounts = list(self.word_counts.items()) 227 | wcounts.sort(key=lambda x: x[1], reverse=True) 228 | # forcing the oov_token to index 1 if it exists 229 | if self.oov_token is None: 230 | sorted_voc = [] 231 | else: 232 | sorted_voc = [self.oov_token] 233 | sorted_voc.extend(wc[0] for wc in wcounts) 234 | 235 | # note that index 0 is reserved, never assigned to an existing word 236 | self.word_index = dict( 237 | zip(sorted_voc, list(range(1, len(sorted_voc) + 1)))) 238 | 239 | self.index_word = {c: w for w, c in self.word_index.items()} 240 | 241 | for w, c in list(self.word_docs.items()): 242 | self.index_docs[self.word_index[w]] = c 243 | 244 | def fit_on_sequences(self, sequences): 245 | """Updates internal vocabulary based on a list of sequences. 246 | 247 | Required before using `sequences_to_matrix` 248 | (if `fit_on_texts` was never called). 249 | 250 | # Arguments 251 | sequences: A list of sequence. 252 | A "sequence" is a list of integer word indices. 253 | """ 254 | self.document_count += len(sequences) 255 | for seq in sequences: 256 | seq = set(seq) 257 | for i in seq: 258 | self.index_docs[i] += 1 259 | 260 | def texts_to_sequences(self, texts): 261 | """Transforms each text in texts to a sequence of integers. 262 | 263 | Only top `num_words-1` most frequent words will be taken into account. 264 | Only words known by the tokenizer will be taken into account. 265 | 266 | # Arguments 267 | texts: A list of texts (strings). 268 | 269 | # Returns 270 | A list of sequences. 271 | """ 272 | return list(self.texts_to_sequences_generator(texts)) 273 | 274 | def texts_to_sequences_generator(self, texts): 275 | """Transforms each text in `texts` to a sequence of integers. 276 | 277 | Each item in texts can also be a list, 278 | in which case we assume each item of that list to be a token. 279 | 280 | Only top `num_words-1` most frequent words will be taken into account. 281 | Only words known by the tokenizer will be taken into account. 282 | 283 | # Arguments 284 | texts: A list of texts (strings). 285 | 286 | # Yields 287 | Yields individual sequences. 288 | """ 289 | num_words = self.num_words 290 | oov_token_index = self.word_index.get(self.oov_token) 291 | for text in texts: 292 | if self.char_level or isinstance(text, list): 293 | if self.lower: 294 | if isinstance(text, list): 295 | text = [text_elem.lower() for text_elem in text] 296 | else: 297 | text = text.lower() 298 | seq = text 299 | else: 300 | if self.analyzer is None: 301 | seq = text_to_word_sequence(text, 302 | filters=self.filters, 303 | lower=self.lower, 304 | split=self.split) 305 | else: 306 | seq = self.analyzer(text) 307 | vect = [] 308 | for w in seq: 309 | i = self.word_index.get(w) 310 | if i is not None: 311 | if num_words and i >= num_words: 312 | if oov_token_index is not None: 313 | vect.append(oov_token_index) 314 | else: 315 | vect.append(i) 316 | elif self.oov_token is not None: 317 | vect.append(oov_token_index) 318 | yield vect 319 | 320 | def sequences_to_texts(self, sequences): 321 | """Transforms each sequence into a list of text. 322 | 323 | Only top `num_words-1` most frequent words will be taken into account. 324 | Only words known by the tokenizer will be taken into account. 325 | 326 | # Arguments 327 | sequences: A list of sequences (list of integers). 328 | 329 | # Returns 330 | A list of texts (strings) 331 | """ 332 | return list(self.sequences_to_texts_generator(sequences)) 333 | 334 | def sequences_to_texts_generator(self, sequences): 335 | """Transforms each sequence in `sequences` to a list of texts(strings). 336 | 337 | Each sequence has to a list of integers. 338 | In other words, sequences should be a list of sequences 339 | 340 | Only top `num_words-1` most frequent words will be taken into account. 341 | Only words known by the tokenizer will be taken into account. 342 | 343 | # Arguments 344 | sequences: A list of sequences. 345 | 346 | # Yields 347 | Yields individual texts. 348 | """ 349 | num_words = self.num_words 350 | oov_token_index = self.word_index.get(self.oov_token) 351 | for seq in sequences: 352 | vect = [] 353 | for num in seq: 354 | word = self.index_word.get(num) 355 | if word is not None: 356 | if num_words and num >= num_words: 357 | if oov_token_index is not None: 358 | vect.append(self.index_word[oov_token_index]) 359 | else: 360 | vect.append(word) 361 | elif self.oov_token is not None: 362 | vect.append(self.index_word[oov_token_index]) 363 | vect = ' '.join(vect) 364 | yield vect 365 | 366 | def texts_to_matrix(self, texts, mode='binary'): 367 | """Convert a list of texts to a Numpy matrix. 368 | 369 | # Arguments 370 | texts: list of strings. 371 | mode: one of "binary", "count", "tfidf", "freq". 372 | 373 | # Returns 374 | A Numpy matrix. 375 | """ 376 | sequences = self.texts_to_sequences(texts) 377 | return self.sequences_to_matrix(sequences, mode=mode) 378 | 379 | def sequences_to_matrix(self, sequences, mode='binary'): 380 | """Converts a list of sequences into a Numpy matrix. 381 | 382 | # Arguments 383 | sequences: list of sequences 384 | (a sequence is a list of integer word indices). 385 | mode: one of "binary", "count", "tfidf", "freq" 386 | 387 | # Returns 388 | A Numpy matrix. 389 | 390 | # Raises 391 | ValueError: In case of invalid `mode` argument, 392 | or if the Tokenizer requires to be fit to sample data. 393 | """ 394 | if not self.num_words: 395 | if self.word_index: 396 | num_words = len(self.word_index) + 1 397 | else: 398 | raise ValueError('Specify a dimension (`num_words` argument), ' 399 | 'or fit on some text data first.') 400 | else: 401 | num_words = self.num_words 402 | 403 | if mode == 'tfidf' and not self.document_count: 404 | raise ValueError('Fit the Tokenizer on some data ' 405 | 'before using tfidf mode.') 406 | 407 | x = np.zeros((len(sequences), num_words)) 408 | for i, seq in enumerate(sequences): 409 | if not seq: 410 | continue 411 | counts = defaultdict(int) 412 | for j in seq: 413 | if j >= num_words: 414 | continue 415 | counts[j] += 1 416 | for j, c in list(counts.items()): 417 | if mode == 'count': 418 | x[i][j] = c 419 | elif mode == 'freq': 420 | x[i][j] = c / len(seq) 421 | elif mode == 'binary': 422 | x[i][j] = 1 423 | elif mode == 'tfidf': 424 | # Use weighting scheme 2 in 425 | # https://en.wikipedia.org/wiki/Tf%E2%80%93idf 426 | tf = 1 + np.log(c) 427 | idf = np.log(1 + self.document_count / 428 | (1 + self.index_docs.get(j, 0))) 429 | x[i][j] = tf * idf 430 | else: 431 | raise ValueError('Unknown vectorization mode:', mode) 432 | return x 433 | 434 | def get_config(self): 435 | '''Returns the tokenizer configuration as Python dictionary. 436 | The word count dictionaries used by the tokenizer get serialized 437 | into plain JSON, so that the configuration can be read by other 438 | projects. 439 | 440 | # Returns 441 | A Python dictionary with the tokenizer configuration. 442 | ''' 443 | json_word_counts = json.dumps(self.word_counts) 444 | json_word_docs = json.dumps(self.word_docs) 445 | json_index_docs = json.dumps(self.index_docs) 446 | json_word_index = json.dumps(self.word_index) 447 | json_index_word = json.dumps(self.index_word) 448 | 449 | return { 450 | 'num_words': self.num_words, 451 | 'filters': self.filters, 452 | 'lower': self.lower, 453 | 'split': self.split, 454 | 'char_level': self.char_level, 455 | 'oov_token': self.oov_token, 456 | 'document_count': self.document_count, 457 | 'word_counts': json_word_counts, 458 | 'word_docs': json_word_docs, 459 | 'index_docs': json_index_docs, 460 | 'index_word': json_index_word, 461 | 'word_index': json_word_index 462 | } 463 | 464 | def to_json(self, **kwargs): 465 | """Returns a JSON string containing the tokenizer configuration. 466 | To load a tokenizer from a JSON string, use 467 | `keras.preprocessing.text.tokenizer_from_json(json_string)`. 468 | 469 | # Arguments 470 | **kwargs: Additional keyword arguments 471 | to be passed to `json.dumps()`. 472 | 473 | # Returns 474 | A JSON string containing the tokenizer configuration. 475 | """ 476 | config = self.get_config() 477 | tokenizer_config = { 478 | 'class_name': self.__class__.__name__, 479 | 'config': config 480 | } 481 | return json.dumps(tokenizer_config, **kwargs) 482 | 483 | 484 | def tokenizer_from_json(json_string): 485 | """Parses a JSON tokenizer configuration file and returns a 486 | tokenizer instance. 487 | 488 | # Arguments 489 | json_string: JSON string encoding a tokenizer configuration. 490 | 491 | # Returns 492 | A Keras Tokenizer instance 493 | """ 494 | tokenizer_config = json.loads(json_string) 495 | config = tokenizer_config.get('config') 496 | 497 | word_counts = json.loads(config.pop('word_counts')) 498 | word_docs = json.loads(config.pop('word_docs')) 499 | index_docs = json.loads(config.pop('index_docs')) 500 | # Integer indexing gets converted to strings with json.dumps() 501 | index_docs = {int(k): v for k, v in index_docs.items()} 502 | index_word = json.loads(config.pop('index_word')) 503 | index_word = {int(k): v for k, v in index_word.items()} 504 | word_index = json.loads(config.pop('word_index')) 505 | 506 | tokenizer = Tokenizer(**config) 507 | tokenizer.word_counts = word_counts 508 | tokenizer.word_docs = word_docs 509 | tokenizer.index_docs = index_docs 510 | tokenizer.word_index = word_index 511 | tokenizer.index_word = index_word 512 | 513 | return tokenizer 514 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Configuration of py.test 2 | [tool:pytest]] 3 | addopts=-v -n 2 --durations=20 4 | # Do not run tests in the build folder 5 | norecursedirs=build 6 | 7 | [flake8] 8 | # Use 85 as max line length in PEP8 test. 9 | max-line-length=85 10 | # do not run pep8 test in the build folder 11 | exclude=build 12 | # PEP-8 The following are ignored: 13 | # E731 do not assign a lambda expression, use a def 14 | # E402 module level import not at top of file 15 | pep8ignore=* E731 \ 16 | * E402 \ 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | long_description = ''' 4 | Keras Preprocessing is the data preprocessing 5 | and data augmentation module of the Keras deep learning library. 6 | It provides utilities for working with image data, text data, 7 | and sequence data. 8 | 9 | Read the documentation at: https://keras.io/ 10 | 11 | Keras Preprocessing may be imported directly 12 | from an up-to-date installation of Keras: 13 | 14 | ``` 15 | from keras import preprocessing 16 | ``` 17 | 18 | Keras Preprocessing is compatible with Python 3.6 19 | and is distributed under the MIT license. 20 | ''' 21 | 22 | setup(name='Keras_Preprocessing', 23 | version='1.1.2', 24 | description='Easy data preprocessing and data augmentation ' 25 | 'for deep learning models', 26 | long_description=long_description, 27 | author='Keras Team', 28 | url='https://github.com/keras-team/keras-preprocessing', 29 | download_url='https://github.com/keras-team/' 30 | 'keras-preprocessing/tarball/1.1.2', 31 | license='MIT', 32 | install_requires=['numpy>=1.9.1'], 33 | extras_require={ 34 | 'tests': ['pandas', 35 | 'Pillow', 36 | 'tensorflow', # CPU version 37 | 'keras', 38 | 'pytest', 39 | 'pytest-xdist', 40 | 'pytest-cov'], 41 | 'pep8': ['flake8'], 42 | 'image': ['scipy>=0.14', 43 | 'Pillow>=5.2.0'], 44 | }, 45 | classifiers=[ 46 | 'Development Status :: 5 - Production/Stable', 47 | 'Intended Audience :: Developers', 48 | 'Intended Audience :: Education', 49 | 'Intended Audience :: Science/Research', 50 | 'License :: OSI Approved :: MIT License', 51 | 'Programming Language :: Python :: 3', 52 | 'Programming Language :: Python :: 3.6', 53 | 'Topic :: Software Development :: Libraries', 54 | 'Topic :: Software Development :: Libraries :: Python Modules' 55 | ], 56 | packages=find_packages()) 57 | -------------------------------------------------------------------------------- /tests/image/affine_transformations_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from keras_preprocessing.image import affine_transformations 5 | 6 | 7 | def test_random_transforms(): 8 | x = np.random.random((2, 28, 28)) 9 | assert affine_transformations.random_rotation(x, 45).shape == (2, 28, 28) 10 | assert affine_transformations.random_shift(x, 1, 1).shape == (2, 28, 28) 11 | assert affine_transformations.random_shear(x, 20).shape == (2, 28, 28) 12 | assert affine_transformations.random_channel_shift(x, 20).shape == (2, 28, 28) 13 | 14 | 15 | def test_deterministic_transform(): 16 | x = np.ones((3, 3, 3)) 17 | x_rotated = np.array([[[0., 0., 0.], 18 | [1., 1., 1.], 19 | [0., 0., 0.]], 20 | [[1., 1., 1.], 21 | [1., 1., 1.], 22 | [1., 1., 1.]], 23 | [[0., 0., 0.], 24 | [1., 1., 1.], 25 | [0., 0., 0.]]]) 26 | assert np.allclose( 27 | affine_transformations.apply_affine_transform(x, 28 | theta=45, 29 | row_axis=0, 30 | col_axis=1, 31 | channel_axis=2, 32 | fill_mode='constant'), 33 | x_rotated) 34 | 35 | 36 | def test_matrix_center(): 37 | x = np.expand_dims(np.array([ 38 | [0, 1], 39 | [0, 0], 40 | ]), -1) 41 | x_rotated90 = np.expand_dims(np.array([ 42 | [1, 0], 43 | [0, 0], 44 | ]), -1) 45 | 46 | assert np.allclose( 47 | affine_transformations.apply_affine_transform(x, 48 | theta=90, 49 | row_axis=0, 50 | col_axis=1, 51 | channel_axis=2), 52 | x_rotated90) 53 | 54 | 55 | def test_translation(): 56 | x = np.array([ 57 | [0, 0, 0, 0], 58 | [0, 1, 0, 0], 59 | [0, 0, 0, 0], 60 | ]) 61 | x_up = np.array([ 62 | [0, 1, 0, 0], 63 | [0, 0, 0, 0], 64 | [0, 0, 0, 0], 65 | ]) 66 | x_dn = np.array([ 67 | [0, 0, 0, 0], 68 | [0, 0, 0, 0], 69 | [0, 1, 0, 0], 70 | ]) 71 | x_left = np.array([ 72 | [0, 0, 0, 0], 73 | [1, 0, 0, 0], 74 | [0, 0, 0, 0], 75 | ]) 76 | x_right = np.array([ 77 | [0, 0, 0, 0], 78 | [0, 0, 1, 0], 79 | [0, 0, 0, 0], 80 | ]) 81 | 82 | # Channels first 83 | x_test = np.expand_dims(x, 0) 84 | 85 | # Horizontal translation 86 | assert np.alltrue(x_left == np.squeeze( 87 | affine_transformations.apply_affine_transform(x_test, tx=1))) 88 | assert np.alltrue(x_right == np.squeeze( 89 | affine_transformations.apply_affine_transform(x_test, tx=-1))) 90 | 91 | # change axes: x<->y 92 | assert np.alltrue(x_left == np.squeeze( 93 | affine_transformations.apply_affine_transform( 94 | x_test, ty=1, row_axis=2, col_axis=1))) 95 | assert np.alltrue(x_right == np.squeeze( 96 | affine_transformations.apply_affine_transform( 97 | x_test, ty=-1, row_axis=2, col_axis=1))) 98 | 99 | # Vertical translation 100 | assert np.alltrue(x_up == np.squeeze( 101 | affine_transformations.apply_affine_transform(x_test, ty=1))) 102 | assert np.alltrue(x_dn == np.squeeze( 103 | affine_transformations.apply_affine_transform(x_test, ty=-1))) 104 | 105 | # change axes: x<->y 106 | assert np.alltrue(x_up == np.squeeze( 107 | affine_transformations.apply_affine_transform( 108 | x_test, tx=1, row_axis=2, col_axis=1))) 109 | assert np.alltrue(x_dn == np.squeeze( 110 | affine_transformations.apply_affine_transform( 111 | x_test, tx=-1, row_axis=2, col_axis=1))) 112 | 113 | # Channels last 114 | x_test = np.expand_dims(x, -1) 115 | 116 | # Horizontal translation 117 | assert np.alltrue(x_left == np.squeeze( 118 | affine_transformations.apply_affine_transform( 119 | x_test, tx=1, row_axis=0, col_axis=1, channel_axis=2))) 120 | assert np.alltrue(x_right == np.squeeze( 121 | affine_transformations.apply_affine_transform( 122 | x_test, tx=-1, row_axis=0, col_axis=1, channel_axis=2))) 123 | 124 | # change axes: x<->y 125 | assert np.alltrue(x_left == np.squeeze( 126 | affine_transformations.apply_affine_transform( 127 | x_test, ty=1, row_axis=1, col_axis=0, channel_axis=2))) 128 | assert np.alltrue(x_right == np.squeeze( 129 | affine_transformations.apply_affine_transform( 130 | x_test, ty=-1, row_axis=1, col_axis=0, channel_axis=2))) 131 | 132 | # Vertical translation 133 | assert np.alltrue(x_up == np.squeeze( 134 | affine_transformations.apply_affine_transform( 135 | x_test, ty=1, row_axis=0, col_axis=1, channel_axis=2))) 136 | assert np.alltrue(x_dn == np.squeeze( 137 | affine_transformations.apply_affine_transform( 138 | x_test, ty=-1, row_axis=0, col_axis=1, channel_axis=2))) 139 | 140 | # change axes: x<->y 141 | assert np.alltrue(x_up == np.squeeze( 142 | affine_transformations.apply_affine_transform( 143 | x_test, tx=1, row_axis=1, col_axis=0, channel_axis=2))) 144 | assert np.alltrue(x_dn == np.squeeze( 145 | affine_transformations.apply_affine_transform( 146 | x_test, tx=-1, row_axis=1, col_axis=0, channel_axis=2))) 147 | 148 | 149 | def test_random_zoom(): 150 | x = np.random.random((2, 28, 28)) 151 | assert affine_transformations.random_zoom(x, (5, 5)).shape == (2, 28, 28) 152 | assert np.allclose(x, affine_transformations.random_zoom(x, (1, 1))) 153 | 154 | 155 | def test_random_zoom_error(): 156 | with pytest.raises(ValueError): 157 | affine_transformations.random_zoom(0, zoom_range=[0]) 158 | 159 | 160 | def test_apply_brightness_shift_error(monkeypatch): 161 | monkeypatch.setattr(affine_transformations, 'ImageEnhance', None) 162 | with pytest.raises(ImportError): 163 | affine_transformations.apply_brightness_shift(0, [0]) 164 | 165 | 166 | def test_random_brightness(monkeypatch): 167 | monkeypatch.setattr(affine_transformations, 168 | 'apply_brightness_shift', lambda x, y, z: (x, y)) 169 | assert (0, 3.) == affine_transformations.random_brightness(0, (3, 3)) 170 | 171 | 172 | def test_random_brightness_error(): 173 | with pytest.raises(ValueError): 174 | affine_transformations.random_brightness(0, [0]) 175 | 176 | 177 | def test_random_brightness_scale(): 178 | img = np.ones((1, 1, 3)) * 128 179 | zeros = np.zeros((1, 1, 3)) 180 | must_be_128 = affine_transformations.random_brightness(img, [1, 1], False) 181 | assert np.array_equal(img, must_be_128) 182 | must_be_0 = affine_transformations.random_brightness(img, [1, 1], True) 183 | assert np.array_equal(zeros, must_be_0) 184 | 185 | 186 | def test_random_brightness_scale_outside_range_positive(): 187 | img = np.ones((1, 1, 3)) * 1024 188 | zeros = np.zeros((1, 1, 3)) 189 | must_be_1024 = affine_transformations.random_brightness(img, [1, 1], False) 190 | assert np.array_equal(img, must_be_1024) 191 | must_be_0 = affine_transformations.random_brightness(img, [1, 1], True) 192 | assert np.array_equal(zeros, must_be_0) 193 | 194 | 195 | def test_random_brightness_scale_outside_range_negative(): 196 | img = np.ones((1, 1, 3)) * -1024 197 | zeros = np.zeros((1, 1, 3)) 198 | must_be_neg_1024 = affine_transformations.random_brightness(img, [1, 1], False) 199 | assert np.array_equal(img, must_be_neg_1024) 200 | must_be_0 = affine_transformations.random_brightness(img, [1, 1], True) 201 | assert np.array_equal(zeros, must_be_0) 202 | 203 | 204 | def test_apply_affine_transform_error(monkeypatch): 205 | monkeypatch.setattr(affine_transformations, 'scipy', None) 206 | with pytest.raises(ImportError): 207 | affine_transformations.apply_affine_transform(0) 208 | -------------------------------------------------------------------------------- /tests/image/directory_iterator_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | 5 | import numpy as np 6 | import pytest 7 | from PIL import Image 8 | 9 | from keras_preprocessing.image import image_data_generator 10 | 11 | 12 | @pytest.fixture(scope='module') 13 | def all_test_images(): 14 | img_w = img_h = 20 15 | rgb_images = [] 16 | rgba_images = [] 17 | gray_images = [] 18 | gray_images_16bit = [] 19 | gray_images_32bit = [] 20 | for n in range(8): 21 | bias = np.random.rand(img_w, img_h, 1) * 64 22 | variance = np.random.rand(img_w, img_h, 1) * (255 - 64) 23 | # RGB 24 | imarray = np.random.rand(img_w, img_h, 3) * variance + bias 25 | im = Image.fromarray(imarray.astype('uint8')).convert('RGB') 26 | rgb_images.append(im) 27 | # RGBA 28 | imarray = np.random.rand(img_w, img_h, 4) * variance + bias 29 | im = Image.fromarray(imarray.astype('uint8')).convert('RGBA') 30 | rgba_images.append(im) 31 | # 8-bit grayscale 32 | imarray = np.random.rand(img_w, img_h, 1) * variance + bias 33 | im = Image.fromarray(imarray.astype('uint8').squeeze()).convert('L') 34 | gray_images.append(im) 35 | # 16-bit grayscale 36 | imarray = np.array( 37 | np.random.randint(-2147483648, 2147483647, (img_w, img_h)) 38 | ) 39 | im = Image.fromarray(imarray.astype('uint16')) 40 | gray_images_16bit.append(im) 41 | # 32-bit grayscale 42 | im = Image.fromarray(imarray.astype('uint32')) 43 | gray_images_32bit.append(im) 44 | 45 | return [rgb_images, rgba_images, 46 | gray_images, gray_images_16bit, gray_images_32bit] 47 | 48 | 49 | def test_directory_iterator(all_test_images, tmpdir): 50 | num_classes = 2 51 | 52 | # create folders and subfolders 53 | paths = [] 54 | for cl in range(num_classes): 55 | class_directory = 'class-{}'.format(cl) 56 | classpaths = [ 57 | class_directory, 58 | os.path.join(class_directory, 'subfolder-1'), 59 | os.path.join(class_directory, 'subfolder-2'), 60 | os.path.join(class_directory, 'subfolder-1', 'sub-subfolder') 61 | ] 62 | for path in classpaths: 63 | tmpdir.join(path).mkdir() 64 | paths.append(classpaths) 65 | 66 | # save the images in the paths 67 | count = 0 68 | filenames = [] 69 | for test_images in all_test_images: 70 | for im in test_images: 71 | # rotate image class 72 | im_class = count % num_classes 73 | # rotate subfolders 74 | classpaths = paths[im_class] 75 | filename = os.path.join( 76 | classpaths[count % len(classpaths)], 77 | 'image-{}.png'.format(count)) 78 | filenames.append(filename) 79 | im.save(str(tmpdir / filename)) 80 | count += 1 81 | 82 | # create iterator 83 | generator = image_data_generator.ImageDataGenerator() 84 | dir_iterator = generator.flow_from_directory(str(tmpdir)) 85 | 86 | # check number of classes and images 87 | assert len(dir_iterator.class_indices) == num_classes 88 | assert len(dir_iterator.classes) == count 89 | assert set(dir_iterator.filenames) == set(filenames) 90 | 91 | # Test invalid use cases 92 | with pytest.raises(ValueError): 93 | generator.flow_from_directory(str(tmpdir), color_mode='cmyk') 94 | with pytest.raises(ValueError): 95 | generator.flow_from_directory(str(tmpdir), class_mode='output') 96 | 97 | def preprocessing_function(x): 98 | """This will fail if not provided by a Numpy array. 99 | Note: This is made to enforce backward compatibility. 100 | """ 101 | 102 | assert x.shape == (26, 26, 3) 103 | assert type(x) is np.ndarray 104 | 105 | return np.zeros_like(x) 106 | 107 | # Test usage as Sequence 108 | generator = image_data_generator.ImageDataGenerator( 109 | preprocessing_function=preprocessing_function) 110 | dir_seq = generator.flow_from_directory(str(tmpdir), 111 | target_size=(26, 26), 112 | color_mode='rgb', 113 | batch_size=3, 114 | class_mode='categorical') 115 | assert len(dir_seq) == np.ceil(count / 3.) 116 | x1, y1 = dir_seq[1] 117 | assert x1.shape == (3, 26, 26, 3) 118 | assert y1.shape == (3, num_classes) 119 | x1, y1 = dir_seq[5] 120 | assert (x1 == 0).all() 121 | 122 | with pytest.raises(ValueError): 123 | x1, y1 = dir_seq[14] # there are 40 images and batch size is 3 124 | 125 | 126 | def test_directory_iterator_class_mode_input(all_test_images, tmpdir): 127 | tmpdir.join('class-1').mkdir() 128 | 129 | # save the images in the paths 130 | count = 0 131 | for test_images in all_test_images: 132 | for im in test_images: 133 | filename = str( 134 | tmpdir / 'class-1' / 'image-{}.png'.format(count)) 135 | im.save(filename) 136 | count += 1 137 | 138 | # create iterator 139 | generator = image_data_generator.ImageDataGenerator() 140 | dir_iterator = generator.flow_from_directory(str(tmpdir), 141 | class_mode='input') 142 | batch = next(dir_iterator) 143 | 144 | # check if input and output have the same shape 145 | assert(batch[0].shape == batch[1].shape) 146 | # check if the input and output images are not the same numpy array 147 | input_img = batch[0][0] 148 | output_img = batch[1][0] 149 | output_img[0][0][0] += 1 150 | assert(input_img[0][0][0] != output_img[0][0][0]) 151 | 152 | 153 | @pytest.mark.parametrize('validation_split,num_training', [ 154 | (0.25, 30), 155 | (0.50, 20), 156 | (0.75, 10), 157 | ]) 158 | def test_directory_iterator_with_validation_split(all_test_images, 159 | validation_split, 160 | num_training): 161 | num_classes = 2 162 | tmp_folder = tempfile.mkdtemp(prefix='test_images') 163 | 164 | # create folders and subfolders 165 | paths = [] 166 | for cl in range(num_classes): 167 | class_directory = 'class-{}'.format(cl) 168 | classpaths = [ 169 | class_directory, 170 | os.path.join(class_directory, 'subfolder-1'), 171 | os.path.join(class_directory, 'subfolder-2'), 172 | os.path.join(class_directory, 'subfolder-1', 'sub-subfolder') 173 | ] 174 | for path in classpaths: 175 | os.mkdir(os.path.join(tmp_folder, path)) 176 | paths.append(classpaths) 177 | 178 | # save the images in the paths 179 | count = 0 180 | filenames = [] 181 | for test_images in all_test_images: 182 | for im in test_images: 183 | # rotate image class 184 | im_class = count % num_classes 185 | # rotate subfolders 186 | classpaths = paths[im_class] 187 | filename = os.path.join( 188 | classpaths[count % len(classpaths)], 189 | 'image-{}.png'.format(count)) 190 | filenames.append(filename) 191 | im.save(os.path.join(tmp_folder, filename)) 192 | count += 1 193 | 194 | # create iterator 195 | generator = image_data_generator.ImageDataGenerator( 196 | validation_split=validation_split 197 | ) 198 | 199 | with pytest.raises(ValueError): 200 | generator.flow_from_directory(tmp_folder, subset='foo') 201 | 202 | train_iterator = generator.flow_from_directory(tmp_folder, 203 | subset='training') 204 | assert train_iterator.samples == num_training 205 | 206 | valid_iterator = generator.flow_from_directory(tmp_folder, 207 | subset='validation') 208 | assert valid_iterator.samples == count - num_training 209 | 210 | # check number of classes and images 211 | assert len(train_iterator.class_indices) == num_classes 212 | assert len(train_iterator.classes) == num_training 213 | assert len(set(train_iterator.filenames) & 214 | set(filenames)) == num_training 215 | 216 | shutil.rmtree(tmp_folder) 217 | 218 | 219 | if __name__ == '__main__': 220 | pytest.main([__file__]) 221 | -------------------------------------------------------------------------------- /tests/image/image_data_generator_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from PIL import Image 4 | 5 | from keras_preprocessing.image import image_data_generator, utils 6 | 7 | 8 | @pytest.fixture(scope='module') 9 | def all_test_images(): 10 | img_w = img_h = 20 11 | rgb_images = [] 12 | rgba_images = [] 13 | gray_images = [] 14 | for n in range(8): 15 | bias = np.random.rand(img_w, img_h, 1) * 64 16 | variance = np.random.rand(img_w, img_h, 1) * (255 - 64) 17 | imarray = np.random.rand(img_w, img_h, 3) * variance + bias 18 | im = Image.fromarray(imarray.astype('uint8')).convert('RGB') 19 | rgb_images.append(im) 20 | 21 | imarray = np.random.rand(img_w, img_h, 4) * variance + bias 22 | im = Image.fromarray(imarray.astype('uint8')).convert('RGBA') 23 | rgba_images.append(im) 24 | 25 | imarray = np.random.rand(img_w, img_h, 1) * variance + bias 26 | im = Image.fromarray( 27 | imarray.astype('uint8').squeeze()).convert('L') 28 | gray_images.append(im) 29 | 30 | return [rgb_images, rgba_images, gray_images] 31 | 32 | 33 | def test_image_data_generator(all_test_images): 34 | for test_images in all_test_images: 35 | img_list = [] 36 | for im in test_images: 37 | img_list.append(utils.img_to_array(im)[None, ...]) 38 | 39 | image_data_generator.ImageDataGenerator( 40 | featurewise_center=True, 41 | samplewise_center=True, 42 | featurewise_std_normalization=True, 43 | samplewise_std_normalization=True, 44 | zca_whitening=True, 45 | rotation_range=90., 46 | width_shift_range=0.1, 47 | height_shift_range=0.1, 48 | shear_range=0.5, 49 | zoom_range=0.2, 50 | channel_shift_range=0., 51 | brightness_range=(1, 5), 52 | fill_mode='nearest', 53 | cval=0.5, 54 | horizontal_flip=True, 55 | vertical_flip=True, 56 | interpolation_order=1 57 | ) 58 | 59 | 60 | def test_image_data_generator_with_validation_split(all_test_images): 61 | for test_images in all_test_images: 62 | img_list = [] 63 | for im in test_images: 64 | img_list.append(utils.img_to_array(im)[None, ...]) 65 | 66 | images = np.vstack(img_list) 67 | labels = np.concatenate([ 68 | np.zeros((int(len(images) / 2),)), 69 | np.ones((int(len(images) / 2),))]) 70 | generator = image_data_generator.ImageDataGenerator(validation_split=0.5) 71 | 72 | # training and validation sets would have different 73 | # number of classes, because labels are sorted 74 | with pytest.raises(ValueError, 75 | match='Training and validation subsets ' 76 | 'have different number of classes after ' 77 | 'the split.*'): 78 | generator.flow(images, labels, 79 | shuffle=False, batch_size=10, 80 | subset='validation') 81 | 82 | # test non categorical labels with validation split 83 | generator.flow(images, labels, 84 | shuffle=False, batch_size=10, 85 | ignore_class_split=True, 86 | subset='validation') 87 | 88 | labels = np.concatenate([ 89 | np.zeros((int(len(images) / 4),)), 90 | np.ones((int(len(images) / 4),)), 91 | np.zeros((int(len(images) / 4),)), 92 | np.ones((int(len(images) / 4),)) 93 | ]) 94 | 95 | seq = generator.flow(images, labels, 96 | shuffle=False, batch_size=10, 97 | subset='validation') 98 | 99 | x, y = seq[0] 100 | assert 2 == len(np.unique(y)) 101 | 102 | seq = generator.flow(images, labels, 103 | shuffle=False, batch_size=10, 104 | subset='training') 105 | x2, y2 = seq[0] 106 | assert 2 == len(np.unique(y2)) 107 | 108 | with pytest.raises(ValueError): 109 | generator.flow(images, np.arange(images.shape[0]), 110 | shuffle=False, batch_size=3, 111 | subset='foo') 112 | 113 | 114 | def test_image_data_generator_with_split_value_error(): 115 | with pytest.raises(ValueError): 116 | image_data_generator.ImageDataGenerator(validation_split=5) 117 | 118 | 119 | def test_image_data_generator_invalid_data(): 120 | generator = image_data_generator.ImageDataGenerator( 121 | featurewise_center=True, 122 | samplewise_center=True, 123 | featurewise_std_normalization=True, 124 | samplewise_std_normalization=True, 125 | zca_whitening=True, 126 | data_format='channels_last') 127 | # Test fit with invalid data 128 | with pytest.raises(ValueError): 129 | x = np.random.random((3, 10, 10)) 130 | generator.fit(x) 131 | 132 | # Test flow with invalid data 133 | with pytest.raises(ValueError): 134 | x = np.random.random((32, 10, 10)) 135 | generator.flow(np.arange(x.shape[0])) 136 | 137 | 138 | def test_image_data_generator_fit(): 139 | generator = image_data_generator.ImageDataGenerator( 140 | featurewise_center=True, 141 | samplewise_center=True, 142 | featurewise_std_normalization=True, 143 | samplewise_std_normalization=True, 144 | zca_whitening=True, 145 | rotation_range=90., 146 | width_shift_range=0.1, 147 | height_shift_range=0.1, 148 | shear_range=0.5, 149 | zoom_range=(0.2, 0.2), 150 | channel_shift_range=0., 151 | brightness_range=(1, 5), 152 | fill_mode='nearest', 153 | cval=0.5, 154 | horizontal_flip=True, 155 | vertical_flip=True, 156 | interpolation_order=1, 157 | data_format='channels_last' 158 | ) 159 | x = np.random.random((32, 10, 10, 3)) 160 | generator.fit(x, augment=True) 161 | # Test grayscale 162 | x = np.random.random((32, 10, 10, 1)) 163 | generator.fit(x) 164 | # Test RBG 165 | x = np.random.random((32, 10, 10, 3)) 166 | generator.fit(x) 167 | # Test more samples than dims 168 | x = np.random.random((32, 4, 4, 1)) 169 | generator.fit(x) 170 | generator = image_data_generator.ImageDataGenerator( 171 | featurewise_center=True, 172 | samplewise_center=True, 173 | featurewise_std_normalization=True, 174 | samplewise_std_normalization=True, 175 | zca_whitening=True, 176 | rotation_range=90., 177 | width_shift_range=0.1, 178 | height_shift_range=0.1, 179 | shear_range=0.5, 180 | zoom_range=(0.2, 0.2), 181 | channel_shift_range=0., 182 | brightness_range=(1, 5), 183 | fill_mode='nearest', 184 | cval=0.5, 185 | horizontal_flip=True, 186 | vertical_flip=True, 187 | interpolation_order=1, 188 | data_format='channels_first' 189 | ) 190 | x = np.random.random((32, 10, 10, 3)) 191 | generator.fit(x, augment=True) 192 | # Test grayscale 193 | x = np.random.random((32, 1, 10, 10)) 194 | generator.fit(x) 195 | # Test RBG 196 | x = np.random.random((32, 3, 10, 10)) 197 | generator.fit(x) 198 | # Test more samples than dims 199 | x = np.random.random((32, 1, 4, 4)) 200 | generator.fit(x) 201 | 202 | 203 | def test_image_data_generator_flow(all_test_images, tmpdir): 204 | for test_images in all_test_images: 205 | img_list = [] 206 | for im in test_images: 207 | img_list.append(utils.img_to_array(im)[None, ...]) 208 | 209 | images = np.vstack(img_list) 210 | dsize = images.shape[0] 211 | generator = image_data_generator.ImageDataGenerator( 212 | featurewise_center=True, 213 | samplewise_center=True, 214 | featurewise_std_normalization=True, 215 | samplewise_std_normalization=True, 216 | zca_whitening=True, 217 | rotation_range=90., 218 | width_shift_range=0.1, 219 | height_shift_range=0.1, 220 | shear_range=0.5, 221 | zoom_range=0.2, 222 | channel_shift_range=0., 223 | brightness_range=(1, 5), 224 | fill_mode='nearest', 225 | cval=0.5, 226 | horizontal_flip=True, 227 | vertical_flip=True, 228 | interpolation_order=1 229 | ) 230 | 231 | generator.flow( 232 | images, 233 | np.arange(images.shape[0]), 234 | shuffle=False, 235 | save_to_dir=str(tmpdir), 236 | batch_size=3 237 | ) 238 | 239 | generator.flow( 240 | images, 241 | np.arange(images.shape[0]), 242 | shuffle=False, 243 | sample_weight=np.arange(images.shape[0]) + 1, 244 | save_to_dir=str(tmpdir), 245 | batch_size=3 246 | ) 247 | 248 | # Test with `shuffle=True` 249 | generator.flow( 250 | images, np.arange(images.shape[0]), 251 | shuffle=True, 252 | save_to_dir=str(tmpdir), 253 | batch_size=3, 254 | seed=42 255 | ) 256 | 257 | # Test without y 258 | generator.flow( 259 | images, 260 | None, 261 | shuffle=True, 262 | save_to_dir=str(tmpdir), 263 | batch_size=3 264 | ) 265 | 266 | # Test with a single miscellaneous input data array 267 | x_misc1 = np.random.random(dsize) 268 | generator.flow( 269 | (images, x_misc1), 270 | np.arange(dsize), 271 | shuffle=False, 272 | batch_size=2 273 | ) 274 | 275 | # Test with two miscellaneous inputs 276 | x_misc2 = np.random.random((dsize, 3, 3)) 277 | generator.flow( 278 | (images, [x_misc1, x_misc2]), 279 | np.arange(dsize), 280 | shuffle=False, 281 | batch_size=2 282 | ) 283 | 284 | # Test cases with `y = None` 285 | generator.flow(images, None, batch_size=3) 286 | generator.flow((images, x_misc1), None, batch_size=3, shuffle=False) 287 | generator.flow( 288 | (images, [x_misc1, x_misc2]), 289 | None, 290 | batch_size=3, 291 | shuffle=False 292 | ) 293 | generator = image_data_generator.ImageDataGenerator(validation_split=0.2) 294 | generator.flow(images, batch_size=3) 295 | 296 | # Test some failure cases: 297 | x_misc_err = np.random.random((dsize + 1, 3, 3)) 298 | with pytest.raises(ValueError) as e_info: 299 | generator.flow((images, x_misc_err), np.arange(dsize), batch_size=3) 300 | assert str(e_info.value).find('All of the arrays in') != -1 301 | 302 | with pytest.raises(ValueError) as e_info: 303 | generator.flow((images, x_misc1), np.arange(dsize + 1), batch_size=3) 304 | assert str(e_info.value).find('`x` (images tensor) and `y` (labels) ') != -1 305 | 306 | # Test `flow` behavior as Sequence 307 | generator.flow( 308 | images, 309 | np.arange(images.shape[0]), 310 | shuffle=False, 311 | save_to_dir=str(tmpdir), 312 | batch_size=3 313 | ) 314 | 315 | # Test with `shuffle=True` 316 | generator.flow( 317 | images, 318 | np.arange(images.shape[0]), 319 | shuffle=True, save_to_dir=str(tmpdir), 320 | batch_size=3, seed=123 321 | ) 322 | 323 | # test order_interpolation 324 | labels = np.array([[2, 2, 0, 2, 2], 325 | [1, 3, 2, 3, 1], 326 | [2, 1, 0, 1, 2], 327 | [3, 1, 0, 2, 0], 328 | [3, 1, 3, 2, 1]]) 329 | 330 | label_generator = image_data_generator.ImageDataGenerator( 331 | rotation_range=90., 332 | interpolation_order=0 333 | ) 334 | label_generator.flow( 335 | x=labels[np.newaxis, ..., np.newaxis], 336 | seed=123 337 | ) 338 | 339 | 340 | def test_valid_args(): 341 | with pytest.raises(ValueError): 342 | image_data_generator.ImageDataGenerator(brightness_range=0.1) 343 | 344 | 345 | def test_batch_standardize(all_test_images): 346 | # ImageDataGenerator.standardize should work on batches 347 | for test_images in all_test_images: 348 | img_list = [] 349 | for im in test_images: 350 | img_list.append(utils.img_to_array(im)[None, ...]) 351 | 352 | images = np.vstack(img_list) 353 | generator = image_data_generator.ImageDataGenerator( 354 | featurewise_center=True, 355 | samplewise_center=True, 356 | featurewise_std_normalization=True, 357 | samplewise_std_normalization=True, 358 | zca_whitening=True, 359 | rotation_range=90., 360 | width_shift_range=0.1, 361 | height_shift_range=0.1, 362 | shear_range=0.5, 363 | zoom_range=0.2, 364 | channel_shift_range=0., 365 | brightness_range=(1, 5), 366 | fill_mode='nearest', 367 | cval=0.5, 368 | horizontal_flip=True, 369 | vertical_flip=True) 370 | generator.fit(images, augment=True) 371 | 372 | transformed = np.copy(images) 373 | for i, im in enumerate(transformed): 374 | transformed[i] = generator.random_transform(im) 375 | transformed = generator.standardize(transformed) 376 | 377 | 378 | def test_deterministic_transform(): 379 | x = np.ones((32, 32, 3)) 380 | generator = image_data_generator.ImageDataGenerator( 381 | rotation_range=90, 382 | fill_mode='constant') 383 | x = np.random.random((32, 32, 3)) 384 | assert np.allclose(generator.apply_transform(x, {'flip_vertical': True}), 385 | x[::-1, :, :]) 386 | assert np.allclose(generator.apply_transform(x, {'flip_horizontal': True}), 387 | x[:, ::-1, :]) 388 | x = np.ones((3, 3, 3)) 389 | x_rotated = np.array([[[0., 0., 0.], 390 | [1., 1., 1.], 391 | [0., 0., 0.]], 392 | [[1., 1., 1.], 393 | [1., 1., 1.], 394 | [1., 1., 1.]], 395 | [[0., 0., 0.], 396 | [1., 1., 1.], 397 | [0., 0., 0.]]]) 398 | assert np.allclose(generator.apply_transform(x, {'theta': 45}), 399 | x_rotated) 400 | 401 | 402 | def test_random_transforms(): 403 | x = np.random.random((2, 28, 28)) 404 | # Test get_random_transform with predefined seed 405 | seed = 1 406 | generator = image_data_generator.ImageDataGenerator( 407 | rotation_range=90., 408 | width_shift_range=0.1, 409 | height_shift_range=0.1, 410 | shear_range=0.5, 411 | zoom_range=0.2, 412 | channel_shift_range=0.1, 413 | brightness_range=(1, 5), 414 | horizontal_flip=True, 415 | vertical_flip=True) 416 | transform_dict = generator.get_random_transform(x.shape, seed) 417 | transform_dict2 = generator.get_random_transform(x.shape, seed * 2) 418 | assert transform_dict['theta'] != 0 419 | assert transform_dict['theta'] != transform_dict2['theta'] 420 | assert transform_dict['tx'] != 0 421 | assert transform_dict['tx'] != transform_dict2['tx'] 422 | assert transform_dict['ty'] != 0 423 | assert transform_dict['ty'] != transform_dict2['ty'] 424 | assert transform_dict['shear'] != 0 425 | assert transform_dict['shear'] != transform_dict2['shear'] 426 | assert transform_dict['zx'] != 0 427 | assert transform_dict['zx'] != transform_dict2['zx'] 428 | assert transform_dict['zy'] != 0 429 | assert transform_dict['zy'] != transform_dict2['zy'] 430 | assert transform_dict['channel_shift_intensity'] != 0 431 | assert (transform_dict['channel_shift_intensity'] != 432 | transform_dict2['channel_shift_intensity']) 433 | assert transform_dict['brightness'] != 0 434 | assert transform_dict['brightness'] != transform_dict2['brightness'] 435 | 436 | # Test get_random_transform without any randomness 437 | generator = image_data_generator.ImageDataGenerator() 438 | transform_dict = generator.get_random_transform(x.shape, seed) 439 | assert transform_dict['theta'] == 0 440 | assert transform_dict['tx'] == 0 441 | assert transform_dict['ty'] == 0 442 | assert transform_dict['shear'] == 0 443 | assert transform_dict['zx'] == 1 444 | assert transform_dict['zy'] == 1 445 | assert transform_dict['channel_shift_intensity'] is None 446 | assert transform_dict['brightness'] is None 447 | 448 | 449 | def test_fit_rescale(all_test_images): 450 | rescale = 1. / 255 451 | 452 | for test_images in all_test_images: 453 | img_list = [] 454 | for im in test_images: 455 | img_list.append(utils.img_to_array(im)[None, ...]) 456 | images = np.vstack(img_list) 457 | 458 | # featurewise_center test 459 | generator = image_data_generator.ImageDataGenerator( 460 | rescale=rescale, 461 | featurewise_center=True, 462 | dtype='float64') 463 | generator.fit(images) 464 | batch = generator.flow(images, batch_size=8).next() 465 | assert abs(np.mean(batch)) < 1e-6 466 | 467 | # featurewise_std_normalization test 468 | generator = image_data_generator.ImageDataGenerator( 469 | rescale=rescale, 470 | featurewise_center=True, 471 | featurewise_std_normalization=True, 472 | dtype='float64') 473 | generator.fit(images) 474 | batch = generator.flow(images, batch_size=8).next() 475 | assert abs(np.mean(batch)) < 1e-6 476 | assert abs(1 - np.std(batch)) < 1e-5 477 | 478 | # zca_whitening test 479 | generator = image_data_generator.ImageDataGenerator( 480 | rescale=rescale, 481 | featurewise_center=True, 482 | zca_whitening=True, 483 | dtype='float64') 484 | generator.fit(images) 485 | batch = generator.flow(images, batch_size=8).next() 486 | batch = np.reshape(batch, 487 | (batch.shape[0], 488 | batch.shape[1] * batch.shape[2] * batch.shape[3])) 489 | # Y * Y_T = n * I, where Y = W * X 490 | identity = np.dot(batch, batch.T) / batch.shape[0] 491 | assert ((np.abs(identity) - np.identity(identity.shape[0])) 492 | < 1e-6).all() 493 | 494 | 495 | if __name__ == '__main__': 496 | pytest.main([__file__]) 497 | -------------------------------------------------------------------------------- /tests/image/iterator_test.py: -------------------------------------------------------------------------------- 1 | from keras_preprocessing.image import iterator 2 | 3 | 4 | def test_iterator_empty_directory(): 5 | # Testing with different batch sizes 6 | for batch_size in [0, 32]: 7 | data_iterator = iterator.Iterator(0, batch_size, False, 0) 8 | ret = next(data_iterator.index_generator) 9 | assert ret.size == 0 10 | -------------------------------------------------------------------------------- /tests/image/numpy_array_iterator_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from PIL import Image 4 | 5 | from keras_preprocessing.image import numpy_array_iterator, utils 6 | from keras_preprocessing.image.image_data_generator import ImageDataGenerator 7 | 8 | 9 | @pytest.fixture(scope='module') 10 | def all_test_images(): 11 | img_w = img_h = 20 12 | rgb_images = [] 13 | rgba_images = [] 14 | gray_images = [] 15 | for n in range(8): 16 | bias = np.random.rand(img_w, img_h, 1) * 64 17 | variance = np.random.rand(img_w, img_h, 1) * (255 - 64) 18 | imarray = np.random.rand(img_w, img_h, 3) * variance + bias 19 | im = Image.fromarray(imarray.astype('uint8')).convert('RGB') 20 | rgb_images.append(im) 21 | 22 | imarray = np.random.rand(img_w, img_h, 4) * variance + bias 23 | im = Image.fromarray(imarray.astype('uint8')).convert('RGBA') 24 | rgba_images.append(im) 25 | 26 | imarray = np.random.rand(img_w, img_h, 1) * variance + bias 27 | im = Image.fromarray( 28 | imarray.astype('uint8').squeeze()).convert('L') 29 | gray_images.append(im) 30 | 31 | return [rgb_images, rgba_images, gray_images] 32 | 33 | 34 | @pytest.fixture(scope='module') 35 | def image_data_generator(): 36 | return ImageDataGenerator( 37 | featurewise_center=True, 38 | samplewise_center=True, 39 | featurewise_std_normalization=True, 40 | samplewise_std_normalization=True, 41 | zca_whitening=True, 42 | rotation_range=90., 43 | width_shift_range=0.1, 44 | height_shift_range=0.1, 45 | shear_range=0.5, 46 | zoom_range=0.2, 47 | channel_shift_range=0., 48 | brightness_range=(1, 5), 49 | fill_mode='nearest', 50 | cval=0.5, 51 | horizontal_flip=True, 52 | vertical_flip=True, 53 | interpolation_order=1 54 | ) 55 | 56 | 57 | def test_numpy_array_iterator(image_data_generator, all_test_images, tmpdir): 58 | for test_images in all_test_images: 59 | img_list = [] 60 | for im in test_images: 61 | img_list.append(utils.img_to_array(im)[None, ...]) 62 | images = np.vstack(img_list) 63 | dsize = images.shape[0] 64 | 65 | iterator = numpy_array_iterator.NumpyArrayIterator( 66 | images, 67 | np.arange(images.shape[0]), 68 | image_data_generator, 69 | shuffle=False, 70 | save_to_dir=str(tmpdir), 71 | batch_size=3 72 | ) 73 | x, y = next(iterator) 74 | assert x.shape == images[:3].shape 75 | assert list(y) == [0, 1, 2] 76 | 77 | # Test with sample weights 78 | iterator = numpy_array_iterator.NumpyArrayIterator( 79 | images, 80 | np.arange(images.shape[0]), 81 | image_data_generator, 82 | shuffle=False, 83 | sample_weight=np.arange(images.shape[0]) + 1, 84 | save_to_dir=str(tmpdir), 85 | batch_size=3 86 | ) 87 | x, y, w = iterator.next() 88 | assert x.shape == images[:3].shape 89 | assert list(y) == [0, 1, 2] 90 | assert list(w) == [1, 2, 3] 91 | 92 | # Test with `shuffle=True` 93 | iterator = numpy_array_iterator.NumpyArrayIterator( 94 | images, 95 | np.arange(images.shape[0]), 96 | image_data_generator, 97 | shuffle=True, 98 | save_to_dir=str(tmpdir), 99 | batch_size=3, 100 | seed=42 101 | ) 102 | x, y = iterator.next() 103 | assert x.shape == images[:3].shape 104 | # Check that the sequence is shuffled. 105 | assert list(y) != [0, 1, 2] 106 | 107 | # Test without y 108 | iterator = numpy_array_iterator.NumpyArrayIterator( 109 | images, 110 | None, 111 | image_data_generator, 112 | shuffle=True, 113 | save_to_dir=str(tmpdir), 114 | batch_size=3 115 | ) 116 | x = iterator.next() 117 | assert type(x) is np.ndarray 118 | assert x.shape == images[:3].shape 119 | 120 | # Test with a single miscellaneous input data array 121 | x_misc1 = np.random.random(dsize) 122 | iterator = numpy_array_iterator.NumpyArrayIterator( 123 | (images, x_misc1), 124 | np.arange(dsize), 125 | image_data_generator, 126 | shuffle=False, 127 | batch_size=2 128 | ) 129 | for i, (x, y) in enumerate(iterator): 130 | assert x[0].shape == images[:2].shape 131 | assert (x[1] == x_misc1[(i * 2):((i + 1) * 2)]).all() 132 | if i == 2: 133 | break 134 | 135 | # Test with two miscellaneous inputs 136 | x_misc2 = np.random.random((dsize, 3, 3)) 137 | iterator = numpy_array_iterator.NumpyArrayIterator( 138 | (images, [x_misc1, x_misc2]), 139 | np.arange(dsize), 140 | image_data_generator, 141 | shuffle=False, 142 | batch_size=2 143 | ) 144 | for i, (x, y) in enumerate(iterator): 145 | assert x[0].shape == images[:2].shape 146 | assert (x[1] == x_misc1[(i * 2):((i + 1) * 2)]).all() 147 | assert (x[2] == x_misc2[(i * 2):((i + 1) * 2)]).all() 148 | if i == 2: 149 | break 150 | 151 | # Test cases with `y = None` 152 | iterator = numpy_array_iterator.NumpyArrayIterator( 153 | images, 154 | None, 155 | image_data_generator, 156 | batch_size=3 157 | ) 158 | x = iterator.next() 159 | assert type(x) is np.ndarray 160 | assert x.shape == images[:3].shape 161 | 162 | iterator = numpy_array_iterator.NumpyArrayIterator( 163 | (images, x_misc1), 164 | None, 165 | image_data_generator, 166 | batch_size=3, 167 | shuffle=False 168 | ) 169 | x = iterator.next() 170 | assert type(x) is list 171 | assert x[0].shape == images[:3].shape 172 | assert (x[1] == x_misc1[:3]).all() 173 | 174 | iterator = numpy_array_iterator.NumpyArrayIterator( 175 | (images, [x_misc1, x_misc2]), 176 | None, 177 | image_data_generator, 178 | batch_size=3, 179 | shuffle=False 180 | ) 181 | x = iterator.next() 182 | assert type(x) is list 183 | assert x[0].shape == images[:3].shape 184 | assert (x[1] == x_misc1[:3]).all() 185 | assert (x[2] == x_misc2[:3]).all() 186 | 187 | # Test with validation split 188 | generator = ImageDataGenerator(validation_split=0.2) 189 | iterator = numpy_array_iterator.NumpyArrayIterator( 190 | images, 191 | None, 192 | generator, 193 | batch_size=3 194 | ) 195 | x = iterator.next() 196 | assert isinstance(x, np.ndarray) 197 | assert x.shape == images[:3].shape 198 | 199 | # Test some failure cases: 200 | x_misc_err = np.random.random((dsize + 1, 3, 3)) 201 | 202 | with pytest.raises(ValueError) as e_info: 203 | numpy_array_iterator.NumpyArrayIterator( 204 | (images, x_misc_err), 205 | np.arange(dsize), 206 | generator, 207 | batch_size=3 208 | ) 209 | assert str(e_info.value).find('All of the arrays in') != -1 210 | 211 | with pytest.raises(ValueError) as e_info: 212 | numpy_array_iterator.NumpyArrayIterator( 213 | (images, x_misc1), 214 | np.arange(dsize + 1), 215 | generator, 216 | batch_size=3 217 | ) 218 | assert str(e_info.value).find('`x` (images tensor) and `y` (labels) ') != -1 219 | 220 | # Test `flow` behavior as Sequence 221 | seq = numpy_array_iterator.NumpyArrayIterator( 222 | images, 223 | np.arange(images.shape[0]), 224 | generator, 225 | shuffle=False, save_to_dir=str(tmpdir), 226 | batch_size=3 227 | ) 228 | assert len(seq) == images.shape[0] // 3 + 1 229 | x, y = seq[0] 230 | assert x.shape == images[:3].shape 231 | assert list(y) == [0, 1, 2] 232 | 233 | # Test with `shuffle=True` 234 | seq = numpy_array_iterator.NumpyArrayIterator( 235 | images, 236 | np.arange(images.shape[0]), 237 | generator, 238 | shuffle=True, 239 | save_to_dir=str(tmpdir), 240 | batch_size=3, 241 | seed=123 242 | ) 243 | x, y = seq[0] 244 | # Check that the sequence is shuffled. 245 | assert list(y) != [0, 1, 2] 246 | # `on_epoch_end` should reshuffle the sequence. 247 | seq.on_epoch_end() 248 | x2, y2 = seq[0] 249 | assert list(y) != list(y2) 250 | 251 | # test order_interpolation 252 | labels = np.array([[2, 2, 0, 2, 2], 253 | [1, 3, 2, 3, 1], 254 | [2, 1, 0, 1, 2], 255 | [3, 1, 0, 2, 0], 256 | [3, 1, 3, 2, 1]]) 257 | label_generator = ImageDataGenerator( 258 | rotation_range=90., 259 | interpolation_order=0 260 | ) 261 | labels_gen = numpy_array_iterator.NumpyArrayIterator( 262 | labels[np.newaxis, ..., np.newaxis], 263 | None, 264 | label_generator, 265 | seed=123 266 | ) 267 | assert (np.unique(labels) == np.unique(next(labels_gen))).all() 268 | -------------------------------------------------------------------------------- /tests/image/test_image_api.py: -------------------------------------------------------------------------------- 1 | from keras_preprocessing import image 2 | 3 | 4 | def test_api_classes(): 5 | expected_exposed_classes = [ 6 | 'DataFrameIterator', 7 | 'DirectoryIterator', 8 | 'ImageDataGenerator', 9 | 'Iterator', 10 | 'NumpyArrayIterator', 11 | ] 12 | for _class in expected_exposed_classes: 13 | assert hasattr(image, _class) 14 | 15 | 16 | def test_api_functions(): 17 | expected_exposed_functions = [ 18 | 'flip_axis', 19 | 'random_rotation', 20 | 'random_shift', 21 | 'random_shear', 22 | 'random_zoom', 23 | 'apply_channel_shift', 24 | 'random_channel_shift', 25 | 'apply_brightness_shift', 26 | 'random_brightness', 27 | 'transform_matrix_offset_center', 28 | 'apply_affine_transform', 29 | 'validate_filename', 30 | 'save_img', 31 | 'load_img', 32 | 'list_pictures', 33 | 'array_to_img', 34 | 'img_to_array' 35 | ] 36 | for function in expected_exposed_functions: 37 | assert hasattr(image, function) 38 | -------------------------------------------------------------------------------- /tests/image/utils_test.py: -------------------------------------------------------------------------------- 1 | import io 2 | import resource 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import PIL 7 | import pytest 8 | 9 | from keras_preprocessing.image import utils 10 | 11 | 12 | def test_validate_filename(tmpdir): 13 | valid_extensions = ('png', 'jpg') 14 | filename = tmpdir.ensure('test.png') 15 | assert utils.validate_filename(str(filename), valid_extensions) 16 | 17 | filename = tmpdir.ensure('test.PnG') 18 | assert utils.validate_filename(str(filename), valid_extensions) 19 | 20 | filename = tmpdir.ensure('test.some_extension') 21 | assert not utils.validate_filename(str(filename), valid_extensions) 22 | assert not utils.validate_filename('some_test_file.png', valid_extensions) 23 | 24 | 25 | def test_load_img(tmpdir): 26 | filename_rgb = str(tmpdir / 'rgb_utils.png') 27 | filename_rgba = str(tmpdir / 'rgba_utils.png') 28 | filename_grayscale_8bit = str(tmpdir / 'grayscale_8bit_utils.png') 29 | filename_grayscale_16bit = str(tmpdir / 'grayscale_16bit_utils.tiff') 30 | filename_grayscale_32bit = str(tmpdir / 'grayscale_32bit_utils.tiff') 31 | 32 | original_rgb_array = np.array(255 * np.random.rand(100, 100, 3), 33 | dtype=np.uint8) 34 | original_rgb = utils.array_to_img(original_rgb_array, scale=False) 35 | original_rgb.save(filename_rgb) 36 | 37 | original_rgba_array = np.array(255 * np.random.rand(100, 100, 4), 38 | dtype=np.uint8) 39 | original_rgba = utils.array_to_img(original_rgba_array, scale=False) 40 | original_rgba.save(filename_rgba) 41 | 42 | original_grayscale_8bit_array = np.array(255 * np.random.rand(100, 100, 1), 43 | dtype=np.uint8) 44 | original_grayscale_8bit = utils.array_to_img(original_grayscale_8bit_array, 45 | scale=False) 46 | original_grayscale_8bit.save(filename_grayscale_8bit) 47 | 48 | original_grayscale_16bit_array = np.array( 49 | np.random.randint(-2147483648, 2147483647, (100, 100, 1)), dtype=np.int16 50 | ) 51 | original_grayscale_16bit = utils.array_to_img(original_grayscale_16bit_array, 52 | scale=False, dtype='int16') 53 | original_grayscale_16bit.save(filename_grayscale_16bit) 54 | 55 | original_grayscale_32bit_array = np.array( 56 | np.random.randint(-2147483648, 2147483647, (100, 100, 1)), dtype=np.int32 57 | ) 58 | original_grayscale_32bit = utils.array_to_img(original_grayscale_32bit_array, 59 | scale=False, dtype='int32') 60 | original_grayscale_32bit.save(filename_grayscale_32bit) 61 | 62 | # Test that loaded image is exactly equal to original. 63 | 64 | loaded_im = utils.load_img(filename_rgb) 65 | loaded_im_array = utils.img_to_array(loaded_im) 66 | assert loaded_im_array.shape == original_rgb_array.shape 67 | assert np.all(loaded_im_array == original_rgb_array) 68 | 69 | loaded_im = utils.load_img(filename_rgba, color_mode='rgba') 70 | loaded_im_array = utils.img_to_array(loaded_im) 71 | assert loaded_im_array.shape == original_rgba_array.shape 72 | assert np.all(loaded_im_array == original_rgba_array) 73 | 74 | loaded_im = utils.load_img(filename_rgb, color_mode='grayscale') 75 | loaded_im_array = utils.img_to_array(loaded_im) 76 | assert loaded_im_array.shape == (original_rgb_array.shape[0], 77 | original_rgb_array.shape[1], 1) 78 | 79 | loaded_im = utils.load_img(filename_grayscale_8bit, color_mode='grayscale') 80 | loaded_im_array = utils.img_to_array(loaded_im) 81 | assert loaded_im_array.shape == original_grayscale_8bit_array.shape 82 | assert np.all(loaded_im_array == original_grayscale_8bit_array) 83 | 84 | loaded_im = utils.load_img(filename_grayscale_16bit, color_mode='grayscale') 85 | loaded_im_array = utils.img_to_array(loaded_im, dtype='int16') 86 | assert loaded_im_array.shape == original_grayscale_16bit_array.shape 87 | assert np.all(loaded_im_array == original_grayscale_16bit_array) 88 | # test casting int16 image to float32 89 | loaded_im_array = utils.img_to_array(loaded_im) 90 | assert np.allclose(loaded_im_array, original_grayscale_16bit_array) 91 | 92 | loaded_im = utils.load_img(filename_grayscale_32bit, color_mode='grayscale') 93 | loaded_im_array = utils.img_to_array(loaded_im, dtype='int32') 94 | assert loaded_im_array.shape == original_grayscale_32bit_array.shape 95 | assert np.all(loaded_im_array == original_grayscale_32bit_array) 96 | # test casting int32 image to float32 97 | loaded_im_array = utils.img_to_array(loaded_im) 98 | assert np.allclose(loaded_im_array, original_grayscale_32bit_array) 99 | 100 | # Test that nothing is changed when target size is equal to original. 101 | 102 | loaded_im = utils.load_img(filename_rgb, target_size=(100, 100)) 103 | loaded_im_array = utils.img_to_array(loaded_im) 104 | assert loaded_im_array.shape == original_rgb_array.shape 105 | assert np.all(loaded_im_array == original_rgb_array) 106 | 107 | loaded_im = utils.load_img(filename_rgba, color_mode='rgba', 108 | target_size=(100, 100)) 109 | loaded_im_array = utils.img_to_array(loaded_im) 110 | assert loaded_im_array.shape == original_rgba_array.shape 111 | assert np.all(loaded_im_array == original_rgba_array) 112 | 113 | loaded_im = utils.load_img(filename_rgb, color_mode='grayscale', 114 | target_size=(100, 100)) 115 | loaded_im_array = utils.img_to_array(loaded_im) 116 | assert loaded_im_array.shape == (original_rgba_array.shape[0], 117 | original_rgba_array.shape[1], 1) 118 | 119 | loaded_im = utils.load_img(filename_grayscale_8bit, color_mode='grayscale', 120 | target_size=(100, 100)) 121 | loaded_im_array = utils.img_to_array(loaded_im) 122 | assert loaded_im_array.shape == original_grayscale_8bit_array.shape 123 | assert np.all(loaded_im_array == original_grayscale_8bit_array) 124 | 125 | loaded_im = utils.load_img(filename_grayscale_16bit, color_mode='grayscale', 126 | target_size=(100, 100)) 127 | loaded_im_array = utils.img_to_array(loaded_im, dtype='int16') 128 | assert loaded_im_array.shape == original_grayscale_16bit_array.shape 129 | assert np.all(loaded_im_array == original_grayscale_16bit_array) 130 | 131 | loaded_im = utils.load_img(filename_grayscale_32bit, color_mode='grayscale', 132 | target_size=(100, 100)) 133 | loaded_im_array = utils.img_to_array(loaded_im, dtype='int32') 134 | assert loaded_im_array.shape == original_grayscale_32bit_array.shape 135 | assert np.all(loaded_im_array == original_grayscale_32bit_array) 136 | 137 | # Test down-sampling with bilinear interpolation. 138 | 139 | loaded_im = utils.load_img(filename_rgb, target_size=(25, 25)) 140 | loaded_im_array = utils.img_to_array(loaded_im) 141 | assert loaded_im_array.shape == (25, 25, 3) 142 | 143 | loaded_im = utils.load_img(filename_rgba, color_mode='rgba', 144 | target_size=(25, 25)) 145 | loaded_im_array = utils.img_to_array(loaded_im) 146 | assert loaded_im_array.shape == (25, 25, 4) 147 | 148 | loaded_im = utils.load_img(filename_rgb, color_mode='grayscale', 149 | target_size=(25, 25)) 150 | loaded_im_array = utils.img_to_array(loaded_im) 151 | assert loaded_im_array.shape == (25, 25, 1) 152 | 153 | loaded_im = utils.load_img(filename_grayscale_8bit, color_mode='grayscale', 154 | target_size=(25, 25)) 155 | loaded_im_array = utils.img_to_array(loaded_im) 156 | assert loaded_im_array.shape == (25, 25, 1) 157 | 158 | loaded_im = utils.load_img(filename_grayscale_16bit, color_mode='grayscale', 159 | target_size=(25, 25)) 160 | loaded_im_array = utils.img_to_array(loaded_im, dtype='int16') 161 | assert loaded_im_array.shape == (25, 25, 1) 162 | 163 | loaded_im = utils.load_img(filename_grayscale_32bit, color_mode='grayscale', 164 | target_size=(25, 25)) 165 | loaded_im_array = utils.img_to_array(loaded_im, dtype='int32') 166 | assert loaded_im_array.shape == (25, 25, 1) 167 | 168 | # Test down-sampling with nearest neighbor interpolation. 169 | 170 | loaded_im_nearest = utils.load_img(filename_rgb, target_size=(25, 25), 171 | interpolation="nearest") 172 | loaded_im_array_nearest = utils.img_to_array(loaded_im_nearest) 173 | assert loaded_im_array_nearest.shape == (25, 25, 3) 174 | assert np.any(loaded_im_array_nearest != loaded_im_array) 175 | 176 | loaded_im_nearest = utils.load_img(filename_rgba, color_mode='rgba', 177 | target_size=(25, 25), 178 | interpolation="nearest") 179 | loaded_im_array_nearest = utils.img_to_array(loaded_im_nearest) 180 | assert loaded_im_array_nearest.shape == (25, 25, 4) 181 | assert np.any(loaded_im_array_nearest != loaded_im_array) 182 | 183 | loaded_im = utils.load_img(filename_grayscale_8bit, color_mode='grayscale', 184 | target_size=(25, 25), interpolation="nearest") 185 | loaded_im_array = utils.img_to_array(loaded_im) 186 | assert loaded_im_array.shape == (25, 25, 1) 187 | 188 | loaded_im = utils.load_img(filename_grayscale_16bit, color_mode='grayscale', 189 | target_size=(25, 25), interpolation="nearest") 190 | loaded_im_array = utils.img_to_array(loaded_im, dtype='int16') 191 | assert loaded_im_array.shape == (25, 25, 1) 192 | 193 | loaded_im = utils.load_img(filename_grayscale_32bit, color_mode='grayscale', 194 | target_size=(25, 25), interpolation="nearest") 195 | loaded_im_array = utils.img_to_array(loaded_im, dtype='int32') 196 | assert loaded_im_array.shape == (25, 25, 1) 197 | 198 | # Test different path type 199 | with open(filename_grayscale_32bit, 'rb') as f: 200 | _path = io.BytesIO(f.read()) # io.Bytesio 201 | loaded_im = utils.load_img(_path, color_mode='grayscale') 202 | loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32) 203 | assert np.all(loaded_im_array == original_grayscale_32bit_array) 204 | 205 | _path = filename_grayscale_32bit # str 206 | loaded_im = utils.load_img(_path, color_mode='grayscale') 207 | loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32) 208 | assert np.all(loaded_im_array == original_grayscale_32bit_array) 209 | 210 | _path = filename_grayscale_32bit.encode() # bytes 211 | loaded_im = utils.load_img(_path, color_mode='grayscale') 212 | loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32) 213 | assert np.all(loaded_im_array == original_grayscale_32bit_array) 214 | 215 | _path = Path(tmpdir / 'grayscale_32bit_utils.tiff') # Path 216 | loaded_im = utils.load_img(_path, color_mode='grayscale') 217 | loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32) 218 | assert np.all(loaded_im_array == original_grayscale_32bit_array) 219 | 220 | # Check that exception is raised if interpolation not supported. 221 | 222 | loaded_im = utils.load_img(filename_rgb, interpolation="unsupported") 223 | with pytest.raises(ValueError): 224 | loaded_im = utils.load_img(filename_rgb, target_size=(25, 25), 225 | interpolation="unsupported") 226 | 227 | # Check that the aspect ratio of a square is the same 228 | 229 | filename_red_square = str(tmpdir / 'red_square_utils.png') 230 | A = np.zeros((50, 100, 3), dtype=np.uint8) # rectangle image 100x50 231 | A[20:30, 45:55, 0] = 255 # red square 10x10 232 | red_square_array = np.array(A) 233 | red_square = utils.array_to_img(red_square_array, scale=False) 234 | red_square.save(filename_red_square) 235 | 236 | loaded_im = utils.load_img(filename_red_square, target_size=(25, 25), 237 | keep_aspect_ratio=True) 238 | loaded_im_array = utils.img_to_array(loaded_im) 239 | assert loaded_im_array.shape == (25, 25, 3) 240 | 241 | red_channel_arr = loaded_im_array[:, :, 0].astype(np.bool) 242 | square_width = np.sum(np.sum(red_channel_arr, axis=0)) 243 | square_height = np.sum(np.sum(red_channel_arr, axis=1)) 244 | aspect_ratio_result = square_width / square_height 245 | 246 | # original square had 1:1 ratio 247 | assert aspect_ratio_result == pytest.approx(1.0) 248 | 249 | 250 | def test_list_pictures(tmpdir): 251 | filenames = ['test.png', 'test0.jpg', 'test-1.jpeg', '2test.bmp', 252 | '2-test.ppm', '3.png', '1.jpeg', 'test.bmp', 'test0.ppm', 253 | 'test4.tiff', '5-test.tif', 'test.txt', 'foo.csv', 254 | 'face.gif', 'bar.txt'] 255 | subdirs = ['', 'subdir1', 'subdir2'] 256 | filenames = [tmpdir.ensure(subdir, f) for subdir in subdirs 257 | for f in filenames] 258 | 259 | found_images = utils.list_pictures(str(tmpdir)) 260 | assert len(found_images) == 33 261 | 262 | found_images = utils.list_pictures(str(tmpdir), ext='png') 263 | assert len(found_images) == 6 264 | 265 | 266 | def test_array_to_img_and_img_to_array(): 267 | height, width = 10, 8 268 | 269 | # Test the data format 270 | # Test RGB 3D 271 | x = np.random.random((3, height, width)) 272 | img = utils.array_to_img(x, data_format='channels_first') 273 | assert img.size == (width, height) 274 | 275 | x = utils.img_to_array(img, data_format='channels_first') 276 | assert x.shape == (3, height, width) 277 | 278 | # Test RGBA 3D 279 | x = np.random.random((4, height, width)) 280 | img = utils.array_to_img(x, data_format='channels_first') 281 | assert img.size == (width, height) 282 | 283 | x = utils.img_to_array(img, data_format='channels_first') 284 | assert x.shape == (4, height, width) 285 | 286 | # Test 2D 287 | x = np.random.random((1, height, width)) 288 | img = utils.array_to_img(x, data_format='channels_first') 289 | assert img.size == (width, height) 290 | 291 | x = utils.img_to_array(img, data_format='channels_first') 292 | assert x.shape == (1, height, width) 293 | 294 | # grayscale 32-bit signed integer 295 | x = np.array( 296 | np.random.randint(-2147483648, 2147483647, (1, height, width)), 297 | dtype=np.int32 298 | ) 299 | img = utils.array_to_img(x, data_format='channels_first') 300 | assert img.size == (width, height) 301 | 302 | x = utils.img_to_array(img, data_format='channels_first') 303 | assert x.shape == (1, height, width) 304 | 305 | # Test tf data format 306 | # Test RGB 3D 307 | x = np.random.random((height, width, 3)) 308 | img = utils.array_to_img(x, data_format='channels_last') 309 | assert img.size == (width, height) 310 | 311 | x = utils.img_to_array(img, data_format='channels_last') 312 | assert x.shape == (height, width, 3) 313 | 314 | # Test RGBA 3D 315 | x = np.random.random((height, width, 4)) 316 | img = utils.array_to_img(x, data_format='channels_last') 317 | assert img.size == (width, height) 318 | 319 | x = utils.img_to_array(img, data_format='channels_last') 320 | assert x.shape == (height, width, 4) 321 | 322 | # Test 2D 323 | x = np.random.random((height, width, 1)) 324 | img = utils.array_to_img(x, data_format='channels_last') 325 | assert img.size == (width, height) 326 | 327 | x = utils.img_to_array(img, data_format='channels_last') 328 | assert x.shape == (height, width, 1) 329 | 330 | # grayscale 16-bit signed integer 331 | x = np.array( 332 | np.random.randint(-2147483648, 2147483647, (height, width, 1)), 333 | dtype=np.int16 334 | ) 335 | img = utils.array_to_img(x, data_format='channels_last') 336 | assert img.size == (width, height) 337 | 338 | x = utils.img_to_array(img, data_format='channels_last') 339 | assert x.shape == (height, width, 1) 340 | 341 | # grayscale 32-bit signed integer 342 | x = np.array( 343 | np.random.randint(-2147483648, 2147483647, (height, width, 1)), 344 | dtype=np.int32 345 | ) 346 | img = utils.array_to_img(x, data_format='channels_last') 347 | assert img.size == (width, height) 348 | 349 | x = utils.img_to_array(img, data_format='channels_last') 350 | assert x.shape == (height, width, 1) 351 | 352 | # Test invalid use case 353 | with pytest.raises(ValueError): 354 | x = np.random.random((height, width)) # not 3D 355 | img = utils.array_to_img(x, data_format='channels_first') 356 | 357 | with pytest.raises(ValueError): 358 | x = np.random.random((height, width, 3)) 359 | # unknown data_format 360 | img = utils.array_to_img(x, data_format='channels') 361 | 362 | with pytest.raises(ValueError): 363 | # neither RGB, RGBA, or gray-scale 364 | x = np.random.random((height, width, 5)) 365 | img = utils.array_to_img(x, data_format='channels_last') 366 | 367 | with pytest.raises(ValueError): 368 | x = np.random.random((height, width, 3)) 369 | # unknown data_format 370 | img = utils.img_to_array(x, data_format='channels') 371 | 372 | with pytest.raises(ValueError): 373 | # neither RGB, RGBA, or gray-scale 374 | x = np.random.random((height, width, 5, 3)) 375 | img = utils.img_to_array(x, data_format='channels_last') 376 | 377 | 378 | def write_sample_image(tmpdir): 379 | im = utils.array_to_img(np.random.rand(1, 1, 3)) 380 | path = str(tmpdir / 'sample_image.png') 381 | utils.save_img(path, im) 382 | return path 383 | 384 | 385 | def test_image_file_handlers_close(tmpdir): 386 | path = write_sample_image(tmpdir) 387 | max_open_files, _ = resource.getrlimit(resource.RLIMIT_NOFILE) 388 | for i in range(max_open_files+1): 389 | utils.load_img(path) 390 | 391 | 392 | def test_load_img_returns_image(tmpdir): 393 | path = write_sample_image(tmpdir) 394 | im = utils.load_img(path) 395 | assert isinstance(im, PIL.Image.Image) 396 | 397 | 398 | if __name__ == '__main__': 399 | pytest.main([__file__]) 400 | -------------------------------------------------------------------------------- /tests/sequence_test.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import numpy as np 4 | import pytest 5 | from numpy.testing import assert_allclose, assert_equal, assert_raises 6 | 7 | from keras_preprocessing import sequence 8 | 9 | 10 | def test_pad_sequences(): 11 | a = [[1], [1, 2], [1, 2, 3]] 12 | 13 | # test padding 14 | b = sequence.pad_sequences(a, maxlen=3, padding='pre') 15 | assert_allclose(b, [[0, 0, 1], [0, 1, 2], [1, 2, 3]]) 16 | b = sequence.pad_sequences(a, maxlen=3, padding='post') 17 | assert_allclose(b, [[1, 0, 0], [1, 2, 0], [1, 2, 3]]) 18 | 19 | # test truncating 20 | b = sequence.pad_sequences(a, maxlen=2, truncating='pre') 21 | assert_allclose(b, [[0, 1], [1, 2], [2, 3]]) 22 | b = sequence.pad_sequences(a, maxlen=2, truncating='post') 23 | assert_allclose(b, [[0, 1], [1, 2], [1, 2]]) 24 | 25 | # test value 26 | b = sequence.pad_sequences(a, maxlen=3, value=1) 27 | assert_allclose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]]) 28 | 29 | 30 | def test_pad_sequences_str(): 31 | a = [['1'], ['1', '2'], ['1', '2', '3']] 32 | 33 | # test padding 34 | b = sequence.pad_sequences(a, maxlen=3, padding='pre', value='pad', dtype=object) 35 | assert_equal(b, [['pad', 'pad', '1'], ['pad', '1', '2'], ['1', '2', '3']]) 36 | b = sequence.pad_sequences(a, maxlen=3, padding='post', value='pad', dtype=' end_index=49` is disallowed' in error 233 | 234 | 235 | def test_TimeSeriesGenerator_doesnt_miss_any_sample(): 236 | x = np.array([[i] for i in range(10)]) 237 | 238 | for length in range(3, 10): 239 | g = sequence.TimeseriesGenerator(x, x, 240 | length=length, 241 | batch_size=1) 242 | expected = max(0, len(x) - length) 243 | actual = len(g) 244 | 245 | assert expected == actual 246 | 247 | if len(g) > 0: 248 | # All elements in range(length, 10) should be used as current step 249 | expected = np.arange(length, 10).reshape(-1, 1) 250 | 251 | y = np.concatenate([g[ix][1] for ix in range(len(g))], axis=0) 252 | assert_allclose(y, expected) 253 | 254 | x = np.array([[i] for i in range(23)]) 255 | 256 | strides = (1, 1, 5, 7, 3, 5, 3) 257 | lengths = (3, 3, 4, 3, 1, 3, 7) 258 | batch_sizes = (6, 6, 6, 5, 6, 6, 6) 259 | shuffles = (False, True, True, False, False, False, False) 260 | 261 | for stride, length, batch_size, shuffle in zip(strides, 262 | lengths, 263 | batch_sizes, 264 | shuffles): 265 | g = sequence.TimeseriesGenerator(x, x, 266 | length=length, 267 | sampling_rate=1, 268 | stride=stride, 269 | start_index=0, 270 | end_index=None, 271 | shuffle=shuffle, 272 | reverse=False, 273 | batch_size=batch_size) 274 | if shuffle: 275 | # all batches have the same size when shuffle is True. 276 | expected_sequences = ceil( 277 | (23 - length) / float(batch_size * stride)) * batch_size 278 | else: 279 | # last batch will be different if `(samples - length) / stride` 280 | # is not a multiple of `batch_size`. 281 | expected_sequences = ceil((23 - length) / float(stride)) 282 | 283 | expected_batches = ceil(expected_sequences / float(batch_size)) 284 | 285 | y = [g[ix][1] for ix in range(len(g))] 286 | 287 | actual_sequences = sum(len(_y) for _y in y) 288 | actual_batches = len(y) 289 | 290 | assert expected_sequences == actual_sequences 291 | assert expected_batches == actual_batches 292 | 293 | 294 | if __name__ == '__main__': 295 | pytest.main([__file__]) 296 | -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import keras_preprocessing 4 | 5 | 6 | def test_api_modules(): 7 | expected_exposed_modules = [ 8 | 'image', 9 | 'sequence', 10 | 'text' 11 | ] 12 | for _module in expected_exposed_modules: 13 | assert hasattr(keras_preprocessing, _module) 14 | 15 | 16 | def test_get_keras_submodule(monkeypatch): 17 | monkeypatch.setattr(keras_preprocessing, '_KERAS_BACKEND', 'backend') 18 | assert 'backend' == keras_preprocessing.get_keras_submodule('backend') 19 | monkeypatch.setattr(keras_preprocessing, '_KERAS_UTILS', 'utils') 20 | assert 'utils' == keras_preprocessing.get_keras_submodule('utils') 21 | 22 | 23 | def test_get_keras_submodule_errors(monkeypatch): 24 | with pytest.raises(ImportError): 25 | keras_preprocessing.get_keras_submodule('something') 26 | 27 | monkeypatch.setattr(keras_preprocessing, '_KERAS_BACKEND', None) 28 | with pytest.raises(ImportError): 29 | keras_preprocessing.get_keras_submodule('backend') 30 | 31 | with pytest.raises(ImportError): 32 | keras_preprocessing.get_keras_submodule('utils') 33 | -------------------------------------------------------------------------------- /tests/test_documentation.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import inspect 3 | import re 4 | from itertools import compress 5 | 6 | import pytest 7 | 8 | modules = ['keras_preprocessing', 9 | 'keras_preprocessing.image', 10 | 'keras_preprocessing.sequence', 11 | 'keras_preprocessing.text'] 12 | 13 | # Tokenizer is being refactored PR #106 14 | accepted_name = ['set_keras_submodules', 'get_keras_submodule', 'Tokenizer'] 15 | accepted_module = [] 16 | 17 | # Functions or classes with less than 'MIN_CODE_SIZE' lines can be ignored 18 | MIN_CODE_SIZE = 10 19 | 20 | 21 | def handle_class_init(name, member): 22 | init_args = [ 23 | arg for arg in list(inspect.signature(member.__init__).parameters.keys()) 24 | if arg not in ['self', 'args', 'kwargs'] 25 | ] 26 | assert_args_presence(init_args, member.__doc__, member, name) 27 | 28 | 29 | def handle_class(name, member): 30 | if is_accepted(name, member): 31 | return 32 | 33 | if member.__doc__ is None and not member_too_small(member): 34 | raise ValueError("{} class doesn't have any documentation".format(name), 35 | member.__module__, inspect.getmodule(member).__file__) 36 | 37 | handle_class_init(name, member) 38 | 39 | for n, met in inspect.getmembers(member): 40 | if inspect.ismethod(met): 41 | handle_method(n, met) 42 | 43 | 44 | def handle_function(name, member): 45 | if is_accepted(name, member) or member_too_small(member): 46 | # We don't need to check this one. 47 | return 48 | doc = member.__doc__ 49 | if doc is None: 50 | raise ValueError("{} function doesn't have any documentation".format(name), 51 | member.__module__, inspect.getmodule(member).__file__) 52 | 53 | args = list(inspect.signature(member).parameters.keys()) 54 | assert_args_presence(args, doc, member, name) 55 | assert_function_style(name, member, doc, args) 56 | assert_doc_style(name, member, doc) 57 | 58 | 59 | def assert_doc_style(name, member, doc): 60 | lines = doc.split("\n") 61 | first_line = lines[0] 62 | if len(first_line.strip()) == 0: 63 | raise ValueError( 64 | "{} the documentation should be on the first line.".format(name), 65 | member.__module__) 66 | first_blank = [i for i, line in enumerate(lines) if not line.strip()] 67 | if len(first_blank) > 0: 68 | if lines[first_blank[0] - 1].strip()[-1] != '.': 69 | raise ValueError("{} first line should end with a '.'".format(name), 70 | member.__module__) 71 | 72 | 73 | def assert_function_style(name, member, doc, args): 74 | code = inspect.getsource(member) 75 | has_return = re.findall(r"\s*return \S+", code, re.MULTILINE) 76 | if has_return and "# Returns" not in doc: 77 | innerfunction = [inspect.getsource(x) for x in member.__code__.co_consts if 78 | inspect.iscode(x)] 79 | return_in_sub = [ret for code_inner in innerfunction for ret in 80 | re.findall(r"\s*return \S+", code_inner, re.MULTILINE)] 81 | if len(return_in_sub) < len(has_return): 82 | raise ValueError("{} needs a '# Returns' section".format(name), 83 | member.__module__) 84 | 85 | has_raise = re.findall(r"^\s*raise \S+", code, re.MULTILINE) 86 | if has_raise and "# Raises" not in doc: 87 | innerfunction = [inspect.getsource(x) for x in member.__code__.co_consts if 88 | inspect.iscode(x)] 89 | raise_in_sub = [ret for code_inner in innerfunction for ret in 90 | re.findall(r"\s*raise \S+", code_inner, re.MULTILINE)] 91 | if len(raise_in_sub) < len(has_raise): 92 | raise ValueError("{} needs a '# Raises' section".format(name), 93 | member.__module__) 94 | 95 | if len(args) > 0 and "# Arguments" not in doc: 96 | raise ValueError("{} needs a '# Arguments' section".format(name), 97 | member.__module__) 98 | 99 | assert_blank_before(name, member, doc, ['# Arguments', '# Raises', '# Returns']) 100 | 101 | 102 | def assert_blank_before(name, member, doc, keywords): 103 | doc_lines = [x.strip() for x in doc.split('\n')] 104 | for keyword in keywords: 105 | if keyword in doc_lines: 106 | index = doc_lines.index(keyword) 107 | if doc_lines[index - 1] != '': 108 | raise ValueError( 109 | "{} '{}' should have a blank line above.".format(name, keyword), 110 | member.__module__) 111 | 112 | 113 | def is_accepted(name, member): 114 | if 'keras_preprocessing' not in str(member.__module__): 115 | return True 116 | return name in accepted_name or member.__module__ in accepted_module 117 | 118 | 119 | def member_too_small(member): 120 | code = inspect.getsource(member).split('\n') 121 | return len(code) < MIN_CODE_SIZE 122 | 123 | 124 | def assert_args_presence(args, doc, member, name): 125 | args_not_in_doc = [arg not in doc for arg in args] 126 | if any(args_not_in_doc): 127 | raise ValueError( 128 | "{} {} arguments are not present in documentation ".format(name, list( 129 | compress(args, args_not_in_doc))), member.__module__, member) 130 | words = doc.replace('*', '').split() 131 | # Check arguments styling 132 | styles = [arg + ":" not in words for arg in args] 133 | if any(styles): 134 | raise ValueError( 135 | "{} {} are not style properly 'argument': documentation".format( 136 | name, 137 | list(compress(args, styles))), 138 | member.__module__) 139 | 140 | # Check arguments order 141 | indexes = [words.index(arg + ":") for arg in args] 142 | if indexes != sorted(indexes): 143 | raise ValueError( 144 | "{} arguments order is different from the documentation".format(name), 145 | member.__module__, indexes) 146 | 147 | 148 | def handle_method(name, member): 149 | if name in accepted_name or member.__module__ in accepted_module: 150 | return 151 | handle_function(name, member) 152 | 153 | 154 | def handle_module(mod): 155 | for name, mem in inspect.getmembers(mod): 156 | if inspect.isclass(mem): 157 | handle_class(name, mem) 158 | elif inspect.isfunction(mem): 159 | handle_function(name, mem) 160 | elif 'keras_preprocessing' in name and inspect.ismodule(mem): 161 | # Only test keras_preprocessing' modules 162 | handle_module(mem) 163 | 164 | 165 | def test_doc(): 166 | for module in modules: 167 | mod = importlib.import_module(module) 168 | handle_module(mod) 169 | 170 | 171 | if __name__ == '__main__': 172 | pytest.main([__file__]) 173 | -------------------------------------------------------------------------------- /tests/text_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import pytest 6 | from tensorflow import keras 7 | 8 | from keras_preprocessing import text 9 | 10 | 11 | def test_one_hot(): 12 | sample_text = 'The cat sat on the mat.' 13 | encoded = text.one_hot(sample_text, 5) 14 | assert len(encoded) == 6 15 | assert np.max(encoded) <= 4 16 | assert np.min(encoded) >= 0 17 | 18 | sample_text = 'The-cat-sat-on-the-mat' 19 | encoded2 = text.one_hot(sample_text, 5, analyzer=lambda t: t.lower().split('-')) 20 | assert encoded == encoded2 21 | assert len(encoded) == 6 22 | assert np.max(encoded) <= 4 23 | assert np.min(encoded) >= 0 24 | 25 | 26 | def test_hashing_trick_hash(): 27 | sample_text = 'The cat sat on the mat.' 28 | encoded = text.hashing_trick(sample_text, 5) 29 | assert len(encoded) == 6 30 | assert np.max(encoded) <= 4 31 | assert np.min(encoded) >= 1 32 | 33 | 34 | def test_hashing_trick_md5(): 35 | sample_text = 'The cat sat on the mat.' 36 | encoded = text.hashing_trick(sample_text, 5, hash_function='md5') 37 | assert len(encoded) == 6 38 | assert np.max(encoded) <= 4 39 | assert np.min(encoded) >= 1 40 | 41 | 42 | def test_tokenizer(): 43 | sample_texts = ['The cat sat on the mat.', 44 | 'The dog sat on the log.', 45 | 'Dogs and cats living together.'] 46 | tokenizer = text.Tokenizer(num_words=10) 47 | tokenizer.fit_on_texts(sample_texts) 48 | 49 | sequences = [] 50 | for seq in tokenizer.texts_to_sequences_generator(sample_texts): 51 | sequences.append(seq) 52 | assert np.max(np.max(sequences)) < 10 53 | assert np.min(np.min(sequences)) == 1 54 | 55 | tokenizer.fit_on_sequences(sequences) 56 | 57 | for mode in ['binary', 'count', 'tfidf', 'freq']: 58 | tokenizer.texts_to_matrix(sample_texts, mode) 59 | 60 | 61 | def test_tokenizer_serde_no_fitting(): 62 | tokenizer = text.Tokenizer(num_words=100) 63 | 64 | tokenizer_json = tokenizer.to_json() 65 | recovered = text.tokenizer_from_json(tokenizer_json) 66 | 67 | assert tokenizer.get_config() == recovered.get_config() 68 | 69 | assert tokenizer.word_docs == recovered.word_docs 70 | assert tokenizer.word_counts == recovered.word_counts 71 | assert tokenizer.word_index == recovered.word_index 72 | assert tokenizer.index_word == recovered.index_word 73 | assert tokenizer.index_docs == recovered.index_docs 74 | 75 | 76 | def test_tokenizer_serde_fitting(): 77 | sample_texts = [ 78 | 'There was a time that the pieces fit, but I watched them fall away', 79 | 'Mildewed and smoldering, strangled by our coveting', 80 | 'I\'ve done the math enough to know the dangers of our second guessing'] 81 | tokenizer = text.Tokenizer(num_words=100) 82 | tokenizer.fit_on_texts(sample_texts) 83 | 84 | seq_generator = tokenizer.texts_to_sequences_generator(sample_texts) 85 | sequences = [seq for seq in seq_generator] 86 | tokenizer.fit_on_sequences(sequences) 87 | 88 | tokenizer_json = tokenizer.to_json() 89 | recovered = text.tokenizer_from_json(tokenizer_json) 90 | 91 | assert tokenizer.char_level == recovered.char_level 92 | assert tokenizer.document_count == recovered.document_count 93 | assert tokenizer.filters == recovered.filters 94 | assert tokenizer.lower == recovered.lower 95 | assert tokenizer.num_words == recovered.num_words 96 | assert tokenizer.oov_token == recovered.oov_token 97 | 98 | assert tokenizer.word_docs == recovered.word_docs 99 | assert tokenizer.word_counts == recovered.word_counts 100 | assert tokenizer.word_index == recovered.word_index 101 | assert tokenizer.index_word == recovered.index_word 102 | assert tokenizer.index_docs == recovered.index_docs 103 | 104 | 105 | def test_sequential_fit(): 106 | texts = ['The cat sat on the mat.', 107 | 'The dog sat on the log.', 108 | 'Dogs and cats living together.'] 109 | word_sequences = [ 110 | ['The', 'cat', 'is', 'sitting'], 111 | ['The', 'dog', 'is', 'standing'] 112 | ] 113 | 114 | tokenizer = text.Tokenizer() 115 | tokenizer.fit_on_texts(texts) 116 | tokenizer.fit_on_texts(word_sequences) 117 | 118 | assert tokenizer.document_count == 5 119 | 120 | tokenizer.texts_to_matrix(texts) 121 | tokenizer.texts_to_matrix(word_sequences) 122 | 123 | 124 | def test_text_to_word_sequence(): 125 | sample_text = 'hello! ? world!' 126 | assert text.text_to_word_sequence(sample_text) == ['hello', 'world'] 127 | 128 | 129 | def test_text_to_word_sequence_multichar_split(): 130 | sample_text = 'hello!stop?world!' 131 | assert text.text_to_word_sequence( 132 | sample_text, split='stop') == ['hello', 'world'] 133 | 134 | 135 | def test_text_to_word_sequence_unicode(): 136 | sample_text = u'ali! veli? kırk dokuz elli' 137 | assert text.text_to_word_sequence( 138 | sample_text) == [u'ali', u'veli', u'kırk', u'dokuz', u'elli'] 139 | 140 | 141 | def test_text_to_word_sequence_unicode_multichar_split(): 142 | sample_text = u'ali!stopveli?stopkırkstopdokuzstopelli' 143 | assert text.text_to_word_sequence( 144 | sample_text, split='stop') == [u'ali', u'veli', u'kırk', u'dokuz', u'elli'] 145 | 146 | 147 | def test_tokenizer_unicode(): 148 | sample_texts = [u'ali veli kırk dokuz elli', 149 | u'ali veli kırk dokuz elli veli kırk dokuz'] 150 | tokenizer = text.Tokenizer(num_words=5) 151 | tokenizer.fit_on_texts(sample_texts) 152 | 153 | assert len(tokenizer.word_counts) == 5 154 | 155 | 156 | def test_tokenizer_oov_flag(): 157 | """Test of Out of Vocabulary (OOV) flag in text.Tokenizer 158 | """ 159 | x_train = ['This text has only known words'] 160 | x_test = ['This text has some unknown words'] # 2 OOVs: some, unknown 161 | 162 | # Default, without OOV flag 163 | tokenizer = text.Tokenizer() 164 | tokenizer.fit_on_texts(x_train) 165 | x_test_seq = tokenizer.texts_to_sequences(x_test) 166 | assert len(x_test_seq[0]) == 4 # discards 2 OOVs 167 | 168 | # With OOV feature 169 | tokenizer = text.Tokenizer(oov_token='') 170 | tokenizer.fit_on_texts(x_train) 171 | x_test_seq = tokenizer.texts_to_sequences(x_test) 172 | assert len(x_test_seq[0]) == 6 # OOVs marked in place 173 | 174 | 175 | def test_tokenizer_oov_flag_and_num_words(): 176 | x_train = ['This text has only known words this text'] 177 | x_test = ['This text has some unknown words'] 178 | 179 | tokenizer = keras.preprocessing.text.Tokenizer(num_words=3, 180 | oov_token='') 181 | tokenizer.fit_on_texts(x_train) 182 | x_test_seq = tokenizer.texts_to_sequences(x_test) 183 | trans_text = ' '.join(tokenizer.index_word[t] for t in x_test_seq[0]) 184 | assert len(x_test_seq[0]) == 6 185 | assert trans_text == 'this ' 186 | 187 | 188 | def test_sequences_to_texts_with_num_words_and_oov_token(): 189 | x_train = ['This text has only known words this text'] 190 | x_test = ['This text has some unknown words'] 191 | 192 | tokenizer = keras.preprocessing.text.Tokenizer(num_words=3, 193 | oov_token='') 194 | 195 | tokenizer.fit_on_texts(x_train) 196 | x_test_seq = tokenizer.texts_to_sequences(x_test) 197 | trans_text = tokenizer.sequences_to_texts(x_test_seq) 198 | assert trans_text == ['this '] 199 | 200 | 201 | def test_sequences_to_texts_no_num_words(): 202 | x_train = ['This text has only known words this text'] 203 | x_test = ['This text has some unknown words'] 204 | 205 | tokenizer = keras.preprocessing.text.Tokenizer(oov_token='') 206 | 207 | tokenizer.fit_on_texts(x_train) 208 | x_test_seq = tokenizer.texts_to_sequences(x_test) 209 | trans_text = tokenizer.sequences_to_texts(x_test_seq) 210 | assert trans_text == ['this text has words'] 211 | 212 | 213 | def test_sequences_to_texts_no_oov_token(): 214 | x_train = ['This text has only known words this text'] 215 | x_test = ['This text has some unknown words'] 216 | 217 | tokenizer = keras.preprocessing.text.Tokenizer(num_words=3) 218 | 219 | tokenizer.fit_on_texts(x_train) 220 | x_test_seq = tokenizer.texts_to_sequences(x_test) 221 | trans_text = tokenizer.sequences_to_texts(x_test_seq) 222 | assert trans_text == ['this text'] 223 | 224 | 225 | def test_sequences_to_texts_no_num_words_no_oov_token(): 226 | x_train = ['This text has only known words this text'] 227 | x_test = ['This text has some unknown words'] 228 | 229 | tokenizer = keras.preprocessing.text.Tokenizer() 230 | 231 | tokenizer.fit_on_texts(x_train) 232 | x_test_seq = tokenizer.texts_to_sequences(x_test) 233 | trans_text = tokenizer.sequences_to_texts(x_test_seq) 234 | assert trans_text == ['this text has words'] 235 | 236 | 237 | def test_sequences_to_texts(): 238 | texts = [ 239 | 'The cat sat on the mat.', 240 | 'The dog sat on the log.', 241 | 'Dogs and cats living together.' 242 | ] 243 | tokenizer = keras.preprocessing.text.Tokenizer(num_words=10, 244 | oov_token='') 245 | tokenizer.fit_on_texts(texts) 246 | tokenized_text = tokenizer.texts_to_sequences(texts) 247 | trans_text = tokenizer.sequences_to_texts(tokenized_text) 248 | assert trans_text == ['the cat sat on the mat', 249 | 'the dog sat on the log', 250 | 'dogs '] 251 | 252 | 253 | def test_tokenizer_lower_flag(): 254 | """Tests for `lower` flag in text.Tokenizer 255 | """ 256 | # word level tokenizer with sentences as texts 257 | word_tokenizer = text.Tokenizer(lower=True) 258 | texts = ['The cat sat on the mat.', 259 | 'The dog sat on the log.', 260 | 'Dog and Cat living Together.'] 261 | word_tokenizer.fit_on_texts(texts) 262 | expected_word_counts = OrderedDict([('the', 4), ('cat', 2), ('sat', 2), 263 | ('on', 2), ('mat', 1), ('dog', 2), 264 | ('log', 1), ('and', 1), ('living', 1), 265 | ('together', 1)]) 266 | assert word_tokenizer.word_counts == expected_word_counts 267 | 268 | # word level tokenizer with word_sequences as texts 269 | word_tokenizer = text.Tokenizer(lower=True) 270 | word_sequences = [ 271 | ['The', 'cat', 'is', 'sitting'], 272 | ['The', 'dog', 'is', 'standing'] 273 | ] 274 | word_tokenizer.fit_on_texts(word_sequences) 275 | expected_word_counts = OrderedDict([('the', 2), ('cat', 1), ('is', 2), 276 | ('sitting', 1), ('dog', 1), 277 | ('standing', 1)]) 278 | assert word_tokenizer.word_counts == expected_word_counts 279 | 280 | # char level tokenizer with sentences as texts 281 | char_tokenizer = text.Tokenizer(lower=True, char_level=True) 282 | texts = ['The cat sat on the mat.', 283 | 'The dog sat on the log.', 284 | 'Dog and Cat living Together.'] 285 | char_tokenizer.fit_on_texts(texts) 286 | expected_word_counts = OrderedDict([('t', 11), ('h', 5), ('e', 6), (' ', 14), 287 | ('c', 2), ('a', 6), ('s', 2), ('o', 6), 288 | ('n', 4), ('m', 1), ('.', 3), ('d', 3), 289 | ('g', 5), ('l', 2), ('i', 2), ('v', 1), 290 | ('r', 1)]) 291 | assert char_tokenizer.word_counts == expected_word_counts 292 | 293 | 294 | if __name__ == '__main__': 295 | pytest.main([__file__]) 296 | --------------------------------------------------------------------------------